""" LLM 服务层 提供统一的 LLM 调用接口,支持多个提供商: - Anthropic (Claude) - OpenAI (GPT) - DeepSeek - Ollama (本地模型) - Google (Gemini) """ import os import json import asyncio import logging from abc import ABC, abstractmethod from typing import Dict, List, Optional, Any, Union from dataclasses import dataclass from enum import Enum import time logger = logging.getLogger(__name__) class TaskType(Enum): """任务类型分类""" COMPLEX_REASONING = "complex_reasoning" CODE_GENERATION = "code_generation" CODE_REVIEW = "code_review" SIMPLE_TASK = "simple_task" COST_SENSITIVE = "cost_sensitive" LOCAL_PRIVACY = "local_privacy" MULTIMODAL = "multimodal" ARCHITECTURE_DESIGN = "architecture_design" TEST_GENERATION = "test_generation" @dataclass class LLMMessage: """LLM 消息""" role: str # system, user, assistant content: str images: Optional[List[str]] = None @dataclass class LLMResponse: """LLM 响应""" content: str model: str provider: str tokens_used: int = 0 finish_reason: str = "" latency: float = 0.0 @dataclass class LLMConfig: """LLM 配置""" # Anthropic anthropic_api_key: Optional[str] = None anthropic_base_url: str = "https://api.anthropic.com" # OpenAI openai_api_key: Optional[str] = None openai_base_url: str = "https://api.openai.com/v1" # DeepSeek deepseek_api_key: Optional[str] = None deepseek_base_url: str = "https://api.deepseek.com" # Google google_api_key: Optional[str] = None # Ollama ollama_base_url: str = "http://localhost:11434" # 通用设置 default_timeout: int = 120 max_retries: int = 3 temperature: float = 0.7 max_tokens: int = 4096 @classmethod def from_env(cls) -> "LLMConfig": """从环境变量加载配置""" return cls( anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"), openai_api_key=os.getenv("OPENAI_API_KEY"), deepseek_api_key=os.getenv("DEEPSEEK_API_KEY"), google_api_key=os.getenv("GOOGLE_API_KEY"), ollama_base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"), default_timeout=int(os.getenv("LLM_TIMEOUT", "120")), max_retries=int(os.getenv("LLM_MAX_RETRIES", "3")), temperature=float(os.getenv("LLM_TEMPERATURE", "0.7")), max_tokens=int(os.getenv("LLM_MAX_TOKENS", "4096")) ) class LLMProvider(ABC): """LLM 提供商抽象基类""" def __init__(self, config: LLMConfig): self.config = config @property @abstractmethod def provider_name(self) -> str: """提供商名称""" pass @abstractmethod async def chat_completion( self, model: str, messages: List[LLMMessage], temperature: float = None, max_tokens: int = None, **kwargs ) -> LLMResponse: """聊天补全""" pass @abstractmethod async def stream_completion( self, model: str, messages: List[LLMMessage], **kwargs ): """流式补全""" pass @abstractmethod def get_available_models(self) -> List[str]: """获取可用模型列表""" pass async def _retry_with_backoff(self, func, *args, **kwargs): """带退避的重试机制""" last_error = None for attempt in range(self.config.max_retries): try: return await func(*args, **kwargs) except Exception as e: last_error = e if attempt < self.config.max_retries - 1: wait_time = 2 ** attempt logger.warning(f"Attempt {attempt + 1} failed, retrying in {wait_time}s: {e}") await asyncio.sleep(wait_time) else: logger.error(f"All {self.config.max_retries} attempts failed") raise last_error class AnthropicProvider(LLMProvider): """Anthropic Claude 提供商""" @property def provider_name(self) -> str: return "anthropic" def __init__(self, config: LLMConfig): super().__init__(config) self._client = None def _get_client(self): """懒加载客户端""" if self._client is None: try: import anthropic self._client = anthropic.AsyncAnthropic( api_key=self.config.anthropic_api_key, base_url=self.config.anthropic_base_url, timeout=self.config.default_timeout ) except ImportError: raise ImportError("请安装 anthropic 包: pip install anthropic") return self._client async def chat_completion( self, model: str, messages: List[LLMMessage], temperature: float = None, max_tokens: int = None, **kwargs ) -> LLMResponse: start_time = time.time() # 分离系统消息 system_message = "" user_messages = [] for msg in messages: if msg.role == "system": system_message = msg.content else: user_messages.append({ "role": msg.role, "content": msg.content }) client = self._get_client() response = await self._retry_with_backoff( client.messages.create, model=model, system=system_message if system_message else None, messages=user_messages, temperature=temperature or self.config.temperature, max_tokens=max_tokens or self.config.max_tokens, **kwargs ) latency = time.time() - start_time return LLMResponse( content=response.content[0].text, model=model, provider=self.provider_name, tokens_used=response.usage.input_tokens + response.usage.output_tokens, finish_reason=response.stop_reason, latency=latency ) async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs): client = self._get_client() system_message = "" user_messages = [] for msg in messages: if msg.role == "system": system_message = msg.content else: user_messages.append({"role": msg.role, "content": msg.content}) async with client.messages.stream( model=model, system=system_message if system_message else None, messages=user_messages, max_tokens=self.config.max_tokens, **kwargs ) as stream: async for text in stream.text_stream: yield text def get_available_models(self) -> List[str]: return [ "claude-opus-4.6", "claude-sonnet-4.6", "claude-haiku-4.6", "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022" ] class OpenAIProvider(LLMProvider): """OpenAI GPT 提供商""" @property def provider_name(self) -> str: return "openai" def __init__(self, config: LLMConfig): super().__init__(config) self._client = None def _get_client(self): if self._client is None: try: import openai self._client = openai.AsyncOpenAI( api_key=self.config.openai_api_key, base_url=self.config.openai_base_url, timeout=self.config.default_timeout ) except ImportError: raise ImportError("请安装 openai 包: pip install openai") return self._client async def chat_completion( self, model: str, messages: List[LLMMessage], temperature: float = None, max_tokens: int = None, **kwargs ) -> LLMResponse: start_time = time.time() client = self._get_client() api_messages = [ {"role": msg.role, "content": msg.content} for msg in messages ] response = await self._retry_with_backoff( client.chat.completions.create, model=model, messages=api_messages, temperature=temperature or self.config.temperature, max_tokens=max_tokens or self.config.max_tokens, **kwargs ) latency = time.time() - start_time return LLMResponse( content=response.choices[0].message.content, model=model, provider=self.provider_name, tokens_used=response.usage.total_tokens, finish_reason=response.choices[0].finish_reason, latency=latency ) async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs): client = self._get_client() api_messages = [{"role": msg.role, "content": msg.content} for msg in messages] stream = await client.chat.completions.create( model=model, messages=api_messages, stream=True, **kwargs ) async for chunk in stream: if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content def get_available_models(self) -> List[str]: return [ "gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo" ] class DeepSeekProvider(LLMProvider): """DeepSeek 提供商""" @property def provider_name(self) -> str: return "deepseek" def __init__(self, config: LLMConfig): super().__init__(config) self._client = None def _get_client(self): if self._client is None: try: import openai self._client = openai.AsyncOpenAI( api_key=self.config.deepseek_api_key, base_url=self.config.deepseek_base_url, timeout=self.config.default_timeout ) except ImportError: raise ImportError("请安装 openai 包: pip install openai") return self._client async def chat_completion( self, model: str, messages: List[LLMMessage], temperature: float = None, max_tokens: int = None, **kwargs ) -> LLMResponse: start_time = time.time() client = self._get_client() api_messages = [{"role": msg.role, "content": msg.content} for msg in messages] response = await self._retry_with_backoff( client.chat.completions.create, model=model, messages=api_messages, temperature=temperature or self.config.temperature, max_tokens=max_tokens or self.config.max_tokens, **kwargs ) latency = time.time() - start_time return LLMResponse( content=response.choices[0].message.content, model=model, provider=self.provider_name, tokens_used=response.usage.total_tokens, finish_reason=response.choices[0].finish_reason, latency=latency ) async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs): client = self._get_client() api_messages = [{"role": msg.role, "content": msg.content} for msg in messages] stream = await client.chat.completions.create( model=model, messages=api_messages, stream=True, **kwargs ) async for chunk in stream: if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content def get_available_models(self) -> List[str]: return [ "deepseek-chat", "deepseek-coder" ] class OllamaProvider(LLMProvider): """Ollama 本地模型提供商""" @property def provider_name(self) -> str: return "ollama" def __init__(self, config: LLMConfig): super().__init__(config) self._base_url = config.ollama_base_url async def chat_completion( self, model: str, messages: List[LLMMessage], temperature: float = None, max_tokens: int = None, **kwargs ) -> LLMResponse: import aiohttp start_time = time.time() api_messages = [{"role": msg.role, "content": msg.content} for msg in messages] payload = { "model": model, "messages": api_messages, "stream": False, "options": { "temperature": temperature or self.config.temperature, "num_predict": max_tokens or self.config.max_tokens } } async with aiohttp.ClientSession() as session: async with session.post( f"{self._base_url}/api/chat", json=payload, timeout=self.config.default_timeout ) as response: result = await response.json() latency = time.time() - start_time return LLMResponse( content=result.get("message", {}).get("content", ""), model=model, provider=self.provider_name, tokens_used=result.get("prompt_eval_count", 0) + result.get("eval_count", 0), latency=latency ) async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs): import aiohttp api_messages = [{"role": msg.role, "content": msg.content} for msg in messages] payload = { "model": model, "messages": api_messages, "stream": True } async with aiohttp.ClientSession() as session: async with session.post( f"{self._base_url}/api/chat", json=payload ) as response: async for line in response.content: if line: data = json.loads(line) if "message" in data: yield data["message"].get("content", "") def get_available_models(self) -> List[str]: return ["llama3", "llama3.2", "mistral", "codellama", "deepseek-coder"] class ModelRouter: """ 智能模型路由器 根据任务类型和需求自动选择最合适的模型 """ # 默认路由规则 ROUTING_RULES = { TaskType.COMPLEX_REASONING: ("anthropic", "claude-opus-4.6"), TaskType.CODE_GENERATION: ("anthropic", "claude-sonnet-4.6"), TaskType.CODE_REVIEW: ("anthropic", "claude-sonnet-4.6"), TaskType.ARCHITECTURE_DESIGN: ("anthropic", "claude-opus-4.6"), TaskType.TEST_GENERATION: ("anthropic", "claude-sonnet-4.6"), TaskType.SIMPLE_TASK: ("anthropic", "claude-haiku-4.6"), TaskType.COST_SENSITIVE: ("deepseek", "deepseek-chat"), TaskType.LOCAL_PRIVACY: ("ollama", "llama3"), } def __init__(self, config: LLMConfig = None): self.config = config or LLMConfig.from_env() self.providers: Dict[str, LLMProvider] = {} self._initialize_providers() def _initialize_providers(self): """初始化可用的提供商""" if self.config.anthropic_api_key: self.providers["anthropic"] = AnthropicProvider(self.config) if self.config.openai_api_key: self.providers["openai"] = OpenAIProvider(self.config) if self.config.deepseek_api_key: self.providers["deepseek"] = DeepSeekProvider(self.config) # Ollama 总是可用(本地服务) self.providers["ollama"] = OllamaProvider(self.config) def classify_task(self, task_description: str) -> TaskType: """ 分析任务描述,分类任务类型 使用关键词匹配和启发式规则 """ task_lower = task_description.lower() # 检查关键词 keywords_map = { TaskType.ARCHITECTURE_DESIGN: ["架构", "设计", "系统设计", "技术选型", "架构图"], TaskType.CODE_GENERATION: ["实现", "编写", "生成代码", "开发", "创建函数"], TaskType.CODE_REVIEW: ["审查", "review", "检查", "分析代码"], TaskType.TEST_GENERATION: ["测试", "单元测试", "测试用例"], TaskType.COMPLEX_REASONING: ["分析", "推理", "判断", "复杂", "评估"], } # 计算匹配分数 scores = {} for task_type, keywords in keywords_map.items(): score = sum(1 for kw in keywords if kw in task_lower) if score > 0: scores[task_type] = score # 返回最高分的类型 if scores: return max(scores, key=scores.get) # 默认返回简单任务 return TaskType.SIMPLE_TASK def get_route(self, task_type: TaskType, preferred_provider: str = None) -> tuple: """ 获取路由决策 返回: (provider_name, model_name) """ # 如果指定了提供商,尝试使用 if preferred_provider and preferred_provider in self.providers: provider = self.providers[preferred_provider] models = provider.get_available_models() if models: return preferred_provider, models[0] # 使用路由规则 if task_type in self.ROUTING_RULES: provider_name, model_name = self.ROUTING_RULES[task_type] if provider_name in self.providers: return provider_name, model_name # 回退到第一个可用的提供商 for provider_name, provider in self.providers.items(): models = provider.get_available_models() if models: return provider_name, models[0] raise RuntimeError("没有可用的 LLM 提供商") async def route_task( self, task: str, messages: List[LLMMessage] = None, preferred_model: str = None, preferred_provider: str = None, **kwargs ) -> LLMResponse: """ 智能路由任务到合适的模型 参数: task: 任务描述 messages: 消息列表(如果为 None,会自动从 task 创建) preferred_model: 首选模型 preferred_provider: 首选提供商 """ # 如果指定了首选模型,尝试直接使用 if preferred_model: if "-" in preferred_model: # 从模型名推断提供商 if preferred_model.startswith("claude"): provider_name = "anthropic" elif preferred_model.startswith("gpt"): provider_name = "openai" elif preferred_model.startswith("deepseek"): provider_name = "deepseek" else: provider_name = preferred_provider or "anthropic" if provider_name in self.providers: provider = self.providers[provider_name] return await provider.chat_completion( preferred_model, messages or [LLMMessage(role="user", content=task)], **kwargs ) # 分类任务类型 task_type = self.classify_task(task) provider_name, model_name = self.get_route(task_type, preferred_provider) logger.info(f"路由任务: {task_type.value} -> {provider_name}/{model_name}") provider = self.providers[provider_name] return await provider.chat_completion( model_name, messages or [LLMMessage(role="user", content=task)], **kwargs ) def get_available_providers(self) -> List[str]: """获取所有可用的提供商""" return list(self.providers.keys()) def get_provider_models(self, provider_name: str) -> List[str]: """获取指定提供商的可用模型""" if provider_name in self.providers: return self.providers[provider_name].get_available_models() return [] # 单例获取函数 _llm_service: Optional[ModelRouter] = None def get_llm_service(config: LLMConfig = None) -> ModelRouter: """获取 LLM 服务单例""" global _llm_service if _llm_service is None: _llm_service = ModelRouter(config or LLMConfig.from_env()) return _llm_service def reset_llm_service(): """重置 LLM 服务(主要用于测试)""" global _llm_service _llm_service = None