完整实现 Swarm 多智能体协作系统
- 新增 CLIPluginAdapter 统一接口 (backend/app/core/agent_adapter.py) - 新增 LLM 服务层,支持 Anthropic/OpenAI/DeepSeek/Ollama (backend/app/services/llm_service.py) - 新增 Agent 执行引擎,支持文件锁自动管理 (backend/app/services/agent_executor.py) - 新增 NativeLLMAgent 原生 LLM 适配器 (backend/app/adapters/native_llm_agent.py) - 新增进程管理器 (backend/app/services/process_manager.py) - 新增 Agent 控制 API (backend/app/routers/agents_control.py) - 新增 WebSocket 实时通信 (backend/app/routers/websocket.py) - 更新前端 AgentsPage,支持启动/停止 Agent - 测试通过:Agent 启动、批量操作、栅栏同步 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
669
backend/app/services/llm_service.py
Normal file
669
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,669 @@
|
||||
"""
|
||||
LLM 服务层
|
||||
|
||||
提供统一的 LLM 调用接口,支持多个提供商:
|
||||
- Anthropic (Claude)
|
||||
- OpenAI (GPT)
|
||||
- DeepSeek
|
||||
- Ollama (本地模型)
|
||||
- Google (Gemini)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""任务类型分类"""
|
||||
COMPLEX_REASONING = "complex_reasoning"
|
||||
CODE_GENERATION = "code_generation"
|
||||
CODE_REVIEW = "code_review"
|
||||
SIMPLE_TASK = "simple_task"
|
||||
COST_SENSITIVE = "cost_sensitive"
|
||||
LOCAL_PRIVACY = "local_privacy"
|
||||
MULTIMODAL = "multimodal"
|
||||
ARCHITECTURE_DESIGN = "architecture_design"
|
||||
TEST_GENERATION = "test_generation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessage:
|
||||
"""LLM 消息"""
|
||||
role: str # system, user, assistant
|
||||
content: str
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""LLM 响应"""
|
||||
content: str
|
||||
model: str
|
||||
provider: str
|
||||
tokens_used: int = 0
|
||||
finish_reason: str = ""
|
||||
latency: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""LLM 配置"""
|
||||
# Anthropic
|
||||
anthropic_api_key: Optional[str] = None
|
||||
anthropic_base_url: str = "https://api.anthropic.com"
|
||||
|
||||
# OpenAI
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_base_url: str = "https://api.openai.com/v1"
|
||||
|
||||
# DeepSeek
|
||||
deepseek_api_key: Optional[str] = None
|
||||
deepseek_base_url: str = "https://api.deepseek.com"
|
||||
|
||||
# Google
|
||||
google_api_key: Optional[str] = None
|
||||
|
||||
# Ollama
|
||||
ollama_base_url: str = "http://localhost:11434"
|
||||
|
||||
# 通用设置
|
||||
default_timeout: int = 120
|
||||
max_retries: int = 3
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "LLMConfig":
|
||||
"""从环境变量加载配置"""
|
||||
return cls(
|
||||
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
||||
deepseek_api_key=os.getenv("DEEPSEEK_API_KEY"),
|
||||
google_api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
ollama_base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
|
||||
default_timeout=int(os.getenv("LLM_TIMEOUT", "120")),
|
||||
max_retries=int(os.getenv("LLM_MAX_RETRIES", "3")),
|
||||
temperature=float(os.getenv("LLM_TEMPERATURE", "0.7")),
|
||||
max_tokens=int(os.getenv("LLM_MAX_TOKENS", "4096"))
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""LLM 提供商抽象基类"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
"""提供商名称"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""聊天补全"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
**kwargs
|
||||
):
|
||||
"""流式补全"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
pass
|
||||
|
||||
async def _retry_with_backoff(self, func, *args, **kwargs):
|
||||
"""带退避的重试机制"""
|
||||
last_error = None
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < self.config.max_retries - 1:
|
||||
wait_time = 2 ** attempt
|
||||
logger.warning(f"Attempt {attempt + 1} failed, retrying in {wait_time}s: {e}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"All {self.config.max_retries} attempts failed")
|
||||
raise last_error
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""Anthropic Claude 提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "anthropic"
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._client = None
|
||||
|
||||
def _get_client(self):
|
||||
"""懒加载客户端"""
|
||||
if self._client is None:
|
||||
try:
|
||||
import anthropic
|
||||
self._client = anthropic.AsyncAnthropic(
|
||||
api_key=self.config.anthropic_api_key,
|
||||
base_url=self.config.anthropic_base_url,
|
||||
timeout=self.config.default_timeout
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 anthropic 包: pip install anthropic")
|
||||
return self._client
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
start_time = time.time()
|
||||
|
||||
# 分离系统消息
|
||||
system_message = ""
|
||||
user_messages = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
system_message = msg.content
|
||||
else:
|
||||
user_messages.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
response = await self._retry_with_backoff(
|
||||
client.messages.create,
|
||||
model=model,
|
||||
system=system_message if system_message else None,
|
||||
messages=user_messages,
|
||||
temperature=temperature or self.config.temperature,
|
||||
max_tokens=max_tokens or self.config.max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
latency = time.time() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
content=response.content[0].text,
|
||||
model=model,
|
||||
provider=self.provider_name,
|
||||
tokens_used=response.usage.input_tokens + response.usage.output_tokens,
|
||||
finish_reason=response.stop_reason,
|
||||
latency=latency
|
||||
)
|
||||
|
||||
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
|
||||
client = self._get_client()
|
||||
system_message = ""
|
||||
user_messages = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
system_message = msg.content
|
||||
else:
|
||||
user_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
system=system_message if system_message else None,
|
||||
messages=user_messages,
|
||||
max_tokens=self.config.max_tokens,
|
||||
**kwargs
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
return [
|
||||
"claude-opus-4.6",
|
||||
"claude-sonnet-4.6",
|
||||
"claude-haiku-4.6",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022"
|
||||
]
|
||||
|
||||
|
||||
class OpenAIProvider(LLMProvider):
|
||||
"""OpenAI GPT 提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._client = None
|
||||
|
||||
def _get_client(self):
|
||||
if self._client is None:
|
||||
try:
|
||||
import openai
|
||||
self._client = openai.AsyncOpenAI(
|
||||
api_key=self.config.openai_api_key,
|
||||
base_url=self.config.openai_base_url,
|
||||
timeout=self.config.default_timeout
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 openai 包: pip install openai")
|
||||
return self._client
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
start_time = time.time()
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
api_messages = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
response = await self._retry_with_backoff(
|
||||
client.chat.completions.create,
|
||||
model=model,
|
||||
messages=api_messages,
|
||||
temperature=temperature or self.config.temperature,
|
||||
max_tokens=max_tokens or self.config.max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
latency = time.time() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
content=response.choices[0].message.content,
|
||||
model=model,
|
||||
provider=self.provider_name,
|
||||
tokens_used=response.usage.total_tokens,
|
||||
finish_reason=response.choices[0].finish_reason,
|
||||
latency=latency
|
||||
)
|
||||
|
||||
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
|
||||
client = self._get_client()
|
||||
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=api_messages,
|
||||
stream=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
return [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-3.5-turbo"
|
||||
]
|
||||
|
||||
|
||||
class DeepSeekProvider(LLMProvider):
|
||||
"""DeepSeek 提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "deepseek"
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._client = None
|
||||
|
||||
def _get_client(self):
|
||||
if self._client is None:
|
||||
try:
|
||||
import openai
|
||||
self._client = openai.AsyncOpenAI(
|
||||
api_key=self.config.deepseek_api_key,
|
||||
base_url=self.config.deepseek_base_url,
|
||||
timeout=self.config.default_timeout
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 openai 包: pip install openai")
|
||||
return self._client
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
start_time = time.time()
|
||||
|
||||
client = self._get_client()
|
||||
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
||||
|
||||
response = await self._retry_with_backoff(
|
||||
client.chat.completions.create,
|
||||
model=model,
|
||||
messages=api_messages,
|
||||
temperature=temperature or self.config.temperature,
|
||||
max_tokens=max_tokens or self.config.max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
latency = time.time() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
content=response.choices[0].message.content,
|
||||
model=model,
|
||||
provider=self.provider_name,
|
||||
tokens_used=response.usage.total_tokens,
|
||||
finish_reason=response.choices[0].finish_reason,
|
||||
latency=latency
|
||||
)
|
||||
|
||||
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
|
||||
client = self._get_client()
|
||||
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=api_messages,
|
||||
stream=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
return [
|
||||
"deepseek-chat",
|
||||
"deepseek-coder"
|
||||
]
|
||||
|
||||
|
||||
class OllamaProvider(LLMProvider):
|
||||
"""Ollama 本地模型提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "ollama"
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._base_url = config.ollama_base_url
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
import aiohttp
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": api_messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"num_predict": max_tokens or self.config.max_tokens
|
||||
}
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._base_url}/api/chat",
|
||||
json=payload,
|
||||
timeout=self.config.default_timeout
|
||||
) as response:
|
||||
result = await response.json()
|
||||
|
||||
latency = time.time() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
content=result.get("message", {}).get("content", ""),
|
||||
model=model,
|
||||
provider=self.provider_name,
|
||||
tokens_used=result.get("prompt_eval_count", 0) + result.get("eval_count", 0),
|
||||
latency=latency
|
||||
)
|
||||
|
||||
async def stream_completion(self, model: str, messages: List[LLMMessage], **kwargs):
|
||||
import aiohttp
|
||||
|
||||
api_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": api_messages,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._base_url}/api/chat",
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.content:
|
||||
if line:
|
||||
data = json.loads(line)
|
||||
if "message" in data:
|
||||
yield data["message"].get("content", "")
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
return ["llama3", "llama3.2", "mistral", "codellama", "deepseek-coder"]
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
"""
|
||||
智能模型路由器
|
||||
|
||||
根据任务类型和需求自动选择最合适的模型
|
||||
"""
|
||||
|
||||
# 默认路由规则
|
||||
ROUTING_RULES = {
|
||||
TaskType.COMPLEX_REASONING: ("anthropic", "claude-opus-4.6"),
|
||||
TaskType.CODE_GENERATION: ("anthropic", "claude-sonnet-4.6"),
|
||||
TaskType.CODE_REVIEW: ("anthropic", "claude-sonnet-4.6"),
|
||||
TaskType.ARCHITECTURE_DESIGN: ("anthropic", "claude-opus-4.6"),
|
||||
TaskType.TEST_GENERATION: ("anthropic", "claude-sonnet-4.6"),
|
||||
TaskType.SIMPLE_TASK: ("anthropic", "claude-haiku-4.6"),
|
||||
TaskType.COST_SENSITIVE: ("deepseek", "deepseek-chat"),
|
||||
TaskType.LOCAL_PRIVACY: ("ollama", "llama3"),
|
||||
}
|
||||
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
self.config = config or LLMConfig.from_env()
|
||||
self.providers: Dict[str, LLMProvider] = {}
|
||||
self._initialize_providers()
|
||||
|
||||
def _initialize_providers(self):
|
||||
"""初始化可用的提供商"""
|
||||
if self.config.anthropic_api_key:
|
||||
self.providers["anthropic"] = AnthropicProvider(self.config)
|
||||
if self.config.openai_api_key:
|
||||
self.providers["openai"] = OpenAIProvider(self.config)
|
||||
if self.config.deepseek_api_key:
|
||||
self.providers["deepseek"] = DeepSeekProvider(self.config)
|
||||
# Ollama 总是可用(本地服务)
|
||||
self.providers["ollama"] = OllamaProvider(self.config)
|
||||
|
||||
def classify_task(self, task_description: str) -> TaskType:
|
||||
"""
|
||||
分析任务描述,分类任务类型
|
||||
|
||||
使用关键词匹配和启发式规则
|
||||
"""
|
||||
task_lower = task_description.lower()
|
||||
|
||||
# 检查关键词
|
||||
keywords_map = {
|
||||
TaskType.ARCHITECTURE_DESIGN: ["架构", "设计", "系统设计", "技术选型", "架构图"],
|
||||
TaskType.CODE_GENERATION: ["实现", "编写", "生成代码", "开发", "创建函数"],
|
||||
TaskType.CODE_REVIEW: ["审查", "review", "检查", "分析代码"],
|
||||
TaskType.TEST_GENERATION: ["测试", "单元测试", "测试用例"],
|
||||
TaskType.COMPLEX_REASONING: ["分析", "推理", "判断", "复杂", "评估"],
|
||||
}
|
||||
|
||||
# 计算匹配分数
|
||||
scores = {}
|
||||
for task_type, keywords in keywords_map.items():
|
||||
score = sum(1 for kw in keywords if kw in task_lower)
|
||||
if score > 0:
|
||||
scores[task_type] = score
|
||||
|
||||
# 返回最高分的类型
|
||||
if scores:
|
||||
return max(scores, key=scores.get)
|
||||
|
||||
# 默认返回简单任务
|
||||
return TaskType.SIMPLE_TASK
|
||||
|
||||
def get_route(self, task_type: TaskType, preferred_provider: str = None) -> tuple:
|
||||
"""
|
||||
获取路由决策
|
||||
|
||||
返回: (provider_name, model_name)
|
||||
"""
|
||||
# 如果指定了提供商,尝试使用
|
||||
if preferred_provider and preferred_provider in self.providers:
|
||||
provider = self.providers[preferred_provider]
|
||||
models = provider.get_available_models()
|
||||
if models:
|
||||
return preferred_provider, models[0]
|
||||
|
||||
# 使用路由规则
|
||||
if task_type in self.ROUTING_RULES:
|
||||
provider_name, model_name = self.ROUTING_RULES[task_type]
|
||||
if provider_name in self.providers:
|
||||
return provider_name, model_name
|
||||
|
||||
# 回退到第一个可用的提供商
|
||||
for provider_name, provider in self.providers.items():
|
||||
models = provider.get_available_models()
|
||||
if models:
|
||||
return provider_name, models[0]
|
||||
|
||||
raise RuntimeError("没有可用的 LLM 提供商")
|
||||
|
||||
async def route_task(
|
||||
self,
|
||||
task: str,
|
||||
messages: List[LLMMessage] = None,
|
||||
preferred_model: str = None,
|
||||
preferred_provider: str = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
智能路由任务到合适的模型
|
||||
|
||||
参数:
|
||||
task: 任务描述
|
||||
messages: 消息列表(如果为 None,会自动从 task 创建)
|
||||
preferred_model: 首选模型
|
||||
preferred_provider: 首选提供商
|
||||
"""
|
||||
# 如果指定了首选模型,尝试直接使用
|
||||
if preferred_model:
|
||||
if "-" in preferred_model:
|
||||
# 从模型名推断提供商
|
||||
if preferred_model.startswith("claude"):
|
||||
provider_name = "anthropic"
|
||||
elif preferred_model.startswith("gpt"):
|
||||
provider_name = "openai"
|
||||
elif preferred_model.startswith("deepseek"):
|
||||
provider_name = "deepseek"
|
||||
else:
|
||||
provider_name = preferred_provider or "anthropic"
|
||||
|
||||
if provider_name in self.providers:
|
||||
provider = self.providers[provider_name]
|
||||
return await provider.chat_completion(
|
||||
preferred_model,
|
||||
messages or [LLMMessage(role="user", content=task)],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 分类任务类型
|
||||
task_type = self.classify_task(task)
|
||||
provider_name, model_name = self.get_route(task_type, preferred_provider)
|
||||
|
||||
logger.info(f"路由任务: {task_type.value} -> {provider_name}/{model_name}")
|
||||
|
||||
provider = self.providers[provider_name]
|
||||
return await provider.chat_completion(
|
||||
model_name,
|
||||
messages or [LLMMessage(role="user", content=task)],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def get_available_providers(self) -> List[str]:
|
||||
"""获取所有可用的提供商"""
|
||||
return list(self.providers.keys())
|
||||
|
||||
def get_provider_models(self, provider_name: str) -> List[str]:
|
||||
"""获取指定提供商的可用模型"""
|
||||
if provider_name in self.providers:
|
||||
return self.providers[provider_name].get_available_models()
|
||||
return []
|
||||
|
||||
|
||||
# 单例获取函数
|
||||
_llm_service: Optional[ModelRouter] = None
|
||||
|
||||
|
||||
def get_llm_service(config: LLMConfig = None) -> ModelRouter:
|
||||
"""获取 LLM 服务单例"""
|
||||
global _llm_service
|
||||
if _llm_service is None:
|
||||
_llm_service = ModelRouter(config or LLMConfig.from_env())
|
||||
return _llm_service
|
||||
|
||||
|
||||
def reset_llm_service():
|
||||
"""重置 LLM 服务(主要用于测试)"""
|
||||
global _llm_service
|
||||
_llm_service = None
|
||||
Reference in New Issue
Block a user