basic transcribe from file support

Kaveen Kumarasinghe 2 years ago
parent 76beaca608
commit 96ec0a4228

@ -29,6 +29,7 @@ class Commands(discord.Cog, name="Commands"):
index_cog,
translations_cog=None,
search_cog=None,
transcribe_cog=None,
):
super().__init__()
self.bot = bot
@ -43,6 +44,7 @@ class Commands(discord.Cog, name="Commands"):
self.index_cog = index_cog
self.translations_cog = translations_cog
self.search_cog = search_cog
self.transcribe_cog = transcribe_cog
# Create slash command groups
dalle = discord.SlashCommandGroup(
@ -75,6 +77,12 @@ class Commands(discord.Cog, name="Commands"):
guild_ids=ALLOWED_GUILDS,
checks=[Check.check_index_roles()],
)
transcribe = discord.SlashCommandGroup(
name="transcribe",
description="Transcription services using OpenAI Whisper",
guild_ids=ALLOWED_GUILDS,
checks=[Check.check_index_roles()], # TODO new role checker for transcribe
)
#
# System commands
@ -1010,3 +1018,29 @@ class Commands(discord.Cog, name="Commands"):
await self.search_cog.search_command(
ctx, query, scope, nodes, deep, response_mode
)
# Transcribe commands
@add_to_group("transcribe")
@discord.slash_command(
name="file", description="Transcribe an audio or video file", guild_ids=ALLOWED_GUILDS
)
@discord.guild_only()
@discord.option(
name="file",
description="A file to transcribe",
required=True,
input_type=discord.SlashCommandOptionType.attachment,
)
@discord.option(
name="temperature",
description="The higher the value, the riskier the model will be",
required=False,
input_type=discord.SlashCommandOptionType.number,
max_value=1,
min_value=0,
)
async def transcribe_file(
self, ctx: discord.ApplicationContext, file: discord.Attachment, temperature: float
):
await self.transcribe_cog.transcribe_file_command(ctx, file, temperature)

@ -0,0 +1,76 @@
import traceback
import aiohttp
import discord
from discord.ext import pages
from models.deepl_model import TranslationModel
from models.embed_statics_model import EmbedStatics
from services.environment_service import EnvService
from services.text_service import TextService
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = EnvService.get_api_db()
class TranscribeService(discord.Cog, name="TranscribeService"):
"""Cog containing translation commands and retrieval of transcribe services"""
def __init__(
self,
bot,
model,
usage_service,
):
super().__init__()
self.bot = bot
self.model = model
self.usage_service = usage_service
async def transcribe_file_command(self, ctx: discord.ApplicationContext, file: discord.Attachment, temperature: float):
# Check if this discord file is an instance of mp3, mp4, mpeg, mpga, m4a, wav, or webm.
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(
ctx.user.id, ctx, USER_KEY_DB
)
if not user_api_key:
return
if not file.filename.endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')):
await ctx.respond("Please upload a valid audio/video file.")
return
# Also check the file metadata in case it is actually an audio/video file but with a weird ending
if not file.content_type.startswith(('audio/', 'video/')):
await ctx.respond("Please upload a valid audio/video file.")
return
response_message = await ctx.respond(embed=EmbedStatics.build_transcribe_progress_embed())
try:
response = await self.model.send_transcription_request(file, temperature, user_api_key)
if len(response) > 4080:
# Chunk the response into 2048 character chunks, each an embed page
chunks = [response[i:i+2048] for i in range(0, len(response), 2048)]
embed_pages = []
for chunk in chunks:
embed_pages.append(discord.Embed(title="Transcription Page {}".format(len(embed_pages) + 1), description=chunk))
paginator = pages.Paginator(
pages=embed_pages,
timeout=None,
author_check=False,
)
await paginator.respond(ctx.interaction)
await response_message.delete_original_response()
return
await response_message.edit_original_response(embed=EmbedStatics.build_transcribe_success_embed(response))
except Exception as e:
await response_message.edit_original_response(embed=EmbedStatics.build_transcribe_failed_embed(str(e)))

@ -17,6 +17,7 @@ from cogs.image_service_cog import DrawDallEService
from cogs.prompt_optimizer_cog import ImgPromptOptimizer
from cogs.moderations_service_cog import ModerationsService
from cogs.commands import Commands
from cogs.transcription_service_cog import TranscribeService
from cogs.translation_service_cog import TranslationService
from cogs.index_service_cog import IndexService
from models.deepl_model import TranslationModel
@ -31,7 +32,7 @@ from services.environment_service import EnvService
from models.openai_model import Model
__version__ = "10.9.0"
__version__ = "10.9.1"
PID_FILE = Path("bot.pid")
@ -174,6 +175,14 @@ async def main():
bot.add_cog(SearchService(bot, model, usage_service))
print("The Search service is enabled.")
bot.add_cog(
TranscribeService(
bot,
model,
usage_service,
)
)
bot.add_cog(
Commands(
bot,
@ -188,6 +197,7 @@ async def main():
bot.get_cog("IndexService"),
bot.get_cog("TranslationService"),
bot.get_cog("SearchService"),
bot.get_cog("TranscribeService"),
)
)

@ -275,3 +275,35 @@ class EmbedStatics:
# thumbnail of https://i.imgur.com/7JF0oGD.png
embed.set_thumbnail(url="https://i.imgur.com/7JF0oGD.png")
return embed
@staticmethod
def build_transcribe_progress_embed():
embed = discord.Embed(
title="Transcriber",
description=f"Your transcription request has been sent, this may take a while.",
color=discord.Color.blurple(),
)
embed.set_thumbnail(url="https://i.imgur.com/txHhNzL.png")
return embed
@staticmethod
def build_transcribe_success_embed(transcribed_text):
embed = discord.Embed(
title="Transcriber",
description=f"Transcribed successfully:\n`{transcribed_text}`",
color=discord.Color.green(),
)
# thumbnail of https://i.imgur.com/7JF0oGD.png
embed.set_thumbnail(url="https://i.imgur.com/7JF0oGD.png")
return embed
@staticmethod
def build_transcribe_failed_embed(message):
embed = discord.Embed(
title="Transcriber",
description=f"Transcription failed: " + message,
color=discord.Color.red(),
)
embed.set_thumbnail(url="https://i.imgur.com/VLJ32x7.png")
return embed

@ -915,6 +915,35 @@ class Model:
return response
@backoff.on_exception(
backoff.expo,
ValueError,
factor=3,
base=5,
max_tries=4,
on_backoff=backoff_handler_request,
)
async def send_transcription_request(self, file: discord.Attachment, temperature_override=None, custom_api_key=None, ):
async with aiohttp.ClientSession(raise_for_status=True) as session:
data = aiohttp.FormData()
data.add_field("model", "whisper-1")
data.add_field(
"file", await file.read(), filename="audio."+file.filename.split(".")[-1], content_type=file.content_type
)
if temperature_override:
data.add_field("temperature", temperature_override)
async with session.post(
"https://api.openai.com/v1/audio/transcriptions",
headers={
"Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}",
},
data=data,
) as resp:
response = await resp.json()
return response['text']
@backoff.on_exception(
backoff.expo,
ValueError,

Loading…
Cancel
Save