95 lines
2.3 KiB
Python
95 lines
2.3 KiB
Python
|
|
"""
|
|||
|
|
MongoDB数据库连接模块
|
|||
|
|
使用Motor异步驱动
|
|||
|
|
"""
|
|||
|
|
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
|||
|
|
from beanie import init_beanie
|
|||
|
|
from loguru import logger
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
from config import settings
|
|||
|
|
|
|||
|
|
# 全局数据库客户端和数据库实例
|
|||
|
|
_client: Optional[AsyncIOMotorClient] = None
|
|||
|
|
_database: Optional[AsyncIOMotorDatabase] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def connect_db() -> None:
|
|||
|
|
"""
|
|||
|
|
连接MongoDB数据库
|
|||
|
|
初始化Beanie ODM
|
|||
|
|
"""
|
|||
|
|
global _client, _database
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
_client = AsyncIOMotorClient(settings.MONGODB_URL)
|
|||
|
|
_database = _client[settings.MONGODB_DB]
|
|||
|
|
|
|||
|
|
# 导入所有文档模型用于初始化Beanie
|
|||
|
|
from models.ai_provider import AIProvider
|
|||
|
|
from models.agent import Agent
|
|||
|
|
from models.chatroom import ChatRoom
|
|||
|
|
from models.message import Message
|
|||
|
|
from models.discussion_result import DiscussionResult
|
|||
|
|
from models.agent_memory import AgentMemory
|
|||
|
|
|
|||
|
|
# 初始化Beanie
|
|||
|
|
await init_beanie(
|
|||
|
|
database=_database,
|
|||
|
|
document_models=[
|
|||
|
|
AIProvider,
|
|||
|
|
Agent,
|
|||
|
|
ChatRoom,
|
|||
|
|
Message,
|
|||
|
|
DiscussionResult,
|
|||
|
|
AgentMemory,
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info(f"已连接到MongoDB数据库: {settings.MONGODB_DB}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"数据库连接失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def close_db() -> None:
|
|||
|
|
"""
|
|||
|
|
关闭数据库连接
|
|||
|
|
"""
|
|||
|
|
global _client
|
|||
|
|
|
|||
|
|
if _client:
|
|||
|
|
_client.close()
|
|||
|
|
logger.info("数据库连接已关闭")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_database() -> AsyncIOMotorDatabase:
|
|||
|
|
"""
|
|||
|
|
获取数据库实例
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
MongoDB数据库实例
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
RuntimeError: 数据库未初始化
|
|||
|
|
"""
|
|||
|
|
if _database is None:
|
|||
|
|
raise RuntimeError("数据库未初始化,请先调用connect_db()")
|
|||
|
|
return _database
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_client() -> AsyncIOMotorClient:
|
|||
|
|
"""
|
|||
|
|
获取数据库客户端
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
MongoDB客户端实例
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
RuntimeError: 客户端未初始化
|
|||
|
|
"""
|
|||
|
|
if _client is None:
|
|||
|
|
raise RuntimeError("数据库客户端未初始化")
|
|||
|
|
return _client
|