diff --git a/cogs/search_service_cog.py b/cogs/search_service_cog.py index 5a9d43a..0093b74 100644 --- a/cogs/search_service_cog.py +++ b/cogs/search_service_cog.py @@ -14,6 +14,12 @@ ALLOWED_GUILDS = EnvService.get_allowed_guilds() USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() USER_KEY_DB = EnvService.get_api_db() +class RedoSearchUser: + def __init__(self, ctx, query, search_scope, nodes): + self.ctx = ctx + self.query = query + self.search_scope = search_scope + self.nodes = nodes class SearchService(discord.Cog, name="SearchService"): """Cog containing translation commands and retrieval of translation services""" @@ -29,6 +35,7 @@ class SearchService(discord.Cog, name="SearchService"): self.usage_service = usage_service self.model = Search(gpt_model, usage_service) self.EMBED_CUTOFF = 2000 + self.redo_users = {} # Make a mapping of all the country codes and their full country names: async def paginate_embed(self, response_text, user: discord.Member): @@ -59,7 +66,7 @@ class SearchService(discord.Cog, name="SearchService"): return pages async def search_command( - self, ctx: discord.ApplicationContext, query, search_scope, nodes + self, ctx: discord.ApplicationContext, query, search_scope, nodes, redo=None ): """Command handler for the translation command""" user_api_key = None @@ -77,7 +84,7 @@ class SearchService(discord.Cog, name="SearchService"): await ctx.respond("The search service is not enabled.") return - await ctx.defer() + await ctx.defer() if not redo else None try: response = await self.model.search( @@ -116,6 +123,23 @@ class SearchService(discord.Cog, name="SearchService"): pages=embed_pages, timeout=None, author_check=False, + custom_view=RedoButton(ctx, self), ) + self.redo_users[ctx.user.id] = RedoSearchUser(ctx, query, search_scope, nodes) + await paginator.respond(ctx.interaction) + + +# A view for a redo button +class RedoButton(discord.ui.View): + def __init__(self, ctx: discord.ApplicationContext, search_cog): + super().__init__() + self.ctx = ctx + self.search_cog = search_cog + + @discord.ui.button(label="Redo", style=discord.ButtonStyle.danger) + async def redo(self, button: discord.ui.Button, interaction: discord.Interaction): + """Redo the translation""" + await interaction.response.send_message("Redoing search...", ephemeral=True, delete_after=15) + await self.search_cog.search_command(self.search_cog.redo_users[self.ctx.user.id].ctx, self.search_cog.redo_users[self.ctx.user.id].query, self.search_cog.redo_users[self.ctx.user.id].search_scope, self.search_cog.redo_users[self.ctx.user.id].nodes, redo=True) diff --git a/models/search_model.py b/models/search_model.py index 482e615..58b65e9 100644 --- a/models/search_model.py +++ b/models/search_model.py @@ -159,7 +159,7 @@ class Search: pass async def search( - self, ctx: discord.ApplicationContext, query, user_api_key, search_scope, nodes + self, ctx: discord.ApplicationContext, query, user_api_key, search_scope, nodes, redo=None ): DEFAULT_SEARCH_NODES = 1 if not user_api_key: @@ -170,7 +170,7 @@ class Search: if ctx: in_progress_message = await ctx.respond( embed=self.build_search_started_embed() - ) + ) if not redo else await ctx.channel.send(embed=self.build_search_started_embed()) llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003")) try: