You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
266 lines
7.8 KiB
266 lines
7.8 KiB
import os
|
|
import asyncio
|
|
import signal
|
|
import sys
|
|
import threading
|
|
import traceback
|
|
from pathlib import Path
|
|
from platform import system
|
|
|
|
import discord
|
|
import pinecone
|
|
from pycord.multicog import apply_multicog
|
|
|
|
from cogs.search_service_cog import SearchService
|
|
from cogs.text_service_cog import GPT3ComCon
|
|
from cogs.image_service_cog import DrawDallEService
|
|
from cogs.prompt_optimizer_cog import ImgPromptOptimizer
|
|
from cogs.moderations_service_cog import ModerationsService
|
|
from cogs.commands import Commands
|
|
from cogs.translation_service_cog import TranslationService
|
|
from models.deepl_model import TranslationModel
|
|
from services.health_service import HealthService
|
|
|
|
from services.pinecone_service import PineconeService
|
|
from services.deletion_service import Deletion
|
|
from services.message_queue_service import Message
|
|
from services.usage_service import UsageService
|
|
from services.environment_service import EnvService
|
|
|
|
from models.openai_model import Model
|
|
|
|
|
|
__version__ = "8.5.1"
|
|
|
|
PID_FILE = Path("bot.pid")
|
|
PROCESS = None
|
|
|
|
if sys.platform == "win32":
|
|
separator = "\\"
|
|
else:
|
|
separator = "/"
|
|
|
|
#
|
|
# The pinecone service is used to store and retrieve conversation embeddings.
|
|
#
|
|
|
|
try:
|
|
PINECONE_TOKEN = os.getenv("PINECONE_TOKEN")
|
|
except Exception:
|
|
PINECONE_TOKEN = None
|
|
|
|
pinecone_service = None
|
|
if PINECONE_TOKEN:
|
|
pinecone.init(api_key=PINECONE_TOKEN, environment="us-west1-gcp")
|
|
PINECONE_INDEX = "conversation-embeddings"
|
|
if PINECONE_INDEX not in pinecone.list_indexes():
|
|
print("Creating pinecone index. Please wait...")
|
|
pinecone.create_index(
|
|
PINECONE_INDEX,
|
|
dimension=1536,
|
|
metric="dotproduct",
|
|
pod_type="s1",
|
|
)
|
|
PINECONE_INDEX_SEARCH = "search-embeddings"
|
|
if EnvService.get_google_search_api_key() and EnvService.get_google_search_engine_id():
|
|
if PINECONE_INDEX_SEARCH not in pinecone.list_indexes():
|
|
print("Creating pinecone index for seraches. Please wait...")
|
|
pinecone.create_index(
|
|
PINECONE_INDEX_SEARCH,
|
|
dimension=1536,
|
|
metric="dotproduct",
|
|
pod_type="s1",
|
|
)
|
|
|
|
pinecone_service = PineconeService(pinecone.Index(PINECONE_INDEX))
|
|
pinecone_search_service = PineconeService(pinecone.Index(PINECONE_INDEX_SEARCH))
|
|
print("Got the pinecone service")
|
|
|
|
#
|
|
# Message queueing for the debug service, defer debug messages to be sent later so we don't hit rate limits.
|
|
#
|
|
message_queue = asyncio.Queue()
|
|
deletion_queue = asyncio.Queue()
|
|
asyncio.ensure_future(Message.process_message_queue(message_queue, 1.5, 5))
|
|
asyncio.ensure_future(Deletion.process_deletion_queue(deletion_queue, 1, 1))
|
|
|
|
|
|
#
|
|
# Settings for the bot
|
|
#
|
|
activity = discord.Activity(
|
|
type=discord.ActivityType.watching, name="for /help /gpt, and more!"
|
|
)
|
|
bot = discord.Bot(intents=discord.Intents.all(), command_prefix="!", activity=activity)
|
|
usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd())))
|
|
model = Model(usage_service)
|
|
|
|
|
|
#
|
|
# An encapsulating wrapper for the discord.py client. This uses the old re-write without cogs, but it gets the job done!
|
|
#
|
|
|
|
|
|
@bot.event # Using self gives u
|
|
async def on_ready(): # I can make self optional by
|
|
print("We have logged in as {0.user}".format(bot))
|
|
|
|
|
|
@bot.event
|
|
async def on_application_command_error(
|
|
ctx: discord.ApplicationContext, error: discord.DiscordException
|
|
):
|
|
if isinstance(error, discord.CheckFailure):
|
|
pass
|
|
else:
|
|
raise error
|
|
|
|
|
|
async def main():
|
|
data_path = EnvService.environment_path_with_fallback("DATA_DIR")
|
|
debug_guild = int(os.getenv("DEBUG_GUILD"))
|
|
debug_channel = int(os.getenv("DEBUG_CHANNEL"))
|
|
|
|
if not data_path.exists():
|
|
raise OSError(f"Data path: {data_path} does not exist ... create it?")
|
|
|
|
# Load the cog for the moderations service
|
|
bot.add_cog(ModerationsService(bot, usage_service, model))
|
|
|
|
# Load the main GPT3 Bot service
|
|
bot.add_cog(
|
|
GPT3ComCon(
|
|
bot,
|
|
usage_service,
|
|
model,
|
|
message_queue,
|
|
deletion_queue,
|
|
debug_guild,
|
|
debug_channel,
|
|
data_path,
|
|
pinecone_service=pinecone_service,
|
|
)
|
|
)
|
|
|
|
bot.add_cog(
|
|
DrawDallEService(
|
|
bot,
|
|
usage_service,
|
|
model,
|
|
message_queue,
|
|
deletion_queue,
|
|
bot.get_cog("GPT3ComCon"),
|
|
)
|
|
)
|
|
|
|
bot.add_cog(
|
|
ImgPromptOptimizer(
|
|
bot,
|
|
usage_service,
|
|
model,
|
|
message_queue,
|
|
deletion_queue,
|
|
bot.get_cog("GPT3ComCon"),
|
|
bot.get_cog("DrawDallEService"),
|
|
)
|
|
)
|
|
|
|
if EnvService.get_deepl_token():
|
|
bot.add_cog(TranslationService(bot, TranslationModel()))
|
|
print("The translation service is enabled.")
|
|
|
|
if (
|
|
EnvService.get_google_search_api_key()
|
|
and EnvService.get_google_search_engine_id()
|
|
):
|
|
bot.add_cog(SearchService(bot, model, pinecone_search_service))
|
|
print("The Search service is enabled.")
|
|
|
|
bot.add_cog(
|
|
Commands(
|
|
bot,
|
|
usage_service,
|
|
model,
|
|
message_queue,
|
|
deletion_queue,
|
|
bot.get_cog("GPT3ComCon"),
|
|
bot.get_cog("DrawDallEService"),
|
|
bot.get_cog("ImgPromptOptimizer"),
|
|
bot.get_cog("ModerationsService"),
|
|
bot.get_cog("TranslationService"),
|
|
bot.get_cog("SearchService"),
|
|
)
|
|
)
|
|
|
|
apply_multicog(bot)
|
|
|
|
await bot.start(os.getenv("DISCORD_TOKEN"))
|
|
|
|
|
|
def check_process_file(pid_file: Path) -> bool:
|
|
"""Check the pid file exists and if the Process ID is actually running"""
|
|
if not pid_file.exists():
|
|
return False
|
|
if system() == "Linux":
|
|
with pid_file.open("r") as pfp:
|
|
try:
|
|
proc_pid_path = Path("/proc") / "{int(pfp.read().strip())}"
|
|
print("Checking if PID proc path {proc_pid_path} exists")
|
|
except ValueError:
|
|
# We don't have a valid int in the PID File^M
|
|
pid_file.unlink()
|
|
return False
|
|
return proc_pid_path.exists()
|
|
return True
|
|
|
|
|
|
def cleanup_pid_file(signum, frame):
|
|
# Kill all threads
|
|
if PROCESS:
|
|
print("Killing all subprocesses")
|
|
PROCESS.terminate()
|
|
print("Killed all subprocesses")
|
|
# Always cleanup PID File if it exists
|
|
if PID_FILE.exists():
|
|
print(f"Removing PID file {PID_FILE}", flush=True)
|
|
PID_FILE.unlink()
|
|
|
|
|
|
# Run the bot with a token taken from an environment file.
|
|
def init():
|
|
global PROCESS
|
|
# Handle SIGTERM cleanly - Docker sends this ...
|
|
signal.signal(signal.SIGTERM, cleanup_pid_file)
|
|
|
|
if check_process_file(PID_FILE):
|
|
print("Process ID file already exists")
|
|
sys.exit(1)
|
|
else:
|
|
with PID_FILE.open("w") as f:
|
|
f.write(str(os.getpid()))
|
|
print(f"Wrote PID to file {PID_FILE}")
|
|
f.close()
|
|
try:
|
|
if EnvService.get_health_service_enabled():
|
|
try:
|
|
PROCESS = HealthService().get_process()
|
|
except:
|
|
traceback.print_exc()
|
|
print("The health service failed to start.")
|
|
|
|
asyncio.get_event_loop().run_until_complete(main())
|
|
except KeyboardInterrupt:
|
|
print("Caught keyboard interrupt, killing and removing PID")
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
print(str(e))
|
|
print("Removing PID file")
|
|
finally:
|
|
cleanup_pid_file(None, None)
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(init())
|