141 lines
5.0 KiB
Python
141 lines
5.0 KiB
Python
|
|
"""分类/标签/打分规则体系的初始化与维护"""
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from typing import List, Dict, Any
|
||
|
|
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.ai_client import ai_client
|
||
|
|
from app.rss_client import rss_client
|
||
|
|
from models import Taxonomy
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
TAXONOMY_SYSTEM_PROMPT = """你是一位专业的信息分类与内容分析专家。
|
||
|
|
请根据用户提供的 RSS 文章样本,生成一套适合的中文内容分类体系、标签体系和打分规则。
|
||
|
|
|
||
|
|
输出必须是合法的 JSON,格式如下:
|
||
|
|
{
|
||
|
|
"categories": [
|
||
|
|
{"name": "科技", "description": "人工智能、芯片、互联网、软件等", "keywords": ["AI", "芯片", "大模型", ...]}
|
||
|
|
],
|
||
|
|
"tags": [
|
||
|
|
{"name": "人工智能", "description": "...", "keywords": ["AI", "人工智能", "大模型", ...]}
|
||
|
|
],
|
||
|
|
"heat_rules": [
|
||
|
|
{"name": "热点事件", "keywords": ["突发", "重磅", "刚刚", "发布"], "weight": 1.5}
|
||
|
|
],
|
||
|
|
"importance_rules": [
|
||
|
|
{"name": "政策法规", "keywords": ["政策", "监管", "法规", "征求意见"], "weight": 1.5}
|
||
|
|
],
|
||
|
|
"duplication_indicators": [
|
||
|
|
{"name": "同一事件", "keywords": ["宣布", "发布", "推出"], "weight": 1.0}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
要求:
|
||
|
|
1. categories 数量控制在 8-12 个,覆盖科技、财经、新闻、设计、生活等常见 RSS 主题。
|
||
|
|
2. tags 数量控制在 30-50 个,尽量细化但避免过度重叠。
|
||
|
|
3. heat_rules 和 importance_rules 各 10-20 条,weight 范围 0.5-2.0。
|
||
|
|
4. 所有 keywords 用中文或中英双语,便于后续关键词匹配。
|
||
|
|
5. 不要输出任何解释文字,只输出 JSON。
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
def _build_sample_prompt(articles: List[Dict[str, Any]]) -> str:
|
||
|
|
lines = [f"共有 {len(articles)} 篇文章样本:"]
|
||
|
|
for idx, art in enumerate(articles[:50], 1):
|
||
|
|
title = art.get("title", "")
|
||
|
|
summary = art.get("summary", "") or art.get("content", "")[:300]
|
||
|
|
feed = art.get("feed_title", "")
|
||
|
|
cat = art.get("category", "")
|
||
|
|
lines.append(f"\n[{idx}] 标题:{title}")
|
||
|
|
lines.append(f" 来源:{feed} | 源分类:{cat}")
|
||
|
|
lines.append(f" 摘要:{summary[:400]}")
|
||
|
|
return "\n".join(lines)
|
||
|
|
|
||
|
|
|
||
|
|
def bootstrap_taxonomy(db: Session, force: bool = False) -> bool:
|
||
|
|
"""
|
||
|
|
初始化分类/标签/打分规则。
|
||
|
|
若 force=True 则清空后重建;否则仅在表为空时初始化。
|
||
|
|
"""
|
||
|
|
existing = db.query(Taxonomy).first()
|
||
|
|
if existing and not force:
|
||
|
|
logger.info("taxonomy 表已存在,跳过初始化")
|
||
|
|
return False
|
||
|
|
|
||
|
|
if force:
|
||
|
|
db.query(Taxonomy).delete()
|
||
|
|
db.commit()
|
||
|
|
logger.info("强制重新初始化 taxonomy")
|
||
|
|
|
||
|
|
logger.info("开始从 rssKeeper 拉取样本文章并生成分类体系...")
|
||
|
|
articles = rss_client.fetch_recent(hours=24 * 7, limit=200)
|
||
|
|
if not articles:
|
||
|
|
logger.warning("未获取到样本文章,无法生成分类体系")
|
||
|
|
return False
|
||
|
|
|
||
|
|
user_prompt = _build_sample_prompt(articles)
|
||
|
|
try:
|
||
|
|
result = ai_client.chat_completion_json(
|
||
|
|
system_prompt=TAXONOMY_SYSTEM_PROMPT,
|
||
|
|
user_prompt=user_prompt,
|
||
|
|
temperature=0.5,
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error("生成分类体系失败: %s", exc)
|
||
|
|
return False
|
||
|
|
|
||
|
|
_save_taxonomy(db, result)
|
||
|
|
logger.info("taxonomy 初始化完成,共写入 %d 条规则", db.query(Taxonomy).count())
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def _save_taxonomy(db: Session, data: Dict[str, Any]) -> None:
|
||
|
|
"""把 LLM 返回的分类体系写入数据库"""
|
||
|
|
|
||
|
|
def _add(kind: str, items: List[Dict[str, Any]], default_weight: float = 1.0):
|
||
|
|
for item in items:
|
||
|
|
name = item.get("name", "").strip()
|
||
|
|
if not name:
|
||
|
|
continue
|
||
|
|
keywords = item.get("keywords", [])
|
||
|
|
if isinstance(keywords, str):
|
||
|
|
keywords = [keywords]
|
||
|
|
db.add(
|
||
|
|
Taxonomy(
|
||
|
|
name=name,
|
||
|
|
kind=kind,
|
||
|
|
description=item.get("description", ""),
|
||
|
|
keywords=keywords,
|
||
|
|
weight=float(item.get("weight", default_weight)),
|
||
|
|
created_by_ai=True,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
_add("category", data.get("categories", []))
|
||
|
|
_add("tag", data.get("tags", []))
|
||
|
|
_add("heat_rule", data.get("heat_rules", []), default_weight=1.0)
|
||
|
|
_add("importance_rule", data.get("importance_rules", []), default_weight=1.0)
|
||
|
|
_add("duplication_rule", data.get("duplication_indicators", []), default_weight=1.0)
|
||
|
|
|
||
|
|
db.commit()
|
||
|
|
|
||
|
|
|
||
|
|
def ensure_taxonomy(db: Session) -> bool:
|
||
|
|
"""确保 taxonomy 表非空,若为空则触发初始化"""
|
||
|
|
existing = db.query(Taxonomy).first()
|
||
|
|
if existing:
|
||
|
|
return True
|
||
|
|
return bootstrap_taxonomy(db)
|
||
|
|
|
||
|
|
|
||
|
|
def list_taxonomy(db: Session, kind: str = None) -> List[Taxonomy]:
|
||
|
|
"""列出分类体系规则"""
|
||
|
|
query = db.query(Taxonomy)
|
||
|
|
if kind:
|
||
|
|
query = query.filter(Taxonomy.kind == kind)
|
||
|
|
return query.order_by(Taxonomy.kind, Taxonomy.name).all()
|