Files
multiAgentTry/backend/app/services/llm_service.py

670 lines
20 KiB
Python
Raw Normal View History

"""
LLM 服务层
提供统一的 LLM 调用接口支持多个提供商
- Anthropic (Claude)
- OpenAI (GPT)
- DeepSeek
- Ollama (本地模型)
- Google (Gemini)
"""
import os
import json
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass
from enum import Enum
import time
logger = logging.getLogger(__name__)
class TaskType(Enum):
"""任务类型分类"""
COMPLEX_REASONING = "complex_reasoning"
CODE_GENERATION = "code_generation"
CODE_REVIEW = "code_review"
SIMPLE_TASK = "simple_task"
COST_SENSITIVE = "cost_sensitive"
LOCAL_PRIVACY = "local_privacy"
MULTIMODAL = "multimodal"
ARCHITECTURE_DESIGN = "architecture_design"
TEST_GENERATION = "test_generation"
@dataclass
class LLMMessage:
"""LLM 消息"""
role: str # system, user, assistant
content: str
images: Optional[List[str]] = None
@dataclass
class LLMResponse:
"""LLM 响应"""
content: str
model: str
provider: str
tokens_used: int = 0
finish_reason: str = ""
latency: float = 0.0
@dataclass
class LLMConfig:
"""LLM 配置"""
# Anthropic
anthropic_api_key: Optional[str] = None
anthropic_base_url: str = "https://api.anthropic.com"
# OpenAI
openai_api_key: Optional[str] = None
openai_base_url: str = "https://api.openai.com/v1"
# DeepSeek
deepseek_api_key: Optional[str] = None
deepseek_base_url: str = "https://api.deepseek.com"
# Google
google_api_key: Optional[str] = None
# Ollama
ollama_base_url: str = "http://localhost:11434"
# 通用设置
default_timeout: int = 120
max_retries: int = 3
temperature: float = 0.7
max_tokens: int = 4096
@classmethod
def from_env(cls) -> "LLMConfig":
"""从环境变量加载配置"""
return cls(
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
openai_api_key=os.getenv("OPENAI_API_KEY"),
deepseek_api_key=os.getenv("DEEPSEEK_API_KEY"),
google_api_key=os.getenv("GOOGLE_API_KEY"),
ollama_base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
default_timeout=int(os.getenv("LLM_TIMEOUT", "120")),
max_retries=int(os.getenv("LLM_MAX_RETRIES", "3")),
temperature=float(os.getenv("LLM_TEMPERATURE", "0.7")),
max_tokens=int(os.getenv("LLM_MAX_TOKENS", "4096"))
)
class LLMProvider(ABC):
"""LLM 提供商抽象基类"""
def __init__(self, config: LLMConfig):
self.config = config
@property
@abstractmethod
def provider_name(self) -> str:
"""提供商名称"""
pass
@abstractmethod
async def chat_completion(
self,
model: str,
messages: List[LLMMessage],
temperature: float = None,
max_tokens: int = None,
**kwargs
) -> LLMResponse:
"""聊天补全"""
pass
@abstractmethod
async def stream_completion(
self,
model: str,
messages: List[LLMMessage],
**kwargs
):
"""流式补全"""
pass
@abstractmethod
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
pass
async def _retry_with_backoff(self, func, *args, **kwargs):
"""带退避的重试机制"""
last_error = None
for attempt in range(self.config.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_error = e
if attempt < self.config.max_retries - 1:
wait_time = 2 ** attempt
logger.warning(f"Attempt {attempt + 1} failed, retrying in {wait_time}s: {e}")
await asyncio.sleep(wait_time)
else:
logger.error(f"All {self.config.max_retries} attempts failed")
raise last_error
class AnthropicProvider(LLMProvider):
"""Anthropic Claude 提供商"""
@property
def provider_name(self) -> str:
return "anthropic"
def __init__(self, config: LLMConfig):
super().__init__(config)
self._client = None
def _get_client(self):
"""懒加载客户端"""
if self._client is None:
try:
import anthropic
self._client = anthropic.AsyncAnthropic(
api_key=self.config.anthropic_api_key,
base_url=self.config.anthropic_base_url,
timeout=self.config.default_timeout
)
except ImportError:
raise ImportError("请安装 anthropic 包: pip install anthropic")
return self._client
async def chat_completion(
self,
model: str,
messages: List[LLMMessage],
temperature: float = None,
max_tokens: int = None,
**kwargs
) -> LLMResponse:
start_time = time.time()
# 分离系统消息
system_message = ""
user_messages = []
for msg in messages:
if msg.role == "system":
system_message = msg.content
else:
user_messages.append({
"role": msg.role,
"content": msg.content
})
client = self._get_client()
response = await self._retry_with_backoff(
client.messages.create,
model=model,
system=system_message if system_message else None,
messages=user_messages,
temperature=temperature or self.config.temperature,
max_tokens=max_tokens or self.config.max_tokens,
**kwargs
)
latency = time.time() - start_time
return LLMResponse(
content=response.content[0].text,
model=model,
provider=self.provider_name,
tokens_used=response.usage.input_tokens + response.usage.output_tokens,
finish_reason=response.stop_reason,
latency=latency
)
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
client = self._get_client()
system_message = ""
user_messages = []
for msg in messages:
if msg.role == "system":
system_message = msg.content
else:
user_messages.append({"role": msg.role, "content": msg.content})
async with client.messages.stream(
model=model,
system=system_message if system_message else None,
messages=user_messages,
max_tokens=self.config.max_tokens,
**kwargs
) as stream:
async for text in stream.text_stream:
yield text
def get_available_models(self) -> List[str]:
return [
"claude-opus-4.6",
"claude-sonnet-4.6",
"claude-haiku-4.6",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022"
]
class OpenAIProvider(LLMProvider):
"""OpenAI GPT 提供商"""
@property
def provider_name(self) -> str:
return "openai"
def __init__(self, config: LLMConfig):
super().__init__(config)
self._client = None
def _get_client(self):
if self._client is None:
try:
import openai
self._client = openai.AsyncOpenAI(
api_key=self.config.openai_api_key,
base_url=self.config.openai_base_url,
timeout=self.config.default_timeout
)
except ImportError:
raise ImportError("请安装 openai 包: pip install openai")
return self._client
async def chat_completion(
self,
model: str,
messages: List[LLMMessage],
temperature: float = None,
max_tokens: int = None,
**kwargs
) -> LLMResponse:
start_time = time.time()
client = self._get_client()
api_messages = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
response = await self._retry_with_backoff(
client.chat.completions.create,
model=model,
messages=api_messages,
temperature=temperature or self.config.temperature,
max_tokens=max_tokens or self.config.max_tokens,
**kwargs
)
latency = time.time() - start_time
return LLMResponse(
content=response.choices[0].message.content,
model=model,
provider=self.provider_name,
tokens_used=response.usage.total_tokens,
finish_reason=response.choices[0].finish_reason,
latency=latency
)
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
client = self._get_client()
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
stream = await client.chat.completions.create(
model=model,
messages=api_messages,
stream=True,
**kwargs
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
def get_available_models(self) -> List[str]:
return [
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-3.5-turbo"
]
class DeepSeekProvider(LLMProvider):
"""DeepSeek 提供商"""
@property
def provider_name(self) -> str:
return "deepseek"
def __init__(self, config: LLMConfig):
super().__init__(config)
self._client = None
def _get_client(self):
if self._client is None:
try:
import openai
self._client = openai.AsyncOpenAI(
api_key=self.config.deepseek_api_key,
base_url=self.config.deepseek_base_url,
timeout=self.config.default_timeout
)
except ImportError:
raise ImportError("请安装 openai 包: pip install openai")
return self._client
async def chat_completion(
self,
model: str,
messages: List[LLMMessage],
temperature: float = None,
max_tokens: int = None,
**kwargs
) -> LLMResponse:
start_time = time.time()
client = self._get_client()
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
response = await self._retry_with_backoff(
client.chat.completions.create,
model=model,
messages=api_messages,
temperature=temperature or self.config.temperature,
max_tokens=max_tokens or self.config.max_tokens,
**kwargs
)
latency = time.time() - start_time
return LLMResponse(
content=response.choices[0].message.content,
model=model,
provider=self.provider_name,
tokens_used=response.usage.total_tokens,
finish_reason=response.choices[0].finish_reason,
latency=latency
)
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
client = self._get_client()
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
stream = await client.chat.completions.create(
model=model,
messages=api_messages,
stream=True,
**kwargs
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
def get_available_models(self) -> List[str]:
return [
"deepseek-chat",
"deepseek-coder"
]
class OllamaProvider(LLMProvider):
"""Ollama 本地模型提供商"""
@property
def provider_name(self) -> str:
return "ollama"
def __init__(self, config: LLMConfig):
super().__init__(config)
self._base_url = config.ollama_base_url
async def chat_completion(
self,
model: str,
messages: List[LLMMessage],
temperature: float = None,
max_tokens: int = None,
**kwargs
) -> LLMResponse:
import aiohttp
start_time = time.time()
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
payload = {
"model": model,
"messages": api_messages,
"stream": False,
"options": {
"temperature": temperature or self.config.temperature,
"num_predict": max_tokens or self.config.max_tokens
}
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self._base_url}/api/chat",
json=payload,
timeout=self.config.default_timeout
) as response:
result = await response.json()
latency = time.time() - start_time
return LLMResponse(
content=result.get("message", {}).get("content", ""),
model=model,
provider=self.provider_name,
tokens_used=result.get("prompt_eval_count", 0) + result.get("eval_count", 0),
latency=latency
)
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
import aiohttp
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
payload = {
"model": model,
"messages": api_messages,
"stream": True
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self._base_url}/api/chat",
json=payload
) as response:
async for line in response.content:
if line:
data = json.loads(line)
if "message" in data:
yield data["message"].get("content", "")
def get_available_models(self) -> List[str]:
return ["llama3", "llama3.2", "mistral", "codellama", "deepseek-coder"]
class ModelRouter:
"""
智能模型路由器
根据任务类型和需求自动选择最合适的模型
"""
# 默认路由规则
ROUTING_RULES = {
TaskType.COMPLEX_REASONING: ("anthropic", "claude-opus-4.6"),
TaskType.CODE_GENERATION: ("anthropic", "claude-sonnet-4.6"),
TaskType.CODE_REVIEW: ("anthropic", "claude-sonnet-4.6"),
TaskType.ARCHITECTURE_DESIGN: ("anthropic", "claude-opus-4.6"),
TaskType.TEST_GENERATION: ("anthropic", "claude-sonnet-4.6"),
TaskType.SIMPLE_TASK: ("anthropic", "claude-haiku-4.6"),
TaskType.COST_SENSITIVE: ("deepseek", "deepseek-chat"),
TaskType.LOCAL_PRIVACY: ("ollama", "llama3"),
}
def __init__(self, config: LLMConfig = None):
self.config = config or LLMConfig.from_env()
self.providers: Dict[str, LLMProvider] = {}
self._initialize_providers()
def _initialize_providers(self):
"""初始化可用的提供商"""
if self.config.anthropic_api_key:
self.providers["anthropic"] = AnthropicProvider(self.config)
if self.config.openai_api_key:
self.providers["openai"] = OpenAIProvider(self.config)
if self.config.deepseek_api_key:
self.providers["deepseek"] = DeepSeekProvider(self.config)
# Ollama 总是可用(本地服务)
self.providers["ollama"] = OllamaProvider(self.config)
def classify_task(self, task_description: str) -> TaskType:
"""
分析任务描述分类任务类型
使用关键词匹配和启发式规则
"""
task_lower = task_description.lower()
# 检查关键词
keywords_map = {
TaskType.ARCHITECTURE_DESIGN: ["架构", "设计", "系统设计", "技术选型", "架构图"],
TaskType.CODE_GENERATION: ["实现", "编写", "生成代码", "开发", "创建函数"],
TaskType.CODE_REVIEW: ["审查", "review", "检查", "分析代码"],
TaskType.TEST_GENERATION: ["测试", "单元测试", "测试用例"],
TaskType.COMPLEX_REASONING: ["分析", "推理", "判断", "复杂", "评估"],
}
# 计算匹配分数
scores = {}
for task_type, keywords in keywords_map.items():
score = sum(1 for kw in keywords if kw in task_lower)
if score > 0:
scores[task_type] = score
# 返回最高分的类型
if scores:
return max(scores, key=scores.get)
# 默认返回简单任务
return TaskType.SIMPLE_TASK
def get_route(self, task_type: TaskType, preferred_provider: str = None) -> tuple:
"""
获取路由决策
返回: (provider_name, model_name)
"""
# 如果指定了提供商,尝试使用
if preferred_provider and preferred_provider in self.providers:
provider = self.providers[preferred_provider]
models = provider.get_available_models()
if models:
return preferred_provider, models[0]
# 使用路由规则
if task_type in self.ROUTING_RULES:
provider_name, model_name = self.ROUTING_RULES[task_type]
if provider_name in self.providers:
return provider_name, model_name
# 回退到第一个可用的提供商
for provider_name, provider in self.providers.items():
models = provider.get_available_models()
if models:
return provider_name, models[0]
raise RuntimeError("没有可用的 LLM 提供商")
async def route_task(
self,
task: str,
messages: List[LLMMessage] = None,
preferred_model: str = None,
preferred_provider: str = None,
**kwargs
) -> LLMResponse:
"""
智能路由任务到合适的模型
参数:
task: 任务描述
messages: 消息列表如果为 None会自动从 task 创建
preferred_model: 首选模型
preferred_provider: 首选提供商
"""
# 如果指定了首选模型,尝试直接使用
if preferred_model:
if "-" in preferred_model:
# 从模型名推断提供商
if preferred_model.startswith("claude"):
provider_name = "anthropic"
elif preferred_model.startswith("gpt"):
provider_name = "openai"
elif preferred_model.startswith("deepseek"):
provider_name = "deepseek"
else:
provider_name = preferred_provider or "anthropic"
if provider_name in self.providers:
provider = self.providers[provider_name]
return await provider.chat_completion(
preferred_model,
messages or [LLMMessage(role="user", content=task)],
**kwargs
)
# 分类任务类型
task_type = self.classify_task(task)
provider_name, model_name = self.get_route(task_type, preferred_provider)
logger.info(f"路由任务: {task_type.value} -> {provider_name}/{model_name}")
provider = self.providers[provider_name]
return await provider.chat_completion(
model_name,
messages or [LLMMessage(role="user", content=task)],
**kwargs
)
def get_available_providers(self) -> List[str]:
"""获取所有可用的提供商"""
return list(self.providers.keys())
def get_provider_models(self, provider_name: str) -> List[str]:
"""获取指定提供商的可用模型"""
if provider_name in self.providers:
return self.providers[provider_name].get_available_models()
return []
# 单例获取函数
_llm_service: Optional[ModelRouter] = None
def get_llm_service(config: LLMConfig = None) -> ModelRouter:
"""获取 LLM 服务单例"""
global _llm_service
if _llm_service is None:
_llm_service = ModelRouter(config or LLMConfig.from_env())
return _llm_service
def reset_llm_service():
"""重置 LLM 服务(主要用于测试)"""
global _llm_service
_llm_service = None