Initial commit: snapAna 截图智能整理工具

包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
wjl
2026-05-27 15:45:50 +08:00
commit 5c028d7952
76 changed files with 10467 additions and 0 deletions
+237
View File
@@ -0,0 +1,237 @@
"""异步任务调度器:从 jobs 表取任务并并发执行。
事务规则:
- 调度循环只用短事务 claim 任务、汇总状态。
- 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务,
绝不在 worker 这一层包裹长事务。
"""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import case, func, or_, select
from app.core.config import settings
from app.core.db import session_scope
from app.core.logger import get_logger
from app.models.job import Job, JobKind, JobStatus
from app.models.screenshot import ProcessStatus, Screenshot
from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id
logger = get_logger(__name__)
class AnalyzeWorker:
"""单实例后台 worker,负责把 jobs 表中的待处理项跑完。"""
def __init__(self) -> None:
self._task: Optional[asyncio.Task] = None
self._event = asyncio.Event()
self._stop = False
self._semaphore = asyncio.Semaphore(settings.analyze_concurrency)
self._inflight: int = 0
self._lock = asyncio.Lock()
self._loop: Optional[asyncio.AbstractEventLoop] = None
async def start(self) -> None:
"""启动主循环。"""
# 启动时把上次中断的 RUNNING 任务复位
with session_scope() as session:
running = session.scalars(
select(Job).where(Job.status == JobStatus.RUNNING.value)
).all()
for job in running:
job.status = JobStatus.PENDING.value
self._stop = False
self._loop = asyncio.get_running_loop()
self._task = asyncio.create_task(self._run(), name="analyze-worker")
self.notify()
async def stop(self) -> None:
self._stop = True
self._event.set()
if self._task is not None:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
def notify(self) -> None:
"""在事件循环线程内通知 worker 有新任务可取。"""
self._event.set()
def notify_threadsafe(self) -> None:
"""跨线程唤醒 workerFastAPI BackgroundTasks / watcher 线程)。
asyncio.Event.set() 本身不是线程安全的,必须通过
loop.call_soon_threadsafe 调度回事件循环线程。
"""
loop = self._loop
if loop is None or loop.is_closed():
return
loop.call_soon_threadsafe(self._event.set)
async def status(self) -> dict[str, int]:
"""供 API 查询当前队列状况(按 status 索引计数,适合大批量)。"""
with session_scope() as session:
pending = session.scalar(
select(Job.id)
.where(Job.status == JobStatus.PENDING.value)
.limit(1)
)
rows = session.execute(
select(Job.status, func.count())
.group_by(Job.status)
).all()
counts = {st.value: 0 for st in JobStatus}
for status, cnt in rows:
counts[status] = int(cnt)
counts["inflight"] = self._inflight
counts["has_more"] = 1 if pending is not None else 0
return counts
def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int:
"""把长时间 RUNNING 且无进展的任务复位为 PENDING。"""
with session_scope() as session:
q = select(Job).where(Job.status == JobStatus.RUNNING.value)
if not reset_all:
cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1))
q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff)
stale = session.scalars(q).all()
for job in stale:
job.status = JobStatus.PENDING.value
job.started_at = None
count = len(stale)
if count:
logger.info("复位 %d 条 RUNNING 任务为 PENDING", count)
self.notify()
return count
def retry_failed(self, job_ids: Optional[list[int]] = None) -> int:
"""将 failed 任务重新排队。"""
with session_scope() as session:
q = select(Job).where(Job.status == JobStatus.FAILED.value)
if job_ids:
q = q.where(Job.id.in_(job_ids))
failed = session.scalars(q).all()
for job in failed:
job.status = JobStatus.PENDING.value
job.retries = 0
job.last_error = None
job.started_at = None
job.finished_at = None
count = len(failed)
if count:
logger.info("重试 %d 条 failed 任务", count)
self.notify()
return count
async def _run(self) -> None:
"""主循环。"""
idle_rounds = 0
while not self._stop:
job = self._claim_one()
if job is None:
idle_rounds += 1
# 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running
if idle_rounds >= 3 and self._inflight == 0:
idle_rounds = 0
if self.reset_stale_running(minutes=5):
continue
self._event.clear()
try:
await asyncio.wait_for(self._event.wait(), timeout=10)
except asyncio.TimeoutError:
pass
continue
idle_rounds = 0
await self._semaphore.acquire()
async with self._lock:
self._inflight += 1
asyncio.create_task(
self._process(job["id"], job["screenshot_id"], job["kind"])
)
def _claim_one(self) -> Optional[dict]:
"""短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。"""
with session_scope() as session:
job = session.scalar(
select(Job)
.where(
Job.status == JobStatus.PENDING.value,
or_(Job.retries < settings.max_retries, Job.retries.is_(None)),
)
.order_by(
case(
(Job.kind == JobKind.FULL.value, 0),
(Job.kind == JobKind.VLM.value, 1),
else_=2,
),
Job.id.asc(),
)
.limit(1)
)
if job is None:
return None
job.status = JobStatus.RUNNING.value
job.started_at = datetime.utcnow()
session.flush()
return {
"id": job.id,
"screenshot_id": job.screenshot_id,
"kind": job.kind,
}
async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None:
"""执行单个任务,所有 DB 写入均在短事务中。"""
try:
try:
if kind == JobKind.OCR.value:
await analyze_ocr_only_by_id(screenshot_id)
else:
await analyze_screenshot_by_id(screenshot_id)
self._finish(job_id, success=True, kind=kind)
except Exception as exc: # noqa: BLE001
logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc)
self._finish(job_id, success=False, error=str(exc), kind=kind)
finally:
self._semaphore.release()
async with self._lock:
self._inflight -= 1
self.notify()
def _finish(
self,
job_id: int,
success: bool,
error: Optional[str] = None,
kind: str = JobKind.FULL.value,
) -> None:
"""短事务:更新 jobs 表完成状态。"""
with session_scope() as session:
job = session.get(Job, job_id)
if job is None:
return
if success:
job.status = JobStatus.DONE.value
job.last_error = None
else:
job.retries = (job.retries or 0) + 1
if job.retries >= settings.max_retries:
job.status = JobStatus.FAILED.value
# OCR 补跑失败不影响 ai_status
if kind != JobKind.OCR.value:
shot = session.get(Screenshot, job.screenshot_id)
if shot is not None:
shot.ai_status = ProcessStatus.FAILED.value
else:
job.status = JobStatus.PENDING.value
job.last_error = (error or "")[:1000]
job.finished_at = datetime.utcnow()
worker = AnalyzeWorker()