Merge pull request #128 from Kav-K/gpt-index
Gpt-index implementation
commit
83cc24e6d0
@ -0,0 +1,111 @@
|
||||
import traceback
|
||||
|
||||
import discord
|
||||
|
||||
from services.environment_service import EnvService
|
||||
from services.text_service import TextService
|
||||
from models.index_model import Index_handler
|
||||
|
||||
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
|
||||
USER_KEY_DB = EnvService.get_api_db()
|
||||
|
||||
class IndexService(discord.Cog, name="IndexService"):
|
||||
"""Cog containing gpt-index commands"""
|
||||
def __init__(
|
||||
self,
|
||||
bot,
|
||||
usage_service,
|
||||
):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
self.index_handler = Index_handler(bot, usage_service)
|
||||
|
||||
async def set_index_command(self, ctx, file: discord.Attachment = None, link: str = None):
|
||||
"""Command handler to set a file as your personal index"""
|
||||
if not file and not link:
|
||||
await ctx.respond("Please provide a file or a link")
|
||||
return
|
||||
|
||||
if file and link:
|
||||
await ctx.respond("Please provide only one file or link. Only one or the other.")
|
||||
return
|
||||
|
||||
user_api_key = None
|
||||
if USER_INPUT_API_KEYS:
|
||||
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
|
||||
if not user_api_key:
|
||||
return
|
||||
|
||||
await ctx.defer(ephemeral=True)
|
||||
if file:
|
||||
await self.index_handler.set_file_index(ctx, file, user_api_key=user_api_key)
|
||||
elif link:
|
||||
await self.index_handler.set_link_index(ctx, link, user_api_key=user_api_key)
|
||||
|
||||
|
||||
async def set_discord_command(self, ctx, channel: discord.TextChannel = None):
|
||||
"""Command handler to set a channel as your personal index"""
|
||||
|
||||
user_api_key = None
|
||||
if USER_INPUT_API_KEYS:
|
||||
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
|
||||
if not user_api_key:
|
||||
return
|
||||
|
||||
await ctx.defer(ephemeral=True)
|
||||
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key)
|
||||
|
||||
async def reset_command(self, ctx):
|
||||
await ctx.defer(ephemeral=True)
|
||||
try:
|
||||
self.index_handler.reset_indexes(ctx.user.id)
|
||||
await ctx.respond("Your indexes have been reset")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
await ctx.respond("Something went wrong while resetting your indexes. Contact the server admin.")
|
||||
|
||||
async def discord_backup_command(self, ctx):
|
||||
"""Command handler to backup the entire server"""
|
||||
|
||||
user_api_key = None
|
||||
if USER_INPUT_API_KEYS:
|
||||
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
|
||||
if not user_api_key:
|
||||
return
|
||||
|
||||
await ctx.defer(ephemeral=True)
|
||||
await self.index_handler.backup_discord(ctx, user_api_key=user_api_key)
|
||||
|
||||
|
||||
async def load_index_command(self, ctx, index):
|
||||
"""Command handler to backup the entire server"""
|
||||
user_api_key = None
|
||||
if USER_INPUT_API_KEYS:
|
||||
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
|
||||
if not user_api_key:
|
||||
return
|
||||
|
||||
await ctx.defer(ephemeral=True)
|
||||
await self.index_handler.load_index(ctx, index, user_api_key)
|
||||
|
||||
|
||||
async def query_command(self, ctx, query, response_mode):
|
||||
"""Command handler to query your index"""
|
||||
user_api_key = None
|
||||
if USER_INPUT_API_KEYS:
|
||||
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
|
||||
if not user_api_key:
|
||||
return
|
||||
|
||||
await ctx.defer()
|
||||
await self.index_handler.query(ctx, query, response_mode, user_api_key)
|
||||
|
||||
async def compose_command(self, ctx, name):
|
||||
"""Command handler to compose from your index"""
|
||||
user_api_key = None
|
||||
if USER_INPUT_API_KEYS:
|
||||
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
|
||||
if not user_api_key:
|
||||
return
|
||||
|
||||
await self.index_handler.compose(ctx, name, user_api_key)
|
@ -0,0 +1,488 @@
|
||||
import os
|
||||
import traceback
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
import discord
|
||||
import aiofiles
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from datetime import date, datetime
|
||||
|
||||
from gpt_index.readers import YoutubeTranscriptReader
|
||||
from gpt_index.readers.schema.base import Document
|
||||
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader, QuestionAnswerPrompt, BeautifulSoupWebReader, \
|
||||
GPTFaissIndex, GPTListIndex, QueryMode, GPTTreeIndex, GoogleDocsReader, MockLLMPredictor, QueryConfig, \
|
||||
IndexStructType
|
||||
from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR
|
||||
|
||||
from gpt_index.composability import ComposableGraph
|
||||
|
||||
from services.environment_service import EnvService, app_root_path
|
||||
|
||||
SHORT_TO_LONG_CACHE = {}
|
||||
|
||||
def get_and_query(user_id, index_storage, query, llm_predictor):
|
||||
# TODO Do prediction here for token usage
|
||||
index: [GPTSimpleVectorIndex, ComposableGraph] = index_storage[user_id].get_index_or_throw()
|
||||
if isinstance(index, GPTTreeIndex):
|
||||
response = index.query(query, verbose=True, child_branch_factor=2)
|
||||
else:
|
||||
response = index.query(query,verbose=True)
|
||||
return response
|
||||
|
||||
class IndexData:
|
||||
def __init__(self):
|
||||
self.queryable_index = None
|
||||
self.individual_indexes = []
|
||||
|
||||
# A safety check for the future
|
||||
def get_index_or_throw(self):
|
||||
if not self.queryable():
|
||||
raise Exception("An index access was attempted before an index was created. This is a programmer error, please report this to the maintainers.")
|
||||
return self.queryable_index
|
||||
def queryable(self):
|
||||
return self.queryable_index is not None
|
||||
|
||||
def has_indexes(self, user_id):
|
||||
try:
|
||||
return len(os.listdir(f"{app_root_path()}/indexes/{user_id}")) > 0
|
||||
except:
|
||||
return False
|
||||
|
||||
def add_index(self, index, user_id, file_name):
|
||||
self.individual_indexes.append(index)
|
||||
self.queryable_index = index
|
||||
|
||||
# Create a folder called "indexes/{USER_ID}" if it doesn't exist already
|
||||
Path(f"{app_root_path()}/indexes/{user_id}").mkdir(parents=True, exist_ok=True)
|
||||
# Save the index to file under the user id
|
||||
index.save_to_disk(app_root_path() / "indexes" / f"{str(user_id)}"/f"{file_name}_{date.today().month}_{date.today().day}.json")
|
||||
|
||||
def reset_indexes(self, user_id):
|
||||
self.individual_indexes = []
|
||||
self.queryable_index = None
|
||||
|
||||
# Delete the user indexes
|
||||
try:
|
||||
# First, clear all the files inside it
|
||||
for file in os.listdir(f"{app_root_path()}/indexes/{user_id}"):
|
||||
os.remove(f"{app_root_path()}/indexes/{user_id}/{file}")
|
||||
|
||||
except:
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
class Index_handler:
|
||||
def __init__(self, bot, usage_service):
|
||||
self.bot = bot
|
||||
self.openai_key = os.getenv("OPENAI_TOKEN")
|
||||
self.index_storage = defaultdict(IndexData)
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.usage_service = usage_service
|
||||
self.qaprompt = QuestionAnswerPrompt(
|
||||
"Context information is below. The text '<|endofstatement|>' is used to separate chat entries and make it easier for you to understand the context\n"
|
||||
"---------------------\n"
|
||||
"{context_str}"
|
||||
"\n---------------------\n"
|
||||
"Never say '<|endofstatement|>'\n"
|
||||
"Given the context information and not prior knowledge, "
|
||||
"answer the question: {query_str}\n"
|
||||
)
|
||||
|
||||
# TODO We need to do predictions below for token usage.
|
||||
def index_file(self, file_path) -> GPTSimpleVectorIndex:
|
||||
document = SimpleDirectoryReader(file_path).load_data()
|
||||
index = GPTSimpleVectorIndex(document)
|
||||
return index
|
||||
|
||||
def index_gdoc(self, doc_id) -> GPTSimpleVectorIndex:
|
||||
document = GoogleDocsReader().load_data(doc_id)
|
||||
index = GPTSimpleVectorIndex(document)
|
||||
return index
|
||||
|
||||
def index_youtube_transcript(self, link):
|
||||
documents = YoutubeTranscriptReader().load_data(ytlinks=[link])
|
||||
index = GPTSimpleVectorIndex(documents,)
|
||||
return index
|
||||
|
||||
def index_load_file(self, file_path) -> [GPTSimpleVectorIndex, ComposableGraph]:
|
||||
if "composed_deep" in str(file_path):
|
||||
index = GPTTreeIndex.load_from_disk(file_path)
|
||||
else:
|
||||
index = GPTSimpleVectorIndex.load_from_disk(file_path)
|
||||
return index
|
||||
|
||||
def index_discord(self, document) -> GPTSimpleVectorIndex:
|
||||
index = GPTSimpleVectorIndex(document,)
|
||||
return index
|
||||
|
||||
def index_webpage(self, url) -> GPTSimpleVectorIndex:
|
||||
documents = BeautifulSoupWebReader(website_extractor=DEFAULT_WEBSITE_EXTRACTOR).load_data(urls=[url])
|
||||
index = GPTSimpleVectorIndex(documents)
|
||||
return index
|
||||
|
||||
def reset_indexes(self, user_id):
|
||||
self.index_storage[user_id].reset_indexes(user_id)
|
||||
|
||||
async def set_file_index(self, ctx: discord.ApplicationContext, file: discord.Attachment, user_api_key):
|
||||
if not user_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = self.openai_key
|
||||
else:
|
||||
os.environ["OPENAI_API_KEY"] = user_api_key
|
||||
|
||||
try:
|
||||
print(file.content_type)
|
||||
if file.content_type.startswith("text/plain"):
|
||||
suffix = ".txt"
|
||||
elif file.content_type.startswith("application/pdf"):
|
||||
suffix = ".pdf"
|
||||
# Allow for images too
|
||||
elif file.content_type.startswith("image/png"):
|
||||
suffix = ".png"
|
||||
elif file.content_type.startswith("image/"):
|
||||
suffix = ".jpg"
|
||||
elif "csv" in file.content_type:
|
||||
suffix = ".csv"
|
||||
elif "vnd." in file.content_type:
|
||||
suffix = ".pptx"
|
||||
# Catch all audio files and suffix with "mp3"
|
||||
elif file.content_type.startswith("audio/"):
|
||||
suffix = ".mp3"
|
||||
# Catch video files
|
||||
elif file.content_type.startswith("video/"):
|
||||
pass # No suffix change
|
||||
else:
|
||||
await ctx.respond("Only accepts text, pdf, images, spreadheets, powerpoint, and audio/video files.")
|
||||
return
|
||||
async with aiofiles.tempfile.TemporaryDirectory() as temp_path:
|
||||
async with aiofiles.tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path, delete=False) as temp_file:
|
||||
await file.save(temp_file.name)
|
||||
index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path))
|
||||
|
||||
file_name = file.filename
|
||||
self.index_storage[ctx.user.id].add_index(index, ctx.user.id, file_name)
|
||||
await ctx.respond("Index added to your indexes.")
|
||||
except Exception:
|
||||
await ctx.respond("Failed to set index")
|
||||
traceback.print_exc()
|
||||
|
||||
async def set_link_index(self, ctx: discord.ApplicationContext, link: str, user_api_key):
|
||||
if not user_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = self.openai_key
|
||||
else:
|
||||
os.environ["OPENAI_API_KEY"] = user_api_key
|
||||
|
||||
# TODO Link validation
|
||||
try:
|
||||
# Check if the link contains youtube in it
|
||||
if "youtube" in link:
|
||||
index = await self.loop.run_in_executor(None, partial(self.index_youtube_transcript, link))
|
||||
else:
|
||||
index = await self.loop.run_in_executor(None, partial(self.index_webpage, link))
|
||||
|
||||
# Make the url look nice, remove https, useless stuff, random characters
|
||||
file_name = link.replace("https://", "").replace("http://", "").replace("www.", "").replace("/", "_").replace("?", "_").replace("&", "_").replace("=", "_").replace("-", "_").replace(".", "_")
|
||||
|
||||
self.index_storage[ctx.user.id].add_index(index, ctx.user.id, file_name)
|
||||
|
||||
except Exception:
|
||||
await ctx.respond("Failed to set index")
|
||||
traceback.print_exc()
|
||||
|
||||
await ctx.respond("Index set")
|
||||
|
||||
|
||||
async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key):
|
||||
if not user_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = self.openai_key
|
||||
else:
|
||||
os.environ["OPENAI_API_KEY"] = user_api_key
|
||||
|
||||
try:
|
||||
document = await self.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
|
||||
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
|
||||
self.index_storage[ctx.user.id].add_index(index, ctx.user.id, channel.name)
|
||||
await ctx.respond("Index set")
|
||||
except Exception:
|
||||
await ctx.respond("Failed to set index")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def load_index(self, ctx:discord.ApplicationContext, index, user_api_key):
|
||||
if not user_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = self.openai_key
|
||||
else:
|
||||
os.environ["OPENAI_API_KEY"] = user_api_key
|
||||
|
||||
try:
|
||||
index_file = EnvService.find_shared_file(f"indexes/{ctx.user.id}/{index}")
|
||||
index = await self.loop.run_in_executor(None, partial(self.index_load_file, index_file))
|
||||
self.index_storage[ctx.user.id].queryable_index = index
|
||||
await ctx.respond("Loaded index")
|
||||
except Exception as e:
|
||||
await ctx.respond(e)
|
||||
|
||||
async def compose_indexes(self, user_id, indexes, name, deep_compose):
|
||||
# Load all the indexes first
|
||||
index_objects = []
|
||||
for _index in indexes:
|
||||
index_file = EnvService.find_shared_file(f"indexes/{user_id}/{_index}")
|
||||
index = await self.loop.run_in_executor(None, partial(self.index_load_file, index_file))
|
||||
index_objects.append(index)
|
||||
|
||||
# For each index object, add its documents to a GPTTreeIndex
|
||||
if deep_compose:
|
||||
documents = []
|
||||
for _index in index_objects:
|
||||
[documents.append(_index.docstore.get_document(doc_id)) for doc_id in [docmeta for docmeta in _index.docstore.docs.keys()] if isinstance(_index.docstore.get_document(doc_id), Document)]
|
||||
tree_index = GPTTreeIndex(documents=documents)
|
||||
|
||||
# Now we have a list of tree indexes, we can compose them
|
||||
if not name:
|
||||
name = f"composed_deep_index_{date.today().month}_{date.today().day}.json"
|
||||
|
||||
# Save the composed index
|
||||
tree_index.save_to_disk(f"indexes/{user_id}/{name}.json")
|
||||
|
||||
self.index_storage[user_id].queryable_index = tree_index
|
||||
else:
|
||||
documents = []
|
||||
for _index in index_objects:
|
||||
[documents.append(_index.docstore.get_document(doc_id)) for doc_id in [docmeta for docmeta in _index.docstore.docs.keys()] if isinstance(_index.docstore.get_document(doc_id), Document)]
|
||||
|
||||
# Add everything into a simple vector index
|
||||
simple_index = GPTSimpleVectorIndex(documents=documents)
|
||||
|
||||
if not name:
|
||||
name = f"composed_index_{date.today().month}_{date.today().day}.json"
|
||||
|
||||
# Save the composed index
|
||||
simple_index.save_to_disk(f"indexes/{user_id}/{name}.json")
|
||||
self.index_storage[user_id].queryable_index = simple_index
|
||||
|
||||
|
||||
async def backup_discord(self, ctx: discord.ApplicationContext, user_api_key):
|
||||
if not user_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = self.openai_key
|
||||
else:
|
||||
os.environ["OPENAI_API_KEY"] = user_api_key
|
||||
|
||||
try:
|
||||
channel_ids:List[int] = []
|
||||
for c in ctx.guild.text_channels:
|
||||
channel_ids.append(c.id)
|
||||
document = await self.load_data(channel_ids=channel_ids, limit=3000, oldest_first=False)
|
||||
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
|
||||
Path(app_root_path() / "indexes").mkdir(parents = True, exist_ok=True)
|
||||
index.save_to_disk(app_root_path() / "indexes" / f"{ctx.guild.name.replace(' ', '-')}_{date.today().month}_{date.today().day}.json")
|
||||
|
||||
await ctx.respond("Backup saved")
|
||||
except Exception:
|
||||
await ctx.respond("Failed to save backup")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
||||
async def query(self, ctx: discord.ApplicationContext, query:str, response_mode, user_api_key):
|
||||
if not user_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = self.openai_key
|
||||
else:
|
||||
os.environ["OPENAI_API_KEY"] = user_api_key
|
||||
|
||||
try:
|
||||
llm_predictor = MockLLMPredictor(max_tokens=256)
|
||||
response = await self.loop.run_in_executor(None, partial(get_and_query, ctx.user.id, self.index_storage, query, llm_predictor))
|
||||
print("The last token usage was ", llm_predictor.last_token_usage)
|
||||
await self.usage_service.update_usage(llm_predictor.last_token_usage)
|
||||
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
await ctx.respond("Failed to send query. You may not have an index set, load an index with /index load", delete_after=10)
|
||||
|
||||
# Extracted functions from DiscordReader
|
||||
|
||||
async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str:
|
||||
"""Async read channel."""
|
||||
|
||||
messages: List[discord.Message] = []
|
||||
|
||||
|
||||
try:
|
||||
channel = self.bot.get_channel(channel_id)
|
||||
print(f"Added {channel.name} from {channel.guild.name}")
|
||||
# only work for text channels for now
|
||||
if not isinstance(channel, discord.TextChannel):
|
||||
raise ValueError(
|
||||
f"Channel {channel_id} is not a text channel. "
|
||||
"Only text channels are supported for now."
|
||||
)
|
||||
# thread_dict maps thread_id to thread
|
||||
thread_dict = {}
|
||||
for thread in channel.threads:
|
||||
thread_dict[thread.id] = thread
|
||||
|
||||
async for msg in channel.history(
|
||||
limit=limit, oldest_first=oldest_first
|
||||
):
|
||||
if msg.author.bot:
|
||||
pass
|
||||
else:
|
||||
messages.append(msg)
|
||||
if msg.id in thread_dict:
|
||||
thread = thread_dict[msg.id]
|
||||
async for thread_msg in thread.history(
|
||||
limit=limit, oldest_first=oldest_first
|
||||
):
|
||||
messages.append(thread_msg)
|
||||
except Exception as e:
|
||||
print("Encountered error: " + str(e))
|
||||
|
||||
channel = self.bot.get_channel(channel_id)
|
||||
msg_txt_list = [f"user:{m.author.display_name}, content:{m.content}" for m in messages]
|
||||
|
||||
return ("<|endofstatement|>\n\n".join(msg_txt_list), channel.name)
|
||||
|
||||
async def load_data(
|
||||
self,
|
||||
channel_ids: List[int],
|
||||
limit: Optional[int] = None,
|
||||
oldest_first: bool = True,
|
||||
) -> List[Document]:
|
||||
"""Load data from the input directory.
|
||||
|
||||
Args:
|
||||
channel_ids (List[int]): List of channel ids to read.
|
||||
limit (Optional[int]): Maximum number of messages to read.
|
||||
oldest_first (bool): Whether to read oldest messages first.
|
||||
Defaults to `True`.
|
||||
|
||||
Returns:
|
||||
List[Document]: List of documents.
|
||||
|
||||
"""
|
||||
results: List[Document] = []
|
||||
for channel_id in channel_ids:
|
||||
if not isinstance(channel_id, int):
|
||||
raise ValueError(
|
||||
f"Channel id {channel_id} must be an integer, "
|
||||
f"not {type(channel_id)}."
|
||||
)
|
||||
(channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first)
|
||||
results.append(
|
||||
Document(channel_content, extra_info={"channel_name": channel_name})
|
||||
)
|
||||
return results
|
||||
|
||||
async def compose(self, ctx: discord.ApplicationContext, name, user_api_key):
|
||||
# Send the ComposeModal
|
||||
if not user_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = self.openai_key
|
||||
else:
|
||||
os.environ["OPENAI_API_KEY"] = user_api_key
|
||||
|
||||
if not self.index_storage[ctx.user.id].has_indexes(ctx.user.id):
|
||||
await ctx.respond("You must load at least one indexes before composing")
|
||||
return
|
||||
|
||||
await ctx.respond("Select the index(es) to compose. You can compose multiple indexes together, you can also Deep Compose a single index.", view=ComposeModal(self, ctx.user.id, name), ephemeral=True)
|
||||
|
||||
|
||||
class ComposeModal(discord.ui.View):
|
||||
def __init__(self, index_cog, user_id, name=None, deep=None) -> None:
|
||||
super().__init__()
|
||||
# Get the argument named "user_key_db" and save it as USER_KEY_DB
|
||||
self.index_cog = index_cog
|
||||
self.user_id = user_id
|
||||
self.deep = deep
|
||||
|
||||
# Get all the indexes for the user
|
||||
self.indexes = [
|
||||
file
|
||||
for file in os.listdir(EnvService.find_shared_file(f"indexes/{str(user_id)}/"))
|
||||
]
|
||||
|
||||
# Map everything into the short to long cache
|
||||
for index in self.indexes:
|
||||
SHORT_TO_LONG_CACHE[index[:99]] = index
|
||||
|
||||
# A text entry field for the name of the composed index
|
||||
self.name = name
|
||||
|
||||
# A discord UI select menu with all the indexes. Limited to 25 entries. For the label field in the SelectOption,
|
||||
# cut it off at 100 characters to prevent the message from being too long
|
||||
|
||||
self.index_select = discord.ui.Select(
|
||||
placeholder="Select index(es) to compose",
|
||||
options=[
|
||||
discord.SelectOption(label=str(index)[:99], value=index[:99])
|
||||
for index in self.indexes
|
||||
][0:25],
|
||||
max_values=len(self.indexes) if len(self.indexes) < 25 else 25,
|
||||
min_values=1,
|
||||
|
||||
)
|
||||
# Add the select menu to the modal
|
||||
self.add_item(self.index_select)
|
||||
|
||||
# If we have more than 25 entries, add more Select fields as neccessary
|
||||
self.extra_index_selects = []
|
||||
if len(self.indexes) > 25:
|
||||
for i in range(25, len(self.indexes), 25):
|
||||
self.extra_index_selects.append(discord.ui.Select(
|
||||
placeholder="Select index(es) to compose",
|
||||
options=[
|
||||
discord.SelectOption(label=index[:99], value=index[:99])
|
||||
for index in self.indexes
|
||||
][i:i+25],
|
||||
max_values=len(self.indexes[i:i+25]),
|
||||
min_values=1,
|
||||
))
|
||||
self.add_item(self.extra_index_selects[-1])
|
||||
|
||||
|
||||
|
||||
|
||||
# Add an input field for "Deep", a "yes" or "no" option, default no
|
||||
self.deep_select = discord.ui.Select(
|
||||
placeholder="Deep Compose",
|
||||
options=[
|
||||
discord.SelectOption(label="Yes", value="yes"),
|
||||
discord.SelectOption(label="No", value="no")
|
||||
],
|
||||
max_values=1,
|
||||
min_values=1
|
||||
)
|
||||
self.add_item(self.deep_select)
|
||||
|
||||
|
||||
# Add a button to the modal called "Compose"
|
||||
self.add_item(discord.ui.Button(label="Compose", style=discord.ButtonStyle.green, custom_id="compose"))
|
||||
|
||||
# The callback for the button
|
||||
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
||||
# Check that the interaction was for custom_id "compose"
|
||||
if interaction.data["custom_id"] == "compose":
|
||||
# Check that the user selected at least one index
|
||||
|
||||
# The total list of indexes is the union of the values of all the select menus
|
||||
indexes = self.index_select.values + [select.values[0] for select in self.extra_index_selects]
|
||||
|
||||
# Remap them from the SHORT_TO_LONG_CACHE
|
||||
indexes = [SHORT_TO_LONG_CACHE[index] for index in indexes]
|
||||
|
||||
if len(indexes) < 1:
|
||||
await interaction.response.send_message("You must select at least 1 index", ephemeral=True)
|
||||
else:
|
||||
composing_message = await interaction.response.send_message("Composing indexes, this may take a long time...", ephemeral=True, delete_after=120)
|
||||
# Compose the indexes
|
||||
await self.index_cog.compose_indexes(self.user_id,indexes,self.name, False if not self.deep_select.values or self.deep_select.values[0] == "no" else True)
|
||||
await interaction.followup.send("Composed indexes", ephemeral=True, delete_after=10)
|
||||
|
||||
try:
|
||||
await composing_message.delete()
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
await interaction.response.defer(ephemeral=True)
|
Loading…
Reference in new issue