Files
congsh 5c028d7952 Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-27 15:45:50 +08:00

238 lines
8.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""异步任务调度器:从 jobs 表取任务并并发执行。
事务规则:
- 调度循环只用短事务 claim 任务、汇总状态。
- 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务,
绝不在 worker 这一层包裹长事务。
"""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import case, func, or_, select
from app.core.config import settings
from app.core.db import session_scope
from app.core.logger import get_logger
from app.models.job import Job, JobKind, JobStatus
from app.models.screenshot import ProcessStatus, Screenshot
from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id
logger = get_logger(__name__)
class AnalyzeWorker:
"""单实例后台 worker,负责把 jobs 表中的待处理项跑完。"""
def __init__(self) -> None:
self._task: Optional[asyncio.Task] = None
self._event = asyncio.Event()
self._stop = False
self._semaphore = asyncio.Semaphore(settings.analyze_concurrency)
self._inflight: int = 0
self._lock = asyncio.Lock()
self._loop: Optional[asyncio.AbstractEventLoop] = None
async def start(self) -> None:
"""启动主循环。"""
# 启动时把上次中断的 RUNNING 任务复位
with session_scope() as session:
running = session.scalars(
select(Job).where(Job.status == JobStatus.RUNNING.value)
).all()
for job in running:
job.status = JobStatus.PENDING.value
self._stop = False
self._loop = asyncio.get_running_loop()
self._task = asyncio.create_task(self._run(), name="analyze-worker")
self.notify()
async def stop(self) -> None:
self._stop = True
self._event.set()
if self._task is not None:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
def notify(self) -> None:
"""在事件循环线程内通知 worker 有新任务可取。"""
self._event.set()
def notify_threadsafe(self) -> None:
"""跨线程唤醒 workerFastAPI BackgroundTasks / watcher 线程)。
asyncio.Event.set() 本身不是线程安全的,必须通过
loop.call_soon_threadsafe 调度回事件循环线程。
"""
loop = self._loop
if loop is None or loop.is_closed():
return
loop.call_soon_threadsafe(self._event.set)
async def status(self) -> dict[str, int]:
"""供 API 查询当前队列状况(按 status 索引计数,适合大批量)。"""
with session_scope() as session:
pending = session.scalar(
select(Job.id)
.where(Job.status == JobStatus.PENDING.value)
.limit(1)
)
rows = session.execute(
select(Job.status, func.count())
.group_by(Job.status)
).all()
counts = {st.value: 0 for st in JobStatus}
for status, cnt in rows:
counts[status] = int(cnt)
counts["inflight"] = self._inflight
counts["has_more"] = 1 if pending is not None else 0
return counts
def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int:
"""把长时间 RUNNING 且无进展的任务复位为 PENDING。"""
with session_scope() as session:
q = select(Job).where(Job.status == JobStatus.RUNNING.value)
if not reset_all:
cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1))
q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff)
stale = session.scalars(q).all()
for job in stale:
job.status = JobStatus.PENDING.value
job.started_at = None
count = len(stale)
if count:
logger.info("复位 %d 条 RUNNING 任务为 PENDING", count)
self.notify()
return count
def retry_failed(self, job_ids: Optional[list[int]] = None) -> int:
"""将 failed 任务重新排队。"""
with session_scope() as session:
q = select(Job).where(Job.status == JobStatus.FAILED.value)
if job_ids:
q = q.where(Job.id.in_(job_ids))
failed = session.scalars(q).all()
for job in failed:
job.status = JobStatus.PENDING.value
job.retries = 0
job.last_error = None
job.started_at = None
job.finished_at = None
count = len(failed)
if count:
logger.info("重试 %d 条 failed 任务", count)
self.notify()
return count
async def _run(self) -> None:
"""主循环。"""
idle_rounds = 0
while not self._stop:
job = self._claim_one()
if job is None:
idle_rounds += 1
# 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running
if idle_rounds >= 3 and self._inflight == 0:
idle_rounds = 0
if self.reset_stale_running(minutes=5):
continue
self._event.clear()
try:
await asyncio.wait_for(self._event.wait(), timeout=10)
except asyncio.TimeoutError:
pass
continue
idle_rounds = 0
await self._semaphore.acquire()
async with self._lock:
self._inflight += 1
asyncio.create_task(
self._process(job["id"], job["screenshot_id"], job["kind"])
)
def _claim_one(self) -> Optional[dict]:
"""短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。"""
with session_scope() as session:
job = session.scalar(
select(Job)
.where(
Job.status == JobStatus.PENDING.value,
or_(Job.retries < settings.max_retries, Job.retries.is_(None)),
)
.order_by(
case(
(Job.kind == JobKind.FULL.value, 0),
(Job.kind == JobKind.VLM.value, 1),
else_=2,
),
Job.id.asc(),
)
.limit(1)
)
if job is None:
return None
job.status = JobStatus.RUNNING.value
job.started_at = datetime.utcnow()
session.flush()
return {
"id": job.id,
"screenshot_id": job.screenshot_id,
"kind": job.kind,
}
async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None:
"""执行单个任务,所有 DB 写入均在短事务中。"""
try:
try:
if kind == JobKind.OCR.value:
await analyze_ocr_only_by_id(screenshot_id)
else:
await analyze_screenshot_by_id(screenshot_id)
self._finish(job_id, success=True, kind=kind)
except Exception as exc: # noqa: BLE001
logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc)
self._finish(job_id, success=False, error=str(exc), kind=kind)
finally:
self._semaphore.release()
async with self._lock:
self._inflight -= 1
self.notify()
def _finish(
self,
job_id: int,
success: bool,
error: Optional[str] = None,
kind: str = JobKind.FULL.value,
) -> None:
"""短事务:更新 jobs 表完成状态。"""
with session_scope() as session:
job = session.get(Job, job_id)
if job is None:
return
if success:
job.status = JobStatus.DONE.value
job.last_error = None
else:
job.retries = (job.retries or 0) + 1
if job.retries >= settings.max_retries:
job.status = JobStatus.FAILED.value
# OCR 补跑失败不影响 ai_status
if kind != JobKind.OCR.value:
shot = session.get(Screenshot, job.screenshot_id)
if shot is not None:
shot.ai_status = ProcessStatus.FAILED.value
else:
job.status = JobStatus.PENDING.value
job.last_error = (error or "")[:1000]
job.finished_at = datetime.utcnow()
worker = AnalyzeWorker()