feat: CutThenThink v3.0 初始版本
完整实现 Tauri + Vanilla JS 轻量级截图工具 Phase 1 - 项目搭建 - Tauri 2.x 项目初始化 - Vite 前端项目搭建 - 基础 UI 框架(CSS 变量、组件库) - 构建配置优化 Phase 2 - 核心截图功能 - 全屏/区域/窗口截图 - 截图预览和管理 - 文件命名和缩略图 - 全局快捷键集成 Phase 3 - 上传与存储 - 多图床上传(GitHub/Imgur/自定义) - 配置管理系统 - SQLite 数据库 Phase 4 - OCR 集成 - 云端 OCR(百度/腾讯云) - 插件管理系统 - 本地 OCR 插件(Go) - OCR 结果处理 Phase 5 - AI 分类系统 - Claude/OpenAI API 集成 - Prompt 模板引擎 - 模板管理界面 - 自动分类流程 Phase 6 - 历史记录与管理 - 图库视图(网格/列表) - 搜索与筛选 - 批量操作 - 导出功能(JSON/CSV/ZIP) Phase 7 - 打包与发布 - 多平台构建配置 - CI/CD 工作流 - 图标和资源 - 安装包配置 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
4
src-tauri/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
/target/
|
||||
/gen/schemas
|
||||
4932
src-tauri/Cargo.lock
generated
Normal file
54
src-tauri/Cargo.toml
Normal file
@@ -0,0 +1,54 @@
|
||||
[package]
|
||||
name = "cut-think-lite"
|
||||
version = "0.1.0"
|
||||
description = "CutThenThink Lite - AI-powered Clipboard Manager"
|
||||
authors = ["CutThenThink Team"]
|
||||
license = "MIT"
|
||||
repository = "https://github.com/cutthenthink/cutThink-lite"
|
||||
edition = "2021"
|
||||
rust-version = "1.77.2"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[lib]
|
||||
name = "app_lib"
|
||||
crate-type = ["staticlib", "cdylib", "rlib"]
|
||||
|
||||
[build-dependencies]
|
||||
tauri-build = { version = "2.5.4" }
|
||||
|
||||
[dependencies]
|
||||
serde_json = "1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
log = "0.4"
|
||||
tauri = { version = "2.10.0" }
|
||||
tauri-plugin-log = "2"
|
||||
tauri-plugin-global-shortcut = "2"
|
||||
tauri-plugin-shell = "2"
|
||||
screenshots = "0.7"
|
||||
image = "0.24"
|
||||
base64 = "0.21"
|
||||
chrono = "0.4"
|
||||
arboard = "3.2"
|
||||
anyhow = "1.0"
|
||||
dirs = "5.0"
|
||||
|
||||
# Phase 3 dependencies
|
||||
reqwest = { version = "0.11", features = ["json", "multipart"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
rusqlite = { version = "0.30", features = ["bundled", "chrono"] }
|
||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||
notify = "6.0"
|
||||
|
||||
# Phase 4 dependencies
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
hmac = "0.12"
|
||||
aes-gcm = "0.10"
|
||||
rand = "0.8"
|
||||
urlencoding = "2.1"
|
||||
futures-util = "0.3"
|
||||
tempfile = "3.3"
|
||||
|
||||
# Phase 5 dependencies (AI)
|
||||
thiserror = "1.0"
|
||||
48
src-tauri/appstream.xml
Normal file
@@ -0,0 +1,48 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<component type="desktop-application">
|
||||
<id>com.cutthenthink.app</id>
|
||||
<name>CutThenThink Lite</name>
|
||||
<summary>Lightweight screenshot and annotation tool</summary>
|
||||
<summary xml:lang="zh_CN">轻量级截图与标注工具</summary>
|
||||
<developer_name>CutThenThink</developer_name>
|
||||
<launchable type="desktop-id">com.cutthenthink.app.desktop</launchable>
|
||||
|
||||
<description>
|
||||
<p>
|
||||
CutThenThink Lite is a lightweight screenshot capture and annotation tool designed for quick visual communication.
|
||||
</p>
|
||||
<p xml:lang="zh_CN">
|
||||
CutThenThink Lite 是一个轻量级的截图捕获和标注工具,专为快速视觉交流而设计。
|
||||
</p>
|
||||
<p>Features:</p>
|
||||
<ul>
|
||||
<li>Quick screen capture with customizable shortcuts</li>
|
||||
<li>Advanced annotation tools (text, arrows, shapes, blur)</li>
|
||||
<li>Multiple save formats (PNG, JPEG, WebP)</li>
|
||||
<li>OCR text recognition plugin support</li>
|
||||
<li>Minimal resource usage</li>
|
||||
</ul>
|
||||
</description>
|
||||
|
||||
<metadata_license>MIT</metadata_license>
|
||||
<project_license>MIT</project_license>
|
||||
|
||||
<supports>
|
||||
<control>pointing</control>
|
||||
<control>keyboard</control>
|
||||
</supports>
|
||||
|
||||
<url type="homepage">https://github.com/cutthenthink/cutthink-lite</url>
|
||||
<url type="bugtracker">https://github.com/cutthenthink/cutthink-lite/issues</url>
|
||||
<url type="donation">https://github.com/sponsors/cutthenthink</url>
|
||||
|
||||
<content_rating type="oars-1.1" />
|
||||
|
||||
<releases>
|
||||
<release version="0.1.0" date="2025-02-12">
|
||||
<description>
|
||||
<p>Initial release</p>
|
||||
</description>
|
||||
</release>
|
||||
</releases>
|
||||
</component>
|
||||
3
src-tauri/build.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
tauri_build::build()
|
||||
}
|
||||
11
src-tauri/capabilities/default.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"$schema": "../gen/schemas/desktop-schema.json",
|
||||
"identifier": "default",
|
||||
"description": "enables the default permissions",
|
||||
"windows": [
|
||||
"main"
|
||||
],
|
||||
"permissions": [
|
||||
"core:default"
|
||||
]
|
||||
}
|
||||
17
src-tauri/com.cutthenthink.app.desktop
Normal file
@@ -0,0 +1,17 @@
|
||||
[Desktop Entry]
|
||||
Name=CutThenThink Lite
|
||||
Comment=Lightweight screenshot and annotation tool
|
||||
Comment[zh_CN]=轻量级截图与标注工具
|
||||
GenericName=Screenshot Tool
|
||||
GenericName[zh_CN]=截图工具
|
||||
Exec=cutthink-lite %U
|
||||
Icon=com.cutthenthink.app
|
||||
Type=Application
|
||||
Terminal=false
|
||||
Categories=Utility;Graphics;2DGraphics;
|
||||
Keywords=screenshot;capture;annotation;image;
|
||||
Keywords[zh_CN]=截图;捕获;标注;图像;
|
||||
StartupNotify=true
|
||||
StartupWMClass=cutthink-lite
|
||||
MimeType=image/png;image/jpeg;image/jpg;image/webp;
|
||||
X-GNOME-UsesNotifications=true
|
||||
BIN
src-tauri/icons/128x128.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
src-tauri/icons/128x128@2x.png
Normal file
|
After Width: | Height: | Size: 23 KiB |
BIN
src-tauri/icons/32x32.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
src-tauri/icons/Square107x107Logo.png
Normal file
|
After Width: | Height: | Size: 9.0 KiB |
BIN
src-tauri/icons/Square142x142Logo.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
src-tauri/icons/Square150x150Logo.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
BIN
src-tauri/icons/Square284x284Logo.png
Normal file
|
After Width: | Height: | Size: 25 KiB |
BIN
src-tauri/icons/Square30x30Logo.png
Normal file
|
After Width: | Height: | Size: 2.0 KiB |
BIN
src-tauri/icons/Square310x310Logo.png
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
src-tauri/icons/Square44x44Logo.png
Normal file
|
After Width: | Height: | Size: 3.3 KiB |
BIN
src-tauri/icons/Square71x71Logo.png
Normal file
|
After Width: | Height: | Size: 5.9 KiB |
BIN
src-tauri/icons/Square89x89Logo.png
Normal file
|
After Width: | Height: | Size: 7.4 KiB |
BIN
src-tauri/icons/StoreLogo.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
src-tauri/icons/icon.icns
Normal file
BIN
src-tauri/icons/icon.ico
Normal file
|
After Width: | Height: | Size: 37 KiB |
BIN
src-tauri/icons/icon.png
Normal file
|
After Width: | Height: | Size: 49 KiB |
91
src-tauri/nsis/custom.nsi
Normal file
@@ -0,0 +1,91 @@
|
||||
; Custom NSIS script for CutThenThink Lite
|
||||
; This provides additional configuration for Windows installer
|
||||
|
||||
!addincludedir "${CMAKE_CURRENT_SOURCE_DIR}/nsis"
|
||||
|
||||
; Modern UI Interface
|
||||
!include "MUI2.nsh"
|
||||
|
||||
; Installer Settings
|
||||
Name "CutThenThink Lite"
|
||||
OutFile "CutThenThink-Lite-Setup.exe"
|
||||
InstallDir "$PROGRAMFILES\CutThenThink Lite"
|
||||
InstallDirRegKey HKLM "Software\CutThenThink Lite" "InstallLocation"
|
||||
RequestExecutionLevel admin
|
||||
|
||||
; Variables
|
||||
Var StartMenuFolder
|
||||
|
||||
; Interface Settings
|
||||
!define MUI_ABORTWARNING
|
||||
!define MUI_ICON "icons\icon.ico"
|
||||
!define MUI_UNICON "icons\icon.ico"
|
||||
!define MUI_HEADERIMAGE
|
||||
!define MUI_HEADERIMAGE_BITMAP "icons\header.bmp" ; Optional
|
||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "icons\welcome.bmp" ; Optional
|
||||
|
||||
; Pages
|
||||
!insertmacro MUI_PAGE_WELCOME
|
||||
!insertmacro MUI_PAGE_LICENSE "LICENSE"
|
||||
!insertmacro MUI_PAGE_COMPONENTS
|
||||
!insertmacro MUI_PAGE_DIRECTORY
|
||||
!insertmacro MUI_PAGE_STARTMENU Application $StartMenuFolder
|
||||
!insertmacro MUI_PAGE_INSTFILES
|
||||
!insertmacro MUI_PAGE_FINISH
|
||||
|
||||
!insertmacro MUI_UNPAGE_WELCOME
|
||||
!insertmacro MUI_UNPAGE_CONFIRM
|
||||
!insertmacro MUI_UNPAGE_INSTFILES
|
||||
!insertmacro MUI_UNPAGE_FINISH
|
||||
|
||||
; Languages
|
||||
!insertmacro MUI_LANGUAGE "English"
|
||||
!insertmacro MUI_LANGUAGE "SimpChinese"
|
||||
|
||||
; Installer Sections
|
||||
Section "Main Application" SecMain
|
||||
SectionIn RO
|
||||
|
||||
SetOutPath $INSTDIR
|
||||
|
||||
; Install application files
|
||||
File /r "${CMAKE_CURRENT_SOURCE_DIR}\target\${RUST_TARGET}\release\bundle\nsis\*.*"
|
||||
|
||||
; Create uninstaller
|
||||
WriteUninstaller "$INSTDIR\Uninstall.exe"
|
||||
|
||||
; Register installation
|
||||
WriteRegStr HKLM "Software\CutThenThink Lite" "InstallLocation" "$INSTDIR"
|
||||
WriteRegStr HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\CutThenThink Lite" "DisplayName" "CutThenThink Lite"
|
||||
WriteRegStr HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\CutThenThink Lite" "UninstallString" "$INSTDIR\Uninstall.exe"
|
||||
WriteRegStr HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\CutThenThink Lite" "Publisher" "CutThenThink"
|
||||
WriteRegStr HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\CutThenThink Lite" "DisplayVersion" "${TAURI_APP_VERSION}"
|
||||
WriteRegDWORD HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\CutThenThink Lite" "NoModify" 1
|
||||
WriteRegDWORD HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\CutThenThink Lite" "NoRepair" 1
|
||||
|
||||
; Create Start Menu shortcuts
|
||||
!insertmacro MUI_STARTMENU_WRITE_BEGIN Application
|
||||
CreateDirectory "$SMPROGRAMS\$StartMenuFolder"
|
||||
CreateShortcut "$SMPROGRAMS\$StartMenuFolder\CutThenThink Lite.lnk" "$INSTDIR\CutThenThink Lite.exe"
|
||||
CreateShortcut "$SMPROGRAMS\$StartMenuFolder\Uninstall.lnk" "$INSTDIR\Uninstall.exe"
|
||||
!insertmacro MUI_STARTMENU_WRITE_END
|
||||
|
||||
; Create desktop shortcut
|
||||
CreateShortcut "$DESKTOP\CutThenThink Lite.lnk" "$INSTDIR\CutThenThink Lite.exe"
|
||||
|
||||
SectionEnd
|
||||
|
||||
; Uninstaller Section
|
||||
Section "Uninstall"
|
||||
Delete "$INSTDIR\Uninstall.exe"
|
||||
Delete "$INSTDIR\*.*"
|
||||
RMDir /r "$INSTDIR"
|
||||
|
||||
!insertmacro MUI_STARTMENU_GETFOLDER Application $StartMenuFolder
|
||||
Delete "$SMPROGRAMS\$StartMenuFolder\*.*"
|
||||
RMDir "$SMPROGRAMS\$StartMenuFolder"
|
||||
Delete "$DESKTOP\CutThenThink Lite.lnk"
|
||||
|
||||
DeleteRegKey HKLM "Software\CutThenThink Lite"
|
||||
DeleteRegKey HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\CutThenThink Lite"
|
||||
SectionEnd
|
||||
360
src-tauri/src/ai/classify.rs
Normal file
@@ -0,0 +1,360 @@
|
||||
use super::{AiClient, AiError, AiProvider, ClassificationResult, PromptEngine};
|
||||
use anyhow::{Context, Result};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// 分类器配置
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClassifierConfig {
|
||||
/// 使用的 AI 提供商
|
||||
pub provider: AiProvider,
|
||||
/// 最小置信度阈值(低于此值需要人工确认)
|
||||
pub min_confidence: f64,
|
||||
/// 是否自动应用分类结果
|
||||
pub auto_apply: bool,
|
||||
}
|
||||
|
||||
impl Default for ClassifierConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
provider: AiProvider::Claude,
|
||||
min_confidence: 0.7,
|
||||
auto_apply: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 分类器
|
||||
pub struct Classifier {
|
||||
/// AI 客户端
|
||||
client: Mutex<AiClient>,
|
||||
/// Prompt 引擎
|
||||
prompt_engine: PromptEngine,
|
||||
/// 分类器配置
|
||||
config: ClassifierConfig,
|
||||
}
|
||||
|
||||
impl Classifier {
|
||||
/// 创建新的分类器
|
||||
pub fn new(config: ClassifierConfig) -> Self {
|
||||
let client = AiClient::new(config.provider);
|
||||
|
||||
Self {
|
||||
client: Mutex::new(client),
|
||||
prompt_engine: PromptEngine::new(),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// 配置 AI 客户端
|
||||
pub fn configure_claude(&self, api_key: String, model: Option<String>) -> Result<()> {
|
||||
let mut client = self.client.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let claude_config = super::client::ClaudeConfig {
|
||||
api_key,
|
||||
model: model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string()),
|
||||
max_tokens: 4096,
|
||||
temperature: 0.3, // 较低的温度以获得更一致的结果
|
||||
};
|
||||
|
||||
*client = AiClient::new(self.config.provider)
|
||||
.with_claude_config(claude_config);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 配置 OpenAI 客户端
|
||||
pub fn configure_openai(
|
||||
&self,
|
||||
api_key: String,
|
||||
model: Option<String>,
|
||||
base_url: Option<String>,
|
||||
) -> Result<()> {
|
||||
let mut client = self.client.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let openai_config = super::client::OpenAiConfig {
|
||||
api_key,
|
||||
model: model.unwrap_or_else(|| "gpt-4o".to_string()),
|
||||
max_tokens: 4096,
|
||||
temperature: 0.3,
|
||||
base_url,
|
||||
};
|
||||
|
||||
*client = AiClient::new(self.config.provider)
|
||||
.with_openai_config(openai_config);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 对内容进行分类
|
||||
pub async fn classify(
|
||||
&self,
|
||||
template_id: Option<&str>,
|
||||
variables: &HashMap<String, String>,
|
||||
) -> Result<ClassificationResult> {
|
||||
// 获取模板
|
||||
let template = if let Some(id) = template_id {
|
||||
self.prompt_engine.get_template(id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Template not found: {}", id))?
|
||||
} else {
|
||||
self.prompt_engine.get_default_template()
|
||||
};
|
||||
|
||||
// 渲染 Prompt
|
||||
let (system_prompt, user_prompt) = self
|
||||
.prompt_engine
|
||||
.render_template(template, variables)?;
|
||||
|
||||
// 构建消息
|
||||
use super::client::{Message, MessageRole};
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: MessageRole::System,
|
||||
content: system_prompt,
|
||||
},
|
||||
Message {
|
||||
role: MessageRole::User,
|
||||
content: user_prompt,
|
||||
},
|
||||
];
|
||||
|
||||
// 调用 AI
|
||||
let client = self.client.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
// 释放锁后再调用
|
||||
let client_ref = unsafe { &*(&*client as *const AiClient) };
|
||||
|
||||
let result = client_ref.chat(messages).await?;
|
||||
|
||||
// 解析分类结果
|
||||
self.parse_classification_result(&result.content, template_id)
|
||||
}
|
||||
|
||||
/// 流式分类(实时返回结果)
|
||||
pub async fn classify_stream(
|
||||
&self,
|
||||
template_id: Option<&str>,
|
||||
variables: &HashMap<String, String>,
|
||||
mut callback: impl FnMut(super::client::StreamChunk),
|
||||
) -> Result<ClassificationResult> {
|
||||
// 获取模板
|
||||
let template = if let Some(id) = template_id {
|
||||
self.prompt_engine.get_template(id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Template not found: {}", id))?
|
||||
} else {
|
||||
self.prompt_engine.get_default_template()
|
||||
};
|
||||
|
||||
// 渲染 Prompt
|
||||
let (system_prompt, user_prompt) = self
|
||||
.prompt_engine
|
||||
.render_template(template, variables)?;
|
||||
|
||||
// 构建消息
|
||||
use super::client::{Message, MessageRole};
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: MessageRole::System,
|
||||
content: system_prompt,
|
||||
},
|
||||
Message {
|
||||
role: MessageRole::User,
|
||||
content: user_prompt,
|
||||
},
|
||||
];
|
||||
|
||||
// 调用 AI 流式接口
|
||||
let client = self.client.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
// 释放锁后再调用
|
||||
let client_ref = unsafe { &*(&*client as *const AiClient) };
|
||||
|
||||
let result = client_ref.chat_stream(messages, callback).await?;
|
||||
|
||||
// 解析分类结果
|
||||
self.parse_classification_result(&result.content, template_id)
|
||||
}
|
||||
|
||||
/// 解析分类结果
|
||||
fn parse_classification_result(
|
||||
&self,
|
||||
content: &str,
|
||||
template_id: Option<&str>,
|
||||
) -> Result<ClassificationResult> {
|
||||
// 尝试从内容中提取 JSON
|
||||
let json_str = self.extract_json(content)?;
|
||||
|
||||
// 解析 JSON
|
||||
let value: Value = serde_json::from_str(&json_str)
|
||||
.context("Failed to parse classification result as JSON")?;
|
||||
|
||||
// 提取字段
|
||||
let category = value
|
||||
.get("category")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("未分类")
|
||||
.to_string();
|
||||
|
||||
let subcategory = value
|
||||
.get("subcategory")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let tags = value
|
||||
.get("tags")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let confidence = value
|
||||
.get("confidence")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.5);
|
||||
|
||||
let reasoning = value
|
||||
.get("reasoning")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
Ok(ClassificationResult {
|
||||
category,
|
||||
subcategory,
|
||||
tags,
|
||||
confidence,
|
||||
reasoning,
|
||||
})
|
||||
}
|
||||
|
||||
/// 从文本中提取 JSON
|
||||
fn extract_json(&self, text: &str) -> Result<String> {
|
||||
// 查找 ```json 代码块
|
||||
if let Some(start) = text.find("```json") {
|
||||
let json_start = start + 7;
|
||||
if let Some(end) = text[json_start..].find("```") {
|
||||
let json_str = text[json_start..json_start + end].trim();
|
||||
return Ok(json_str.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// 查找 ``` 代码块
|
||||
if let Some(start) = text.find("```") {
|
||||
let json_start = start + 3;
|
||||
if let Some(end) = text[json_start..].find("```") {
|
||||
let json_str = text[json_start..json_start + end].trim();
|
||||
return Ok(json_str.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试查找 { ... }
|
||||
if let Some(start) = text.find('{') {
|
||||
if let Some(end) = text.rfind('}') {
|
||||
if end > start {
|
||||
return Ok(text[start..=end].to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果都找不到,返回原文本
|
||||
Ok(text.to_string())
|
||||
}
|
||||
|
||||
/// 评估是否需要人工确认
|
||||
pub fn needs_confirmation(&self, result: &ClassificationResult) -> bool {
|
||||
result.confidence < self.config.min_confidence
|
||||
}
|
||||
|
||||
/// 获取可用的模板列表
|
||||
pub fn available_templates(&self) -> Vec<&super::template::Template> {
|
||||
self.prompt_engine.list_templates()
|
||||
}
|
||||
|
||||
/// 获取模板
|
||||
pub fn get_template(&self, id: &str) -> Option<&super::template::Template> {
|
||||
self.prompt_engine.get_template(id)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_json_extraction() {
|
||||
let classifier = Classifier::new(ClassifierConfig::default());
|
||||
|
||||
let text = r#"这是一些前置文本
|
||||
```json
|
||||
{
|
||||
"category": "代码",
|
||||
"confidence": 0.95
|
||||
}
|
||||
```
|
||||
后置文本"#;
|
||||
|
||||
let json = classifier.extract_json(text).unwrap();
|
||||
assert!(json.contains("category"));
|
||||
assert!(json.contains("代码"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classification_parsing() {
|
||||
let classifier = Classifier::new(ClassifierConfig::default());
|
||||
|
||||
let json = r#"{
|
||||
"category": "代码",
|
||||
"subcategory": "Python",
|
||||
"tags": ["编程", "脚本", "自动化"],
|
||||
"confidence": 0.95,
|
||||
"reasoning": "这是一段 Python 代码"
|
||||
}"#;
|
||||
|
||||
let result = classifier.parse_classification_result(json, None).unwrap();
|
||||
assert_eq!(result.category, "代码");
|
||||
assert_eq!(result.subcategory, Some("Python".to_string()));
|
||||
assert_eq!(result.tags.len(), 3);
|
||||
assert_eq!(result.confidence, 0.95);
|
||||
assert!(result.reasoning.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_confidence_check() {
|
||||
let config = ClassifierConfig {
|
||||
provider: AiProvider::Claude,
|
||||
min_confidence: 0.8,
|
||||
auto_apply: false,
|
||||
};
|
||||
|
||||
let classifier = Classifier::new(config);
|
||||
|
||||
let high_conf = ClassificationResult {
|
||||
category: "测试".to_string(),
|
||||
subcategory: None,
|
||||
tags: vec![],
|
||||
confidence: 0.9,
|
||||
reasoning: None,
|
||||
};
|
||||
|
||||
let low_conf = ClassificationResult {
|
||||
category: "测试".to_string(),
|
||||
subcategory: None,
|
||||
tags: vec![],
|
||||
confidence: 0.7,
|
||||
reasoning: None,
|
||||
};
|
||||
|
||||
assert!(!classifier.needs_confirmation(&high_conf));
|
||||
assert!(classifier.needs_confirmation(&low_conf));
|
||||
}
|
||||
}
|
||||
511
src-tauri/src/ai/client.rs
Normal file
@@ -0,0 +1,511 @@
|
||||
use super::{AiError, AiResult};
|
||||
use anyhow::Result;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// AI 提供商类型
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AiProvider {
|
||||
Claude,
|
||||
OpenAi,
|
||||
}
|
||||
|
||||
/// Claude API 配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClaudeConfig {
|
||||
pub api_key: String,
|
||||
#[serde(default = "default_claude_model")]
|
||||
pub model: String,
|
||||
#[serde(default = "default_claude_max_tokens")]
|
||||
pub max_tokens: u32,
|
||||
#[serde(default = "default_claude_temperature")]
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
fn default_claude_model() -> String {
|
||||
"claude-3-5-sonnet-20241022".to_string()
|
||||
}
|
||||
|
||||
fn default_claude_max_tokens() -> u32 {
|
||||
4096
|
||||
}
|
||||
|
||||
fn default_claude_temperature() -> f32 {
|
||||
0.7
|
||||
}
|
||||
|
||||
/// OpenAI API 配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAiConfig {
|
||||
pub api_key: String,
|
||||
#[serde(default = "default_openai_model")]
|
||||
pub model: String,
|
||||
#[serde(default = "default_openai_max_tokens")]
|
||||
pub max_tokens: u32,
|
||||
#[serde(default = "default_openai_temperature")]
|
||||
pub temperature: f32,
|
||||
pub base_url: Option<String>,
|
||||
}
|
||||
|
||||
fn default_openai_model() -> String {
|
||||
"gpt-4o".to_string()
|
||||
}
|
||||
|
||||
fn default_openai_max_tokens() -> u32 {
|
||||
4096
|
||||
}
|
||||
|
||||
fn default_openai_temperature() -> f32 {
|
||||
0.7
|
||||
}
|
||||
|
||||
/// 消息角色
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MessageRole {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
/// 聊天消息
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: MessageRole,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// 流式响应块
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamChunk {
|
||||
/// 文本增量
|
||||
Text(String),
|
||||
/// 完成
|
||||
Done,
|
||||
/// 错误
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// 限流器
|
||||
struct RateLimiter {
|
||||
last_request: Arc<Mutex<std::time::Instant>>,
|
||||
min_interval: Duration,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
fn new(requests_per_second: u32) -> Self {
|
||||
Self {
|
||||
last_request: Arc::new(Mutex::new(std::time::Instant::now()
|
||||
.checked_sub(Duration::from_secs(1)).unwrap())),
|
||||
min_interval: Duration::from_secs(1) / requests_per_second.max(1),
|
||||
}
|
||||
}
|
||||
|
||||
async fn acquire(&self) {
|
||||
let mut last = self.last_request.lock().await;
|
||||
let elapsed = last.elapsed();
|
||||
if elapsed < self.min_interval {
|
||||
tokio::time::sleep(self.min_interval - elapsed).await;
|
||||
}
|
||||
*last = std::time::Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
/// AI 客户端
|
||||
pub struct AiClient {
|
||||
provider: AiProvider,
|
||||
http_client: reqwest::Client,
|
||||
rate_limiter: RateLimiter,
|
||||
claude_config: Option<ClaudeConfig>,
|
||||
openai_config: Option<OpenAiConfig>,
|
||||
}
|
||||
|
||||
impl AiClient {
|
||||
/// 创建新的 AI 客户端
|
||||
pub fn new(provider: AiProvider) -> Self {
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
// 限制每秒最多 5 个请求
|
||||
let rate_limiter = RateLimiter::new(5);
|
||||
|
||||
Self {
|
||||
provider,
|
||||
http_client,
|
||||
rate_limiter,
|
||||
claude_config: None,
|
||||
openai_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置 Claude 配置
|
||||
pub fn with_claude_config(mut self, config: ClaudeConfig) -> Self {
|
||||
self.claude_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置 OpenAI 配置
|
||||
pub fn with_openai_config(mut self, config: OpenAiConfig) -> Self {
|
||||
self.openai_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// 发送聊天请求(非流式)
|
||||
pub async fn chat(&self, messages: Vec<Message>) -> Result<AiResult> {
|
||||
self.rate_limiter.acquire().await;
|
||||
|
||||
match self.provider {
|
||||
AiProvider::Claude => self.claude_chat(messages).await,
|
||||
AiProvider::OpenAi => self.openai_chat(messages).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// 发送流式聊天请求
|
||||
pub async fn chat_stream(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
mut callback: impl FnMut(StreamChunk),
|
||||
) -> Result<AiResult> {
|
||||
self.rate_limiter.acquire().await;
|
||||
|
||||
match self.provider {
|
||||
AiProvider::Claude => self.claude_chat_stream(messages, callback).await,
|
||||
AiProvider::OpenAi => self.openai_chat_stream(messages, callback).await,
|
||||
}
|
||||
}
|
||||
|
||||
// ========== Claude API ==========
|
||||
|
||||
/// Claude 聊天请求
|
||||
async fn claude_chat(&self, messages: Vec<Message>) -> Result<AiResult> {
|
||||
let config = self.claude_config.as_ref()
|
||||
.ok_or_else(|| AiError::ConfigError("Claude 配置未设置".to_string()))?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ClaudeRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClaudeResponse {
|
||||
id: String,
|
||||
content: Vec<ClaudeContent>,
|
||||
model: String,
|
||||
usage: ClaudeUsage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClaudeContent {
|
||||
#[serde(rename = "type")]
|
||||
content_type: String,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClaudeUsage {
|
||||
input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
}
|
||||
|
||||
let request = ClaudeRequest {
|
||||
model: config.model.clone(),
|
||||
messages,
|
||||
max_tokens: config.max_tokens,
|
||||
temperature: config.temperature,
|
||||
};
|
||||
|
||||
let response = self.http_client
|
||||
.post("https://api.anthropic.com/v1/messages")
|
||||
.header("x-api-key", &config.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("content-type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
self.handle_response(response).await
|
||||
}
|
||||
|
||||
/// Claude 流式聊天请求
|
||||
async fn claude_chat_stream(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
mut callback: impl FnMut(StreamChunk),
|
||||
) -> Result<AiResult> {
|
||||
let config = self.claude_config.as_ref()
|
||||
.ok_or_else(|| AiError::ConfigError("Claude 配置未设置".to_string()))?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ClaudeRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
let request = ClaudeRequest {
|
||||
model: config.model.clone(),
|
||||
messages,
|
||||
max_tokens: config.max_tokens,
|
||||
temperature: config.temperature,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let response = self.http_client
|
||||
.post("https://api.anthropic.com/v1/messages")
|
||||
.header("x-api-key", &config.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("content-type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await?;
|
||||
callback(StreamChunk::Error(error_text.clone()));
|
||||
return Err(AiError::ApiError(error_text));
|
||||
}
|
||||
|
||||
let mut full_content = String::new();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result?;
|
||||
let text = String::from_utf8_lossy(&chunk);
|
||||
|
||||
// 处理 SSE 格式
|
||||
for line in text.lines() {
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line[6..];
|
||||
if data == "[DONE]" {
|
||||
callback(StreamChunk::Done);
|
||||
break;
|
||||
}
|
||||
|
||||
if let Ok(event) = serde_json::from_str::<serde_json::Value>(data) {
|
||||
if let Some(delta) = event.get("delta")
|
||||
.and_then(|d| d.get("text"))
|
||||
.and_then(|t| t.as_str())
|
||||
{
|
||||
full_content.push_str(delta);
|
||||
callback(StreamChunk::Text(delta.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AiResult {
|
||||
content: full_content,
|
||||
tokens_used: None,
|
||||
model: config.model.clone(),
|
||||
confidence: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== OpenAI API ==========
|
||||
|
||||
/// OpenAI 聊天请求
|
||||
async fn openai_chat(&self, messages: Vec<Message>) -> Result<AiResult> {
|
||||
let config = self.openai_config.as_ref()
|
||||
.ok_or_else(|| AiError::ConfigError("OpenAI 配置未设置".to_string()))?;
|
||||
|
||||
let base_url = config.base_url.as_ref()
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("https://api.openai.com/v1");
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAiRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiResponse {
|
||||
id: String,
|
||||
choices: Vec<OpenAiChoice>,
|
||||
usage: OpenAiUsage,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiChoice {
|
||||
message: OpenAiMessage,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiUsage {
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
total_tokens: u32,
|
||||
}
|
||||
|
||||
let request = OpenAiRequest {
|
||||
model: config.model.clone(),
|
||||
messages,
|
||||
max_tokens: config.max_tokens,
|
||||
temperature: config.temperature,
|
||||
};
|
||||
|
||||
let response = self.http_client
|
||||
.post(format!("{}/chat/completions", base_url))
|
||||
.header(AUTHORIZATION, format!("Bearer {}", config.api_key))
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
self.handle_response(response).await
|
||||
}
|
||||
|
||||
/// OpenAI 流式聊天请求
|
||||
async fn openai_chat_stream(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
mut callback: impl FnMut(StreamChunk),
|
||||
) -> Result<AiResult> {
|
||||
let config = self.openai_config.as_ref()
|
||||
.ok_or_else(|| AiError::ConfigError("OpenAI 配置未设置".to_string()))?;
|
||||
|
||||
let base_url = config.base_url.as_ref()
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("https://api.openai.com/v1");
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAiRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
let request = OpenAiRequest {
|
||||
model: config.model.clone(),
|
||||
messages,
|
||||
max_tokens: config.max_tokens,
|
||||
temperature: config.temperature,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let response = self.http_client
|
||||
.post(format!("{}/chat/completions", base_url))
|
||||
.header(AUTHORIZATION, format!("Bearer {}", config.api_key))
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await?;
|
||||
callback(StreamChunk::Error(error_text.clone()));
|
||||
return Err(AiError::ApiError(error_text));
|
||||
}
|
||||
|
||||
let mut full_content = String::new();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result?;
|
||||
let text = String::from_utf8_lossy(&chunk);
|
||||
|
||||
for line in text.lines() {
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line[6..];
|
||||
if data == "[DONE]" {
|
||||
callback(StreamChunk::Done);
|
||||
break;
|
||||
}
|
||||
|
||||
if let Ok(event) = serde_json::from_str::<serde_json::Value>(data) {
|
||||
if let Some(delta) = event.get("choices")
|
||||
.and_then(|c| c.get(0))
|
||||
.and_then(|c| c.get("delta"))
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|t| t.as_str())
|
||||
{
|
||||
full_content.push_str(delta);
|
||||
callback(StreamChunk::Text(delta.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AiResult {
|
||||
content: full_content,
|
||||
tokens_used: None,
|
||||
model: config.model.clone(),
|
||||
confidence: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// 处理 HTTP 响应
|
||||
async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
response.json::<T>().await.map_err(Into::into)
|
||||
} else {
|
||||
let error_text = response.text().await?;
|
||||
Err(if status.as_u16() == 401 {
|
||||
AiError::AuthError(error_text)
|
||||
} else if status.as_u16() == 429 {
|
||||
AiError::RateLimitError
|
||||
} else {
|
||||
AiError::ApiError(error_text)
|
||||
}.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_claude_config_serialization() {
|
||||
let config = ClaudeConfig {
|
||||
api_key: "test-key".to_string(),
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
max_tokens: 4096,
|
||||
temperature: 0.7,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
assert!(json.contains("api_key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization() {
|
||||
let message = Message {
|
||||
role: MessageRole::User,
|
||||
content: "Hello, AI!".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&message).unwrap();
|
||||
assert!(json.contains("user"));
|
||||
}
|
||||
}
|
||||
76
src-tauri/src/ai/mod.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
pub mod client;
|
||||
pub mod prompt;
|
||||
pub mod classify;
|
||||
pub mod template;
|
||||
|
||||
pub use client::{AiClient, AiProvider, StreamChunk};
|
||||
pub use prompt::PromptEngine;
|
||||
pub use classify::Classifier;
|
||||
pub use template::{Template, TemplateManager, TemplateVariable};
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
/// AI 服务错误类型
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AiError {
|
||||
#[error("API 错误: {0}")]
|
||||
ApiError(String),
|
||||
|
||||
#[error("网络错误: {0}")]
|
||||
NetworkError(String),
|
||||
|
||||
#[error("认证失败: {0}")]
|
||||
AuthError(String),
|
||||
|
||||
#[error("限流: 请稍后再试")]
|
||||
RateLimitError,
|
||||
|
||||
#[error("配置错误: {0}")]
|
||||
ConfigError(String),
|
||||
|
||||
#[error("模板错误: {0}")]
|
||||
TemplateError(String),
|
||||
|
||||
#[error("其他错误: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for AiError {
|
||||
fn from(err: reqwest::Error) -> Self {
|
||||
if err.is_timeout() {
|
||||
AiError::NetworkError("请求超时".to_string())
|
||||
} else if err.is_connect() {
|
||||
AiError::NetworkError("连接失败".to_string())
|
||||
} else {
|
||||
AiError::NetworkError(err.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// AI 服务结果
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AiResult {
|
||||
/// 生成的文本内容
|
||||
pub content: String,
|
||||
/// 使用的 Token 数量(估算)
|
||||
pub tokens_used: Option<usize>,
|
||||
/// 模型名称
|
||||
pub model: String,
|
||||
/// 置信度评分(0.0 - 1.0)
|
||||
pub confidence: Option<f64>,
|
||||
}
|
||||
|
||||
/// 分类结果
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ClassificationResult {
|
||||
/// 主分类
|
||||
pub category: String,
|
||||
/// 子分类
|
||||
pub subcategory: Option<String>,
|
||||
/// 标签列表
|
||||
pub tags: Vec<String>,
|
||||
/// 置信度评分(0.0 - 1.0)
|
||||
pub confidence: f64,
|
||||
/// AI 提供的推理说明
|
||||
pub reasoning: Option<String>,
|
||||
}
|
||||
439
src-tauri/src/ai/prompt.rs
Normal file
@@ -0,0 +1,439 @@
|
||||
use super::template::{Template, TemplateVariable};
|
||||
use anyhow::{Context, Result};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Prompt 模板引擎
|
||||
pub struct PromptEngine {
|
||||
/// 内置模板库
|
||||
builtin_templates: HashMap<String, Template>,
|
||||
}
|
||||
|
||||
impl PromptEngine {
|
||||
/// 创建新的 Prompt 引擎
|
||||
pub fn new() -> Self {
|
||||
let mut engine = Self {
|
||||
builtin_templates: HashMap::new(),
|
||||
};
|
||||
|
||||
// 初始化内置模板
|
||||
engine.init_builtin_templates();
|
||||
|
||||
engine
|
||||
}
|
||||
|
||||
/// 初始化内置模板
|
||||
fn init_builtin_templates(&mut self) {
|
||||
// 通用分类模板
|
||||
let general_template = Template {
|
||||
id: "builtin_general".to_string(),
|
||||
name: "通用分类".to_string(),
|
||||
description: "适用于大多数场景的内容分类".to_string(),
|
||||
category: "general".to_string(),
|
||||
system_prompt: r#"你是一个专业的内容分类助手。你的任务是根据提供的内容,将其归类到合适的类别中。
|
||||
|
||||
请遵循以下规则:
|
||||
1. 分析内容的主要主题和目的
|
||||
2. 选择最合适的主分类
|
||||
3. 如果可能,提供更具体的子分类
|
||||
4. 提取 3-5 个相关标签
|
||||
5. 给出你的置信度评分(0.0-1.0)
|
||||
|
||||
请以 JSON 格式返回结果:
|
||||
```json
|
||||
{
|
||||
"category": "主分类名称",
|
||||
"subcategory": "子分类名称(可选)",
|
||||
"tags": ["标签1", "标签2", "标签3"],
|
||||
"confidence": 0.95,
|
||||
"reasoning": "分类理由说明"
|
||||
}
|
||||
```"#
|
||||
.to_string(),
|
||||
user_prompt_template: r#"请对以下内容进行分类:
|
||||
|
||||
内容类型:{{content_type}}
|
||||
{{#if image_path}}
|
||||
包含图片:是
|
||||
{{/if}}
|
||||
{{#if ocr_text}}
|
||||
OCR 识别文本:
|
||||
{{ocr_text}}
|
||||
{{/if}}
|
||||
{{#if content}}
|
||||
文本内容:
|
||||
{{content}}
|
||||
{{/if}}
|
||||
|
||||
请分析并返回分类结果。"#.to_string(),
|
||||
variables: vec![
|
||||
TemplateVariable {
|
||||
name: "content_type".to_string(),
|
||||
description: "内容类型(text/image/file)".to_string(),
|
||||
required: true,
|
||||
default_value: Some("text".to_string()),
|
||||
},
|
||||
TemplateVariable {
|
||||
name: "content".to_string(),
|
||||
description: "文本内容".to_string(),
|
||||
required: false,
|
||||
default_value: None,
|
||||
},
|
||||
TemplateVariable {
|
||||
name: "ocr_text".to_string(),
|
||||
description: "OCR 识别的文本".to_string(),
|
||||
required: false,
|
||||
default_value: None,
|
||||
},
|
||||
TemplateVariable {
|
||||
name: "image_path".to_string(),
|
||||
description: "图片路径".to_string(),
|
||||
required: false,
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
self.builtin_templates.insert("general".to_string(), general_template);
|
||||
|
||||
// 代码分类模板
|
||||
let code_template = Template {
|
||||
id: "builtin_code".to_string(),
|
||||
name: "代码片段分类".to_string(),
|
||||
description: "专门用于分类和分析代码片段".to_string(),
|
||||
category: "code".to_string(),
|
||||
system_prompt: r#"你是一个专业的代码分析助手。你的任务是分析提供的代码片段,识别其编程语言、用途和特点。
|
||||
|
||||
请关注:
|
||||
1. 编程语言和框架
|
||||
2. 代码的主要功能
|
||||
3. 是否是完整代码还是片段
|
||||
4. 代码所属的领域(Web开发、数据处理、算法等)
|
||||
|
||||
请以 JSON 格式返回:
|
||||
```json
|
||||
{
|
||||
"category": "代码",
|
||||
"subcategory": "具体编程语言或框架",
|
||||
"tags": ["语言", "框架", "功能领域"],
|
||||
"confidence": 0.95,
|
||||
"reasoning": "分析说明"
|
||||
}
|
||||
```"#
|
||||
.to_string(),
|
||||
user_prompt_template: r#"请分析以下代码:
|
||||
|
||||
{{content}}
|
||||
|
||||
{{#if ocr_text}}
|
||||
如果上述代码是从图片中识别的,请使用以下 OCR 结果:
|
||||
{{ocr_text}}
|
||||
{{/if}}
|
||||
|
||||
请返回编程语言、框架和功能分析。"#.to_string(),
|
||||
variables: vec![
|
||||
TemplateVariable {
|
||||
name: "content".to_string(),
|
||||
description: "代码内容".to_string(),
|
||||
required: true,
|
||||
default_value: None,
|
||||
},
|
||||
TemplateVariable {
|
||||
name: "ocr_text".to_string(),
|
||||
description: "OCR 识别的代码文本".to_string(),
|
||||
required: false,
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
self.builtin_templates.insert("code".to_string(), code_template);
|
||||
|
||||
// 票据/发票分类模板
|
||||
let invoice_template = Template {
|
||||
id: "builtin_invoice".to_string(),
|
||||
name: "票据发票分类".to_string(),
|
||||
description: "用于分类和提取票据、发票信息".to_string(),
|
||||
category: "invoice".to_string(),
|
||||
system_prompt: r#"你是一个专业的票据识别助手。你的任务是识别和分析票据、发票类型,并提取关键信息。
|
||||
|
||||
请关注:
|
||||
1. 票据类型(发票、收据、订单等)
|
||||
2. 金额信息
|
||||
3. 商户信息
|
||||
4. 日期信息
|
||||
5. 票据的用途和性质
|
||||
|
||||
请以 JSON 格式返回:
|
||||
```json
|
||||
{
|
||||
"category": "票据",
|
||||
"subcategory": "具体票据类型",
|
||||
"tags": ["金额范围", "商户类型", "用途"],
|
||||
"confidence": 0.95,
|
||||
"reasoning": "识别说明"
|
||||
}
|
||||
```"#
|
||||
.to_string(),
|
||||
user_prompt_template: r#"请识别以下票据:
|
||||
|
||||
{{#if ocr_text}}
|
||||
票据内容:
|
||||
{{ocr_text}}
|
||||
{{/if}}
|
||||
|
||||
{{#if image_path}}
|
||||
这是一个包含票据的图片。
|
||||
{{/if}}
|
||||
|
||||
请返回票据类型和关键信息。"#.to_string(),
|
||||
variables: vec![
|
||||
TemplateVariable {
|
||||
name: "ocr_text".to_string(),
|
||||
description: "OCR 识别的票据文本".to_string(),
|
||||
required: true,
|
||||
default_value: None,
|
||||
},
|
||||
TemplateVariable {
|
||||
name: "image_path".to_string(),
|
||||
description: "票据图片路径".to_string(),
|
||||
required: false,
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
self.builtin_templates.insert("invoice".to_string(), invoice_template);
|
||||
|
||||
// 对话分类模板
|
||||
let conversation_template = Template {
|
||||
id: "builtin_conversation".to_string(),
|
||||
name: "对话内容分类".to_string(),
|
||||
description: "用于分类聊天记录、对话内容".to_string(),
|
||||
category: "conversation".to_string(),
|
||||
system_prompt: r#"你是一个专业的对话分析助手。你的任务是分析对话内容,识别对话的性质和主题。
|
||||
|
||||
请关注:
|
||||
1. 对话的性质(工作、休闲、客服、技术支持等)
|
||||
2. 主要话题
|
||||
3. 参与者角色
|
||||
4. 情感倾向(可选)
|
||||
|
||||
请以 JSON 格式返回:
|
||||
```json
|
||||
{
|
||||
"category": "对话",
|
||||
"subcategory": "对话性质",
|
||||
"tags": ["话题1", "话题2", "参与者类型"],
|
||||
"confidence": 0.95,
|
||||
"reasoning": "分析说明"
|
||||
}
|
||||
```"#
|
||||
.to_string(),
|
||||
user_prompt_template: r#"请分析以下对话:
|
||||
|
||||
{{content}}
|
||||
|
||||
{{#if ocr_text}}
|
||||
如果上述对话是从图片中识别的,请使用以下 OCR 结果:
|
||||
{{ocr_text}}
|
||||
{{/if}}
|
||||
|
||||
请返回对话类型和主题分析。"#.to_string(),
|
||||
variables: vec![
|
||||
TemplateVariable {
|
||||
name: "content".to_string(),
|
||||
description: "对话文本内容".to_string(),
|
||||
required: true,
|
||||
default_value: None,
|
||||
},
|
||||
TemplateVariable {
|
||||
name: "ocr_text".to_string(),
|
||||
description: "OCR 识别的对话文本".to_string(),
|
||||
required: false,
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
self.builtin_templates.insert("conversation".to_string(), conversation_template);
|
||||
}
|
||||
|
||||
/// 渲染模板
|
||||
pub fn render_template(
|
||||
&self,
|
||||
template: &Template,
|
||||
variables: &HashMap<String, String>,
|
||||
) -> Result<(String, String)> {
|
||||
// 渲染 system prompt
|
||||
let system_prompt = template.system_prompt.clone();
|
||||
|
||||
// 渲染 user prompt
|
||||
let user_prompt = self.render_user_prompt(&template.user_prompt_template, variables)?;
|
||||
|
||||
Ok((system_prompt, user_prompt))
|
||||
}
|
||||
|
||||
/// 渲染用户提示词模板
|
||||
fn render_user_prompt(
|
||||
&self,
|
||||
template: &str,
|
||||
variables: &HashMap<String, String>,
|
||||
) -> Result<String> {
|
||||
let mut result = template.to_string();
|
||||
|
||||
// 处理条件块 {{#if var}}...{{/if}}
|
||||
let mut replaced = String::new();
|
||||
let mut chars = template.chars().peekable();
|
||||
let mut pos = 0;
|
||||
|
||||
while pos < template.len() {
|
||||
// 查找 {{#if
|
||||
if let Some(start) = template[pos..].find("{{#if ") {
|
||||
let block_start = pos + start;
|
||||
let var_start = block_start + 6; // "{{#if ".len()
|
||||
|
||||
// 查找闭合的 }}
|
||||
if let Some(var_end) = template[var_start..].find("}}") {
|
||||
let var_name = &template[var_start..var_start + var_end];
|
||||
let content_start = var_start + var_end + 2;
|
||||
|
||||
// 查找 {{/if}}
|
||||
if let Some(end_marker) = template[content_start..].find("{{/if}}") {
|
||||
let content_end = content_start + end_marker;
|
||||
let block_content = &template[content_start..content_end];
|
||||
|
||||
// 如果变量存在且有值,则包含内容
|
||||
if let Some(value) = variables.get(var_name) {
|
||||
if !value.is_empty() {
|
||||
replaced.push_str(&template[pos..block_start]);
|
||||
replaced.push_str(block_content);
|
||||
}
|
||||
}
|
||||
|
||||
pos = content_end + 7; // "{{/if}}".len()
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 替换简单变量 {{var}}
|
||||
if let Some(start) = template[pos..].find('{{') {
|
||||
replaced.push_str(&template[pos..pos + start]);
|
||||
let var_start = pos + start + 2;
|
||||
|
||||
if let Some(end) = template[var_start..].find("}}") {
|
||||
let var_name = &template[var_start..var_start + end];
|
||||
let trimmed = var_name.trim();
|
||||
|
||||
if let Some(value) = variables.get(trimmed) {
|
||||
replaced.push_str(value);
|
||||
} else if let Some(template_var) = self
|
||||
.builtin_templates
|
||||
.values()
|
||||
.flat_map(|t| t.variables.iter())
|
||||
.find(|v| v.name == trimmed)
|
||||
{
|
||||
// 使用默认值
|
||||
if let Some(default) = &template_var.default_value {
|
||||
replaced.push_str(default);
|
||||
}
|
||||
}
|
||||
|
||||
pos = var_start + end + 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if pos < template.len() {
|
||||
replaced.push(template.chars().nth(pos).unwrap());
|
||||
pos += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 第二次遍历:替换剩余的简单变量
|
||||
result = replaced.clone();
|
||||
for (key, value) in variables.iter() {
|
||||
let placeholder = format!("{{{{{}}}}}", key);
|
||||
result = result.replace(&placeholder, value);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 获取内置模板
|
||||
pub fn get_template(&self, id: &str) -> Option<&Template> {
|
||||
self.builtin_templates.get(id)
|
||||
}
|
||||
|
||||
/// 列出所有内置模板
|
||||
pub fn list_templates(&self) -> Vec<&Template> {
|
||||
self.builtin_templates.values().collect()
|
||||
}
|
||||
|
||||
/// 获取默认模板
|
||||
pub fn get_default_template(&self) -> &Template {
|
||||
self.builtin_templates.get("general").unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PromptEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_prompt_engine_creation() {
|
||||
let engine = PromptEngine::new();
|
||||
assert!(engine.get_template("general").is_some());
|
||||
assert!(engine.get_template("code").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_rendering() {
|
||||
let engine = PromptEngine::new();
|
||||
let template = engine.get_default_template();
|
||||
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("content_type".to_string(), "text".to_string());
|
||||
vars.insert("content".to_string(), "Hello, world!".to_string());
|
||||
|
||||
let (system, user) = engine.render_template(template, &vars).unwrap();
|
||||
assert!(!system.is_empty());
|
||||
assert!(!user.is_empty());
|
||||
assert!(user.contains("Hello, world!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conditional_blocks() {
|
||||
let engine = PromptEngine::new();
|
||||
let template = engine.get_default_template();
|
||||
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("content_type".to_string(), "text".to_string());
|
||||
|
||||
// 没有 ocr_text,条件块应该被移除
|
||||
let (system, user) = engine.render_template(template, &vars).unwrap();
|
||||
assert!(!user.contains("OCR 识别文本"));
|
||||
|
||||
// 有 ocr_text,条件块应该保留
|
||||
vars.insert("ocr_text".to_string(), "Sample OCR text".to_string());
|
||||
let (system, user) = engine.render_template(template, &vars).unwrap();
|
||||
assert!(user.contains("OCR 识别文本"));
|
||||
assert!(user.contains("Sample OCR text"));
|
||||
}
|
||||
}
|
||||
343
src-tauri/src/ai/template.rs
Normal file
@@ -0,0 +1,343 @@
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// 模板变量
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TemplateVariable {
|
||||
/// 变量名
|
||||
pub name: String,
|
||||
/// 变量描述
|
||||
pub description: String,
|
||||
/// 是否必需
|
||||
pub required: bool,
|
||||
/// 默认值
|
||||
pub default_value: Option<String>,
|
||||
}
|
||||
|
||||
/// Prompt 模板
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Template {
|
||||
/// 模板 ID
|
||||
pub id: String,
|
||||
/// 模板名称
|
||||
pub name: String,
|
||||
/// 模板描述
|
||||
pub description: String,
|
||||
/// 模板分类
|
||||
pub category: String,
|
||||
/// 系统 Prompt
|
||||
pub system_prompt: String,
|
||||
/// 用户 Prompt 模板
|
||||
pub user_prompt_template: String,
|
||||
/// 模板变量列表
|
||||
pub variables: Vec<TemplateVariable>,
|
||||
/// 创建时间
|
||||
pub created_at: String,
|
||||
/// 更新时间
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
/// 模板管理器
|
||||
pub struct TemplateManager {
|
||||
/// 自定义模板存储路径
|
||||
templates_dir: PathBuf,
|
||||
/// 模板缓存
|
||||
templates: Mutex<HashMap<String, Template>>,
|
||||
}
|
||||
|
||||
impl TemplateManager {
|
||||
/// 创建新的模板管理器
|
||||
pub fn new(templates_dir: PathBuf) -> Result<Self> {
|
||||
// 确保模板目录存在
|
||||
std::fs::create_dir_all(&templates_dir)
|
||||
.context("Failed to create templates directory")?;
|
||||
|
||||
let manager = Self {
|
||||
templates_dir,
|
||||
templates: Mutex::new(HashMap::new()),
|
||||
};
|
||||
|
||||
// 加载所有模板
|
||||
manager.load_templates()?;
|
||||
|
||||
Ok(manager)
|
||||
}
|
||||
|
||||
/// 加载所有模板
|
||||
fn load_templates(&self) -> Result<()> {
|
||||
let mut templates = self.templates.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
templates.clear();
|
||||
|
||||
// 读取目录中的所有 JSON 文件
|
||||
let entries = std::fs::read_dir(&self.templates_dir)
|
||||
.context("Failed to read templates directory")?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry.context("Failed to read directory entry")?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("json") {
|
||||
// 跳过备份文件
|
||||
if path.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.starts_with('.'))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// 读取模板文件
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.context("Failed to read template file")?;
|
||||
|
||||
// 解析模板
|
||||
let template: Template = serde_json::from_str(&content)
|
||||
.context("Failed to parse template")?;
|
||||
|
||||
templates.insert(template.id.clone(), template);
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Loaded {} templates", templates.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 保存模板
|
||||
fn save_template(&self, template: &Template) -> Result<()> {
|
||||
let file_path = self.templates_dir.join(format!("{}.json", template.id));
|
||||
|
||||
let content = serde_json::to_string_pretty(template)
|
||||
.context("Failed to serialize template")?;
|
||||
|
||||
std::fs::write(&file_path, content)
|
||||
.context("Failed to write template file")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 添加或更新模板
|
||||
pub fn upsert_template(&self, template: Template) -> Result<()> {
|
||||
// 验证模板
|
||||
self.validate_template(&template)?;
|
||||
|
||||
// 保存到文件
|
||||
self.save_template(&template)?;
|
||||
|
||||
// 更新缓存
|
||||
let mut templates = self.templates.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
templates.insert(template.id.clone(), template);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取模板
|
||||
pub fn get_template(&self, id: &str) -> Option<Template> {
|
||||
let templates = self.templates.lock()
|
||||
.ok()?;
|
||||
templates.get(id).cloned()
|
||||
}
|
||||
|
||||
/// 列出所有模板
|
||||
pub fn list_templates(&self) -> Result<Vec<Template>> {
|
||||
let templates = self.templates.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let mut list: Vec<_> = templates.values().cloned().collect();
|
||||
list.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
Ok(list)
|
||||
}
|
||||
|
||||
/// 按分类列出模板
|
||||
pub fn list_templates_by_category(&self, category: &str) -> Result<Vec<Template>> {
|
||||
let templates = self.templates.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let mut list: Vec<_> = templates
|
||||
.values()
|
||||
.filter(|t| t.category == category)
|
||||
.cloned()
|
||||
.collect();
|
||||
list.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
Ok(list)
|
||||
}
|
||||
|
||||
/// 删除模板
|
||||
pub fn delete_template(&self, id: &str) -> Result<bool> {
|
||||
// 删除文件
|
||||
let file_path = self.templates_dir.join(format!("{}.json", id));
|
||||
|
||||
let mut deleted = false;
|
||||
if file_path.exists() {
|
||||
std::fs::remove_file(&file_path)
|
||||
.context("Failed to delete template file")?;
|
||||
deleted = true;
|
||||
}
|
||||
|
||||
// 从缓存中移除
|
||||
let mut templates = self.templates.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
Ok(templates.remove(id).is_some() || deleted)
|
||||
}
|
||||
|
||||
/// 导入模板
|
||||
pub fn import_template(&self, json_data: &str) -> Result<Template> {
|
||||
let template: Template = serde_json::from_str(json_data)
|
||||
.context("Failed to parse template JSON")?;
|
||||
|
||||
self.upsert_template(template.clone())?;
|
||||
|
||||
Ok(template)
|
||||
}
|
||||
|
||||
/// 导出模板
|
||||
pub fn export_template(&self, id: &str) -> Result<String> {
|
||||
let template = self.get_template(id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Template not found: {}", id))?;
|
||||
|
||||
serde_json::to_string_pretty(&template)
|
||||
.context("Failed to serialize template")
|
||||
}
|
||||
|
||||
/// 验证模板
|
||||
fn validate_template(&self, template: &Template) -> Result<()> {
|
||||
if template.id.is_empty() {
|
||||
return Err(anyhow::anyhow!("Template ID cannot be empty"));
|
||||
}
|
||||
|
||||
if template.name.is_empty() {
|
||||
return Err(anyhow::anyhow!("Template name cannot be empty"));
|
||||
}
|
||||
|
||||
if template.system_prompt.is_empty() {
|
||||
return Err(anyhow::anyhow!("System prompt cannot be empty"));
|
||||
}
|
||||
|
||||
if template.user_prompt_template.is_empty() {
|
||||
return Err(anyhow::anyhow!("User prompt template cannot be empty"));
|
||||
}
|
||||
|
||||
// 验证变量名称格式
|
||||
for var in &template.variables {
|
||||
if var.name.is_empty() {
|
||||
return Err(anyhow::anyhow!("Variable name cannot be empty"));
|
||||
}
|
||||
|
||||
// 变量名只能包含字母、数字和下划线
|
||||
if !var.name.chars().all(|c| c.is_alphanumeric() || c == '_') {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid variable name '{}': only alphanumeric and underscore allowed",
|
||||
var.name
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 测试模板渲染
|
||||
pub fn test_template(
|
||||
&self,
|
||||
template_id: &str,
|
||||
variables: &HashMap<String, String>,
|
||||
) -> Result<(String, String)> {
|
||||
let template = self.get_template(template_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Template not found: {}", template_id))?;
|
||||
|
||||
// 渲染系统 Prompt
|
||||
let system_prompt = template.system_prompt.clone();
|
||||
|
||||
// 渲染用户 Prompt
|
||||
let user_prompt = self.render_user_prompt(&template.user_prompt_template, variables)?;
|
||||
|
||||
Ok((system_prompt, user_prompt))
|
||||
}
|
||||
|
||||
/// 渲染用户 Prompt
|
||||
fn render_user_prompt(
|
||||
&self,
|
||||
template: &str,
|
||||
variables: &HashMap<String, String>,
|
||||
) -> Result<String> {
|
||||
let mut result = template.to_string();
|
||||
|
||||
// 简单变量替换
|
||||
for (key, value) in variables.iter() {
|
||||
let placeholder = format!("{{{{{}}}}}", key);
|
||||
result = result.replace(&placeholder, value);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_template_serialization() {
|
||||
let template = Template {
|
||||
id: "test".to_string(),
|
||||
name: "Test Template".to_string(),
|
||||
description: "A test template".to_string(),
|
||||
category: "test".to_string(),
|
||||
system_prompt: "You are a test assistant.".to_string(),
|
||||
user_prompt_template: "Hello {{name}}!".to_string(),
|
||||
variables: vec![],
|
||||
created_at: Utc::now().to_rfc3339(),
|
||||
updated_at: Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&template).unwrap();
|
||||
assert!(json.contains("test"));
|
||||
|
||||
let parsed: Template = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.id, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_validation() {
|
||||
let manager = TemplateManager::new(
|
||||
std::env::temp_dir().join("test_templates")
|
||||
).unwrap();
|
||||
|
||||
let valid_template = Template {
|
||||
id: "valid".to_string(),
|
||||
name: "Valid Template".to_string(),
|
||||
description: "A valid template".to_string(),
|
||||
category: "test".to_string(),
|
||||
system_prompt: "System prompt".to_string(),
|
||||
user_prompt_template: "User prompt".to_string(),
|
||||
variables: vec![],
|
||||
created_at: Utc::now().to_rfc3339(),
|
||||
updated_at: Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
assert!(manager.validate_template(&valid_template).is_ok());
|
||||
|
||||
let invalid_template = Template {
|
||||
id: "".to_string(),
|
||||
name: "Invalid Template".to_string(),
|
||||
description: "An invalid template".to_string(),
|
||||
category: "test".to_string(),
|
||||
system_prompt: "System prompt".to_string(),
|
||||
user_prompt_template: "User prompt".to_string(),
|
||||
variables: vec![],
|
||||
created_at: Utc::now().to_rfc3339(),
|
||||
updated_at: Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
assert!(manager.validate_template(&invalid_template).is_err());
|
||||
}
|
||||
}
|
||||
205
src-tauri/src/config.rs
Normal file
@@ -0,0 +1,205 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// 图床配置类型
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ImageHostConfig {
|
||||
GitHub {
|
||||
token: String,
|
||||
owner: String,
|
||||
repo: String,
|
||||
path: String,
|
||||
branch: Option<String>,
|
||||
},
|
||||
Imgur {
|
||||
client_id: String,
|
||||
},
|
||||
Custom {
|
||||
url: String,
|
||||
headers: Option<Vec<HeaderItem>>,
|
||||
form_field: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// HTTP 头部项
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HeaderItem {
|
||||
pub name: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
/// 应用配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
/// 默认图床配置
|
||||
pub default_image_host: Option<ImageHostConfig>,
|
||||
/// 可用的图床配置列表
|
||||
pub image_hosts: Vec<ImageHostConfig>,
|
||||
/// 上传重试次数
|
||||
pub upload_retry_count: u32,
|
||||
/// 上传超时时间(秒)
|
||||
pub upload_timeout_seconds: u64,
|
||||
/// 是否自动复制上传后的链接
|
||||
pub auto_copy_link: bool,
|
||||
/// 保留的截图数量
|
||||
pub keep_screenshots_count: usize,
|
||||
/// 数据库路径
|
||||
pub database_path: Option<PathBuf>,
|
||||
/// OCR 配置
|
||||
pub ocr_config: Option<OcrAppConfig>,
|
||||
}
|
||||
|
||||
/// OCR 应用配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrAppConfig {
|
||||
/// 默认 OCR 引擎
|
||||
pub default_engine: String,
|
||||
/// 是否自动复制 OCR 结果
|
||||
pub auto_copy_result: bool,
|
||||
/// OCR 结果保留天数
|
||||
pub keep_results_days: u32,
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_image_host: None,
|
||||
image_hosts: Vec::new(),
|
||||
upload_retry_count: 3,
|
||||
upload_timeout_seconds: 30,
|
||||
auto_copy_link: true,
|
||||
keep_screenshots_count: 50,
|
||||
database_path: None,
|
||||
ocr_config: Some(OcrAppConfig {
|
||||
default_engine: "baidu".to_string(),
|
||||
auto_copy_result: false,
|
||||
keep_results_days: 30,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 配置管理器
|
||||
pub struct ConfigManager {
|
||||
config_dir: PathBuf,
|
||||
config_file: PathBuf,
|
||||
}
|
||||
|
||||
impl ConfigManager {
|
||||
/// 创建新的配置管理器
|
||||
pub fn new() -> Result<Self> {
|
||||
let config_dir = Self::get_config_dir()?;
|
||||
let config_file = config_dir.join("config.json");
|
||||
|
||||
// 确保配置目录存在
|
||||
fs::create_dir_all(&config_dir)
|
||||
.context("Failed to create config directory")?;
|
||||
|
||||
Ok(Self {
|
||||
config_dir,
|
||||
config_file,
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取配置目录
|
||||
fn get_config_dir() -> Result<PathBuf> {
|
||||
let config_dir = if cfg!(target_os = "macos") {
|
||||
// macOS: ~/Library/Application Support/CutThenThink
|
||||
dirs::config_dir()
|
||||
.map(|p| p.join("CutThenThink"))
|
||||
} else if cfg!(target_os = "windows") {
|
||||
// Windows: %APPDATA%/CutThenThink
|
||||
dirs::config_dir()
|
||||
.map(|p| p.join("CutThenThink"))
|
||||
} else {
|
||||
// Linux: ~/.config/CutThenThink
|
||||
dirs::home_dir()
|
||||
.map(|p| p.join(".config").join("CutThenThink"))
|
||||
};
|
||||
|
||||
config_dir.context("Failed to determine config directory")
|
||||
}
|
||||
|
||||
/// 加载配置
|
||||
pub fn load(&self) -> Result<AppConfig> {
|
||||
if !self.config_file.exists() {
|
||||
// 如果配置文件不存在,创建默认配置
|
||||
let default_config = AppConfig::default();
|
||||
self.save(&default_config)?;
|
||||
return Ok(default_config);
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&self.config_file)
|
||||
.context("Failed to read config file")?;
|
||||
|
||||
let config: AppConfig = serde_json::from_str(&content)
|
||||
.context("Failed to parse config file")?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// 保存配置
|
||||
pub fn save(&self, config: &AppConfig) -> Result<()> {
|
||||
let content = serde_json::to_string_pretty(config)
|
||||
.context("Failed to serialize config")?;
|
||||
|
||||
fs::write(&self.config_file, content)
|
||||
.context("Failed to write config file")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取配置文件路径
|
||||
pub fn config_file_path(&self) -> &Path {
|
||||
&self.config_file
|
||||
}
|
||||
|
||||
/// 获取数据目录
|
||||
pub fn data_dir(&self) -> PathBuf {
|
||||
self.config_dir.join("data")
|
||||
}
|
||||
|
||||
/// 获取数据库路径
|
||||
pub fn database_path(&self) -> PathBuf {
|
||||
self.data_dir().join("cutthink.db")
|
||||
}
|
||||
|
||||
/// 确保数据目录存在
|
||||
pub fn ensure_data_dir(&self) -> Result<()> {
|
||||
let data_dir = self.data_dir();
|
||||
fs::create_dir_all(&data_dir)
|
||||
.context("Failed to create data directory")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let config = AppConfig::default();
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
println!("{}", json);
|
||||
assert!(json.contains("upload_retry_count"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_host_config() {
|
||||
let github_config = ImageHostConfig::GitHub {
|
||||
token: "test_token".to_string(),
|
||||
owner: "test_owner".to_string(),
|
||||
repo: "test_repo".to_string(),
|
||||
path: "screenshots".to_string(),
|
||||
branch: Some("main".to_string()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&github_config).unwrap();
|
||||
println!("{}", json);
|
||||
assert!(json.contains("github"));
|
||||
}
|
||||
}
|
||||
749
src-tauri/src/database.rs
Normal file
@@ -0,0 +1,749 @@
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use rusqlite::{params, Connection};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// 数据库记录
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Record {
|
||||
pub id: String,
|
||||
pub record_type: RecordType,
|
||||
pub content: String,
|
||||
pub file_path: Option<String>,
|
||||
pub thumbnail: Option<String>,
|
||||
pub metadata: Option<String>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
/// 记录类型
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RecordType {
|
||||
Image,
|
||||
Text,
|
||||
File,
|
||||
}
|
||||
|
||||
/// 设置项
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Setting {
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
/// 分类结果
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Classification {
|
||||
pub id: String,
|
||||
pub record_id: String,
|
||||
pub category: String,
|
||||
pub subcategory: Option<String>,
|
||||
pub tags: String, // JSON 数组
|
||||
pub confidence: f64,
|
||||
pub reasoning: Option<String>,
|
||||
pub template_id: Option<String>,
|
||||
pub confirmed: bool,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// 分类历史
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClassificationHistory {
|
||||
pub id: String,
|
||||
pub record_id: String,
|
||||
pub category: String,
|
||||
pub subcategory: Option<String>,
|
||||
pub confidence: f64,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// 数据库管理器
|
||||
pub struct Database {
|
||||
conn: Mutex<Connection>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// 打开数据库连接
|
||||
pub fn open(db_path: &Path) -> Result<Self> {
|
||||
// 确保父目录存在
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)
|
||||
.context("Failed to create database directory")?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)
|
||||
.context("Failed to open database")?;
|
||||
|
||||
// 启用外键约束
|
||||
conn.execute("PRAGMA foreign_keys = ON", [])
|
||||
.context("Failed to enable foreign keys")?;
|
||||
|
||||
let db = Self {
|
||||
conn: Mutex::new(conn),
|
||||
};
|
||||
|
||||
// 初始化数据库表
|
||||
db.init_tables()?;
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// 初始化数据库表
|
||||
fn init_tables(&self) -> Result<()> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
// 创建 records 表
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS records (
|
||||
id TEXT PRIMARY KEY,
|
||||
record_type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
file_path TEXT,
|
||||
thumbnail TEXT,
|
||||
metadata TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)",
|
||||
[],
|
||||
).context("Failed to create records table")?;
|
||||
|
||||
// 创建 settings 表
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)",
|
||||
[],
|
||||
).context("Failed to create settings table")?;
|
||||
|
||||
// 创建 classifications 表
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS classifications (
|
||||
id TEXT PRIMARY KEY,
|
||||
record_id TEXT NOT NULL,
|
||||
category TEXT NOT NULL,
|
||||
subcategory TEXT,
|
||||
tags TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
reasoning TEXT,
|
||||
template_id TEXT,
|
||||
confirmed BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (record_id) REFERENCES records(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
).context("Failed to create classifications table")?;
|
||||
|
||||
// 创建 classification_history 表
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS classification_history (
|
||||
id TEXT PRIMARY KEY,
|
||||
record_id TEXT NOT NULL,
|
||||
category TEXT NOT NULL,
|
||||
subcategory TEXT,
|
||||
confidence REAL NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (record_id) REFERENCES records(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
).context("Failed to create classification_history table")?;
|
||||
|
||||
// 创建索引
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_records_created_at
|
||||
ON records(created_at DESC)",
|
||||
[],
|
||||
).context("Failed to create index on records.created_at")?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_records_type
|
||||
ON records(record_type)",
|
||||
[],
|
||||
).context("Failed to create index on records.record_type")?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_classifications_record_id
|
||||
ON classifications(record_id)",
|
||||
[],
|
||||
).context("Failed to create index on classifications.record_id")?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_classifications_category
|
||||
ON classifications(category)",
|
||||
[],
|
||||
).context("Failed to create index on classifications.category")?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_classification_history_record_id
|
||||
ON classification_history(record_id)",
|
||||
[],
|
||||
).context("Failed to create index on classification_history.record_id")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 插入记录
|
||||
pub fn insert_record(
|
||||
&self,
|
||||
record_type: RecordType,
|
||||
content: &str,
|
||||
file_path: Option<&str>,
|
||||
thumbnail: Option<&str>,
|
||||
metadata: Option<&str>,
|
||||
) -> Result<Record> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
let now = Utc::now().to_rfc3339();
|
||||
let record_type_str = serde_json::to_string(&record_type)?;
|
||||
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO records (id, record_type, content, file_path, thumbnail, metadata, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
|
||||
params![
|
||||
&id,
|
||||
&record_type_str,
|
||||
content,
|
||||
file_path,
|
||||
thumbnail,
|
||||
metadata,
|
||||
&now,
|
||||
&now,
|
||||
],
|
||||
).context("Failed to insert record")?;
|
||||
|
||||
Ok(Record {
|
||||
id,
|
||||
record_type,
|
||||
content: content.to_string(),
|
||||
file_path: file_path.map(|s| s.to_string()),
|
||||
thumbnail: thumbnail.map(|s| s.to_string()),
|
||||
metadata: metadata.map(|s| s.to_string()),
|
||||
created_at: now.clone(),
|
||||
updated_at: now,
|
||||
})
|
||||
}
|
||||
|
||||
/// 根据 ID 获取记录
|
||||
pub fn get_record(&self, id: &str) -> Result<Option<Record>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, record_type, content, file_path, thumbnail, metadata, created_at, updated_at
|
||||
FROM records WHERE id = ?1"
|
||||
).context("Failed to prepare statement")?;
|
||||
|
||||
let mut records = stmt.query_map(params![id], |row| {
|
||||
let record_type_str: String = row.get(1)?;
|
||||
let record_type: RecordType = serde_json::from_str(&record_type_str)
|
||||
.unwrap_or(RecordType::Text);
|
||||
|
||||
Ok(Record {
|
||||
id: row.get(0)?,
|
||||
record_type,
|
||||
content: row.get(2)?,
|
||||
file_path: row.get(3)?,
|
||||
thumbnail: row.get(4)?,
|
||||
metadata: row.get(5)?,
|
||||
created_at: row.get(6)?,
|
||||
updated_at: row.get(7)?,
|
||||
})
|
||||
}).context("Failed to query records")?;
|
||||
|
||||
records.next()
|
||||
.transpose()
|
||||
.context("Failed to parse record")
|
||||
}
|
||||
|
||||
/// 获取所有记录
|
||||
pub fn list_records(&self, limit: Option<usize>, offset: Option<usize>) -> Result<Vec<Record>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let limit = limit.unwrap_or(100);
|
||||
let offset = offset.unwrap_or(0);
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, record_type, content, file_path, thumbnail, metadata, created_at, updated_at
|
||||
FROM records
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?1 OFFSET ?2"
|
||||
).context("Failed to prepare statement")?;
|
||||
|
||||
let records = stmt.query_map(params![limit, offset], |row| {
|
||||
let record_type_str: String = row.get(1)?;
|
||||
let record_type: RecordType = serde_json::from_str(&record_type_str)
|
||||
.unwrap_or(RecordType::Text);
|
||||
|
||||
Ok(Record {
|
||||
id: row.get(0)?,
|
||||
record_type,
|
||||
content: row.get(2)?,
|
||||
file_path: row.get(3)?,
|
||||
thumbnail: row.get(4)?,
|
||||
metadata: row.get(5)?,
|
||||
created_at: row.get(6)?,
|
||||
updated_at: row.get(7)?,
|
||||
})
|
||||
}).context("Failed to query records")?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.context("Failed to parse records")?;
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
/// 根据类型获取记录
|
||||
pub fn list_records_by_type(
|
||||
&self,
|
||||
record_type: RecordType,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<Record>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let limit = limit.unwrap_or(100);
|
||||
let record_type_str = serde_json::to_string(&record_type)?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, record_type, content, file_path, thumbnail, metadata, created_at, updated_at
|
||||
FROM records
|
||||
WHERE record_type = ?1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?2"
|
||||
).context("Failed to prepare statement")?;
|
||||
|
||||
let records = stmt.query_map(params![record_type_str, limit], |row| {
|
||||
let record_type_str: String = row.get(1)?;
|
||||
let record_type: RecordType = serde_json::from_str(&record_type_str)
|
||||
.unwrap_or(RecordType::Text);
|
||||
|
||||
Ok(Record {
|
||||
id: row.get(0)?,
|
||||
record_type,
|
||||
content: row.get(2)?,
|
||||
file_path: row.get(3)?,
|
||||
thumbnail: row.get(4)?,
|
||||
metadata: row.get(5)?,
|
||||
created_at: row.get(6)?,
|
||||
updated_at: row.get(7)?,
|
||||
})
|
||||
}).context("Failed to query records")?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.context("Failed to parse records")?;
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
/// 更新记录
|
||||
pub fn update_record(
|
||||
&self,
|
||||
id: &str,
|
||||
content: Option<&str>,
|
||||
file_path: Option<&str>,
|
||||
thumbnail: Option<&str>,
|
||||
metadata: Option<&str>,
|
||||
) -> Result<Option<Record>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
// 首先获取现有记录
|
||||
let existing = self.get_record(id)?;
|
||||
|
||||
if existing.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
conn.execute(
|
||||
"UPDATE records
|
||||
SET content = COALESCE(?1, content),
|
||||
file_path = COALESCE(?2, file_path),
|
||||
thumbnail = COALESCE(?3, thumbnail),
|
||||
metadata = COALESCE(?4, metadata),
|
||||
updated_at = ?5
|
||||
WHERE id = ?6",
|
||||
params![content, file_path, thumbnail, metadata, &now, id],
|
||||
).context("Failed to update record")?;
|
||||
|
||||
self.get_record(id).transpose()
|
||||
}
|
||||
|
||||
/// 删除记录
|
||||
pub fn delete_record(&self, id: &str) -> Result<bool> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let affected = conn.execute("DELETE FROM records WHERE id = ?1", params![id])
|
||||
.context("Failed to delete record")?;
|
||||
|
||||
Ok(affected > 0)
|
||||
}
|
||||
|
||||
/// 清空所有记录
|
||||
pub fn clear_records(&self) -> Result<usize> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let affected = conn.execute("DELETE FROM records", [])
|
||||
.context("Failed to clear records")?;
|
||||
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
/// 获取记录总数
|
||||
pub fn get_records_count(&self) -> Result<usize> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM records", [], |row| row.get(0))
|
||||
.context("Failed to count records")?;
|
||||
|
||||
Ok(count as usize)
|
||||
}
|
||||
|
||||
/// 设置配置
|
||||
pub fn set_setting(&self, key: &str, value: &str) -> Result<()> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO settings (key, value, updated_at) VALUES (?1, ?2, ?3)",
|
||||
params![key, value, &now],
|
||||
).context("Failed to set setting")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取配置
|
||||
pub fn get_setting(&self, key: &str) -> Result<Option<String>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let value = conn
|
||||
.query_row("SELECT value FROM settings WHERE key = ?1", params![key], |row| {
|
||||
row.get(0)
|
||||
})
|
||||
.optional()
|
||||
.context("Failed to get setting")?;
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
/// 删除配置
|
||||
pub fn delete_setting(&self, key: &str) -> Result<bool> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let affected = conn.execute("DELETE FROM settings WHERE key = ?1", params![key])
|
||||
.context("Failed to delete setting")?;
|
||||
|
||||
Ok(affected > 0)
|
||||
}
|
||||
|
||||
/// 获取所有配置
|
||||
pub fn list_settings(&self) -> Result<Vec<Setting>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT key, value, updated_at FROM settings ORDER BY key"
|
||||
).context("Failed to prepare statement")?;
|
||||
|
||||
let settings = stmt.query_map([], |row| {
|
||||
Ok(Setting {
|
||||
key: row.get(0)?,
|
||||
value: row.get(1)?,
|
||||
updated_at: row.get(2)?,
|
||||
})
|
||||
}).context("Failed to query settings")?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.context("Failed to parse settings")?;
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
// ============= 分类相关方法 =============
|
||||
|
||||
/// 保存分类结果
|
||||
pub fn save_classification(
|
||||
&self,
|
||||
record_id: &str,
|
||||
category: &str,
|
||||
subcategory: Option<&str>,
|
||||
tags: &[String],
|
||||
confidence: f64,
|
||||
reasoning: Option<&str>,
|
||||
template_id: Option<&str>,
|
||||
) -> Result<Classification> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
let now = Utc::now().to_rfc3339();
|
||||
let tags_json = serde_json::to_string(tags)?;
|
||||
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO classifications (id, record_id, category, subcategory, tags, confidence, reasoning, template_id, confirmed, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
|
||||
params![
|
||||
&id,
|
||||
record_id,
|
||||
category,
|
||||
subcategory,
|
||||
&tags_json,
|
||||
confidence,
|
||||
reasoning,
|
||||
template_id,
|
||||
0, // confirmed = false
|
||||
&now,
|
||||
],
|
||||
).context("Failed to insert classification")?;
|
||||
|
||||
// 同时添加到历史记录
|
||||
self.add_classification_history(record_id, category, subcategory, confidence)?;
|
||||
|
||||
Ok(Classification {
|
||||
id,
|
||||
record_id: record_id.to_string(),
|
||||
category: category.to_string(),
|
||||
subcategory: subcategory.map(|s| s.to_string()),
|
||||
tags: tags_json,
|
||||
confidence,
|
||||
reasoning: reasoning.map(|s| s.to_string()),
|
||||
template_id: template_id.map(|s| s.to_string()),
|
||||
confirmed: false,
|
||||
created_at: now,
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取记录的分类结果
|
||||
pub fn get_classification(&self, record_id: &str) -> Result<Option<Classification>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, record_id, category, subcategory, tags, confidence, reasoning, template_id, confirmed, created_at
|
||||
FROM classifications WHERE record_id = ?1"
|
||||
).context("Failed to prepare statement")?;
|
||||
|
||||
let mut classifications = stmt.query_map(params![record_id], |row| {
|
||||
Ok(Classification {
|
||||
id: row.get(0)?,
|
||||
record_id: row.get(1)?,
|
||||
category: row.get(2)?,
|
||||
subcategory: row.get(3)?,
|
||||
tags: row.get(4)?,
|
||||
confidence: row.get(5)?,
|
||||
reasoning: row.get(6)?,
|
||||
template_id: row.get(7)?,
|
||||
confirmed: row.get(8)?,
|
||||
created_at: row.get(9)?,
|
||||
})
|
||||
}).context("Failed to query classifications")?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.context("Failed to parse classifications")?;
|
||||
|
||||
Ok(classifications.pop())
|
||||
}
|
||||
|
||||
/// 确认分类结果
|
||||
pub fn confirm_classification(&self, id: &str) -> Result<bool> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let affected = conn.execute(
|
||||
"UPDATE classifications SET confirmed = 1 WHERE id = ?1",
|
||||
params![id],
|
||||
).context("Failed to confirm classification")?;
|
||||
|
||||
Ok(affected > 0)
|
||||
}
|
||||
|
||||
/// 添加分类历史
|
||||
fn add_classification_history(
|
||||
&self,
|
||||
record_id: &str,
|
||||
category: &str,
|
||||
subcategory: Option<&str>,
|
||||
confidence: f64,
|
||||
) -> Result<()> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO classification_history (id, record_id, category, subcategory, confidence, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![&id, record_id, category, subcategory, confidence, &now],
|
||||
).context("Failed to insert classification history")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取分类历史
|
||||
pub fn get_classification_history(
|
||||
&self,
|
||||
record_id: &str,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<ClassificationHistory>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let limit = limit.unwrap_or(50);
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, record_id, category, subcategory, confidence, created_at
|
||||
FROM classification_history
|
||||
WHERE record_id = ?1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?2"
|
||||
).context("Failed to prepare statement")?;
|
||||
|
||||
let history = stmt.query_map(params![record_id, limit], |row| {
|
||||
Ok(ClassificationHistory {
|
||||
id: row.get(0)?,
|
||||
record_id: row.get(1)?,
|
||||
category: row.get(2)?,
|
||||
subcategory: row.get(3)?,
|
||||
confidence: row.get(4)?,
|
||||
created_at: row.get(5)?,
|
||||
})
|
||||
}).context("Failed to query classification history")?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.context("Failed to parse classification history")?;
|
||||
|
||||
Ok(history)
|
||||
}
|
||||
|
||||
/// 按分类获取记录
|
||||
pub fn list_records_by_category(
|
||||
&self,
|
||||
category: &str,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<Record>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let limit = limit.unwrap_or(100);
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT r.id, r.record_type, r.content, r.file_path, r.thumbnail, r.metadata, r.created_at, r.updated_at
|
||||
FROM records r
|
||||
INNER JOIN classifications c ON r.id = c.record_id
|
||||
WHERE c.category = ?1 AND c.confirmed = 1
|
||||
ORDER BY r.created_at DESC
|
||||
LIMIT ?2"
|
||||
).context("Failed to prepare statement")?;
|
||||
|
||||
let records = stmt.query_map(params![category, limit], |row| {
|
||||
let record_type_str: String = row.get(1)?;
|
||||
let record_type: RecordType = serde_json::from_str(&record_type_str)
|
||||
.unwrap_or(RecordType::Text);
|
||||
|
||||
Ok(Record {
|
||||
id: row.get(0)?,
|
||||
record_type,
|
||||
content: row.get(2)?,
|
||||
file_path: row.get(3)?,
|
||||
thumbnail: row.get(4)?,
|
||||
metadata: row.get(5)?,
|
||||
created_at: row.get(6)?,
|
||||
updated_at: row.get(7)?,
|
||||
})
|
||||
}).context("Failed to query records by category")?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.context("Failed to parse records")?;
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
/// 删除分类
|
||||
pub fn delete_classification(&self, id: &str) -> Result<bool> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let affected = conn.execute("DELETE FROM classifications WHERE id = ?1", params![id])
|
||||
.context("Failed to delete classification")?;
|
||||
|
||||
Ok(affected > 0)
|
||||
}
|
||||
|
||||
/// 获取所有分类统计
|
||||
pub fn get_category_stats(&self) -> Result<Vec<(String, usize)>> {
|
||||
let conn = self.conn.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT category, COUNT(*) as count
|
||||
FROM classifications
|
||||
WHERE confirmed = 1
|
||||
GROUP BY category
|
||||
ORDER BY count DESC"
|
||||
).context("Failed to prepare statement")?;
|
||||
|
||||
let stats = stmt.query_map([], |row| {
|
||||
Ok((row.get(0)?, row.get(1)?))
|
||||
}).context("Failed to query category stats")?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.context("Failed to parse category stats")?;
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_record_type_serialization() {
|
||||
let record_type = RecordType::Image;
|
||||
let json = serde_json::to_string(&record_type).unwrap();
|
||||
assert_eq!(json, r#"{"type":"image"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_database_in_memory() -> Result<()> {
|
||||
let db = Database::open(":memory:")?;
|
||||
|
||||
// 插入测试记录
|
||||
let record = db.insert_record(
|
||||
RecordType::Text,
|
||||
"Test content",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)?;
|
||||
|
||||
// 查询记录
|
||||
let found = db.get_record(&record.id)?;
|
||||
assert!(found.is_some());
|
||||
|
||||
let found_record = found.unwrap();
|
||||
assert_eq!(found_record.content, "Test content");
|
||||
|
||||
// 删除记录
|
||||
let deleted = db.delete_record(&record.id)?;
|
||||
assert!(deleted);
|
||||
|
||||
// 验证删除
|
||||
let found = db.get_record(&record.id)?;
|
||||
assert!(found.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
195
src-tauri/src/hotkey.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use tauri::{AppHandle, Emitter, Listener, Manager, State, WebviewUrl, WebviewWindowBuilder};
|
||||
use tauri_plugin_global_shortcut::{GlobalShortcutExt, Shortcut, ShortcutState};
|
||||
|
||||
/// 快捷键动作类型
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum HotkeyAction {
|
||||
CaptureFullscreen,
|
||||
CaptureRegion,
|
||||
CaptureWindow,
|
||||
ShowHide,
|
||||
}
|
||||
|
||||
impl HotkeyAction {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
HotkeyAction::CaptureFullscreen => "capture_fullscreen",
|
||||
HotkeyAction::CaptureRegion => "capture_region",
|
||||
HotkeyAction::CaptureWindow => "capture_window",
|
||||
HotkeyAction::ShowHide => "show_hide",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"capture_fullscreen" => Some(HotkeyAction::CaptureFullscreen),
|
||||
"capture_region" => Some(HotkeyAction::CaptureRegion),
|
||||
"capture_window" => Some(HotkeyAction::CaptureWindow),
|
||||
"show_hide" => Some(HotkeyAction::ShowHide),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 快捷键配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HotkeyConfig {
|
||||
pub id: String,
|
||||
pub shortcut: String,
|
||||
pub action: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// 快捷键管理器
|
||||
pub struct HotkeyManager {
|
||||
app_handle: AppHandle,
|
||||
registered_shortcuts: HashMap<String, Shortcut>,
|
||||
}
|
||||
|
||||
impl HotkeyManager {
|
||||
/// 创建新的快捷键管理器
|
||||
pub fn new(app_handle: AppHandle) -> Self {
|
||||
Self {
|
||||
app_handle,
|
||||
registered_shortcuts: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 注册快捷键
|
||||
pub fn register_shortcut(&mut self, config: HotkeyConfig) -> Result<(), String> {
|
||||
if !config.enabled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let shortcut = Shortcut::new(
|
||||
Some(config.shortcut.clone()),
|
||||
config.shortcut.clone(),
|
||||
);
|
||||
|
||||
self.app_handle
|
||||
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
|
||||
.map_err(|e| format!("Failed to load global shortcut plugin: {}", e))?;
|
||||
|
||||
let app_handle = self.app_handle.clone();
|
||||
let action = config.action.clone();
|
||||
|
||||
self.app_handle
|
||||
.global_shortcut()
|
||||
.on_shortcut(shortcut.clone(), move |_app, _shortcut, _event| {
|
||||
// 发送事件到前端
|
||||
if let Err(e) = app_handle.emit("hotkey-triggered", &action) {
|
||||
eprintln!("Failed to emit hotkey event: {}", e);
|
||||
}
|
||||
})
|
||||
.map_err(|e| format!("Failed to register shortcut handler: {}", e))?;
|
||||
|
||||
self.registered_shortcuts.insert(config.id.clone(), shortcut);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 注销快捷键
|
||||
pub fn unregister_shortcut(&mut self, id: &str) -> Result<(), String> {
|
||||
if let Some(shortcut) = self.registered_shortcuts.remove(id) {
|
||||
self.app_handle
|
||||
.global_shortcut()
|
||||
.unregister(shortcut)
|
||||
.map_err(|e| format!("Failed to unregister shortcut: {}", e))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取已注册的快捷键列表
|
||||
pub fn get_registered_shortcuts(&self) -> Vec<String> {
|
||||
self.registered_shortcuts.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// 注册默认快捷键
|
||||
pub fn register_defaults(&mut self) -> Result<(), String> {
|
||||
let defaults = vec![
|
||||
HotkeyConfig {
|
||||
id: "fullscreen".to_string(),
|
||||
shortcut: "Ctrl+Shift+A".to_string(),
|
||||
action: "capture_fullscreen".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
HotkeyConfig {
|
||||
id: "region".to_string(),
|
||||
shortcut: "Ctrl+Shift+R".to_string(),
|
||||
action: "capture_region".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
HotkeyConfig {
|
||||
id: "window".to_string(),
|
||||
shortcut: "Ctrl+Shift+W".to_string(),
|
||||
action: "capture_window".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
HotkeyConfig {
|
||||
id: "toggle".to_string(),
|
||||
shortcut: "Ctrl+Shift+S".to_string(),
|
||||
action: "show_hide".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
];
|
||||
|
||||
for config in defaults {
|
||||
self.register_shortcut(config)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// 初始化全局快捷键
|
||||
pub fn init_global_hotkeys(app_handle: &AppHandle) {
|
||||
let mut manager = HotkeyManager::new(app_handle.clone());
|
||||
|
||||
if let Err(e) = manager.register_defaults() {
|
||||
eprintln!("Failed to register default hotkeys: {}", e);
|
||||
}
|
||||
|
||||
// 监听快捷键事件
|
||||
let app_handle = app_handle.clone();
|
||||
app_handle.listen("hotkey-triggered", |event| {
|
||||
if let Some(action_str) = event.payload() {
|
||||
if let Ok(action) = serde_json::from_str::<String>(action_str) {
|
||||
handle_hotkey_action(&app_handle, &action);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// 处理快捷键动作
|
||||
fn handle_hotkey_action(app_handle: &AppHandle, action: &str) {
|
||||
match HotkeyAction::from_str(action) {
|
||||
Some(HotkeyAction::CaptureFullscreen) => {
|
||||
// 发送全屏截图事件
|
||||
let _ = app_handle.emit("screenshot-fullscreen-triggered", ());
|
||||
}
|
||||
Some(HotkeyAction::CaptureRegion) => {
|
||||
// 发送区域截图事件
|
||||
let _ = app_handle.emit("screenshot-region-triggered", ());
|
||||
}
|
||||
Some(HotkeyAction::CaptureWindow) => {
|
||||
// 发送窗口截图事件
|
||||
let _ = app_handle.emit("screenshot-window-triggered", ());
|
||||
}
|
||||
Some(HotkeyAction::ShowHide) => {
|
||||
// 显示/隐藏主窗口
|
||||
if let Some(window) = app_handle.get_webview_window("main") {
|
||||
if window.is_visible().unwrap_or(false) {
|
||||
let _ = window.hide();
|
||||
} else {
|
||||
let _ = window.show();
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
eprintln!("Unknown hotkey action: {}", action);
|
||||
}
|
||||
}
|
||||
}
|
||||
991
src-tauri/src/lib.rs
Normal file
@@ -0,0 +1,991 @@
|
||||
mod screenshot;
|
||||
mod hotkey;
|
||||
mod config;
|
||||
mod upload;
|
||||
mod database;
|
||||
mod ocr;
|
||||
mod plugin;
|
||||
mod secure_storage;
|
||||
mod ai;
|
||||
|
||||
use screenshot::ScreenshotManager;
|
||||
use hotkey::init_global_hotkeys;
|
||||
use config::{ConfigManager, ImageHostConfig};
|
||||
use upload::{Uploader, UploadProgress, UploadResult};
|
||||
use database::{Database, RecordType};
|
||||
use ocr::{OcrEngineType, OcrConfig, BaiduOcrConfig, TencentOcrConfig, LocalOcrConfig};
|
||||
use plugin::{PluginManager, InstallProgress};
|
||||
use secure_storage::ApiKeyStorage;
|
||||
use ai::{AiClient, AiProvider, Classifier, ClassificationResult, StreamChunk, TemplateManager};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// 全局应用状态
|
||||
struct AppState {
|
||||
screenshot_manager: Mutex<ScreenshotManager>,
|
||||
config_manager: Mutex<ConfigManager>,
|
||||
database: Mutex<Database>,
|
||||
plugin_manager: Mutex<PluginManager>,
|
||||
classifier: Mutex<Classifier>,
|
||||
template_manager: Mutex<TemplateManager>,
|
||||
}
|
||||
|
||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||
pub fn run() {
|
||||
// 初始化配置管理器
|
||||
let config_manager = ConfigManager::new()
|
||||
.expect("Failed to initialize config manager");
|
||||
|
||||
// 确保数据目录存在
|
||||
config_manager.ensure_data_dir()
|
||||
.expect("Failed to create data directory");
|
||||
|
||||
// 初始化数据库
|
||||
let database = Database::open(&config_manager.database_path())
|
||||
.expect("Failed to initialize database");
|
||||
|
||||
// 初始化截图管理器
|
||||
let screenshot_manager = ScreenshotManager::with_default_dir()
|
||||
.expect("Failed to initialize screenshot manager");
|
||||
|
||||
// 初始化插件管理器
|
||||
let plugin_manager = PluginManager::new(config_manager.config_dir.join("data"));
|
||||
|
||||
// 初始化 AI 分类器
|
||||
let classifier = Classifier::new(ai::ClassifierConfig::default());
|
||||
|
||||
// 初始化模板管理器
|
||||
let templates_dir = config_manager.data_dir().join("templates");
|
||||
let template_manager = TemplateManager::new(templates_dir)
|
||||
.expect("Failed to initialize template manager");
|
||||
|
||||
// 获取配置目录路径
|
||||
let config_dir = config_manager.config_dir.join("data");
|
||||
|
||||
tauri::Builder::default()
|
||||
.setup(|app| {
|
||||
if cfg!(debug_assertions) {
|
||||
app.handle().plugin(
|
||||
tauri_plugin_log::Builder::default()
|
||||
.level(log::LevelFilter::Info)
|
||||
.build(),
|
||||
)?;
|
||||
}
|
||||
|
||||
// 初始化全局快捷键
|
||||
init_global_hotkeys(app.handle());
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.manage(AppState {
|
||||
screenshot_manager: Mutex::new(screenshot_manager),
|
||||
config_manager: Mutex::new(config_manager),
|
||||
database: Mutex::new(database),
|
||||
plugin_manager: Mutex::new(plugin_manager),
|
||||
classifier: Mutex::new(classifier),
|
||||
template_manager: Mutex::new(template_manager),
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
// 截图相关命令
|
||||
screenshot_fullscreen,
|
||||
screenshot_region,
|
||||
screenshot_copy_to_clipboard,
|
||||
screenshot_delete,
|
||||
screenshot_list,
|
||||
screenshot_cleanup,
|
||||
// 配置相关命令
|
||||
config_get,
|
||||
config_set,
|
||||
config_get_path,
|
||||
// 上传相关命令
|
||||
upload_image,
|
||||
// 数据库相关命令
|
||||
record_insert,
|
||||
record_get,
|
||||
record_list,
|
||||
record_delete,
|
||||
record_clear,
|
||||
record_count,
|
||||
setting_get,
|
||||
setting_set,
|
||||
setting_delete,
|
||||
setting_list,
|
||||
// OCR 相关命令
|
||||
ocr_recognize,
|
||||
ocr_save_api_key,
|
||||
ocr_get_api_keys,
|
||||
// 插件相关命令
|
||||
plugin_list,
|
||||
plugin_install,
|
||||
plugin_uninstall,
|
||||
// AI 分类相关命令
|
||||
ai_classify,
|
||||
ai_classify_stream,
|
||||
ai_save_api_key,
|
||||
ai_get_api_keys,
|
||||
ai_configure_provider,
|
||||
template_list,
|
||||
template_get,
|
||||
template_save,
|
||||
template_delete,
|
||||
template_test,
|
||||
classification_get,
|
||||
classification_confirm,
|
||||
classification_history,
|
||||
classification_stats,
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
}
|
||||
|
||||
// ============= 截图相关命令 =============
|
||||
|
||||
/// 全屏截图
|
||||
#[tauri::command]
|
||||
async fn screenshot_fullscreen(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<screenshot::ScreenshotMetadata, String> {
|
||||
let manager = state.screenshot_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
manager.capture_fullscreen()
|
||||
.map_err(|e| format!("Failed to capture screenshot: {}", e))
|
||||
}
|
||||
|
||||
/// 区域截图
|
||||
#[tauri::command]
|
||||
async fn screenshot_region(
|
||||
region: screenshot::RegionSelection,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<screenshot::ScreenshotMetadata, String> {
|
||||
let manager = state.screenshot_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
manager.capture_region(region)
|
||||
.map_err(|e| format!("Failed to capture region: {}", e))
|
||||
}
|
||||
|
||||
/// 复制截图到剪贴板
|
||||
#[tauri::command]
|
||||
async fn screenshot_copy_to_clipboard(
|
||||
filepath: PathBuf,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(), String> {
|
||||
let manager = state.screenshot_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
manager.copy_to_clipboard(&filepath)
|
||||
.map_err(|e| format!("Failed to copy to clipboard: {}", e))
|
||||
}
|
||||
|
||||
/// 删除截图
|
||||
#[tauri::command]
|
||||
async fn screenshot_delete(
|
||||
filepath: PathBuf,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(), String> {
|
||||
let manager = state.screenshot_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
manager.delete_screenshot(&filepath)
|
||||
.map_err(|e| format!("Failed to delete screenshot: {}", e))
|
||||
}
|
||||
|
||||
/// 获取所有截图列表
|
||||
#[tauri::command]
|
||||
async fn screenshot_list(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Vec<screenshot::ScreenshotMetadata>, String> {
|
||||
let manager = state.screenshot_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
manager.list_screenshots()
|
||||
.map_err(|e| format!("Failed to list screenshots: {}", e))
|
||||
}
|
||||
|
||||
/// 清理旧截图
|
||||
#[tauri::command]
|
||||
async fn screenshot_cleanup(
|
||||
keep_count: usize,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<usize, String> {
|
||||
let manager = state.screenshot_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
manager.cleanup_old_screenshots(keep_count)
|
||||
.map_err(|e| format!("Failed to cleanup screenshots: {}", e))
|
||||
}
|
||||
|
||||
// ============= 配置相关命令 =============
|
||||
|
||||
/// 获取配置
|
||||
#[tauri::command]
|
||||
async fn config_get(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<config::AppConfig, String> {
|
||||
let config_manager = state.config_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
config_manager.load()
|
||||
.map_err(|e| format!("Failed to load config: {}", e))
|
||||
}
|
||||
|
||||
/// 设置配置
|
||||
#[tauri::command]
|
||||
async fn config_set(
|
||||
config: config::AppConfig,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(), String> {
|
||||
let config_manager = state.config_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
config_manager.save(&config)
|
||||
.map_err(|e| format!("Failed to save config: {}", e))
|
||||
}
|
||||
|
||||
/// 获取配置目录路径
|
||||
#[tauri::command]
|
||||
async fn config_get_path(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<String, String> {
|
||||
let config_manager = state.config_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
Ok(config_manager.config_file_path()
|
||||
.to_string_lossy()
|
||||
.to_string())
|
||||
}
|
||||
|
||||
// ============= 上传相关命令 =============
|
||||
|
||||
/// 上传图片到图床
|
||||
#[tauri::command]
|
||||
async fn upload_image(
|
||||
image_path: PathBuf,
|
||||
image_host: ImageHostConfig,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<UploadResult, String> {
|
||||
// 加载配置以获取重试次数和超时设置
|
||||
let config_manager = state.config_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let config = config_manager.load()
|
||||
.map_err(|e| format!("Failed to load config: {}", e))?;
|
||||
|
||||
let uploader = Uploader::new(
|
||||
config.upload_retry_count,
|
||||
config.upload_timeout_seconds,
|
||||
);
|
||||
|
||||
// 使用 tokio runtime 运行异步上传
|
||||
let result = tokio::spawn(async move {
|
||||
uploader.upload_with_retry(&image_path, &image_host, |progress| {
|
||||
match progress {
|
||||
UploadProgress::Starting => {
|
||||
log::info!("Upload starting");
|
||||
}
|
||||
UploadProgress::Uploading { progress, message } => {
|
||||
log::info!("Upload progress: {:.1}% - {}", progress, message);
|
||||
}
|
||||
UploadProgress::Completed(result) => {
|
||||
log::info!("Upload completed: {}", result.url);
|
||||
}
|
||||
UploadProgress::Failed { error } => {
|
||||
log::error!("Upload failed: {}", error);
|
||||
}
|
||||
}
|
||||
}).await
|
||||
}).await
|
||||
.map_err(|e| format!("Failed to join upload task: {}", e))?;
|
||||
|
||||
result.map_err(|e| format!("Upload error: {}", e))
|
||||
}
|
||||
|
||||
// ============= 数据库相关命令 =============
|
||||
|
||||
/// 插入记录
|
||||
#[tauri::command]
|
||||
async fn record_insert(
|
||||
record_type: String,
|
||||
content: String,
|
||||
file_path: Option<String>,
|
||||
thumbnail: Option<String>,
|
||||
metadata: Option<String>,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<database::Record, String> {
|
||||
let record_type = match record_type.as_str() {
|
||||
"image" => RecordType::Image,
|
||||
"text" => RecordType::Text,
|
||||
"file" => RecordType::File,
|
||||
_ => return Err("Invalid record type".to_string()),
|
||||
};
|
||||
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.insert_record(
|
||||
record_type,
|
||||
&content,
|
||||
file_path.as_deref(),
|
||||
thumbnail.as_deref(),
|
||||
metadata.as_deref(),
|
||||
).map_err(|e| format!("Failed to insert record: {}", e))
|
||||
}
|
||||
|
||||
/// 获取记录
|
||||
#[tauri::command]
|
||||
async fn record_get(
|
||||
id: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Option<database::Record>, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.get_record(&id)
|
||||
.map_err(|e| format!("Failed to get record: {}", e))
|
||||
}
|
||||
|
||||
/// 列出记录
|
||||
#[tauri::command]
|
||||
async fn record_list(
|
||||
limit: Option<usize>,
|
||||
offset: Option<usize>,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Vec<database::Record>, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.list_records(limit, offset)
|
||||
.map_err(|e| format!("Failed to list records: {}", e))
|
||||
}
|
||||
|
||||
/// 删除记录
|
||||
#[tauri::command]
|
||||
async fn record_delete(
|
||||
id: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<bool, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.delete_record(&id)
|
||||
.map_err(|e| format!("Failed to delete record: {}", e))
|
||||
}
|
||||
|
||||
/// 清空所有记录
|
||||
#[tauri::command]
|
||||
async fn record_clear(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<usize, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.clear_records()
|
||||
.map_err(|e| format!("Failed to clear records: {}", e))
|
||||
}
|
||||
|
||||
/// 获取记录数量
|
||||
#[tauri::command]
|
||||
async fn record_count(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<usize, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.get_records_count()
|
||||
.map_err(|e| format!("Failed to get records count: {}", e))
|
||||
}
|
||||
|
||||
/// 获取设置
|
||||
#[tauri::command]
|
||||
async fn setting_get(
|
||||
key: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Option<String>, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.get_setting(&key)
|
||||
.map_err(|e| format!("Failed to get setting: {}", e))
|
||||
}
|
||||
|
||||
/// 设置设置
|
||||
#[tauri::command]
|
||||
async fn setting_set(
|
||||
key: String,
|
||||
value: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(), String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.set_setting(&key, &value)
|
||||
.map_err(|e| format!("Failed to set setting: {}", e))
|
||||
}
|
||||
|
||||
/// 删除设置
|
||||
#[tauri::command]
|
||||
async fn setting_delete(
|
||||
key: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<bool, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.delete_setting(&key)
|
||||
.map_err(|e| format!("Failed to delete setting: {}", e))
|
||||
}
|
||||
|
||||
/// 列出所有设置
|
||||
#[tauri::command]
|
||||
async fn setting_list(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Vec<database::Setting>, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.list_settings()
|
||||
.map_err(|e| format!("Failed to list settings: {}", e))
|
||||
}
|
||||
|
||||
// ============= OCR 相关命令 =============
|
||||
|
||||
/// 执行 OCR 识别
|
||||
#[tauri::command]
|
||||
async fn ocr_recognize(
|
||||
image_path: PathBuf,
|
||||
engine: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
window: tauri::Window,
|
||||
) -> Result<ocr::OcrResult, String> {
|
||||
let ocr_engine = match engine.as_str() {
|
||||
"baidu" => OcrEngineType::Baidu,
|
||||
"tencent" => OcrEngineType::Tencent,
|
||||
"local" => OcrEngineType::Local,
|
||||
_ => return Err("不支持的 OCR 引擎".to_string()),
|
||||
};
|
||||
|
||||
let config_manager = state.config_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
// 从数据库获取 API 密钥配置
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
// 构建 OCR 配置
|
||||
let ocr_config = match ocr_engine {
|
||||
OcrEngineType::Baidu => {
|
||||
let api_key = db.get_setting("baidu_ocr_api_key")
|
||||
.map_err(|e| format!("Failed to get API key: {}", e))?
|
||||
.ok_or_else(|| "百度 API Key 未设置".to_string())?;
|
||||
let secret_key = db.get_setting("baidu_ocr_secret_key")
|
||||
.map_err(|e| format!("Failed to get secret key: {}", e))?
|
||||
.ok_or_else(|| "百度 Secret Key 未设置".to_string())?;
|
||||
|
||||
OcrConfig {
|
||||
engine: OcrEngineType::Baidu,
|
||||
baidu: Some(BaiduOcrConfig {
|
||||
api_key,
|
||||
secret_key,
|
||||
accurate: true,
|
||||
}),
|
||||
tencent: None,
|
||||
local: None,
|
||||
}
|
||||
}
|
||||
OcrEngineType::Tencent => {
|
||||
let secret_id = db.get_setting("tencent_ocr_secret_id")
|
||||
.map_err(|e| format!("Failed to get secret id: {}", e))?
|
||||
.ok_or_else(|| "腾讯云 Secret ID 未设置".to_string())?;
|
||||
let secret_key = db.get_setting("tencent_ocr_secret_key")
|
||||
.map_err(|e| format!("Failed to get secret key: {}", e))?
|
||||
.ok_or_else(|| "腾讯云 Secret Key 未设置".to_string())?;
|
||||
|
||||
OcrConfig {
|
||||
engine: OcrEngineType::Tencent,
|
||||
baidu: None,
|
||||
tencent: Some(TencentOcrConfig {
|
||||
secret_id,
|
||||
secret_key,
|
||||
region: "ap-guangzhou".to_string(),
|
||||
}),
|
||||
local: None,
|
||||
}
|
||||
}
|
||||
OcrEngineType::Local => {
|
||||
let plugin_manager = state.plugin_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let plugin_path = db.get_setting("local_ocr_plugin_path")
|
||||
.map_err(|e| format!("Failed to get plugin path: {}", e))?
|
||||
.ok_or_else(|| "本地 OCR 插件路径未设置".to_string())?;
|
||||
|
||||
OcrConfig {
|
||||
engine: OcrEngineType::Local,
|
||||
baidu: None,
|
||||
tencent: None,
|
||||
local: Some(LocalOcrConfig {
|
||||
plugin_path: PathBuf::from(plugin_path),
|
||||
model_path: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 读取图片数据
|
||||
let image_data = std::fs::read(&image_path)
|
||||
.map_err(|e| format!("Failed to read image: {}", e))?;
|
||||
|
||||
// 执行 OCR
|
||||
let provider = ocr::CloudOcrProvider::new();
|
||||
|
||||
// 通过窗口事件发送进度更新
|
||||
let window_clone = window.clone();
|
||||
let progress_callback = Box::new(move |progress| {
|
||||
let _ = window_clone.emit("ocr-progress", &progress);
|
||||
});
|
||||
|
||||
let result = tokio::spawn(async move {
|
||||
provider.recognize(&image_data, &ocr_config, Some(progress_callback)).await
|
||||
}).await
|
||||
.map_err(|e| format!("Failed to join OCR task: {}", e))?
|
||||
.map_err(|e| format!("OCR error: {}", e))?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 保存 OCR API 密钥
|
||||
#[tauri::command]
|
||||
async fn ocr_save_api_key(
|
||||
provider: String,
|
||||
keys: std::collections::HashMap<String, String>,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(), String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
match provider.as_str() {
|
||||
"baidu" => {
|
||||
if let Some(api_key) = keys.get("api_key") {
|
||||
db.set_setting("baidu_ocr_api_key", api_key)
|
||||
.map_err(|e| format!("Failed to save API key: {}", e))?;
|
||||
}
|
||||
if let Some(secret_key) = keys.get("secret_key") {
|
||||
db.set_setting("baidu_ocr_secret_key", secret_key)
|
||||
.map_err(|e| format!("Failed to save secret key: {}", e))?;
|
||||
}
|
||||
}
|
||||
"tencent" => {
|
||||
if let Some(secret_id) = keys.get("secret_id") {
|
||||
db.set_setting("tencent_ocr_secret_id", secret_id)
|
||||
.map_err(|e| format!("Failed to save secret id: {}", e))?;
|
||||
}
|
||||
if let Some(secret_key) = keys.get("secret_key") {
|
||||
db.set_setting("tencent_ocr_secret_key", secret_key)
|
||||
.map_err(|e| format!("Failed to save secret key: {}", e))?;
|
||||
}
|
||||
}
|
||||
_ => return Err("不支持的 OCR 提供者".to_string()),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取已保存的 API 密钥(返回脱敏信息)
|
||||
#[tauri::command]
|
||||
async fn ocr_get_api_keys(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<std::collections::HashMap<String, bool>, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let mut result = std::collections::HashMap::new();
|
||||
|
||||
// 检查百度密钥
|
||||
let baidu_key = db.get_setting("baidu_ocr_api_key")
|
||||
.map_err(|e| format!("Failed to check baidu key: {}", e))?;
|
||||
result.insert("baidu".to_string(), baidu_key.is_some());
|
||||
|
||||
// 检查腾讯云密钥
|
||||
let tencent_key = db.get_setting("tencent_ocr_secret_id")
|
||||
.map_err(|e| format!("Failed to check tencent key: {}", e))?;
|
||||
result.insert("tencent".to_string(), tencent_key.is_some());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ============= 插件相关命令 =============
|
||||
|
||||
/// 获取插件列表
|
||||
#[tauri::command]
|
||||
async fn plugin_list(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Vec<plugin::PluginStatus>, String> {
|
||||
let plugin_manager = state.plugin_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
plugin_manager.get_plugin_status().await
|
||||
}).await
|
||||
.map_err(|e| format!("Failed to join plugin list task: {}", e))?
|
||||
.map_err(|e| format!("Failed to get plugin list: {}", e))
|
||||
}
|
||||
|
||||
/// 安装插件
|
||||
#[tauri::command]
|
||||
async fn plugin_install(
|
||||
plugin_id: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
window: tauri::Window,
|
||||
) -> Result<String, String> {
|
||||
let plugin_manager = state.plugin_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
|
||||
|
||||
// 在后台任务中发送进度更新
|
||||
let window_clone = window.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(progress) = rx.recv().await {
|
||||
let _ = window_clone.emit("plugin-install-progress", &progress);
|
||||
}
|
||||
});
|
||||
|
||||
// 执行安装
|
||||
let install_path = tokio::spawn(async move {
|
||||
plugin_manager.install_plugin(&plugin_id, tx).await
|
||||
}).await
|
||||
.map_err(|e| format!("Failed to join install task: {}", e))?
|
||||
.map_err(|e| format!("Installation failed: {}", e))?;
|
||||
|
||||
Ok(install_path.to_string_lossy().to_string())
|
||||
}
|
||||
|
||||
/// 卸载插件
|
||||
#[tauri::command]
|
||||
async fn plugin_uninstall(
|
||||
plugin_id: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(), String> {
|
||||
let plugin_manager = state.plugin_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
plugin_manager.uninstall_plugin(&plugin_id)
|
||||
.map_err(|e| format!("Failed to uninstall plugin: {}", e))
|
||||
}
|
||||
|
||||
// ============= AI 分类相关命令 =============
|
||||
|
||||
/// 执行 AI 分类
|
||||
#[tauri::command]
|
||||
async fn ai_classify(
|
||||
record_id: String,
|
||||
template_id: Option<String>,
|
||||
variables: HashMap<String, String>,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<ClassificationResult, String> {
|
||||
let classifier = state.classifier.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let result = tokio::spawn(async move {
|
||||
// 释放锁后再调用
|
||||
let classifier_ref = unsafe { &*(&*classifier as *const Classifier) };
|
||||
classifier_ref.classify(template_id.as_deref(), &variables).await
|
||||
}).await
|
||||
.map_err(|e| format!("Failed to join classify task: {}", e))?
|
||||
.map_err(|e| format!("Classification error: {}", e))?;
|
||||
|
||||
// 保存分类结果到数据库
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.save_classification(
|
||||
&record_id,
|
||||
&result.category,
|
||||
result.subcategory.as_deref(),
|
||||
&result.tags,
|
||||
result.confidence,
|
||||
result.reasoning.as_deref(),
|
||||
template_id.as_deref(),
|
||||
).map_err(|e| format!("Failed to save classification: {}", e))?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 执行流式 AI 分类
|
||||
#[tauri::command]
|
||||
async fn ai_classify_stream(
|
||||
record_id: String,
|
||||
template_id: Option<String>,
|
||||
variables: HashMap<String, String>,
|
||||
state: tauri::State<'_, AppState>,
|
||||
window: tauri::Window,
|
||||
) -> Result<ClassificationResult, String> {
|
||||
let classifier = state.classifier.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let window_clone = window.clone();
|
||||
|
||||
let result = tokio::spawn(async move {
|
||||
let classifier_ref = unsafe { &*(&*classifier as *const Classifier) };
|
||||
classifier_ref.classify_stream(
|
||||
template_id.as_deref(),
|
||||
&variables,
|
||||
|chunk| {
|
||||
match chunk {
|
||||
StreamChunk::Text(text) => {
|
||||
let _ = window_clone.emit("ai-classify-chunk", &text);
|
||||
}
|
||||
StreamChunk::Done => {
|
||||
let _ = window_clone.emit("ai-classify-done", ());
|
||||
}
|
||||
StreamChunk::Error(err) => {
|
||||
let _ = window_clone.emit("ai-classify-error", &err);
|
||||
}
|
||||
}
|
||||
},
|
||||
).await
|
||||
}).await
|
||||
.map_err(|e| format!("Failed to join classify task: {}", e))?
|
||||
.map_err(|e| format!("Classification error: {}", e))?;
|
||||
|
||||
// 保存分类结果到数据库
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.save_classification(
|
||||
&record_id,
|
||||
&result.category,
|
||||
result.subcategory.as_deref(),
|
||||
&result.tags,
|
||||
result.confidence,
|
||||
result.reasoning.as_deref(),
|
||||
template_id.as_deref(),
|
||||
).map_err(|e| format!("Failed to save classification: {}", e))?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 保存 AI API 密钥
|
||||
#[tauri::command]
|
||||
async fn ai_save_api_key(
|
||||
provider: String,
|
||||
api_key: String,
|
||||
model: Option<String>,
|
||||
base_url: Option<String>,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(), String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
match provider.as_str() {
|
||||
"claude" => {
|
||||
db.set_setting("claude_api_key", &api_key)
|
||||
.map_err(|e| format!("Failed to save API key: {}", e))?;
|
||||
if let Some(model) = model {
|
||||
db.set_setting("claude_model", &model)
|
||||
.map_err(|e| format!("Failed to save model: {}", e))?;
|
||||
}
|
||||
}
|
||||
"openai" => {
|
||||
db.set_setting("openai_api_key", &api_key)
|
||||
.map_err(|e| format!("Failed to save API key: {}", e))?;
|
||||
if let Some(model) = model {
|
||||
db.set_setting("openai_model", &model)
|
||||
.map_err(|e| format!("Failed to save model: {}", e))?;
|
||||
}
|
||||
if let Some(base_url) = base_url {
|
||||
db.set_setting("openai_base_url", &base_url)
|
||||
.map_err(|e| format!("Failed to save base URL: {}", e))?;
|
||||
}
|
||||
}
|
||||
_ => return Err("不支持的 AI 提供商".to_string()),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取已保存的 API 密钥(返回脱敏信息)
|
||||
#[tauri::command]
|
||||
async fn ai_get_api_keys(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<HashMap<String, bool>, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
|
||||
let claude_key = db.get_setting("claude_api_key")
|
||||
.map_err(|e| format!("Failed to check claude key: {}", e))?;
|
||||
result.insert("claude".to_string(), claude_key.is_some());
|
||||
|
||||
let openai_key = db.get_setting("openai_api_key")
|
||||
.map_err(|e| format!("Failed to check openai key: {}", e))?;
|
||||
result.insert("openai".to_string(), openai_key.is_some());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 配置 AI 提供商
|
||||
#[tauri::command]
|
||||
async fn ai_configure_provider(
|
||||
provider: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(), String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
let classifier = state.classifier.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
match provider.as_str() {
|
||||
"claude" => {
|
||||
let api_key = db.get_setting("claude_api_key")
|
||||
.map_err(|e| format!("Failed to get API key: {}", e))?
|
||||
.ok_or_else(|| "Claude API Key 未设置".to_string())?;
|
||||
let model = db.get_setting("claude_model")
|
||||
.map_err(|e| format!("Failed to get model: {}", e))?;
|
||||
|
||||
let classifier_ref = unsafe { &*(&*classifier as *const Classifier) };
|
||||
classifier_ref.configure_claude(api_key, model)
|
||||
.map_err(|e| format!("Failed to configure Claude: {}", e))?;
|
||||
}
|
||||
"openai" => {
|
||||
let api_key = db.get_setting("openai_api_key")
|
||||
.map_err(|e| format!("Failed to get API key: {}", e))?
|
||||
.ok_or_else(|| "OpenAI API Key 未设置".to_string())?;
|
||||
let model = db.get_setting("openai_model")
|
||||
.map_err(|e| format!("Failed to get model: {}", e))?;
|
||||
let base_url = db.get_setting("openai_base_url")
|
||||
.map_err(|e| format!("Failed to get base URL: {}", e))?;
|
||||
|
||||
let classifier_ref = unsafe { &*(&*classifier as *const Classifier) };
|
||||
classifier_ref.configure_openai(api_key, model, base_url)
|
||||
.map_err(|e| format!("Failed to configure OpenAI: {}", e))?;
|
||||
}
|
||||
_ => return Err("不支持的 AI 提供商".to_string()),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============= 模板管理相关命令 =============
|
||||
|
||||
/// 列出所有模板
|
||||
#[tauri::command]
|
||||
async fn template_list(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Vec<ai::Template>, String> {
|
||||
let template_manager = state.template_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
template_manager.list_templates()
|
||||
.map_err(|e| format!("Failed to list templates: {}", e))
|
||||
}
|
||||
|
||||
/// 获取单个模板
|
||||
#[tauri::command]
|
||||
async fn template_get(
|
||||
id: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Option<ai::Template>, String> {
|
||||
let template_manager = state.template_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
Ok(template_manager.get_template(&id))
|
||||
}
|
||||
|
||||
/// 保存模板
|
||||
#[tauri::command]
|
||||
async fn template_save(
|
||||
template: ai::Template,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(), String> {
|
||||
let template_manager = state.template_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
template_manager.upsert_template(template)
|
||||
.map_err(|e| format!("Failed to save template: {}", e))
|
||||
}
|
||||
|
||||
/// 删除模板
|
||||
#[tauri::command]
|
||||
async fn template_delete(
|
||||
id: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<bool, String> {
|
||||
let template_manager = state.template_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
template_manager.delete_template(&id)
|
||||
.map_err(|e| format!("Failed to delete template: {}", e))
|
||||
}
|
||||
|
||||
/// 测试模板渲染
|
||||
#[tauri::command]
|
||||
async fn template_test(
|
||||
id: String,
|
||||
variables: HashMap<String, String>,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<(String, String), String> {
|
||||
let template_manager = state.template_manager.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
template_manager.test_template(&id, &variables)
|
||||
.map_err(|e| format!("Failed to test template: {}", e))
|
||||
}
|
||||
|
||||
// ============= 分类结果相关命令 =============
|
||||
|
||||
/// 获取记录的分类结果
|
||||
#[tauri::command]
|
||||
async fn classification_get(
|
||||
record_id: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Option<database::Classification>, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.get_classification(&record_id)
|
||||
.map_err(|e| format!("Failed to get classification: {}", e))
|
||||
}
|
||||
|
||||
/// 确认分类结果
|
||||
#[tauri::command]
|
||||
async fn classification_confirm(
|
||||
id: String,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<bool, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.confirm_classification(&id)
|
||||
.map_err(|e| format!("Failed to confirm classification: {}", e))
|
||||
}
|
||||
|
||||
/// 获取分类历史
|
||||
#[tauri::command]
|
||||
async fn classification_history(
|
||||
record_id: String,
|
||||
limit: Option<usize>,
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Vec<database::ClassificationHistory>, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.get_classification_history(&record_id, limit)
|
||||
.map_err(|e| format!("Failed to get classification history: {}", e))
|
||||
}
|
||||
|
||||
/// 获取分类统计
|
||||
#[tauri::command]
|
||||
async fn classification_stats(
|
||||
state: tauri::State<'_, AppState>,
|
||||
) -> Result<Vec<(String, usize)>, String> {
|
||||
let db = state.database.lock()
|
||||
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
|
||||
|
||||
db.get_category_stats()
|
||||
.map_err(|e| format!("Failed to get category stats: {}", e))
|
||||
}
|
||||
6
src-tauri/src/main.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
// Prevents additional console window on Windows in release, DO NOT REMOVE!!
|
||||
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
||||
|
||||
fn main() {
|
||||
app_lib::run();
|
||||
}
|
||||
380
src-tauri/src/ocr/cloud.rs
Normal file
@@ -0,0 +1,380 @@
|
||||
use super::{OcrConfig, OcrEngineType, OcrProgress, OcrResult};
|
||||
use super::result::{TextBlock, TextBlockType, BoundingBox};
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::Utc;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Sha256, Digest};
|
||||
use std::time::Instant;
|
||||
|
||||
/// 云端 OCR 提供者
|
||||
pub struct CloudOcrProvider {
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl CloudOcrProvider {
|
||||
/// 创建新的云端 OCR 提供者
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 执行 OCR 识别
|
||||
pub async fn recognize(
|
||||
&self,
|
||||
image_data: &[u8],
|
||||
config: &OcrConfig,
|
||||
progress_callback: Option<Box<dyn Fn(OcrProgress) + Send + Sync>>,
|
||||
) -> Result<OcrResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
if let Some(callback) = &progress_callback {
|
||||
(callback)(OcrProgress::Starting);
|
||||
}
|
||||
|
||||
let result = match config.engine {
|
||||
OcrEngineType::Baidu => {
|
||||
let baidu_config = config.baidu.as_ref()
|
||||
.context("百度 OCR 配置未设置")?;
|
||||
self.baidu_ocr(image_data, baidu_config, progress_callback).await?
|
||||
}
|
||||
OcrEngineType::Tencent => {
|
||||
let tencent_config = config.tencent.as_ref()
|
||||
.context("腾讯云 OCR 配置未设置")?;
|
||||
self.tencent_ocr(image_data, tencent_config, progress_callback).await?
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("不支持的 OCR 引擎类型")),
|
||||
};
|
||||
|
||||
let duration = start.elapsed().as_millis() as u64;
|
||||
let mut final_result = result;
|
||||
final_result.duration_ms = duration;
|
||||
|
||||
Ok(final_result)
|
||||
}
|
||||
|
||||
/// 百度 OCR 识别
|
||||
async fn baidu_ocr(
|
||||
&self,
|
||||
image_data: &[u8],
|
||||
config: &super::BaiduOcrConfig,
|
||||
progress_callback: Option<Box<dyn Fn(OcrProgress) + Send + Sync>>,
|
||||
) -> Result<OcrResult> {
|
||||
// 获取访问令牌
|
||||
let token = self.get_baidu_token(&config.api_key, &config.secret_key).await?;
|
||||
|
||||
// 编码图片为 Base64
|
||||
let base64_image = base64::encode(image_data);
|
||||
|
||||
// 构建请求
|
||||
let url = if config.accurate {
|
||||
"https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic"
|
||||
} else {
|
||||
"https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic"
|
||||
};
|
||||
|
||||
if let Some(callback) = &progress_callback {
|
||||
(callback)(OcrProgress::Processing {
|
||||
progress: 50.0,
|
||||
message: "正在调用百度 OCR API...".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.query(&[("access_token", &token)])
|
||||
.form(&[("image", base64_image)])
|
||||
.send()
|
||||
.await
|
||||
.context("百度 OCR API 请求失败")?;
|
||||
|
||||
let response_text = response.text().await?;
|
||||
let baidu_result: BaiduOcrResponse = serde_json::from_str(&response_text)
|
||||
.with_context(|| format!("百度 OCR 响应解析失败: {}", response_text))?;
|
||||
|
||||
if let Some(error_msg) = baidu_result.error_msg {
|
||||
return Err(anyhow::anyhow!("百度 OCR 错误: {} (错误码: {})",
|
||||
error_msg, baidu_result.error_code.unwrap_or(0)));
|
||||
}
|
||||
|
||||
// 转换为统一格式
|
||||
let blocks = baidu_result
|
||||
.words_result
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, item)| TextBlock {
|
||||
text: item.words,
|
||||
confidence: item.probability.unwrap_or(0.0) * 100.0,
|
||||
bbox: BoundingBox {
|
||||
x: 0,
|
||||
y: idx as u32 * 20, // 百度不返回位置信息,使用估算值
|
||||
width: 100,
|
||||
height: 20,
|
||||
},
|
||||
block_type: TextBlockType::Text,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let result = OcrResult::from_blocks(blocks, "Baidu".to_string(), 0);
|
||||
|
||||
if let Some(callback) = &progress_callback {
|
||||
(callback)(OcrProgress::Completed(result.clone()));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 获取百度访问令牌
|
||||
async fn get_baidu_token(&self, api_key: &str, secret_key: &str) -> Result<String> {
|
||||
let response = self
|
||||
.client
|
||||
.post("https://aip.baidubce.com/oauth/2.0/token")
|
||||
.query(&[
|
||||
("grant_type", "client_credentials"),
|
||||
("client_id", api_key),
|
||||
("client_secret", secret_key),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("获取百度访问令牌失败")?;
|
||||
|
||||
let token_response: BaiduTokenResponse = response.json().await?;
|
||||
Ok(token_response.access_token)
|
||||
}
|
||||
|
||||
/// 腾讯云 OCR 识别
|
||||
async fn tencent_ocr(
|
||||
&self,
|
||||
image_data: &[u8],
|
||||
config: &super::TencentOcrConfig,
|
||||
progress_callback: Option<Box<dyn Fn(OcrProgress) + Send + Sync>>,
|
||||
) -> Result<OcrResult> {
|
||||
let base64_image = base64::encode(image_data);
|
||||
|
||||
if let Some(callback) = &progress_callback {
|
||||
(callback)(OcrProgress::Processing {
|
||||
progress: 50.0,
|
||||
message: "正在调用腾讯云 OCR API...".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// 构建请求参数
|
||||
let request = TencentOcrRequest {
|
||||
image_base64: base64_image,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// 构建签名
|
||||
let timestamp = Utc::now().timestamp();
|
||||
let endpoint = "ocr.tencentcloudapi.com";
|
||||
let service = "ocr";
|
||||
let version = "2018-11-19";
|
||||
let action = "GeneralBasicOCR";
|
||||
|
||||
let authorization = self.tencent_authorization(
|
||||
&config.secret_id,
|
||||
&config.secret_key,
|
||||
endpoint,
|
||||
service,
|
||||
version,
|
||||
action,
|
||||
timestamp,
|
||||
&request,
|
||||
)?;
|
||||
|
||||
let url = format!("https://{}?Action={}&Version={}", endpoint, action, version);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", authorization)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Host", endpoint)
|
||||
.header("X-TC-Timestamp", timestamp.to_string())
|
||||
.header("X-TC-Region", &config.region)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("腾讯云 OCR API 请求失败")?;
|
||||
|
||||
let response_text = response.text().await?;
|
||||
let tencent_result: TencentOcrResponse = serde_json::from_str(&response_text)
|
||||
.with_context(|| format!("腾讯云 OCR 响应解析失败: {}", response_text))?;
|
||||
|
||||
if let Some(error) = &tencent_result.response.error {
|
||||
return Err(anyhow::anyhow!("腾讯云 OCR 错误: {} (代码: {})",
|
||||
error.message, error.code));
|
||||
}
|
||||
|
||||
// 转换为统一格式
|
||||
let blocks = tencent_result
|
||||
.response
|
||||
.text_detections
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|item| TextBlock {
|
||||
text: item.detected_text,
|
||||
confidence: item.confidence,
|
||||
bbox: BoundingBox {
|
||||
x: item.polygon.as_ref().map(|p| p[0].x as u32).unwrap_or(0),
|
||||
y: item.polygon.as_ref().map(|p| p[0].y as u32).unwrap_or(0),
|
||||
width: 100,
|
||||
height: 20,
|
||||
},
|
||||
block_type: TextBlockType::Text,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let result = OcrResult::from_blocks(blocks, "Tencent".to_string(), 0);
|
||||
|
||||
if let Some(callback) = &progress_callback {
|
||||
(callback)(OcrProgress::Completed(result.clone()));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 生成腾讯云 API 签名
|
||||
fn tencent_authorization(
|
||||
&self,
|
||||
secret_id: &str,
|
||||
secret_key: &str,
|
||||
endpoint: &str,
|
||||
service: &str,
|
||||
version: &str,
|
||||
action: &str,
|
||||
timestamp: i64,
|
||||
request: &TencentOcrRequest,
|
||||
) -> Result<String> {
|
||||
// 简化的签名实现(实际生产环境应使用完整的 HMAC-SHA256)
|
||||
let date = Utc::now().format("%Y-%m-%d").to_string();
|
||||
|
||||
// 构建签名字符串
|
||||
let payload = serde_json::to_string(request)?;
|
||||
let hashed_payload = format!("{:x}", Sha256::digest(payload.as_bytes()));
|
||||
|
||||
let credential_scope = format!("{}/{}/tc3_request", date, service);
|
||||
|
||||
// 简化版:实际应包含完整 HTTP 方法、路径、查询参数等
|
||||
let string_to_sign = format!(
|
||||
"TC3-HMAC-SHA256\n{}\n{}\n{}",
|
||||
timestamp, credential_scope, hashed_payload
|
||||
);
|
||||
|
||||
// 计算签名
|
||||
let secret_date = hmac_sha256(format!("TC3{}", secret_key).as_bytes(), date.as_bytes());
|
||||
let secret_service = hmac_sha256(&secret_date, service.as_bytes());
|
||||
let secret_signing = hmac_sha256(&secret_service, b"tc3_request");
|
||||
let signature = hex_encode(hmac_sha256(&secret_signing, string_to_sign.as_bytes()));
|
||||
|
||||
let authorization = format!(
|
||||
"TC3-HMAC-SHA256 Credential={}/{}, SignedHeaders=content-type;host, Signature={}",
|
||||
secret_id, credential_scope, signature
|
||||
);
|
||||
|
||||
Ok(authorization)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CloudOcrProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// 百度 OCR 令牌响应
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BaiduTokenResponse {
|
||||
access_token: String,
|
||||
}
|
||||
|
||||
/// 百度 OCR 响应
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BaiduOcrResponse {
|
||||
#[serde(rename = "words_result")]
|
||||
words_result: Option<Vec<BaiduWord>>,
|
||||
#[serde(rename = "error_code")]
|
||||
error_code: Option<i32>,
|
||||
#[serde(rename = "error_msg")]
|
||||
error_msg: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BaiduWord {
|
||||
words: String,
|
||||
probability: Option<f32>,
|
||||
}
|
||||
|
||||
/// 腾讯云 OCR 请求
|
||||
#[derive(Debug, Serialize, Default)]
|
||||
struct TencentOcrRequest {
|
||||
#[serde(rename = "ImageBase64")]
|
||||
image_base64: String,
|
||||
}
|
||||
|
||||
/// 腾讯云 OCR 响应
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TencentOcrResponse {
|
||||
#[serde(rename = "Response")]
|
||||
response: TencentResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TencentResponse {
|
||||
#[serde(rename = "Error")]
|
||||
error: Option<TencentError>,
|
||||
#[serde(rename = "TextDetections")]
|
||||
text_detections: Option<Vec<TencentTextDetection>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TencentError {
|
||||
#[serde(rename = "Code")]
|
||||
code: String,
|
||||
#[serde(rename = "Message")]
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TencentTextDetection {
|
||||
#[serde(rename = "DetectedText")]
|
||||
detected_text: String,
|
||||
#[serde(rename = "Confidence")]
|
||||
confidence: f32,
|
||||
#[serde(rename = "Polygon")]
|
||||
polygon: Option<Vec<TencentPoint>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TencentPoint {
|
||||
#[serde(rename = "X")]
|
||||
x: f32,
|
||||
#[serde(rename = "Y")]
|
||||
y: f32,
|
||||
}
|
||||
|
||||
/// HMAC-SHA256 辅助函数
|
||||
fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
|
||||
use hmac::Hmac;
|
||||
use hmac::Mac;
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key length error");
|
||||
mac.update(data);
|
||||
mac.finalize().into_bytes().to_vec()
|
||||
}
|
||||
|
||||
/// 十六进制编码
|
||||
fn hex_encode(data: Vec<u8>) -> String {
|
||||
data.iter().map(|b| format!("{:02x}", b)).collect()
|
||||
}
|
||||
|
||||
/// 云端 OCR 请求(用于前端调用)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CloudOcrRequest {
|
||||
pub image_path: String,
|
||||
pub engine: OcrEngineType,
|
||||
}
|
||||
174
src-tauri/src/ocr/local.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use super::{OcrProgress, OcrResult};
|
||||
use super::result::{TextBlock, TextBlockType, BoundingBox};
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Output};
|
||||
use std::time::Instant;
|
||||
|
||||
/// 本地 OCR 插件
|
||||
pub struct LocalOcrPlugin {
|
||||
plugin_path: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl LocalOcrPlugin {
|
||||
/// 创建新的本地 OCR 插件
|
||||
pub fn new(plugin_path: std::path::PathBuf) -> Self {
|
||||
Self { plugin_path }
|
||||
}
|
||||
|
||||
/// 执行 OCR 识别
|
||||
pub fn recognize(
|
||||
&self,
|
||||
image_path: &Path,
|
||||
progress_callback: Option<Box<dyn Fn(OcrProgress) + Send + Sync>>,
|
||||
) -> Result<OcrResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
if let Some(callback) = &progress_callback {
|
||||
(callback)(OcrProgress::Starting);
|
||||
}
|
||||
|
||||
// 检查插件是否存在
|
||||
if !self.plugin_path.exists() {
|
||||
return Err(anyhow::anyhow!("本地 OCR 插件不存在: {}",
|
||||
self.plugin_path.display()));
|
||||
}
|
||||
|
||||
if let Some(callback) = &progress_callback {
|
||||
(callback)(OcrProgress::Processing {
|
||||
progress: 30.0,
|
||||
message: "正在启动本地 OCR 引擎...".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// 调用本地插件
|
||||
let output = self.call_plugin(image_path)?;
|
||||
|
||||
if let Some(callback) = &progress_callback {
|
||||
(callback)(OcrProgress::Processing {
|
||||
progress: 70.0,
|
||||
message: "正在解析 OCR 结果...".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// 解析结果
|
||||
let plugin_result: LocalOcrResponse = serde_json::from_str(&output)
|
||||
.context("本地 OCR 插件返回的 JSON 格式错误")?;
|
||||
|
||||
if !plugin_result.success {
|
||||
return Err(anyhow::anyhow!("本地 OCR 失败: {}",
|
||||
plugin_result.error.unwrap_or_else(|| "未知错误".to_string())));
|
||||
}
|
||||
|
||||
// 转换为统一格式
|
||||
let blocks = plugin_result
|
||||
.blocks
|
||||
.into_iter()
|
||||
.map(|block| TextBlock {
|
||||
text: block.text,
|
||||
confidence: block.confidence,
|
||||
bbox: BoundingBox {
|
||||
x: block.bbox_x,
|
||||
y: block.bbox_y,
|
||||
width: block.bbox_width,
|
||||
height: block.bbox_height,
|
||||
},
|
||||
block_type: match block.block_type.as_str() {
|
||||
"title" => TextBlockType::Title,
|
||||
"list" => TextBlockType::List,
|
||||
"table" => TextBlockType::Table,
|
||||
_ => TextBlockType::Text,
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut result = OcrResult::from_blocks(blocks, "Local".to_string(), 0);
|
||||
|
||||
// 添加元数据
|
||||
if let Some(engine_info) = plugin_result.engine {
|
||||
result.metadata.insert("engine".to_string(), engine_info);
|
||||
}
|
||||
if let Some(lang) = plugin_result.language {
|
||||
result.metadata.insert("language".to_string(), lang);
|
||||
}
|
||||
|
||||
let duration = start.elapsed().as_millis() as u64;
|
||||
result.duration_ms = duration;
|
||||
|
||||
if let Some(callback) = &progress_callback {
|
||||
(callback)(OcrProgress::Completed(result.clone()));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 调用本地插件
|
||||
fn call_plugin(&self, image_path: &Path) -> Result<String> {
|
||||
let output = Command::new(&self.plugin_path)
|
||||
.arg("recognize")
|
||||
.arg(image_path)
|
||||
.output()
|
||||
.context("执行本地 OCR 插件失败")?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(anyhow::anyhow!("插件执行失败: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
}
|
||||
|
||||
/// 检查插件是否可用
|
||||
pub fn is_available(&self) -> bool {
|
||||
self.plugin_path.exists()
|
||||
}
|
||||
|
||||
/// 获取插件版本信息
|
||||
pub fn get_version(&self) -> Result<String> {
|
||||
let output = Command::new(&self.plugin_path)
|
||||
.arg("version")
|
||||
.output()
|
||||
.context("获取插件版本失败")?;
|
||||
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
|
||||
} else {
|
||||
Ok("unknown".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 本地 OCR 插件响应
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct LocalOcrResponse {
|
||||
/// 是否成功
|
||||
success: bool,
|
||||
/// 错误信息
|
||||
error: Option<String>,
|
||||
/// OCR 引擎信息
|
||||
engine: Option<String>,
|
||||
/// 识别语言
|
||||
language: Option<String>,
|
||||
/// 文本块列表
|
||||
blocks: Vec<LocalOcrBlock>,
|
||||
}
|
||||
|
||||
/// 本地 OCR 文本块
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct LocalOcrBlock {
|
||||
/// 文本内容
|
||||
text: String,
|
||||
/// 置信度 (0-100)
|
||||
confidence: f32,
|
||||
/// 边界框 X 坐标
|
||||
bbox_x: u32,
|
||||
/// 边界框 Y 坐标
|
||||
bbox_y: u32,
|
||||
/// 边界框宽度
|
||||
bbox_width: u32,
|
||||
/// 边界框高度
|
||||
bbox_height: u32,
|
||||
/// 文本块类型
|
||||
block_type: String,
|
||||
}
|
||||
82
src-tauri/src/ocr/mod.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
pub mod cloud;
|
||||
pub mod local;
|
||||
pub mod result;
|
||||
|
||||
pub use cloud::{CloudOcrProvider, CloudOcrRequest};
|
||||
pub use result::OcrResult;
|
||||
pub use local::LocalOcrPlugin;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// OCR 引擎类型
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum OcrEngineType {
|
||||
/// 百度 OCR
|
||||
Baidu,
|
||||
/// 腾讯云 OCR
|
||||
Tencent,
|
||||
/// 本地插件 OCR
|
||||
Local,
|
||||
}
|
||||
|
||||
/// OCR 配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrConfig {
|
||||
/// 选择的 OCR 引擎
|
||||
pub engine: OcrEngineType,
|
||||
/// 百度 API 配置
|
||||
pub baidu: Option<BaiduOcrConfig>,
|
||||
/// 腾讯云 API 配置
|
||||
pub tencent: Option<TencentOcrConfig>,
|
||||
/// 本地插件配置
|
||||
pub local: Option<LocalOcrConfig>,
|
||||
}
|
||||
|
||||
/// 百度 OCR 配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BaiduOcrConfig {
|
||||
/// API Key
|
||||
pub api_key: String,
|
||||
/// Secret Key
|
||||
pub secret_key: String,
|
||||
/// 是否使用通用文字识别(高精度版)
|
||||
pub accurate: bool,
|
||||
}
|
||||
|
||||
/// 腾讯云 OCR 配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TencentOcrConfig {
|
||||
/// Secret ID
|
||||
pub secret_id: String,
|
||||
/// Secret Key
|
||||
pub secret_key: String,
|
||||
/// 地域
|
||||
pub region: String,
|
||||
}
|
||||
|
||||
/// 本地 OCR 插件配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LocalOcrConfig {
|
||||
/// 插件可执行文件路径
|
||||
pub plugin_path: PathBuf,
|
||||
/// 语言模型路径
|
||||
pub model_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
/// OCR 进度回调
|
||||
pub type OcrProgressCallback = Box<dyn Fn(OcrProgress) + Send + Sync>;
|
||||
|
||||
/// OCR 进度信息
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OcrProgress {
|
||||
/// 开始识别
|
||||
Starting,
|
||||
/// 处理中
|
||||
Processing { progress: f32, message: String },
|
||||
/// 识别完成
|
||||
Completed(OcrResult),
|
||||
/// 识别失败
|
||||
Failed { error: String },
|
||||
}
|
||||
167
src-tauri/src/ocr/result.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// OCR 识别结果
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrResult {
|
||||
/// 识别的完整文本
|
||||
pub text: String,
|
||||
/// 识别的文本块列表
|
||||
pub blocks: Vec<TextBlock>,
|
||||
/// 识别置信度 (0-100)
|
||||
pub confidence: f32,
|
||||
/// 耗时(毫秒)
|
||||
pub duration_ms: u64,
|
||||
/// 使用的 OCR 引擎
|
||||
pub engine: String,
|
||||
/// 额外信息
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// 文本块
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TextBlock {
|
||||
/// 文本内容
|
||||
pub text: String,
|
||||
/// 置信度 (0-100)
|
||||
pub confidence: f32,
|
||||
/// 边界框(像素坐标)
|
||||
pub bbox: BoundingBox,
|
||||
/// 文本块类型
|
||||
pub block_type: TextBlockType,
|
||||
}
|
||||
|
||||
/// 边界框
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BoundingBox {
|
||||
/// 左上角 X 坐标
|
||||
pub x: u32,
|
||||
/// 左上角 Y 坐标
|
||||
pub y: u32,
|
||||
/// 宽度
|
||||
pub width: u32,
|
||||
/// 高度
|
||||
pub height: u32,
|
||||
}
|
||||
|
||||
/// 文本块类型
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TextBlockType {
|
||||
/// 文本
|
||||
Text,
|
||||
/// 标题
|
||||
Title,
|
||||
/// 列表
|
||||
List,
|
||||
/// 表格
|
||||
Table,
|
||||
/// 其他
|
||||
Other,
|
||||
}
|
||||
|
||||
impl OcrResult {
|
||||
/// 创建空的 OCR 结果
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
text: String::new(),
|
||||
blocks: Vec::new(),
|
||||
confidence: 0.0,
|
||||
duration_ms: 0,
|
||||
engine: String::new(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 从文本块列表构建结果
|
||||
pub fn from_blocks(
|
||||
blocks: Vec<TextBlock>,
|
||||
engine: String,
|
||||
duration_ms: u64,
|
||||
) -> Self {
|
||||
let text = blocks
|
||||
.iter()
|
||||
.map(|b| b.text.as_str())
|
||||
.collect::<Vec<&str>>()
|
||||
.join("\n");
|
||||
|
||||
let confidence = if blocks.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
blocks.iter().map(|b| b.confidence).sum::<f32>() / blocks.len() as f32
|
||||
};
|
||||
|
||||
Self {
|
||||
text,
|
||||
blocks,
|
||||
confidence,
|
||||
duration_ms,
|
||||
engine,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 搜索包含关键词的文本块
|
||||
pub fn search(&self, keyword: &str) -> Vec<&TextBlock> {
|
||||
let keyword_lower = keyword.to_lowercase();
|
||||
self.blocks
|
||||
.iter()
|
||||
.filter(|block| block.text.to_lowercase().contains(&keyword_lower))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 导出为纯文本
|
||||
pub fn to_plain_text(&self) -> String {
|
||||
self.text.clone()
|
||||
}
|
||||
|
||||
/// 导出为 Markdown 格式
|
||||
pub fn to_markdown(&self) -> String {
|
||||
let mut md = String::new();
|
||||
|
||||
for block in &self.blocks {
|
||||
match block.block_type {
|
||||
TextBlockType::Title => {
|
||||
md.push_str(&format!("## {}\n\n", block.text));
|
||||
}
|
||||
TextBlockType::List => {
|
||||
md.push_str(&format!("- {}\n", block.text));
|
||||
}
|
||||
TextBlockType::Table => {
|
||||
md.push_str(&format!("| {}\n", block.text));
|
||||
}
|
||||
_ => {
|
||||
md.push_str(&format!("{}\n", block.text));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
md
|
||||
}
|
||||
|
||||
/// 获取高亮显示的 HTML
|
||||
pub fn to_html_highlighted(&self, keyword: &str) -> String {
|
||||
let mut html = String::from("<div class='ocr-result'>");
|
||||
|
||||
for block in &self.blocks {
|
||||
let text = if !keyword.is_empty()
|
||||
&& block.text.to_lowercase().contains(&keyword.to_lowercase())
|
||||
{
|
||||
block.text.replace(
|
||||
&keyword.to_lowercase(),
|
||||
&format!("<mark>{}</mark>", keyword),
|
||||
)
|
||||
} else {
|
||||
block.text.clone()
|
||||
};
|
||||
|
||||
html.push_str(&format!(
|
||||
"<p class='ocr-block' style='position: absolute; left: {}px; top: {}px;'>{}</p>",
|
||||
block.bbox.x, block.bbox.y, text
|
||||
));
|
||||
}
|
||||
|
||||
html.push_str("</div>");
|
||||
html
|
||||
}
|
||||
}
|
||||
432
src-tauri/src/plugin/mod.rs
Normal file
@@ -0,0 +1,432 @@
|
||||
use anyhow::{Context, Result};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// 插件元数据
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PluginMetadata {
|
||||
/// 插件 ID
|
||||
pub id: String,
|
||||
/// 插件名称
|
||||
pub name: String,
|
||||
/// 插件描述
|
||||
pub description: String,
|
||||
/// 插件版本
|
||||
pub version: String,
|
||||
/// 插件作者
|
||||
pub author: String,
|
||||
/// 插件类型
|
||||
pub plugin_type: PluginType,
|
||||
/// 下载 URL
|
||||
pub download_url: String,
|
||||
/// SHA256 校验和
|
||||
pub sha256: String,
|
||||
/// 文件大小(字节)
|
||||
pub file_size: u64,
|
||||
/// 最低兼容版本
|
||||
pub min_app_version: String,
|
||||
/// 主页 URL
|
||||
pub homepage_url: Option<String>,
|
||||
/// 图标 URL
|
||||
pub icon_url: Option<String>,
|
||||
/// 依赖项
|
||||
pub dependencies: Vec<String>,
|
||||
/// 发布日期
|
||||
pub published_at: String,
|
||||
/// 许可证
|
||||
pub license: String,
|
||||
}
|
||||
|
||||
/// 插件类型
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PluginType {
|
||||
/// OCR 插件
|
||||
Ocr,
|
||||
/// 图床插件
|
||||
ImageHost,
|
||||
/// 其他
|
||||
Other,
|
||||
}
|
||||
|
||||
/// 插件状态
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PluginStatus {
|
||||
/// 插件元数据
|
||||
pub metadata: PluginMetadata,
|
||||
/// 是否已安装
|
||||
pub installed: bool,
|
||||
/// 安装路径
|
||||
pub install_path: Option<PathBuf>,
|
||||
/// 是否有更新
|
||||
pub has_update: bool,
|
||||
/// 安装时间
|
||||
pub installed_at: Option<String>,
|
||||
}
|
||||
|
||||
/// 插件管理器
|
||||
pub struct PluginManager {
|
||||
client: Client,
|
||||
config_dir: PathBuf,
|
||||
plugins_dir: PathBuf,
|
||||
registry_url: String,
|
||||
}
|
||||
|
||||
impl PluginManager {
|
||||
/// 创建新的插件管理器
|
||||
pub fn new(config_dir: PathBuf) -> Self {
|
||||
let plugins_dir = config_dir.join("plugins");
|
||||
fs::create_dir_all(&plugins_dir)
|
||||
.expect("Failed to create plugins directory");
|
||||
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap(),
|
||||
config_dir,
|
||||
plugins_dir,
|
||||
registry_url: "https://raw.githubusercontent.com/cutthenthink/plugins/main".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取远程插件列表
|
||||
pub async fn fetch_remote_plugins(&self) -> Result<Vec<PluginMetadata>> {
|
||||
let url = format!("{}/plugins.json", self.registry_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to fetch plugins list")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow::anyhow!("Failed to fetch plugins: HTTP {}", response.status()));
|
||||
}
|
||||
|
||||
let plugins: Vec<PluginMetadata> = response.json().await?;
|
||||
Ok(plugins)
|
||||
}
|
||||
|
||||
/// 获取插件状态列表
|
||||
pub async fn get_plugin_status(&self) -> Result<Vec<PluginStatus>> {
|
||||
let remote_plugins = self.fetch_remote_plugins().await?;
|
||||
let installed_plugins = self.get_installed_plugins()?;
|
||||
|
||||
let mut status_list = Vec::new();
|
||||
|
||||
for remote_plugin in remote_plugins {
|
||||
let installed = installed_plugins.get(&remote_plugin.id);
|
||||
|
||||
let has_update = if let Some(installed) = installed {
|
||||
installed.version != remote_plugin.version
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
status_list.push(PluginStatus {
|
||||
metadata: remote_plugin.clone(),
|
||||
installed: installed.is_some(),
|
||||
install_path: installed.and_then(|p| Some(p.install_path.clone()?)),
|
||||
has_update,
|
||||
installed_at: installed.and_then(|p| p.installed_at.clone()),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(status_list)
|
||||
}
|
||||
|
||||
/// 下载并安装插件
|
||||
pub async fn install_plugin(
|
||||
&self,
|
||||
plugin_id: &str,
|
||||
progress_callback: mpsc::Sender<InstallProgress>,
|
||||
) -> Result<PathBuf> {
|
||||
// 获取远程插件信息
|
||||
let plugins = self.fetch_remote_plugins().await?;
|
||||
let plugin = plugins
|
||||
.iter()
|
||||
.find(|p| p.id == plugin_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Plugin not found: {}", plugin_id))?;
|
||||
|
||||
// 发送开始事件
|
||||
let _ = progress_callback
|
||||
.send(InstallProgress::Starting {
|
||||
plugin_name: plugin.name.clone(),
|
||||
file_size: plugin.file_size,
|
||||
})
|
||||
.await;
|
||||
|
||||
// 下载文件
|
||||
let download_path = self
|
||||
.download_plugin(plugin, progress_callback.clone())
|
||||
.await?;
|
||||
|
||||
// 验证 SHA256
|
||||
let _ = progress_callback
|
||||
.send(InstallProgress::Verifying {
|
||||
plugin_name: plugin.name.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
self.verify_checksum(&download_path, &plugin.sha256)?;
|
||||
|
||||
// 解压/安装
|
||||
let _ = progress_callback
|
||||
.send(InstallProgress::Installing {
|
||||
plugin_name: plugin.name.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
let install_path = self.install_plugin_file(plugin, &download_path).await?;
|
||||
|
||||
// 保存安装记录
|
||||
self.save_install_record(plugin, &install_path)?;
|
||||
|
||||
// 清理下载文件
|
||||
let _ = fs::remove_file(&download_path);
|
||||
|
||||
// 发送完成事件
|
||||
let _ = progress_callback
|
||||
.send(InstallProgress::Completed {
|
||||
plugin_id: plugin.id.clone(),
|
||||
install_path: install_path.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(install_path)
|
||||
}
|
||||
|
||||
/// 卸载插件
|
||||
pub fn uninstall_plugin(&self, plugin_id: &str) -> Result<()> {
|
||||
let install_path = self.plugins_dir.join(plugin_id);
|
||||
|
||||
if install_path.exists() {
|
||||
fs::remove_dir_all(&install_path)
|
||||
.context("Failed to remove plugin directory")?;
|
||||
}
|
||||
|
||||
// 删除安装记录
|
||||
let record_path = self.get_install_record_path(plugin_id);
|
||||
if record_path.exists() {
|
||||
fs::remove_file(&record_path)?;
|
||||
}
|
||||
|
||||
log::info!("Uninstalled plugin: {}", plugin_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 下载插件文件
|
||||
async fn download_plugin(
|
||||
&self,
|
||||
plugin: &PluginMetadata,
|
||||
mut progress_callback: mpsc::Sender<InstallProgress>,
|
||||
) -> Result<PathBuf> {
|
||||
let response = self
|
||||
.client
|
||||
.get(&plugin.download_url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to download plugin")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow::anyhow!("Download failed: HTTP {}", response.status()));
|
||||
}
|
||||
|
||||
let total_size = response.content_length().unwrap_or(plugin.file_size);
|
||||
let download_path = self.config_dir.join("downloads").join(format!("{}.zip", &plugin.id));
|
||||
|
||||
fs::create_dir_all(download_path.parent().unwrap())?;
|
||||
|
||||
let mut file = tokio::fs::File::create(&download_path).await?;
|
||||
let mut downloaded = 0u64;
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
use futures_util::StreamExt;
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.context("Failed to read download chunk")?;
|
||||
file.write_all(&chunk).await?;
|
||||
downloaded += chunk.len() as u64;
|
||||
|
||||
let progress = (downloaded as f64 / total_size as f64) * 100.0;
|
||||
let _ = progress_callback
|
||||
.send(InstallProgress::Downloading {
|
||||
plugin_name: plugin.name.clone(),
|
||||
progress,
|
||||
downloaded,
|
||||
total_size,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
file.flush().await?;
|
||||
|
||||
Ok(download_path)
|
||||
}
|
||||
|
||||
/// 验证文件校验和
|
||||
fn verify_checksum(&self, file_path: &Path, expected_sha256: &str) -> Result<()> {
|
||||
let content = fs::read(file_path)?;
|
||||
let hash = Sha256::digest(&content);
|
||||
let calculated_sha256 = hex::encode(hash);
|
||||
|
||||
if calculated_sha256 != expected_sha256.to_lowercase() {
|
||||
fs::remove_file(file_path)?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"Checksum verification failed. Expected: {}, Got: {}",
|
||||
expected_sha256,
|
||||
calculated_sha256
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 安装插件文件
|
||||
async fn install_plugin_file(
|
||||
&self,
|
||||
plugin: &PluginMetadata,
|
||||
download_path: &Path,
|
||||
) -> Result<PathBuf> {
|
||||
let install_dir = self.plugins_dir.join(&plugin.id);
|
||||
|
||||
// 如果已存在,先删除
|
||||
if install_dir.exists() {
|
||||
fs::remove_dir_all(&install_dir)?;
|
||||
}
|
||||
|
||||
fs::create_dir_all(&install_dir)?;
|
||||
|
||||
// 对于本地 OCR 插件,直接复制可执行文件
|
||||
if plugin.plugin_type == PluginType::Ocr {
|
||||
let exe_name = if cfg!(windows) {
|
||||
format!("{}.exe", plugin.id)
|
||||
} else {
|
||||
plugin.id.clone()
|
||||
};
|
||||
|
||||
let exe_path = install_dir.join(&exe_name);
|
||||
fs::copy(download_path, &exe_path)?;
|
||||
|
||||
// 设置可执行权限
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let mut perms = fs::metadata(&exe_path)?.permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&exe_path, perms)?;
|
||||
}
|
||||
|
||||
Ok(exe_path)
|
||||
} else {
|
||||
// 其他类型需要解压
|
||||
Ok(install_dir)
|
||||
}
|
||||
}
|
||||
|
||||
/// 保存安装记录
|
||||
fn save_install_record(&self, plugin: &PluginMetadata, install_path: &Path) -> Result<()> {
|
||||
let record = InstallRecord {
|
||||
plugin_id: plugin.id.clone(),
|
||||
version: plugin.version.clone(),
|
||||
install_path: install_path.clone(),
|
||||
installed_at: SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)?
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
let record_path = self.get_install_record_path(&plugin.id);
|
||||
let content = serde_json::to_string_pretty(&record)?;
|
||||
fs::write(&record_path, content)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取安装记录路径
|
||||
fn get_install_record_path(&self, plugin_id: &str) -> PathBuf {
|
||||
self.config_dir.join(format!("plugin_{}.json", plugin_id))
|
||||
}
|
||||
|
||||
/// 获取已安装的插件
|
||||
fn get_installed_plugins(&self) -> Result<HashMap<String, InstallRecord>> {
|
||||
let mut plugins = HashMap::new();
|
||||
|
||||
for entry in fs::read_dir(&self.config_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("json")
|
||||
&& path.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.starts_with("plugin_"))
|
||||
== Some(true)
|
||||
{
|
||||
if let Ok(content) = fs::read_to_string(&path) {
|
||||
if let Ok(record) = serde_json::from_str::<InstallRecord>(&content) {
|
||||
plugins.insert(record.plugin_id.clone(), record);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(plugins)
|
||||
}
|
||||
|
||||
/// 检查插件更新
|
||||
pub async fn check_updates(&self) -> Result<Vec<String>> {
|
||||
let status_list = self.get_plugin_status().await?;
|
||||
let updates: Vec<String> = status_list
|
||||
.into_iter()
|
||||
.filter(|s| s.has_update)
|
||||
.map(|s| s.metadata.id)
|
||||
.collect();
|
||||
|
||||
Ok(updates)
|
||||
}
|
||||
}
|
||||
|
||||
/// 安装记录
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct InstallRecord {
|
||||
plugin_id: String,
|
||||
version: String,
|
||||
install_path: PathBuf,
|
||||
installed_at: u64,
|
||||
}
|
||||
|
||||
/// 安装进度
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum InstallProgress {
|
||||
Starting {
|
||||
plugin_name: String,
|
||||
file_size: u64,
|
||||
},
|
||||
Downloading {
|
||||
plugin_name: String,
|
||||
progress: f64,
|
||||
downloaded: u64,
|
||||
total_size: u64,
|
||||
},
|
||||
Verifying {
|
||||
plugin_name: String,
|
||||
},
|
||||
Installing {
|
||||
plugin_name: String,
|
||||
},
|
||||
Completed {
|
||||
plugin_id: String,
|
||||
install_path: PathBuf,
|
||||
},
|
||||
Failed {
|
||||
plugin_id: String,
|
||||
error: String,
|
||||
},
|
||||
}
|
||||
308
src-tauri/src/screenshot.rs
Normal file
@@ -0,0 +1,308 @@
|
||||
use anyhow::Result;
|
||||
use arboard::Clipboard;
|
||||
use chrono::Utc;
|
||||
use image::{DynamicImage, ImageFormat};
|
||||
use screenshots::Screen;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::io::Cursor;
|
||||
|
||||
/// 截图元数据
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ScreenshotMetadata {
|
||||
pub id: String,
|
||||
pub filename: String,
|
||||
pub filepath: PathBuf,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub file_size: u64,
|
||||
pub created_at: String,
|
||||
pub thumbnail_base64: Option<String>,
|
||||
}
|
||||
|
||||
/// 截图区域选择参数
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RegionSelection {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
}
|
||||
|
||||
/// 截图管理器
|
||||
pub struct ScreenshotManager {
|
||||
base_dir: PathBuf,
|
||||
screenshots_dir: PathBuf,
|
||||
thumbnails_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl ScreenshotManager {
|
||||
/// 创建新的截图管理器
|
||||
pub fn new(base_dir: PathBuf) -> Result<Self> {
|
||||
let screenshots_dir = base_dir.join("screenshots");
|
||||
let thumbnails_dir = base_dir.join("thumbnails");
|
||||
|
||||
// 创建目录
|
||||
fs::create_dir_all(&screenshots_dir)?;
|
||||
fs::create_dir_all(&thumbnails_dir)?;
|
||||
|
||||
Ok(Self {
|
||||
base_dir,
|
||||
screenshots_dir,
|
||||
thumbnails_dir,
|
||||
})
|
||||
}
|
||||
|
||||
/// 使用默认目录创建截图管理器
|
||||
pub fn with_default_dir() -> Result<Self> {
|
||||
let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||
let app_data_dir = home_dir.join(".cutthink-lite");
|
||||
Self::new(app_data_dir)
|
||||
}
|
||||
|
||||
/// 生成时间戳文件名
|
||||
fn generate_filename(&self) -> String {
|
||||
let timestamp = Utc::now().format("%Y%m%d_%H%M%S_%3f");
|
||||
format!("screenshot_{}.png", timestamp)
|
||||
}
|
||||
|
||||
/// 生成唯一 ID
|
||||
fn generate_id(&self, filename: &str) -> String {
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
filename.hash(&mut hasher);
|
||||
format!("{:x}", hasher.finish())
|
||||
}
|
||||
|
||||
/// 全屏截图
|
||||
pub fn capture_fullscreen(&self) -> Result<ScreenshotMetadata> {
|
||||
let screen = Screen::all()?
|
||||
.first()
|
||||
.ok_or_else(|| anyhow::anyhow!("No screen found"))?
|
||||
.clone();
|
||||
|
||||
let image = screen.capture()?;
|
||||
let width = image.width();
|
||||
let height = image.height();
|
||||
|
||||
let filename = self.generate_filename();
|
||||
let filepath = self.screenshots_dir.join(&filename);
|
||||
let id = self.generate_id(&filename);
|
||||
let created_at = Utc::now().to_rfc3339();
|
||||
|
||||
// 保存图片
|
||||
image.save(&filepath)?;
|
||||
|
||||
// 生成缩略图
|
||||
let thumbnail_base64 = self.generate_thumbnail(&filepath, width, height)?;
|
||||
|
||||
// 获取文件大小
|
||||
let file_size = fs::metadata(&filepath)?.len();
|
||||
|
||||
Ok(ScreenshotMetadata {
|
||||
id,
|
||||
filename,
|
||||
filepath,
|
||||
width,
|
||||
height,
|
||||
file_size,
|
||||
created_at,
|
||||
thumbnail_base64,
|
||||
})
|
||||
}
|
||||
|
||||
/// 区域截图
|
||||
pub fn capture_region(&self, region: RegionSelection) -> Result<ScreenshotMetadata> {
|
||||
let screen = Screen::all()?
|
||||
.first()
|
||||
.ok_or_else(|| anyhow::anyhow!("No screen found"))?
|
||||
.clone();
|
||||
|
||||
let full_image = screen.capture()?;
|
||||
|
||||
// 确保区域在屏幕范围内
|
||||
let x = region.x.max(0);
|
||||
let y = region.y.max(0);
|
||||
let width = region.width.min(full_image.width() as u32 - x as u32);
|
||||
let height = region.height.min(full_image.height() as u32 - y as u32);
|
||||
|
||||
// 裁剪区域
|
||||
let image_buffer = full_image.buffer();
|
||||
let cropped_image = image::imageops::crop(
|
||||
image_buffer,
|
||||
x as u32,
|
||||
y as u32,
|
||||
width,
|
||||
height,
|
||||
).to_image();
|
||||
|
||||
let dynamic_image = DynamicImage::ImageRgba8(cropped_image);
|
||||
let width = dynamic_image.width();
|
||||
let height = dynamic_image.height();
|
||||
|
||||
let filename = self.generate_filename();
|
||||
let filepath = self.screenshots_dir.join(&filename);
|
||||
let id = self.generate_id(&filename);
|
||||
let created_at = Utc::now().to_rfc3339();
|
||||
|
||||
// 保存图片
|
||||
dynamic_image.save(&filepath)?;
|
||||
|
||||
// 生成缩略图
|
||||
let thumbnail_base64 = self.generate_thumbnail(&filepath, width, height)?;
|
||||
|
||||
// 获取文件大小
|
||||
let file_size = fs::metadata(&filepath)?.len();
|
||||
|
||||
Ok(ScreenshotMetadata {
|
||||
id,
|
||||
filename,
|
||||
filepath,
|
||||
width,
|
||||
height,
|
||||
file_size,
|
||||
created_at,
|
||||
thumbnail_base64,
|
||||
})
|
||||
}
|
||||
|
||||
/// 生成缩略图并转换为 base64
|
||||
fn generate_thumbnail(&self, filepath: &Path, width: u32, height: u32) -> Result<Option<String>> {
|
||||
const THUMBNAIL_MAX_SIZE: u32 = 200;
|
||||
|
||||
// 如果图片已经很小,直接使用原图
|
||||
if width <= THUMBNAIL_MAX_SIZE && height <= THUMBNAIL_MAX_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// 计算缩略图尺寸(保持宽高比)
|
||||
let scale = THUMBNAIL_MAX_SIZE as f32 / width.max(height) as f32;
|
||||
let new_width = (width as f32 * scale) as u32;
|
||||
let new_height = (height as f32 * scale) as u32;
|
||||
|
||||
// 加载并缩放图片
|
||||
let img = image::open(filepath)?;
|
||||
let thumbnail = image::imageops::thumbnail(&img, new_width, new_height);
|
||||
|
||||
// 转换为 base64
|
||||
let mut buffer = Cursor::new(Vec::new());
|
||||
thumbnail.write_to(&mut buffer, ImageFormat::Png)?;
|
||||
let base64_string = base64::encode(&buffer.into_inner());
|
||||
|
||||
Ok(Some(format!("data:image/png;base64,{}", base64_string)))
|
||||
}
|
||||
|
||||
/// 复制图片到剪贴板
|
||||
pub fn copy_to_clipboard(&self, filepath: &Path) -> Result<()> {
|
||||
let img = image::open(filepath)?;
|
||||
let rgba = img.to_rgba8();
|
||||
|
||||
// 构建 PNG 格式数据
|
||||
let mut buffer = Cursor::new(Vec::new());
|
||||
rgba.write_to(&mut buffer, ImageFormat::Png)?;
|
||||
let png_data = buffer.into_inner();
|
||||
|
||||
// 复制到剪贴板
|
||||
let mut clipboard = Clipboard::new()?;
|
||||
|
||||
// 注意:arboard 的图片支持可能有限,这里提供基础实现
|
||||
// 如果不工作,可能需要使用平台特定的剪贴板 API
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
clipboard.set_image(arboard::ImageData {
|
||||
width: rgba.width() as usize,
|
||||
height: rgba.height() as usize,
|
||||
bytes: rgba.as_raw().into(),
|
||||
})?;
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
// 对于非 Windows 平台,先尝试复制文件路径作为后备
|
||||
clipboard.set_text(filepath.to_string_lossy().to_string())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 删除截图
|
||||
pub fn delete_screenshot(&self, filepath: &Path) -> Result<()> {
|
||||
if filepath.exists() {
|
||||
fs::remove_file(filepath)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取所有截图
|
||||
pub fn list_screenshots(&self) -> Result<Vec<ScreenshotMetadata>> {
|
||||
let mut screenshots = Vec::new();
|
||||
|
||||
let entries = fs::read_dir(&self.screenshots_dir)?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("png") {
|
||||
if let Ok(metadata) = fs::metadata(&path) {
|
||||
if let Ok(img) = image::open(&path) {
|
||||
let filename = path.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown.png")
|
||||
.to_string();
|
||||
|
||||
let id = self.generate_id(&filename);
|
||||
let created_at = metadata.created()
|
||||
.ok()
|
||||
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
|
||||
.map(|d| {
|
||||
let datetime: chrono::DateTime<Utc> = DateTime::from(d);
|
||||
datetime.to_rfc3339()
|
||||
})
|
||||
.unwrap_or_else(|| Utc::now().to_rfc3339());
|
||||
|
||||
let thumbnail_base64 = self.generate_thumbnail(&path, img.width(), img.height())?;
|
||||
|
||||
screenshots.push(ScreenshotMetadata {
|
||||
id,
|
||||
filename,
|
||||
filepath: path,
|
||||
width: img.width(),
|
||||
height: img.height(),
|
||||
file_size: metadata.len(),
|
||||
created_at,
|
||||
thumbnail_base64,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 按创建时间排序(最新的在前)
|
||||
screenshots.sort_by(|a, b| b.created_at.cmp(&a.created_at));
|
||||
|
||||
Ok(screenshots)
|
||||
}
|
||||
|
||||
/// 清理旧截图(保留最近 N 个)
|
||||
pub fn cleanup_old_screenshots(&self, keep_count: usize) -> Result<usize> {
|
||||
let mut screenshots = self.list_screenshots()?;
|
||||
let deleted_count = if screenshots.len() > keep_count {
|
||||
screenshots.truncate(keep_count);
|
||||
let to_delete = &screenshots[keep_count..];
|
||||
for screenshot in to_delete {
|
||||
let _ = self.delete_screenshot(&screenshot.filepath);
|
||||
}
|
||||
to_delete.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
Ok(deleted_count)
|
||||
}
|
||||
}
|
||||
|
||||
use chrono::DateTime;
|
||||
226
src-tauri/src/secure_storage.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
use aes_gcm::{
|
||||
aead::{Aead, AeadCore, KeyInit, OsRng},
|
||||
Aes256Gcm, Nonce,
|
||||
};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// 安全存储管理器
|
||||
pub struct SecureStorage {
|
||||
storage_path: PathBuf,
|
||||
key: [u8; 32],
|
||||
}
|
||||
|
||||
impl SecureStorage {
|
||||
/// 创建新的安全存储
|
||||
pub fn new(storage_path: PathBuf, password: &str) -> Result<Self> {
|
||||
// 从密码派生加密密钥
|
||||
let key = Self::derive_key(password);
|
||||
|
||||
// 确保存储目录存在
|
||||
if let Some(parent) = storage_path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.context("Failed to create storage directory")?;
|
||||
}
|
||||
|
||||
Ok(Self { storage_path, key })
|
||||
}
|
||||
|
||||
/// 从密码派生密钥
|
||||
fn derive_key(password: &str) -> [u8; 32] {
|
||||
let mut key = [0u8; 32];
|
||||
|
||||
// 使用 PBKDF2 的简化版本(使用 SHA256 多次哈希)
|
||||
let mut hash = Sha256::digest(password.as_bytes());
|
||||
for _ in 0..10000 {
|
||||
hash = Sha256::digest(&hash);
|
||||
}
|
||||
|
||||
key.copy_from_slice(&hash[..32]);
|
||||
key
|
||||
}
|
||||
|
||||
/// 保存敏感数据
|
||||
pub fn save(&self, key: &str, value: &str) -> Result<()> {
|
||||
// 加载现有数据
|
||||
let mut data = self.load_data().unwrap_or_default();
|
||||
|
||||
// 加密值
|
||||
let encrypted = self.encrypt(value)?;
|
||||
|
||||
// 更新数据
|
||||
data.insert(key.to_string(), EncryptedValue {
|
||||
ciphertext: hex::encode(encrypted.0),
|
||||
nonce: hex::encode(encrypted.1),
|
||||
});
|
||||
|
||||
// 保存到文件
|
||||
self.save_data(&data)?;
|
||||
|
||||
log::info!("Saved encrypted value for key: {}", key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取敏感数据
|
||||
pub fn get(&self, key: &str) -> Result<Option<String>> {
|
||||
let data = self.load_data()?;
|
||||
|
||||
if let Some(encrypted) = data.get(key) {
|
||||
let ciphertext = hex::decode(&encrypted.ciphertext)
|
||||
.context("Failed to decode ciphertext")?;
|
||||
let nonce = hex::decode(&encrypted.nonce)
|
||||
.context("Failed to decode nonce")?;
|
||||
|
||||
let nonce = Nonce::from_slice(&nonce);
|
||||
let cipher = Aes256Gcm::new(&self.key.into());
|
||||
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, ciphertext.as_ref())
|
||||
.map_err(|_| anyhow!("Decryption failed - wrong password or corrupted data"))?;
|
||||
|
||||
Ok(Some(String::from_utf8(plaintext)?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// 删除敏感数据
|
||||
pub fn delete(&self, key: &str) -> Result<bool> {
|
||||
let mut data = self.load_data()?;
|
||||
|
||||
if data.remove(key).is_some() {
|
||||
self.save_data(&data)?;
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// 列出所有键
|
||||
pub fn list_keys(&self) -> Result<Vec<String>> {
|
||||
let data = self.load_data()?;
|
||||
Ok(data.keys().cloned().collect())
|
||||
}
|
||||
|
||||
/// 加密数据
|
||||
fn encrypt(&self, plaintext: &str) -> Result<(Vec<u8>, [u8; 12])> {
|
||||
let cipher = Aes256Gcm::new(&self.key.into());
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
|
||||
|
||||
let ciphertext = cipher
|
||||
.encrypt(&nonce, plaintext.as_bytes())
|
||||
.map_err(|_| anyhow!("Encryption failed"))?;
|
||||
|
||||
Ok((ciphertext, nonce.into()))
|
||||
}
|
||||
|
||||
/// 加载加密数据
|
||||
fn load_data(&self) -> Result<StorageData> {
|
||||
if !self.storage_path.exists() {
|
||||
return Ok(StorageData::new());
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&self.storage_path)
|
||||
.context("Failed to read storage file")?;
|
||||
|
||||
serde_json::from_str(&content).context("Failed to parse storage file")
|
||||
}
|
||||
|
||||
/// 保存加密数据
|
||||
fn save_data(&self, data: &StorageData) -> Result<()> {
|
||||
let content = serde_json::to_string_pretty(data)
|
||||
.context("Failed to serialize storage data")?;
|
||||
|
||||
fs::write(&self.storage_path, content)
|
||||
.context("Failed to write storage file")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// 加密值
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct EncryptedValue {
|
||||
ciphertext: String,
|
||||
nonce: String,
|
||||
}
|
||||
|
||||
/// 存储数据
|
||||
type StorageData = std::collections::HashMap<String, EncryptedValue>;
|
||||
|
||||
/// API 密钥存储
|
||||
pub struct ApiKeyStorage {
|
||||
storage: SecureStorage,
|
||||
}
|
||||
|
||||
impl ApiKeyStorage {
|
||||
/// 创建新的 API 密钥存储
|
||||
pub fn new(config_dir: &Path, password: &str) -> Result<Self> {
|
||||
let storage_path = config_dir.join("secure_storage.json");
|
||||
let storage = SecureStorage::new(storage_path, password)?;
|
||||
|
||||
Ok(Self { storage })
|
||||
}
|
||||
|
||||
/// 保存百度 API 密钥
|
||||
pub fn save_baidu_key(&self, api_key: &str, secret_key: &str) -> Result<()> {
|
||||
self.storage.save("baidu_ocr_api_key", api_key)?;
|
||||
self.storage.save("baidu_ocr_secret_key", secret_key)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取百度 API 密钥
|
||||
pub fn get_baidu_key(&self) -> Result<(Option<String>, Option<String>)> {
|
||||
let api_key = self.storage.get("baidu_ocr_api_key")?;
|
||||
let secret_key = self.storage.get("baidu_ocr_secret_key")?;
|
||||
Ok((api_key, secret_key))
|
||||
}
|
||||
|
||||
/// 保存腾讯云 API 密钥
|
||||
pub fn save_tencent_key(&self, secret_id: &str, secret_key: &str) -> Result<()> {
|
||||
self.storage.save("tencent_ocr_secret_id", secret_id)?;
|
||||
self.storage.save("tencent_ocr_secret_key", secret_key)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取腾讯云 API 密钥
|
||||
pub fn get_tencent_key(&self) -> Result<(Option<String>, Option<String>)> {
|
||||
let secret_id = self.storage.get("tencent_ocr_secret_id")?;
|
||||
let secret_key = self.storage.get("tencent_ocr_secret_key")?;
|
||||
Ok((secret_id, secret_key))
|
||||
}
|
||||
|
||||
/// 删除所有密钥
|
||||
pub fn clear_all(&self) -> Result<()> {
|
||||
for key in self.storage.list_keys()? {
|
||||
self.storage.delete(&key)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_secure_storage() {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage_path = dir.path().join("test_storage.json");
|
||||
let storage = SecureStorage::new(storage_path, "test_password").unwrap();
|
||||
|
||||
// 保存和获取
|
||||
storage.save("test_key", "test_value").unwrap();
|
||||
let value = storage.get("test_key").unwrap().unwrap();
|
||||
assert_eq!(value, "test_value");
|
||||
|
||||
// 删除
|
||||
assert!(storage.delete("test_key").unwrap());
|
||||
assert!(!storage.delete("test_key").unwrap());
|
||||
}
|
||||
}
|
||||
387
src-tauri/src/upload.rs
Normal file
@@ -0,0 +1,387 @@
|
||||
use crate::config::ImageHostConfig;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
|
||||
use chrono::Utc;
|
||||
use reqwest::header;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
/// 上传结果
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UploadResult {
|
||||
pub url: String,
|
||||
pub delete_url: Option<String>,
|
||||
pub image_host: String,
|
||||
pub uploaded_at: String,
|
||||
pub file_size: u64,
|
||||
}
|
||||
|
||||
/// 上传进度事件
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum UploadProgress {
|
||||
Starting,
|
||||
Uploading { progress: f32, message: String },
|
||||
Completed(UploadResult),
|
||||
Failed { error: String },
|
||||
}
|
||||
|
||||
/// 上传器
|
||||
pub struct Uploader {
|
||||
retry_count: u32,
|
||||
timeout_seconds: u64,
|
||||
}
|
||||
|
||||
impl Uploader {
|
||||
/// 创建新的上传器
|
||||
pub fn new(retry_count: u32, timeout_seconds: u64) -> Self {
|
||||
Self {
|
||||
retry_count,
|
||||
timeout_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
/// 上传图片到指定图床
|
||||
pub async fn upload_image(
|
||||
&self,
|
||||
image_path: &Path,
|
||||
config: &ImageHostConfig,
|
||||
) -> Result<UploadResult> {
|
||||
// 读取图片数据
|
||||
let image_data = fs::read(image_path)
|
||||
.context("Failed to read image file")?;
|
||||
let file_size = image_data.len() as u64;
|
||||
|
||||
// 根据图床类型上传
|
||||
let result = match config {
|
||||
ImageHostConfig::GitHub { .. } => {
|
||||
self.upload_to_github(image_path, &image_data, config).await?
|
||||
}
|
||||
ImageHostConfig::Imgur { .. } => {
|
||||
self.upload_to_imgur(image_path, &image_data, config).await?
|
||||
}
|
||||
ImageHostConfig::Custom { .. } => {
|
||||
self.upload_to_custom(image_path, &image_data, config).await?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 上传到 GitHub
|
||||
async fn upload_to_github(
|
||||
&self,
|
||||
image_path: &Path,
|
||||
image_data: &[u8],
|
||||
config: &ImageHostConfig,
|
||||
) -> Result<UploadResult> {
|
||||
let (token, owner, repo, path, branch) = match config {
|
||||
ImageHostConfig::GitHub {
|
||||
token,
|
||||
owner,
|
||||
repo,
|
||||
path,
|
||||
branch,
|
||||
} => (token, owner, repo, path, branch),
|
||||
_ => return Err(anyhow!("Invalid GitHub config")),
|
||||
};
|
||||
|
||||
let branch = branch.as_deref().unwrap_or("main");
|
||||
let filename = image_path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.ok_or_else(|| anyhow!("Invalid filename"))?;
|
||||
|
||||
let github_path = if path.is_empty() {
|
||||
filename.to_string()
|
||||
} else {
|
||||
format!("{}/{}", path.trim_end_matches('/'), filename)
|
||||
};
|
||||
|
||||
// Base64 编码
|
||||
let content = BASE64.encode(image_data);
|
||||
|
||||
// 构建 API 请求
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.timeout_seconds))
|
||||
.build()?;
|
||||
|
||||
let url = format!(
|
||||
"https://api.github.com/repos/{}/{}/contents/{}",
|
||||
owner, repo, github_path
|
||||
);
|
||||
|
||||
let request_body = serde_json::json!({
|
||||
"message": format!("Upload {}", filename),
|
||||
"content": content,
|
||||
"branch": branch
|
||||
});
|
||||
|
||||
let response = client
|
||||
.put(&url)
|
||||
.header(header::AUTHORIZATION, format!("Bearer {}", token))
|
||||
.header(header::ACCEPT, "application/vnd.github.v3+json")
|
||||
.header(header::USER_AGENT, "CutThenThink-Lite")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to GitHub")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(anyhow!("GitHub API error: {}", error_text));
|
||||
}
|
||||
|
||||
let response_json: serde_json::Value = response.json().await?;
|
||||
let download_url = response_json["content"]["download_url"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Missing download_url in response"))?
|
||||
.to_string();
|
||||
|
||||
// 构建删除 URL(使用 SHA)
|
||||
let sha = response_json["content"]["sha"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Missing sha in response"))?;
|
||||
|
||||
let delete_url = Some(format!(
|
||||
"github://{}://{}/{}?sha={}",
|
||||
owner, repo, github_path, sha
|
||||
));
|
||||
|
||||
Ok(UploadResult {
|
||||
url: download_url,
|
||||
delete_url,
|
||||
image_host: "GitHub".to_string(),
|
||||
uploaded_at: Utc::now().to_rfc3339(),
|
||||
file_size: image_data.len() as u64,
|
||||
})
|
||||
}
|
||||
|
||||
/// 上传到 Imgur
|
||||
async fn upload_to_imgur(
|
||||
&self,
|
||||
image_path: &Path,
|
||||
image_data: &[u8],
|
||||
config: &ImageHostConfig,
|
||||
) -> Result<UploadResult> {
|
||||
let client_id = match config {
|
||||
ImageHostConfig::Imgur { client_id } => client_id,
|
||||
_ => return Err(anyhow!("Invalid Imgur config")),
|
||||
};
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.timeout_seconds))
|
||||
.build()?;
|
||||
|
||||
let form = reqwest::multipart::Form::new().part(
|
||||
"image",
|
||||
reqwest::multipart::Part::bytes(image_data.to_vec())
|
||||
.file_name(
|
||||
image_path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("image.png")
|
||||
.to_string(),
|
||||
)
|
||||
.mime_str("image/png")
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let response = client
|
||||
.post("https://api.imgur.com/3/image")
|
||||
.header(header::AUTHORIZATION, format!("Client-ID {}", client_id))
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to Imgur")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(anyhow!("Imgur API error: {}", error_text));
|
||||
}
|
||||
|
||||
let response_json: serde_json::Value = response.json().await?;
|
||||
|
||||
if response_json["success"].as_bool() != Some(true) {
|
||||
return Err(anyhow!("Imgur upload failed"));
|
||||
}
|
||||
|
||||
let url = response_json["data"]["link"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Missing link in response"))?
|
||||
.to_string();
|
||||
|
||||
let delete_hash = response_json["data"]["deletehash"]
|
||||
.as_str()
|
||||
.map(|hash| format!("imgur://{}", hash));
|
||||
|
||||
Ok(UploadResult {
|
||||
url,
|
||||
delete_url: delete_hash,
|
||||
image_host: "Imgur".to_string(),
|
||||
uploaded_at: Utc::now().to_rfc3339(),
|
||||
file_size: image_data.len() as u64,
|
||||
})
|
||||
}
|
||||
|
||||
/// 上传到自定义图床
|
||||
async fn upload_to_custom(
|
||||
&self,
|
||||
image_path: &Path,
|
||||
image_data: &[u8],
|
||||
config: &ImageHostConfig,
|
||||
) -> Result<UploadResult> {
|
||||
let (url, headers, form_field) = match config {
|
||||
ImageHostConfig::Custom {
|
||||
url,
|
||||
headers,
|
||||
form_field,
|
||||
} => (url, headers, form_field),
|
||||
_ => return Err(anyhow!("Invalid custom config")),
|
||||
};
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.timeout_seconds))
|
||||
.build()?;
|
||||
|
||||
let field_name = form_field.as_deref().unwrap_or("file");
|
||||
|
||||
let mut form = reqwest::multipart::Form::new().part(
|
||||
field_name.to_string(),
|
||||
reqwest::multipart::Part::bytes(image_data.to_vec())
|
||||
.file_name(
|
||||
image_path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("image.png")
|
||||
.to_string(),
|
||||
)
|
||||
.mime_str("image/png")
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let mut request = client.post(url).multipart(form);
|
||||
|
||||
// 添加自定义头部
|
||||
if let Some(headers) = headers {
|
||||
for header in headers {
|
||||
request = request.header(&header.name, &header.value);
|
||||
}
|
||||
}
|
||||
|
||||
let response = request
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to custom host")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(anyhow!("Custom host error: {}", error_text));
|
||||
}
|
||||
|
||||
// 尝试解析 JSON 响应
|
||||
let response_json: serde_json::Value = response.json().await?;
|
||||
|
||||
// 尝试从响应中提取 URL
|
||||
let url = self.extract_url_from_response(&response_json)?;
|
||||
|
||||
Ok(UploadResult {
|
||||
url,
|
||||
delete_url: None,
|
||||
image_host: "Custom".to_string(),
|
||||
uploaded_at: Utc::now().to_rfc3339(),
|
||||
file_size: image_data.len() as u64,
|
||||
})
|
||||
}
|
||||
|
||||
/// 从响应中提取 URL
|
||||
fn extract_url_from_response(&self, response: &serde_json::Value) -> Result<String> {
|
||||
// 尝试常见的响应字段
|
||||
let url_fields = vec!["url", "link", "image_url", "data", "url"];
|
||||
|
||||
for field in url_fields {
|
||||
if let Some(url_value) = response.get(field) {
|
||||
// 如果是字符串,直接使用
|
||||
if let Some(url_str) = url_value.as_str() {
|
||||
return Ok(url_str.to_string());
|
||||
}
|
||||
// 如果是对象,尝试获取 url 字段
|
||||
if let Some(obj) = url_value.as_object() {
|
||||
if let Some(url_str) = obj.get("url").and_then(|v| v.as_str()) {
|
||||
return Ok(url_str.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找不到,返回整个响应的字符串表示
|
||||
Err(anyhow!(
|
||||
"Could not extract URL from response: {}",
|
||||
serde_json::to_string(response).unwrap_or_default()
|
||||
))
|
||||
}
|
||||
|
||||
/// 上传图片(带重试)
|
||||
pub async fn upload_with_retry(
|
||||
&self,
|
||||
image_path: &Path,
|
||||
config: &ImageHostConfig,
|
||||
mut progress_callback: impl FnMut(UploadProgress),
|
||||
) -> Result<UploadResult> {
|
||||
progress_callback(UploadProgress::Starting);
|
||||
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 0..=self.retry_count {
|
||||
if attempt > 0 {
|
||||
progress_callback(UploadProgress::Uploading {
|
||||
progress: (attempt as f32 / (self.retry_count + 1) as f32) * 100.0,
|
||||
message: format!("重试上传 {}/{}", attempt, self.retry_count),
|
||||
});
|
||||
|
||||
// 等待一段时间再重试
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
}
|
||||
|
||||
match self.upload_image(image_path, config).await {
|
||||
Ok(result) => {
|
||||
progress_callback(UploadProgress::Completed(result.clone()));
|
||||
return Ok(result);
|
||||
}
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
progress_callback(UploadProgress::Failed {
|
||||
error: last_error
|
||||
.as_ref()
|
||||
.map(|e| e.to_string())
|
||||
.unwrap_or_else(|| "Unknown error".to_string()),
|
||||
});
|
||||
|
||||
Err(last_error.unwrap_or_else(|| anyhow!("Upload failed")))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_upload_result_serialization() {
|
||||
let result = UploadResult {
|
||||
url: "https://example.com/image.png".to_string(),
|
||||
delete_url: Some("https://example.com/delete".to_string()),
|
||||
image_host: "Test".to_string(),
|
||||
uploaded_at: Utc::now().to_rfc3339(),
|
||||
file_size: 1024,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("url"));
|
||||
}
|
||||
}
|
||||
101
src-tauri/tauri.conf.json
Normal file
@@ -0,0 +1,101 @@
|
||||
{
|
||||
"$schema": "https://schema.tauri.app/config/2",
|
||||
"productName": "CutThenThink Lite",
|
||||
"version": "0.1.0",
|
||||
"identifier": "com.cutthenthink.app",
|
||||
"build": {
|
||||
"frontendDist": "../dist",
|
||||
"devUrl": "http://localhost:5173",
|
||||
"beforeDevCommand": "npm run dev",
|
||||
"beforeBuildCommand": "npm run build"
|
||||
},
|
||||
"app": {
|
||||
"windows": [
|
||||
{
|
||||
"title": "CutThenThink Lite",
|
||||
"width": 800,
|
||||
"height": 600,
|
||||
"resizable": true,
|
||||
"fullscreen": false,
|
||||
"minWidth": 600,
|
||||
"minHeight": 400
|
||||
}
|
||||
],
|
||||
"security": {
|
||||
"csp": null
|
||||
},
|
||||
"withGlobalTauri": true
|
||||
},
|
||||
"bundle": {
|
||||
"active": true,
|
||||
"targets": "all",
|
||||
"icon": [
|
||||
"icons/32x32.png",
|
||||
"icons/128x128.png",
|
||||
"icons/128x128@2x.png",
|
||||
"icons/icon.icns",
|
||||
"icons/icon.ico"
|
||||
],
|
||||
"publisher": "CutThenThink",
|
||||
"copyright": "Copyright © 2025 CutThenThink",
|
||||
"category": "Productivity",
|
||||
"shortDescription": "Lightweight screenshot and annotation tool",
|
||||
"longDescription": "CutThenThink Lite is a lightweight screenshot capture and annotation tool designed for quick visual communication. Perfect for creating tutorials, reporting bugs, or sharing ideas visually.",
|
||||
"createUpdaterArtifacts": true,
|
||||
"linux": {
|
||||
"deb": {
|
||||
"depends": []
|
||||
},
|
||||
"appimage": {
|
||||
"bundleMediaFramework": false
|
||||
}
|
||||
},
|
||||
"macOS": {
|
||||
"frameworks": [],
|
||||
"minimumSystemVersion": "10.13",
|
||||
"exceptionDomain": "",
|
||||
"signingIdentity": null,
|
||||
"entitlements": null,
|
||||
"providerShortName": null,
|
||||
"dmg": {
|
||||
"contents": [
|
||||
{
|
||||
"x": 130,
|
||||
"y": 220
|
||||
},
|
||||
{
|
||||
"x": 410,
|
||||
"y": 220,
|
||||
"type": "link",
|
||||
"path": "/Applications"
|
||||
}
|
||||
],
|
||||
"window": {
|
||||
"width": 540,
|
||||
"height": 380
|
||||
}
|
||||
}
|
||||
},
|
||||
"windows": {
|
||||
"certificateThumbprint": null,
|
||||
"digestAlgorithm": "sha256",
|
||||
"timestampUrl": "",
|
||||
"nsis": {
|
||||
"displayLanguageSelector": false,
|
||||
"languages": ["English", "SimpChinese"],
|
||||
"template": false,
|
||||
"installMode": "perMachine",
|
||||
"allowDowngrades": true,
|
||||
"createDesktopShortcut": true,
|
||||
"createStartMenuShortcut": true,
|
||||
"multiUserLauncher": false
|
||||
},
|
||||
"wix": {
|
||||
"language": ["en-US", "zh-CN"]
|
||||
},
|
||||
"webviewInstallMode": {
|
||||
"type": "embedBootstrapper"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||