diff --git a/cogs/commands.py b/cogs/commands.py index dc1a24e..ed6439b 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -29,6 +29,7 @@ class Commands(discord.Cog, name="Commands"): index_cog, translations_cog=None, search_cog=None, + transcribe_cog=None, ): super().__init__() self.bot = bot @@ -43,6 +44,7 @@ class Commands(discord.Cog, name="Commands"): self.index_cog = index_cog self.translations_cog = translations_cog self.search_cog = search_cog + self.transcribe_cog = transcribe_cog # Create slash command groups dalle = discord.SlashCommandGroup( @@ -75,6 +77,12 @@ class Commands(discord.Cog, name="Commands"): guild_ids=ALLOWED_GUILDS, checks=[Check.check_index_roles()], ) + transcribe = discord.SlashCommandGroup( + name="transcribe", + description="Transcription services using OpenAI Whisper", + guild_ids=ALLOWED_GUILDS, + checks=[Check.check_index_roles()], # TODO new role checker for transcribe + ) # # System commands @@ -1010,3 +1018,29 @@ class Commands(discord.Cog, name="Commands"): await self.search_cog.search_command( ctx, query, scope, nodes, deep, response_mode ) + + + # Transcribe commands + @add_to_group("transcribe") + @discord.slash_command( + name="file", description="Transcribe an audio or video file", guild_ids=ALLOWED_GUILDS + ) + @discord.guild_only() + @discord.option( + name="file", + description="A file to transcribe", + required=True, + input_type=discord.SlashCommandOptionType.attachment, + ) + @discord.option( + name="temperature", + description="The higher the value, the riskier the model will be", + required=False, + input_type=discord.SlashCommandOptionType.number, + max_value=1, + min_value=0, + ) + async def transcribe_file( + self, ctx: discord.ApplicationContext, file: discord.Attachment, temperature: float + ): + await self.transcribe_cog.transcribe_file_command(ctx, file, temperature) diff --git a/cogs/transcription_service_cog.py b/cogs/transcription_service_cog.py new file mode 100644 index 0000000..ac7743f --- /dev/null +++ b/cogs/transcription_service_cog.py @@ -0,0 +1,76 @@ +import traceback + +import aiohttp +import discord +from discord.ext import pages + +from models.deepl_model import TranslationModel +from models.embed_statics_model import EmbedStatics +from services.environment_service import EnvService +from services.text_service import TextService + +ALLOWED_GUILDS = EnvService.get_allowed_guilds() +USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() +USER_KEY_DB = EnvService.get_api_db() +class TranscribeService(discord.Cog, name="TranscribeService"): + """Cog containing translation commands and retrieval of transcribe services""" + + def __init__( + self, + bot, + model, + usage_service, + ): + super().__init__() + self.bot = bot + self.model = model + self.usage_service = usage_service + + async def transcribe_file_command(self, ctx: discord.ApplicationContext, file: discord.Attachment, temperature: float): + # Check if this discord file is an instance of mp3, mp4, mpeg, mpga, m4a, wav, or webm. + + user_api_key = None + if USER_INPUT_API_KEYS: + user_api_key = await TextService.get_user_api_key( + ctx.user.id, ctx, USER_KEY_DB + ) + if not user_api_key: + return + + if not file.filename.endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')): + await ctx.respond("Please upload a valid audio/video file.") + return + + # Also check the file metadata in case it is actually an audio/video file but with a weird ending + if not file.content_type.startswith(('audio/', 'video/')): + await ctx.respond("Please upload a valid audio/video file.") + return + + response_message = await ctx.respond(embed=EmbedStatics.build_transcribe_progress_embed()) + + try: + + response = await self.model.send_transcription_request(file, temperature, user_api_key) + + if len(response) > 4080: + # Chunk the response into 2048 character chunks, each an embed page + chunks = [response[i:i+2048] for i in range(0, len(response), 2048)] + embed_pages = [] + for chunk in chunks: + embed_pages.append(discord.Embed(title="Transcription Page {}".format(len(embed_pages) + 1), description=chunk)) + + + paginator = pages.Paginator( + pages=embed_pages, + timeout=None, + author_check=False, + ) + + await paginator.respond(ctx.interaction) + await response_message.delete_original_response() + return + + await response_message.edit_original_response(embed=EmbedStatics.build_transcribe_success_embed(response)) + except Exception as e: + await response_message.edit_original_response(embed=EmbedStatics.build_transcribe_failed_embed(str(e))) + diff --git a/gpt3discord.py b/gpt3discord.py index a07513f..10f4844 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -17,6 +17,7 @@ from cogs.image_service_cog import DrawDallEService from cogs.prompt_optimizer_cog import ImgPromptOptimizer from cogs.moderations_service_cog import ModerationsService from cogs.commands import Commands +from cogs.transcription_service_cog import TranscribeService from cogs.translation_service_cog import TranslationService from cogs.index_service_cog import IndexService from models.deepl_model import TranslationModel @@ -31,7 +32,7 @@ from services.environment_service import EnvService from models.openai_model import Model -__version__ = "10.9.0" +__version__ = "10.9.1" PID_FILE = Path("bot.pid") @@ -174,6 +175,14 @@ async def main(): bot.add_cog(SearchService(bot, model, usage_service)) print("The Search service is enabled.") + bot.add_cog( + TranscribeService( + bot, + model, + usage_service, + ) + ) + bot.add_cog( Commands( bot, @@ -188,6 +197,7 @@ async def main(): bot.get_cog("IndexService"), bot.get_cog("TranslationService"), bot.get_cog("SearchService"), + bot.get_cog("TranscribeService"), ) ) diff --git a/models/embed_statics_model.py b/models/embed_statics_model.py index c6cec55..706cd5b 100644 --- a/models/embed_statics_model.py +++ b/models/embed_statics_model.py @@ -275,3 +275,35 @@ class EmbedStatics: # thumbnail of https://i.imgur.com/7JF0oGD.png embed.set_thumbnail(url="https://i.imgur.com/7JF0oGD.png") return embed + + + @staticmethod + def build_transcribe_progress_embed(): + embed = discord.Embed( + title="Transcriber", + description=f"Your transcription request has been sent, this may take a while.", + color=discord.Color.blurple(), + ) + embed.set_thumbnail(url="https://i.imgur.com/txHhNzL.png") + return embed + + @staticmethod + def build_transcribe_success_embed(transcribed_text): + embed = discord.Embed( + title="Transcriber", + description=f"Transcribed successfully:\n`{transcribed_text}`", + color=discord.Color.green(), + ) + # thumbnail of https://i.imgur.com/7JF0oGD.png + embed.set_thumbnail(url="https://i.imgur.com/7JF0oGD.png") + return embed + + @staticmethod + def build_transcribe_failed_embed(message): + embed = discord.Embed( + title="Transcriber", + description=f"Transcription failed: " + message, + color=discord.Color.red(), + ) + embed.set_thumbnail(url="https://i.imgur.com/VLJ32x7.png") + return embed diff --git a/models/openai_model.py b/models/openai_model.py index 7b15c6c..705b71a 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -915,6 +915,35 @@ class Model: return response + @backoff.on_exception( + backoff.expo, + ValueError, + factor=3, + base=5, + max_tries=4, + on_backoff=backoff_handler_request, + ) + async def send_transcription_request(self, file: discord.Attachment, temperature_override=None, custom_api_key=None, ): + + async with aiohttp.ClientSession(raise_for_status=True) as session: + data = aiohttp.FormData() + data.add_field("model", "whisper-1") + data.add_field( + "file", await file.read(), filename="audio."+file.filename.split(".")[-1], content_type=file.content_type + ) + if temperature_override: + data.add_field("temperature", temperature_override) + + async with session.post( + "https://api.openai.com/v1/audio/transcriptions", + headers={ + "Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}", + }, + data=data, + ) as resp: + response = await resp.json() + return response['text'] + @backoff.on_exception( backoff.expo, ValueError,