79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
"""去重模块测试"""
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from app.deduplicator import _title_similarity, _find_duplicate_clusters, deduplicate_articles
|
|
from models import EnrichedArticle
|
|
|
|
|
|
def test_title_similarity_identical():
|
|
assert _title_similarity("OpenAI 发布 GPT-5", "OpenAI 发布 GPT-5") > 0.95
|
|
|
|
|
|
def test_title_similarity_different():
|
|
assert _title_similarity("OpenAI 发布 GPT-5", "苹果发布新款 iPhone") < 0.5
|
|
|
|
|
|
def test_find_duplicate_clusters(db):
|
|
articles = [
|
|
EnrichedArticle(
|
|
rk_article_id=1,
|
|
title="OpenAI 发布 GPT-5,性能大幅提升",
|
|
content="OpenAI 今天发布了 GPT-5,性能大幅提升。",
|
|
),
|
|
EnrichedArticle(
|
|
rk_article_id=2,
|
|
title="OpenAI 发布 GPT-5 性能大幅提升",
|
|
content="OpenAI 发布了 GPT-5,性能提升明显。",
|
|
),
|
|
EnrichedArticle(
|
|
rk_article_id=3,
|
|
title="苹果发布新款 iPhone",
|
|
content="苹果公司发布了新款 iPhone。",
|
|
),
|
|
]
|
|
clusters = _find_duplicate_clusters(articles, title_threshold=0.85, content_threshold=0.80)
|
|
assert len(clusters) == 1
|
|
assert {0, 1} in clusters
|
|
|
|
|
|
def test_deduplicate_articles(db):
|
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
day_start = datetime.strptime(today, "%Y-%m-%d")
|
|
|
|
a1 = EnrichedArticle(
|
|
rk_article_id=1,
|
|
title="OpenAI 发布 GPT-5",
|
|
content="OpenAI 今天发布了 GPT-5。",
|
|
fetched_at=day_start,
|
|
)
|
|
a2 = EnrichedArticle(
|
|
rk_article_id=2,
|
|
title="OpenAI 发布 GPT-5 性能提升",
|
|
content="OpenAI 发布了 GPT-5,性能提升。",
|
|
fetched_at=day_start + timedelta(minutes=10),
|
|
)
|
|
a3 = EnrichedArticle(
|
|
rk_article_id=3,
|
|
title="苹果发布新款 iPhone",
|
|
content="苹果发布了 iPhone。",
|
|
fetched_at=day_start + timedelta(minutes=20),
|
|
)
|
|
|
|
db.add_all([a1, a2, a3])
|
|
db.commit()
|
|
|
|
stats = deduplicate_articles(db, date_str=today, title_threshold=0.85, content_threshold=0.80)
|
|
|
|
assert stats["total"] == 3
|
|
assert stats["duplicate_groups"] == 1
|
|
assert stats["representatives"] == 1
|
|
|
|
# 刷新对象
|
|
db.refresh(a1)
|
|
db.refresh(a2)
|
|
db.refresh(a3)
|
|
|
|
representatives = [a for a in [a1, a2, a3] if a.is_representative]
|
|
assert len(representatives) == 1
|
|
assert representatives[0].duplicate_group_id is not None
|