diff --git a/cogs/commands.py b/cogs/commands.py index ed6439b..b3ccc02 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -81,7 +81,7 @@ class Commands(discord.Cog, name="Commands"): name="transcribe", description="Transcription services using OpenAI Whisper", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_index_roles()], # TODO new role checker for transcribe + checks=[Check.check_index_roles()], # TODO new role checker for transcribe ) # @@ -1019,11 +1019,12 @@ class Commands(discord.Cog, name="Commands"): ctx, query, scope, nodes, deep, response_mode ) - # Transcribe commands @add_to_group("transcribe") @discord.slash_command( - name="file", description="Transcribe an audio or video file", guild_ids=ALLOWED_GUILDS + name="file", + description="Transcribe an audio or video file", + guild_ids=ALLOWED_GUILDS, ) @discord.guild_only() @discord.option( @@ -1041,6 +1042,9 @@ class Commands(discord.Cog, name="Commands"): min_value=0, ) async def transcribe_file( - self, ctx: discord.ApplicationContext, file: discord.Attachment, temperature: float + self, + ctx: discord.ApplicationContext, + file: discord.Attachment, + temperature: float, ): await self.transcribe_cog.transcribe_file_command(ctx, file, temperature) diff --git a/cogs/transcription_service_cog.py b/cogs/transcription_service_cog.py index ac7743f..aa8c1fd 100644 --- a/cogs/transcription_service_cog.py +++ b/cogs/transcription_service_cog.py @@ -12,6 +12,8 @@ from services.text_service import TextService ALLOWED_GUILDS = EnvService.get_allowed_guilds() USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() USER_KEY_DB = EnvService.get_api_db() + + class TranscribeService(discord.Cog, name="TranscribeService"): """Cog containing translation commands and retrieval of transcribe services""" @@ -26,7 +28,12 @@ class TranscribeService(discord.Cog, name="TranscribeService"): self.model = model self.usage_service = usage_service - async def transcribe_file_command(self, ctx: discord.ApplicationContext, file: discord.Attachment, temperature: float): + async def transcribe_file_command( + self, + ctx: discord.ApplicationContext, + file: discord.Attachment, + temperature: float, + ): # Check if this discord file is an instance of mp3, mp4, mpeg, mpga, m4a, wav, or webm. user_api_key = None @@ -37,28 +44,37 @@ class TranscribeService(discord.Cog, name="TranscribeService"): if not user_api_key: return - if not file.filename.endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')): + if not file.filename.endswith( + (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm") + ): await ctx.respond("Please upload a valid audio/video file.") return # Also check the file metadata in case it is actually an audio/video file but with a weird ending - if not file.content_type.startswith(('audio/', 'video/')): + if not file.content_type.startswith(("audio/", "video/")): await ctx.respond("Please upload a valid audio/video file.") return - response_message = await ctx.respond(embed=EmbedStatics.build_transcribe_progress_embed()) + response_message = await ctx.respond( + embed=EmbedStatics.build_transcribe_progress_embed() + ) try: - - response = await self.model.send_transcription_request(file, temperature, user_api_key) + response = await self.model.send_transcription_request( + file, temperature, user_api_key + ) if len(response) > 4080: # Chunk the response into 2048 character chunks, each an embed page - chunks = [response[i:i+2048] for i in range(0, len(response), 2048)] + chunks = [response[i : i + 2048] for i in range(0, len(response), 2048)] embed_pages = [] for chunk in chunks: - embed_pages.append(discord.Embed(title="Transcription Page {}".format(len(embed_pages) + 1), description=chunk)) - + embed_pages.append( + discord.Embed( + title="Transcription Page {}".format(len(embed_pages) + 1), + description=chunk, + ) + ) paginator = pages.Paginator( pages=embed_pages, @@ -70,7 +86,10 @@ class TranscribeService(discord.Cog, name="TranscribeService"): await response_message.delete_original_response() return - await response_message.edit_original_response(embed=EmbedStatics.build_transcribe_success_embed(response)) + await response_message.edit_original_response( + embed=EmbedStatics.build_transcribe_success_embed(response) + ) except Exception as e: - await response_message.edit_original_response(embed=EmbedStatics.build_transcribe_failed_embed(str(e))) - + await response_message.edit_original_response( + embed=EmbedStatics.build_transcribe_failed_embed(str(e)) + ) diff --git a/models/embed_statics_model.py b/models/embed_statics_model.py index 706cd5b..725bcaf 100644 --- a/models/embed_statics_model.py +++ b/models/embed_statics_model.py @@ -276,7 +276,6 @@ class EmbedStatics: embed.set_thumbnail(url="https://i.imgur.com/7JF0oGD.png") return embed - @staticmethod def build_transcribe_progress_embed(): embed = discord.Embed( diff --git a/models/openai_model.py b/models/openai_model.py index 705b71a..6facb51 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -923,26 +923,33 @@ class Model: max_tries=4, on_backoff=backoff_handler_request, ) - async def send_transcription_request(self, file: discord.Attachment, temperature_override=None, custom_api_key=None, ): - + async def send_transcription_request( + self, + file: discord.Attachment, + temperature_override=None, + custom_api_key=None, + ): async with aiohttp.ClientSession(raise_for_status=True) as session: data = aiohttp.FormData() data.add_field("model", "whisper-1") data.add_field( - "file", await file.read(), filename="audio."+file.filename.split(".")[-1], content_type=file.content_type + "file", + await file.read(), + filename="audio." + file.filename.split(".")[-1], + content_type=file.content_type, ) if temperature_override: data.add_field("temperature", temperature_override) async with session.post( - "https://api.openai.com/v1/audio/transcriptions", - headers={ - "Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}", - }, - data=data, + "https://api.openai.com/v1/audio/transcriptions", + headers={ + "Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}", + }, + data=data, ) as resp: response = await resp.json() - return response['text'] + return response["text"] @backoff.on_exception( backoff.expo,