238 lines
8.8 KiB
Python
238 lines
8.8 KiB
Python
|
|
"""异步任务调度器:从 jobs 表取任务并并发执行。
|
|||
|
|
|
|||
|
|
事务规则:
|
|||
|
|
- 调度循环只用短事务 claim 任务、汇总状态。
|
|||
|
|
- 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务,
|
|||
|
|
绝不在 worker 这一层包裹长事务。
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
from sqlalchemy import case, func, or_, select
|
|||
|
|
|
|||
|
|
from app.core.config import settings
|
|||
|
|
from app.core.db import session_scope
|
|||
|
|
from app.core.logger import get_logger
|
|||
|
|
from app.models.job import Job, JobKind, JobStatus
|
|||
|
|
from app.models.screenshot import ProcessStatus, Screenshot
|
|||
|
|
from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id
|
|||
|
|
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AnalyzeWorker:
|
|||
|
|
"""单实例后台 worker,负责把 jobs 表中的待处理项跑完。"""
|
|||
|
|
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
self._task: Optional[asyncio.Task] = None
|
|||
|
|
self._event = asyncio.Event()
|
|||
|
|
self._stop = False
|
|||
|
|
self._semaphore = asyncio.Semaphore(settings.analyze_concurrency)
|
|||
|
|
self._inflight: int = 0
|
|||
|
|
self._lock = asyncio.Lock()
|
|||
|
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|||
|
|
|
|||
|
|
async def start(self) -> None:
|
|||
|
|
"""启动主循环。"""
|
|||
|
|
# 启动时把上次中断的 RUNNING 任务复位
|
|||
|
|
with session_scope() as session:
|
|||
|
|
running = session.scalars(
|
|||
|
|
select(Job).where(Job.status == JobStatus.RUNNING.value)
|
|||
|
|
).all()
|
|||
|
|
for job in running:
|
|||
|
|
job.status = JobStatus.PENDING.value
|
|||
|
|
self._stop = False
|
|||
|
|
self._loop = asyncio.get_running_loop()
|
|||
|
|
self._task = asyncio.create_task(self._run(), name="analyze-worker")
|
|||
|
|
self.notify()
|
|||
|
|
|
|||
|
|
async def stop(self) -> None:
|
|||
|
|
self._stop = True
|
|||
|
|
self._event.set()
|
|||
|
|
if self._task is not None:
|
|||
|
|
self._task.cancel()
|
|||
|
|
try:
|
|||
|
|
await self._task
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def notify(self) -> None:
|
|||
|
|
"""在事件循环线程内通知 worker 有新任务可取。"""
|
|||
|
|
self._event.set()
|
|||
|
|
|
|||
|
|
def notify_threadsafe(self) -> None:
|
|||
|
|
"""跨线程唤醒 worker(FastAPI BackgroundTasks / watcher 线程)。
|
|||
|
|
|
|||
|
|
asyncio.Event.set() 本身不是线程安全的,必须通过
|
|||
|
|
loop.call_soon_threadsafe 调度回事件循环线程。
|
|||
|
|
"""
|
|||
|
|
loop = self._loop
|
|||
|
|
if loop is None or loop.is_closed():
|
|||
|
|
return
|
|||
|
|
loop.call_soon_threadsafe(self._event.set)
|
|||
|
|
|
|||
|
|
async def status(self) -> dict[str, int]:
|
|||
|
|
"""供 API 查询当前队列状况(按 status 索引计数,适合大批量)。"""
|
|||
|
|
with session_scope() as session:
|
|||
|
|
pending = session.scalar(
|
|||
|
|
select(Job.id)
|
|||
|
|
.where(Job.status == JobStatus.PENDING.value)
|
|||
|
|
.limit(1)
|
|||
|
|
)
|
|||
|
|
rows = session.execute(
|
|||
|
|
select(Job.status, func.count())
|
|||
|
|
.group_by(Job.status)
|
|||
|
|
).all()
|
|||
|
|
counts = {st.value: 0 for st in JobStatus}
|
|||
|
|
for status, cnt in rows:
|
|||
|
|
counts[status] = int(cnt)
|
|||
|
|
counts["inflight"] = self._inflight
|
|||
|
|
counts["has_more"] = 1 if pending is not None else 0
|
|||
|
|
return counts
|
|||
|
|
|
|||
|
|
def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int:
|
|||
|
|
"""把长时间 RUNNING 且无进展的任务复位为 PENDING。"""
|
|||
|
|
with session_scope() as session:
|
|||
|
|
q = select(Job).where(Job.status == JobStatus.RUNNING.value)
|
|||
|
|
if not reset_all:
|
|||
|
|
cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1))
|
|||
|
|
q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff)
|
|||
|
|
stale = session.scalars(q).all()
|
|||
|
|
for job in stale:
|
|||
|
|
job.status = JobStatus.PENDING.value
|
|||
|
|
job.started_at = None
|
|||
|
|
count = len(stale)
|
|||
|
|
if count:
|
|||
|
|
logger.info("复位 %d 条 RUNNING 任务为 PENDING", count)
|
|||
|
|
self.notify()
|
|||
|
|
return count
|
|||
|
|
|
|||
|
|
def retry_failed(self, job_ids: Optional[list[int]] = None) -> int:
|
|||
|
|
"""将 failed 任务重新排队。"""
|
|||
|
|
with session_scope() as session:
|
|||
|
|
q = select(Job).where(Job.status == JobStatus.FAILED.value)
|
|||
|
|
if job_ids:
|
|||
|
|
q = q.where(Job.id.in_(job_ids))
|
|||
|
|
failed = session.scalars(q).all()
|
|||
|
|
for job in failed:
|
|||
|
|
job.status = JobStatus.PENDING.value
|
|||
|
|
job.retries = 0
|
|||
|
|
job.last_error = None
|
|||
|
|
job.started_at = None
|
|||
|
|
job.finished_at = None
|
|||
|
|
count = len(failed)
|
|||
|
|
if count:
|
|||
|
|
logger.info("重试 %d 条 failed 任务", count)
|
|||
|
|
self.notify()
|
|||
|
|
return count
|
|||
|
|
|
|||
|
|
async def _run(self) -> None:
|
|||
|
|
"""主循环。"""
|
|||
|
|
idle_rounds = 0
|
|||
|
|
while not self._stop:
|
|||
|
|
job = self._claim_one()
|
|||
|
|
if job is None:
|
|||
|
|
idle_rounds += 1
|
|||
|
|
# 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running
|
|||
|
|
if idle_rounds >= 3 and self._inflight == 0:
|
|||
|
|
idle_rounds = 0
|
|||
|
|
if self.reset_stale_running(minutes=5):
|
|||
|
|
continue
|
|||
|
|
self._event.clear()
|
|||
|
|
try:
|
|||
|
|
await asyncio.wait_for(self._event.wait(), timeout=10)
|
|||
|
|
except asyncio.TimeoutError:
|
|||
|
|
pass
|
|||
|
|
continue
|
|||
|
|
idle_rounds = 0
|
|||
|
|
await self._semaphore.acquire()
|
|||
|
|
async with self._lock:
|
|||
|
|
self._inflight += 1
|
|||
|
|
asyncio.create_task(
|
|||
|
|
self._process(job["id"], job["screenshot_id"], job["kind"])
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _claim_one(self) -> Optional[dict]:
|
|||
|
|
"""短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。"""
|
|||
|
|
with session_scope() as session:
|
|||
|
|
job = session.scalar(
|
|||
|
|
select(Job)
|
|||
|
|
.where(
|
|||
|
|
Job.status == JobStatus.PENDING.value,
|
|||
|
|
or_(Job.retries < settings.max_retries, Job.retries.is_(None)),
|
|||
|
|
)
|
|||
|
|
.order_by(
|
|||
|
|
case(
|
|||
|
|
(Job.kind == JobKind.FULL.value, 0),
|
|||
|
|
(Job.kind == JobKind.VLM.value, 1),
|
|||
|
|
else_=2,
|
|||
|
|
),
|
|||
|
|
Job.id.asc(),
|
|||
|
|
)
|
|||
|
|
.limit(1)
|
|||
|
|
)
|
|||
|
|
if job is None:
|
|||
|
|
return None
|
|||
|
|
job.status = JobStatus.RUNNING.value
|
|||
|
|
job.started_at = datetime.utcnow()
|
|||
|
|
session.flush()
|
|||
|
|
return {
|
|||
|
|
"id": job.id,
|
|||
|
|
"screenshot_id": job.screenshot_id,
|
|||
|
|
"kind": job.kind,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None:
|
|||
|
|
"""执行单个任务,所有 DB 写入均在短事务中。"""
|
|||
|
|
try:
|
|||
|
|
try:
|
|||
|
|
if kind == JobKind.OCR.value:
|
|||
|
|
await analyze_ocr_only_by_id(screenshot_id)
|
|||
|
|
else:
|
|||
|
|
await analyze_screenshot_by_id(screenshot_id)
|
|||
|
|
self._finish(job_id, success=True, kind=kind)
|
|||
|
|
except Exception as exc: # noqa: BLE001
|
|||
|
|
logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc)
|
|||
|
|
self._finish(job_id, success=False, error=str(exc), kind=kind)
|
|||
|
|
finally:
|
|||
|
|
self._semaphore.release()
|
|||
|
|
async with self._lock:
|
|||
|
|
self._inflight -= 1
|
|||
|
|
self.notify()
|
|||
|
|
|
|||
|
|
def _finish(
|
|||
|
|
self,
|
|||
|
|
job_id: int,
|
|||
|
|
success: bool,
|
|||
|
|
error: Optional[str] = None,
|
|||
|
|
kind: str = JobKind.FULL.value,
|
|||
|
|
) -> None:
|
|||
|
|
"""短事务:更新 jobs 表完成状态。"""
|
|||
|
|
with session_scope() as session:
|
|||
|
|
job = session.get(Job, job_id)
|
|||
|
|
if job is None:
|
|||
|
|
return
|
|||
|
|
if success:
|
|||
|
|
job.status = JobStatus.DONE.value
|
|||
|
|
job.last_error = None
|
|||
|
|
else:
|
|||
|
|
job.retries = (job.retries or 0) + 1
|
|||
|
|
if job.retries >= settings.max_retries:
|
|||
|
|
job.status = JobStatus.FAILED.value
|
|||
|
|
# OCR 补跑失败不影响 ai_status
|
|||
|
|
if kind != JobKind.OCR.value:
|
|||
|
|
shot = session.get(Screenshot, job.screenshot_id)
|
|||
|
|
if shot is not None:
|
|||
|
|
shot.ai_status = ProcessStatus.FAILED.value
|
|||
|
|
else:
|
|||
|
|
job.status = JobStatus.PENDING.value
|
|||
|
|
job.last_error = (error or "")[:1000]
|
|||
|
|
job.finished_at = datetime.utcnow()
|
|||
|
|
|
|||
|
|
|
|||
|
|
worker = AnalyzeWorker()
|