From f473342bffbed0114f699812cf310b0c4d092422 Mon Sep 17 00:00:00 2001 From: Kaveen Kumarasinghe Date: Fri, 27 Jan 2023 14:37:47 -0500 Subject: [PATCH] move overrides into class --- cogs/commands.py | 4 ++-- cogs/text_service_cog.py | 20 ++++++++++---------- models/openai_model.py | 7 +++++++ services/text_service.py | 28 ++++++++++++---------------- 4 files changed, 31 insertions(+), 28 deletions(-) diff --git a/cogs/commands.py b/cogs/commands.py index ae7904b..78b440b 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -306,7 +306,7 @@ class Commands(discord.Cog, name="Commands"): description="Higher values means the model will take more risks", required=False, min_value=0, - max_value=1, + max_value=2, ) @discord.option( name="top_p", @@ -366,7 +366,7 @@ class Commands(discord.Cog, name="Commands"): required=False, input_type=float, min_value=0, - max_value=1, + max_value=2, ) @discord.option( name="top_p", diff --git a/cogs/text_service_cog.py b/cogs/text_service_cog.py index 0c0af98..75d1d3a 100644 --- a/cogs/text_service_cog.py +++ b/cogs/text_service_cog.py @@ -11,6 +11,7 @@ import json import discord from models.embed_statics_model import EmbedStatics +from models.openai_model import Override from services.environment_service import EnvService from services.message_queue_service import Message from services.moderations_service import Moderation @@ -720,15 +721,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await ctx.defer() + overrides = Override(temperature,top_p,frequency_penalty,presence_penalty) + await TextService.encapsulated_send( self, user.id, prompt, ctx, - temp_override=temperature, - top_p_override=top_p, - frequency_penalty_override=frequency_penalty, - presence_penalty_override=presence_penalty, + overrides=overrides, from_ask_command=True, custom_api_key=user_api_key, from_action=from_action, @@ -766,13 +766,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await ctx.defer() + overrides = Override(temperature,top_p,0,0) + await TextService.encapsulated_send( self, user.id, prompt=text, ctx=ctx, - temp_override=temperature, - top_p_override=top_p, + overrides=overrides, instruction=instruction, from_edit_command=True, codex=codex, @@ -963,6 +964,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): self.conversation_threads[thread.id].count += 1 + overrides = Override(overrides['temperature'], overrides['top_p'], overrides['frequency_penalty'], overrides['presence_penalty']) + await TextService.encapsulated_send( self, thread.id, @@ -972,10 +975,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): [item.text for item in self.conversation_threads[thread.id].history] ), thread_message, - temp_override=overrides["temperature"], - top_p_override=overrides["top_p"], - frequency_penalty_override=overrides["frequency_penalty"], - presence_penalty_override=overrides["presence_penalty"], + overrides=overrides, user=user, model=self.conversation_threads[thread.id].model, custom_api_key=user_api_key, diff --git a/models/openai_model.py b/models/openai_model.py index bd124c9..289bb9a 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -31,6 +31,13 @@ class Mode: ALL_MODES = [TEMPERATURE, TOP_P] +class Override: + def __init__(self, temp=None, top_p=None, frequency=None, presence=None): + self.temperature = temp + self.top_p = top_p + self.frequency_penalty = frequency + self.presence_penalty = presence + class Models: # Text models diff --git a/services/text_service.py b/services/text_service.py index ce4af21..4cd0d90 100644 --- a/services/text_service.py +++ b/services/text_service.py @@ -8,7 +8,7 @@ from discord.ext import pages from models.embed_statics_model import EmbedStatics from services.deletion_service import Deletion -from models.openai_model import Model +from models.openai_model import Model, Override from models.user_model import EmbeddedConversationItem, RedoUser from services.environment_service import EnvService @@ -26,10 +26,7 @@ class TextService: prompt, ctx, response_message=None, - temp_override=None, - top_p_override=None, - frequency_penalty_override=None, - presence_penalty_override=None, + overrides=None, instruction=None, from_ask_command=False, from_edit_command=False, @@ -268,8 +265,8 @@ class TextService: response = await converser_cog.model.send_edit_request( text=new_prompt, instruction=instruction, - temp_override=temp_override, - top_p_override=top_p_override, + temp_override=overrides.temperature, + top_p_override=overrides.top_p, codex=codex, custom_api_key=custom_api_key, ) @@ -277,10 +274,10 @@ class TextService: response = await converser_cog.model.send_request( new_prompt, tokens=tokens, - temp_override=temp_override, - top_p_override=top_p_override, - frequency_penalty_override=frequency_penalty_override, - presence_penalty_override=presence_penalty_override, + temp_override=overrides.temperature, + top_p_override=overrides.top_p, + frequency_penalty_override=overrides.frequency_penalty, + presence_penalty_override=overrides.presence_penalty, model=model, stop=stop if not from_ask_command else None, custom_api_key=custom_api_key, @@ -622,19 +619,18 @@ class TextService: ) # set conversation overrides - overrides = converser_cog.conversation_threads[ + conversation_overrides = converser_cog.conversation_threads[ message.channel.id ].get_overrides() + overrides = Override(conversation_overrides['temperature'],conversation_overrides['top_p'],conversation_overrides['frequency_penalty'],conversation_overrides['presence_penalty']) + await TextService.encapsulated_send( converser_cog, message.channel.id, primary_prompt, message, - temp_override=overrides["temperature"], - top_p_override=overrides["top_p"], - frequency_penalty_override=overrides["frequency_penalty"], - presence_penalty_override=overrides["presence_penalty"], + overrides=overrides, model=converser_cog.conversation_threads[message.channel.id].model, custom_api_key=user_api_key, )