Files
AIChatRoom/backend/adapters/ollama_adapter.py

242 lines
8.4 KiB
Python
Raw Permalink Normal View History

"""
Ollama适配器
支持本地Ollama服务
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class OllamaAdapter(BaseAdapter):
"""
Ollama API适配器
用于连接本地Ollama服务
"""
DEFAULT_BASE_URL = "http://localhost:11434"
def __init__(
self,
api_key: str = "", # Ollama通常不需要API密钥
base_url: str = "",
model: str = "llama2",
use_proxy: bool = False, # 本地服务通常不需要代理
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 120, # 本地模型可能需要更长时间
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"options": {
"temperature": temperature,
"num_predict": max_tokens,
},
"stream": False
}
response = await client.post(
f"{self.base_url}/api/chat",
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"Ollama API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
message = data.get("message", {})
return AdapterResponse(
success=True,
content=message.get("content", ""),
model=data.get("model", self.model),
finish_reason=data.get("done_reason", "stop"),
prompt_tokens=data.get("prompt_eval_count", 0),
completion_tokens=data.get("eval_count", 0),
total_tokens=(
data.get("prompt_eval_count", 0) +
data.get("eval_count", 0)
),
latency_ms=self._calculate_latency(start_time)
)
except Exception as e:
logger.error(f"Ollama请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"options": {
"temperature": temperature,
"num_predict": max_tokens,
},
"stream": True
}
async with client.stream(
"POST",
f"{self.base_url}/api/chat",
json=payload
) as response:
async for line in response.aiter_lines():
if line:
try:
data = json.loads(line)
message = data.get("message", {})
content = message.get("content", "")
if content:
yield content
# 检查是否完成
if data.get("done", False):
break
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"Ollama流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
# 首先检查服务是否在运行
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=10
) as client:
# 检查模型是否存在
response = await client.get(f"{self.base_url}/api/tags")
if response.status_code != 200:
return {
"success": False,
"message": "Ollama服务未运行或不可访问",
"latency_ms": self._calculate_latency(start_time)
}
data = response.json()
models = [m.get("name", "").split(":")[0] for m in data.get("models", [])]
model_name = self.model.split(":")[0]
if model_name not in models:
return {
"success": False,
"message": f"模型 {self.model} 未安装,可用模型: {', '.join(models)}",
"latency_ms": self._calculate_latency(start_time)
}
# 发送测试消息
test_messages = [
ChatMessage(role="user", content="Hello, respond with 'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}
async def list_models(self) -> List[str]:
"""
列出本地可用的模型
Returns:
模型名称列表
"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=10
) as client:
response = await client.get(f"{self.base_url}/api/tags")
if response.status_code == 200:
data = response.json()
return [m.get("name", "") for m in data.get("models", [])]
except Exception as e:
logger.error(f"获取Ollama模型列表失败: {e}")
return []