"""API 代理路由 -- OpenAI / Anthropic 兼容端点""" from __future__ import annotations import json from typing import Any from fastapi import APIRouter, Header, HTTPException, Request from fastapi.responses import StreamingResponse from app import database as db from app.config import settings from app.providers import ProviderRegistry router = APIRouter() def _verify_key(authorization: str | None): expected = settings.server.proxy_api_key if not expected: return if not authorization: raise HTTPException(401, "Missing Authorization header") token = authorization.removeprefix("Bearer ").strip() if token != expected: raise HTTPException(403, "Invalid API key") async def _resolve_plan(model: str, plan_id_header: str | None) -> tuple[dict, str]: """解析目标 Plan: 优先 X-Plan-Id header, 否则按 model 路由表查找""" if plan_id_header: plan = await db.get_plan(plan_id_header) if not plan: raise HTTPException(404, f"Plan {plan_id_header} not found") return plan, model resolved_plan_id = await db.resolve_model(model) if not resolved_plan_id: raise HTTPException(404, f"No plan found for model '{model}'") plan = await db.get_plan(resolved_plan_id) if not plan: raise HTTPException(500, "Resolved plan missing from DB") return plan, model async def _stream_and_count(provider, messages, model, plan, stream, **kwargs): """流式转发并统计 token 消耗""" total_tokens = 0 async for chunk_data in provider.chat(messages, model, plan, stream=stream, **kwargs): yield chunk_data # 尝试从 chunk 中提取 usage if not stream and chunk_data: try: resp_obj = json.loads(chunk_data) usage = resp_obj.get("usage", {}) total_tokens = usage.get("total_tokens", 0) except (json.JSONDecodeError, TypeError): pass # 流式模式下无法精确统计 token,按请求次数 +1 计费 await db.increment_quota_used(plan["id"], token_count=total_tokens) # ── OpenAI 兼容: /v1/chat/completions ───────────────── @router.post("/v1/chat/completions") async def openai_chat_completions( request: Request, authorization: str | None = Header(None), x_plan_id: str | None = Header(None, alias="X-Plan-Id"), ): _verify_key(authorization) body = await request.json() model = body.get("model", "") messages = body.get("messages", []) stream = body.get("stream", False) if not model or not messages: raise HTTPException(400, "model and messages are required") plan, model = await _resolve_plan(model, x_plan_id) # 检查额度 if not await db.check_plan_available(plan["id"]): raise HTTPException(429, f"Plan '{plan['name']}' quota exhausted") provider = ProviderRegistry.get(plan["provider_name"]) if not provider: raise HTTPException(500, f"Provider '{plan['provider_name']}' not registered") extra_kwargs = {k: v for k, v in body.items() if k not in ("model", "messages", "stream")} if stream: return StreamingResponse( _stream_and_count(provider, messages, model, plan, True, **extra_kwargs), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Plan-Id": plan["id"]}, ) else: chunks = [] async for c in _stream_and_count(provider, messages, model, plan, False, **extra_kwargs): chunks.append(c) result = json.loads(chunks[0]) if chunks else {} return result # ── Anthropic 兼容: /v1/messages ────────────────────── @router.post("/v1/messages") async def anthropic_messages( request: Request, authorization: str | None = Header(None), x_plan_id: str | None = Header(None, alias="X-Plan-Id"), x_api_key: str | None = Header(None, alias="x-api-key"), ): auth = authorization or (f"Bearer {x_api_key}" if x_api_key else None) _verify_key(auth) body = await request.json() model = body.get("model", "") messages = body.get("messages", []) stream = body.get("stream", False) system_msg = body.get("system", "") if not model or not messages: raise HTTPException(400, "model and messages are required") # Anthropic 格式 -> OpenAI 格式 messages oai_messages = [] if system_msg: oai_messages.append({"role": "system", "content": system_msg}) for m in messages: content = m.get("content", "") if isinstance(content, list): # Anthropic 多模态 content blocks -> 取文本 text_parts = [c.get("text", "") for c in content if c.get("type") == "text"] content = "\n".join(text_parts) oai_messages.append({"role": m.get("role", "user"), "content": content}) plan, model = await _resolve_plan(model, x_plan_id) if not await db.check_plan_available(plan["id"]): raise HTTPException(429, f"Plan '{plan['name']}' quota exhausted") provider = ProviderRegistry.get(plan["provider_name"]) if not provider: raise HTTPException(500, f"Provider '{plan['provider_name']}' not registered") extra_kwargs = {k: v for k, v in body.items() if k not in ("model", "messages", "stream", "system")} if stream: async def anthropic_stream(): """将 OpenAI SSE 格式转换为 Anthropic SSE 格式""" yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': 'msg_proxy', 'type': 'message', 'role': 'assistant', 'model': model, 'content': []}})}\n\n" yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" async for chunk_data in _stream_and_count( provider, oai_messages, model, plan, True, **extra_kwargs ): if chunk_data.startswith("data: [DONE]"): break if chunk_data.startswith("data: "): try: oai_chunk = json.loads(chunk_data[6:].strip()) delta = oai_chunk.get("choices", [{}])[0].get("delta", {}) text = delta.get("content", "") if text: yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': text}})}\n\n" except (json.JSONDecodeError, IndexError): pass yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n" return StreamingResponse( anthropic_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"}, ) else: chunks = [] async for c in _stream_and_count( provider, oai_messages, model, plan, False, **extra_kwargs ): chunks.append(c) oai_resp = json.loads(chunks[0]) if chunks else {} # OpenAI 响应 -> Anthropic 响应 content_text = "" choices = oai_resp.get("choices", []) if choices: content_text = choices[0].get("message", {}).get("content", "") usage = oai_resp.get("usage", {}) return { "id": "msg_proxy", "type": "message", "role": "assistant", "model": model, "content": [{"type": "text", "text": content_text}], "stop_reason": "end_turn", "usage": { "input_tokens": usage.get("prompt_tokens", 0), "output_tokens": usage.get("completion_tokens", 0), }, }