modified: README.md
new file: app.py new file: config/config.toml new file: requirements.txt new file: run.bat new file: run.sh new file: src/__init__.py new file: src/file_store_api.py new file: src/mainprocess.py new file: src/modules/__init__.py new file: src/modules/plugin_modules.py new file: src/modules/user_modules.py new file: src/plugin_manager.py
This commit is contained in:
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
264
src/file_store_api.py
Normal file
264
src/file_store_api.py
Normal file
@ -0,0 +1,264 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import sqlite3
|
||||
import os
|
||||
import time
|
||||
import toml
|
||||
from pathlib import Path
|
||||
from http import HTTPStatus
|
||||
from datetime import datetime
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("chat_app.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理类,处理应用配置"""
|
||||
def __init__(self, config_path="config"):
|
||||
self.config = {}
|
||||
self.config_path = config_path
|
||||
self.build_config_dict()
|
||||
|
||||
|
||||
def build_config_dict(self) -> dict[str, str]:
|
||||
config_dict = {}
|
||||
for config_file in Path(self.config_path).rglob("*.toml"):
|
||||
if not config_file.is_file():
|
||||
continue
|
||||
|
||||
# 获取相对路径的父目录名
|
||||
rel_path = config_file.relative_to(self.config_path)
|
||||
parent_name = rel_path.parent.name if rel_path.parent.name else None
|
||||
|
||||
if parent_name:
|
||||
key = parent_name
|
||||
else:
|
||||
key = config_file.stem # 去掉扩展名
|
||||
|
||||
config_dict[key] = str(config_file.absolute())
|
||||
self.config = config_dict
|
||||
|
||||
def load_config(self,name="config"):
|
||||
"""加载配置文件"""
|
||||
if not os.path.exists(self.config[name]):
|
||||
return {}
|
||||
with open(self.config[name], 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
return toml.load(f)
|
||||
except toml.TomlDecodeError:
|
||||
return {}
|
||||
|
||||
def save_config(self, key=None, value=None):
|
||||
"""保存配置项"""
|
||||
if key is not None and value is not None:
|
||||
# 如果提供了 key 和 value,则更新单个值
|
||||
self.config[key] = value
|
||||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||||
toml.dump(self.config, f)
|
||||
|
||||
def update_config(self, config_dict):
|
||||
"""更新配置字典"""
|
||||
self.config.update(config_dict)
|
||||
self.save_config()
|
||||
|
||||
class MainDatabase:
|
||||
def __init__(self,db_path):
|
||||
self.db_path = db_path
|
||||
self._init_db()
|
||||
def _ini_db(self):
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
|
||||
""")
|
||||
configm = ConfigManager()
|
||||
|
||||
basecontent = configm.load_config("config").get("app", {}).get("system_content", "你是一个qq助手,名叫”the real“")
|
||||
|
||||
class ChatDatabase:
|
||||
"""数据库管理类,处理所有SQLite操作"""
|
||||
def __init__(self, db_path):
|
||||
self.db_path = db_path
|
||||
self._init_db()
|
||||
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库表结构,并添加初始系统消息"""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建消息表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
role TEXT NOT NULL, -- user/assistant/system
|
||||
content TEXT NOT NULL, -- 消息内容
|
||||
sender_id TEXT, -- 发送者ID
|
||||
timestamp REAL NOT NULL -- 时间戳
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建索引
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sender ON messages(sender_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_timestamp ON messages(timestamp)")
|
||||
|
||||
# 检查表是否为空(新创建的数据库)
|
||||
cursor.execute("SELECT COUNT(*) FROM messages")
|
||||
count = cursor.fetchone()[0]
|
||||
if count == 0:
|
||||
# 插入初始系统消息
|
||||
timestamp = datetime.now().timestamp()
|
||||
cursor.execute("""
|
||||
INSERT INTO messages (role, content, sender_id, timestamp)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", ('system', '你是一个qq助手,名叫”the real“需要根据消息回答私信或者群聊消息,注意语言要尽可能的精炼,尽量不要超过20字,同时尽量不要换行,尽量不要让其他用户认出你是ai,同时注意,当聊天中出现[CQ:at,qq=1456128421]时,证明用户在针对你发言,你的uid是1456128421', None, timestamp))
|
||||
logger.info(f"初始化系统消息已添加到数据库: {self.db_path}")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def save_message(self, role, content, sender_id=None):
|
||||
"""保存消息到数据库"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
timestamp = datetime.now().timestamp()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO messages (role, content, sender_id, timestamp)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (role, content, sender_id, timestamp))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def load_messages(self, limit=10, sender_id=None):
|
||||
"""从数据库加载消息"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = "SELECT role, content, sender_id, timestamp FROM messages"
|
||||
params = []
|
||||
|
||||
if sender_id:
|
||||
query += " WHERE sender_id = ?"
|
||||
params.append(sender_id)
|
||||
|
||||
query += " ORDER BY timestamp LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
# 转换为消息字典列表
|
||||
messages = list()
|
||||
for row in rows:
|
||||
messages.append({
|
||||
'role': row[0],
|
||||
'content': row[1],
|
||||
'sender_id': row[2],
|
||||
'timestamp': row[3]
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器,处理所有数据库操作"""
|
||||
def __init__(self):
|
||||
self.base_dir = os.path.join("databases","chats")
|
||||
self.user_dir = os.path.join(self.base_dir, "user")
|
||||
self.group_dir = os.path.join(self.base_dir, "group")
|
||||
# 确保目录存在
|
||||
os.makedirs(self.user_dir, exist_ok=True)
|
||||
os.makedirs(self.group_dir, exist_ok=True)
|
||||
|
||||
def get_user_db(self, user_id):
|
||||
"""获取用户私聊数据库实例"""
|
||||
db_path = os.path.join(self.user_dir, f"{user_id}.db")
|
||||
return ChatDatabase(db_path)
|
||||
|
||||
def get_group_db(self, group_id):
|
||||
"""获取群聊数据库实例"""
|
||||
db_path = os.path.join(self.group_dir, f"{group_id}.db")
|
||||
return ChatDatabase(db_path)
|
||||
|
||||
def save_private_message(self, user, role, content):
|
||||
"""保存私聊消息"""
|
||||
db = self.get_user_db(user.user_id)
|
||||
db.save_message(role, content, sender_id=user.user_id)
|
||||
|
||||
def load_private_messages(self, user, limit=100):
|
||||
"""加载私聊消息"""
|
||||
db = self.get_user_db(user.user_id)
|
||||
return db.load_messages(limit)
|
||||
|
||||
def save_group_message(self, group, role, content, sender_id=None):
|
||||
"""保存群聊消息"""
|
||||
db = self.get_group_db(group.group_id)
|
||||
db.save_message(role, content, sender_id=sender_id)
|
||||
|
||||
def load_group_messages(self, group, limit=100):
|
||||
"""加载群聊消息"""
|
||||
db = self.get_group_db(group.group_id)
|
||||
return db.load_messages(limit)
|
||||
|
||||
def load_user_group_messages(self, user, group, limit=10):
|
||||
"""加载用户在群聊中的消息"""
|
||||
db = self.get_group_db(group.group_id)
|
||||
return db.load_messages(limit, sender_id=user.user_id)
|
||||
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
from modules import user_modules as chater
|
||||
# 创建聊天管理器
|
||||
chat_manager = ChatManager()
|
||||
|
||||
# 创建用户和群组(仅包含基本信息)
|
||||
user1 = chater.Qquser("12345")
|
||||
user2 = chater.Qquser("67890")
|
||||
group = chater.Qqgroup("1001")
|
||||
|
||||
# 保存私聊消息
|
||||
chat_manager.save_private_message(user1, 'user', '你好,我想问个问题')
|
||||
chat_manager.save_private_message(user1, 'assistant', '请说,我会尽力回答')
|
||||
|
||||
# 保存群聊消息
|
||||
chat_manager.save_group_message(group, 'user', '大家好,我是张三!', sender_id=user1.user_id)
|
||||
chat_manager.save_group_message(group, 'user', '大家好,我是李四!', sender_id=user2.user_id)
|
||||
chat_manager.save_group_message(group, 'assistant', '欢迎加入群聊!')
|
||||
|
||||
# 获取私聊消息
|
||||
private_messages = chat_manager.load_private_messages(user1)
|
||||
print(f"{user1.nickname}的私聊记录:")
|
||||
for msg in private_messages:
|
||||
role = "用户" if msg['role'] == 'user' else "AI助手"
|
||||
print(f"{role}: {msg['content']}")
|
||||
|
||||
# 获取群聊完整消息
|
||||
group_messages = chat_manager.load_group_messages(group)
|
||||
print(f"\n{group.nickname}的群聊记录:")
|
||||
for msg in group_messages:
|
||||
if msg['role'] == 'user':
|
||||
print(f"{msg['sender_id']}: {msg['content']}")
|
||||
else:
|
||||
print(f"AI助手: {msg['content']}")
|
||||
|
||||
# 获取用户在群聊中的消息
|
||||
user1_messages = chat_manager.load_user_group_messages(user1, group)
|
||||
print(f"\n{user1.nickname}在{group.nickname}中的消息:")
|
||||
for msg in user1_messages:
|
||||
print(f"{msg['content']}")
|
||||
config = ConfigManager()
|
||||
print(config.config)
|
79
src/mainprocess.py
Normal file
79
src/mainprocess.py
Normal file
@ -0,0 +1,79 @@
|
||||
import sys
|
||||
import src.modules.user_modules as usermod
|
||||
from src.modules.plugin_modules import BasePlugin, MessageContext
|
||||
import src.file_store_api as file_M
|
||||
import src.plugin_manager as plm
|
||||
|
||||
manager = plm.PluginManager()
|
||||
config = file_M.ConfigManager()
|
||||
rebot_id = config.load_config().get("rebot").get("id")
|
||||
def process_message(uid: str, gid: str | None, message: str) -> str:
|
||||
# 创建上下文
|
||||
ctx = MessageContext(uid=uid, gid=gid, raw_message=message,id = rebot_id)
|
||||
|
||||
plugin_manager = manager
|
||||
manager.scan_plugins()
|
||||
# 阶段1: before_load 插件(加载数据前)
|
||||
ctx.phase = "before_load"
|
||||
early_plugins = []
|
||||
for name, plugin_cls in plugin_manager._plugins.items():
|
||||
plugin = plugin_cls(ctx)
|
||||
if hasattr(plugin, 'before_load') and callable(plugin.before_load):
|
||||
early_plugins.append(plugin)
|
||||
|
||||
for plugin in early_plugins:
|
||||
try:
|
||||
result = plugin.before_load()
|
||||
if result is not None: # 拦截逻辑
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"error:Plugin {plugin.__class__.__name__} before_load error: {str(e)}")
|
||||
|
||||
# 原始加载逻辑
|
||||
if gid is not None:
|
||||
ctx.group.messages = ctx.chat_manager.load_group_messages(ctx.group)
|
||||
ctx.user.messages = ctx.chat_manager.load_user_group_messages(user=ctx.user, group=ctx.group)
|
||||
else:
|
||||
ctx.user.messages = ctx.chat_manager.load_private_messages(ctx.user)
|
||||
|
||||
# 阶段2: after_load 插件(加载数据后)
|
||||
ctx.phase = "after_load"
|
||||
loaded_plugins = []
|
||||
for name, plugin_cls in plugin_manager._plugins.items():
|
||||
plugin = plugin_cls(ctx)
|
||||
if hasattr(plugin, 'after_load') and callable(plugin.after_load):
|
||||
loaded_plugins.append(plugin)
|
||||
|
||||
for plugin in loaded_plugins:
|
||||
try:
|
||||
result = plugin.after_load()
|
||||
if result is not None:
|
||||
ctx.response = result
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"error:Plugin {plugin.__class__.__name__} after_load error: {str(e)}")
|
||||
|
||||
# 原始保存逻辑
|
||||
if gid is not None:
|
||||
ctx.chat_manager.save_group_message(ctx.group, role="user", content=ctx.raw_message, sender_id=ctx.user.user_id)
|
||||
else:
|
||||
ctx.chat_manager.save_private_message(ctx.user, role="user", content=ctx.raw_message)
|
||||
|
||||
# 阶段3: after_save 插件(保存数据后)
|
||||
ctx.phase = "after_save"
|
||||
saved_plugins = []
|
||||
for name, plugin_cls in plugin_manager._plugins.items():
|
||||
plugin = plugin_cls(ctx)
|
||||
if hasattr(plugin, 'after_save') and callable(plugin.after_save):
|
||||
saved_plugins.append(plugin)
|
||||
|
||||
for plugin in saved_plugins:
|
||||
try:
|
||||
result = plugin.after_save()
|
||||
if result is not None and ctx.response is None:
|
||||
ctx.response = result
|
||||
except Exception as e:
|
||||
print(f"error:Plugin {plugin.__class__.__name__} after_save error: {str(e)}")
|
||||
plugin_manager.cleanup()
|
||||
|
||||
return ctx.response if ctx.response is not None else "ok"
|
0
src/modules/__init__.py
Normal file
0
src/modules/__init__.py
Normal file
113
src/modules/plugin_modules.py
Normal file
113
src/modules/plugin_modules.py
Normal file
@ -0,0 +1,113 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
import shutil
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from src.modules import user_modules as usermod
|
||||
import src.file_store_api as file_M
|
||||
|
||||
|
||||
class MessageContext:
|
||||
"""封装消息处理的上下文数据"""
|
||||
def __init__(self, uid: str, gid: Optional[str], raw_message: str,id:str):
|
||||
self.raw_message = raw_message
|
||||
self._processed = False
|
||||
self.response: Optional[str] = None
|
||||
# 核心服务实例化
|
||||
self.chat_manager = file_M.ChatManager()
|
||||
self.user = usermod.User(user_id=uid)
|
||||
self.rebot_id = id
|
||||
|
||||
# 动态加载数据
|
||||
if gid:
|
||||
self.group = usermod.Group(group_id=gid)
|
||||
self.group.current_user = self.user
|
||||
else:
|
||||
self.group = None
|
||||
|
||||
@dataclass
|
||||
class PluginPermission:
|
||||
access_private: bool = False # 允许处理私聊消息
|
||||
access_group: bool = True # 允许处理群消息
|
||||
read_history: bool = False # 允许读取历史记录
|
||||
|
||||
from pathlib import Path
|
||||
import toml
|
||||
import os
|
||||
|
||||
class BasePlugin(ABC):
|
||||
def __init__(self, ctx: MessageContext):
|
||||
self.ctx = ctx
|
||||
self._config_dir = self._get_plugin_config_path()
|
||||
self._config_manager = file_M.ConfigManager(self._config_dir)
|
||||
|
||||
def _get_plugin_resource(self, resource_path: str) -> bytes:
|
||||
plugin_name = self.__class__.__name__.lower()
|
||||
zip_path = Path("plugins") / f"{plugin_name}.zip"
|
||||
|
||||
if not zip_path.exists():
|
||||
raise FileNotFoundError(f"插件ZIP包不存在: {zip_path}")
|
||||
|
||||
# 遍历可能的ZIP内路径
|
||||
possible_paths = (
|
||||
resource_path, # 直接路径(config.toml)
|
||||
f"{plugin_name}/{resource_path}" # 插件子目录路径(myplugin/config.toml)
|
||||
)
|
||||
|
||||
with zipfile.ZipFile(zip_path, 'r') as zf:
|
||||
for path in possible_paths:
|
||||
if path in zf.namelist():
|
||||
return zf.read(path)
|
||||
|
||||
raise FileNotFoundError(f"文件 '{resource_path}' 不在ZIP包中")
|
||||
|
||||
def _ensure_config_exists(self):
|
||||
"""确保配置文件存在(不存在时从ZIP复制默认配置)"""
|
||||
config_file = Path(self._config_dir) / "config.toml"
|
||||
|
||||
# 外部配置已存在则直接返回
|
||||
if config_file.exists():
|
||||
return
|
||||
|
||||
# 尝试从ZIP复制默认配置
|
||||
try:
|
||||
|
||||
for ext in ['.toml', '.json', '.yaml']:
|
||||
try:
|
||||
config_data = self._get_plugin_resource(f"config{ext}")
|
||||
|
||||
# 确保配置目录存在
|
||||
config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 统一保存为.toml格式(或保持原格式)
|
||||
config_file.write_bytes(config_data)
|
||||
print(f"✅ 已将默认配置复制到: {config_file}")
|
||||
break
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
else:
|
||||
print(f"⚠️ 插件包内未找到默认配置文件")
|
||||
except Exception as e:
|
||||
print(f"❌ 初始化配置失败: {str(e)}")
|
||||
|
||||
@property
|
||||
def config(self) -> dict:
|
||||
"""始终读取外部配置文件(确保已通过_ensure_config_exists初始化)"""
|
||||
|
||||
try:
|
||||
config_data = self._config_manager.load_config("config")
|
||||
except:
|
||||
self._ensure_config_exists()
|
||||
config_data = self._config_manager.load_config("config")
|
||||
return config_data or {}
|
||||
|
||||
def _get_plugin_config_path(self) -> str:
|
||||
"""获取插件配置目录路径(保持原有逻辑)"""
|
||||
plugin_name = self.__class__.__name__.lower()
|
||||
return str(Path("config") / plugin_name)
|
||||
|
||||
def save_config(self, config_dict: dict) -> bool:
|
||||
"""保存配置到外部目录(与之前逻辑一致)"""
|
||||
self._config_manager.update_config({"config": config_dict})
|
||||
return self._config_manager.save_config()
|
119
src/modules/user_modules.py
Normal file
119
src/modules/user_modules.py
Normal file
@ -0,0 +1,119 @@
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
import src.file_store_api as filer
|
||||
|
||||
config_m = filer.ConfigManager()
|
||||
url = config_m.load_config("config")
|
||||
mainurl = url.get("app").get("send_url")
|
||||
|
||||
class User:
|
||||
def __init__(self, user_id,url = mainurl):
|
||||
self.user_id = user_id
|
||||
self.url = url
|
||||
self.nickname = None
|
||||
self.get_name()
|
||||
self.messages = []
|
||||
self.signal = True
|
||||
self.db = filer.ChatManager()
|
||||
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
获取用户的名称
|
||||
"""
|
||||
try:
|
||||
response = requests.post('{0}/ArkSharePeer'.format(self.url),json={'user_id':self.user_id})
|
||||
except:
|
||||
return 0
|
||||
if response.status_code == 200:
|
||||
# 解析返回的JSON数据
|
||||
response_data = response.json()
|
||||
# 检查是否有错误信息
|
||||
if response_data.get("status") == "ok":
|
||||
# 获取用户卡片信息
|
||||
ark_json = response_data.get("data", {})
|
||||
ark_msg_str = ark_json.get("arkMsg", "{}")
|
||||
try:
|
||||
ark_json = json.loads(ark_msg_str) # 字符串转字典
|
||||
except json.JSONDecodeError:
|
||||
ark_json = {}
|
||||
user_nick = ark_json.get("meta", {}).get("contact", {}).get("nickname")
|
||||
self.nickname = user_nick
|
||||
else:
|
||||
print(f"请求失败,错误信息: {response_data.get('errMsg')}")
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
|
||||
|
||||
def set_input_status(self, status):
|
||||
payload = json.dumps({
|
||||
"user_id": self.user_id,
|
||||
"event_type": status
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
while self.signal:
|
||||
print(f"刷新 {self.nickname} 的输入状态为: {status}")
|
||||
requests.request("POST","{0}/set_input_status".format(self.url), headers=headers, data=payload)
|
||||
time.sleep(0.5)
|
||||
|
||||
def send_message(self, message):
|
||||
requests.post(url='{0}/send_private_msg'.format(self.url), json={'user_id':self.user_id, 'message':message})
|
||||
self.db.save_private_message(self,role = 'assistant',content=message)#保存发送的消息
|
||||
|
||||
class Group:
|
||||
def __init__(self, group_id,url = mainurl,user=None,users=None):
|
||||
self.url = url
|
||||
self.group_id = group_id
|
||||
self.current_user = user
|
||||
self.nickname = None
|
||||
self.get_group_name()
|
||||
self.users = users
|
||||
self.get_group_users()
|
||||
self.messages =[]
|
||||
self.db = filer.ChatManager()
|
||||
|
||||
def get_group_name(self):
|
||||
"""
|
||||
获取群组的名称
|
||||
"""
|
||||
try:
|
||||
response = requests.post('{0}/ArkSharePeer'.format(self.url), json={'group_id': self.group_id})
|
||||
except:
|
||||
return 0
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if response_data.get("status") == "ok":
|
||||
ark_json = response_data.get("data", {})
|
||||
ark_msg_str = ark_json.get("arkJson", "{}")
|
||||
try:
|
||||
ark_json = json.loads(ark_msg_str)
|
||||
except json.JSONDecodeError:
|
||||
ark_json = {}
|
||||
group_name = ark_json.get("meta", {}).get("contact", {}).get("nickname")
|
||||
self.nickname = group_name
|
||||
else:
|
||||
print(f"请求失败,错误信息: {response_data.get('errMsg')}")
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
|
||||
def get_group_users(self):
|
||||
try:
|
||||
response = requests.post('{0}/get_group_member_list'.format(self.url), json={'group_id':self.group_id,'no_cache': False})
|
||||
except:
|
||||
return 0
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if response_data.get("status") == "ok":
|
||||
group_users = response_data.get("data", {})
|
||||
self.users = group_users
|
||||
else:
|
||||
print(f"请求失败,错误信息: {response_data.get('errMsg')}")
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
|
||||
def send_message(self,message):
|
||||
requests.post(url='{0}/send_group_msg'.format(self.url), json={'group_id': self.group_id, 'message': message})
|
||||
self.db.save_group_message(self,'assistant',message, sender_id=0)#保存发送的消息
|
261
src/plugin_manager.py
Normal file
261
src/plugin_manager.py
Normal file
@ -0,0 +1,261 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import zipfile
|
||||
import tempfile
|
||||
import importlib.util
|
||||
import importlib.metadata
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Type, Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
from packaging import version
|
||||
from contextlib import contextmanager
|
||||
|
||||
from src.file_store_api import ConfigManager as ConfigM
|
||||
from src.modules import plugin_modules as plugin_mod
|
||||
|
||||
BasePlugin = plugin_mod.BasePlugin
|
||||
MessageContext = plugin_mod.MessageContext
|
||||
|
||||
config_manager = ConfigM()
|
||||
plugin_config = config_manager.load_config("config").get("plugins", {})
|
||||
PLUGIN_DIR = os.path.join(*plugin_config.get("dir", ["plugins"]))
|
||||
# ---- 核心修改部分 ----
|
||||
|
||||
class DependencyManager:
|
||||
"""处理插件内嵌依赖的专用管理器"""
|
||||
def __init__(self):
|
||||
self._package_roots = set()
|
||||
|
||||
@contextmanager
|
||||
def isolated_import(self, package_root: str):
|
||||
"""上下文管理器,临时添加包路径"""
|
||||
if package_root not in sys.path:
|
||||
sys.path.insert(0, package_root)
|
||||
self._package_roots.add(package_root)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if package_root in sys.path: # 确保不重复移除
|
||||
sys.path.remove(package_root)
|
||||
self._package_roots.discard(package_root)
|
||||
|
||||
def _parse_requirement(self, req: str) -> tuple:
|
||||
"""解析依赖项字符串"""
|
||||
ops = {"==", ">=", "<=", ">", "<", "~=", "!="}
|
||||
for op in ops:
|
||||
if op in req:
|
||||
parts = req.split(op, 1)
|
||||
return parts[0].strip(), op, parts[1].strip()
|
||||
return req.strip(), None, None
|
||||
|
||||
def check_embedded_dependencies(self, plugin_dir: str) -> bool:
|
||||
"""检查插件自带的依赖是否可用"""
|
||||
packages_dir = os.path.join(plugin_dir, "packages")
|
||||
if not os.path.exists(packages_dir):
|
||||
return True
|
||||
|
||||
# 检查requirements.txt或dependencies.json
|
||||
req_file = os.path.join(plugin_dir, "requirements.txt")
|
||||
if not os.path.exists(req_file):
|
||||
req_file = os.path.join(plugin_dir, "dependencies.json")
|
||||
|
||||
if os.path.exists(req_file):
|
||||
return self._validate_dependencies(req_file, packages_dir)
|
||||
return True
|
||||
|
||||
def _validate_dependencies(self, req_file: str, packages_dir: str) -> bool:
|
||||
"""验证依赖是否满足"""
|
||||
requirements = self._load_requirements(req_file)
|
||||
if not requirements:
|
||||
return True
|
||||
|
||||
with self.isolated_import(packages_dir):
|
||||
for req in requirements:
|
||||
try:
|
||||
pkg, op, ver = self._parse_requirement(req)
|
||||
installed = importlib.metadata.version(pkg)
|
||||
|
||||
if op and ver:
|
||||
if not self._compare_versions(installed, op, ver):
|
||||
print(f"error: 依赖版本不匹配: 需要 {req},但找到 {installed}")
|
||||
return False
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
print(f"error: 依赖未找到: {req}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _load_requirements(self, req_file: str) -> List[str]:
|
||||
"""加载依赖文件"""
|
||||
try:
|
||||
if req_file.endswith('.json'):
|
||||
with open(req_file) as f:
|
||||
data = json.load(f)
|
||||
return data.get("requirements", [])
|
||||
else:
|
||||
with open(req_file) as f:
|
||||
return [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
||||
except Exception as e:
|
||||
print(f"error: 读取依赖文件失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def _compare_versions(self, installed: str, op: str, required: str) -> bool:
|
||||
"""比较版本号"""
|
||||
iv = version.parse(installed)
|
||||
rv = version.parse(required)
|
||||
|
||||
if op == "==": return iv == rv
|
||||
elif op == ">=": return iv >= rv
|
||||
elif op == "<=": return iv <= rv
|
||||
elif op == ">": return iv > rv
|
||||
elif op == "<": return iv < rv
|
||||
elif op == "~=": return iv >= rv and iv < version.parse(self._next_major(required))
|
||||
elif op == "!=": return iv != rv
|
||||
return True
|
||||
|
||||
def _next_major(self, ver: str) -> str:
|
||||
"""获取下一个主版本号"""
|
||||
parts = ver.split('.')
|
||||
if parts:
|
||||
try:
|
||||
parts[0] = str(int(parts[0]) + 1)
|
||||
return '.'.join(parts)
|
||||
except ValueError:
|
||||
pass
|
||||
return ver + ".0"
|
||||
|
||||
# ---- 修改后的插件管理器 ----
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self):
|
||||
self._plugins: Dict[str, Type[BasePlugin]] = {} # 插件类注册表
|
||||
self._active_instances: Dict[str, BasePlugin] = {} # 激活的插件实例
|
||||
self._hook_registry: Dict[str, List[Callable]] = {} # 兼容旧版钩子
|
||||
self._temp_dirs: List[str] = [] # 临时目录记录
|
||||
self._dependency_manager = DependencyManager() # 新增依赖管理器\
|
||||
self.scan_plugins
|
||||
|
||||
def scan_plugins(self):
|
||||
"""扫描并加载所有ZIP插件(保持原接口不变)"""
|
||||
if not os.path.exists(PLUGIN_DIR):
|
||||
os.makedirs(PLUGIN_DIR, exist_ok=True)
|
||||
|
||||
for item in Path(PLUGIN_DIR).glob("**/*.zip"):
|
||||
self.load_plugin(str(item))
|
||||
|
||||
def load_plugin(self, zip_path: str) -> bool:
|
||||
"""动态加载ZIP格式插件(支持内嵌依赖)"""
|
||||
try:
|
||||
# 解压到临时目录(保持原逻辑)
|
||||
temp_dir = tempfile.mkdtemp(prefix=f"plugin_{Path(zip_path).stem}_")
|
||||
self._temp_dirs.append(temp_dir)
|
||||
self._extract_zip(zip_path, temp_dir)
|
||||
|
||||
# 检查并加载内嵌依赖(新增功能)
|
||||
if not self._load_embedded_dependencies(temp_dir):
|
||||
print(f"error: 依赖检查失败,跳过插件: {Path(zip_path).name}")
|
||||
return False
|
||||
|
||||
# 动态导入主模块(保持原逻辑)
|
||||
return self._import_plugin_module(temp_dir, Path(zip_path).stem)
|
||||
|
||||
except Exception as e:
|
||||
print(f"error: 加载插件失败 {Path(zip_path).name}: {str(e)}")
|
||||
return False
|
||||
|
||||
def _extract_zip(self, zip_path: str, target_dir: str):
|
||||
"""解压ZIP文件(抽取为独立方法)"""
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(target_dir)
|
||||
|
||||
def _load_embedded_dependencies(self, plugin_dir: str) -> bool:
|
||||
"""加载插件内嵌的依赖包(新增方法)"""
|
||||
# 检查依赖是否满足
|
||||
if not self._dependency_manager.check_embedded_dependencies(plugin_dir):
|
||||
return False
|
||||
|
||||
# 如果有packages目录,添加到导入路径
|
||||
packages_dir = os.path.join(plugin_dir, "packages")
|
||||
if os.path.exists(packages_dir):
|
||||
sys.path.insert(0, packages_dir)
|
||||
print(f"已加载插件内嵌依赖: {packages_dir}")
|
||||
|
||||
return True
|
||||
|
||||
def _import_plugin_module(self, plugin_dir: str, plugin_name: str) -> bool:
|
||||
"""导入插件主模块(重构为独立方法)"""
|
||||
entry_file = Path(plugin_dir) / "process.py"
|
||||
if not entry_file.exists():
|
||||
print(f"error: 插件入口文件不存在: {plugin_name}")
|
||||
return False
|
||||
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
f"plugins.{plugin_name}",
|
||||
str(entry_file)
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 自动注册插件类(保持原逻辑)
|
||||
return self._register_plugin_classes(module, plugin_name)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 导入插件模块失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def _register_plugin_classes(self, module, plugin_name: str) -> bool:
|
||||
"""注册插件类(抽取为独立方法)"""
|
||||
plugin_registered = False
|
||||
for name, obj in module.__dict__.items():
|
||||
if isinstance(obj, type) and issubclass(obj, BasePlugin) and obj != BasePlugin:
|
||||
self._plugins[plugin_name] = obj
|
||||
print(f"✅ 已注册插件类: {plugin_name}::{name}")
|
||||
plugin_registered = True
|
||||
|
||||
# 兼容旧版钩子注册(保持原逻辑)
|
||||
if hasattr(module, "register_hooks"):
|
||||
module.register_hooks(self)
|
||||
print(f"🔄 已注册旧版钩子: {plugin_name}")
|
||||
|
||||
return plugin_registered
|
||||
|
||||
def process_message(self, uid: str, gid: Optional[str], message: str) -> str:
|
||||
"""主消息处理入口(保持完全兼容)"""
|
||||
ctx = MessageContext(uid, gid, message)
|
||||
|
||||
# 优先执行新版插件流程
|
||||
for name, plugin_cls in self._plugins.items():
|
||||
try:
|
||||
plugin = plugin_cls(ctx)
|
||||
if result := plugin.process():
|
||||
ctx.response = result
|
||||
ctx.intercepted = True
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"⚠️ 插件错误 {name}: {str(e)}")
|
||||
|
||||
# 如果没有被拦截,运行旧版钩子
|
||||
if not ctx.intercepted:
|
||||
for hook in self._hook_registry.get("on_message", []):
|
||||
hook(ctx)
|
||||
|
||||
return ctx.response or "ok"
|
||||
|
||||
def register_hook(self, hook_name: str):
|
||||
"""兼容旧版钩子注册(装饰器模式,保持原样)"""
|
||||
def decorator(func):
|
||||
if hook_name not in self._hook_registry:
|
||||
self._hook_registry[hook_name] = []
|
||||
self._hook_registry[hook_name].append(func)
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源(保持原逻辑)"""
|
||||
for temp_dir in self._temp_dirs:
|
||||
if os.path.exists(temp_dir):
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
Reference in New Issue
Block a user