93 lines
2.6 KiB
Python
93 lines
2.6 KiB
Python
"""LLM API 客户端,兼容 OpenAI API 格式"""
|
|
import json
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from openai import OpenAI, APIError
|
|
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AIClient:
|
|
"""封装 LLM 调用,支持重试和 JSON 输出"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
timeout: Optional[int] = None,
|
|
max_retries: Optional[int] = None,
|
|
):
|
|
self.api_key = api_key or settings.OPENAI_API_KEY
|
|
self.base_url = base_url or settings.OPENAI_BASE_URL
|
|
self.model = model or settings.OPENAI_MODEL
|
|
self.timeout = timeout or settings.OPENAI_TIMEOUT
|
|
self.max_retries = max_retries or settings.OPENAI_MAX_RETRIES
|
|
|
|
self._client: Optional[OpenAI] = None
|
|
|
|
@property
|
|
def client(self) -> OpenAI:
|
|
if self._client is None:
|
|
self._client = OpenAI(
|
|
api_key=self.api_key,
|
|
base_url=self.base_url,
|
|
timeout=self.timeout,
|
|
max_retries=self.max_retries,
|
|
)
|
|
return self._client
|
|
|
|
def chat_completion(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
temperature: float = 0.3,
|
|
json_mode: bool = False,
|
|
) -> str:
|
|
"""调用 LLM 返回文本"""
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
]
|
|
|
|
kwargs = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
}
|
|
if json_mode:
|
|
kwargs["response_format"] = {"type": "json_object"}
|
|
|
|
try:
|
|
resp = self.client.chat.completions.create(**kwargs)
|
|
content = resp.choices[0].message.content or ""
|
|
return content.strip()
|
|
except APIError as exc:
|
|
logger.error("LLM API 调用失败: %s", exc)
|
|
raise
|
|
|
|
def chat_completion_json(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
temperature: float = 0.3,
|
|
) -> dict:
|
|
"""调用 LLM 并解析返回的 JSON"""
|
|
content = self.chat_completion(
|
|
system_prompt=system_prompt,
|
|
user_prompt=user_prompt,
|
|
temperature=temperature,
|
|
json_mode=True,
|
|
)
|
|
try:
|
|
return json.loads(content)
|
|
except json.JSONDecodeError as exc:
|
|
logger.error("LLM 返回不是合法 JSON: %s - content=%s", exc, content[:500])
|
|
raise
|
|
|
|
|
|
ai_client = AIClient()
|