修复递归导致的栈溢出
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
[base]
|
||||
maxcount = 20
|
||||
[siloconflow]
|
||||
api_key = ""
|
||||
modules = ["Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-Coder-7B-Instruct"]
|
||||
|
||||
@ -6,44 +6,48 @@ import time
|
||||
|
||||
class SiloconFlowAPI:
|
||||
def __init__(self, api_key,modules,base_url):
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
self.api_key = api_key
|
||||
self.module = modules
|
||||
|
||||
def get_ai_message(self, message,i=0):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.module[i],
|
||||
messages=message,
|
||||
temperature=0.7,
|
||||
max_tokens=4096
|
||||
)
|
||||
print("token usage {0}".format(response.usage.total_tokens))
|
||||
return response.choices[0].message.content
|
||||
except:
|
||||
if i<len(self.module)-1:
|
||||
i+=1
|
||||
self.get_ai_message(message,i)
|
||||
else:
|
||||
print("server busy,waiting")
|
||||
time.sleep(10)
|
||||
i = 0
|
||||
self.get_ai_message(message,i)
|
||||
self.url = base_url
|
||||
self.count = 0
|
||||
def get_ai_message(self, message,maxcount,i=0):
|
||||
with OpenAI(api_key=self.api_key, base_url=self.base_url) as client:
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=self.module[i],
|
||||
messages=message,
|
||||
temperature=0.7,
|
||||
max_tokens=4096
|
||||
)
|
||||
print("token usage {0}".format(response.usage.total_tokens))
|
||||
return response.choices[0].message.content
|
||||
except:
|
||||
if i<len(self.module)-1:
|
||||
i+=1
|
||||
self.count = self.count+1
|
||||
return self.get_ai_message(message,i)
|
||||
else:
|
||||
print("server busy,waiting")
|
||||
time.sleep(10)
|
||||
i = 0
|
||||
if(self.count>maxcount)
|
||||
return self.get_ai_message(message,i)
|
||||
|
||||
|
||||
|
||||
class Ai_reply(BasePlugin):
|
||||
def after_save(self):
|
||||
config = self.config.get("siloconflow")
|
||||
maxcount = self.config.get("base").get("maxcount")
|
||||
ai_server = SiloconFlowAPI(base_url= config.get("base_url"), api_key= config.get("api_key"), modules= config.get("modules"))
|
||||
if self.ctx.group is None:
|
||||
self.ctx.user.messages.append({'role': 'user',
|
||||
'content': f"当前用户为:{self.ctx.user.nickname},其uid为:{self.ctx.user.user_id},以上为背景信息,根据背景信息回复用户消息。用户消息:{self.ctx.raw_message}"})
|
||||
message = ai_server.get_ai_message(self.ctx.user.messages)
|
||||
'content':
|
||||
f"当前用户为:{self.ctx.user.nickname},其uid为:{self.ctx.user.user_id},以上为背景信息,根据背景信息回复用户消息。用户消息:{self.ctx.raw_message}"})
|
||||
typing_thread = threading.Thread(target=self.ctx.user.set_input_status, args=(1,))
|
||||
typing_thread.start()
|
||||
#获取ai返回值
|
||||
message = ai_server.get_ai_message(self.ctx.user.messages)
|
||||
message = ai_server.get_ai_message(self.ctx.user.messages,maxcount)
|
||||
#结束刷新状态
|
||||
self.ctx.user.signal = False
|
||||
#回收线程
|
||||
@ -57,7 +61,7 @@ class Ai_reply(BasePlugin):
|
||||
self.ctx.group.messages.append({'role': 'user',
|
||||
'content':
|
||||
f"群聊名称为:{self.ctx.group.nickname},用户在群中的近十条消息为:{self.ctx.group.current_user.messages},用户名叫:{self.ctx.group.current_user.nickname},以上为背景信息,根据背景信息回复消息。消息为:{self.ctx.raw_message}"})
|
||||
message = ai_server.get_ai_message(self.ctx.group.messages)
|
||||
message = ai_server.get_ai_message(self.ctx.group.messages,maxcount)
|
||||
|
||||
self.ctx.group.send_message(message)
|
||||
return "ok"
|
||||
Reference in New Issue
Block a user