Files
MineNasAI/src/minenasai/gateway/router.py

204 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""智能路由模块
评估任务复杂度,决定处理方式
"""
from __future__ import annotations
import re
from typing import Any
from minenasai.core import get_logger
from minenasai.gateway.protocol import TaskComplexity
logger = get_logger(__name__)
# 复杂任务关键词
COMPLEX_KEYWORDS = [
# 编程相关
"实现", "开发", "编写", "重构", "优化", "调试", "修复bug",
"创建项目", "搭建", "部署", "迁移",
# 文件操作
"批量", "遍历", "递归", "所有文件",
# 分析
"分析代码", "审查", "review", "架构设计",
]
# 简单任务关键词
SIMPLE_KEYWORDS = [
# 查询
"是什么", "什么是", "解释", "说明", "介绍",
"几点", "天气", "日期", "时间",
# 状态
"状态", "运行情况", "磁盘", "内存", "CPU",
# 简单操作
"打开", "关闭", "启动", "停止",
]
# 需要工具的关键词
TOOL_KEYWORDS = [
"查看", "读取", "列出", "搜索", "查找",
"执行", "运行", "计算",
"文件", "目录", "路径",
]
# 命令前缀
COMMAND_PREFIXES = {
"/快速": TaskComplexity.SIMPLE,
"/简单": TaskComplexity.SIMPLE,
"/深度": TaskComplexity.COMPLEX,
"/复杂": TaskComplexity.COMPLEX,
"/tui": TaskComplexity.COMPLEX,
"/TUI": TaskComplexity.COMPLEX,
}
class SmartRouter:
"""智能路由器
基于启发式规则评估任务复杂度
"""
def __init__(
self,
simple_max_length: int = 100,
complex_min_length: int = 500,
) -> None:
"""初始化路由器
Args:
simple_max_length: 简单任务的最大长度
complex_min_length: 复杂任务的最小长度
"""
self.simple_max_length = simple_max_length
self.complex_min_length = complex_min_length
def evaluate(self, content: str, metadata: dict[str, Any] | None = None) -> dict[str, Any]:
"""评估任务复杂度
Args:
content: 用户输入内容
metadata: 额外元数据(如历史上下文)
Returns:
评估结果,包含 complexity, confidence, reason, suggested_handler
"""
content = content.strip()
metadata = metadata or {}
# 检查命令前缀覆盖
for prefix, complexity in COMMAND_PREFIXES.items():
if content.startswith(prefix):
return {
"complexity": complexity,
"confidence": 1.0,
"reason": f"用户指定 {prefix}",
"suggested_handler": self._get_handler(complexity),
"content": content[len(prefix):].strip(),
}
# 计算各项得分
scores = {
"simple": 0.0,
"medium": 0.0,
"complex": 0.0,
}
# 长度评估
length = len(content)
if length <= self.simple_max_length:
scores["simple"] += 0.3
elif length >= self.complex_min_length:
scores["complex"] += 0.3
else:
scores["medium"] += 0.2
# 关键词评估
content_lower = content.lower()
simple_matches = sum(1 for kw in SIMPLE_KEYWORDS if kw in content_lower)
complex_matches = sum(1 for kw in COMPLEX_KEYWORDS if kw in content_lower)
tool_matches = sum(1 for kw in TOOL_KEYWORDS if kw in content_lower)
if simple_matches > 0:
scores["simple"] += min(0.4, simple_matches * 0.15)
if complex_matches > 0:
scores["complex"] += min(0.5, complex_matches * 0.2)
if tool_matches > 0:
scores["medium"] += min(0.3, tool_matches * 0.1)
# 问号检测(通常是简单问题)
if content.endswith("?") or content.endswith(""):
scores["simple"] += 0.1
# 代码块检测
if "```" in content or re.search(r"def\s+\w+|class\s+\w+|function\s+\w+", content):
scores["complex"] += 0.3
# 多步骤检测
if re.search(r"\d+\.\s|第[一二三四五六七八九十]+步|首先.*然后|step\s*\d+", content_lower):
scores["complex"] += 0.2
# 确定复杂度
max_score = max(scores.values())
if scores["complex"] == max_score and scores["complex"] >= 0.3:
complexity = TaskComplexity.COMPLEX
elif scores["simple"] == max_score and scores["simple"] >= 0.3:
complexity = TaskComplexity.SIMPLE
else:
complexity = TaskComplexity.MEDIUM
# 计算置信度
total = sum(scores.values())
confidence = max_score / total if total > 0 else 0.5
# 生成原因
reasons = []
if length <= self.simple_max_length:
reasons.append("短文本")
elif length >= self.complex_min_length:
reasons.append("长文本")
if simple_matches:
reasons.append(f"简单关键词x{simple_matches}")
if complex_matches:
reasons.append(f"复杂关键词x{complex_matches}")
if tool_matches:
reasons.append(f"工具关键词x{tool_matches}")
return {
"complexity": complexity,
"confidence": round(confidence, 2),
"reason": ", ".join(reasons) if reasons else "综合评估",
"suggested_handler": self._get_handler(complexity),
"scores": scores,
"content": content,
}
def _get_handler(self, complexity: TaskComplexity) -> str:
"""获取建议的处理器
Args:
complexity: 任务复杂度
Returns:
处理器名称
"""
handlers = {
TaskComplexity.SIMPLE: "quick_response",
TaskComplexity.MEDIUM: "agent_execute",
TaskComplexity.COMPLEX: "webtui_redirect",
}
return handlers.get(complexity, "agent_execute")
# 全局路由器实例
_router: SmartRouter | None = None
def get_router() -> SmartRouter:
"""获取全局路由器实例"""
global _router
if _router is None:
_router = SmartRouter()
return _router