5c028d7952
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
238 lines
8.8 KiB
Python
238 lines
8.8 KiB
Python
"""异步任务调度器:从 jobs 表取任务并并发执行。
|
||
|
||
事务规则:
|
||
- 调度循环只用短事务 claim 任务、汇总状态。
|
||
- 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务,
|
||
绝不在 worker 这一层包裹长事务。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional
|
||
|
||
from sqlalchemy import case, func, or_, select
|
||
|
||
from app.core.config import settings
|
||
from app.core.db import session_scope
|
||
from app.core.logger import get_logger
|
||
from app.models.job import Job, JobKind, JobStatus
|
||
from app.models.screenshot import ProcessStatus, Screenshot
|
||
from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id
|
||
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class AnalyzeWorker:
|
||
"""单实例后台 worker,负责把 jobs 表中的待处理项跑完。"""
|
||
|
||
def __init__(self) -> None:
|
||
self._task: Optional[asyncio.Task] = None
|
||
self._event = asyncio.Event()
|
||
self._stop = False
|
||
self._semaphore = asyncio.Semaphore(settings.analyze_concurrency)
|
||
self._inflight: int = 0
|
||
self._lock = asyncio.Lock()
|
||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||
|
||
async def start(self) -> None:
|
||
"""启动主循环。"""
|
||
# 启动时把上次中断的 RUNNING 任务复位
|
||
with session_scope() as session:
|
||
running = session.scalars(
|
||
select(Job).where(Job.status == JobStatus.RUNNING.value)
|
||
).all()
|
||
for job in running:
|
||
job.status = JobStatus.PENDING.value
|
||
self._stop = False
|
||
self._loop = asyncio.get_running_loop()
|
||
self._task = asyncio.create_task(self._run(), name="analyze-worker")
|
||
self.notify()
|
||
|
||
async def stop(self) -> None:
|
||
self._stop = True
|
||
self._event.set()
|
||
if self._task is not None:
|
||
self._task.cancel()
|
||
try:
|
||
await self._task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
def notify(self) -> None:
|
||
"""在事件循环线程内通知 worker 有新任务可取。"""
|
||
self._event.set()
|
||
|
||
def notify_threadsafe(self) -> None:
|
||
"""跨线程唤醒 worker(FastAPI BackgroundTasks / watcher 线程)。
|
||
|
||
asyncio.Event.set() 本身不是线程安全的,必须通过
|
||
loop.call_soon_threadsafe 调度回事件循环线程。
|
||
"""
|
||
loop = self._loop
|
||
if loop is None or loop.is_closed():
|
||
return
|
||
loop.call_soon_threadsafe(self._event.set)
|
||
|
||
async def status(self) -> dict[str, int]:
|
||
"""供 API 查询当前队列状况(按 status 索引计数,适合大批量)。"""
|
||
with session_scope() as session:
|
||
pending = session.scalar(
|
||
select(Job.id)
|
||
.where(Job.status == JobStatus.PENDING.value)
|
||
.limit(1)
|
||
)
|
||
rows = session.execute(
|
||
select(Job.status, func.count())
|
||
.group_by(Job.status)
|
||
).all()
|
||
counts = {st.value: 0 for st in JobStatus}
|
||
for status, cnt in rows:
|
||
counts[status] = int(cnt)
|
||
counts["inflight"] = self._inflight
|
||
counts["has_more"] = 1 if pending is not None else 0
|
||
return counts
|
||
|
||
def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int:
|
||
"""把长时间 RUNNING 且无进展的任务复位为 PENDING。"""
|
||
with session_scope() as session:
|
||
q = select(Job).where(Job.status == JobStatus.RUNNING.value)
|
||
if not reset_all:
|
||
cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1))
|
||
q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff)
|
||
stale = session.scalars(q).all()
|
||
for job in stale:
|
||
job.status = JobStatus.PENDING.value
|
||
job.started_at = None
|
||
count = len(stale)
|
||
if count:
|
||
logger.info("复位 %d 条 RUNNING 任务为 PENDING", count)
|
||
self.notify()
|
||
return count
|
||
|
||
def retry_failed(self, job_ids: Optional[list[int]] = None) -> int:
|
||
"""将 failed 任务重新排队。"""
|
||
with session_scope() as session:
|
||
q = select(Job).where(Job.status == JobStatus.FAILED.value)
|
||
if job_ids:
|
||
q = q.where(Job.id.in_(job_ids))
|
||
failed = session.scalars(q).all()
|
||
for job in failed:
|
||
job.status = JobStatus.PENDING.value
|
||
job.retries = 0
|
||
job.last_error = None
|
||
job.started_at = None
|
||
job.finished_at = None
|
||
count = len(failed)
|
||
if count:
|
||
logger.info("重试 %d 条 failed 任务", count)
|
||
self.notify()
|
||
return count
|
||
|
||
async def _run(self) -> None:
|
||
"""主循环。"""
|
||
idle_rounds = 0
|
||
while not self._stop:
|
||
job = self._claim_one()
|
||
if job is None:
|
||
idle_rounds += 1
|
||
# 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running
|
||
if idle_rounds >= 3 and self._inflight == 0:
|
||
idle_rounds = 0
|
||
if self.reset_stale_running(minutes=5):
|
||
continue
|
||
self._event.clear()
|
||
try:
|
||
await asyncio.wait_for(self._event.wait(), timeout=10)
|
||
except asyncio.TimeoutError:
|
||
pass
|
||
continue
|
||
idle_rounds = 0
|
||
await self._semaphore.acquire()
|
||
async with self._lock:
|
||
self._inflight += 1
|
||
asyncio.create_task(
|
||
self._process(job["id"], job["screenshot_id"], job["kind"])
|
||
)
|
||
|
||
def _claim_one(self) -> Optional[dict]:
|
||
"""短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。"""
|
||
with session_scope() as session:
|
||
job = session.scalar(
|
||
select(Job)
|
||
.where(
|
||
Job.status == JobStatus.PENDING.value,
|
||
or_(Job.retries < settings.max_retries, Job.retries.is_(None)),
|
||
)
|
||
.order_by(
|
||
case(
|
||
(Job.kind == JobKind.FULL.value, 0),
|
||
(Job.kind == JobKind.VLM.value, 1),
|
||
else_=2,
|
||
),
|
||
Job.id.asc(),
|
||
)
|
||
.limit(1)
|
||
)
|
||
if job is None:
|
||
return None
|
||
job.status = JobStatus.RUNNING.value
|
||
job.started_at = datetime.utcnow()
|
||
session.flush()
|
||
return {
|
||
"id": job.id,
|
||
"screenshot_id": job.screenshot_id,
|
||
"kind": job.kind,
|
||
}
|
||
|
||
async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None:
|
||
"""执行单个任务,所有 DB 写入均在短事务中。"""
|
||
try:
|
||
try:
|
||
if kind == JobKind.OCR.value:
|
||
await analyze_ocr_only_by_id(screenshot_id)
|
||
else:
|
||
await analyze_screenshot_by_id(screenshot_id)
|
||
self._finish(job_id, success=True, kind=kind)
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc)
|
||
self._finish(job_id, success=False, error=str(exc), kind=kind)
|
||
finally:
|
||
self._semaphore.release()
|
||
async with self._lock:
|
||
self._inflight -= 1
|
||
self.notify()
|
||
|
||
def _finish(
|
||
self,
|
||
job_id: int,
|
||
success: bool,
|
||
error: Optional[str] = None,
|
||
kind: str = JobKind.FULL.value,
|
||
) -> None:
|
||
"""短事务:更新 jobs 表完成状态。"""
|
||
with session_scope() as session:
|
||
job = session.get(Job, job_id)
|
||
if job is None:
|
||
return
|
||
if success:
|
||
job.status = JobStatus.DONE.value
|
||
job.last_error = None
|
||
else:
|
||
job.retries = (job.retries or 0) + 1
|
||
if job.retries >= settings.max_retries:
|
||
job.status = JobStatus.FAILED.value
|
||
# OCR 补跑失败不影响 ai_status
|
||
if kind != JobKind.OCR.value:
|
||
shot = session.get(Screenshot, job.screenshot_id)
|
||
if shot is not None:
|
||
shot.ai_status = ProcessStatus.FAILED.value
|
||
else:
|
||
job.status = JobStatus.PENDING.value
|
||
job.last_error = (error or "")[:1000]
|
||
job.finished_at = datetime.utcnow()
|
||
|
||
|
||
worker = AnalyzeWorker()
|