move overrides into class

Kaveen Kumarasinghe 1 year ago
parent 595feaf539
commit f473342bff

@ -306,7 +306,7 @@ class Commands(discord.Cog, name="Commands"):
description="Higher values means the model will take more risks", description="Higher values means the model will take more risks",
required=False, required=False,
min_value=0, min_value=0,
max_value=1, max_value=2,
) )
@discord.option( @discord.option(
name="top_p", name="top_p",
@ -366,7 +366,7 @@ class Commands(discord.Cog, name="Commands"):
required=False, required=False,
input_type=float, input_type=float,
min_value=0, min_value=0,
max_value=1, max_value=2,
) )
@discord.option( @discord.option(
name="top_p", name="top_p",

@ -11,6 +11,7 @@ import json
import discord import discord
from models.embed_statics_model import EmbedStatics from models.embed_statics_model import EmbedStatics
from models.openai_model import Override
from services.environment_service import EnvService from services.environment_service import EnvService
from services.message_queue_service import Message from services.message_queue_service import Message
from services.moderations_service import Moderation from services.moderations_service import Moderation
@ -720,15 +721,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
await ctx.defer() await ctx.defer()
overrides = Override(temperature,top_p,frequency_penalty,presence_penalty)
await TextService.encapsulated_send( await TextService.encapsulated_send(
self, self,
user.id, user.id,
prompt, prompt,
ctx, ctx,
temp_override=temperature, overrides=overrides,
top_p_override=top_p,
frequency_penalty_override=frequency_penalty,
presence_penalty_override=presence_penalty,
from_ask_command=True, from_ask_command=True,
custom_api_key=user_api_key, custom_api_key=user_api_key,
from_action=from_action, from_action=from_action,
@ -766,13 +766,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
await ctx.defer() await ctx.defer()
overrides = Override(temperature,top_p,0,0)
await TextService.encapsulated_send( await TextService.encapsulated_send(
self, self,
user.id, user.id,
prompt=text, prompt=text,
ctx=ctx, ctx=ctx,
temp_override=temperature, overrides=overrides,
top_p_override=top_p,
instruction=instruction, instruction=instruction,
from_edit_command=True, from_edit_command=True,
codex=codex, codex=codex,
@ -963,6 +964,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
self.conversation_threads[thread.id].count += 1 self.conversation_threads[thread.id].count += 1
overrides = Override(overrides['temperature'], overrides['top_p'], overrides['frequency_penalty'], overrides['presence_penalty'])
await TextService.encapsulated_send( await TextService.encapsulated_send(
self, self,
thread.id, thread.id,
@ -972,10 +975,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
[item.text for item in self.conversation_threads[thread.id].history] [item.text for item in self.conversation_threads[thread.id].history]
), ),
thread_message, thread_message,
temp_override=overrides["temperature"], overrides=overrides,
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
user=user, user=user,
model=self.conversation_threads[thread.id].model, model=self.conversation_threads[thread.id].model,
custom_api_key=user_api_key, custom_api_key=user_api_key,

@ -31,6 +31,13 @@ class Mode:
ALL_MODES = [TEMPERATURE, TOP_P] ALL_MODES = [TEMPERATURE, TOP_P]
class Override:
def __init__(self, temp=None, top_p=None, frequency=None, presence=None):
self.temperature = temp
self.top_p = top_p
self.frequency_penalty = frequency
self.presence_penalty = presence
class Models: class Models:
# Text models # Text models

@ -8,7 +8,7 @@ from discord.ext import pages
from models.embed_statics_model import EmbedStatics from models.embed_statics_model import EmbedStatics
from services.deletion_service import Deletion from services.deletion_service import Deletion
from models.openai_model import Model from models.openai_model import Model, Override
from models.user_model import EmbeddedConversationItem, RedoUser from models.user_model import EmbeddedConversationItem, RedoUser
from services.environment_service import EnvService from services.environment_service import EnvService
@ -26,10 +26,7 @@ class TextService:
prompt, prompt,
ctx, ctx,
response_message=None, response_message=None,
temp_override=None, overrides=None,
top_p_override=None,
frequency_penalty_override=None,
presence_penalty_override=None,
instruction=None, instruction=None,
from_ask_command=False, from_ask_command=False,
from_edit_command=False, from_edit_command=False,
@ -268,8 +265,8 @@ class TextService:
response = await converser_cog.model.send_edit_request( response = await converser_cog.model.send_edit_request(
text=new_prompt, text=new_prompt,
instruction=instruction, instruction=instruction,
temp_override=temp_override, temp_override=overrides.temperature,
top_p_override=top_p_override, top_p_override=overrides.top_p,
codex=codex, codex=codex,
custom_api_key=custom_api_key, custom_api_key=custom_api_key,
) )
@ -277,10 +274,10 @@ class TextService:
response = await converser_cog.model.send_request( response = await converser_cog.model.send_request(
new_prompt, new_prompt,
tokens=tokens, tokens=tokens,
temp_override=temp_override, temp_override=overrides.temperature,
top_p_override=top_p_override, top_p_override=overrides.top_p,
frequency_penalty_override=frequency_penalty_override, frequency_penalty_override=overrides.frequency_penalty,
presence_penalty_override=presence_penalty_override, presence_penalty_override=overrides.presence_penalty,
model=model, model=model,
stop=stop if not from_ask_command else None, stop=stop if not from_ask_command else None,
custom_api_key=custom_api_key, custom_api_key=custom_api_key,
@ -622,19 +619,18 @@ class TextService:
) )
# set conversation overrides # set conversation overrides
overrides = converser_cog.conversation_threads[ conversation_overrides = converser_cog.conversation_threads[
message.channel.id message.channel.id
].get_overrides() ].get_overrides()
overrides = Override(conversation_overrides['temperature'],conversation_overrides['top_p'],conversation_overrides['frequency_penalty'],conversation_overrides['presence_penalty'])
await TextService.encapsulated_send( await TextService.encapsulated_send(
converser_cog, converser_cog,
message.channel.id, message.channel.id,
primary_prompt, primary_prompt,
message, message,
temp_override=overrides["temperature"], overrides=overrides,
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
model=converser_cog.conversation_threads[message.channel.id].model, model=converser_cog.conversation_threads[message.channel.id].model,
custom_api_key=user_api_key, custom_api_key=user_api_key,
) )

Loading…
Cancel
Save