modified: README.md
new file: app.py new file: config/config.toml new file: requirements.txt new file: run.bat new file: run.sh new file: src/__init__.py new file: src/file_store_api.py new file: src/mainprocess.py new file: src/modules/__init__.py new file: src/modules/plugin_modules.py new file: src/modules/user_modules.py new file: src/plugin_manager.py
This commit is contained in:
145
README.md
145
README.md
@ -1,3 +1,144 @@
|
||||
# chat_rebot-connect-with-onebot-standard-
|
||||
# OneBot Chatbot Framework
|
||||
|
||||
该项目是一个基于OneBot标准的聊天机器人后端框架,采用高度可扩展的插件架构设计,支持消息的模块化处理和插件热加载。
|
||||
该项目是一个基于OneBot标准的聊天机器人后端框架,采用高度可扩展的插件架构设计,支持消息的模块化处理和插件热加载。
|
||||
|
||||
## 项目特点
|
||||
|
||||
- **模块化设计**:每个功能作为独立插件实现,易于扩展和维护
|
||||
- **插件生命周期管理**:支持插件加载、注册、依赖处理和实例管理
|
||||
- **消息处理管道**:分阶段处理消息,支持各阶段拦截机制
|
||||
- **会话管理**:支持群组消息和私聊消息的独立管理
|
||||
- **内嵌依赖处理**:自动管理插件内嵌的Python依赖包
|
||||
- **兼容性设计**:支持新旧版本插件并存运行
|
||||
|
||||
## 核心组件
|
||||
|
||||
### 消息处理流程 (`process_message`)
|
||||
|
||||
```python
|
||||
def process_message(uid: str, gid: str | None, message: str) -> str:
|
||||
# 1. 创建消息上下文
|
||||
ctx = MessageContext(...)
|
||||
|
||||
# 2. 扫描并加载插件
|
||||
plugin_manager.scan_plugins()
|
||||
|
||||
# 3. 消息处理阶段:
|
||||
# - before_load: 加载数据前拦截点
|
||||
# - after_load: 加载数据后处理点
|
||||
# - after_save: 保存数据后处理点
|
||||
|
||||
# 4. 会话数据持久化
|
||||
ctx.chat_manager.save_message(...)
|
||||
|
||||
return ctx.response or "ok"
|
||||
```
|
||||
|
||||
### 插件管理器 (`PluginManager`)
|
||||
|
||||
```python
|
||||
class PluginManager:
|
||||
def __init__(self):
|
||||
self._plugins = {} # 插件类注册表
|
||||
self._active_instances = {} # 插件实例
|
||||
self._hook_registry = {} # 兼容旧版钩子
|
||||
self._temp_dirs = [] # 临时目录
|
||||
self._dependency_manager = DependencyManager() # 依赖处理器
|
||||
|
||||
def scan_plugins(self):
|
||||
"""扫描插件目录并加载ZIP格式插件"""
|
||||
|
||||
def load_plugin(self, zip_path: str) -> bool:
|
||||
"""动态加载ZIP格式插件"""
|
||||
|
||||
def _load_embedded_dependencies(self, plugin_dir: str) -> bool:
|
||||
"""加载插件内嵌的依赖包"""
|
||||
|
||||
def register_hook(self, hook_name: str):
|
||||
"""注册兼容旧版钩子(装饰器模式)"""
|
||||
|
||||
def cleanup(self):
|
||||
"""清理临时资源"""
|
||||
```
|
||||
|
||||
## 插件开发指南
|
||||
|
||||
### 基本插件结构
|
||||
|
||||
```python
|
||||
# process.py
|
||||
from src.modules.plugin_modules import BasePlugin, MessageContext
|
||||
|
||||
class MyPlugin(BasePlugin):
|
||||
def __init__(self, ctx: MessageContext):
|
||||
super().__init__(ctx)
|
||||
|
||||
def process(self) -> str | None:
|
||||
"""核心处理方法"""
|
||||
if self.ctx.command == "help":
|
||||
return self._show_help()
|
||||
|
||||
def before_load(self) -> str | None:
|
||||
"""数据加载前拦截点"""
|
||||
|
||||
def after_load(self) -> str | None:
|
||||
"""数据加载后处理点"""
|
||||
|
||||
def after_save(self) -> str | None:
|
||||
"""数据保存后处理点"""
|
||||
```
|
||||
|
||||
### 目录结构要求
|
||||
|
||||
插件应以ZIP格式打包,包含以下内容:
|
||||
|
||||
```
|
||||
my_plugin.zip
|
||||
├── process.py # 必需 - 插件入口文件
|
||||
├── requirements.txt # 可选 - 依赖声明
|
||||
└── packages/ # 可选 - 内嵌依赖包
|
||||
├── package1/
|
||||
└── package2/
|
||||
```
|
||||
|
||||
### 依赖声明
|
||||
|
||||
插件可通过两种方式声明依赖:
|
||||
|
||||
1. `requirements.txt` 标准格式
|
||||
2. `dependencies.json` 自定义格式
|
||||
|
||||
```json
|
||||
// dependencies.json 示例
|
||||
{
|
||||
"requirements": [
|
||||
"requests==2.28.2",
|
||||
"numpy>=1.25.0"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 快速启动
|
||||
|
||||
1. 为启动脚本授权(linux):
|
||||
|
||||
```bash
|
||||
chmod +x run.sh
|
||||
```
|
||||
2. 修改配置文件,list_port为接收消息推送端口,send_url为消息发送地址
|
||||
3. 运行启动脚本
|
||||
linux下
|
||||
|
||||
`./run.sh`
|
||||
|
||||
windows下
|
||||
|
||||
`.\run.bat`
|
||||
|
||||
## 设计优势
|
||||
|
||||
1. **解耦设计**:插件与核心系统完全解耦
|
||||
2. **安全隔离**:使用临时目录加载插件
|
||||
3. **版本兼容**:内建依赖版本验证机制
|
||||
4. **灵活扩展**:支持多个消息处理点
|
||||
5. **新旧兼容**:支持传统钩子和现代OOP插件的共存
|
||||
|
97
app.py
Normal file
97
app.py
Normal file
@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from flask import Flask, request, jsonify
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from functools import wraps
|
||||
from datetime import datetime
|
||||
from src import mainprocess as src
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
#===rebot===#
|
||||
# 处理私聊消息
|
||||
# 处理群聊消息
|
||||
@app.route('/', methods=["POST"])
|
||||
def handle_event():
|
||||
try:
|
||||
event = request.get_json()
|
||||
event_type = event.get('post_type')
|
||||
|
||||
# 1. 处理私聊消息
|
||||
if event_type == 'message' and event.get('message_type') == 'private':
|
||||
# 注意:私聊消息在顶层有 user_id
|
||||
uid = event.get('user_id')
|
||||
message = event.get('raw_message')
|
||||
src.process_message(uid, None, message)
|
||||
|
||||
# 2. 处理群消息
|
||||
elif event_type == 'message' and event.get('message_type') == 'group':
|
||||
gid = event.get('group_id')
|
||||
# 注意:群消息发送者在 sender 内
|
||||
sender = event.get('sender', {})
|
||||
uid = sender.get('user_id')
|
||||
message = event.get('raw_message')
|
||||
src.process_message(uid, gid, message)
|
||||
|
||||
# 3. 处理通知事件(如输入状态)
|
||||
elif event_type == 'notice':
|
||||
notice_type = event.get('notice_type')
|
||||
|
||||
if notice_type == 'notify' and event.get('sub_type') == 'input_status':
|
||||
# 仅记录,不处理
|
||||
logging.info(f"用户 {event.get('user_id')} 输入状态变化")
|
||||
|
||||
elif notice_type == 'group_recall':
|
||||
# 示例:处理群消息撤回
|
||||
logging.info(f"群 {event.get('group_id')} 撤回消息")
|
||||
|
||||
else:
|
||||
# 其他通知类型
|
||||
logging.info(f"Ignored notice: {event}")
|
||||
|
||||
# 通知事件直接返回成功
|
||||
return jsonify({
|
||||
"status": "ok",
|
||||
"retcode": 0,
|
||||
"data": None
|
||||
})
|
||||
|
||||
# 4. 处理元事件(如心跳)
|
||||
elif event_type == 'meta_event':
|
||||
# 心跳等元事件直接返回成功
|
||||
return jsonify({
|
||||
"status": "ok",
|
||||
"retcode": 0,
|
||||
"data": None
|
||||
})
|
||||
|
||||
# 5. 一切正常的消息事件返回成功
|
||||
return jsonify({
|
||||
"status": "ok",
|
||||
"retcode": 0,
|
||||
"data": "Processed successfully"
|
||||
})
|
||||
|
||||
except KeyError:
|
||||
logging.warning(f"Missing required field in event: {event}")
|
||||
return jsonify({
|
||||
"status": "failed",
|
||||
"retcode": 10001,
|
||||
"message": "Missing required field"
|
||||
}), 400
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"Error processing event: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "failed",
|
||||
"retcode": 20001,
|
||||
"message": "Internal server error"
|
||||
}), 500
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
port = 25580
|
||||
app.run(debug=True, host='0.0.0.0', port=port)
|
||||
except Exception as e:
|
||||
print(f"启动失败: {e}")
|
10
config/config.toml
Normal file
10
config/config.toml
Normal file
@ -0,0 +1,10 @@
|
||||
[app]
|
||||
list_port = 25580
|
||||
send_url = "http://192.168.1.1:25570"
|
||||
system_content = "你是一个qq助手,名叫”the real“需要根据消息回答私信或者群聊消息,注意语言要尽可能的精炼,尽量不要超过20字,同时尽量不要换行,尽量不要让其他用户认出你是ai,同时注意,当聊天中出现[CQ:at,qq=1456128421]时,证明用户在针对你发言,你的uid是1456128421"
|
||||
|
||||
[plugins]
|
||||
dir = ["plugins"]
|
||||
|
||||
[rebot]
|
||||
id = ""
|
16
requirements.txt
Normal file
16
requirements.txt
Normal file
@ -0,0 +1,16 @@
|
||||
blinker==1.9.0
|
||||
certifi==2025.8.3
|
||||
charset-normalizer==3.4.2
|
||||
click==8.2.1
|
||||
colorama==0.4.6
|
||||
Flask==3.1.1
|
||||
idna==3.10
|
||||
itsdangerous==2.2.0
|
||||
Jinja2==3.1.6
|
||||
MarkupSafe==3.0.2
|
||||
packaging==25.0
|
||||
pkg==0.2
|
||||
requests==2.32.4
|
||||
toml==0.10.2
|
||||
urllib3==2.5.0
|
||||
Werkzeug==3.1.3
|
74
run.bat
Normal file
74
run.bat
Normal file
@ -0,0 +1,74 @@
|
||||
@echo off
|
||||
|
||||
|
||||
set PROJECT_DIR=%~dp0
|
||||
set VENV_DIR=%PROJECT_DIR%.venv
|
||||
|
||||
|
||||
if exist "%VENV_DIR%\Scripts\activate.bat" (
|
||||
|
||||
call "%VENV_DIR%\Scripts\activate.bat"
|
||||
) else (
|
||||
|
||||
python -m venv "%VENV_DIR%"
|
||||
call "%VENV_DIR%\Scripts\activate.bat"
|
||||
|
||||
if errorlevel 1 (
|
||||
echo error: fail to create env
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if "%VIRTUAL_ENV%" == "" (
|
||||
echo error: fail to activate env
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo installing dependence...
|
||||
|
||||
pip install -r requirements.txt
|
||||
if errorlevel 1 (
|
||||
echo error: fail to install dependence
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
pip install waitress
|
||||
if errorlevel 1 (
|
||||
echo error: fail to install waitress
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
|
||||
echo reading port from config...
|
||||
for /f "usebackq tokens=*" %%P in (`python -c "from src.file_store_api import ConfigManager; config=ConfigManager().load_config(); print(config.get('app', {}).get('list_port', 25580))"`) do (
|
||||
set PORT=%%P
|
||||
)
|
||||
|
||||
|
||||
if "%PORT%"=="" (
|
||||
set PORT=25580
|
||||
echo can't read port,use custom port:25580
|
||||
) else (
|
||||
echo success read port: %PORT%
|
||||
)
|
||||
|
||||
:: 5. 启动服务
|
||||
echo starting rebot_server...
|
||||
echo listening at: %PORT%
|
||||
|
||||
:: 正确启动 waitress
|
||||
waitress-serve --host=0.0.0.0 --port=%PORT% app:app
|
||||
|
||||
:: 6. 错误处理
|
||||
if errorlevel 1 (
|
||||
echo error,fail to start
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
pause
|
51
run.sh
Normal file
51
run.sh
Normal file
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
|
||||
VENV_DIR="$PROJECT_DIR/.venv"
|
||||
FLASK_APP="app:app"
|
||||
|
||||
|
||||
echo "Activating virtual environment..."
|
||||
|
||||
if [ -f "$VENV_DIR/bin/activate" ]; then
|
||||
source "$VENV_DIR/bin/activate"
|
||||
else
|
||||
echo "Creating new virtual environment..."
|
||||
python3 -m venv "$VENV_DIR"
|
||||
source "$VENV_DIR/bin/activate"
|
||||
fi
|
||||
|
||||
if [ -z "$VIRTUAL_ENV" ]; then
|
||||
echo "Error: Failed to activate virtual environment"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Installing dependencies..."
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install gunicorn
|
||||
|
||||
echo "Reading port from configuration..."
|
||||
|
||||
PORT=$(python3 -c \
|
||||
"
|
||||
from src.file_store_api import ConfigManager
|
||||
try:
|
||||
config = ConfigManager().load_config()
|
||||
port = config.get('app', {}).get('list_port')
|
||||
print(str(port) if port else '')
|
||||
except Exception as e:
|
||||
print('ERROR: ' + str(e))
|
||||
exit(1)
|
||||
")
|
||||
|
||||
if [[ "$PORT" == ERROR:* ]] || [ -z "$PORT" ]; then
|
||||
echo "Failed to get port from config: $PORT"
|
||||
echo "Using default port 25580"
|
||||
PORT=25580
|
||||
fi
|
||||
|
||||
echo "Starting rebot server..."
|
||||
echo "Listening on port: $PORT"
|
||||
|
||||
gunicorn -w 4 -b 0.0.0.0:$PORT "$FLASK_APP" --access-logfile - --error-logfile -
|
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
264
src/file_store_api.py
Normal file
264
src/file_store_api.py
Normal file
@ -0,0 +1,264 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import sqlite3
|
||||
import os
|
||||
import time
|
||||
import toml
|
||||
from pathlib import Path
|
||||
from http import HTTPStatus
|
||||
from datetime import datetime
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("chat_app.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理类,处理应用配置"""
|
||||
def __init__(self, config_path="config"):
|
||||
self.config = {}
|
||||
self.config_path = config_path
|
||||
self.build_config_dict()
|
||||
|
||||
|
||||
def build_config_dict(self) -> dict[str, str]:
|
||||
config_dict = {}
|
||||
for config_file in Path(self.config_path).rglob("*.toml"):
|
||||
if not config_file.is_file():
|
||||
continue
|
||||
|
||||
# 获取相对路径的父目录名
|
||||
rel_path = config_file.relative_to(self.config_path)
|
||||
parent_name = rel_path.parent.name if rel_path.parent.name else None
|
||||
|
||||
if parent_name:
|
||||
key = parent_name
|
||||
else:
|
||||
key = config_file.stem # 去掉扩展名
|
||||
|
||||
config_dict[key] = str(config_file.absolute())
|
||||
self.config = config_dict
|
||||
|
||||
def load_config(self,name="config"):
|
||||
"""加载配置文件"""
|
||||
if not os.path.exists(self.config[name]):
|
||||
return {}
|
||||
with open(self.config[name], 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
return toml.load(f)
|
||||
except toml.TomlDecodeError:
|
||||
return {}
|
||||
|
||||
def save_config(self, key=None, value=None):
|
||||
"""保存配置项"""
|
||||
if key is not None and value is not None:
|
||||
# 如果提供了 key 和 value,则更新单个值
|
||||
self.config[key] = value
|
||||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||||
toml.dump(self.config, f)
|
||||
|
||||
def update_config(self, config_dict):
|
||||
"""更新配置字典"""
|
||||
self.config.update(config_dict)
|
||||
self.save_config()
|
||||
|
||||
class MainDatabase:
|
||||
def __init__(self,db_path):
|
||||
self.db_path = db_path
|
||||
self._init_db()
|
||||
def _ini_db(self):
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
|
||||
""")
|
||||
configm = ConfigManager()
|
||||
|
||||
basecontent = configm.load_config("config").get("app", {}).get("system_content", "你是一个qq助手,名叫”the real“")
|
||||
|
||||
class ChatDatabase:
|
||||
"""数据库管理类,处理所有SQLite操作"""
|
||||
def __init__(self, db_path):
|
||||
self.db_path = db_path
|
||||
self._init_db()
|
||||
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库表结构,并添加初始系统消息"""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建消息表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
role TEXT NOT NULL, -- user/assistant/system
|
||||
content TEXT NOT NULL, -- 消息内容
|
||||
sender_id TEXT, -- 发送者ID
|
||||
timestamp REAL NOT NULL -- 时间戳
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建索引
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sender ON messages(sender_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_timestamp ON messages(timestamp)")
|
||||
|
||||
# 检查表是否为空(新创建的数据库)
|
||||
cursor.execute("SELECT COUNT(*) FROM messages")
|
||||
count = cursor.fetchone()[0]
|
||||
if count == 0:
|
||||
# 插入初始系统消息
|
||||
timestamp = datetime.now().timestamp()
|
||||
cursor.execute("""
|
||||
INSERT INTO messages (role, content, sender_id, timestamp)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", ('system', '你是一个qq助手,名叫”the real“需要根据消息回答私信或者群聊消息,注意语言要尽可能的精炼,尽量不要超过20字,同时尽量不要换行,尽量不要让其他用户认出你是ai,同时注意,当聊天中出现[CQ:at,qq=1456128421]时,证明用户在针对你发言,你的uid是1456128421', None, timestamp))
|
||||
logger.info(f"初始化系统消息已添加到数据库: {self.db_path}")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def save_message(self, role, content, sender_id=None):
|
||||
"""保存消息到数据库"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
timestamp = datetime.now().timestamp()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO messages (role, content, sender_id, timestamp)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (role, content, sender_id, timestamp))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def load_messages(self, limit=10, sender_id=None):
|
||||
"""从数据库加载消息"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = "SELECT role, content, sender_id, timestamp FROM messages"
|
||||
params = []
|
||||
|
||||
if sender_id:
|
||||
query += " WHERE sender_id = ?"
|
||||
params.append(sender_id)
|
||||
|
||||
query += " ORDER BY timestamp LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
# 转换为消息字典列表
|
||||
messages = list()
|
||||
for row in rows:
|
||||
messages.append({
|
||||
'role': row[0],
|
||||
'content': row[1],
|
||||
'sender_id': row[2],
|
||||
'timestamp': row[3]
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器,处理所有数据库操作"""
|
||||
def __init__(self):
|
||||
self.base_dir = os.path.join("databases","chats")
|
||||
self.user_dir = os.path.join(self.base_dir, "user")
|
||||
self.group_dir = os.path.join(self.base_dir, "group")
|
||||
# 确保目录存在
|
||||
os.makedirs(self.user_dir, exist_ok=True)
|
||||
os.makedirs(self.group_dir, exist_ok=True)
|
||||
|
||||
def get_user_db(self, user_id):
|
||||
"""获取用户私聊数据库实例"""
|
||||
db_path = os.path.join(self.user_dir, f"{user_id}.db")
|
||||
return ChatDatabase(db_path)
|
||||
|
||||
def get_group_db(self, group_id):
|
||||
"""获取群聊数据库实例"""
|
||||
db_path = os.path.join(self.group_dir, f"{group_id}.db")
|
||||
return ChatDatabase(db_path)
|
||||
|
||||
def save_private_message(self, user, role, content):
|
||||
"""保存私聊消息"""
|
||||
db = self.get_user_db(user.user_id)
|
||||
db.save_message(role, content, sender_id=user.user_id)
|
||||
|
||||
def load_private_messages(self, user, limit=100):
|
||||
"""加载私聊消息"""
|
||||
db = self.get_user_db(user.user_id)
|
||||
return db.load_messages(limit)
|
||||
|
||||
def save_group_message(self, group, role, content, sender_id=None):
|
||||
"""保存群聊消息"""
|
||||
db = self.get_group_db(group.group_id)
|
||||
db.save_message(role, content, sender_id=sender_id)
|
||||
|
||||
def load_group_messages(self, group, limit=100):
|
||||
"""加载群聊消息"""
|
||||
db = self.get_group_db(group.group_id)
|
||||
return db.load_messages(limit)
|
||||
|
||||
def load_user_group_messages(self, user, group, limit=10):
|
||||
"""加载用户在群聊中的消息"""
|
||||
db = self.get_group_db(group.group_id)
|
||||
return db.load_messages(limit, sender_id=user.user_id)
|
||||
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
from modules import user_modules as chater
|
||||
# 创建聊天管理器
|
||||
chat_manager = ChatManager()
|
||||
|
||||
# 创建用户和群组(仅包含基本信息)
|
||||
user1 = chater.Qquser("12345")
|
||||
user2 = chater.Qquser("67890")
|
||||
group = chater.Qqgroup("1001")
|
||||
|
||||
# 保存私聊消息
|
||||
chat_manager.save_private_message(user1, 'user', '你好,我想问个问题')
|
||||
chat_manager.save_private_message(user1, 'assistant', '请说,我会尽力回答')
|
||||
|
||||
# 保存群聊消息
|
||||
chat_manager.save_group_message(group, 'user', '大家好,我是张三!', sender_id=user1.user_id)
|
||||
chat_manager.save_group_message(group, 'user', '大家好,我是李四!', sender_id=user2.user_id)
|
||||
chat_manager.save_group_message(group, 'assistant', '欢迎加入群聊!')
|
||||
|
||||
# 获取私聊消息
|
||||
private_messages = chat_manager.load_private_messages(user1)
|
||||
print(f"{user1.nickname}的私聊记录:")
|
||||
for msg in private_messages:
|
||||
role = "用户" if msg['role'] == 'user' else "AI助手"
|
||||
print(f"{role}: {msg['content']}")
|
||||
|
||||
# 获取群聊完整消息
|
||||
group_messages = chat_manager.load_group_messages(group)
|
||||
print(f"\n{group.nickname}的群聊记录:")
|
||||
for msg in group_messages:
|
||||
if msg['role'] == 'user':
|
||||
print(f"{msg['sender_id']}: {msg['content']}")
|
||||
else:
|
||||
print(f"AI助手: {msg['content']}")
|
||||
|
||||
# 获取用户在群聊中的消息
|
||||
user1_messages = chat_manager.load_user_group_messages(user1, group)
|
||||
print(f"\n{user1.nickname}在{group.nickname}中的消息:")
|
||||
for msg in user1_messages:
|
||||
print(f"{msg['content']}")
|
||||
config = ConfigManager()
|
||||
print(config.config)
|
79
src/mainprocess.py
Normal file
79
src/mainprocess.py
Normal file
@ -0,0 +1,79 @@
|
||||
import sys
|
||||
import src.modules.user_modules as usermod
|
||||
from src.modules.plugin_modules import BasePlugin, MessageContext
|
||||
import src.file_store_api as file_M
|
||||
import src.plugin_manager as plm
|
||||
|
||||
manager = plm.PluginManager()
|
||||
config = file_M.ConfigManager()
|
||||
rebot_id = config.load_config().get("rebot").get("id")
|
||||
def process_message(uid: str, gid: str | None, message: str) -> str:
|
||||
# 创建上下文
|
||||
ctx = MessageContext(uid=uid, gid=gid, raw_message=message,id = rebot_id)
|
||||
|
||||
plugin_manager = manager
|
||||
manager.scan_plugins()
|
||||
# 阶段1: before_load 插件(加载数据前)
|
||||
ctx.phase = "before_load"
|
||||
early_plugins = []
|
||||
for name, plugin_cls in plugin_manager._plugins.items():
|
||||
plugin = plugin_cls(ctx)
|
||||
if hasattr(plugin, 'before_load') and callable(plugin.before_load):
|
||||
early_plugins.append(plugin)
|
||||
|
||||
for plugin in early_plugins:
|
||||
try:
|
||||
result = plugin.before_load()
|
||||
if result is not None: # 拦截逻辑
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"error:Plugin {plugin.__class__.__name__} before_load error: {str(e)}")
|
||||
|
||||
# 原始加载逻辑
|
||||
if gid is not None:
|
||||
ctx.group.messages = ctx.chat_manager.load_group_messages(ctx.group)
|
||||
ctx.user.messages = ctx.chat_manager.load_user_group_messages(user=ctx.user, group=ctx.group)
|
||||
else:
|
||||
ctx.user.messages = ctx.chat_manager.load_private_messages(ctx.user)
|
||||
|
||||
# 阶段2: after_load 插件(加载数据后)
|
||||
ctx.phase = "after_load"
|
||||
loaded_plugins = []
|
||||
for name, plugin_cls in plugin_manager._plugins.items():
|
||||
plugin = plugin_cls(ctx)
|
||||
if hasattr(plugin, 'after_load') and callable(plugin.after_load):
|
||||
loaded_plugins.append(plugin)
|
||||
|
||||
for plugin in loaded_plugins:
|
||||
try:
|
||||
result = plugin.after_load()
|
||||
if result is not None:
|
||||
ctx.response = result
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"error:Plugin {plugin.__class__.__name__} after_load error: {str(e)}")
|
||||
|
||||
# 原始保存逻辑
|
||||
if gid is not None:
|
||||
ctx.chat_manager.save_group_message(ctx.group, role="user", content=ctx.raw_message, sender_id=ctx.user.user_id)
|
||||
else:
|
||||
ctx.chat_manager.save_private_message(ctx.user, role="user", content=ctx.raw_message)
|
||||
|
||||
# 阶段3: after_save 插件(保存数据后)
|
||||
ctx.phase = "after_save"
|
||||
saved_plugins = []
|
||||
for name, plugin_cls in plugin_manager._plugins.items():
|
||||
plugin = plugin_cls(ctx)
|
||||
if hasattr(plugin, 'after_save') and callable(plugin.after_save):
|
||||
saved_plugins.append(plugin)
|
||||
|
||||
for plugin in saved_plugins:
|
||||
try:
|
||||
result = plugin.after_save()
|
||||
if result is not None and ctx.response is None:
|
||||
ctx.response = result
|
||||
except Exception as e:
|
||||
print(f"error:Plugin {plugin.__class__.__name__} after_save error: {str(e)}")
|
||||
plugin_manager.cleanup()
|
||||
|
||||
return ctx.response if ctx.response is not None else "ok"
|
0
src/modules/__init__.py
Normal file
0
src/modules/__init__.py
Normal file
113
src/modules/plugin_modules.py
Normal file
113
src/modules/plugin_modules.py
Normal file
@ -0,0 +1,113 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
import shutil
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from src.modules import user_modules as usermod
|
||||
import src.file_store_api as file_M
|
||||
|
||||
|
||||
class MessageContext:
|
||||
"""封装消息处理的上下文数据"""
|
||||
def __init__(self, uid: str, gid: Optional[str], raw_message: str,id:str):
|
||||
self.raw_message = raw_message
|
||||
self._processed = False
|
||||
self.response: Optional[str] = None
|
||||
# 核心服务实例化
|
||||
self.chat_manager = file_M.ChatManager()
|
||||
self.user = usermod.User(user_id=uid)
|
||||
self.rebot_id = id
|
||||
|
||||
# 动态加载数据
|
||||
if gid:
|
||||
self.group = usermod.Group(group_id=gid)
|
||||
self.group.current_user = self.user
|
||||
else:
|
||||
self.group = None
|
||||
|
||||
@dataclass
|
||||
class PluginPermission:
|
||||
access_private: bool = False # 允许处理私聊消息
|
||||
access_group: bool = True # 允许处理群消息
|
||||
read_history: bool = False # 允许读取历史记录
|
||||
|
||||
from pathlib import Path
|
||||
import toml
|
||||
import os
|
||||
|
||||
class BasePlugin(ABC):
|
||||
def __init__(self, ctx: MessageContext):
|
||||
self.ctx = ctx
|
||||
self._config_dir = self._get_plugin_config_path()
|
||||
self._config_manager = file_M.ConfigManager(self._config_dir)
|
||||
|
||||
def _get_plugin_resource(self, resource_path: str) -> bytes:
|
||||
plugin_name = self.__class__.__name__.lower()
|
||||
zip_path = Path("plugins") / f"{plugin_name}.zip"
|
||||
|
||||
if not zip_path.exists():
|
||||
raise FileNotFoundError(f"插件ZIP包不存在: {zip_path}")
|
||||
|
||||
# 遍历可能的ZIP内路径
|
||||
possible_paths = (
|
||||
resource_path, # 直接路径(config.toml)
|
||||
f"{plugin_name}/{resource_path}" # 插件子目录路径(myplugin/config.toml)
|
||||
)
|
||||
|
||||
with zipfile.ZipFile(zip_path, 'r') as zf:
|
||||
for path in possible_paths:
|
||||
if path in zf.namelist():
|
||||
return zf.read(path)
|
||||
|
||||
raise FileNotFoundError(f"文件 '{resource_path}' 不在ZIP包中")
|
||||
|
||||
def _ensure_config_exists(self):
|
||||
"""确保配置文件存在(不存在时从ZIP复制默认配置)"""
|
||||
config_file = Path(self._config_dir) / "config.toml"
|
||||
|
||||
# 外部配置已存在则直接返回
|
||||
if config_file.exists():
|
||||
return
|
||||
|
||||
# 尝试从ZIP复制默认配置
|
||||
try:
|
||||
|
||||
for ext in ['.toml', '.json', '.yaml']:
|
||||
try:
|
||||
config_data = self._get_plugin_resource(f"config{ext}")
|
||||
|
||||
# 确保配置目录存在
|
||||
config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 统一保存为.toml格式(或保持原格式)
|
||||
config_file.write_bytes(config_data)
|
||||
print(f"✅ 已将默认配置复制到: {config_file}")
|
||||
break
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
else:
|
||||
print(f"⚠️ 插件包内未找到默认配置文件")
|
||||
except Exception as e:
|
||||
print(f"❌ 初始化配置失败: {str(e)}")
|
||||
|
||||
@property
|
||||
def config(self) -> dict:
|
||||
"""始终读取外部配置文件(确保已通过_ensure_config_exists初始化)"""
|
||||
|
||||
try:
|
||||
config_data = self._config_manager.load_config("config")
|
||||
except:
|
||||
self._ensure_config_exists()
|
||||
config_data = self._config_manager.load_config("config")
|
||||
return config_data or {}
|
||||
|
||||
def _get_plugin_config_path(self) -> str:
|
||||
"""获取插件配置目录路径(保持原有逻辑)"""
|
||||
plugin_name = self.__class__.__name__.lower()
|
||||
return str(Path("config") / plugin_name)
|
||||
|
||||
def save_config(self, config_dict: dict) -> bool:
|
||||
"""保存配置到外部目录(与之前逻辑一致)"""
|
||||
self._config_manager.update_config({"config": config_dict})
|
||||
return self._config_manager.save_config()
|
119
src/modules/user_modules.py
Normal file
119
src/modules/user_modules.py
Normal file
@ -0,0 +1,119 @@
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
import src.file_store_api as filer
|
||||
|
||||
config_m = filer.ConfigManager()
|
||||
url = config_m.load_config("config")
|
||||
mainurl = url.get("app").get("send_url")
|
||||
|
||||
class User:
|
||||
def __init__(self, user_id,url = mainurl):
|
||||
self.user_id = user_id
|
||||
self.url = url
|
||||
self.nickname = None
|
||||
self.get_name()
|
||||
self.messages = []
|
||||
self.signal = True
|
||||
self.db = filer.ChatManager()
|
||||
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
获取用户的名称
|
||||
"""
|
||||
try:
|
||||
response = requests.post('{0}/ArkSharePeer'.format(self.url),json={'user_id':self.user_id})
|
||||
except:
|
||||
return 0
|
||||
if response.status_code == 200:
|
||||
# 解析返回的JSON数据
|
||||
response_data = response.json()
|
||||
# 检查是否有错误信息
|
||||
if response_data.get("status") == "ok":
|
||||
# 获取用户卡片信息
|
||||
ark_json = response_data.get("data", {})
|
||||
ark_msg_str = ark_json.get("arkMsg", "{}")
|
||||
try:
|
||||
ark_json = json.loads(ark_msg_str) # 字符串转字典
|
||||
except json.JSONDecodeError:
|
||||
ark_json = {}
|
||||
user_nick = ark_json.get("meta", {}).get("contact", {}).get("nickname")
|
||||
self.nickname = user_nick
|
||||
else:
|
||||
print(f"请求失败,错误信息: {response_data.get('errMsg')}")
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
|
||||
|
||||
def set_input_status(self, status):
|
||||
payload = json.dumps({
|
||||
"user_id": self.user_id,
|
||||
"event_type": status
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
while self.signal:
|
||||
print(f"刷新 {self.nickname} 的输入状态为: {status}")
|
||||
requests.request("POST","{0}/set_input_status".format(self.url), headers=headers, data=payload)
|
||||
time.sleep(0.5)
|
||||
|
||||
def send_message(self, message):
|
||||
requests.post(url='{0}/send_private_msg'.format(self.url), json={'user_id':self.user_id, 'message':message})
|
||||
self.db.save_private_message(self,role = 'assistant',content=message)#保存发送的消息
|
||||
|
||||
class Group:
|
||||
def __init__(self, group_id,url = mainurl,user=None,users=None):
|
||||
self.url = url
|
||||
self.group_id = group_id
|
||||
self.current_user = user
|
||||
self.nickname = None
|
||||
self.get_group_name()
|
||||
self.users = users
|
||||
self.get_group_users()
|
||||
self.messages =[]
|
||||
self.db = filer.ChatManager()
|
||||
|
||||
def get_group_name(self):
|
||||
"""
|
||||
获取群组的名称
|
||||
"""
|
||||
try:
|
||||
response = requests.post('{0}/ArkSharePeer'.format(self.url), json={'group_id': self.group_id})
|
||||
except:
|
||||
return 0
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if response_data.get("status") == "ok":
|
||||
ark_json = response_data.get("data", {})
|
||||
ark_msg_str = ark_json.get("arkJson", "{}")
|
||||
try:
|
||||
ark_json = json.loads(ark_msg_str)
|
||||
except json.JSONDecodeError:
|
||||
ark_json = {}
|
||||
group_name = ark_json.get("meta", {}).get("contact", {}).get("nickname")
|
||||
self.nickname = group_name
|
||||
else:
|
||||
print(f"请求失败,错误信息: {response_data.get('errMsg')}")
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
|
||||
def get_group_users(self):
|
||||
try:
|
||||
response = requests.post('{0}/get_group_member_list'.format(self.url), json={'group_id':self.group_id,'no_cache': False})
|
||||
except:
|
||||
return 0
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if response_data.get("status") == "ok":
|
||||
group_users = response_data.get("data", {})
|
||||
self.users = group_users
|
||||
else:
|
||||
print(f"请求失败,错误信息: {response_data.get('errMsg')}")
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
|
||||
def send_message(self,message):
|
||||
requests.post(url='{0}/send_group_msg'.format(self.url), json={'group_id': self.group_id, 'message': message})
|
||||
self.db.save_group_message(self,'assistant',message, sender_id=0)#保存发送的消息
|
261
src/plugin_manager.py
Normal file
261
src/plugin_manager.py
Normal file
@ -0,0 +1,261 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import zipfile
|
||||
import tempfile
|
||||
import importlib.util
|
||||
import importlib.metadata
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Type, Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
from packaging import version
|
||||
from contextlib import contextmanager
|
||||
|
||||
from src.file_store_api import ConfigManager as ConfigM
|
||||
from src.modules import plugin_modules as plugin_mod
|
||||
|
||||
BasePlugin = plugin_mod.BasePlugin
|
||||
MessageContext = plugin_mod.MessageContext
|
||||
|
||||
config_manager = ConfigM()
|
||||
plugin_config = config_manager.load_config("config").get("plugins", {})
|
||||
PLUGIN_DIR = os.path.join(*plugin_config.get("dir", ["plugins"]))
|
||||
# ---- 核心修改部分 ----
|
||||
|
||||
class DependencyManager:
|
||||
"""处理插件内嵌依赖的专用管理器"""
|
||||
def __init__(self):
|
||||
self._package_roots = set()
|
||||
|
||||
@contextmanager
|
||||
def isolated_import(self, package_root: str):
|
||||
"""上下文管理器,临时添加包路径"""
|
||||
if package_root not in sys.path:
|
||||
sys.path.insert(0, package_root)
|
||||
self._package_roots.add(package_root)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if package_root in sys.path: # 确保不重复移除
|
||||
sys.path.remove(package_root)
|
||||
self._package_roots.discard(package_root)
|
||||
|
||||
def _parse_requirement(self, req: str) -> tuple:
|
||||
"""解析依赖项字符串"""
|
||||
ops = {"==", ">=", "<=", ">", "<", "~=", "!="}
|
||||
for op in ops:
|
||||
if op in req:
|
||||
parts = req.split(op, 1)
|
||||
return parts[0].strip(), op, parts[1].strip()
|
||||
return req.strip(), None, None
|
||||
|
||||
def check_embedded_dependencies(self, plugin_dir: str) -> bool:
|
||||
"""检查插件自带的依赖是否可用"""
|
||||
packages_dir = os.path.join(plugin_dir, "packages")
|
||||
if not os.path.exists(packages_dir):
|
||||
return True
|
||||
|
||||
# 检查requirements.txt或dependencies.json
|
||||
req_file = os.path.join(plugin_dir, "requirements.txt")
|
||||
if not os.path.exists(req_file):
|
||||
req_file = os.path.join(plugin_dir, "dependencies.json")
|
||||
|
||||
if os.path.exists(req_file):
|
||||
return self._validate_dependencies(req_file, packages_dir)
|
||||
return True
|
||||
|
||||
def _validate_dependencies(self, req_file: str, packages_dir: str) -> bool:
|
||||
"""验证依赖是否满足"""
|
||||
requirements = self._load_requirements(req_file)
|
||||
if not requirements:
|
||||
return True
|
||||
|
||||
with self.isolated_import(packages_dir):
|
||||
for req in requirements:
|
||||
try:
|
||||
pkg, op, ver = self._parse_requirement(req)
|
||||
installed = importlib.metadata.version(pkg)
|
||||
|
||||
if op and ver:
|
||||
if not self._compare_versions(installed, op, ver):
|
||||
print(f"error: 依赖版本不匹配: 需要 {req},但找到 {installed}")
|
||||
return False
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
print(f"error: 依赖未找到: {req}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _load_requirements(self, req_file: str) -> List[str]:
|
||||
"""加载依赖文件"""
|
||||
try:
|
||||
if req_file.endswith('.json'):
|
||||
with open(req_file) as f:
|
||||
data = json.load(f)
|
||||
return data.get("requirements", [])
|
||||
else:
|
||||
with open(req_file) as f:
|
||||
return [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
||||
except Exception as e:
|
||||
print(f"error: 读取依赖文件失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def _compare_versions(self, installed: str, op: str, required: str) -> bool:
|
||||
"""比较版本号"""
|
||||
iv = version.parse(installed)
|
||||
rv = version.parse(required)
|
||||
|
||||
if op == "==": return iv == rv
|
||||
elif op == ">=": return iv >= rv
|
||||
elif op == "<=": return iv <= rv
|
||||
elif op == ">": return iv > rv
|
||||
elif op == "<": return iv < rv
|
||||
elif op == "~=": return iv >= rv and iv < version.parse(self._next_major(required))
|
||||
elif op == "!=": return iv != rv
|
||||
return True
|
||||
|
||||
def _next_major(self, ver: str) -> str:
|
||||
"""获取下一个主版本号"""
|
||||
parts = ver.split('.')
|
||||
if parts:
|
||||
try:
|
||||
parts[0] = str(int(parts[0]) + 1)
|
||||
return '.'.join(parts)
|
||||
except ValueError:
|
||||
pass
|
||||
return ver + ".0"
|
||||
|
||||
# ---- 修改后的插件管理器 ----
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self):
|
||||
self._plugins: Dict[str, Type[BasePlugin]] = {} # 插件类注册表
|
||||
self._active_instances: Dict[str, BasePlugin] = {} # 激活的插件实例
|
||||
self._hook_registry: Dict[str, List[Callable]] = {} # 兼容旧版钩子
|
||||
self._temp_dirs: List[str] = [] # 临时目录记录
|
||||
self._dependency_manager = DependencyManager() # 新增依赖管理器\
|
||||
self.scan_plugins
|
||||
|
||||
def scan_plugins(self):
|
||||
"""扫描并加载所有ZIP插件(保持原接口不变)"""
|
||||
if not os.path.exists(PLUGIN_DIR):
|
||||
os.makedirs(PLUGIN_DIR, exist_ok=True)
|
||||
|
||||
for item in Path(PLUGIN_DIR).glob("**/*.zip"):
|
||||
self.load_plugin(str(item))
|
||||
|
||||
def load_plugin(self, zip_path: str) -> bool:
|
||||
"""动态加载ZIP格式插件(支持内嵌依赖)"""
|
||||
try:
|
||||
# 解压到临时目录(保持原逻辑)
|
||||
temp_dir = tempfile.mkdtemp(prefix=f"plugin_{Path(zip_path).stem}_")
|
||||
self._temp_dirs.append(temp_dir)
|
||||
self._extract_zip(zip_path, temp_dir)
|
||||
|
||||
# 检查并加载内嵌依赖(新增功能)
|
||||
if not self._load_embedded_dependencies(temp_dir):
|
||||
print(f"error: 依赖检查失败,跳过插件: {Path(zip_path).name}")
|
||||
return False
|
||||
|
||||
# 动态导入主模块(保持原逻辑)
|
||||
return self._import_plugin_module(temp_dir, Path(zip_path).stem)
|
||||
|
||||
except Exception as e:
|
||||
print(f"error: 加载插件失败 {Path(zip_path).name}: {str(e)}")
|
||||
return False
|
||||
|
||||
def _extract_zip(self, zip_path: str, target_dir: str):
|
||||
"""解压ZIP文件(抽取为独立方法)"""
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(target_dir)
|
||||
|
||||
def _load_embedded_dependencies(self, plugin_dir: str) -> bool:
|
||||
"""加载插件内嵌的依赖包(新增方法)"""
|
||||
# 检查依赖是否满足
|
||||
if not self._dependency_manager.check_embedded_dependencies(plugin_dir):
|
||||
return False
|
||||
|
||||
# 如果有packages目录,添加到导入路径
|
||||
packages_dir = os.path.join(plugin_dir, "packages")
|
||||
if os.path.exists(packages_dir):
|
||||
sys.path.insert(0, packages_dir)
|
||||
print(f"已加载插件内嵌依赖: {packages_dir}")
|
||||
|
||||
return True
|
||||
|
||||
def _import_plugin_module(self, plugin_dir: str, plugin_name: str) -> bool:
|
||||
"""导入插件主模块(重构为独立方法)"""
|
||||
entry_file = Path(plugin_dir) / "process.py"
|
||||
if not entry_file.exists():
|
||||
print(f"error: 插件入口文件不存在: {plugin_name}")
|
||||
return False
|
||||
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
f"plugins.{plugin_name}",
|
||||
str(entry_file)
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 自动注册插件类(保持原逻辑)
|
||||
return self._register_plugin_classes(module, plugin_name)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 导入插件模块失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def _register_plugin_classes(self, module, plugin_name: str) -> bool:
|
||||
"""注册插件类(抽取为独立方法)"""
|
||||
plugin_registered = False
|
||||
for name, obj in module.__dict__.items():
|
||||
if isinstance(obj, type) and issubclass(obj, BasePlugin) and obj != BasePlugin:
|
||||
self._plugins[plugin_name] = obj
|
||||
print(f"✅ 已注册插件类: {plugin_name}::{name}")
|
||||
plugin_registered = True
|
||||
|
||||
# 兼容旧版钩子注册(保持原逻辑)
|
||||
if hasattr(module, "register_hooks"):
|
||||
module.register_hooks(self)
|
||||
print(f"🔄 已注册旧版钩子: {plugin_name}")
|
||||
|
||||
return plugin_registered
|
||||
|
||||
def process_message(self, uid: str, gid: Optional[str], message: str) -> str:
|
||||
"""主消息处理入口(保持完全兼容)"""
|
||||
ctx = MessageContext(uid, gid, message)
|
||||
|
||||
# 优先执行新版插件流程
|
||||
for name, plugin_cls in self._plugins.items():
|
||||
try:
|
||||
plugin = plugin_cls(ctx)
|
||||
if result := plugin.process():
|
||||
ctx.response = result
|
||||
ctx.intercepted = True
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"⚠️ 插件错误 {name}: {str(e)}")
|
||||
|
||||
# 如果没有被拦截,运行旧版钩子
|
||||
if not ctx.intercepted:
|
||||
for hook in self._hook_registry.get("on_message", []):
|
||||
hook(ctx)
|
||||
|
||||
return ctx.response or "ok"
|
||||
|
||||
def register_hook(self, hook_name: str):
|
||||
"""兼容旧版钩子注册(装饰器模式,保持原样)"""
|
||||
def decorator(func):
|
||||
if hook_name not in self._hook_registry:
|
||||
self._hook_registry[hook_name] = []
|
||||
self._hook_registry[hook_name].append(func)
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源(保持原逻辑)"""
|
||||
for temp_dir in self._temp_dirs:
|
||||
if os.path.exists(temp_dir):
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
Reference in New Issue
Block a user