-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
petals.py
131 lines (118 loc) · 5.69 KB
/
petals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from discord.ext import commands
import time
import aiohttp
import websockets
import json
import re
from util_discord import command_check, get_guild_prefix
# list of old models
models = [
'stabilityai/StableBeluga2',
'meta-llama/Llama-2-70b-chat-hf',
'meta-llama/Llama-2-70b-hf',
'timdettmers/guanaco-65b',
'huggyllama/llama-65b',
'bigscience/bloomz',
'meta-llama/Meta-Llama-3-70B',
'petals-team/StableBeluga2',
]
# TODO: support read replies (just supply conversation history in inputs, weird)
async def petalsWebsocket(ctx: commands.Context, arg: str, model: int):
"""
Connects to a WebSocket server and generates text using a specified model.
This function connects to the WebSocket server at 'wss://chat.petals.dev/api/v2/generate',
opens an inference session with the 'stabilityai/StableBeluga2' model, and generates text based on
the given prompt.
Returns:
None
"""
if await command_check(ctx, "petals", "ai"): return await ctx.reply("command disabled", ephemeral=True)
async with ctx.typing():
msg = await ctx.reply("**Starting session…**")
if not arg: arg = "Explain who you are, your functions, capabilities, limitations, and purpose."
text = ""
text_mod = text_inc = 50
old = round(time.time() * 1000)
uri = "wss://chat.petals.dev/api/v2/generate"
try:
async with websockets.connect(uri) as ws:
await ws.send(json.dumps({
"type": "open_inference_session",
"model": models[model],
"max_length": 2000
}))
await ws.send(json.dumps({
"type": "generate",
"inputs": f"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.###Assistant: Hi! How can I help you?###Human: {arg}###Assistant:",
"max_new_tokens": 1,
"do_sample": 1,
"temperature": 0.6,
"top_p": 0.9,
"extra_stop_sequences": ["</s>"],
"stop_sequence": "###"
}))
async for message in ws:
data = json.loads(message)
if data.get("ok"):
if data.get("outputs") is None:
await msg.edit(content="**Session opened, generating…**")
elif not data["stop"]:
text += data["outputs"]
if len(text)//text_mod!=0:
await msg.edit(content=f"**Generating response…**\nLength: {len(text)}")
text_mod += text_inc
else:
if text != "":
await send(ctx, text)
await msg.edit(content=f"**Took {round(time.time() * 1000)-old}ms**\nLength: {len(text)}")
else: await msg.edit(content=f"**Error! :(**\nEmpty response.\n{PETALS()}")
await ws.close()
else:
print("Error:", data.get("traceback"))
# Use regular expressions to extract the error message
error_match = re.search(r'Error:(.*?)(?=(\n\s{2,}File|\Z))', data.get("traceback"), re.DOTALL)
error_message = "Error message not found."
if error_match:
error_message = error_match.group(1).strip()
if text != "":
await send(ctx, text)
await msg.edit(content=f"**Took {round(time.time() * 1000)-old}ms and got interrupted with an error.**\n{error_message}\nLength: {len(text)}")
else:
await msg.edit(content=f"**Error! :(**\n{error_message}\n{PETALS()}")
await ws.close()
except:
await msg.edit(content=f"**Error! :(**\nConnection timed out.\n{PETALS()}")
async def send(ctx: commands.Context, text: str):
chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
replyFirst = True
for chunk in chunks:
if replyFirst:
replyFirst = False
await ctx.reply(chunk)
else:
await ctx.send(chunk)
async def req_real(api):
try:
async with aiohttp.ClientSession() as session:
async with session.get(api) as response:
if response.status == 200: return await response.json()
except Exception as e: print(e)
async def PETALS(ctx: commands.Context):
if await command_check(ctx, "petals", "ai"): return await ctx.reply("command disabled", ephemeral=True)
status = await req_real("https://health.petals.dev/api/v1/state")
text = f"`{await get_guild_prefix(ctx)}beluga2` petals-team/StableBeluga2```diff\n"
for i in status["model_reports"]:
text += f"{'+ ' if i['state'] == 'healthy' else '- '}{i['name']}: {i['state']}\n"
text += "```"
await ctx.reply(text)
class CogPetals(commands.Cog):
def __init__(self, bot):
self.bot = bot
@commands.command()
async def petals(self, ctx: commands.Context):
await PETALS(ctx)
@commands.command()
async def beluga2(self, ctx: commands.Context, *, arg=None):
await petalsWebsocket(ctx, arg, 7)
async def setup(bot: commands.Bot):
await bot.add_cog(CogPetals(bot))