Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,237 @@
|
||||
"""异步任务调度器:从 jobs 表取任务并并发执行。
|
||||
|
||||
事务规则:
|
||||
- 调度循环只用短事务 claim 任务、汇总状态。
|
||||
- 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务,
|
||||
绝不在 worker 这一层包裹长事务。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import case, func, or_, select
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.db import session_scope
|
||||
from app.core.logger import get_logger
|
||||
from app.models.job import Job, JobKind, JobStatus
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnalyzeWorker:
|
||||
"""单实例后台 worker,负责把 jobs 表中的待处理项跑完。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._event = asyncio.Event()
|
||||
self._stop = False
|
||||
self._semaphore = asyncio.Semaphore(settings.analyze_concurrency)
|
||||
self._inflight: int = 0
|
||||
self._lock = asyncio.Lock()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动主循环。"""
|
||||
# 启动时把上次中断的 RUNNING 任务复位
|
||||
with session_scope() as session:
|
||||
running = session.scalars(
|
||||
select(Job).where(Job.status == JobStatus.RUNNING.value)
|
||||
).all()
|
||||
for job in running:
|
||||
job.status = JobStatus.PENDING.value
|
||||
self._stop = False
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._task = asyncio.create_task(self._run(), name="analyze-worker")
|
||||
self.notify()
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stop = True
|
||||
self._event.set()
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def notify(self) -> None:
|
||||
"""在事件循环线程内通知 worker 有新任务可取。"""
|
||||
self._event.set()
|
||||
|
||||
def notify_threadsafe(self) -> None:
|
||||
"""跨线程唤醒 worker(FastAPI BackgroundTasks / watcher 线程)。
|
||||
|
||||
asyncio.Event.set() 本身不是线程安全的,必须通过
|
||||
loop.call_soon_threadsafe 调度回事件循环线程。
|
||||
"""
|
||||
loop = self._loop
|
||||
if loop is None or loop.is_closed():
|
||||
return
|
||||
loop.call_soon_threadsafe(self._event.set)
|
||||
|
||||
async def status(self) -> dict[str, int]:
|
||||
"""供 API 查询当前队列状况(按 status 索引计数,适合大批量)。"""
|
||||
with session_scope() as session:
|
||||
pending = session.scalar(
|
||||
select(Job.id)
|
||||
.where(Job.status == JobStatus.PENDING.value)
|
||||
.limit(1)
|
||||
)
|
||||
rows = session.execute(
|
||||
select(Job.status, func.count())
|
||||
.group_by(Job.status)
|
||||
).all()
|
||||
counts = {st.value: 0 for st in JobStatus}
|
||||
for status, cnt in rows:
|
||||
counts[status] = int(cnt)
|
||||
counts["inflight"] = self._inflight
|
||||
counts["has_more"] = 1 if pending is not None else 0
|
||||
return counts
|
||||
|
||||
def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int:
|
||||
"""把长时间 RUNNING 且无进展的任务复位为 PENDING。"""
|
||||
with session_scope() as session:
|
||||
q = select(Job).where(Job.status == JobStatus.RUNNING.value)
|
||||
if not reset_all:
|
||||
cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1))
|
||||
q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff)
|
||||
stale = session.scalars(q).all()
|
||||
for job in stale:
|
||||
job.status = JobStatus.PENDING.value
|
||||
job.started_at = None
|
||||
count = len(stale)
|
||||
if count:
|
||||
logger.info("复位 %d 条 RUNNING 任务为 PENDING", count)
|
||||
self.notify()
|
||||
return count
|
||||
|
||||
def retry_failed(self, job_ids: Optional[list[int]] = None) -> int:
|
||||
"""将 failed 任务重新排队。"""
|
||||
with session_scope() as session:
|
||||
q = select(Job).where(Job.status == JobStatus.FAILED.value)
|
||||
if job_ids:
|
||||
q = q.where(Job.id.in_(job_ids))
|
||||
failed = session.scalars(q).all()
|
||||
for job in failed:
|
||||
job.status = JobStatus.PENDING.value
|
||||
job.retries = 0
|
||||
job.last_error = None
|
||||
job.started_at = None
|
||||
job.finished_at = None
|
||||
count = len(failed)
|
||||
if count:
|
||||
logger.info("重试 %d 条 failed 任务", count)
|
||||
self.notify()
|
||||
return count
|
||||
|
||||
async def _run(self) -> None:
|
||||
"""主循环。"""
|
||||
idle_rounds = 0
|
||||
while not self._stop:
|
||||
job = self._claim_one()
|
||||
if job is None:
|
||||
idle_rounds += 1
|
||||
# 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running
|
||||
if idle_rounds >= 3 and self._inflight == 0:
|
||||
idle_rounds = 0
|
||||
if self.reset_stale_running(minutes=5):
|
||||
continue
|
||||
self._event.clear()
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
continue
|
||||
idle_rounds = 0
|
||||
await self._semaphore.acquire()
|
||||
async with self._lock:
|
||||
self._inflight += 1
|
||||
asyncio.create_task(
|
||||
self._process(job["id"], job["screenshot_id"], job["kind"])
|
||||
)
|
||||
|
||||
def _claim_one(self) -> Optional[dict]:
|
||||
"""短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。"""
|
||||
with session_scope() as session:
|
||||
job = session.scalar(
|
||||
select(Job)
|
||||
.where(
|
||||
Job.status == JobStatus.PENDING.value,
|
||||
or_(Job.retries < settings.max_retries, Job.retries.is_(None)),
|
||||
)
|
||||
.order_by(
|
||||
case(
|
||||
(Job.kind == JobKind.FULL.value, 0),
|
||||
(Job.kind == JobKind.VLM.value, 1),
|
||||
else_=2,
|
||||
),
|
||||
Job.id.asc(),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if job is None:
|
||||
return None
|
||||
job.status = JobStatus.RUNNING.value
|
||||
job.started_at = datetime.utcnow()
|
||||
session.flush()
|
||||
return {
|
||||
"id": job.id,
|
||||
"screenshot_id": job.screenshot_id,
|
||||
"kind": job.kind,
|
||||
}
|
||||
|
||||
async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None:
|
||||
"""执行单个任务,所有 DB 写入均在短事务中。"""
|
||||
try:
|
||||
try:
|
||||
if kind == JobKind.OCR.value:
|
||||
await analyze_ocr_only_by_id(screenshot_id)
|
||||
else:
|
||||
await analyze_screenshot_by_id(screenshot_id)
|
||||
self._finish(job_id, success=True, kind=kind)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc)
|
||||
self._finish(job_id, success=False, error=str(exc), kind=kind)
|
||||
finally:
|
||||
self._semaphore.release()
|
||||
async with self._lock:
|
||||
self._inflight -= 1
|
||||
self.notify()
|
||||
|
||||
def _finish(
|
||||
self,
|
||||
job_id: int,
|
||||
success: bool,
|
||||
error: Optional[str] = None,
|
||||
kind: str = JobKind.FULL.value,
|
||||
) -> None:
|
||||
"""短事务:更新 jobs 表完成状态。"""
|
||||
with session_scope() as session:
|
||||
job = session.get(Job, job_id)
|
||||
if job is None:
|
||||
return
|
||||
if success:
|
||||
job.status = JobStatus.DONE.value
|
||||
job.last_error = None
|
||||
else:
|
||||
job.retries = (job.retries or 0) + 1
|
||||
if job.retries >= settings.max_retries:
|
||||
job.status = JobStatus.FAILED.value
|
||||
# OCR 补跑失败不影响 ai_status
|
||||
if kind != JobKind.OCR.value:
|
||||
shot = session.get(Screenshot, job.screenshot_id)
|
||||
if shot is not None:
|
||||
shot.ai_status = ProcessStatus.FAILED.value
|
||||
else:
|
||||
job.status = JobStatus.PENDING.value
|
||||
job.last_error = (error or "")[:1000]
|
||||
job.finished_at = datetime.utcnow()
|
||||
|
||||
|
||||
worker = AnalyzeWorker()
|
||||
Reference in New Issue
Block a user