413 lines
15 KiB
Python
413 lines
15 KiB
Python
|
|
"""API 路由测试"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
from httpx import AsyncClient, ASGITransport
|
|||
|
|
from unittest.mock import AsyncMock, patch
|
|||
|
|
|
|||
|
|
from app.main import app
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestHealthCheck:
|
|||
|
|
"""健康检查测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_health_endpoint(self):
|
|||
|
|
"""测试 /health 端点"""
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.get("/health")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert data["status"] == "healthy"
|
|||
|
|
assert "service" in data
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestPlansAPI:
|
|||
|
|
"""Plans API 测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_list_plans_empty(self, temp_db, temp_storage):
|
|||
|
|
"""测试列出空的 Plans"""
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.get("/api/plans")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
assert response.json() == []
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_create_plan(self, temp_db, temp_storage):
|
|||
|
|
"""测试创建 Plan"""
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.post(
|
|||
|
|
"/api/plans",
|
|||
|
|
json={
|
|||
|
|
"name": "Test Plan",
|
|||
|
|
"provider_name": "openai",
|
|||
|
|
"api_key": "sk-test",
|
|||
|
|
"api_base": "https://api.openai.com/v1",
|
|||
|
|
"plan_type": "coding",
|
|||
|
|
"supported_models": ["gpt-4"],
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert "id" in data
|
|||
|
|
assert data["name"] == "Test Plan"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_create_plan_invalid(self, temp_db, temp_storage):
|
|||
|
|
"""测试创建 Plan 无效数据"""
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.post("/api/plans", json={})
|
|||
|
|
|
|||
|
|
# 应该返回 422 验证错误
|
|||
|
|
assert response.status_code == 422
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_get_plan(self, temp_db, temp_storage):
|
|||
|
|
"""测试获取 Plan"""
|
|||
|
|
# 先创建
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(
|
|||
|
|
name="Test Plan",
|
|||
|
|
provider_name="openai",
|
|||
|
|
api_key="sk-test",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.get(f"/api/plans/{plan['id']}")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert data["id"] == plan["id"]
|
|||
|
|
assert data["name"] == "Test Plan"
|
|||
|
|
assert "quota_rules" in data
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_get_plan_not_found(self, temp_db, temp_storage):
|
|||
|
|
"""测试获取不存在的 Plan"""
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.get("/api/plans/nonexistent")
|
|||
|
|
|
|||
|
|
assert response.status_code == 404
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_update_plan(self, temp_db, temp_storage):
|
|||
|
|
"""测试更新 Plan"""
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(
|
|||
|
|
name="Old Name",
|
|||
|
|
provider_name="openai",
|
|||
|
|
api_key="sk-test",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.patch(
|
|||
|
|
f"/api/plans/{plan['id']}",
|
|||
|
|
json={"name": "New Name"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
assert response.json()["ok"] is True
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_delete_plan(self, temp_db, temp_storage):
|
|||
|
|
"""测试删除 Plan"""
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(
|
|||
|
|
name="To Delete",
|
|||
|
|
provider_name="openai",
|
|||
|
|
api_key="sk-test",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.delete(f"/api/plans/{plan['id']}")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
assert response.json()["ok"] is True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestQuotaRulesAPI:
|
|||
|
|
"""QuotaRules API 测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_create_quota_rule(self, temp_db, temp_storage):
|
|||
|
|
"""测试创建 QuotaRule"""
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.post(
|
|||
|
|
f"/api/plans/{plan['id']}/rules",
|
|||
|
|
json={
|
|||
|
|
"rule_name": "Daily Limit",
|
|||
|
|
"quota_total": 100,
|
|||
|
|
"quota_unit": "requests",
|
|||
|
|
"refresh_type": "calendar_cycle",
|
|||
|
|
"calendar_unit": "daily",
|
|||
|
|
"calendar_anchor": {"hour": 0},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert "id" in data
|
|||
|
|
assert data["rule_name"] == "Daily Limit"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_list_quota_rules(self, temp_db, temp_storage):
|
|||
|
|
"""测试列出 QuotaRules"""
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
|||
|
|
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 1", quota_total=100)
|
|||
|
|
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 2", quota_total=200)
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.get(f"/api/plans/{plan['id']}/rules")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert len(data) == 2
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_update_quota_rule(self, temp_db, temp_storage):
|
|||
|
|
"""测试更新 QuotaRule"""
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
|||
|
|
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.patch(
|
|||
|
|
f"/api/plans/rules/{rule['id']}",
|
|||
|
|
json={"quota_total": 200},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
assert response.json()["ok"] is True
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_delete_quota_rule(self, temp_db, temp_storage):
|
|||
|
|
"""测试删除 QuotaRule"""
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
|||
|
|
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.delete(f"/api/plans/rules/{rule['id']}")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
assert response.json()["ok"] is True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestModelRoutesAPI:
|
|||
|
|
"""ModelRoutes API 测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_set_model_route(self, temp_db, temp_storage):
|
|||
|
|
"""测试设置模型路由"""
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.post(
|
|||
|
|
"/api/plans/routes/models",
|
|||
|
|
json={"model_name": "gpt-4", "plan_id": plan["id"], "priority": 10},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert "id" in data
|
|||
|
|
assert data["model_name"] == "gpt-4"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_list_model_routes(self, temp_db, temp_storage):
|
|||
|
|
"""测试列出模型路由"""
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
|||
|
|
await db.set_model_route("gpt-4", plan["id"], priority=10)
|
|||
|
|
await db.set_model_route("gpt-3.5-turbo", plan["id"], priority=5)
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.get("/api/plans/routes/models")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert len(data) == 2
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_delete_model_route(self, temp_db, temp_storage):
|
|||
|
|
"""测试删除模型路由"""
|
|||
|
|
from app import database as db
|
|||
|
|
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
|||
|
|
route = await db.set_model_route("gpt-4", plan["id"], priority=10)
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.delete(f"/api/plans/routes/models/{route['id']}")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
assert response.json()["ok"] is True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestTasksAPI:
|
|||
|
|
"""Tasks API 测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_create_task(self, temp_db, temp_storage):
|
|||
|
|
"""测试创建任务"""
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.post(
|
|||
|
|
"/api/queue",
|
|||
|
|
json={
|
|||
|
|
"task_type": "image",
|
|||
|
|
"request_payload": {"prompt": "a cat"},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert "id" in data
|
|||
|
|
assert data["status"] == "pending"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_list_tasks(self, temp_db, temp_storage):
|
|||
|
|
"""测试列出任务"""
|
|||
|
|
from app import database as db
|
|||
|
|
await db.create_task(task_type="image", request_payload={"prompt": "cat"})
|
|||
|
|
t = await db.create_task(task_type="voice", request_payload={"text": "hello"})
|
|||
|
|
await db.update_task(t["id"], status="running")
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.get("/api/queue")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert len(data) == 2
|
|||
|
|
|
|||
|
|
# 测试过滤
|
|||
|
|
response = await client.get("/api/queue?status=pending")
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert len(data) == 1
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_get_task(self, temp_db, temp_storage):
|
|||
|
|
"""测试获取任务"""
|
|||
|
|
from app import database as db
|
|||
|
|
task = await db.create_task(task_type="image", request_payload={"prompt": "cat"})
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.get(f"/api/queue/{task['id']}")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
data = response.json()
|
|||
|
|
assert data["id"] == task["id"]
|
|||
|
|
assert data["task_type"] == "image"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_cancel_task(self, temp_db, temp_storage):
|
|||
|
|
"""测试取消任务"""
|
|||
|
|
from app import database as db
|
|||
|
|
task = await db.create_task(task_type="image", request_payload={"prompt": "cat"})
|
|||
|
|
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.post(f"/api/queue/{task['id']}/cancel")
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
assert response.json()["ok"] is True
|
|||
|
|
|
|||
|
|
# 验证状态
|
|||
|
|
updated = await db.get_task(task["id"])
|
|||
|
|
assert updated["status"] == "cancelled"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestProxyAPI:
|
|||
|
|
"""代理 API 测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_chat_completions_missing_auth(self, temp_db, temp_storage):
|
|||
|
|
"""测试聊天请求缺少鉴权"""
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.post(
|
|||
|
|
"/v1/chat/completions",
|
|||
|
|
json={"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 默认 proxy_api_key 为空时应该通过,但如果设置了则返回 401
|
|||
|
|
# 这里我们测试没有设置 key 的情况
|
|||
|
|
# 实际行为取决于配置
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_chat_completions_invalid_model(self, temp_db, temp_storage):
|
|||
|
|
"""测试聊天请求无效模型"""
|
|||
|
|
from app import database as db
|
|||
|
|
|
|||
|
|
# 设置空的 proxy_api_key 以跳过鉴权测试
|
|||
|
|
import app.config as config_module
|
|||
|
|
original_key = config_module.settings.server.proxy_api_key
|
|||
|
|
config_module.settings.server.proxy_api_key = ""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.post(
|
|||
|
|
"/v1/chat/completions",
|
|||
|
|
json={"model": "nonexistent-model", "messages": [{"role": "user", "content": "hello"}]},
|
|||
|
|
)
|
|||
|
|
finally:
|
|||
|
|
config_module.settings.server.proxy_api_key = original_key
|
|||
|
|
|
|||
|
|
# 应该返回 404 找不到模型路由
|
|||
|
|
assert response.status_code == 404
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_anthropic_messages_format_conversion(self, temp_db, temp_storage):
|
|||
|
|
"""测试 Anthropic 格式转换"""
|
|||
|
|
from app import database as db
|
|||
|
|
|
|||
|
|
import app.config as config_module
|
|||
|
|
original_key = config_module.settings.server.proxy_api_key
|
|||
|
|
config_module.settings.server.proxy_api_key = ""
|
|||
|
|
|
|||
|
|
# 创建路由
|
|||
|
|
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
|||
|
|
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
|||
|
|
await db.set_model_route("gpt-4", plan["id"])
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
transport = ASGITransport(app=app)
|
|||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|||
|
|
response = await client.post(
|
|||
|
|
"/v1/messages",
|
|||
|
|
json={
|
|||
|
|
"model": "gpt-4",
|
|||
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|||
|
|
"max_tokens": 100,
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
finally:
|
|||
|
|
config_module.settings.server.proxy_api_key = original_key
|
|||
|
|
|
|||
|
|
# 由于没有真实的 API,可能会失败,但至少验证格式转换不抛出错误
|
|||
|
|
# 实际会返回 502 或类似的错误,因为没有真实后端
|