"""异步任务调度器:从 jobs 表取任务并并发执行。 事务规则: - 调度循环只用短事务 claim 任务、汇总状态。 - 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务, 绝不在 worker 这一层包裹长事务。 """ from __future__ import annotations import asyncio from datetime import datetime, timedelta from typing import Optional from sqlalchemy import case, func, or_, select from app.core.config import settings from app.core.db import session_scope from app.core.logger import get_logger from app.models.job import Job, JobKind, JobStatus from app.models.screenshot import ProcessStatus, Screenshot from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id logger = get_logger(__name__) class AnalyzeWorker: """单实例后台 worker,负责把 jobs 表中的待处理项跑完。""" def __init__(self) -> None: self._task: Optional[asyncio.Task] = None self._event = asyncio.Event() self._stop = False self._semaphore = asyncio.Semaphore(settings.analyze_concurrency) self._inflight: int = 0 self._lock = asyncio.Lock() self._loop: Optional[asyncio.AbstractEventLoop] = None async def start(self) -> None: """启动主循环。""" # 启动时把上次中断的 RUNNING 任务复位 with session_scope() as session: running = session.scalars( select(Job).where(Job.status == JobStatus.RUNNING.value) ).all() for job in running: job.status = JobStatus.PENDING.value self._stop = False self._loop = asyncio.get_running_loop() self._task = asyncio.create_task(self._run(), name="analyze-worker") self.notify() async def stop(self) -> None: self._stop = True self._event.set() if self._task is not None: self._task.cancel() try: await self._task except asyncio.CancelledError: pass def notify(self) -> None: """在事件循环线程内通知 worker 有新任务可取。""" self._event.set() def notify_threadsafe(self) -> None: """跨线程唤醒 worker(FastAPI BackgroundTasks / watcher 线程)。 asyncio.Event.set() 本身不是线程安全的,必须通过 loop.call_soon_threadsafe 调度回事件循环线程。 """ loop = self._loop if loop is None or loop.is_closed(): return loop.call_soon_threadsafe(self._event.set) async def status(self) -> dict[str, int]: """供 API 查询当前队列状况(按 status 索引计数,适合大批量)。""" with session_scope() as session: pending = session.scalar( select(Job.id) .where(Job.status == JobStatus.PENDING.value) .limit(1) ) rows = session.execute( select(Job.status, func.count()) .group_by(Job.status) ).all() counts = {st.value: 0 for st in JobStatus} for status, cnt in rows: counts[status] = int(cnt) counts["inflight"] = self._inflight counts["has_more"] = 1 if pending is not None else 0 return counts def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int: """把长时间 RUNNING 且无进展的任务复位为 PENDING。""" with session_scope() as session: q = select(Job).where(Job.status == JobStatus.RUNNING.value) if not reset_all: cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1)) q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff) stale = session.scalars(q).all() for job in stale: job.status = JobStatus.PENDING.value job.started_at = None count = len(stale) if count: logger.info("复位 %d 条 RUNNING 任务为 PENDING", count) self.notify() return count def retry_failed(self, job_ids: Optional[list[int]] = None) -> int: """将 failed 任务重新排队。""" with session_scope() as session: q = select(Job).where(Job.status == JobStatus.FAILED.value) if job_ids: q = q.where(Job.id.in_(job_ids)) failed = session.scalars(q).all() for job in failed: job.status = JobStatus.PENDING.value job.retries = 0 job.last_error = None job.started_at = None job.finished_at = None count = len(failed) if count: logger.info("重试 %d 条 failed 任务", count) self.notify() return count async def _run(self) -> None: """主循环。""" idle_rounds = 0 while not self._stop: job = self._claim_one() if job is None: idle_rounds += 1 # 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running if idle_rounds >= 3 and self._inflight == 0: idle_rounds = 0 if self.reset_stale_running(minutes=5): continue self._event.clear() try: await asyncio.wait_for(self._event.wait(), timeout=10) except asyncio.TimeoutError: pass continue idle_rounds = 0 await self._semaphore.acquire() async with self._lock: self._inflight += 1 asyncio.create_task( self._process(job["id"], job["screenshot_id"], job["kind"]) ) def _claim_one(self) -> Optional[dict]: """短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。""" with session_scope() as session: job = session.scalar( select(Job) .where( Job.status == JobStatus.PENDING.value, or_(Job.retries < settings.max_retries, Job.retries.is_(None)), ) .order_by( case( (Job.kind == JobKind.FULL.value, 0), (Job.kind == JobKind.VLM.value, 1), else_=2, ), Job.id.asc(), ) .limit(1) ) if job is None: return None job.status = JobStatus.RUNNING.value job.started_at = datetime.utcnow() session.flush() return { "id": job.id, "screenshot_id": job.screenshot_id, "kind": job.kind, } async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None: """执行单个任务,所有 DB 写入均在短事务中。""" try: try: if kind == JobKind.OCR.value: await analyze_ocr_only_by_id(screenshot_id) else: await analyze_screenshot_by_id(screenshot_id) self._finish(job_id, success=True, kind=kind) except Exception as exc: # noqa: BLE001 logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc) self._finish(job_id, success=False, error=str(exc), kind=kind) finally: self._semaphore.release() async with self._lock: self._inflight -= 1 self.notify() def _finish( self, job_id: int, success: bool, error: Optional[str] = None, kind: str = JobKind.FULL.value, ) -> None: """短事务:更新 jobs 表完成状态。""" with session_scope() as session: job = session.get(Job, job_id) if job is None: return if success: job.status = JobStatus.DONE.value job.last_error = None else: job.retries = (job.retries or 0) + 1 if job.retries >= settings.max_retries: job.status = JobStatus.FAILED.value # OCR 补跑失败不影响 ai_status if kind != JobKind.OCR.value: shot = session.get(Screenshot, job.screenshot_id) if shot is not None: shot.ai_status = ProcessStatus.FAILED.value else: job.status = JobStatus.PENDING.value job.last_error = (error or "")[:1000] job.finished_at = datetime.utcnow() worker = AnalyzeWorker()