670 lines
20 KiB
Python
670 lines
20 KiB
Python
|
|
"""
|
|||
|
|
LLM 服务层
|
|||
|
|
|
|||
|
|
提供统一的 LLM 调用接口,支持多个提供商:
|
|||
|
|
- Anthropic (Claude)
|
|||
|
|
- OpenAI (GPT)
|
|||
|
|
- DeepSeek
|
|||
|
|
- Ollama (本地模型)
|
|||
|
|
- Google (Gemini)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import json
|
|||
|
|
import asyncio
|
|||
|
|
import logging
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
from typing import Dict, List, Optional, Any, Union
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from enum import Enum
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TaskType(Enum):
|
|||
|
|
"""任务类型分类"""
|
|||
|
|
COMPLEX_REASONING = "complex_reasoning"
|
|||
|
|
CODE_GENERATION = "code_generation"
|
|||
|
|
CODE_REVIEW = "code_review"
|
|||
|
|
SIMPLE_TASK = "simple_task"
|
|||
|
|
COST_SENSITIVE = "cost_sensitive"
|
|||
|
|
LOCAL_PRIVACY = "local_privacy"
|
|||
|
|
MULTIMODAL = "multimodal"
|
|||
|
|
ARCHITECTURE_DESIGN = "architecture_design"
|
|||
|
|
TEST_GENERATION = "test_generation"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class LLMMessage:
|
|||
|
|
"""LLM 消息"""
|
|||
|
|
role: str # system, user, assistant
|
|||
|
|
content: str
|
|||
|
|
images: Optional[List[str]] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class LLMResponse:
|
|||
|
|
"""LLM 响应"""
|
|||
|
|
content: str
|
|||
|
|
model: str
|
|||
|
|
provider: str
|
|||
|
|
tokens_used: int = 0
|
|||
|
|
finish_reason: str = ""
|
|||
|
|
latency: float = 0.0
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class LLMConfig:
|
|||
|
|
"""LLM 配置"""
|
|||
|
|
# Anthropic
|
|||
|
|
anthropic_api_key: Optional[str] = None
|
|||
|
|
anthropic_base_url: str = "https://api.anthropic.com"
|
|||
|
|
|
|||
|
|
# OpenAI
|
|||
|
|
openai_api_key: Optional[str] = None
|
|||
|
|
openai_base_url: str = "https://api.openai.com/v1"
|
|||
|
|
|
|||
|
|
# DeepSeek
|
|||
|
|
deepseek_api_key: Optional[str] = None
|
|||
|
|
deepseek_base_url: str = "https://api.deepseek.com"
|
|||
|
|
|
|||
|
|
# Google
|
|||
|
|
google_api_key: Optional[str] = None
|
|||
|
|
|
|||
|
|
# Ollama
|
|||
|
|
ollama_base_url: str = "http://localhost:11434"
|
|||
|
|
|
|||
|
|
# 通用设置
|
|||
|
|
default_timeout: int = 120
|
|||
|
|
max_retries: int = 3
|
|||
|
|
temperature: float = 0.7
|
|||
|
|
max_tokens: int = 4096
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def from_env(cls) -> "LLMConfig":
|
|||
|
|
"""从环境变量加载配置"""
|
|||
|
|
return cls(
|
|||
|
|
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
|||
|
|
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
|||
|
|
deepseek_api_key=os.getenv("DEEPSEEK_API_KEY"),
|
|||
|
|
google_api_key=os.getenv("GOOGLE_API_KEY"),
|
|||
|
|
ollama_base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
|
|||
|
|
default_timeout=int(os.getenv("LLM_TIMEOUT", "120")),
|
|||
|
|
max_retries=int(os.getenv("LLM_MAX_RETRIES", "3")),
|
|||
|
|
temperature=float(os.getenv("LLM_TEMPERATURE", "0.7")),
|
|||
|
|
max_tokens=int(os.getenv("LLM_MAX_TOKENS", "4096"))
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMProvider(ABC):
|
|||
|
|
"""LLM 提供商抽象基类"""
|
|||
|
|
|
|||
|
|
def __init__(self, config: LLMConfig):
|
|||
|
|
self.config = config
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
@abstractmethod
|
|||
|
|
def provider_name(self) -> str:
|
|||
|
|
"""提供商名称"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def chat_completion(
|
|||
|
|
self,
|
|||
|
|
model: str,
|
|||
|
|
messages: List[LLMMessage],
|
|||
|
|
temperature: float = None,
|
|||
|
|
max_tokens: int = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> LLMResponse:
|
|||
|
|
"""聊天补全"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def stream_completion(
|
|||
|
|
self,
|
|||
|
|
model: str,
|
|||
|
|
messages: List[LLMMessage],
|
|||
|
|
**kwargs
|
|||
|
|
):
|
|||
|
|
"""流式补全"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def get_available_models(self) -> List[str]:
|
|||
|
|
"""获取可用模型列表"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
async def _retry_with_backoff(self, func, *args, **kwargs):
|
|||
|
|
"""带退避的重试机制"""
|
|||
|
|
last_error = None
|
|||
|
|
for attempt in range(self.config.max_retries):
|
|||
|
|
try:
|
|||
|
|
return await func(*args, **kwargs)
|
|||
|
|
except Exception as e:
|
|||
|
|
last_error = e
|
|||
|
|
if attempt < self.config.max_retries - 1:
|
|||
|
|
wait_time = 2 ** attempt
|
|||
|
|
logger.warning(f"Attempt {attempt + 1} failed, retrying in {wait_time}s: {e}")
|
|||
|
|
await asyncio.sleep(wait_time)
|
|||
|
|
else:
|
|||
|
|
logger.error(f"All {self.config.max_retries} attempts failed")
|
|||
|
|
raise last_error
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AnthropicProvider(LLMProvider):
|
|||
|
|
"""Anthropic Claude 提供商"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def provider_name(self) -> str:
|
|||
|
|
return "anthropic"
|
|||
|
|
|
|||
|
|
def __init__(self, config: LLMConfig):
|
|||
|
|
super().__init__(config)
|
|||
|
|
self._client = None
|
|||
|
|
|
|||
|
|
def _get_client(self):
|
|||
|
|
"""懒加载客户端"""
|
|||
|
|
if self._client is None:
|
|||
|
|
try:
|
|||
|
|
import anthropic
|
|||
|
|
self._client = anthropic.AsyncAnthropic(
|
|||
|
|
api_key=self.config.anthropic_api_key,
|
|||
|
|
base_url=self.config.anthropic_base_url,
|
|||
|
|
timeout=self.config.default_timeout
|
|||
|
|
)
|
|||
|
|
except ImportError:
|
|||
|
|
raise ImportError("请安装 anthropic 包: pip install anthropic")
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
async def chat_completion(
|
|||
|
|
self,
|
|||
|
|
model: str,
|
|||
|
|
messages: List[LLMMessage],
|
|||
|
|
temperature: float = None,
|
|||
|
|
max_tokens: int = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> LLMResponse:
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 分离系统消息
|
|||
|
|
system_message = ""
|
|||
|
|
user_messages = []
|
|||
|
|
for msg in messages:
|
|||
|
|
if msg.role == "system":
|
|||
|
|
system_message = msg.content
|
|||
|
|
else:
|
|||
|
|
user_messages.append({
|
|||
|
|
"role": msg.role,
|
|||
|
|
"content": msg.content
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
client = self._get_client()
|
|||
|
|
|
|||
|
|
response = await self._retry_with_backoff(
|
|||
|
|
client.messages.create,
|
|||
|
|
model=model,
|
|||
|
|
system=system_message if system_message else None,
|
|||
|
|
messages=user_messages,
|
|||
|
|
temperature=temperature or self.config.temperature,
|
|||
|
|
max_tokens=max_tokens or self.config.max_tokens,
|
|||
|
|
**kwargs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
latency = time.time() - start_time
|
|||
|
|
|
|||
|
|
return LLMResponse(
|
|||
|
|
content=response.content[0].text,
|
|||
|
|
model=model,
|
|||
|
|
provider=self.provider_name,
|
|||
|
|
tokens_used=response.usage.input_tokens + response.usage.output_tokens,
|
|||
|
|
finish_reason=response.stop_reason,
|
|||
|
|
latency=latency
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
|
|||
|
|
client = self._get_client()
|
|||
|
|
system_message = ""
|
|||
|
|
user_messages = []
|
|||
|
|
for msg in messages:
|
|||
|
|
if msg.role == "system":
|
|||
|
|
system_message = msg.content
|
|||
|
|
else:
|
|||
|
|
user_messages.append({"role": msg.role, "content": msg.content})
|
|||
|
|
|
|||
|
|
async with client.messages.stream(
|
|||
|
|
model=model,
|
|||
|
|
system=system_message if system_message else None,
|
|||
|
|
messages=user_messages,
|
|||
|
|
max_tokens=self.config.max_tokens,
|
|||
|
|
**kwargs
|
|||
|
|
) as stream:
|
|||
|
|
async for text in stream.text_stream:
|
|||
|
|
yield text
|
|||
|
|
|
|||
|
|
def get_available_models(self) -> List[str]:
|
|||
|
|
return [
|
|||
|
|
"claude-opus-4.6",
|
|||
|
|
"claude-sonnet-4.6",
|
|||
|
|
"claude-haiku-4.6",
|
|||
|
|
"claude-3-5-sonnet-20241022",
|
|||
|
|
"claude-3-5-haiku-20241022"
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OpenAIProvider(LLMProvider):
|
|||
|
|
"""OpenAI GPT 提供商"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def provider_name(self) -> str:
|
|||
|
|
return "openai"
|
|||
|
|
|
|||
|
|
def __init__(self, config: LLMConfig):
|
|||
|
|
super().__init__(config)
|
|||
|
|
self._client = None
|
|||
|
|
|
|||
|
|
def _get_client(self):
|
|||
|
|
if self._client is None:
|
|||
|
|
try:
|
|||
|
|
import openai
|
|||
|
|
self._client = openai.AsyncOpenAI(
|
|||
|
|
api_key=self.config.openai_api_key,
|
|||
|
|
base_url=self.config.openai_base_url,
|
|||
|
|
timeout=self.config.default_timeout
|
|||
|
|
)
|
|||
|
|
except ImportError:
|
|||
|
|
raise ImportError("请安装 openai 包: pip install openai")
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
async def chat_completion(
|
|||
|
|
self,
|
|||
|
|
model: str,
|
|||
|
|
messages: List[LLMMessage],
|
|||
|
|
temperature: float = None,
|
|||
|
|
max_tokens: int = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> LLMResponse:
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
client = self._get_client()
|
|||
|
|
|
|||
|
|
api_messages = [
|
|||
|
|
{"role": msg.role, "content": msg.content}
|
|||
|
|
for msg in messages
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
response = await self._retry_with_backoff(
|
|||
|
|
client.chat.completions.create,
|
|||
|
|
model=model,
|
|||
|
|
messages=api_messages,
|
|||
|
|
temperature=temperature or self.config.temperature,
|
|||
|
|
max_tokens=max_tokens or self.config.max_tokens,
|
|||
|
|
**kwargs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
latency = time.time() - start_time
|
|||
|
|
|
|||
|
|
return LLMResponse(
|
|||
|
|
content=response.choices[0].message.content,
|
|||
|
|
model=model,
|
|||
|
|
provider=self.provider_name,
|
|||
|
|
tokens_used=response.usage.total_tokens,
|
|||
|
|
finish_reason=response.choices[0].finish_reason,
|
|||
|
|
latency=latency
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
|
|||
|
|
client = self._get_client()
|
|||
|
|
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|||
|
|
|
|||
|
|
stream = await client.chat.completions.create(
|
|||
|
|
model=model,
|
|||
|
|
messages=api_messages,
|
|||
|
|
stream=True,
|
|||
|
|
**kwargs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async for chunk in stream:
|
|||
|
|
if chunk.choices[0].delta.content:
|
|||
|
|
yield chunk.choices[0].delta.content
|
|||
|
|
|
|||
|
|
def get_available_models(self) -> List[str]:
|
|||
|
|
return [
|
|||
|
|
"gpt-4o",
|
|||
|
|
"gpt-4o-mini",
|
|||
|
|
"gpt-4-turbo",
|
|||
|
|
"gpt-3.5-turbo"
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DeepSeekProvider(LLMProvider):
|
|||
|
|
"""DeepSeek 提供商"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def provider_name(self) -> str:
|
|||
|
|
return "deepseek"
|
|||
|
|
|
|||
|
|
def __init__(self, config: LLMConfig):
|
|||
|
|
super().__init__(config)
|
|||
|
|
self._client = None
|
|||
|
|
|
|||
|
|
def _get_client(self):
|
|||
|
|
if self._client is None:
|
|||
|
|
try:
|
|||
|
|
import openai
|
|||
|
|
self._client = openai.AsyncOpenAI(
|
|||
|
|
api_key=self.config.deepseek_api_key,
|
|||
|
|
base_url=self.config.deepseek_base_url,
|
|||
|
|
timeout=self.config.default_timeout
|
|||
|
|
)
|
|||
|
|
except ImportError:
|
|||
|
|
raise ImportError("请安装 openai 包: pip install openai")
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
async def chat_completion(
|
|||
|
|
self,
|
|||
|
|
model: str,
|
|||
|
|
messages: List[LLMMessage],
|
|||
|
|
temperature: float = None,
|
|||
|
|
max_tokens: int = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> LLMResponse:
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
client = self._get_client()
|
|||
|
|
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|||
|
|
|
|||
|
|
response = await self._retry_with_backoff(
|
|||
|
|
client.chat.completions.create,
|
|||
|
|
model=model,
|
|||
|
|
messages=api_messages,
|
|||
|
|
temperature=temperature or self.config.temperature,
|
|||
|
|
max_tokens=max_tokens or self.config.max_tokens,
|
|||
|
|
**kwargs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
latency = time.time() - start_time
|
|||
|
|
|
|||
|
|
return LLMResponse(
|
|||
|
|
content=response.choices[0].message.content,
|
|||
|
|
model=model,
|
|||
|
|
provider=self.provider_name,
|
|||
|
|
tokens_used=response.usage.total_tokens,
|
|||
|
|
finish_reason=response.choices[0].finish_reason,
|
|||
|
|
latency=latency
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
|
|||
|
|
client = self._get_client()
|
|||
|
|
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|||
|
|
|
|||
|
|
stream = await client.chat.completions.create(
|
|||
|
|
model=model,
|
|||
|
|
messages=api_messages,
|
|||
|
|
stream=True,
|
|||
|
|
**kwargs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async for chunk in stream:
|
|||
|
|
if chunk.choices[0].delta.content:
|
|||
|
|
yield chunk.choices[0].delta.content
|
|||
|
|
|
|||
|
|
def get_available_models(self) -> List[str]:
|
|||
|
|
return [
|
|||
|
|
"deepseek-chat",
|
|||
|
|
"deepseek-coder"
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OllamaProvider(LLMProvider):
|
|||
|
|
"""Ollama 本地模型提供商"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def provider_name(self) -> str:
|
|||
|
|
return "ollama"
|
|||
|
|
|
|||
|
|
def __init__(self, config: LLMConfig):
|
|||
|
|
super().__init__(config)
|
|||
|
|
self._base_url = config.ollama_base_url
|
|||
|
|
|
|||
|
|
async def chat_completion(
|
|||
|
|
self,
|
|||
|
|
model: str,
|
|||
|
|
messages: List[LLMMessage],
|
|||
|
|
temperature: float = None,
|
|||
|
|
max_tokens: int = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> LLMResponse:
|
|||
|
|
import aiohttp
|
|||
|
|
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|||
|
|
|
|||
|
|
payload = {
|
|||
|
|
"model": model,
|
|||
|
|
"messages": api_messages,
|
|||
|
|
"stream": False,
|
|||
|
|
"options": {
|
|||
|
|
"temperature": temperature or self.config.temperature,
|
|||
|
|
"num_predict": max_tokens or self.config.max_tokens
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async with aiohttp.ClientSession() as session:
|
|||
|
|
async with session.post(
|
|||
|
|
f"{self._base_url}/api/chat",
|
|||
|
|
json=payload,
|
|||
|
|
timeout=self.config.default_timeout
|
|||
|
|
) as response:
|
|||
|
|
result = await response.json()
|
|||
|
|
|
|||
|
|
latency = time.time() - start_time
|
|||
|
|
|
|||
|
|
return LLMResponse(
|
|||
|
|
content=result.get("message", {}).get("content", ""),
|
|||
|
|
model=model,
|
|||
|
|
provider=self.provider_name,
|
|||
|
|
tokens_used=result.get("prompt_eval_count", 0) + result.get("eval_count", 0),
|
|||
|
|
latency=latency
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
|
|||
|
|
import aiohttp
|
|||
|
|
|
|||
|
|
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|||
|
|
|
|||
|
|
payload = {
|
|||
|
|
"model": model,
|
|||
|
|
"messages": api_messages,
|
|||
|
|
"stream": True
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async with aiohttp.ClientSession() as session:
|
|||
|
|
async with session.post(
|
|||
|
|
f"{self._base_url}/api/chat",
|
|||
|
|
json=payload
|
|||
|
|
) as response:
|
|||
|
|
async for line in response.content:
|
|||
|
|
if line:
|
|||
|
|
data = json.loads(line)
|
|||
|
|
if "message" in data:
|
|||
|
|
yield data["message"].get("content", "")
|
|||
|
|
|
|||
|
|
def get_available_models(self) -> List[str]:
|
|||
|
|
return ["llama3", "llama3.2", "mistral", "codellama", "deepseek-coder"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModelRouter:
|
|||
|
|
"""
|
|||
|
|
智能模型路由器
|
|||
|
|
|
|||
|
|
根据任务类型和需求自动选择最合适的模型
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 默认路由规则
|
|||
|
|
ROUTING_RULES = {
|
|||
|
|
TaskType.COMPLEX_REASONING: ("anthropic", "claude-opus-4.6"),
|
|||
|
|
TaskType.CODE_GENERATION: ("anthropic", "claude-sonnet-4.6"),
|
|||
|
|
TaskType.CODE_REVIEW: ("anthropic", "claude-sonnet-4.6"),
|
|||
|
|
TaskType.ARCHITECTURE_DESIGN: ("anthropic", "claude-opus-4.6"),
|
|||
|
|
TaskType.TEST_GENERATION: ("anthropic", "claude-sonnet-4.6"),
|
|||
|
|
TaskType.SIMPLE_TASK: ("anthropic", "claude-haiku-4.6"),
|
|||
|
|
TaskType.COST_SENSITIVE: ("deepseek", "deepseek-chat"),
|
|||
|
|
TaskType.LOCAL_PRIVACY: ("ollama", "llama3"),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def __init__(self, config: LLMConfig = None):
|
|||
|
|
self.config = config or LLMConfig.from_env()
|
|||
|
|
self.providers: Dict[str, LLMProvider] = {}
|
|||
|
|
self._initialize_providers()
|
|||
|
|
|
|||
|
|
def _initialize_providers(self):
|
|||
|
|
"""初始化可用的提供商"""
|
|||
|
|
if self.config.anthropic_api_key:
|
|||
|
|
self.providers["anthropic"] = AnthropicProvider(self.config)
|
|||
|
|
if self.config.openai_api_key:
|
|||
|
|
self.providers["openai"] = OpenAIProvider(self.config)
|
|||
|
|
if self.config.deepseek_api_key:
|
|||
|
|
self.providers["deepseek"] = DeepSeekProvider(self.config)
|
|||
|
|
# Ollama 总是可用(本地服务)
|
|||
|
|
self.providers["ollama"] = OllamaProvider(self.config)
|
|||
|
|
|
|||
|
|
def classify_task(self, task_description: str) -> TaskType:
|
|||
|
|
"""
|
|||
|
|
分析任务描述,分类任务类型
|
|||
|
|
|
|||
|
|
使用关键词匹配和启发式规则
|
|||
|
|
"""
|
|||
|
|
task_lower = task_description.lower()
|
|||
|
|
|
|||
|
|
# 检查关键词
|
|||
|
|
keywords_map = {
|
|||
|
|
TaskType.ARCHITECTURE_DESIGN: ["架构", "设计", "系统设计", "技术选型", "架构图"],
|
|||
|
|
TaskType.CODE_GENERATION: ["实现", "编写", "生成代码", "开发", "创建函数"],
|
|||
|
|
TaskType.CODE_REVIEW: ["审查", "review", "检查", "分析代码"],
|
|||
|
|
TaskType.TEST_GENERATION: ["测试", "单元测试", "测试用例"],
|
|||
|
|
TaskType.COMPLEX_REASONING: ["分析", "推理", "判断", "复杂", "评估"],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 计算匹配分数
|
|||
|
|
scores = {}
|
|||
|
|
for task_type, keywords in keywords_map.items():
|
|||
|
|
score = sum(1 for kw in keywords if kw in task_lower)
|
|||
|
|
if score > 0:
|
|||
|
|
scores[task_type] = score
|
|||
|
|
|
|||
|
|
# 返回最高分的类型
|
|||
|
|
if scores:
|
|||
|
|
return max(scores, key=scores.get)
|
|||
|
|
|
|||
|
|
# 默认返回简单任务
|
|||
|
|
return TaskType.SIMPLE_TASK
|
|||
|
|
|
|||
|
|
def get_route(self, task_type: TaskType, preferred_provider: str = None) -> tuple:
|
|||
|
|
"""
|
|||
|
|
获取路由决策
|
|||
|
|
|
|||
|
|
返回: (provider_name, model_name)
|
|||
|
|
"""
|
|||
|
|
# 如果指定了提供商,尝试使用
|
|||
|
|
if preferred_provider and preferred_provider in self.providers:
|
|||
|
|
provider = self.providers[preferred_provider]
|
|||
|
|
models = provider.get_available_models()
|
|||
|
|
if models:
|
|||
|
|
return preferred_provider, models[0]
|
|||
|
|
|
|||
|
|
# 使用路由规则
|
|||
|
|
if task_type in self.ROUTING_RULES:
|
|||
|
|
provider_name, model_name = self.ROUTING_RULES[task_type]
|
|||
|
|
if provider_name in self.providers:
|
|||
|
|
return provider_name, model_name
|
|||
|
|
|
|||
|
|
# 回退到第一个可用的提供商
|
|||
|
|
for provider_name, provider in self.providers.items():
|
|||
|
|
models = provider.get_available_models()
|
|||
|
|
if models:
|
|||
|
|
return provider_name, models[0]
|
|||
|
|
|
|||
|
|
raise RuntimeError("没有可用的 LLM 提供商")
|
|||
|
|
|
|||
|
|
async def route_task(
|
|||
|
|
self,
|
|||
|
|
task: str,
|
|||
|
|
messages: List[LLMMessage] = None,
|
|||
|
|
preferred_model: str = None,
|
|||
|
|
preferred_provider: str = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> LLMResponse:
|
|||
|
|
"""
|
|||
|
|
智能路由任务到合适的模型
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
task: 任务描述
|
|||
|
|
messages: 消息列表(如果为 None,会自动从 task 创建)
|
|||
|
|
preferred_model: 首选模型
|
|||
|
|
preferred_provider: 首选提供商
|
|||
|
|
"""
|
|||
|
|
# 如果指定了首选模型,尝试直接使用
|
|||
|
|
if preferred_model:
|
|||
|
|
if "-" in preferred_model:
|
|||
|
|
# 从模型名推断提供商
|
|||
|
|
if preferred_model.startswith("claude"):
|
|||
|
|
provider_name = "anthropic"
|
|||
|
|
elif preferred_model.startswith("gpt"):
|
|||
|
|
provider_name = "openai"
|
|||
|
|
elif preferred_model.startswith("deepseek"):
|
|||
|
|
provider_name = "deepseek"
|
|||
|
|
else:
|
|||
|
|
provider_name = preferred_provider or "anthropic"
|
|||
|
|
|
|||
|
|
if provider_name in self.providers:
|
|||
|
|
provider = self.providers[provider_name]
|
|||
|
|
return await provider.chat_completion(
|
|||
|
|
preferred_model,
|
|||
|
|
messages or [LLMMessage(role="user", content=task)],
|
|||
|
|
**kwargs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 分类任务类型
|
|||
|
|
task_type = self.classify_task(task)
|
|||
|
|
provider_name, model_name = self.get_route(task_type, preferred_provider)
|
|||
|
|
|
|||
|
|
logger.info(f"路由任务: {task_type.value} -> {provider_name}/{model_name}")
|
|||
|
|
|
|||
|
|
provider = self.providers[provider_name]
|
|||
|
|
return await provider.chat_completion(
|
|||
|
|
model_name,
|
|||
|
|
messages or [LLMMessage(role="user", content=task)],
|
|||
|
|
**kwargs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def get_available_providers(self) -> List[str]:
|
|||
|
|
"""获取所有可用的提供商"""
|
|||
|
|
return list(self.providers.keys())
|
|||
|
|
|
|||
|
|
def get_provider_models(self, provider_name: str) -> List[str]:
|
|||
|
|
"""获取指定提供商的可用模型"""
|
|||
|
|
if provider_name in self.providers:
|
|||
|
|
return self.providers[provider_name].get_available_models()
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 单例获取函数
|
|||
|
|
_llm_service: Optional[ModelRouter] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_llm_service(config: LLMConfig = None) -> ModelRouter:
|
|||
|
|
"""获取 LLM 服务单例"""
|
|||
|
|
global _llm_service
|
|||
|
|
if _llm_service is None:
|
|||
|
|
_llm_service = ModelRouter(config or LLMConfig.from_env())
|
|||
|
|
return _llm_service
|
|||
|
|
|
|||
|
|
|
|||
|
|
def reset_llm_service():
|
|||
|
|
"""重置 LLM 服务(主要用于测试)"""
|
|||
|
|
global _llm_service
|
|||
|
|
_llm_service = None
|