Files
AIChatRoom/backend/adapters/llmstudio_adapter.py

254 lines
8.8 KiB
Python
Raw Normal View History

"""
LLM Studio适配器
支持本地LLM Studio服务
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class LLMStudioAdapter(BaseAdapter):
"""
LLM Studio API适配器
兼容OpenAI API格式的本地服务
"""
DEFAULT_BASE_URL = "http://localhost:1234/v1"
def __init__(
self,
api_key: str = "lm-studio", # LLM Studio使用固定key
base_url: str = "",
model: str = "local-model",
use_proxy: bool = False, # 本地服务不需要代理
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 120, # 本地模型可能需要更长时间
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs
}
response = await client.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"LLM Studio API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
usage = data.get("usage", {})
return AdapterResponse(
success=True,
content=message.get("content", ""),
model=data.get("model", self.model),
finish_reason=choice.get("finish_reason", ""),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
latency_ms=self._calculate_latency(start_time)
)
except Exception as e:
logger.error(f"LLM Studio请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
**kwargs
}
async with client.stream(
"POST",
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"LLM Studio流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
# 首先检查服务是否在运行
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=10
) as client:
# 获取模型列表
response = await client.get(
f"{self.base_url}/models",
headers={"Authorization": f"Bearer {self.api_key}"}
)
if response.status_code != 200:
return {
"success": False,
"message": "LLM Studio服务未运行或不可访问",
"latency_ms": self._calculate_latency(start_time)
}
data = response.json()
models = [m.get("id", "") for m in data.get("data", [])]
if not models:
return {
"success": False,
"message": "LLM Studio中没有加载的模型",
"latency_ms": self._calculate_latency(start_time)
}
# 发送测试消息
test_messages = [
ChatMessage(role="user", content="Hello, respond with 'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}
async def list_models(self) -> List[Dict[str, Any]]:
"""
列出LLM Studio中加载的模型
Returns:
模型信息列表
"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=10
) as client:
response = await client.get(
f"{self.base_url}/models",
headers={"Authorization": f"Bearer {self.api_key}"}
)
if response.status_code == 200:
data = response.json()
return data.get("data", [])
except Exception as e:
logger.error(f"获取LLM Studio模型列表失败: {e}")
return []