- 添加 pytest 配置和测试依赖到 requirements.txt - 创建测试包结构和 fixtures (conftest.py) - 添加数据库模块的 CRUD 操作测试 (test_database.py) - 添加 Provider 插件系统测试 (test_providers.py) - 添加调度器模块测试 (test_scheduler.py) - 添加 API 路由测试 (test_api.py) - 添加回归测试覆盖边界条件和错误处理 (test_regressions.py) - 添加健康检查端点用于容器监控 - 修复调度器中的日历计算逻辑和任务执行参数处理 - 更新数据库函数以返回操作结果状态
413 lines
15 KiB
Python
413 lines
15 KiB
Python
"""API 路由测试"""
|
||
|
||
import pytest
|
||
from httpx import AsyncClient, ASGITransport
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
from app.main import app
|
||
|
||
|
||
class TestHealthCheck:
|
||
"""健康检查测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_endpoint(self):
|
||
"""测试 /health 端点"""
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.get("/health")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["status"] == "healthy"
|
||
assert "service" in data
|
||
|
||
|
||
class TestPlansAPI:
|
||
"""Plans API 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_plans_empty(self, temp_db, temp_storage):
|
||
"""测试列出空的 Plans"""
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.get("/api/plans")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_plan(self, temp_db, temp_storage):
|
||
"""测试创建 Plan"""
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/plans",
|
||
json={
|
||
"name": "Test Plan",
|
||
"provider_name": "openai",
|
||
"api_key": "sk-test",
|
||
"api_base": "https://api.openai.com/v1",
|
||
"plan_type": "coding",
|
||
"supported_models": ["gpt-4"],
|
||
},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "id" in data
|
||
assert data["name"] == "Test Plan"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_plan_invalid(self, temp_db, temp_storage):
|
||
"""测试创建 Plan 无效数据"""
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post("/api/plans", json={})
|
||
|
||
# 应该返回 422 验证错误
|
||
assert response.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_plan(self, temp_db, temp_storage):
|
||
"""测试获取 Plan"""
|
||
# 先创建
|
||
from app import database as db
|
||
plan = await db.create_plan(
|
||
name="Test Plan",
|
||
provider_name="openai",
|
||
api_key="sk-test",
|
||
)
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.get(f"/api/plans/{plan['id']}")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["id"] == plan["id"]
|
||
assert data["name"] == "Test Plan"
|
||
assert "quota_rules" in data
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_plan_not_found(self, temp_db, temp_storage):
|
||
"""测试获取不存在的 Plan"""
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.get("/api/plans/nonexistent")
|
||
|
||
assert response.status_code == 404
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_plan(self, temp_db, temp_storage):
|
||
"""测试更新 Plan"""
|
||
from app import database as db
|
||
plan = await db.create_plan(
|
||
name="Old Name",
|
||
provider_name="openai",
|
||
api_key="sk-test",
|
||
)
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.patch(
|
||
f"/api/plans/{plan['id']}",
|
||
json={"name": "New Name"},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
assert response.json()["ok"] is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_plan(self, temp_db, temp_storage):
|
||
"""测试删除 Plan"""
|
||
from app import database as db
|
||
plan = await db.create_plan(
|
||
name="To Delete",
|
||
provider_name="openai",
|
||
api_key="sk-test",
|
||
)
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.delete(f"/api/plans/{plan['id']}")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json()["ok"] is True
|
||
|
||
|
||
class TestQuotaRulesAPI:
|
||
"""QuotaRules API 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_quota_rule(self, temp_db, temp_storage):
|
||
"""测试创建 QuotaRule"""
|
||
from app import database as db
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
f"/api/plans/{plan['id']}/rules",
|
||
json={
|
||
"rule_name": "Daily Limit",
|
||
"quota_total": 100,
|
||
"quota_unit": "requests",
|
||
"refresh_type": "calendar_cycle",
|
||
"calendar_unit": "daily",
|
||
"calendar_anchor": {"hour": 0},
|
||
},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "id" in data
|
||
assert data["rule_name"] == "Daily Limit"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_quota_rules(self, temp_db, temp_storage):
|
||
"""测试列出 QuotaRules"""
|
||
from app import database as db
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 1", quota_total=100)
|
||
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 2", quota_total=200)
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.get(f"/api/plans/{plan['id']}/rules")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert len(data) == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_quota_rule(self, temp_db, temp_storage):
|
||
"""测试更新 QuotaRule"""
|
||
from app import database as db
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.patch(
|
||
f"/api/plans/rules/{rule['id']}",
|
||
json={"quota_total": 200},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
assert response.json()["ok"] is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_quota_rule(self, temp_db, temp_storage):
|
||
"""测试删除 QuotaRule"""
|
||
from app import database as db
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.delete(f"/api/plans/rules/{rule['id']}")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json()["ok"] is True
|
||
|
||
|
||
class TestModelRoutesAPI:
|
||
"""ModelRoutes API 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_set_model_route(self, temp_db, temp_storage):
|
||
"""测试设置模型路由"""
|
||
from app import database as db
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/plans/routes/models",
|
||
json={"model_name": "gpt-4", "plan_id": plan["id"], "priority": 10},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "id" in data
|
||
assert data["model_name"] == "gpt-4"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_model_routes(self, temp_db, temp_storage):
|
||
"""测试列出模型路由"""
|
||
from app import database as db
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
await db.set_model_route("gpt-4", plan["id"], priority=10)
|
||
await db.set_model_route("gpt-3.5-turbo", plan["id"], priority=5)
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.get("/api/plans/routes/models")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert len(data) == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_model_route(self, temp_db, temp_storage):
|
||
"""测试删除模型路由"""
|
||
from app import database as db
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
route = await db.set_model_route("gpt-4", plan["id"], priority=10)
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.delete(f"/api/plans/routes/models/{route['id']}")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json()["ok"] is True
|
||
|
||
|
||
class TestTasksAPI:
|
||
"""Tasks API 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_task(self, temp_db, temp_storage):
|
||
"""测试创建任务"""
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/queue",
|
||
json={
|
||
"task_type": "image",
|
||
"request_payload": {"prompt": "a cat"},
|
||
},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "id" in data
|
||
assert data["status"] == "pending"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_tasks(self, temp_db, temp_storage):
|
||
"""测试列出任务"""
|
||
from app import database as db
|
||
await db.create_task(task_type="image", request_payload={"prompt": "cat"})
|
||
t = await db.create_task(task_type="voice", request_payload={"text": "hello"})
|
||
await db.update_task(t["id"], status="running")
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.get("/api/queue")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert len(data) == 2
|
||
|
||
# 测试过滤
|
||
response = await client.get("/api/queue?status=pending")
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert len(data) == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_task(self, temp_db, temp_storage):
|
||
"""测试获取任务"""
|
||
from app import database as db
|
||
task = await db.create_task(task_type="image", request_payload={"prompt": "cat"})
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.get(f"/api/queue/{task['id']}")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["id"] == task["id"]
|
||
assert data["task_type"] == "image"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cancel_task(self, temp_db, temp_storage):
|
||
"""测试取消任务"""
|
||
from app import database as db
|
||
task = await db.create_task(task_type="image", request_payload={"prompt": "cat"})
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(f"/api/queue/{task['id']}/cancel")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json()["ok"] is True
|
||
|
||
# 验证状态
|
||
updated = await db.get_task(task["id"])
|
||
assert updated["status"] == "cancelled"
|
||
|
||
|
||
class TestProxyAPI:
|
||
"""代理 API 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_completions_missing_auth(self, temp_db, temp_storage):
|
||
"""测试聊天请求缺少鉴权"""
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/v1/chat/completions",
|
||
json={"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]},
|
||
)
|
||
|
||
# 默认 proxy_api_key 为空时应该通过,但如果设置了则返回 401
|
||
# 这里我们测试没有设置 key 的情况
|
||
# 实际行为取决于配置
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_completions_invalid_model(self, temp_db, temp_storage):
|
||
"""测试聊天请求无效模型"""
|
||
from app import database as db
|
||
|
||
# 设置空的 proxy_api_key 以跳过鉴权测试
|
||
import app.config as config_module
|
||
original_key = config_module.settings.server.proxy_api_key
|
||
config_module.settings.server.proxy_api_key = ""
|
||
|
||
try:
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/v1/chat/completions",
|
||
json={"model": "nonexistent-model", "messages": [{"role": "user", "content": "hello"}]},
|
||
)
|
||
finally:
|
||
config_module.settings.server.proxy_api_key = original_key
|
||
|
||
# 应该返回 404 找不到模型路由
|
||
assert response.status_code == 404
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_anthropic_messages_format_conversion(self, temp_db, temp_storage):
|
||
"""测试 Anthropic 格式转换"""
|
||
from app import database as db
|
||
|
||
import app.config as config_module
|
||
original_key = config_module.settings.server.proxy_api_key
|
||
config_module.settings.server.proxy_api_key = ""
|
||
|
||
# 创建路由
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
||
await db.set_model_route("gpt-4", plan["id"])
|
||
|
||
try:
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/v1/messages",
|
||
json={
|
||
"model": "gpt-4",
|
||
"messages": [{"role": "user", "content": "hello"}],
|
||
"max_tokens": 100,
|
||
},
|
||
)
|
||
finally:
|
||
config_module.settings.server.proxy_api_key = original_key
|
||
|
||
# 由于没有真实的 API,可能会失败,但至少验证格式转换不抛出错误
|
||
# 实际会返回 502 或类似的错误,因为没有真实后端
|