Format Python code with psf/black push

github-actions 2 years ago
parent 2bdd9baced
commit 8a4442a149

@ -81,7 +81,7 @@ class Commands(discord.Cog, name="Commands"):
name="transcribe", name="transcribe",
description="Transcription services using OpenAI Whisper", description="Transcription services using OpenAI Whisper",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_index_roles()], # TODO new role checker for transcribe checks=[Check.check_index_roles()], # TODO new role checker for transcribe
) )
# #
@ -1019,11 +1019,12 @@ class Commands(discord.Cog, name="Commands"):
ctx, query, scope, nodes, deep, response_mode ctx, query, scope, nodes, deep, response_mode
) )
# Transcribe commands # Transcribe commands
@add_to_group("transcribe") @add_to_group("transcribe")
@discord.slash_command( @discord.slash_command(
name="file", description="Transcribe an audio or video file", guild_ids=ALLOWED_GUILDS name="file",
description="Transcribe an audio or video file",
guild_ids=ALLOWED_GUILDS,
) )
@discord.guild_only() @discord.guild_only()
@discord.option( @discord.option(
@ -1041,6 +1042,9 @@ class Commands(discord.Cog, name="Commands"):
min_value=0, min_value=0,
) )
async def transcribe_file( async def transcribe_file(
self, ctx: discord.ApplicationContext, file: discord.Attachment, temperature: float self,
ctx: discord.ApplicationContext,
file: discord.Attachment,
temperature: float,
): ):
await self.transcribe_cog.transcribe_file_command(ctx, file, temperature) await self.transcribe_cog.transcribe_file_command(ctx, file, temperature)

@ -12,6 +12,8 @@ from services.text_service import TextService
ALLOWED_GUILDS = EnvService.get_allowed_guilds() ALLOWED_GUILDS = EnvService.get_allowed_guilds()
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = EnvService.get_api_db() USER_KEY_DB = EnvService.get_api_db()
class TranscribeService(discord.Cog, name="TranscribeService"): class TranscribeService(discord.Cog, name="TranscribeService"):
"""Cog containing translation commands and retrieval of transcribe services""" """Cog containing translation commands and retrieval of transcribe services"""
@ -26,7 +28,12 @@ class TranscribeService(discord.Cog, name="TranscribeService"):
self.model = model self.model = model
self.usage_service = usage_service self.usage_service = usage_service
async def transcribe_file_command(self, ctx: discord.ApplicationContext, file: discord.Attachment, temperature: float): async def transcribe_file_command(
self,
ctx: discord.ApplicationContext,
file: discord.Attachment,
temperature: float,
):
# Check if this discord file is an instance of mp3, mp4, mpeg, mpga, m4a, wav, or webm. # Check if this discord file is an instance of mp3, mp4, mpeg, mpga, m4a, wav, or webm.
user_api_key = None user_api_key = None
@ -37,28 +44,37 @@ class TranscribeService(discord.Cog, name="TranscribeService"):
if not user_api_key: if not user_api_key:
return return
if not file.filename.endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')): if not file.filename.endswith(
(".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
):
await ctx.respond("Please upload a valid audio/video file.") await ctx.respond("Please upload a valid audio/video file.")
return return
# Also check the file metadata in case it is actually an audio/video file but with a weird ending # Also check the file metadata in case it is actually an audio/video file but with a weird ending
if not file.content_type.startswith(('audio/', 'video/')): if not file.content_type.startswith(("audio/", "video/")):
await ctx.respond("Please upload a valid audio/video file.") await ctx.respond("Please upload a valid audio/video file.")
return return
response_message = await ctx.respond(embed=EmbedStatics.build_transcribe_progress_embed()) response_message = await ctx.respond(
embed=EmbedStatics.build_transcribe_progress_embed()
)
try: try:
response = await self.model.send_transcription_request(
response = await self.model.send_transcription_request(file, temperature, user_api_key) file, temperature, user_api_key
)
if len(response) > 4080: if len(response) > 4080:
# Chunk the response into 2048 character chunks, each an embed page # Chunk the response into 2048 character chunks, each an embed page
chunks = [response[i:i+2048] for i in range(0, len(response), 2048)] chunks = [response[i : i + 2048] for i in range(0, len(response), 2048)]
embed_pages = [] embed_pages = []
for chunk in chunks: for chunk in chunks:
embed_pages.append(discord.Embed(title="Transcription Page {}".format(len(embed_pages) + 1), description=chunk)) embed_pages.append(
discord.Embed(
title="Transcription Page {}".format(len(embed_pages) + 1),
description=chunk,
)
)
paginator = pages.Paginator( paginator = pages.Paginator(
pages=embed_pages, pages=embed_pages,
@ -70,7 +86,10 @@ class TranscribeService(discord.Cog, name="TranscribeService"):
await response_message.delete_original_response() await response_message.delete_original_response()
return return
await response_message.edit_original_response(embed=EmbedStatics.build_transcribe_success_embed(response)) await response_message.edit_original_response(
embed=EmbedStatics.build_transcribe_success_embed(response)
)
except Exception as e: except Exception as e:
await response_message.edit_original_response(embed=EmbedStatics.build_transcribe_failed_embed(str(e))) await response_message.edit_original_response(
embed=EmbedStatics.build_transcribe_failed_embed(str(e))
)

@ -276,7 +276,6 @@ class EmbedStatics:
embed.set_thumbnail(url="https://i.imgur.com/7JF0oGD.png") embed.set_thumbnail(url="https://i.imgur.com/7JF0oGD.png")
return embed return embed
@staticmethod @staticmethod
def build_transcribe_progress_embed(): def build_transcribe_progress_embed():
embed = discord.Embed( embed = discord.Embed(

@ -923,26 +923,33 @@ class Model:
max_tries=4, max_tries=4,
on_backoff=backoff_handler_request, on_backoff=backoff_handler_request,
) )
async def send_transcription_request(self, file: discord.Attachment, temperature_override=None, custom_api_key=None, ): async def send_transcription_request(
self,
file: discord.Attachment,
temperature_override=None,
custom_api_key=None,
):
async with aiohttp.ClientSession(raise_for_status=True) as session: async with aiohttp.ClientSession(raise_for_status=True) as session:
data = aiohttp.FormData() data = aiohttp.FormData()
data.add_field("model", "whisper-1") data.add_field("model", "whisper-1")
data.add_field( data.add_field(
"file", await file.read(), filename="audio."+file.filename.split(".")[-1], content_type=file.content_type "file",
await file.read(),
filename="audio." + file.filename.split(".")[-1],
content_type=file.content_type,
) )
if temperature_override: if temperature_override:
data.add_field("temperature", temperature_override) data.add_field("temperature", temperature_override)
async with session.post( async with session.post(
"https://api.openai.com/v1/audio/transcriptions", "https://api.openai.com/v1/audio/transcriptions",
headers={ headers={
"Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}", "Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}",
}, },
data=data, data=data,
) as resp: ) as resp:
response = await resp.json() response = await resp.json()
return response['text'] return response["text"]
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,

Loading…
Cancel
Save