diff --git a/models/openai_model.py b/models/openai_model.py index 7b25807..5753041 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -634,7 +634,7 @@ class Model: await self.usage_service.update_usage(tokens_used) except Exception as e: traceback.print_exc() - if 'error' in response: + if "error" in response: raise ValueError( "The API returned an invalid response: " + str(response["error"]["message"]) @@ -1018,7 +1018,11 @@ class Model: # Validate that all the parameters are in a good state before we send the request if not max_tokens_override: - if model and model not in Models.GPT4_MODELS and model not in Models.CHATGPT_MODELS: + if ( + model + and model not in Models.GPT4_MODELS + and model not in Models.CHATGPT_MODELS + ): max_tokens_override = Models.get_max_tokens(model) - tokens print(f"The prompt about to be sent is {prompt}") diff --git a/tests/test_requests.py b/tests/test_requests.py index aec41fd..0daf2a6 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -10,24 +10,24 @@ from services.usage_service import UsageService # Non-ChatGPT -> TODO: make generic test and loop through text models @pytest.mark.asyncio async def test_send_req(): - usage_service = UsageService(Path("../tests")) model = Model(usage_service) - prompt = 'how many hours are in a day?' + prompt = "how many hours are in a day?" tokens = len(GPT2TokenizerFast.from_pretrained("gpt2")(prompt)["input_ids"]) res = await model.send_request(prompt, tokens) - assert '24' in res['choices'][0]['text'] + assert "24" in res["choices"][0]["text"] # ChatGPT version @pytest.mark.asyncio async def test_send_req_gpt(): - usage_service = UsageService(Path("../tests")) model = Model(usage_service) - prompt = 'how many hours are in a day?' - res = await model.send_request(prompt, None, is_chatgpt_request=True, model="gpt-3.5-turbo") - assert '24' in res['choices'][0]['message']['content'] + prompt = "how many hours are in a day?" + res = await model.send_request( + prompt, None, is_chatgpt_request=True, model="gpt-3.5-turbo" + ) + assert "24" in res["choices"][0]["message"]["content"] # GPT4 version @@ -35,9 +35,9 @@ async def test_send_req_gpt(): async def test_send_req_gpt4(): usage_service = UsageService(Path("../tests")) model = Model(usage_service) - prompt = 'how many hours are in a day?' + prompt = "how many hours are in a day?" res = await model.send_request(prompt, None, is_chatgpt_request=True, model="gpt-4") - assert '24' in res['choices'][0]['message']['content'] + assert "24" in res["choices"][0]["message"]["content"] # Edit request -> currently broken due to endpoint @@ -48,4 +48,3 @@ async def test_send_req_gpt4(): # text = 'how many hours are in a day?' # res = await model.send_edit_request(text, codex=True) # assert '24' in res['choices'][0]['text'] -