242 lines
7.7 KiB
Python
242 lines
7.7 KiB
Python
|
|
"""
|
|||
|
|
AI接口管理路由
|
|||
|
|
"""
|
|||
|
|
from typing import List, Optional, Dict, Any
|
|||
|
|
from fastapi import APIRouter, HTTPException, status
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from services.ai_provider_service import AIProviderService
|
|||
|
|
from utils.encryption import mask_api_key
|
|||
|
|
|
|||
|
|
|
|||
|
|
router = APIRouter()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ 请求/响应模型 ============
|
|||
|
|
|
|||
|
|
class ProxyConfigModel(BaseModel):
|
|||
|
|
"""代理配置模型"""
|
|||
|
|
http_proxy: Optional[str] = None
|
|||
|
|
https_proxy: Optional[str] = None
|
|||
|
|
no_proxy: List[str] = []
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RateLimitModel(BaseModel):
|
|||
|
|
"""速率限制模型"""
|
|||
|
|
requests_per_minute: int = 60
|
|||
|
|
tokens_per_minute: int = 100000
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ProviderCreateRequest(BaseModel):
|
|||
|
|
"""创建AI接口请求"""
|
|||
|
|
provider_type: str = Field(..., description="提供商类型: minimax, zhipu, openrouter, kimi, deepseek, gemini, ollama, llmstudio")
|
|||
|
|
name: str = Field(..., description="自定义名称")
|
|||
|
|
model: str = Field(..., description="模型名称")
|
|||
|
|
api_key: str = Field(default="", description="API密钥")
|
|||
|
|
base_url: str = Field(default="", description="API基础URL")
|
|||
|
|
use_proxy: bool = Field(default=False, description="是否使用代理")
|
|||
|
|
proxy_config: Optional[ProxyConfigModel] = None
|
|||
|
|
rate_limit: Optional[RateLimitModel] = None
|
|||
|
|
timeout: int = Field(default=60, description="超时时间(秒)")
|
|||
|
|
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数")
|
|||
|
|
|
|||
|
|
class Config:
|
|||
|
|
json_schema_extra = {
|
|||
|
|
"example": {
|
|||
|
|
"provider_type": "openrouter",
|
|||
|
|
"name": "OpenRouter GPT-4",
|
|||
|
|
"model": "openai/gpt-4-turbo",
|
|||
|
|
"api_key": "sk-xxx",
|
|||
|
|
"use_proxy": True,
|
|||
|
|
"proxy_config": {
|
|||
|
|
"http_proxy": "http://127.0.0.1:7890",
|
|||
|
|
"https_proxy": "http://127.0.0.1:7890"
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ProviderUpdateRequest(BaseModel):
|
|||
|
|
"""更新AI接口请求"""
|
|||
|
|
name: Optional[str] = None
|
|||
|
|
model: Optional[str] = None
|
|||
|
|
api_key: Optional[str] = None
|
|||
|
|
base_url: Optional[str] = None
|
|||
|
|
use_proxy: Optional[bool] = None
|
|||
|
|
proxy_config: Optional[ProxyConfigModel] = None
|
|||
|
|
rate_limit: Optional[RateLimitModel] = None
|
|||
|
|
timeout: Optional[int] = None
|
|||
|
|
extra_params: Optional[Dict[str, Any]] = None
|
|||
|
|
enabled: Optional[bool] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ProviderResponse(BaseModel):
|
|||
|
|
"""AI接口响应"""
|
|||
|
|
provider_id: str
|
|||
|
|
provider_type: str
|
|||
|
|
name: str
|
|||
|
|
api_key_masked: str
|
|||
|
|
base_url: str
|
|||
|
|
model: str
|
|||
|
|
use_proxy: bool
|
|||
|
|
proxy_config: Dict[str, Any]
|
|||
|
|
rate_limit: Dict[str, int]
|
|||
|
|
timeout: int
|
|||
|
|
extra_params: Dict[str, Any]
|
|||
|
|
enabled: bool
|
|||
|
|
created_at: str
|
|||
|
|
updated_at: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestConfigRequest(BaseModel):
|
|||
|
|
"""测试配置请求"""
|
|||
|
|
provider_type: str
|
|||
|
|
api_key: str
|
|||
|
|
base_url: str = ""
|
|||
|
|
model: str = ""
|
|||
|
|
use_proxy: bool = False
|
|||
|
|
proxy_config: Optional[ProxyConfigModel] = None
|
|||
|
|
timeout: int = 30
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestResponse(BaseModel):
|
|||
|
|
"""测试响应"""
|
|||
|
|
success: bool
|
|||
|
|
message: str
|
|||
|
|
model: Optional[str] = None
|
|||
|
|
latency_ms: Optional[float] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ 路由处理 ============
|
|||
|
|
|
|||
|
|
@router.post("", response_model=ProviderResponse, status_code=status.HTTP_201_CREATED)
|
|||
|
|
async def create_provider(request: ProviderCreateRequest):
|
|||
|
|
"""
|
|||
|
|
创建新的AI接口配置
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
provider = await AIProviderService.create_provider(
|
|||
|
|
provider_type=request.provider_type,
|
|||
|
|
name=request.name,
|
|||
|
|
model=request.model,
|
|||
|
|
api_key=request.api_key,
|
|||
|
|
base_url=request.base_url,
|
|||
|
|
use_proxy=request.use_proxy,
|
|||
|
|
proxy_config=request.proxy_config.dict() if request.proxy_config else None,
|
|||
|
|
rate_limit=request.rate_limit.dict() if request.rate_limit else None,
|
|||
|
|
timeout=request.timeout,
|
|||
|
|
extra_params=request.extra_params
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return _to_response(provider)
|
|||
|
|
|
|||
|
|
except ValueError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"创建AI接口失败: {e}")
|
|||
|
|
raise HTTPException(status_code=500, detail="创建失败")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("", response_model=List[ProviderResponse])
|
|||
|
|
async def list_providers(enabled_only: bool = False):
|
|||
|
|
"""
|
|||
|
|
获取所有AI接口配置
|
|||
|
|
"""
|
|||
|
|
providers = await AIProviderService.get_all_providers(enabled_only)
|
|||
|
|
return [_to_response(p) for p in providers]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{provider_id}", response_model=ProviderResponse)
|
|||
|
|
async def get_provider(provider_id: str):
|
|||
|
|
"""
|
|||
|
|
获取指定AI接口配置
|
|||
|
|
"""
|
|||
|
|
provider = await AIProviderService.get_provider(provider_id)
|
|||
|
|
if not provider:
|
|||
|
|
raise HTTPException(status_code=404, detail="AI接口不存在")
|
|||
|
|
return _to_response(provider)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.put("/{provider_id}", response_model=ProviderResponse)
|
|||
|
|
async def update_provider(provider_id: str, request: ProviderUpdateRequest):
|
|||
|
|
"""
|
|||
|
|
更新AI接口配置
|
|||
|
|
"""
|
|||
|
|
update_data = request.dict(exclude_unset=True)
|
|||
|
|
|
|||
|
|
# 转换嵌套模型
|
|||
|
|
if "proxy_config" in update_data and update_data["proxy_config"]:
|
|||
|
|
update_data["proxy_config"] = update_data["proxy_config"].dict() if hasattr(update_data["proxy_config"], "dict") else update_data["proxy_config"]
|
|||
|
|
if "rate_limit" in update_data and update_data["rate_limit"]:
|
|||
|
|
update_data["rate_limit"] = update_data["rate_limit"].dict() if hasattr(update_data["rate_limit"], "dict") else update_data["rate_limit"]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
provider = await AIProviderService.update_provider(provider_id, **update_data)
|
|||
|
|
if not provider:
|
|||
|
|
raise HTTPException(status_code=404, detail="AI接口不存在")
|
|||
|
|
return _to_response(provider)
|
|||
|
|
|
|||
|
|
except ValueError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|||
|
|
async def delete_provider(provider_id: str):
|
|||
|
|
"""
|
|||
|
|
删除AI接口配置
|
|||
|
|
"""
|
|||
|
|
success = await AIProviderService.delete_provider(provider_id)
|
|||
|
|
if not success:
|
|||
|
|
raise HTTPException(status_code=404, detail="AI接口不存在")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/{provider_id}/test", response_model=TestResponse)
|
|||
|
|
async def test_provider(provider_id: str):
|
|||
|
|
"""
|
|||
|
|
测试AI接口连接
|
|||
|
|
"""
|
|||
|
|
result = await AIProviderService.test_provider(provider_id)
|
|||
|
|
return TestResponse(**result)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/test", response_model=TestResponse)
|
|||
|
|
async def test_provider_config(request: TestConfigRequest):
|
|||
|
|
"""
|
|||
|
|
测试AI接口配置(不保存)
|
|||
|
|
"""
|
|||
|
|
result = await AIProviderService.test_provider_config(
|
|||
|
|
provider_type=request.provider_type,
|
|||
|
|
api_key=request.api_key,
|
|||
|
|
base_url=request.base_url,
|
|||
|
|
model=request.model,
|
|||
|
|
use_proxy=request.use_proxy,
|
|||
|
|
proxy_config=request.proxy_config.dict() if request.proxy_config else None,
|
|||
|
|
timeout=request.timeout
|
|||
|
|
)
|
|||
|
|
return TestResponse(**result)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ 辅助函数 ============
|
|||
|
|
|
|||
|
|
def _to_response(provider) -> ProviderResponse:
|
|||
|
|
"""
|
|||
|
|
转换为响应模型
|
|||
|
|
"""
|
|||
|
|
return ProviderResponse(
|
|||
|
|
provider_id=provider.provider_id,
|
|||
|
|
provider_type=provider.provider_type,
|
|||
|
|
name=provider.name,
|
|||
|
|
api_key_masked=mask_api_key(provider.api_key) if provider.api_key else "",
|
|||
|
|
base_url=provider.base_url,
|
|||
|
|
model=provider.model,
|
|||
|
|
use_proxy=provider.use_proxy,
|
|||
|
|
proxy_config=provider.proxy_config,
|
|||
|
|
rate_limit=provider.rate_limit,
|
|||
|
|
timeout=provider.timeout,
|
|||
|
|
extra_params=provider.extra_params,
|
|||
|
|
enabled=provider.enabled,
|
|||
|
|
created_at=provider.created_at.isoformat(),
|
|||
|
|
updated_at=provider.updated_at.isoformat()
|
|||
|
|
)
|