254 lines
8.8 KiB
Python
254 lines
8.8 KiB
Python
|
|
"""
|
||
|
|
LLM Studio适配器
|
||
|
|
支持本地LLM Studio服务
|
||
|
|
"""
|
||
|
|
import json
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||
|
|
from loguru import logger
|
||
|
|
|
||
|
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||
|
|
from utils.proxy_handler import get_http_client
|
||
|
|
|
||
|
|
|
||
|
|
class LLMStudioAdapter(BaseAdapter):
|
||
|
|
"""
|
||
|
|
LLM Studio API适配器
|
||
|
|
兼容OpenAI API格式的本地服务
|
||
|
|
"""
|
||
|
|
|
||
|
|
DEFAULT_BASE_URL = "http://localhost:1234/v1"
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
api_key: str = "lm-studio", # LLM Studio使用固定key
|
||
|
|
base_url: str = "",
|
||
|
|
model: str = "local-model",
|
||
|
|
use_proxy: bool = False, # 本地服务不需要代理
|
||
|
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||
|
|
timeout: int = 120, # 本地模型可能需要更长时间
|
||
|
|
**kwargs
|
||
|
|
):
|
||
|
|
super().__init__(
|
||
|
|
api_key=api_key,
|
||
|
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
||
|
|
model=model,
|
||
|
|
use_proxy=use_proxy,
|
||
|
|
proxy_config=proxy_config,
|
||
|
|
timeout=timeout,
|
||
|
|
**kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
async def chat(
|
||
|
|
self,
|
||
|
|
messages: List[ChatMessage],
|
||
|
|
temperature: float = 0.7,
|
||
|
|
max_tokens: int = 2000,
|
||
|
|
**kwargs
|
||
|
|
) -> AdapterResponse:
|
||
|
|
"""发送聊天请求"""
|
||
|
|
start_time = datetime.utcnow()
|
||
|
|
|
||
|
|
try:
|
||
|
|
async with get_http_client(
|
||
|
|
use_proxy=self.use_proxy,
|
||
|
|
proxy_config=self.proxy_config,
|
||
|
|
timeout=self.timeout
|
||
|
|
) as client:
|
||
|
|
headers = {
|
||
|
|
"Authorization": f"Bearer {self.api_key}",
|
||
|
|
"Content-Type": "application/json"
|
||
|
|
}
|
||
|
|
|
||
|
|
payload = {
|
||
|
|
"model": self.model,
|
||
|
|
"messages": self._build_messages(messages),
|
||
|
|
"temperature": temperature,
|
||
|
|
"max_tokens": max_tokens,
|
||
|
|
**kwargs
|
||
|
|
}
|
||
|
|
|
||
|
|
response = await client.post(
|
||
|
|
f"{self.base_url}/chat/completions",
|
||
|
|
headers=headers,
|
||
|
|
json=payload
|
||
|
|
)
|
||
|
|
|
||
|
|
if response.status_code != 200:
|
||
|
|
error_text = response.text
|
||
|
|
logger.error(f"LLM Studio API错误: {response.status_code} - {error_text}")
|
||
|
|
return AdapterResponse(
|
||
|
|
success=False,
|
||
|
|
error=f"API错误: {response.status_code} - {error_text}",
|
||
|
|
latency_ms=self._calculate_latency(start_time)
|
||
|
|
)
|
||
|
|
|
||
|
|
data = response.json()
|
||
|
|
choice = data.get("choices", [{}])[0]
|
||
|
|
message = choice.get("message", {})
|
||
|
|
usage = data.get("usage", {})
|
||
|
|
|
||
|
|
return AdapterResponse(
|
||
|
|
success=True,
|
||
|
|
content=message.get("content", ""),
|
||
|
|
model=data.get("model", self.model),
|
||
|
|
finish_reason=choice.get("finish_reason", ""),
|
||
|
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||
|
|
completion_tokens=usage.get("completion_tokens", 0),
|
||
|
|
total_tokens=usage.get("total_tokens", 0),
|
||
|
|
latency_ms=self._calculate_latency(start_time)
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"LLM Studio请求异常: {e}")
|
||
|
|
return AdapterResponse(
|
||
|
|
success=False,
|
||
|
|
error=str(e),
|
||
|
|
latency_ms=self._calculate_latency(start_time)
|
||
|
|
)
|
||
|
|
|
||
|
|
async def chat_stream(
|
||
|
|
self,
|
||
|
|
messages: List[ChatMessage],
|
||
|
|
temperature: float = 0.7,
|
||
|
|
max_tokens: int = 2000,
|
||
|
|
**kwargs
|
||
|
|
) -> AsyncGenerator[str, None]:
|
||
|
|
"""发送流式聊天请求"""
|
||
|
|
try:
|
||
|
|
async with get_http_client(
|
||
|
|
use_proxy=self.use_proxy,
|
||
|
|
proxy_config=self.proxy_config,
|
||
|
|
timeout=self.timeout
|
||
|
|
) as client:
|
||
|
|
headers = {
|
||
|
|
"Authorization": f"Bearer {self.api_key}",
|
||
|
|
"Content-Type": "application/json"
|
||
|
|
}
|
||
|
|
|
||
|
|
payload = {
|
||
|
|
"model": self.model,
|
||
|
|
"messages": self._build_messages(messages),
|
||
|
|
"temperature": temperature,
|
||
|
|
"max_tokens": max_tokens,
|
||
|
|
"stream": True,
|
||
|
|
**kwargs
|
||
|
|
}
|
||
|
|
|
||
|
|
async with client.stream(
|
||
|
|
"POST",
|
||
|
|
f"{self.base_url}/chat/completions",
|
||
|
|
headers=headers,
|
||
|
|
json=payload
|
||
|
|
) as response:
|
||
|
|
async for line in response.aiter_lines():
|
||
|
|
if line.startswith("data: "):
|
||
|
|
data_str = line[6:]
|
||
|
|
if data_str.strip() == "[DONE]":
|
||
|
|
break
|
||
|
|
try:
|
||
|
|
data = json.loads(data_str)
|
||
|
|
delta = data.get("choices", [{}])[0].get("delta", {})
|
||
|
|
content = delta.get("content", "")
|
||
|
|
if content:
|
||
|
|
yield content
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
continue
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"LLM Studio流式请求异常: {e}")
|
||
|
|
yield f"[错误: {str(e)}]"
|
||
|
|
|
||
|
|
async def test_connection(self) -> Dict[str, Any]:
|
||
|
|
"""测试API连接"""
|
||
|
|
start_time = datetime.utcnow()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# 首先检查服务是否在运行
|
||
|
|
async with get_http_client(
|
||
|
|
use_proxy=self.use_proxy,
|
||
|
|
proxy_config=self.proxy_config,
|
||
|
|
timeout=10
|
||
|
|
) as client:
|
||
|
|
# 获取模型列表
|
||
|
|
response = await client.get(
|
||
|
|
f"{self.base_url}/models",
|
||
|
|
headers={"Authorization": f"Bearer {self.api_key}"}
|
||
|
|
)
|
||
|
|
|
||
|
|
if response.status_code != 200:
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"message": "LLM Studio服务未运行或不可访问",
|
||
|
|
"latency_ms": self._calculate_latency(start_time)
|
||
|
|
}
|
||
|
|
|
||
|
|
data = response.json()
|
||
|
|
models = [m.get("id", "") for m in data.get("data", [])]
|
||
|
|
|
||
|
|
if not models:
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"message": "LLM Studio中没有加载的模型",
|
||
|
|
"latency_ms": self._calculate_latency(start_time)
|
||
|
|
}
|
||
|
|
|
||
|
|
# 发送测试消息
|
||
|
|
test_messages = [
|
||
|
|
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||
|
|
]
|
||
|
|
|
||
|
|
response = await self.chat(
|
||
|
|
messages=test_messages,
|
||
|
|
temperature=0,
|
||
|
|
max_tokens=10
|
||
|
|
)
|
||
|
|
|
||
|
|
if response.success:
|
||
|
|
return {
|
||
|
|
"success": True,
|
||
|
|
"message": "连接成功",
|
||
|
|
"model": response.model,
|
||
|
|
"latency_ms": response.latency_ms
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"message": response.error,
|
||
|
|
"latency_ms": response.latency_ms
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"message": str(e),
|
||
|
|
"latency_ms": self._calculate_latency(start_time)
|
||
|
|
}
|
||
|
|
|
||
|
|
async def list_models(self) -> List[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
列出LLM Studio中加载的模型
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
模型信息列表
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
async with get_http_client(
|
||
|
|
use_proxy=self.use_proxy,
|
||
|
|
proxy_config=self.proxy_config,
|
||
|
|
timeout=10
|
||
|
|
) as client:
|
||
|
|
response = await client.get(
|
||
|
|
f"{self.base_url}/models",
|
||
|
|
headers={"Authorization": f"Bearer {self.api_key}"}
|
||
|
|
)
|
||
|
|
|
||
|
|
if response.status_code == 200:
|
||
|
|
data = response.json()
|
||
|
|
return data.get("data", [])
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"获取LLM Studio模型列表失败: {e}")
|
||
|
|
|
||
|
|
return []
|