""" Ollama适配器 支持本地Ollama服务 """ import json from datetime import datetime from typing import List, Dict, Any, Optional, AsyncGenerator from loguru import logger from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse from utils.proxy_handler import get_http_client class OllamaAdapter(BaseAdapter): """ Ollama API适配器 用于连接本地Ollama服务 """ DEFAULT_BASE_URL = "http://localhost:11434" def __init__( self, api_key: str = "", # Ollama通常不需要API密钥 base_url: str = "", model: str = "llama2", use_proxy: bool = False, # 本地服务通常不需要代理 proxy_config: Optional[Dict[str, Any]] = None, timeout: int = 120, # 本地模型可能需要更长时间 **kwargs ): super().__init__( api_key=api_key, base_url=base_url or self.DEFAULT_BASE_URL, model=model, use_proxy=use_proxy, proxy_config=proxy_config, timeout=timeout, **kwargs ) async def chat( self, messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2000, **kwargs ) -> AdapterResponse: """发送聊天请求""" start_time = datetime.utcnow() try: async with get_http_client( use_proxy=self.use_proxy, proxy_config=self.proxy_config, timeout=self.timeout ) as client: payload = { "model": self.model, "messages": self._build_messages(messages), "options": { "temperature": temperature, "num_predict": max_tokens, }, "stream": False } response = await client.post( f"{self.base_url}/api/chat", json=payload ) if response.status_code != 200: error_text = response.text logger.error(f"Ollama API错误: {response.status_code} - {error_text}") return AdapterResponse( success=False, error=f"API错误: {response.status_code} - {error_text}", latency_ms=self._calculate_latency(start_time) ) data = response.json() message = data.get("message", {}) return AdapterResponse( success=True, content=message.get("content", ""), model=data.get("model", self.model), finish_reason=data.get("done_reason", "stop"), prompt_tokens=data.get("prompt_eval_count", 0), completion_tokens=data.get("eval_count", 0), total_tokens=( data.get("prompt_eval_count", 0) + data.get("eval_count", 0) ), latency_ms=self._calculate_latency(start_time) ) except Exception as e: logger.error(f"Ollama请求异常: {e}") return AdapterResponse( success=False, error=str(e), latency_ms=self._calculate_latency(start_time) ) async def chat_stream( self, messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2000, **kwargs ) -> AsyncGenerator[str, None]: """发送流式聊天请求""" try: async with get_http_client( use_proxy=self.use_proxy, proxy_config=self.proxy_config, timeout=self.timeout ) as client: payload = { "model": self.model, "messages": self._build_messages(messages), "options": { "temperature": temperature, "num_predict": max_tokens, }, "stream": True } async with client.stream( "POST", f"{self.base_url}/api/chat", json=payload ) as response: async for line in response.aiter_lines(): if line: try: data = json.loads(line) message = data.get("message", {}) content = message.get("content", "") if content: yield content # 检查是否完成 if data.get("done", False): break except json.JSONDecodeError: continue except Exception as e: logger.error(f"Ollama流式请求异常: {e}") yield f"[错误: {str(e)}]" async def test_connection(self) -> Dict[str, Any]: """测试API连接""" start_time = datetime.utcnow() try: # 首先检查服务是否在运行 async with get_http_client( use_proxy=self.use_proxy, proxy_config=self.proxy_config, timeout=10 ) as client: # 检查模型是否存在 response = await client.get(f"{self.base_url}/api/tags") if response.status_code != 200: return { "success": False, "message": "Ollama服务未运行或不可访问", "latency_ms": self._calculate_latency(start_time) } data = response.json() models = [m.get("name", "").split(":")[0] for m in data.get("models", [])] model_name = self.model.split(":")[0] if model_name not in models: return { "success": False, "message": f"模型 {self.model} 未安装,可用模型: {', '.join(models)}", "latency_ms": self._calculate_latency(start_time) } # 发送测试消息 test_messages = [ ChatMessage(role="user", content="Hello, respond with 'OK'") ] response = await self.chat( messages=test_messages, temperature=0, max_tokens=10 ) if response.success: return { "success": True, "message": "连接成功", "model": response.model, "latency_ms": response.latency_ms } else: return { "success": False, "message": response.error, "latency_ms": response.latency_ms } except Exception as e: return { "success": False, "message": str(e), "latency_ms": self._calculate_latency(start_time) } async def list_models(self) -> List[str]: """ 列出本地可用的模型 Returns: 模型名称列表 """ try: async with get_http_client( use_proxy=self.use_proxy, proxy_config=self.proxy_config, timeout=10 ) as client: response = await client.get(f"{self.base_url}/api/tags") if response.status_code == 200: data = response.json() return [m.get("name", "") for m in data.get("models", [])] except Exception as e: logger.error(f"获取Ollama模型列表失败: {e}") return []