diff --git a/src/minenasai/__init__.py b/src/minenasai/__init__.py index 830a254..44c6632 100644 --- a/src/minenasai/__init__.py +++ b/src/minenasai/__init__.py @@ -1,3 +1,36 @@ -"""MineNASAI - 基于NAS的智能个人AI助理""" +"""MineNASAI - 基于NAS的智能个人AI助理 + +MineNASAI 是一个部署在家用 NAS 上的智能 AI 助理系统,支持: +- 多渠道接入(飞书、企业微信等) +- 多 LLM 后端(OpenAI、Anthropic、Gemini、国产模型等) +- 智能任务路由与复杂度评估 +- 安全的工具执行与权限管理 +- Web 终端管理界面 + +模块结构: + - core: 核心基础设施(配置、日志、缓存、监控、数据库) + - agent: AI Agent 运行时和工具管理 + - gateway: 消息网关和渠道接入 + - llm: 多模型 LLM 客户端 + - scheduler: 定时任务调度 + - webtui: Web 管理界面 + +快速开始: + >>> from minenasai.core import get_settings, setup_logging + >>> from minenasai.llm import get_llm_manager + >>> from minenasai.agent import get_agent_runtime + >>> + >>> # 初始化 + >>> settings = get_settings() + >>> setup_logging(settings) + >>> + >>> # 获取 LLM 管理器 + >>> llm_manager = get_llm_manager() + >>> + >>> # 获取 Agent 运行时 + >>> agent = await get_agent_runtime() +""" __version__ = "0.1.0" +__author__ = "MineNASAI Team" +__license__ = "MIT" diff --git a/src/minenasai/agent/__init__.py b/src/minenasai/agent/__init__.py index 9524e5b..7a45109 100644 --- a/src/minenasai/agent/__init__.py +++ b/src/minenasai/agent/__init__.py @@ -1,21 +1,53 @@ -"""Agent 模块 +"""Agent 模块 - AI Agent 运行时和工具管理 -提供 Agent 运行时、工具管理、会话管理 +本模块提供 AI Agent 的核心功能: + +Agent 运行时 (runtime): + - AgentRuntime: Agent 执行引擎,支持多轮对话和工具调用 + - get_agent_runtime(): 获取 Agent 运行时单例 + +权限管理 (permissions): + - PermissionManager: 工具权限管理器 + - DangerLevel: 危险等级枚举(SAFE/LOW/MEDIUM/HIGH/CRITICAL) + - 支持细粒度的路径和参数权限控制 + - 高危操作需要用户确认 + +工具注册 (tool_registry): + - ToolRegistry: 工具注册表 + - @tool 装饰器: 快速注册工具函数 + - register_builtin_tools(): 注册内置工具 + +内置工具 (tools): + - read_file: 读取文件内容 + - list_directory: 列出目录文件 + - python_eval: 安全的 Python 表达式求值 + +使用示例: + >>> from minenasai.agent import get_agent_runtime, tool + >>> + >>> # 注册自定义工具 + >>> @tool(name="my_tool", description="自定义工具") + ... async def my_tool(param: str) -> str: + ... return f"处理: {param}" + >>> + >>> # 运行 Agent + >>> agent = await get_agent_runtime() + >>> response = await agent.run("帮我处理这个任务") """ -from minenasai.agent.runtime import AgentRuntime, get_agent_runtime -from minenasai.agent.tools.basic import get_basic_tools from minenasai.agent.permissions import ( DangerLevel, PermissionManager, get_permission_manager, ) +from minenasai.agent.runtime import AgentRuntime, get_agent_runtime from minenasai.agent.tool_registry import ( ToolRegistry, get_tool_registry, register_builtin_tools, tool, ) +from minenasai.agent.tools.basic import get_basic_tools __all__ = [ "AgentRuntime", diff --git a/src/minenasai/agent/permissions.py b/src/minenasai/agent/permissions.py index 4e714ac..f99faa8 100644 --- a/src/minenasai/agent/permissions.py +++ b/src/minenasai/agent/permissions.py @@ -6,17 +6,21 @@ from __future__ import annotations import asyncio +from collections.abc import Callable, Coroutine from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Callable, Coroutine +from enum import StrEnum +from typing import Any from minenasai.core import get_logger, get_settings logger = get_logger(__name__) -class DangerLevel(str, Enum): - """危险等级""" +class DangerLevel(StrEnum): + """危险等级 + + 定义工具操作的风险级别,用于权限控制和确认流程。 + """ SAFE = "safe" # 只读操作,无风险 LOW = "low" # 低风险,写入工作目录 @@ -25,8 +29,11 @@ class DangerLevel(str, Enum): CRITICAL = "critical" # 极高危,系统级操作 -class ConfirmationStatus(str, Enum): - """确认状态""" +class ConfirmationStatus(StrEnum): + """确认状态 + + 高危操作确认请求的状态。 + """ PENDING = "pending" APPROVED = "approved" @@ -168,14 +175,14 @@ class PermissionManager: self, tool_name: str, params: dict[str, Any] | None = None, - user_id: str | None = None, + _user_id: str | None = None, # 预留:将来用于基于用户的权限检查 ) -> tuple[bool, str]: """检查工具执行权限 Args: tool_name: 工具名称 params: 执行参数 - user_id: 用户 ID + _user_id: 用户 ID(预留参数,将来用于基于用户的权限检查) Returns: (是否允许, 原因) @@ -375,7 +382,35 @@ _permission_manager: PermissionManager | None = None def get_permission_manager() -> PermissionManager: - """获取全局权限管理器""" + """获取全局权限管理器单例 + + 返回全局唯一的权限管理器实例,用于检查工具执行权限 + 和管理高危操作确认流程。 + + Returns: + PermissionManager: 全局权限管理器 + + Example: + >>> pm = get_permission_manager() + >>> + >>> # 检查权限 + >>> allowed, reason = pm.check_permission( + ... "write_file", + ... {"path": "/tmp/test.txt"} + ... ) + >>> + >>> # 检查是否需要确认 + >>> if pm.requires_confirmation("delete_file"): + ... request_id = await pm.request_confirmation( + ... "delete_file", + ... {"path": "/important/file.txt"} + ... ) + ... # 等待用户确认... + + Note: + 权限管理器会根据配置文件中的 security 设置 + 决定哪些危险等级的操作需要用户确认。 + """ global _permission_manager if _permission_manager is None: _permission_manager = PermissionManager() diff --git a/src/minenasai/agent/runtime.py b/src/minenasai/agent/runtime.py index 91ba7b2..231b728 100644 --- a/src/minenasai/agent/runtime.py +++ b/src/minenasai/agent/runtime.py @@ -6,11 +6,12 @@ from __future__ import annotations import time -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any from minenasai.core import get_audit_logger, get_logger, get_settings from minenasai.core.session_store import get_session_store -from minenasai.llm import LLMManager, get_llm_manager +from minenasai.llm import get_llm_manager from minenasai.llm.base import Message, Provider, ToolDefinition logger = get_logger(__name__) diff --git a/src/minenasai/agent/tool_registry.py b/src/minenasai/agent/tool_registry.py index 3202572..2179b43 100644 --- a/src/minenasai/agent/tool_registry.py +++ b/src/minenasai/agent/tool_registry.py @@ -6,11 +6,12 @@ from __future__ import annotations import inspect +from collections.abc import Callable, Coroutine from dataclasses import dataclass, field -from typing import Any, Callable, Coroutine, get_type_hints +from typing import Any, get_type_hints -from minenasai.core import get_logger from minenasai.agent.permissions import DangerLevel, ToolPermission, get_permission_manager +from minenasai.core import get_logger from minenasai.llm.base import ToolDefinition logger = get_logger(__name__) diff --git a/src/minenasai/agent/tools/__init__.py b/src/minenasai/agent/tools/__init__.py index f552b1a..d280062 100644 --- a/src/minenasai/agent/tools/__init__.py +++ b/src/minenasai/agent/tools/__init__.py @@ -1,9 +1,9 @@ """内置工具集""" from minenasai.agent.tools.basic import ( - read_file_tool, list_directory_tool, python_eval_tool, + read_file_tool, ) __all__ = [ diff --git a/src/minenasai/agent/tools/basic.py b/src/minenasai/agent/tools/basic.py index ad50b3b..d986f82 100644 --- a/src/minenasai/agent/tools/basic.py +++ b/src/minenasai/agent/tools/basic.py @@ -5,8 +5,6 @@ from __future__ import annotations -import os -from pathlib import Path from typing import Any from minenasai.core import get_logger diff --git a/src/minenasai/core/__init__.py b/src/minenasai/core/__init__.py index f982af2..95c8c38 100644 --- a/src/minenasai/core/__init__.py +++ b/src/minenasai/core/__init__.py @@ -1,6 +1,42 @@ -"""核心模块 +"""核心模块 - 提供系统基础设施 -提供配置管理、日志系统、数据库、监控、缓存等基础功能 +本模块包含 MineNASAI 的核心基础功能: + +配置管理 (config): + - Settings: 全局配置数据类 + - get_settings(): 获取配置单例 + - load_config(): 从文件加载配置 + +日志系统 (logging): + - setup_logging(): 初始化日志系统 + - get_logger(): 获取模块日志记录器 + - AuditLogger: 审计日志记录器 + +监控系统 (monitoring): + - SystemMetrics: 系统指标收集 + - HealthChecker: 健康检查器 + - setup_monitoring(): 初始化监控 + +缓存系统 (cache): + - MemoryCache: 内存缓存,支持 TTL 和 LRU 淘汰 + - RateLimiter: 令牌桶限流器 + +数据库 (database): + - Database: SQLite 异步数据库封装 + - 支持会话、消息、任务的持久化 + +使用示例: + >>> from minenasai.core import get_settings, setup_logging, get_logger + >>> + >>> # 初始化配置和日志 + >>> settings = get_settings() + >>> setup_logging(settings) + >>> logger = get_logger(__name__) + >>> + >>> # 使用缓存 + >>> from minenasai.core import get_response_cache + >>> cache = get_response_cache() + >>> await cache.set("key", "value", ttl=300) """ from minenasai.core.cache import ( diff --git a/src/minenasai/core/cache.py b/src/minenasai/core/cache.py index 8e4f8f1..e3ff80f 100644 --- a/src/minenasai/core/cache.py +++ b/src/minenasai/core/cache.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +import contextlib import hashlib import time from dataclasses import dataclass, field @@ -71,10 +72,8 @@ class MemoryCache(Generic[T]): """停止后台清理任务""" if self._cleanup_task: self._cleanup_task.cancel() - try: + with contextlib.suppress(asyncio.CancelledError): await self._cleanup_task - except asyncio.CancelledError: - pass self._cleanup_task = None async def _cleanup_loop(self) -> None: diff --git a/src/minenasai/core/config.py b/src/minenasai/core/config.py index ebf18ac..1f47c09 100644 --- a/src/minenasai/core/config.py +++ b/src/minenasai/core/config.py @@ -296,7 +296,25 @@ _settings: Settings | None = None def get_settings() -> Settings: - """获取全局配置实例""" + """获取全局配置单例 + + 返回全局唯一的配置实例。首次调用时从配置文件加载, + 后续调用返回缓存的实例。 + + Returns: + Settings: 全局配置对象 + + Example: + >>> settings = get_settings() + >>> print(settings.app.name) + MineNASAI + >>> print(settings.llm.default_model) + claude-sonnet-4-20250514 + + Note: + 配置加载优先级:环境变量 > 配置文件 > 默认值 + 使用 reset_settings() 可重置配置(仅用于测试) + """ global _settings if _settings is None: _settings = load_config() @@ -304,6 +322,13 @@ def get_settings() -> Settings: def reset_settings() -> None: - """重置全局配置(用于测试)""" + """重置全局配置单例 + + 清除缓存的配置实例,下次调用 get_settings() 时将重新加载配置。 + 主要用于测试场景,生产环境慎用。 + + Warning: + 此函数会清除所有已加载的配置,可能影响正在运行的组件。 + """ global _settings _settings = None diff --git a/src/minenasai/core/monitoring.py b/src/minenasai/core/monitoring.py index 0efce27..92767cf 100644 --- a/src/minenasai/core/monitoring.py +++ b/src/minenasai/core/monitoring.py @@ -11,10 +11,11 @@ from __future__ import annotations import asyncio import time import traceback +from collections.abc import Callable, Coroutine from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Callable, Coroutine +from typing import Any from fastapi import FastAPI, Request, Response from fastapi.responses import JSONResponse @@ -161,7 +162,7 @@ class HealthChecker: result.latency_ms = (time.time() - start_time) * 1000 self._results[name] = result return result - except asyncio.TimeoutError: + except TimeoutError: result = ComponentHealth( name=name, status=HealthStatus.UNHEALTHY, diff --git a/src/minenasai/core/session_store.py b/src/minenasai/core/session_store.py index 8185cb8..4db60a8 100644 --- a/src/minenasai/core/session_store.py +++ b/src/minenasai/core/session_store.py @@ -7,8 +7,9 @@ from __future__ import annotations import json import time +from collections.abc import Iterator from pathlib import Path -from typing import Any, Iterator +from typing import Any from minenasai.core.config import expand_path diff --git a/src/minenasai/gateway/__init__.py b/src/minenasai/gateway/__init__.py index db4e825..4299348 100644 --- a/src/minenasai/gateway/__init__.py +++ b/src/minenasai/gateway/__init__.py @@ -1,6 +1,51 @@ -"""Gateway 模块 +"""Gateway 模块 - 消息网关和渠道接入 -提供 WebSocket 服务、通讯渠道接入、智能路由 +本模块提供消息网关服务,支持多渠道消息接入: + +WebSocket 服务 (server): + - 提供实时双向通信 + - 支持多客户端连接管理 + - 心跳检测和自动重连 + +渠道接入 (channels): + - FeishuChannel: 飞书机器人接入 + - WeWorkChannel: 企业微信机器人接入 + - 统一的消息格式转换 + +智能路由 (router): + - SmartRouter: 智能任务路由器 + - 自动评估消息复杂度 + - 根据复杂度选择处理策略 + +协议定义 (protocol): + - ChatMessage: 聊天消息模型 + - MessageType: 消息类型枚举 + - ChannelType: 渠道类型枚举 + +使用示例: + >>> from minenasai.gateway.router import SmartRouter + >>> from minenasai.gateway.protocol.schema import ChatMessage + >>> + >>> # 创建路由器 + >>> router = SmartRouter() + >>> + >>> # 评估任务复杂度 + >>> message = ChatMessage(content="帮我写一个 Python 爬虫") + >>> complexity = router.route(message) """ -__all__ = [] +from minenasai.gateway.protocol.schema import ( + ChannelType, + ChatMessage, + MessageType, + TaskComplexity, +) +from minenasai.gateway.router import SmartRouter + +__all__ = [ + "ChatMessage", + "MessageType", + "ChannelType", + "TaskComplexity", + "SmartRouter", +] diff --git a/src/minenasai/gateway/channels/feishu.py b/src/minenasai/gateway/channels/feishu.py index 2d07e70..4c405eb 100644 --- a/src/minenasai/gateway/channels/feishu.py +++ b/src/minenasai/gateway/channels/feishu.py @@ -5,8 +5,6 @@ from __future__ import annotations -import hashlib -import hmac import time from typing import Any @@ -33,8 +31,12 @@ class FeishuChannel(BaseChannel): self._tenant_access_token: str | None = None self._token_expires: float = 0 - async def verify_signature(self, request: Any) -> bool: - """验证飞书签名""" + async def verify_signature(self, _request: Any) -> bool: + """验证飞书签名 + + Args: + _request: HTTP 请求对象(待实现时使用) + """ # TODO: 实现签名验证 return True @@ -142,10 +144,10 @@ class FeishuChannel(BaseChannel): # 判断是用户还是群聊 chat_id = kwargs.get("chat_id") if chat_id: - url = f"https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=chat_id" + url = "https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=chat_id" receive_id = chat_id else: - url = f"https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=user_id" + url = "https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=user_id" receive_id = peer_id import json diff --git a/src/minenasai/gateway/channels/wework.py b/src/minenasai/gateway/channels/wework.py index b9df4fc..877aea1 100644 --- a/src/minenasai/gateway/channels/wework.py +++ b/src/minenasai/gateway/channels/wework.py @@ -5,7 +5,6 @@ from __future__ import annotations -import hashlib import time from typing import Any @@ -33,10 +32,13 @@ class WeworkChannel(BaseChannel): self._access_token: str | None = None self._token_expires: float = 0 - async def verify_signature(self, request: Any) -> bool: + async def verify_signature(self, _request: Any) -> bool: """验证企业微信签名 - 企业微信使用 SHA1 签名验证 + 企业微信使用 SHA1 签名验证。 + + Args: + _request: HTTP 请求对象(待实现时使用) """ # TODO: 实现签名验证 # signature = sorted([self.token, timestamp, nonce, echostr]) @@ -105,9 +107,16 @@ class WeworkChannel(BaseChannel): peer_id: str, content: str, message_type: str = "text", - **kwargs: Any, + **_kwargs: Any, # 预留:将来支持更多消息参数 ) -> bool: - """发送企业微信消息""" + """发送企业微信消息 + + Args: + peer_id: 接收者用户 ID + content: 消息内容 + message_type: 消息类型 + **_kwargs: 预留参数,将来支持更多消息选项 + """ access_token = await self._get_access_token() url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}" diff --git a/src/minenasai/gateway/protocol/schema.py b/src/minenasai/gateway/protocol/schema.py index ab3c852..be7716f 100644 --- a/src/minenasai/gateway/protocol/schema.py +++ b/src/minenasai/gateway/protocol/schema.py @@ -5,14 +5,17 @@ from __future__ import annotations -from enum import Enum +from enum import StrEnum from typing import Any from pydantic import BaseModel, Field -class MessageType(str, Enum): - """消息类型""" +class MessageType(StrEnum): + """消息类型 + + 定义 WebSocket 通信中的消息类型。 + """ # 客户端 -> 服务器 CHAT = "chat" # 普通聊天消息 @@ -32,8 +35,11 @@ class MessageType(str, Enum): CONFIRM = "confirm" # 确认请求/响应 -class ChannelType(str, Enum): - """渠道类型""" +class ChannelType(StrEnum): + """渠道类型 + + 定义消息来源的渠道类型。 + """ WEWORK = "wework" FEISHU = "feishu" @@ -41,8 +47,11 @@ class ChannelType(str, Enum): WEBSOCKET = "websocket" -class TaskComplexity(str, Enum): - """任务复杂度""" +class TaskComplexity(StrEnum): + """任务复杂度 + + 用于智能路由,决定任务的处理方式。 + """ SIMPLE = "simple" # 简单查询,直接回复 MEDIUM = "medium" # 中等任务,需要工具 diff --git a/src/minenasai/gateway/server.py b/src/minenasai/gateway/server.py index 6f7dd27..91ffa32 100644 --- a/src/minenasai/gateway/server.py +++ b/src/minenasai/gateway/server.py @@ -6,8 +6,9 @@ from __future__ import annotations import asyncio +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator +from typing import Any from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware @@ -152,10 +153,16 @@ async def list_sessions(agent_id: str | None = None) -> dict[str, Any]: async def handle_chat_message( websocket: WebSocket, - client_id: str, + _client_id: str, # 预留:将来用于客户端标识和会话管理 message: ChatMessage, ) -> None: - """处理聊天消息""" + """处理聊天消息 + + Args: + websocket: WebSocket 连接 + _client_id: 客户端 ID(预留,将来用于会话管理) + message: 聊天消息 + """ router = get_router() # 发送思考状态 diff --git a/src/minenasai/llm/__init__.py b/src/minenasai/llm/__init__.py index e5ee9fa..a11f7e7 100644 --- a/src/minenasai/llm/__init__.py +++ b/src/minenasai/llm/__init__.py @@ -1,6 +1,44 @@ -"""LLM 多模型客户端模块 +"""LLM 多模型客户端模块 - 统一的 AI API 接口 -支持多个 AI API 提供商的统一接口 +本模块提供统一的 LLM 客户端接口,支持多个 AI 提供商: + +支持的提供商: + - OpenAI (GPT-4, GPT-3.5) + - Anthropic (Claude 3) + - Google (Gemini) + - DeepSeek + - Moonshot (月之暗面) + - MiniMax + - Zhipu (智谱) + +核心类: + - BaseLLMClient: LLM 客户端基类 + - LLMManager: 多模型管理器,支持负载均衡和故障转移 + - Message: 统一消息格式 + - LLMResponse: 统一响应格式 + - ToolCall: 工具调用定义 + +特性: + - 统一的 API 接口 + - 自动重试和错误处理 + - 流式响应支持 + - 工具调用 (Function Calling) 支持 + - 响应缓存 + +使用示例: + >>> from minenasai.llm import get_llm_manager, Message + >>> + >>> # 获取 LLM 管理器 + >>> manager = get_llm_manager() + >>> + >>> # 发送消息 + >>> messages = [Message(role="user", content="你好")] + >>> response = await manager.chat(messages) + >>> print(response.content) + >>> + >>> # 流式响应 + >>> async for chunk in manager.stream(messages): + ... print(chunk.content, end="") """ from minenasai.llm.base import BaseLLMClient, LLMResponse, Message, ToolCall diff --git a/src/minenasai/llm/base.py b/src/minenasai/llm/base.py index 4b9a034..3097548 100644 --- a/src/minenasai/llm/base.py +++ b/src/minenasai/llm/base.py @@ -6,13 +6,17 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import AsyncIterator from dataclasses import dataclass, field -from enum import Enum -from typing import Any, AsyncIterator +from enum import StrEnum +from typing import Any -class Provider(str, Enum): - """LLM 提供商""" +class Provider(StrEnum): + """LLM 提供商 + + 支持的 AI API 提供商枚举。 + """ ANTHROPIC = "anthropic" # Claude OPENAI = "openai" # GPT diff --git a/src/minenasai/llm/clients/__init__.py b/src/minenasai/llm/clients/__init__.py index 7618bf1..9ec640f 100644 --- a/src/minenasai/llm/clients/__init__.py +++ b/src/minenasai/llm/clients/__init__.py @@ -1,12 +1,12 @@ """LLM 客户端实现""" from minenasai.llm.clients.anthropic import AnthropicClient -from minenasai.llm.clients.openai_compat import OpenAICompatClient from minenasai.llm.clients.deepseek import DeepSeekClient -from minenasai.llm.clients.zhipu import ZhipuClient +from minenasai.llm.clients.gemini import GeminiClient from minenasai.llm.clients.minimax import MiniMaxClient from minenasai.llm.clients.moonshot import MoonshotClient -from minenasai.llm.clients.gemini import GeminiClient +from minenasai.llm.clients.openai_compat import OpenAICompatClient +from minenasai.llm.clients.zhipu import ZhipuClient __all__ = [ "AnthropicClient", diff --git a/src/minenasai/llm/clients/anthropic.py b/src/minenasai/llm/clients/anthropic.py index 3eeadce..76d5098 100644 --- a/src/minenasai/llm/clients/anthropic.py +++ b/src/minenasai/llm/clients/anthropic.py @@ -3,7 +3,8 @@ from __future__ import annotations import json -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any import httpx @@ -74,7 +75,7 @@ class AnthropicClient(BaseLLMClient): max_tokens: int = 4096, temperature: float = 0.7, tools: list[ToolDefinition] | None = None, - **kwargs: Any, + **_kwargs: Any, # 预留:基类接口扩展参数 ) -> LLMResponse: """发送聊天请求""" client = await self._get_client() @@ -142,7 +143,7 @@ class AnthropicClient(BaseLLMClient): max_tokens: int = 4096, temperature: float = 0.7, tools: list[ToolDefinition] | None = None, - **kwargs: Any, + **_kwargs: Any, # 预留:基类接口扩展参数 ) -> AsyncIterator[StreamChunk]: """流式聊天""" client = await self._get_client() diff --git a/src/minenasai/llm/clients/gemini.py b/src/minenasai/llm/clients/gemini.py index 8cead36..16114dd 100644 --- a/src/minenasai/llm/clients/gemini.py +++ b/src/minenasai/llm/clients/gemini.py @@ -6,7 +6,8 @@ Gemini API 使用独特的接口格式 from __future__ import annotations import json -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any import httpx @@ -76,7 +77,7 @@ class GeminiClient(BaseLLMClient): max_tokens: int = 4096, temperature: float = 0.7, tools: list[ToolDefinition] | None = None, - **kwargs: Any, + **_kwargs: Any, # 预留:基类接口扩展参数 ) -> LLMResponse: """发送聊天请求""" client = await self._get_client() @@ -164,10 +165,10 @@ class GeminiClient(BaseLLMClient): system: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, - tools: list[ToolDefinition] | None = None, - **kwargs: Any, + _tools: list[ToolDefinition] | None = None, # 暂不支持流式工具调用 + **_kwargs: Any, # 预留:基类接口扩展参数 ) -> AsyncIterator[StreamChunk]: - """流式聊天""" + """流式聊天(暂不支持工具调用)""" client = await self._get_client() model = model or self.default_model diff --git a/src/minenasai/llm/clients/minimax.py b/src/minenasai/llm/clients/minimax.py index 2b891b4..22d27f2 100644 --- a/src/minenasai/llm/clients/minimax.py +++ b/src/minenasai/llm/clients/minimax.py @@ -5,8 +5,6 @@ MiniMax 使用 OpenAI 兼容接口 from __future__ import annotations -from typing import Any - import httpx from minenasai.llm.base import Provider diff --git a/src/minenasai/llm/clients/openai_compat.py b/src/minenasai/llm/clients/openai_compat.py index fb001d2..346b28f 100644 --- a/src/minenasai/llm/clients/openai_compat.py +++ b/src/minenasai/llm/clients/openai_compat.py @@ -6,7 +6,8 @@ from __future__ import annotations import json -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any import httpx @@ -87,7 +88,7 @@ class OpenAICompatClient(BaseLLMClient): max_tokens: int = 4096, temperature: float = 0.7, tools: list[ToolDefinition] | None = None, - **kwargs: Any, + **_kwargs: Any, # 预留:基类接口扩展参数 ) -> LLMResponse: """发送聊天请求""" client = await self._get_client() @@ -154,7 +155,7 @@ class OpenAICompatClient(BaseLLMClient): max_tokens: int = 4096, temperature: float = 0.7, tools: list[ToolDefinition] | None = None, - **kwargs: Any, + **_kwargs: Any, # 预留:基类接口扩展参数 ) -> AsyncIterator[StreamChunk]: """流式聊天""" client = await self._get_client() diff --git a/src/minenasai/llm/manager.py b/src/minenasai/llm/manager.py index a6adfbc..ac82b8e 100644 --- a/src/minenasai/llm/manager.py +++ b/src/minenasai/llm/manager.py @@ -5,7 +5,8 @@ from __future__ import annotations -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any from minenasai.core import get_logger, get_settings from minenasai.llm.base import ( @@ -142,10 +143,10 @@ class LLMManager: client = self._create_client(provider) if client: self._clients[provider] = client - logger.info(f"LLM 客户端已初始化", provider=provider.display_name) + logger.info("LLM 客户端已初始化", provider=provider.display_name) self._initialized = True - logger.info(f"LLM Manager 初始化完成", providers=list(self._clients.keys())) + logger.info("LLM Manager 初始化完成", providers=list(self._clients.keys())) def get_client(self, provider: Provider) -> BaseLLMClient | None: """获取指定提供商的客户端""" @@ -334,7 +335,30 @@ _llm_manager: LLMManager | None = None def get_llm_manager() -> LLMManager: - """获取全局 LLM 管理器""" + """获取全局 LLM 管理器单例 + + 返回全局唯一的 LLM 管理器实例。管理器会根据配置自动初始化 + 可用的 LLM 客户端。 + + Returns: + LLMManager: 全局 LLM 管理器 + + Example: + >>> manager = get_llm_manager() + >>> # 发送消息 + >>> response = await manager.chat([ + ... Message(role="user", content="你好") + ... ]) + >>> print(response.content) + >>> + >>> # 流式响应 + >>> async for chunk in manager.stream(messages): + ... print(chunk.content, end="") + + Note: + 首次调用时会自动初始化所有配置了 API Key 的客户端。 + 支持的提供商取决于环境变量或配置文件中的 API Key 设置。 + """ global _llm_manager if _llm_manager is None: _llm_manager = LLMManager() diff --git a/src/minenasai/scheduler/__init__.py b/src/minenasai/scheduler/__init__.py index 87e4829..3d5727e 100644 --- a/src/minenasai/scheduler/__init__.py +++ b/src/minenasai/scheduler/__init__.py @@ -1,6 +1,43 @@ -"""定时任务调度模块 +"""定时任务调度模块 - Cron 任务调度 -提供 Cron 任务调度功能 +本模块提供类似 Linux cron 的定时任务调度功能: + +核心类: + - CronScheduler: 调度器,管理所有定时任务 + - CronJob: 定时任务定义 + +Cron 表达式: + 支持标准 5 字段 cron 表达式:分 时 日 月 周 + - * : 任意值 + - */n : 每隔 n + - n-m : 范围 + - n,m : 列表 + +预设表达式: + - @hourly : 每小时 + - @daily : 每天 + - @weekly : 每周 + - @monthly : 每月 + +使用示例: + >>> from minenasai.scheduler import get_scheduler + >>> + >>> # 获取调度器 + >>> scheduler = get_scheduler() + >>> + >>> # 添加定时任务 + >>> async def my_task(): + ... print("任务执行中...") + >>> + >>> scheduler.add_job( + ... job_id="my_job", + ... cron_expr="0 9 * * *", # 每天 9:00 + ... func=my_task, + ... description="每日任务" + ... ) + >>> + >>> # 启动调度器 + >>> await scheduler.start() """ from minenasai.scheduler.cron import CronJob, CronScheduler, get_scheduler diff --git a/src/minenasai/scheduler/cron.py b/src/minenasai/scheduler/cron.py index 28ed3c6..c2506df 100644 --- a/src/minenasai/scheduler/cron.py +++ b/src/minenasai/scheduler/cron.py @@ -6,21 +6,24 @@ from __future__ import annotations import asyncio -import re +import contextlib import time +from collections.abc import Callable, Coroutine from dataclasses import dataclass, field from datetime import datetime, timedelta -from enum import Enum -from typing import Any, Callable, Coroutine +from enum import StrEnum +from typing import Any from minenasai.core import get_logger -from minenasai.core.database import get_database logger = get_logger(__name__) -class JobStatus(str, Enum): - """任务状态""" +class JobStatus(StrEnum): + """任务状态 + + 定时任务的执行状态。 + """ PENDING = "pending" RUNNING = "running" @@ -283,10 +286,8 @@ class CronScheduler: self._running = False if self._task: self._task.cancel() - try: + with contextlib.suppress(asyncio.CancelledError): await self._task - except asyncio.CancelledError: - pass logger.info("定时调度器已停止") async def _run_loop(self) -> None: diff --git a/src/minenasai/webtui/__init__.py b/src/minenasai/webtui/__init__.py index 93a85bc..a8ad79d 100644 --- a/src/minenasai/webtui/__init__.py +++ b/src/minenasai/webtui/__init__.py @@ -1,6 +1,36 @@ -"""Web TUI 模块 +"""Web TUI 模块 - Web 管理界面 -提供 Web 终端界面、SSH 连接管理 +本模块提供基于 Web 的终端管理界面: + +认证管理 (auth): + - AuthManager: 认证管理器 + - AuthToken: 认证令牌 + - 支持 JWT 令牌认证 + - 令牌刷新和撤销 + +SSH 管理 (ssh_manager): + - SSHManager: SSH 连接管理器 + - SSHSession: SSH 会话封装 + - 支持多会话管理 + - WebSocket 实时终端 + +Web 服务 (server): + - FastAPI 应用 + - 静态文件服务 + - WebSocket 终端 + - RESTful API + +使用示例: + >>> from minenasai.webtui import get_auth_manager + >>> + >>> # 获取认证管理器 + >>> auth = get_auth_manager() + >>> + >>> # 生成令牌 + >>> token = auth.generate_token(user_id="admin") + >>> + >>> # 验证令牌 + >>> user_info = auth.verify_token(token.token) """ from minenasai.webtui.auth import AuthManager, AuthToken, get_auth_manager diff --git a/src/minenasai/webtui/server.py b/src/minenasai/webtui/server.py index ba8771f..fac1819 100644 --- a/src/minenasai/webtui/server.py +++ b/src/minenasai/webtui/server.py @@ -7,13 +7,14 @@ from __future__ import annotations import asyncio import uuid +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, AsyncGenerator +from typing import Any from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse, HTMLResponse +from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from minenasai.core import get_logger, get_settings, setup_logging @@ -95,8 +96,12 @@ manager = ConnectionManager() @asynccontextmanager -async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - """应用生命周期管理""" +async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: + """应用生命周期管理 + + Args: + _app: FastAPI 应用实例(lifespan 标准签名要求) + """ settings = get_settings() setup_logging(settings.logging) logger.info("Web TUI 服务启动", port=settings.webtui.port) diff --git a/src/minenasai/webtui/ssh_manager.py b/src/minenasai/webtui/ssh_manager.py index e4a7de9..cec4463 100644 --- a/src/minenasai/webtui/ssh_manager.py +++ b/src/minenasai/webtui/ssh_manager.py @@ -6,10 +6,11 @@ from __future__ import annotations import asyncio +import contextlib import os import time -from pathlib import Path -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import paramiko @@ -128,10 +129,8 @@ class SSHSession: if self._read_task: self._read_task.cancel() - try: + with contextlib.suppress(asyncio.CancelledError): await self._read_task - except asyncio.CancelledError: - pass if self.channel: self.channel.close() diff --git a/tests/test_core.py b/tests/test_core.py index 255cce5..380424d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,8 +4,6 @@ from __future__ import annotations import json -import pytest - from minenasai.core import Settings, get_settings, load_config, reset_settings from minenasai.core.config import expand_path diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..e98ec04 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,345 @@ +"""数据库模块测试 + +测试 SQLite 异步数据库操作 +""" + +import pytest + +from minenasai.core.database import Database + + +class TestDatabase: + """数据库基本操作测试""" + + @pytest.fixture + async def db(self, tmp_path): + """创建临时数据库""" + db_path = tmp_path / "test.db" + db = Database(db_path) + await db.connect() + yield db + await db.close() + + @pytest.mark.asyncio + async def test_connect_and_close(self, tmp_path): + """测试数据库连接和关闭""" + db_path = tmp_path / "test.db" + db = Database(db_path) + + # 连接前访问 conn 应该报错 + with pytest.raises(RuntimeError, match="数据库未连接"): + _ = db.conn + + # 连接 + await db.connect() + assert db.conn is not None + + # 关闭 + await db.close() + assert db._conn is None + + +class TestAgentOperations: + """Agent 操作测试""" + + @pytest.fixture + async def db(self, tmp_path): + """创建临时数据库""" + db_path = tmp_path / "test.db" + db = Database(db_path) + await db.connect() + yield db + await db.close() + + @pytest.mark.asyncio + async def test_create_agent(self, db): + """测试创建 Agent""" + agent = await db.create_agent( + agent_id="test-agent-1", + name="Test Agent", + workspace_path="/tmp/workspace", + ) + + assert agent["id"] == "test-agent-1" + assert agent["name"] == "Test Agent" + assert agent["workspace_path"] == "/tmp/workspace" + assert agent["model"] == "claude-sonnet-4-20250514" + assert agent["sandbox_mode"] == "workspace" + assert "created_at" in agent + assert "updated_at" in agent + + @pytest.mark.asyncio + async def test_create_agent_custom_model(self, db): + """测试使用自定义模型创建 Agent""" + agent = await db.create_agent( + agent_id="test-agent-2", + name="Custom Agent", + workspace_path="/tmp/workspace", + model="gpt-4", + sandbox_mode="strict", + ) + + assert agent["model"] == "gpt-4" + assert agent["sandbox_mode"] == "strict" + + @pytest.mark.asyncio + async def test_get_agent(self, db): + """测试获取 Agent""" + # 创建 Agent + await db.create_agent( + agent_id="test-agent-3", + name="Get Test Agent", + workspace_path="/tmp/workspace", + ) + + # 获取存在的 Agent + agent = await db.get_agent("test-agent-3") + assert agent is not None + assert agent["name"] == "Get Test Agent" + + # 获取不存在的 Agent + agent = await db.get_agent("nonexistent") + assert agent is None + + @pytest.mark.asyncio + async def test_list_agents(self, db): + """测试列出所有 Agent""" + # 初始为空 + agents = await db.list_agents() + assert len(agents) == 0 + + # 创建多个 Agent + await db.create_agent("agent-1", "Agent 1", "/tmp/1") + await db.create_agent("agent-2", "Agent 2", "/tmp/2") + await db.create_agent("agent-3", "Agent 3", "/tmp/3") + + # 列出所有 + agents = await db.list_agents() + assert len(agents) == 3 + + +class TestSessionOperations: + """Session 操作测试""" + + @pytest.fixture + async def db(self, tmp_path): + """创建临时数据库并添加测试 Agent""" + db_path = tmp_path / "test.db" + db = Database(db_path) + await db.connect() + await db.create_agent("test-agent", "Test Agent", "/tmp/workspace") + yield db + await db.close() + + @pytest.mark.asyncio + async def test_create_session(self, db): + """测试创建会话""" + session = await db.create_session( + agent_id="test-agent", + channel="websocket", + peer_id="user-123", + ) + + assert session["agent_id"] == "test-agent" + assert session["channel"] == "websocket" + assert session["peer_id"] == "user-123" + assert session["status"] == "active" + assert "session_key" in session + + @pytest.mark.asyncio + async def test_create_session_with_metadata(self, db): + """测试创建带元数据的会话""" + metadata = {"client": "web", "version": "1.0"} + session = await db.create_session( + agent_id="test-agent", + channel="websocket", + metadata=metadata, + ) + + assert session["metadata"] == metadata + + @pytest.mark.asyncio + async def test_get_session(self, db): + """测试获取会话""" + # 创建会话 + created = await db.create_session( + agent_id="test-agent", + channel="websocket", + metadata={"test": "data"}, + ) + + # 获取会话 + session = await db.get_session(created["id"]) + assert session is not None + assert session["id"] == created["id"] + assert session["metadata"] == {"test": "data"} + + # 获取不存在的会话 + session = await db.get_session("nonexistent-id") + assert session is None + + @pytest.mark.asyncio + async def test_update_session_status(self, db): + """测试更新会话状态""" + # 创建会话 + session = await db.create_session( + agent_id="test-agent", + channel="websocket", + ) + assert session["status"] == "active" + + # 更新状态 + await db.update_session_status(session["id"], "closed") + + # 验证 + updated = await db.get_session(session["id"]) + assert updated["status"] == "closed" + + @pytest.mark.asyncio + async def test_list_active_sessions(self, db): + """测试列出活跃会话""" + # 创建多个会话 + session1 = await db.create_session("test-agent", "websocket") + session2 = await db.create_session("test-agent", "feishu") + await db.create_session("test-agent", "wework") + + # 关闭一个会话 + await db.update_session_status(session2["id"], "closed") + + # 列出所有活跃会话 + active = await db.list_active_sessions() + assert len(active) == 2 + + # 按 agent_id 过滤 + active = await db.list_active_sessions("test-agent") + assert len(active) == 2 + + +class TestMessageOperations: + """Message 操作测试""" + + @pytest.fixture + async def db(self, tmp_path): + """创建临时数据库并添加测试会话""" + db_path = tmp_path / "test.db" + db = Database(db_path) + await db.connect() + await db.create_agent("test-agent", "Test Agent", "/tmp/workspace") + session = await db.create_session("test-agent", "websocket") + db._test_session_id = session["id"] + yield db + await db.close() + + @pytest.mark.asyncio + async def test_add_message(self, db): + """测试添加消息""" + session_id = db._test_session_id + + message = await db.add_message( + session_id=session_id, + role="user", + content="Hello, world!", + ) + + assert message["session_id"] == session_id + assert message["role"] == "user" + assert message["content"] == "Hello, world!" + assert message["tokens_used"] == 0 + + @pytest.mark.asyncio + async def test_add_message_with_tool_calls(self, db): + """测试添加带工具调用的消息""" + session_id = db._test_session_id + + tool_calls = [ + {"id": "call-1", "name": "read_file", "arguments": {"path": "/tmp/test"}} + ] + message = await db.add_message( + session_id=session_id, + role="assistant", + content=None, + tool_calls=tool_calls, + tokens_used=100, + ) + + assert message["tool_calls"] == tool_calls + assert message["tokens_used"] == 100 + + @pytest.mark.asyncio + async def test_get_messages(self, db): + """测试获取会话消息""" + session_id = db._test_session_id + + # 添加多条消息 + await db.add_message(session_id, "user", "Message 1") + await db.add_message(session_id, "assistant", "Response 1") + await db.add_message(session_id, "user", "Message 2") + await db.add_message(session_id, "assistant", "Response 2") + + # 获取所有消息 + messages = await db.get_messages(session_id) + assert len(messages) == 4 + # 验证角色交替 + roles = [m["role"] for m in messages] + assert roles.count("user") == 2 + assert roles.count("assistant") == 2 + + @pytest.mark.asyncio + async def test_get_messages_with_limit(self, db): + """测试获取消息数量限制""" + session_id = db._test_session_id + + # 添加多条消息 + for i in range(10): + await db.add_message(session_id, "user", f"Message {i}") + + # 限制返回数量 + messages = await db.get_messages(session_id, limit=5) + assert len(messages) == 5 + + +class TestAuditLog: + """审计日志测试""" + + @pytest.fixture + async def db(self, tmp_path): + """创建临时数据库""" + db_path = tmp_path / "test.db" + db = Database(db_path) + await db.connect() + yield db + await db.close() + + @pytest.mark.asyncio + async def test_add_audit_log(self, db): + """测试添加审计日志""" + # 添加日志不应该报错 + await db.add_audit_log( + agent_id="test-agent", + tool_name="read_file", + danger_level="safe", + params={"path": "/tmp/test.txt"}, + result="success", + duration_ms=50, + ) + + @pytest.mark.asyncio + async def test_add_audit_log_minimal(self, db): + """测试添加最小审计日志""" + await db.add_audit_log( + agent_id=None, + tool_name="python_eval", + danger_level="low", + ) + + +class TestGlobalDatabase: + """全局数据库实例测试""" + + @pytest.mark.asyncio + async def test_import_functions(self): + """测试导入全局函数""" + from minenasai.core.database import close_database, get_database + + assert callable(get_database) + assert callable(close_database) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 0c9fccf..778558a 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -142,3 +142,159 @@ class TestRouterEdgeCases: result = self.router.evaluate("查看 /tmp/test.txt 文件内容") assert result["complexity"] in [TaskComplexity.SIMPLE, TaskComplexity.MEDIUM] + + +class TestConnectionManager: + """WebSocket 连接管理器测试""" + + def test_import_manager(self): + """测试导入连接管理器""" + from minenasai.gateway.server import ConnectionManager + + manager = ConnectionManager() + assert manager.active_connections == {} + + @pytest.mark.asyncio + async def test_connect_and_disconnect(self): + """测试连接和断开""" + from unittest.mock import AsyncMock, MagicMock + + from minenasai.gateway.server import ConnectionManager + + manager = ConnectionManager() + + # Mock WebSocket + mock_ws = AsyncMock() + mock_ws.accept = AsyncMock() + + # 连接 + await manager.connect(mock_ws, "client-1") + assert "client-1" in manager.active_connections + mock_ws.accept.assert_called_once() + + # 断开 + manager.disconnect("client-1") + assert "client-1" not in manager.active_connections + + @pytest.mark.asyncio + async def test_disconnect_nonexistent(self): + """测试断开不存在的连接""" + from minenasai.gateway.server import ConnectionManager + + manager = ConnectionManager() + # 不应该抛出异常 + manager.disconnect("nonexistent") + + @pytest.mark.asyncio + async def test_send_message(self): + """测试发送消息""" + from unittest.mock import AsyncMock + + from minenasai.gateway.server import ConnectionManager + + manager = ConnectionManager() + + # Mock WebSocket + mock_ws = AsyncMock() + mock_ws.accept = AsyncMock() + mock_ws.send_json = AsyncMock() + + # 连接 + await manager.connect(mock_ws, "client-1") + + # 发送消息 + await manager.send_message("client-1", {"type": "test"}) + mock_ws.send_json.assert_called_once_with({"type": "test"}) + + @pytest.mark.asyncio + async def test_send_message_to_nonexistent(self): + """测试发送消息给不存在的客户端""" + from minenasai.gateway.server import ConnectionManager + + manager = ConnectionManager() + # 不应该抛出异常 + await manager.send_message("nonexistent", {"type": "test"}) + + @pytest.mark.asyncio + async def test_broadcast(self): + """测试广播消息""" + from unittest.mock import AsyncMock + + from minenasai.gateway.server import ConnectionManager + + manager = ConnectionManager() + + # Mock 多个 WebSocket + mock_ws1 = AsyncMock() + mock_ws1.accept = AsyncMock() + mock_ws1.send_json = AsyncMock() + + mock_ws2 = AsyncMock() + mock_ws2.accept = AsyncMock() + mock_ws2.send_json = AsyncMock() + + # 连接 + await manager.connect(mock_ws1, "client-1") + await manager.connect(mock_ws2, "client-2") + + # 广播 + await manager.broadcast({"type": "broadcast"}) + mock_ws1.send_json.assert_called_once_with({"type": "broadcast"}) + mock_ws2.send_json.assert_called_once_with({"type": "broadcast"}) + + +class TestGatewayServer: + """Gateway 服务器测试""" + + def test_import_app(self): + """测试导入应用""" + from minenasai.gateway.server import app + + assert app is not None + assert app.title == "MineNASAI Gateway" + + def test_import_endpoints(self): + """测试导入端点函数""" + from minenasai.gateway.server import list_agents, list_sessions, root + + assert callable(root) + assert callable(list_agents) + assert callable(list_sessions) + + +class TestMessageTypes: + """消息类型测试""" + + def test_status_message(self): + """测试状态消息""" + from minenasai.gateway.protocol import StatusMessage + + msg = StatusMessage(status="thinking", message="处理中...") + assert msg.type == MessageType.STATUS + assert msg.status == "thinking" + assert msg.message == "处理中..." + + def test_response_message(self): + """测试响应消息""" + from minenasai.gateway.protocol import ResponseMessage + + msg = ResponseMessage(content="Hello!", in_reply_to="msg-123") + assert msg.type == MessageType.RESPONSE + assert msg.content == "Hello!" + assert msg.in_reply_to == "msg-123" + + def test_error_message(self): + """测试错误消息""" + from minenasai.gateway.protocol import ErrorMessage + + msg = ErrorMessage(message="Something went wrong", code="ERR_001") + assert msg.type == MessageType.ERROR + assert msg.message == "Something went wrong" + assert msg.code == "ERR_001" + + def test_pong_message(self): + """测试心跳响应消息""" + from minenasai.gateway.protocol import PongMessage + + msg = PongMessage() + assert msg.type == MessageType.PONG diff --git a/tests/test_llm.py b/tests/test_llm.py index 443f483..b95bee9 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -2,9 +2,18 @@ from __future__ import annotations +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from minenasai.llm.base import Message, Provider, ToolCall, ToolDefinition +from minenasai.llm.base import ( + LLMResponse, + Message, + Provider, + StreamChunk, + ToolCall, + ToolDefinition, +) class TestProvider: @@ -133,3 +142,208 @@ class TestLLMManager: providers = manager.get_available_providers() # 可能为空,取决于环境变量 assert isinstance(providers, list) + + +class TestOpenAICompatClientMock: + """OpenAI 兼容客户端 Mock 测试""" + + @pytest.fixture + def client(self): + """创建测试客户端""" + from minenasai.llm.clients import OpenAICompatClient + + return OpenAICompatClient(api_key="test-key", base_url="https://api.test.com/v1") + + @pytest.mark.asyncio + async def test_chat_mock(self, client): + """测试聊天功能(Mock)""" + # Mock HTTP 响应 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I help you?", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + }, + } + + with patch.object(client, "_get_client") as mock_get_client: + mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(return_value=mock_response) + mock_get_client.return_value = mock_http_client + + messages = [Message(role="user", content="Hello")] + response = await client.chat(messages) + + assert isinstance(response, LLMResponse) + assert response.content == "Hello! How can I help you?" + assert response.model == "gpt-4o" + assert response.finish_reason == "stop" + assert response.usage["total_tokens"] == 30 + + @pytest.mark.asyncio + async def test_chat_with_tools_mock(self, client): + """测试带工具调用的聊天(Mock)""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "id": "chatcmpl-456", + "model": "gpt-4o", + "choices": [ + { + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "read_file", + "arguments": '{"path": "/tmp/test.txt"}', + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, + } + + with patch.object(client, "_get_client") as mock_get_client: + mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(return_value=mock_response) + mock_get_client.return_value = mock_http_client + + messages = [Message(role="user", content="Read the test file")] + tools = [ + ToolDefinition( + name="read_file", + description="Read a file", + parameters={"type": "object", "properties": {"path": {"type": "string"}}}, + ) + ] + response = await client.chat(messages, tools=tools) + + assert response.tool_calls is not None + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "read_file" + assert response.tool_calls[0].arguments == {"path": "/tmp/test.txt"} + + @pytest.mark.asyncio + async def test_close_client(self, client): + """测试关闭客户端""" + # 创建 mock 客户端 + mock_http_client = AsyncMock() + client._client = mock_http_client + + await client.close() + + mock_http_client.aclose.assert_called_once() + assert client._client is None + + +class TestAnthropicClientMock: + """Anthropic 客户端 Mock 测试""" + + @pytest.fixture + def client(self): + """创建测试客户端""" + from minenasai.llm.clients import AnthropicClient + + return AnthropicClient(api_key="test-key") + + @pytest.mark.asyncio + async def test_chat_mock(self, client): + """测试聊天功能(Mock)""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello from Claude!"}], + "model": "claude-sonnet-4-20250514", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 15}, + } + + with patch.object(client, "_get_client") as mock_get_client: + mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(return_value=mock_response) + mock_get_client.return_value = mock_http_client + + messages = [Message(role="user", content="Hello")] + response = await client.chat(messages) + + assert isinstance(response, LLMResponse) + assert response.content == "Hello from Claude!" + assert response.provider == Provider.ANTHROPIC + + +class TestLLMResponse: + """LLM 响应测试""" + + def test_response_basic(self): + """测试基本响应""" + response = LLMResponse( + content="Test response", + model="gpt-4o", + provider=Provider.OPENAI, + ) + assert response.content == "Test response" + assert response.model == "gpt-4o" + assert response.provider == Provider.OPENAI + assert response.finish_reason == "stop" + + def test_response_with_tool_calls(self): + """测试带工具调用的响应""" + tool_calls = [ + ToolCall(id="tc_1", name="read_file", arguments={"path": "/test"}), + ToolCall(id="tc_2", name="list_dir", arguments={"path": "/"}), + ] + response = LLMResponse( + content="", + model="claude-sonnet-4-20250514", + provider=Provider.ANTHROPIC, + tool_calls=tool_calls, + finish_reason="tool_use", + ) + assert len(response.tool_calls) == 2 + assert response.finish_reason == "tool_use" + + +class TestStreamChunk: + """流式响应块测试""" + + def test_chunk_basic(self): + """测试基本响应块""" + chunk = StreamChunk(content="Hello") + assert chunk.content == "Hello" + assert chunk.is_final is False + + def test_chunk_final(self): + """测试最终响应块""" + chunk = StreamChunk( + content="", + is_final=True, + usage={"prompt_tokens": 10, "completion_tokens": 20}, + ) + assert chunk.is_final is True + assert chunk.usage is not None diff --git a/tests/test_permissions.py b/tests/test_permissions.py index e1d7a8f..c6edae5 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -30,7 +30,7 @@ class TestToolPermission: name="test_tool", danger_level=DangerLevel.SAFE, ) - + assert perm.requires_confirmation is False assert perm.rate_limit is None @@ -45,7 +45,7 @@ class TestPermissionManager: def test_get_default_permission(self): """测试获取默认权限""" perm = self.manager.get_permission("read_file") - + assert perm is not None assert perm.danger_level == DangerLevel.SAFE @@ -57,7 +57,7 @@ class TestPermissionManager: description="自定义工具", ) self.manager.register_tool(perm) - + result = self.manager.get_permission("custom_tool") assert result is not None assert result.danger_level == DangerLevel.MEDIUM @@ -65,13 +65,13 @@ class TestPermissionManager: def test_check_permission_allowed(self): """测试权限检查 - 允许""" allowed, reason = self.manager.check_permission("read_file") - + assert allowed is True def test_check_permission_unknown_tool(self): """测试权限检查 - 未知工具""" allowed, reason = self.manager.check_permission("unknown_tool") - + assert allowed is False assert "未知工具" in reason @@ -83,12 +83,12 @@ class TestPermissionManager: denied_paths=["/etc/", "/root/"], ) self.manager.register_tool(perm) - + allowed, reason = self.manager.check_permission( "restricted_read", params={"path": "/etc/passwd"}, ) - + assert allowed is False assert "禁止访问" in reason @@ -96,7 +96,7 @@ class TestPermissionManager: """测试确认要求 - 按等级""" # HIGH 级别需要确认 assert self.manager.requires_confirmation("delete_file") is True - + # SAFE 级别不需要确认 assert self.manager.requires_confirmation("read_file") is False @@ -108,7 +108,7 @@ class TestPermissionManager: requires_confirmation=True, ) self.manager.register_tool(perm) - + assert self.manager.requires_confirmation("explicit_confirm") is True @pytest.mark.asyncio @@ -119,9 +119,9 @@ class TestPermissionManager: tool_name="delete_file", params={"path": "/test.txt"}, ) - + assert request.status == ConfirmationStatus.PENDING - + # 批准 self.manager.approve_confirmation("req-1") assert request.status == ConfirmationStatus.APPROVED @@ -134,7 +134,7 @@ class TestPermissionManager: tool_name="delete_file", params={"path": "/test.txt"}, ) - + self.manager.deny_confirmation("req-2") assert request.status == ConfirmationStatus.DENIED @@ -146,7 +146,7 @@ class TestPermissionManager: tool_name="test", params={}, ) - + pending = self.manager.get_pending_confirmations() assert len(pending) >= 1 @@ -157,48 +157,48 @@ class TestToolRegistry: def test_import_registry(self): """测试导入注册中心""" from minenasai.agent import ToolRegistry, get_tool_registry - + registry = get_tool_registry() assert isinstance(registry, ToolRegistry) def test_register_builtin_tools(self): """测试注册内置工具""" from minenasai.agent import get_tool_registry, register_builtin_tools - + registry = get_tool_registry() initial_count = len(registry.list_tools()) - + register_builtin_tools() - + # 应该有更多工具 new_count = len(registry.list_tools()) assert new_count >= initial_count def test_tool_decorator(self): """测试工具装饰器""" - from minenasai.agent import tool, get_tool_registry - + from minenasai.agent import get_tool_registry, tool + @tool(name="decorated_tool", description="装饰器测试") async def decorated_tool(param: str) -> dict: return {"result": param} - + registry = get_tool_registry() tool_obj = registry.get("decorated_tool") - + assert tool_obj is not None assert tool_obj.description == "装饰器测试" @pytest.mark.asyncio async def test_execute_tool(self): """测试执行工具""" - from minenasai.agent import get_tool_registry, DangerLevel - + from minenasai.agent import DangerLevel, get_tool_registry + registry = get_tool_registry() - + # 注册测试工具 async def echo(message: str) -> dict: return {"echo": message} - + registry.register( name="echo", description="回显消息", @@ -210,18 +210,18 @@ class TestToolRegistry: }, danger_level=DangerLevel.SAFE, ) - + result = await registry.execute("echo", {"message": "hello"}) - + assert result["success"] is True assert result["result"]["echo"] == "hello" def test_get_stats(self): """测试获取统计""" from minenasai.agent import get_tool_registry - + registry = get_tool_registry() stats = registry.get_stats() - + assert "total_tools" in stats assert "categories" in stats diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index c7a3164..6acee1a 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -15,7 +15,7 @@ class TestCronParser: def test_parse_all_stars(self): """测试全星号表达式""" result = CronParser.parse("* * * * *") - + assert len(result["minute"]) == 60 assert len(result["hour"]) == 24 assert len(result["day"]) == 31 @@ -25,39 +25,39 @@ class TestCronParser: def test_parse_specific_values(self): """测试具体值""" result = CronParser.parse("30 8 * * *") - + assert result["minute"] == {30} assert result["hour"] == {8} def test_parse_range(self): """测试范围""" result = CronParser.parse("0 9-17 * * *") - + assert result["hour"] == {9, 10, 11, 12, 13, 14, 15, 16, 17} def test_parse_step(self): """测试步进""" result = CronParser.parse("*/15 * * * *") - + assert result["minute"] == {0, 15, 30, 45} def test_parse_list(self): """测试列表""" result = CronParser.parse("0 8,12,18 * * *") - + assert result["hour"] == {8, 12, 18} def test_parse_preset_daily(self): """测试预定义表达式 @daily""" result = CronParser.parse("@daily") - + assert result["minute"] == {0} assert result["hour"] == {0} def test_parse_preset_hourly(self): """测试预定义表达式 @hourly""" result = CronParser.parse("@hourly") - + assert result["minute"] == {0} assert len(result["hour"]) == 24 @@ -73,7 +73,7 @@ class TestCronParser: "0 * * * *", after=datetime(2026, 1, 1, 10, 30) ) - + assert next_run.minute == 0 assert next_run.hour == 11 @@ -89,7 +89,7 @@ class TestCronJob: schedule="*/5 * * * *", task="测试", ) - + assert job.id == "test-job" assert job.enabled is True assert job.last_status == JobStatus.PENDING @@ -113,7 +113,7 @@ class TestCronScheduler: schedule="*/5 * * * *", callback=task, ) - + assert job.id == "test-1" assert job.next_run is not None @@ -123,7 +123,7 @@ class TestCronScheduler: pass self.scheduler.add_job("test-1", "测试", "* * * * *", task) - + assert self.scheduler.remove_job("test-1") is True assert self.scheduler.get_job("test-1") is None @@ -133,12 +133,12 @@ class TestCronScheduler: pass self.scheduler.add_job("test-1", "测试", "* * * * *", task) - + assert self.scheduler.disable_job("test-1") is True job = self.scheduler.get_job("test-1") assert job.enabled is False assert job.last_status == JobStatus.DISABLED - + assert self.scheduler.enable_job("test-1") is True assert job.enabled is True @@ -149,7 +149,7 @@ class TestCronScheduler: self.scheduler.add_job("test-1", "任务1", "* * * * *", task) self.scheduler.add_job("test-2", "任务2", "*/5 * * * *", task) - + jobs = self.scheduler.list_jobs() assert len(jobs) == 2 @@ -159,7 +159,7 @@ class TestCronScheduler: pass self.scheduler.add_job("test-1", "任务1", "* * * * *", task) - + stats = self.scheduler.get_stats() assert stats["total_jobs"] == 1 assert stats["enabled_jobs"] == 1 diff --git a/tests/test_webtui.py b/tests/test_webtui.py index ce1f227..916d4ed 100644 --- a/tests/test_webtui.py +++ b/tests/test_webtui.py @@ -4,8 +4,6 @@ from __future__ import annotations import time -import pytest - from minenasai.webtui.auth import AuthManager, AuthToken @@ -45,7 +43,7 @@ class TestAuthManager: def test_generate_token(self): """测试生成令牌""" token = self.manager.generate_token("user1") - + assert token is not None assert len(token) > 20 @@ -53,7 +51,7 @@ class TestAuthManager: """测试验证令牌""" token = self.manager.generate_token("user1") auth_token = self.manager.verify_token(token) - + assert auth_token is not None assert auth_token.user_id == "user1" @@ -65,17 +63,17 @@ class TestAuthManager: def test_verify_expired_token(self): """测试验证过期令牌""" token = self.manager.generate_token("user1", expires_in=0) - + # 等待过期 time.sleep(0.1) - + auth_token = self.manager.verify_token(token) assert auth_token is None def test_revoke_token(self): """测试撤销令牌""" token = self.manager.generate_token("user1") - + assert self.manager.revoke_token(token) is True assert self.manager.verify_token(token) is None @@ -88,9 +86,9 @@ class TestAuthManager: self.manager.generate_token("user1") self.manager.generate_token("user1") self.manager.generate_token("user2") - + count = self.manager.revoke_user_tokens("user1") - + assert count == 2 assert self.manager.get_stats()["total_tokens"] == 1 @@ -98,7 +96,7 @@ class TestAuthManager: """测试刷新令牌""" old_token = self.manager.generate_token("user1") new_token = self.manager.refresh_token(old_token) - + assert new_token is not None assert new_token != old_token assert self.manager.verify_token(old_token) is None @@ -108,9 +106,9 @@ class TestAuthManager: """测试令牌元数据""" metadata = {"channel": "wework", "task_id": "123"} token = self.manager.generate_token("user1", metadata=metadata) - + auth_token = self.manager.verify_token(token) - + assert auth_token is not None assert auth_token.metadata == metadata @@ -118,10 +116,10 @@ class TestAuthManager: """测试清理过期令牌""" self.manager.generate_token("user1", expires_in=0) self.manager.generate_token("user2", expires_in=3600) - + time.sleep(0.1) count = self.manager.cleanup_expired() - + assert count == 1 assert self.manager.get_stats()["total_tokens"] == 1 @@ -132,17 +130,17 @@ class TestSSHManager: def test_import_ssh_manager(self): """测试导入 SSH 管理器""" from minenasai.webtui import SSHManager, get_ssh_manager - + manager = get_ssh_manager() assert isinstance(manager, SSHManager) def test_ssh_manager_stats(self): """测试 SSH 管理器统计""" from minenasai.webtui import SSHManager - + manager = SSHManager() stats = manager.get_stats() - + assert "active_sessions" in stats assert stats["active_sessions"] == 0 @@ -153,6 +151,6 @@ class TestWebTUIServer: def test_import_server(self): """测试导入服务器""" from minenasai.webtui.server import app - + assert app is not None assert app.title == "MineNASAI Web TUI"