-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbot.py
More file actions
83 lines (70 loc) · 2.69 KB
/
bot.py
File metadata and controls
83 lines (70 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python3
from websockets import client
import discord
import asyncio
import config
from discord.ext import commands
from discord.ext.commands import has_permissions, CheckFailure
from config import *
if config.model == "falcon":
from models.falcon import *
print("Loaded Falcon Model")
elif config.model == "vicuna":
from models.vicuna import *
print("Loaded Vicuna Model")
elif config.model == "llama":
from models.vicuna import *
print("Loaded LLaMA Model")
elif config.model == "guanaco":
from models.guanaco import *
print("Loaded Guanaco Model")
elif config.model == "mpt":
from models.mpt import *
print("Loaded MPT Model")
else:
print("Improper model name passed, loading default model (Falcon)")
from falcon import *
client = commands.Bot(command_prefix=bot_prefix, intents=discord.Intents.all())
client.remove_command("help")
continue_generation = asyncio.Event() # Event to control generation
generation_task = None # Task representing the generation process
@client.event
async def on_ready():
print(f"Logged in as {client.user.name}")
@client.event
async def on_message(message):
global generation_task
if client.user in message.mentions:
words = ' '.join(message.content.split()[1:]) # Extract the words after the mention
prompt = words
print(words)
continue_generation.set() # Set the event to continue generation
if generation_task and not generation_task.done():
await message.channel.send("Generation is already in progress.")
else:
async def generate_sentences():
sentences_generator = generate_from_model(prompt)
for sentence in sentences_generator:
async with message.channel.typing():
if not continue_generation.is_set(): # Check the event
break
if sentence == '\n' or sentence == "":
print("")
else:
await message.channel.send(sentence)
generation_task = asyncio.create_task(generate_sentences())
else:
await client.process_commands(message)
@client.command()
async def stop(ctx, *args):
global continue_generation, generation_task
if generation_task and not generation_task.done():
continue_generation.clear() # Clear the event to stop generation
await generation_task
await ctx.send("Generation stopped.")
else:
await ctx.send("No generation is currently in progress.")
@client.command()
async def model(ctx):
await ctx.send("Currently using: " + str(config.model))
client.run(token)