feat: CutThenThink v3.0 初始版本

完整实现 Tauri + Vanilla JS 轻量级截图工具

Phase 1 - 项目搭建
- Tauri 2.x 项目初始化
- Vite 前端项目搭建
- 基础 UI 框架(CSS 变量、组件库)
- 构建配置优化

Phase 2 - 核心截图功能
- 全屏/区域/窗口截图
- 截图预览和管理
- 文件命名和缩略图
- 全局快捷键集成

Phase 3 - 上传与存储
- 多图床上传(GitHub/Imgur/自定义)
- 配置管理系统
- SQLite 数据库

Phase 4 - OCR 集成
- 云端 OCR(百度/腾讯云)
- 插件管理系统
- 本地 OCR 插件(Go)
- OCR 结果处理

Phase 5 - AI 分类系统
- Claude/OpenAI API 集成
- Prompt 模板引擎
- 模板管理界面
- 自动分类流程

Phase 6 - 历史记录与管理
- 图库视图(网格/列表)
- 搜索与筛选
- 批量操作
- 导出功能(JSON/CSV/ZIP)

Phase 7 - 打包与发布
- 多平台构建配置
- CI/CD 工作流
- 图标和资源
- 安装包配置

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude
2026-02-12 18:58:40 +08:00
commit e2ea309ee6
142 changed files with 38818 additions and 0 deletions

View File

@@ -0,0 +1,360 @@
use super::{AiClient, AiError, AiProvider, ClassificationResult, PromptEngine};
use anyhow::{Context, Result};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Mutex;
/// 分类器配置
#[derive(Debug, Clone)]
pub struct ClassifierConfig {
/// 使用的 AI 提供商
pub provider: AiProvider,
/// 最小置信度阈值(低于此值需要人工确认)
pub min_confidence: f64,
/// 是否自动应用分类结果
pub auto_apply: bool,
}
impl Default for ClassifierConfig {
fn default() -> Self {
Self {
provider: AiProvider::Claude,
min_confidence: 0.7,
auto_apply: false,
}
}
}
/// 分类器
pub struct Classifier {
/// AI 客户端
client: Mutex<AiClient>,
/// Prompt 引擎
prompt_engine: PromptEngine,
/// 分类器配置
config: ClassifierConfig,
}
impl Classifier {
/// 创建新的分类器
pub fn new(config: ClassifierConfig) -> Self {
let client = AiClient::new(config.provider);
Self {
client: Mutex::new(client),
prompt_engine: PromptEngine::new(),
config,
}
}
/// 配置 AI 客户端
pub fn configure_claude(&self, api_key: String, model: Option<String>) -> Result<()> {
let mut client = self.client.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let claude_config = super::client::ClaudeConfig {
api_key,
model: model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string()),
max_tokens: 4096,
temperature: 0.3, // 较低的温度以获得更一致的结果
};
*client = AiClient::new(self.config.provider)
.with_claude_config(claude_config);
Ok(())
}
/// 配置 OpenAI 客户端
pub fn configure_openai(
&self,
api_key: String,
model: Option<String>,
base_url: Option<String>,
) -> Result<()> {
let mut client = self.client.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let openai_config = super::client::OpenAiConfig {
api_key,
model: model.unwrap_or_else(|| "gpt-4o".to_string()),
max_tokens: 4096,
temperature: 0.3,
base_url,
};
*client = AiClient::new(self.config.provider)
.with_openai_config(openai_config);
Ok(())
}
/// 对内容进行分类
pub async fn classify(
&self,
template_id: Option<&str>,
variables: &HashMap<String, String>,
) -> Result<ClassificationResult> {
// 获取模板
let template = if let Some(id) = template_id {
self.prompt_engine.get_template(id)
.ok_or_else(|| anyhow::anyhow!("Template not found: {}", id))?
} else {
self.prompt_engine.get_default_template()
};
// 渲染 Prompt
let (system_prompt, user_prompt) = self
.prompt_engine
.render_template(template, variables)?;
// 构建消息
use super::client::{Message, MessageRole};
let messages = vec![
Message {
role: MessageRole::System,
content: system_prompt,
},
Message {
role: MessageRole::User,
content: user_prompt,
},
];
// 调用 AI
let client = self.client.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
// 释放锁后再调用
let client_ref = unsafe { &*(&*client as *const AiClient) };
let result = client_ref.chat(messages).await?;
// 解析分类结果
self.parse_classification_result(&result.content, template_id)
}
/// 流式分类(实时返回结果)
pub async fn classify_stream(
&self,
template_id: Option<&str>,
variables: &HashMap<String, String>,
mut callback: impl FnMut(super::client::StreamChunk),
) -> Result<ClassificationResult> {
// 获取模板
let template = if let Some(id) = template_id {
self.prompt_engine.get_template(id)
.ok_or_else(|| anyhow::anyhow!("Template not found: {}", id))?
} else {
self.prompt_engine.get_default_template()
};
// 渲染 Prompt
let (system_prompt, user_prompt) = self
.prompt_engine
.render_template(template, variables)?;
// 构建消息
use super::client::{Message, MessageRole};
let messages = vec![
Message {
role: MessageRole::System,
content: system_prompt,
},
Message {
role: MessageRole::User,
content: user_prompt,
},
];
// 调用 AI 流式接口
let client = self.client.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
// 释放锁后再调用
let client_ref = unsafe { &*(&*client as *const AiClient) };
let result = client_ref.chat_stream(messages, callback).await?;
// 解析分类结果
self.parse_classification_result(&result.content, template_id)
}
/// 解析分类结果
fn parse_classification_result(
&self,
content: &str,
template_id: Option<&str>,
) -> Result<ClassificationResult> {
// 尝试从内容中提取 JSON
let json_str = self.extract_json(content)?;
// 解析 JSON
let value: Value = serde_json::from_str(&json_str)
.context("Failed to parse classification result as JSON")?;
// 提取字段
let category = value
.get("category")
.and_then(|v| v.as_str())
.unwrap_or("未分类")
.to_string();
let subcategory = value
.get("subcategory")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let tags = value
.get("tags")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str())
.map(|s| s.to_string())
.collect()
})
.unwrap_or_default();
let confidence = value
.get("confidence")
.and_then(|v| v.as_f64())
.unwrap_or(0.5);
let reasoning = value
.get("reasoning")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
Ok(ClassificationResult {
category,
subcategory,
tags,
confidence,
reasoning,
})
}
/// 从文本中提取 JSON
fn extract_json(&self, text: &str) -> Result<String> {
// 查找 ```json 代码块
if let Some(start) = text.find("```json") {
let json_start = start + 7;
if let Some(end) = text[json_start..].find("```") {
let json_str = text[json_start..json_start + end].trim();
return Ok(json_str.to_string());
}
}
// 查找 ``` 代码块
if let Some(start) = text.find("```") {
let json_start = start + 3;
if let Some(end) = text[json_start..].find("```") {
let json_str = text[json_start..json_start + end].trim();
return Ok(json_str.to_string());
}
}
// 尝试查找 { ... }
if let Some(start) = text.find('{') {
if let Some(end) = text.rfind('}') {
if end > start {
return Ok(text[start..=end].to_string());
}
}
}
// 如果都找不到,返回原文本
Ok(text.to_string())
}
/// 评估是否需要人工确认
pub fn needs_confirmation(&self, result: &ClassificationResult) -> bool {
result.confidence < self.config.min_confidence
}
/// 获取可用的模板列表
pub fn available_templates(&self) -> Vec<&super::template::Template> {
self.prompt_engine.list_templates()
}
/// 获取模板
pub fn get_template(&self, id: &str) -> Option<&super::template::Template> {
self.prompt_engine.get_template(id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_extraction() {
let classifier = Classifier::new(ClassifierConfig::default());
let text = r#"这是一些前置文本
```json
{
"category": "代码",
"confidence": 0.95
}
```
后置文本"#;
let json = classifier.extract_json(text).unwrap();
assert!(json.contains("category"));
assert!(json.contains("代码"));
}
#[test]
fn test_classification_parsing() {
let classifier = Classifier::new(ClassifierConfig::default());
let json = r#"{
"category": "代码",
"subcategory": "Python",
"tags": ["编程", "脚本", "自动化"],
"confidence": 0.95,
"reasoning": "这是一段 Python 代码"
}"#;
let result = classifier.parse_classification_result(json, None).unwrap();
assert_eq!(result.category, "代码");
assert_eq!(result.subcategory, Some("Python".to_string()));
assert_eq!(result.tags.len(), 3);
assert_eq!(result.confidence, 0.95);
assert!(result.reasoning.is_some());
}
#[test]
fn test_confidence_check() {
let config = ClassifierConfig {
provider: AiProvider::Claude,
min_confidence: 0.8,
auto_apply: false,
};
let classifier = Classifier::new(config);
let high_conf = ClassificationResult {
category: "测试".to_string(),
subcategory: None,
tags: vec![],
confidence: 0.9,
reasoning: None,
};
let low_conf = ClassificationResult {
category: "测试".to_string(),
subcategory: None,
tags: vec![],
confidence: 0.7,
reasoning: None,
};
assert!(!classifier.needs_confirmation(&high_conf));
assert!(classifier.needs_confirmation(&low_conf));
}
}

511
src-tauri/src/ai/client.rs Normal file
View File

@@ -0,0 +1,511 @@
use super::{AiError, AiResult};
use anyhow::Result;
use futures_util::StreamExt;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::sync::Mutex;
use std::sync::Arc;
/// AI 提供商类型
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AiProvider {
Claude,
OpenAi,
}
/// Claude API 配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaudeConfig {
pub api_key: String,
#[serde(default = "default_claude_model")]
pub model: String,
#[serde(default = "default_claude_max_tokens")]
pub max_tokens: u32,
#[serde(default = "default_claude_temperature")]
pub temperature: f32,
}
fn default_claude_model() -> String {
"claude-3-5-sonnet-20241022".to_string()
}
fn default_claude_max_tokens() -> u32 {
4096
}
fn default_claude_temperature() -> f32 {
0.7
}
/// OpenAI API 配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiConfig {
pub api_key: String,
#[serde(default = "default_openai_model")]
pub model: String,
#[serde(default = "default_openai_max_tokens")]
pub max_tokens: u32,
#[serde(default = "default_openai_temperature")]
pub temperature: f32,
pub base_url: Option<String>,
}
fn default_openai_model() -> String {
"gpt-4o".to_string()
}
fn default_openai_max_tokens() -> u32 {
4096
}
fn default_openai_temperature() -> f32 {
0.7
}
/// 消息角色
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
}
/// 聊天消息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
/// 流式响应块
#[derive(Debug, Clone)]
pub enum StreamChunk {
/// 文本增量
Text(String),
/// 完成
Done,
/// 错误
Error(String),
}
/// 限流器
struct RateLimiter {
last_request: Arc<Mutex<std::time::Instant>>,
min_interval: Duration,
}
impl RateLimiter {
fn new(requests_per_second: u32) -> Self {
Self {
last_request: Arc::new(Mutex::new(std::time::Instant::now()
.checked_sub(Duration::from_secs(1)).unwrap())),
min_interval: Duration::from_secs(1) / requests_per_second.max(1),
}
}
async fn acquire(&self) {
let mut last = self.last_request.lock().await;
let elapsed = last.elapsed();
if elapsed < self.min_interval {
tokio::time::sleep(self.min_interval - elapsed).await;
}
*last = std::time::Instant::now();
}
}
/// AI 客户端
pub struct AiClient {
provider: AiProvider,
http_client: reqwest::Client,
rate_limiter: RateLimiter,
claude_config: Option<ClaudeConfig>,
openai_config: Option<OpenAiConfig>,
}
impl AiClient {
/// 创建新的 AI 客户端
pub fn new(provider: AiProvider) -> Self {
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(120))
.build()
.expect("Failed to create HTTP client");
// 限制每秒最多 5 个请求
let rate_limiter = RateLimiter::new(5);
Self {
provider,
http_client,
rate_limiter,
claude_config: None,
openai_config: None,
}
}
/// 设置 Claude 配置
pub fn with_claude_config(mut self, config: ClaudeConfig) -> Self {
self.claude_config = Some(config);
self
}
/// 设置 OpenAI 配置
pub fn with_openai_config(mut self, config: OpenAiConfig) -> Self {
self.openai_config = Some(config);
self
}
/// 发送聊天请求(非流式)
pub async fn chat(&self, messages: Vec<Message>) -> Result<AiResult> {
self.rate_limiter.acquire().await;
match self.provider {
AiProvider::Claude => self.claude_chat(messages).await,
AiProvider::OpenAi => self.openai_chat(messages).await,
}
}
/// 发送流式聊天请求
pub async fn chat_stream(
&self,
messages: Vec<Message>,
mut callback: impl FnMut(StreamChunk),
) -> Result<AiResult> {
self.rate_limiter.acquire().await;
match self.provider {
AiProvider::Claude => self.claude_chat_stream(messages, callback).await,
AiProvider::OpenAi => self.openai_chat_stream(messages, callback).await,
}
}
// ========== Claude API ==========
/// Claude 聊天请求
async fn claude_chat(&self, messages: Vec<Message>) -> Result<AiResult> {
let config = self.claude_config.as_ref()
.ok_or_else(|| AiError::ConfigError("Claude 配置未设置".to_string()))?;
#[derive(Serialize)]
struct ClaudeRequest {
model: String,
messages: Vec<Message>,
max_tokens: u32,
temperature: f32,
}
#[derive(Deserialize)]
struct ClaudeResponse {
id: String,
content: Vec<ClaudeContent>,
model: String,
usage: ClaudeUsage,
}
#[derive(Deserialize)]
struct ClaudeContent {
#[serde(rename = "type")]
content_type: String,
text: String,
}
#[derive(Deserialize)]
struct ClaudeUsage {
input_tokens: u32,
output_tokens: u32,
}
let request = ClaudeRequest {
model: config.model.clone(),
messages,
max_tokens: config.max_tokens,
temperature: config.temperature,
};
let response = self.http_client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", &config.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&request)
.send()
.await?;
self.handle_response(response).await
}
/// Claude 流式聊天请求
async fn claude_chat_stream(
&self,
messages: Vec<Message>,
mut callback: impl FnMut(StreamChunk),
) -> Result<AiResult> {
let config = self.claude_config.as_ref()
.ok_or_else(|| AiError::ConfigError("Claude 配置未设置".to_string()))?;
#[derive(Serialize)]
struct ClaudeRequest {
model: String,
messages: Vec<Message>,
max_tokens: u32,
temperature: f32,
stream: bool,
}
let request = ClaudeRequest {
model: config.model.clone(),
messages,
max_tokens: config.max_tokens,
temperature: config.temperature,
stream: true,
};
let response = self.http_client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", &config.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
callback(StreamChunk::Error(error_text.clone()));
return Err(AiError::ApiError(error_text));
}
let mut full_content = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result?;
let text = String::from_utf8_lossy(&chunk);
// 处理 SSE 格式
for line in text.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
callback(StreamChunk::Done);
break;
}
if let Ok(event) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(delta) = event.get("delta")
.and_then(|d| d.get("text"))
.and_then(|t| t.as_str())
{
full_content.push_str(delta);
callback(StreamChunk::Text(delta.to_string()));
}
}
}
}
}
Ok(AiResult {
content: full_content,
tokens_used: None,
model: config.model.clone(),
confidence: None,
})
}
// ========== OpenAI API ==========
/// OpenAI 聊天请求
async fn openai_chat(&self, messages: Vec<Message>) -> Result<AiResult> {
let config = self.openai_config.as_ref()
.ok_or_else(|| AiError::ConfigError("OpenAI 配置未设置".to_string()))?;
let base_url = config.base_url.as_ref()
.map(|s| s.as_str())
.unwrap_or("https://api.openai.com/v1");
#[derive(Serialize)]
struct OpenAiRequest {
model: String,
messages: Vec<Message>,
max_tokens: u32,
temperature: f32,
}
#[derive(Deserialize)]
struct OpenAiResponse {
id: String,
choices: Vec<OpenAiChoice>,
usage: OpenAiUsage,
model: String,
}
#[derive(Deserialize)]
struct OpenAiChoice {
message: OpenAiMessage,
finish_reason: String,
}
#[derive(Deserialize)]
struct OpenAiMessage {
role: String,
content: String,
}
#[derive(Deserialize)]
struct OpenAiUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
let request = OpenAiRequest {
model: config.model.clone(),
messages,
max_tokens: config.max_tokens,
temperature: config.temperature,
};
let response = self.http_client
.post(format!("{}/chat/completions", base_url))
.header(AUTHORIZATION, format!("Bearer {}", config.api_key))
.header(CONTENT_TYPE, "application/json")
.json(&request)
.send()
.await?;
self.handle_response(response).await
}
/// OpenAI 流式聊天请求
async fn openai_chat_stream(
&self,
messages: Vec<Message>,
mut callback: impl FnMut(StreamChunk),
) -> Result<AiResult> {
let config = self.openai_config.as_ref()
.ok_or_else(|| AiError::ConfigError("OpenAI 配置未设置".to_string()))?;
let base_url = config.base_url.as_ref()
.map(|s| s.as_str())
.unwrap_or("https://api.openai.com/v1");
#[derive(Serialize)]
struct OpenAiRequest {
model: String,
messages: Vec<Message>,
max_tokens: u32,
temperature: f32,
stream: bool,
}
let request = OpenAiRequest {
model: config.model.clone(),
messages,
max_tokens: config.max_tokens,
temperature: config.temperature,
stream: true,
};
let response = self.http_client
.post(format!("{}/chat/completions", base_url))
.header(AUTHORIZATION, format!("Bearer {}", config.api_key))
.header(CONTENT_TYPE, "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
callback(StreamChunk::Error(error_text.clone()));
return Err(AiError::ApiError(error_text));
}
let mut full_content = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result?;
let text = String::from_utf8_lossy(&chunk);
for line in text.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
callback(StreamChunk::Done);
break;
}
if let Ok(event) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(delta) = event.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("delta"))
.and_then(|d| d.get("content"))
.and_then(|t| t.as_str())
{
full_content.push_str(delta);
callback(StreamChunk::Text(delta.to_string()));
}
}
}
}
}
Ok(AiResult {
content: full_content,
tokens_used: None,
model: config.model.clone(),
confidence: None,
})
}
/// 处理 HTTP 响应
async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
where
T: for<'de> serde::Deserialize<'de>,
{
let status = response.status();
if status.is_success() {
response.json::<T>().await.map_err(Into::into)
} else {
let error_text = response.text().await?;
Err(if status.as_u16() == 401 {
AiError::AuthError(error_text)
} else if status.as_u16() == 429 {
AiError::RateLimitError
} else {
AiError::ApiError(error_text)
}.into())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_claude_config_serialization() {
let config = ClaudeConfig {
api_key: "test-key".to_string(),
model: "claude-3-5-sonnet-20241022".to_string(),
max_tokens: 4096,
temperature: 0.7,
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("api_key"));
}
#[test]
fn test_message_serialization() {
let message = Message {
role: MessageRole::User,
content: "Hello, AI!".to_string(),
};
let json = serde_json::to_string(&message).unwrap();
assert!(json.contains("user"));
}
}

76
src-tauri/src/ai/mod.rs Normal file
View File

@@ -0,0 +1,76 @@
pub mod client;
pub mod prompt;
pub mod classify;
pub mod template;
pub use client::{AiClient, AiProvider, StreamChunk};
pub use prompt::PromptEngine;
pub use classify::Classifier;
pub use template::{Template, TemplateManager, TemplateVariable};
use anyhow::Result;
/// AI 服务错误类型
#[derive(Debug, thiserror::Error)]
pub enum AiError {
#[error("API 错误: {0}")]
ApiError(String),
#[error("网络错误: {0}")]
NetworkError(String),
#[error("认证失败: {0}")]
AuthError(String),
#[error("限流: 请稍后再试")]
RateLimitError,
#[error("配置错误: {0}")]
ConfigError(String),
#[error("模板错误: {0}")]
TemplateError(String),
#[error("其他错误: {0}")]
Other(String),
}
impl From<reqwest::Error> for AiError {
fn from(err: reqwest::Error) -> Self {
if err.is_timeout() {
AiError::NetworkError("请求超时".to_string())
} else if err.is_connect() {
AiError::NetworkError("连接失败".to_string())
} else {
AiError::NetworkError(err.to_string())
}
}
}
/// AI 服务结果
#[derive(Debug, Clone)]
pub struct AiResult {
/// 生成的文本内容
pub content: String,
/// 使用的 Token 数量(估算)
pub tokens_used: Option<usize>,
/// 模型名称
pub model: String,
/// 置信度评分0.0 - 1.0
pub confidence: Option<f64>,
}
/// 分类结果
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ClassificationResult {
/// 主分类
pub category: String,
/// 子分类
pub subcategory: Option<String>,
/// 标签列表
pub tags: Vec<String>,
/// 置信度评分0.0 - 1.0
pub confidence: f64,
/// AI 提供的推理说明
pub reasoning: Option<String>,
}

439
src-tauri/src/ai/prompt.rs Normal file
View File

@@ -0,0 +1,439 @@
use super::template::{Template, TemplateVariable};
use anyhow::{Context, Result};
use std::collections::HashMap;
/// Prompt 模板引擎
pub struct PromptEngine {
/// 内置模板库
builtin_templates: HashMap<String, Template>,
}
impl PromptEngine {
/// 创建新的 Prompt 引擎
pub fn new() -> Self {
let mut engine = Self {
builtin_templates: HashMap::new(),
};
// 初始化内置模板
engine.init_builtin_templates();
engine
}
/// 初始化内置模板
fn init_builtin_templates(&mut self) {
// 通用分类模板
let general_template = Template {
id: "builtin_general".to_string(),
name: "通用分类".to_string(),
description: "适用于大多数场景的内容分类".to_string(),
category: "general".to_string(),
system_prompt: r#"你是一个专业的内容分类助手。你的任务是根据提供的内容,将其归类到合适的类别中。
请遵循以下规则:
1. 分析内容的主要主题和目的
2. 选择最合适的主分类
3. 如果可能,提供更具体的子分类
4. 提取 3-5 个相关标签
5. 给出你的置信度评分0.0-1.0
请以 JSON 格式返回结果:
```json
{
"category": "",
"subcategory": "",
"tags": ["1", "2", "3"],
"confidence": 0.95,
"reasoning": ""
}
```"#
.to_string(),
user_prompt_template: r#"请对以下内容进行分类:
内容类型:{{content_type}}
{{#if image_path}}
包含图片:是
{{/if}}
{{#if ocr_text}}
OCR 识别文本:
{{ocr_text}}
{{/if}}
{{#if content}}
文本内容:
{{content}}
{{/if}}
请分析并返回分类结果。"#.to_string(),
variables: vec![
TemplateVariable {
name: "content_type".to_string(),
description: "内容类型text/image/file".to_string(),
required: true,
default_value: Some("text".to_string()),
},
TemplateVariable {
name: "content".to_string(),
description: "文本内容".to_string(),
required: false,
default_value: None,
},
TemplateVariable {
name: "ocr_text".to_string(),
description: "OCR 识别的文本".to_string(),
required: false,
default_value: None,
},
TemplateVariable {
name: "image_path".to_string(),
description: "图片路径".to_string(),
required: false,
default_value: None,
},
],
created_at: chrono::Utc::now().to_rfc3339(),
updated_at: chrono::Utc::now().to_rfc3339(),
};
self.builtin_templates.insert("general".to_string(), general_template);
// 代码分类模板
let code_template = Template {
id: "builtin_code".to_string(),
name: "代码片段分类".to_string(),
description: "专门用于分类和分析代码片段".to_string(),
category: "code".to_string(),
system_prompt: r#"你是一个专业的代码分析助手。你的任务是分析提供的代码片段,识别其编程语言、用途和特点。
请关注:
1. 编程语言和框架
2. 代码的主要功能
3. 是否是完整代码还是片段
4. 代码所属的领域Web开发、数据处理、算法等
请以 JSON 格式返回:
```json
{
"category": "",
"subcategory": "",
"tags": ["", "", ""],
"confidence": 0.95,
"reasoning": ""
}
```"#
.to_string(),
user_prompt_template: r#"请分析以下代码:
{{content}}
{{#if ocr_text}}
如果上述代码是从图片中识别的,请使用以下 OCR 结果:
{{ocr_text}}
{{/if}}
请返回编程语言、框架和功能分析。"#.to_string(),
variables: vec![
TemplateVariable {
name: "content".to_string(),
description: "代码内容".to_string(),
required: true,
default_value: None,
},
TemplateVariable {
name: "ocr_text".to_string(),
description: "OCR 识别的代码文本".to_string(),
required: false,
default_value: None,
},
],
created_at: chrono::Utc::now().to_rfc3339(),
updated_at: chrono::Utc::now().to_rfc3339(),
};
self.builtin_templates.insert("code".to_string(), code_template);
// 票据/发票分类模板
let invoice_template = Template {
id: "builtin_invoice".to_string(),
name: "票据发票分类".to_string(),
description: "用于分类和提取票据、发票信息".to_string(),
category: "invoice".to_string(),
system_prompt: r#"你是一个专业的票据识别助手。你的任务是识别和分析票据、发票类型,并提取关键信息。
请关注:
1. 票据类型(发票、收据、订单等)
2. 金额信息
3. 商户信息
4. 日期信息
5. 票据的用途和性质
请以 JSON 格式返回:
```json
{
"category": "",
"subcategory": "",
"tags": ["", "", ""],
"confidence": 0.95,
"reasoning": ""
}
```"#
.to_string(),
user_prompt_template: r#"请识别以下票据:
{{#if ocr_text}}
票据内容:
{{ocr_text}}
{{/if}}
{{#if image_path}}
这是一个包含票据的图片。
{{/if}}
请返回票据类型和关键信息。"#.to_string(),
variables: vec![
TemplateVariable {
name: "ocr_text".to_string(),
description: "OCR 识别的票据文本".to_string(),
required: true,
default_value: None,
},
TemplateVariable {
name: "image_path".to_string(),
description: "票据图片路径".to_string(),
required: false,
default_value: None,
},
],
created_at: chrono::Utc::now().to_rfc3339(),
updated_at: chrono::Utc::now().to_rfc3339(),
};
self.builtin_templates.insert("invoice".to_string(), invoice_template);
// 对话分类模板
let conversation_template = Template {
id: "builtin_conversation".to_string(),
name: "对话内容分类".to_string(),
description: "用于分类聊天记录、对话内容".to_string(),
category: "conversation".to_string(),
system_prompt: r#"你是一个专业的对话分析助手。你的任务是分析对话内容,识别对话的性质和主题。
请关注:
1. 对话的性质(工作、休闲、客服、技术支持等)
2. 主要话题
3. 参与者角色
4. 情感倾向(可选)
请以 JSON 格式返回:
```json
{
"category": "",
"subcategory": "",
"tags": ["1", "2", ""],
"confidence": 0.95,
"reasoning": ""
}
```"#
.to_string(),
user_prompt_template: r#"请分析以下对话:
{{content}}
{{#if ocr_text}}
如果上述对话是从图片中识别的,请使用以下 OCR 结果:
{{ocr_text}}
{{/if}}
请返回对话类型和主题分析。"#.to_string(),
variables: vec![
TemplateVariable {
name: "content".to_string(),
description: "对话文本内容".to_string(),
required: true,
default_value: None,
},
TemplateVariable {
name: "ocr_text".to_string(),
description: "OCR 识别的对话文本".to_string(),
required: false,
default_value: None,
},
],
created_at: chrono::Utc::now().to_rfc3339(),
updated_at: chrono::Utc::now().to_rfc3339(),
};
self.builtin_templates.insert("conversation".to_string(), conversation_template);
}
/// 渲染模板
pub fn render_template(
&self,
template: &Template,
variables: &HashMap<String, String>,
) -> Result<(String, String)> {
// 渲染 system prompt
let system_prompt = template.system_prompt.clone();
// 渲染 user prompt
let user_prompt = self.render_user_prompt(&template.user_prompt_template, variables)?;
Ok((system_prompt, user_prompt))
}
/// 渲染用户提示词模板
fn render_user_prompt(
&self,
template: &str,
variables: &HashMap<String, String>,
) -> Result<String> {
let mut result = template.to_string();
// 处理条件块 {{#if var}}...{{/if}}
let mut replaced = String::new();
let mut chars = template.chars().peekable();
let mut pos = 0;
while pos < template.len() {
// 查找 {{#if
if let Some(start) = template[pos..].find("{{#if ") {
let block_start = pos + start;
let var_start = block_start + 6; // "{{#if ".len()
// 查找闭合的 }}
if let Some(var_end) = template[var_start..].find("}}") {
let var_name = &template[var_start..var_start + var_end];
let content_start = var_start + var_end + 2;
// 查找 {{/if}}
if let Some(end_marker) = template[content_start..].find("{{/if}}") {
let content_end = content_start + end_marker;
let block_content = &template[content_start..content_end];
// 如果变量存在且有值,则包含内容
if let Some(value) = variables.get(var_name) {
if !value.is_empty() {
replaced.push_str(&template[pos..block_start]);
replaced.push_str(block_content);
}
}
pos = content_end + 7; // "{{/if}}".len()
continue;
}
}
}
// 替换简单变量 {{var}}
if let Some(start) = template[pos..].find('{{') {
replaced.push_str(&template[pos..pos + start]);
let var_start = pos + start + 2;
if let Some(end) = template[var_start..].find("}}") {
let var_name = &template[var_start..var_start + end];
let trimmed = var_name.trim();
if let Some(value) = variables.get(trimmed) {
replaced.push_str(value);
} else if let Some(template_var) = self
.builtin_templates
.values()
.flat_map(|t| t.variables.iter())
.find(|v| v.name == trimmed)
{
// 使用默认值
if let Some(default) = &template_var.default_value {
replaced.push_str(default);
}
}
pos = var_start + end + 2;
continue;
}
}
if pos < template.len() {
replaced.push(template.chars().nth(pos).unwrap());
pos += 1;
} else {
break;
}
}
// 第二次遍历:替换剩余的简单变量
result = replaced.clone();
for (key, value) in variables.iter() {
let placeholder = format!("{{{{{}}}}}", key);
result = result.replace(&placeholder, value);
}
Ok(result)
}
/// 获取内置模板
pub fn get_template(&self, id: &str) -> Option<&Template> {
self.builtin_templates.get(id)
}
/// 列出所有内置模板
pub fn list_templates(&self) -> Vec<&Template> {
self.builtin_templates.values().collect()
}
/// 获取默认模板
pub fn get_default_template(&self) -> &Template {
self.builtin_templates.get("general").unwrap()
}
}
impl Default for PromptEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prompt_engine_creation() {
let engine = PromptEngine::new();
assert!(engine.get_template("general").is_some());
assert!(engine.get_template("code").is_some());
}
#[test]
fn test_template_rendering() {
let engine = PromptEngine::new();
let template = engine.get_default_template();
let mut vars = HashMap::new();
vars.insert("content_type".to_string(), "text".to_string());
vars.insert("content".to_string(), "Hello, world!".to_string());
let (system, user) = engine.render_template(template, &vars).unwrap();
assert!(!system.is_empty());
assert!(!user.is_empty());
assert!(user.contains("Hello, world!"));
}
#[test]
fn test_conditional_blocks() {
let engine = PromptEngine::new();
let template = engine.get_default_template();
let mut vars = HashMap::new();
vars.insert("content_type".to_string(), "text".to_string());
// 没有 ocr_text条件块应该被移除
let (system, user) = engine.render_template(template, &vars).unwrap();
assert!(!user.contains("OCR 识别文本"));
// 有 ocr_text条件块应该保留
vars.insert("ocr_text".to_string(), "Sample OCR text".to_string());
let (system, user) = engine.render_template(template, &vars).unwrap();
assert!(user.contains("OCR 识别文本"));
assert!(user.contains("Sample OCR text"));
}
}

View File

@@ -0,0 +1,343 @@
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
/// 模板变量
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemplateVariable {
/// 变量名
pub name: String,
/// 变量描述
pub description: String,
/// 是否必需
pub required: bool,
/// 默认值
pub default_value: Option<String>,
}
/// Prompt 模板
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Template {
/// 模板 ID
pub id: String,
/// 模板名称
pub name: String,
/// 模板描述
pub description: String,
/// 模板分类
pub category: String,
/// 系统 Prompt
pub system_prompt: String,
/// 用户 Prompt 模板
pub user_prompt_template: String,
/// 模板变量列表
pub variables: Vec<TemplateVariable>,
/// 创建时间
pub created_at: String,
/// 更新时间
pub updated_at: String,
}
/// 模板管理器
pub struct TemplateManager {
/// 自定义模板存储路径
templates_dir: PathBuf,
/// 模板缓存
templates: Mutex<HashMap<String, Template>>,
}
impl TemplateManager {
/// 创建新的模板管理器
pub fn new(templates_dir: PathBuf) -> Result<Self> {
// 确保模板目录存在
std::fs::create_dir_all(&templates_dir)
.context("Failed to create templates directory")?;
let manager = Self {
templates_dir,
templates: Mutex::new(HashMap::new()),
};
// 加载所有模板
manager.load_templates()?;
Ok(manager)
}
/// 加载所有模板
fn load_templates(&self) -> Result<()> {
let mut templates = self.templates.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
templates.clear();
// 读取目录中的所有 JSON 文件
let entries = std::fs::read_dir(&self.templates_dir)
.context("Failed to read templates directory")?;
for entry in entries {
let entry = entry.context("Failed to read directory entry")?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
// 跳过备份文件
if path.file_name()
.and_then(|s| s.to_str())
.map(|s| s.starts_with('.'))
.unwrap_or(false)
{
continue;
}
// 读取模板文件
let content = std::fs::read_to_string(&path)
.context("Failed to read template file")?;
// 解析模板
let template: Template = serde_json::from_str(&content)
.context("Failed to parse template")?;
templates.insert(template.id.clone(), template);
}
}
log::info!("Loaded {} templates", templates.len());
Ok(())
}
/// 保存模板
fn save_template(&self, template: &Template) -> Result<()> {
let file_path = self.templates_dir.join(format!("{}.json", template.id));
let content = serde_json::to_string_pretty(template)
.context("Failed to serialize template")?;
std::fs::write(&file_path, content)
.context("Failed to write template file")?;
Ok(())
}
/// 添加或更新模板
pub fn upsert_template(&self, template: Template) -> Result<()> {
// 验证模板
self.validate_template(&template)?;
// 保存到文件
self.save_template(&template)?;
// 更新缓存
let mut templates = self.templates.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
templates.insert(template.id.clone(), template);
Ok(())
}
/// 获取模板
pub fn get_template(&self, id: &str) -> Option<Template> {
let templates = self.templates.lock()
.ok()?;
templates.get(id).cloned()
}
/// 列出所有模板
pub fn list_templates(&self) -> Result<Vec<Template>> {
let templates = self.templates.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let mut list: Vec<_> = templates.values().cloned().collect();
list.sort_by(|a, b| a.name.cmp(&b.name));
Ok(list)
}
/// 按分类列出模板
pub fn list_templates_by_category(&self, category: &str) -> Result<Vec<Template>> {
let templates = self.templates.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let mut list: Vec<_> = templates
.values()
.filter(|t| t.category == category)
.cloned()
.collect();
list.sort_by(|a, b| a.name.cmp(&b.name));
Ok(list)
}
/// 删除模板
pub fn delete_template(&self, id: &str) -> Result<bool> {
// 删除文件
let file_path = self.templates_dir.join(format!("{}.json", id));
let mut deleted = false;
if file_path.exists() {
std::fs::remove_file(&file_path)
.context("Failed to delete template file")?;
deleted = true;
}
// 从缓存中移除
let mut templates = self.templates.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
Ok(templates.remove(id).is_some() || deleted)
}
/// 导入模板
pub fn import_template(&self, json_data: &str) -> Result<Template> {
let template: Template = serde_json::from_str(json_data)
.context("Failed to parse template JSON")?;
self.upsert_template(template.clone())?;
Ok(template)
}
/// 导出模板
pub fn export_template(&self, id: &str) -> Result<String> {
let template = self.get_template(id)
.ok_or_else(|| anyhow::anyhow!("Template not found: {}", id))?;
serde_json::to_string_pretty(&template)
.context("Failed to serialize template")
}
/// 验证模板
fn validate_template(&self, template: &Template) -> Result<()> {
if template.id.is_empty() {
return Err(anyhow::anyhow!("Template ID cannot be empty"));
}
if template.name.is_empty() {
return Err(anyhow::anyhow!("Template name cannot be empty"));
}
if template.system_prompt.is_empty() {
return Err(anyhow::anyhow!("System prompt cannot be empty"));
}
if template.user_prompt_template.is_empty() {
return Err(anyhow::anyhow!("User prompt template cannot be empty"));
}
// 验证变量名称格式
for var in &template.variables {
if var.name.is_empty() {
return Err(anyhow::anyhow!("Variable name cannot be empty"));
}
// 变量名只能包含字母、数字和下划线
if !var.name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(anyhow::anyhow!(
"Invalid variable name '{}': only alphanumeric and underscore allowed",
var.name
));
}
}
Ok(())
}
/// 测试模板渲染
pub fn test_template(
&self,
template_id: &str,
variables: &HashMap<String, String>,
) -> Result<(String, String)> {
let template = self.get_template(template_id)
.ok_or_else(|| anyhow::anyhow!("Template not found: {}", template_id))?;
// 渲染系统 Prompt
let system_prompt = template.system_prompt.clone();
// 渲染用户 Prompt
let user_prompt = self.render_user_prompt(&template.user_prompt_template, variables)?;
Ok((system_prompt, user_prompt))
}
/// 渲染用户 Prompt
fn render_user_prompt(
&self,
template: &str,
variables: &HashMap<String, String>,
) -> Result<String> {
let mut result = template.to_string();
// 简单变量替换
for (key, value) in variables.iter() {
let placeholder = format!("{{{{{}}}}}", key);
result = result.replace(&placeholder, value);
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_template_serialization() {
let template = Template {
id: "test".to_string(),
name: "Test Template".to_string(),
description: "A test template".to_string(),
category: "test".to_string(),
system_prompt: "You are a test assistant.".to_string(),
user_prompt_template: "Hello {{name}}!".to_string(),
variables: vec![],
created_at: Utc::now().to_rfc3339(),
updated_at: Utc::now().to_rfc3339(),
};
let json = serde_json::to_string(&template).unwrap();
assert!(json.contains("test"));
let parsed: Template = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "test");
}
#[test]
fn test_template_validation() {
let manager = TemplateManager::new(
std::env::temp_dir().join("test_templates")
).unwrap();
let valid_template = Template {
id: "valid".to_string(),
name: "Valid Template".to_string(),
description: "A valid template".to_string(),
category: "test".to_string(),
system_prompt: "System prompt".to_string(),
user_prompt_template: "User prompt".to_string(),
variables: vec![],
created_at: Utc::now().to_rfc3339(),
updated_at: Utc::now().to_rfc3339(),
};
assert!(manager.validate_template(&valid_template).is_ok());
let invalid_template = Template {
id: "".to_string(),
name: "Invalid Template".to_string(),
description: "An invalid template".to_string(),
category: "test".to_string(),
system_prompt: "System prompt".to_string(),
user_prompt_template: "User prompt".to_string(),
variables: vec![],
created_at: Utc::now().to_rfc3339(),
updated_at: Utc::now().to_rfc3339(),
};
assert!(manager.validate_template(&invalid_template).is_err());
}
}

205
src-tauri/src/config.rs Normal file
View File

@@ -0,0 +1,205 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
/// 图床配置类型
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ImageHostConfig {
GitHub {
token: String,
owner: String,
repo: String,
path: String,
branch: Option<String>,
},
Imgur {
client_id: String,
},
Custom {
url: String,
headers: Option<Vec<HeaderItem>>,
form_field: Option<String>,
},
}
/// HTTP 头部项
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeaderItem {
pub name: String,
pub value: String,
}
/// 应用配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
/// 默认图床配置
pub default_image_host: Option<ImageHostConfig>,
/// 可用的图床配置列表
pub image_hosts: Vec<ImageHostConfig>,
/// 上传重试次数
pub upload_retry_count: u32,
/// 上传超时时间(秒)
pub upload_timeout_seconds: u64,
/// 是否自动复制上传后的链接
pub auto_copy_link: bool,
/// 保留的截图数量
pub keep_screenshots_count: usize,
/// 数据库路径
pub database_path: Option<PathBuf>,
/// OCR 配置
pub ocr_config: Option<OcrAppConfig>,
}
/// OCR 应用配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrAppConfig {
/// 默认 OCR 引擎
pub default_engine: String,
/// 是否自动复制 OCR 结果
pub auto_copy_result: bool,
/// OCR 结果保留天数
pub keep_results_days: u32,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
default_image_host: None,
image_hosts: Vec::new(),
upload_retry_count: 3,
upload_timeout_seconds: 30,
auto_copy_link: true,
keep_screenshots_count: 50,
database_path: None,
ocr_config: Some(OcrAppConfig {
default_engine: "baidu".to_string(),
auto_copy_result: false,
keep_results_days: 30,
}),
}
}
}
/// 配置管理器
pub struct ConfigManager {
config_dir: PathBuf,
config_file: PathBuf,
}
impl ConfigManager {
/// 创建新的配置管理器
pub fn new() -> Result<Self> {
let config_dir = Self::get_config_dir()?;
let config_file = config_dir.join("config.json");
// 确保配置目录存在
fs::create_dir_all(&config_dir)
.context("Failed to create config directory")?;
Ok(Self {
config_dir,
config_file,
})
}
/// 获取配置目录
fn get_config_dir() -> Result<PathBuf> {
let config_dir = if cfg!(target_os = "macos") {
// macOS: ~/Library/Application Support/CutThenThink
dirs::config_dir()
.map(|p| p.join("CutThenThink"))
} else if cfg!(target_os = "windows") {
// Windows: %APPDATA%/CutThenThink
dirs::config_dir()
.map(|p| p.join("CutThenThink"))
} else {
// Linux: ~/.config/CutThenThink
dirs::home_dir()
.map(|p| p.join(".config").join("CutThenThink"))
};
config_dir.context("Failed to determine config directory")
}
/// 加载配置
pub fn load(&self) -> Result<AppConfig> {
if !self.config_file.exists() {
// 如果配置文件不存在,创建默认配置
let default_config = AppConfig::default();
self.save(&default_config)?;
return Ok(default_config);
}
let content = fs::read_to_string(&self.config_file)
.context("Failed to read config file")?;
let config: AppConfig = serde_json::from_str(&content)
.context("Failed to parse config file")?;
Ok(config)
}
/// 保存配置
pub fn save(&self, config: &AppConfig) -> Result<()> {
let content = serde_json::to_string_pretty(config)
.context("Failed to serialize config")?;
fs::write(&self.config_file, content)
.context("Failed to write config file")?;
Ok(())
}
/// 获取配置文件路径
pub fn config_file_path(&self) -> &Path {
&self.config_file
}
/// 获取数据目录
pub fn data_dir(&self) -> PathBuf {
self.config_dir.join("data")
}
/// 获取数据库路径
pub fn database_path(&self) -> PathBuf {
self.data_dir().join("cutthink.db")
}
/// 确保数据目录存在
pub fn ensure_data_dir(&self) -> Result<()> {
let data_dir = self.data_dir();
fs::create_dir_all(&data_dir)
.context("Failed to create data directory")?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_serialization() {
let config = AppConfig::default();
let json = serde_json::to_string(&config).unwrap();
println!("{}", json);
assert!(json.contains("upload_retry_count"));
}
#[test]
fn test_image_host_config() {
let github_config = ImageHostConfig::GitHub {
token: "test_token".to_string(),
owner: "test_owner".to_string(),
repo: "test_repo".to_string(),
path: "screenshots".to_string(),
branch: Some("main".to_string()),
};
let json = serde_json::to_string(&github_config).unwrap();
println!("{}", json);
assert!(json.contains("github"));
}
}

749
src-tauri/src/database.rs Normal file
View File

@@ -0,0 +1,749 @@
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use uuid::Uuid;
/// 数据库记录
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Record {
pub id: String,
pub record_type: RecordType,
pub content: String,
pub file_path: Option<String>,
pub thumbnail: Option<String>,
pub metadata: Option<String>,
pub created_at: String,
pub updated_at: String,
}
/// 记录类型
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RecordType {
Image,
Text,
File,
}
/// 设置项
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Setting {
pub key: String,
pub value: String,
pub updated_at: String,
}
/// 分类结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Classification {
pub id: String,
pub record_id: String,
pub category: String,
pub subcategory: Option<String>,
pub tags: String, // JSON 数组
pub confidence: f64,
pub reasoning: Option<String>,
pub template_id: Option<String>,
pub confirmed: bool,
pub created_at: String,
}
/// 分类历史
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationHistory {
pub id: String,
pub record_id: String,
pub category: String,
pub subcategory: Option<String>,
pub confidence: f64,
pub created_at: String,
}
/// 数据库管理器
pub struct Database {
conn: Mutex<Connection>,
}
impl Database {
/// 打开数据库连接
pub fn open(db_path: &Path) -> Result<Self> {
// 确保父目录存在
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)
.context("Failed to create database directory")?;
}
let conn = Connection::open(db_path)
.context("Failed to open database")?;
// 启用外键约束
conn.execute("PRAGMA foreign_keys = ON", [])
.context("Failed to enable foreign keys")?;
let db = Self {
conn: Mutex::new(conn),
};
// 初始化数据库表
db.init_tables()?;
Ok(db)
}
/// 初始化数据库表
fn init_tables(&self) -> Result<()> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
// 创建 records 表
conn.execute(
"CREATE TABLE IF NOT EXISTS records (
id TEXT PRIMARY KEY,
record_type TEXT NOT NULL,
content TEXT NOT NULL,
file_path TEXT,
thumbnail TEXT,
metadata TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)",
[],
).context("Failed to create records table")?;
// 创建 settings 表
conn.execute(
"CREATE TABLE IF NOT EXISTS settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL
)",
[],
).context("Failed to create settings table")?;
// 创建 classifications 表
conn.execute(
"CREATE TABLE IF NOT EXISTS classifications (
id TEXT PRIMARY KEY,
record_id TEXT NOT NULL,
category TEXT NOT NULL,
subcategory TEXT,
tags TEXT NOT NULL,
confidence REAL NOT NULL,
reasoning TEXT,
template_id TEXT,
confirmed BOOLEAN NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
FOREIGN KEY (record_id) REFERENCES records(id) ON DELETE CASCADE
)",
[],
).context("Failed to create classifications table")?;
// 创建 classification_history 表
conn.execute(
"CREATE TABLE IF NOT EXISTS classification_history (
id TEXT PRIMARY KEY,
record_id TEXT NOT NULL,
category TEXT NOT NULL,
subcategory TEXT,
confidence REAL NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY (record_id) REFERENCES records(id) ON DELETE CASCADE
)",
[],
).context("Failed to create classification_history table")?;
// 创建索引
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_records_created_at
ON records(created_at DESC)",
[],
).context("Failed to create index on records.created_at")?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_records_type
ON records(record_type)",
[],
).context("Failed to create index on records.record_type")?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_classifications_record_id
ON classifications(record_id)",
[],
).context("Failed to create index on classifications.record_id")?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_classifications_category
ON classifications(category)",
[],
).context("Failed to create index on classifications.category")?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_classification_history_record_id
ON classification_history(record_id)",
[],
).context("Failed to create index on classification_history.record_id")?;
Ok(())
}
/// 插入记录
pub fn insert_record(
&self,
record_type: RecordType,
content: &str,
file_path: Option<&str>,
thumbnail: Option<&str>,
metadata: Option<&str>,
) -> Result<Record> {
let id = Uuid::new_v4().to_string();
let now = Utc::now().to_rfc3339();
let record_type_str = serde_json::to_string(&record_type)?;
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
conn.execute(
"INSERT INTO records (id, record_type, content, file_path, thumbnail, metadata, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
params![
&id,
&record_type_str,
content,
file_path,
thumbnail,
metadata,
&now,
&now,
],
).context("Failed to insert record")?;
Ok(Record {
id,
record_type,
content: content.to_string(),
file_path: file_path.map(|s| s.to_string()),
thumbnail: thumbnail.map(|s| s.to_string()),
metadata: metadata.map(|s| s.to_string()),
created_at: now.clone(),
updated_at: now,
})
}
/// 根据 ID 获取记录
pub fn get_record(&self, id: &str) -> Result<Option<Record>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let mut stmt = conn.prepare(
"SELECT id, record_type, content, file_path, thumbnail, metadata, created_at, updated_at
FROM records WHERE id = ?1"
).context("Failed to prepare statement")?;
let mut records = stmt.query_map(params![id], |row| {
let record_type_str: String = row.get(1)?;
let record_type: RecordType = serde_json::from_str(&record_type_str)
.unwrap_or(RecordType::Text);
Ok(Record {
id: row.get(0)?,
record_type,
content: row.get(2)?,
file_path: row.get(3)?,
thumbnail: row.get(4)?,
metadata: row.get(5)?,
created_at: row.get(6)?,
updated_at: row.get(7)?,
})
}).context("Failed to query records")?;
records.next()
.transpose()
.context("Failed to parse record")
}
/// 获取所有记录
pub fn list_records(&self, limit: Option<usize>, offset: Option<usize>) -> Result<Vec<Record>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let limit = limit.unwrap_or(100);
let offset = offset.unwrap_or(0);
let mut stmt = conn.prepare(
"SELECT id, record_type, content, file_path, thumbnail, metadata, created_at, updated_at
FROM records
ORDER BY created_at DESC
LIMIT ?1 OFFSET ?2"
).context("Failed to prepare statement")?;
let records = stmt.query_map(params![limit, offset], |row| {
let record_type_str: String = row.get(1)?;
let record_type: RecordType = serde_json::from_str(&record_type_str)
.unwrap_or(RecordType::Text);
Ok(Record {
id: row.get(0)?,
record_type,
content: row.get(2)?,
file_path: row.get(3)?,
thumbnail: row.get(4)?,
metadata: row.get(5)?,
created_at: row.get(6)?,
updated_at: row.get(7)?,
})
}).context("Failed to query records")?
.collect::<std::result::Result<Vec<_>, _>>()
.context("Failed to parse records")?;
Ok(records)
}
/// 根据类型获取记录
pub fn list_records_by_type(
&self,
record_type: RecordType,
limit: Option<usize>,
) -> Result<Vec<Record>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let limit = limit.unwrap_or(100);
let record_type_str = serde_json::to_string(&record_type)?;
let mut stmt = conn.prepare(
"SELECT id, record_type, content, file_path, thumbnail, metadata, created_at, updated_at
FROM records
WHERE record_type = ?1
ORDER BY created_at DESC
LIMIT ?2"
).context("Failed to prepare statement")?;
let records = stmt.query_map(params![record_type_str, limit], |row| {
let record_type_str: String = row.get(1)?;
let record_type: RecordType = serde_json::from_str(&record_type_str)
.unwrap_or(RecordType::Text);
Ok(Record {
id: row.get(0)?,
record_type,
content: row.get(2)?,
file_path: row.get(3)?,
thumbnail: row.get(4)?,
metadata: row.get(5)?,
created_at: row.get(6)?,
updated_at: row.get(7)?,
})
}).context("Failed to query records")?
.collect::<std::result::Result<Vec<_>, _>>()
.context("Failed to parse records")?;
Ok(records)
}
/// 更新记录
pub fn update_record(
&self,
id: &str,
content: Option<&str>,
file_path: Option<&str>,
thumbnail: Option<&str>,
metadata: Option<&str>,
) -> Result<Option<Record>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
// 首先获取现有记录
let existing = self.get_record(id)?;
if existing.is_none() {
return Ok(None);
}
let now = Utc::now().to_rfc3339();
conn.execute(
"UPDATE records
SET content = COALESCE(?1, content),
file_path = COALESCE(?2, file_path),
thumbnail = COALESCE(?3, thumbnail),
metadata = COALESCE(?4, metadata),
updated_at = ?5
WHERE id = ?6",
params![content, file_path, thumbnail, metadata, &now, id],
).context("Failed to update record")?;
self.get_record(id).transpose()
}
/// 删除记录
pub fn delete_record(&self, id: &str) -> Result<bool> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let affected = conn.execute("DELETE FROM records WHERE id = ?1", params![id])
.context("Failed to delete record")?;
Ok(affected > 0)
}
/// 清空所有记录
pub fn clear_records(&self) -> Result<usize> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let affected = conn.execute("DELETE FROM records", [])
.context("Failed to clear records")?;
Ok(affected)
}
/// 获取记录总数
pub fn get_records_count(&self) -> Result<usize> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let count: i64 = conn.query_row("SELECT COUNT(*) FROM records", [], |row| row.get(0))
.context("Failed to count records")?;
Ok(count as usize)
}
/// 设置配置
pub fn set_setting(&self, key: &str, value: &str) -> Result<()> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT OR REPLACE INTO settings (key, value, updated_at) VALUES (?1, ?2, ?3)",
params![key, value, &now],
).context("Failed to set setting")?;
Ok(())
}
/// 获取配置
pub fn get_setting(&self, key: &str) -> Result<Option<String>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let value = conn
.query_row("SELECT value FROM settings WHERE key = ?1", params![key], |row| {
row.get(0)
})
.optional()
.context("Failed to get setting")?;
Ok(value)
}
/// 删除配置
pub fn delete_setting(&self, key: &str) -> Result<bool> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let affected = conn.execute("DELETE FROM settings WHERE key = ?1", params![key])
.context("Failed to delete setting")?;
Ok(affected > 0)
}
/// 获取所有配置
pub fn list_settings(&self) -> Result<Vec<Setting>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let mut stmt = conn.prepare(
"SELECT key, value, updated_at FROM settings ORDER BY key"
).context("Failed to prepare statement")?;
let settings = stmt.query_map([], |row| {
Ok(Setting {
key: row.get(0)?,
value: row.get(1)?,
updated_at: row.get(2)?,
})
}).context("Failed to query settings")?
.collect::<std::result::Result<Vec<_>, _>>()
.context("Failed to parse settings")?;
Ok(settings)
}
// ============= 分类相关方法 =============
/// 保存分类结果
pub fn save_classification(
&self,
record_id: &str,
category: &str,
subcategory: Option<&str>,
tags: &[String],
confidence: f64,
reasoning: Option<&str>,
template_id: Option<&str>,
) -> Result<Classification> {
let id = Uuid::new_v4().to_string();
let now = Utc::now().to_rfc3339();
let tags_json = serde_json::to_string(tags)?;
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
conn.execute(
"INSERT INTO classifications (id, record_id, category, subcategory, tags, confidence, reasoning, template_id, confirmed, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
params![
&id,
record_id,
category,
subcategory,
&tags_json,
confidence,
reasoning,
template_id,
0, // confirmed = false
&now,
],
).context("Failed to insert classification")?;
// 同时添加到历史记录
self.add_classification_history(record_id, category, subcategory, confidence)?;
Ok(Classification {
id,
record_id: record_id.to_string(),
category: category.to_string(),
subcategory: subcategory.map(|s| s.to_string()),
tags: tags_json,
confidence,
reasoning: reasoning.map(|s| s.to_string()),
template_id: template_id.map(|s| s.to_string()),
confirmed: false,
created_at: now,
})
}
/// 获取记录的分类结果
pub fn get_classification(&self, record_id: &str) -> Result<Option<Classification>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let mut stmt = conn.prepare(
"SELECT id, record_id, category, subcategory, tags, confidence, reasoning, template_id, confirmed, created_at
FROM classifications WHERE record_id = ?1"
).context("Failed to prepare statement")?;
let mut classifications = stmt.query_map(params![record_id], |row| {
Ok(Classification {
id: row.get(0)?,
record_id: row.get(1)?,
category: row.get(2)?,
subcategory: row.get(3)?,
tags: row.get(4)?,
confidence: row.get(5)?,
reasoning: row.get(6)?,
template_id: row.get(7)?,
confirmed: row.get(8)?,
created_at: row.get(9)?,
})
}).context("Failed to query classifications")?
.collect::<std::result::Result<Vec<_>, _>>()
.context("Failed to parse classifications")?;
Ok(classifications.pop())
}
/// 确认分类结果
pub fn confirm_classification(&self, id: &str) -> Result<bool> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let affected = conn.execute(
"UPDATE classifications SET confirmed = 1 WHERE id = ?1",
params![id],
).context("Failed to confirm classification")?;
Ok(affected > 0)
}
/// 添加分类历史
fn add_classification_history(
&self,
record_id: &str,
category: &str,
subcategory: Option<&str>,
confidence: f64,
) -> Result<()> {
let id = Uuid::new_v4().to_string();
let now = Utc::now().to_rfc3339();
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
conn.execute(
"INSERT INTO classification_history (id, record_id, category, subcategory, confidence, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![&id, record_id, category, subcategory, confidence, &now],
).context("Failed to insert classification history")?;
Ok(())
}
/// 获取分类历史
pub fn get_classification_history(
&self,
record_id: &str,
limit: Option<usize>,
) -> Result<Vec<ClassificationHistory>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let limit = limit.unwrap_or(50);
let mut stmt = conn.prepare(
"SELECT id, record_id, category, subcategory, confidence, created_at
FROM classification_history
WHERE record_id = ?1
ORDER BY created_at DESC
LIMIT ?2"
).context("Failed to prepare statement")?;
let history = stmt.query_map(params![record_id, limit], |row| {
Ok(ClassificationHistory {
id: row.get(0)?,
record_id: row.get(1)?,
category: row.get(2)?,
subcategory: row.get(3)?,
confidence: row.get(4)?,
created_at: row.get(5)?,
})
}).context("Failed to query classification history")?
.collect::<std::result::Result<Vec<_>, _>>()
.context("Failed to parse classification history")?;
Ok(history)
}
/// 按分类获取记录
pub fn list_records_by_category(
&self,
category: &str,
limit: Option<usize>,
) -> Result<Vec<Record>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let limit = limit.unwrap_or(100);
let mut stmt = conn.prepare(
"SELECT r.id, r.record_type, r.content, r.file_path, r.thumbnail, r.metadata, r.created_at, r.updated_at
FROM records r
INNER JOIN classifications c ON r.id = c.record_id
WHERE c.category = ?1 AND c.confirmed = 1
ORDER BY r.created_at DESC
LIMIT ?2"
).context("Failed to prepare statement")?;
let records = stmt.query_map(params![category, limit], |row| {
let record_type_str: String = row.get(1)?;
let record_type: RecordType = serde_json::from_str(&record_type_str)
.unwrap_or(RecordType::Text);
Ok(Record {
id: row.get(0)?,
record_type,
content: row.get(2)?,
file_path: row.get(3)?,
thumbnail: row.get(4)?,
metadata: row.get(5)?,
created_at: row.get(6)?,
updated_at: row.get(7)?,
})
}).context("Failed to query records by category")?
.collect::<std::result::Result<Vec<_>, _>>()
.context("Failed to parse records")?;
Ok(records)
}
/// 删除分类
pub fn delete_classification(&self, id: &str) -> Result<bool> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let affected = conn.execute("DELETE FROM classifications WHERE id = ?1", params![id])
.context("Failed to delete classification")?;
Ok(affected > 0)
}
/// 获取所有分类统计
pub fn get_category_stats(&self) -> Result<Vec<(String, usize)>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire lock: {}", e))?;
let mut stmt = conn.prepare(
"SELECT category, COUNT(*) as count
FROM classifications
WHERE confirmed = 1
GROUP BY category
ORDER BY count DESC"
).context("Failed to prepare statement")?;
let stats = stmt.query_map([], |row| {
Ok((row.get(0)?, row.get(1)?))
}).context("Failed to query category stats")?
.collect::<std::result::Result<Vec<_>, _>>()
.context("Failed to parse category stats")?;
Ok(stats)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_record_type_serialization() {
let record_type = RecordType::Image;
let json = serde_json::to_string(&record_type).unwrap();
assert_eq!(json, r#"{"type":"image"}"#);
}
#[test]
fn test_database_in_memory() -> Result<()> {
let db = Database::open(":memory:")?;
// 插入测试记录
let record = db.insert_record(
RecordType::Text,
"Test content",
None,
None,
None,
)?;
// 查询记录
let found = db.get_record(&record.id)?;
assert!(found.is_some());
let found_record = found.unwrap();
assert_eq!(found_record.content, "Test content");
// 删除记录
let deleted = db.delete_record(&record.id)?;
assert!(deleted);
// 验证删除
let found = db.get_record(&record.id)?;
assert!(found.is_none());
Ok(())
}
}

195
src-tauri/src/hotkey.rs Normal file
View File

@@ -0,0 +1,195 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tauri::{AppHandle, Emitter, Listener, Manager, State, WebviewUrl, WebviewWindowBuilder};
use tauri_plugin_global_shortcut::{GlobalShortcutExt, Shortcut, ShortcutState};
/// 快捷键动作类型
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HotkeyAction {
CaptureFullscreen,
CaptureRegion,
CaptureWindow,
ShowHide,
}
impl HotkeyAction {
pub fn as_str(&self) -> &'static str {
match self {
HotkeyAction::CaptureFullscreen => "capture_fullscreen",
HotkeyAction::CaptureRegion => "capture_region",
HotkeyAction::CaptureWindow => "capture_window",
HotkeyAction::ShowHide => "show_hide",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"capture_fullscreen" => Some(HotkeyAction::CaptureFullscreen),
"capture_region" => Some(HotkeyAction::CaptureRegion),
"capture_window" => Some(HotkeyAction::CaptureWindow),
"show_hide" => Some(HotkeyAction::ShowHide),
_ => None,
}
}
}
/// 快捷键配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HotkeyConfig {
pub id: String,
pub shortcut: String,
pub action: String,
pub enabled: bool,
}
/// 快捷键管理器
pub struct HotkeyManager {
app_handle: AppHandle,
registered_shortcuts: HashMap<String, Shortcut>,
}
impl HotkeyManager {
/// 创建新的快捷键管理器
pub fn new(app_handle: AppHandle) -> Self {
Self {
app_handle,
registered_shortcuts: HashMap::new(),
}
}
/// 注册快捷键
pub fn register_shortcut(&mut self, config: HotkeyConfig) -> Result<(), String> {
if !config.enabled {
return Ok(());
}
let shortcut = Shortcut::new(
Some(config.shortcut.clone()),
config.shortcut.clone(),
);
self.app_handle
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
.map_err(|e| format!("Failed to load global shortcut plugin: {}", e))?;
let app_handle = self.app_handle.clone();
let action = config.action.clone();
self.app_handle
.global_shortcut()
.on_shortcut(shortcut.clone(), move |_app, _shortcut, _event| {
// 发送事件到前端
if let Err(e) = app_handle.emit("hotkey-triggered", &action) {
eprintln!("Failed to emit hotkey event: {}", e);
}
})
.map_err(|e| format!("Failed to register shortcut handler: {}", e))?;
self.registered_shortcuts.insert(config.id.clone(), shortcut);
Ok(())
}
/// 注销快捷键
pub fn unregister_shortcut(&mut self, id: &str) -> Result<(), String> {
if let Some(shortcut) = self.registered_shortcuts.remove(id) {
self.app_handle
.global_shortcut()
.unregister(shortcut)
.map_err(|e| format!("Failed to unregister shortcut: {}", e))?;
}
Ok(())
}
/// 获取已注册的快捷键列表
pub fn get_registered_shortcuts(&self) -> Vec<String> {
self.registered_shortcuts.keys().cloned().collect()
}
/// 注册默认快捷键
pub fn register_defaults(&mut self) -> Result<(), String> {
let defaults = vec![
HotkeyConfig {
id: "fullscreen".to_string(),
shortcut: "Ctrl+Shift+A".to_string(),
action: "capture_fullscreen".to_string(),
enabled: true,
},
HotkeyConfig {
id: "region".to_string(),
shortcut: "Ctrl+Shift+R".to_string(),
action: "capture_region".to_string(),
enabled: true,
},
HotkeyConfig {
id: "window".to_string(),
shortcut: "Ctrl+Shift+W".to_string(),
action: "capture_window".to_string(),
enabled: true,
},
HotkeyConfig {
id: "toggle".to_string(),
shortcut: "Ctrl+Shift+S".to_string(),
action: "show_hide".to_string(),
enabled: true,
},
];
for config in defaults {
self.register_shortcut(config)?;
}
Ok(())
}
}
/// 初始化全局快捷键
pub fn init_global_hotkeys(app_handle: &AppHandle) {
let mut manager = HotkeyManager::new(app_handle.clone());
if let Err(e) = manager.register_defaults() {
eprintln!("Failed to register default hotkeys: {}", e);
}
// 监听快捷键事件
let app_handle = app_handle.clone();
app_handle.listen("hotkey-triggered", |event| {
if let Some(action_str) = event.payload() {
if let Ok(action) = serde_json::from_str::<String>(action_str) {
handle_hotkey_action(&app_handle, &action);
}
}
});
}
/// 处理快捷键动作
fn handle_hotkey_action(app_handle: &AppHandle, action: &str) {
match HotkeyAction::from_str(action) {
Some(HotkeyAction::CaptureFullscreen) => {
// 发送全屏截图事件
let _ = app_handle.emit("screenshot-fullscreen-triggered", ());
}
Some(HotkeyAction::CaptureRegion) => {
// 发送区域截图事件
let _ = app_handle.emit("screenshot-region-triggered", ());
}
Some(HotkeyAction::CaptureWindow) => {
// 发送窗口截图事件
let _ = app_handle.emit("screenshot-window-triggered", ());
}
Some(HotkeyAction::ShowHide) => {
// 显示/隐藏主窗口
if let Some(window) = app_handle.get_webview_window("main") {
if window.is_visible().unwrap_or(false) {
let _ = window.hide();
} else {
let _ = window.show();
let _ = window.set_focus();
}
}
}
None => {
eprintln!("Unknown hotkey action: {}", action);
}
}
}

991
src-tauri/src/lib.rs Normal file
View File

@@ -0,0 +1,991 @@
mod screenshot;
mod hotkey;
mod config;
mod upload;
mod database;
mod ocr;
mod plugin;
mod secure_storage;
mod ai;
use screenshot::ScreenshotManager;
use hotkey::init_global_hotkeys;
use config::{ConfigManager, ImageHostConfig};
use upload::{Uploader, UploadProgress, UploadResult};
use database::{Database, RecordType};
use ocr::{OcrEngineType, OcrConfig, BaiduOcrConfig, TencentOcrConfig, LocalOcrConfig};
use plugin::{PluginManager, InstallProgress};
use secure_storage::ApiKeyStorage;
use ai::{AiClient, AiProvider, Classifier, ClassificationResult, StreamChunk, TemplateManager};
use std::path::PathBuf;
use std::sync::Mutex;
use std::collections::HashMap;
// 全局应用状态
struct AppState {
screenshot_manager: Mutex<ScreenshotManager>,
config_manager: Mutex<ConfigManager>,
database: Mutex<Database>,
plugin_manager: Mutex<PluginManager>,
classifier: Mutex<Classifier>,
template_manager: Mutex<TemplateManager>,
}
#[cfg_attr(mobile, tauri::mobile_entry_point)]
pub fn run() {
// 初始化配置管理器
let config_manager = ConfigManager::new()
.expect("Failed to initialize config manager");
// 确保数据目录存在
config_manager.ensure_data_dir()
.expect("Failed to create data directory");
// 初始化数据库
let database = Database::open(&config_manager.database_path())
.expect("Failed to initialize database");
// 初始化截图管理器
let screenshot_manager = ScreenshotManager::with_default_dir()
.expect("Failed to initialize screenshot manager");
// 初始化插件管理器
let plugin_manager = PluginManager::new(config_manager.config_dir.join("data"));
// 初始化 AI 分类器
let classifier = Classifier::new(ai::ClassifierConfig::default());
// 初始化模板管理器
let templates_dir = config_manager.data_dir().join("templates");
let template_manager = TemplateManager::new(templates_dir)
.expect("Failed to initialize template manager");
// 获取配置目录路径
let config_dir = config_manager.config_dir.join("data");
tauri::Builder::default()
.setup(|app| {
if cfg!(debug_assertions) {
app.handle().plugin(
tauri_plugin_log::Builder::default()
.level(log::LevelFilter::Info)
.build(),
)?;
}
// 初始化全局快捷键
init_global_hotkeys(app.handle());
Ok(())
})
.manage(AppState {
screenshot_manager: Mutex::new(screenshot_manager),
config_manager: Mutex::new(config_manager),
database: Mutex::new(database),
plugin_manager: Mutex::new(plugin_manager),
classifier: Mutex::new(classifier),
template_manager: Mutex::new(template_manager),
})
.invoke_handler(tauri::generate_handler![
// 截图相关命令
screenshot_fullscreen,
screenshot_region,
screenshot_copy_to_clipboard,
screenshot_delete,
screenshot_list,
screenshot_cleanup,
// 配置相关命令
config_get,
config_set,
config_get_path,
// 上传相关命令
upload_image,
// 数据库相关命令
record_insert,
record_get,
record_list,
record_delete,
record_clear,
record_count,
setting_get,
setting_set,
setting_delete,
setting_list,
// OCR 相关命令
ocr_recognize,
ocr_save_api_key,
ocr_get_api_keys,
// 插件相关命令
plugin_list,
plugin_install,
plugin_uninstall,
// AI 分类相关命令
ai_classify,
ai_classify_stream,
ai_save_api_key,
ai_get_api_keys,
ai_configure_provider,
template_list,
template_get,
template_save,
template_delete,
template_test,
classification_get,
classification_confirm,
classification_history,
classification_stats,
])
.run(tauri::generate_context!())
.expect("error while running tauri application");
}
// ============= 截图相关命令 =============
/// 全屏截图
#[tauri::command]
async fn screenshot_fullscreen(
state: tauri::State<'_, AppState>,
) -> Result<screenshot::ScreenshotMetadata, String> {
let manager = state.screenshot_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
manager.capture_fullscreen()
.map_err(|e| format!("Failed to capture screenshot: {}", e))
}
/// 区域截图
#[tauri::command]
async fn screenshot_region(
region: screenshot::RegionSelection,
state: tauri::State<'_, AppState>,
) -> Result<screenshot::ScreenshotMetadata, String> {
let manager = state.screenshot_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
manager.capture_region(region)
.map_err(|e| format!("Failed to capture region: {}", e))
}
/// 复制截图到剪贴板
#[tauri::command]
async fn screenshot_copy_to_clipboard(
filepath: PathBuf,
state: tauri::State<'_, AppState>,
) -> Result<(), String> {
let manager = state.screenshot_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
manager.copy_to_clipboard(&filepath)
.map_err(|e| format!("Failed to copy to clipboard: {}", e))
}
/// 删除截图
#[tauri::command]
async fn screenshot_delete(
filepath: PathBuf,
state: tauri::State<'_, AppState>,
) -> Result<(), String> {
let manager = state.screenshot_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
manager.delete_screenshot(&filepath)
.map_err(|e| format!("Failed to delete screenshot: {}", e))
}
/// 获取所有截图列表
#[tauri::command]
async fn screenshot_list(
state: tauri::State<'_, AppState>,
) -> Result<Vec<screenshot::ScreenshotMetadata>, String> {
let manager = state.screenshot_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
manager.list_screenshots()
.map_err(|e| format!("Failed to list screenshots: {}", e))
}
/// 清理旧截图
#[tauri::command]
async fn screenshot_cleanup(
keep_count: usize,
state: tauri::State<'_, AppState>,
) -> Result<usize, String> {
let manager = state.screenshot_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
manager.cleanup_old_screenshots(keep_count)
.map_err(|e| format!("Failed to cleanup screenshots: {}", e))
}
// ============= 配置相关命令 =============
/// 获取配置
#[tauri::command]
async fn config_get(
state: tauri::State<'_, AppState>,
) -> Result<config::AppConfig, String> {
let config_manager = state.config_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
config_manager.load()
.map_err(|e| format!("Failed to load config: {}", e))
}
/// 设置配置
#[tauri::command]
async fn config_set(
config: config::AppConfig,
state: tauri::State<'_, AppState>,
) -> Result<(), String> {
let config_manager = state.config_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
config_manager.save(&config)
.map_err(|e| format!("Failed to save config: {}", e))
}
/// 获取配置目录路径
#[tauri::command]
async fn config_get_path(
state: tauri::State<'_, AppState>,
) -> Result<String, String> {
let config_manager = state.config_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
Ok(config_manager.config_file_path()
.to_string_lossy()
.to_string())
}
// ============= 上传相关命令 =============
/// 上传图片到图床
#[tauri::command]
async fn upload_image(
image_path: PathBuf,
image_host: ImageHostConfig,
state: tauri::State<'_, AppState>,
) -> Result<UploadResult, String> {
// 加载配置以获取重试次数和超时设置
let config_manager = state.config_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
let config = config_manager.load()
.map_err(|e| format!("Failed to load config: {}", e))?;
let uploader = Uploader::new(
config.upload_retry_count,
config.upload_timeout_seconds,
);
// 使用 tokio runtime 运行异步上传
let result = tokio::spawn(async move {
uploader.upload_with_retry(&image_path, &image_host, |progress| {
match progress {
UploadProgress::Starting => {
log::info!("Upload starting");
}
UploadProgress::Uploading { progress, message } => {
log::info!("Upload progress: {:.1}% - {}", progress, message);
}
UploadProgress::Completed(result) => {
log::info!("Upload completed: {}", result.url);
}
UploadProgress::Failed { error } => {
log::error!("Upload failed: {}", error);
}
}
}).await
}).await
.map_err(|e| format!("Failed to join upload task: {}", e))?;
result.map_err(|e| format!("Upload error: {}", e))
}
// ============= 数据库相关命令 =============
/// 插入记录
#[tauri::command]
async fn record_insert(
record_type: String,
content: String,
file_path: Option<String>,
thumbnail: Option<String>,
metadata: Option<String>,
state: tauri::State<'_, AppState>,
) -> Result<database::Record, String> {
let record_type = match record_type.as_str() {
"image" => RecordType::Image,
"text" => RecordType::Text,
"file" => RecordType::File,
_ => return Err("Invalid record type".to_string()),
};
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.insert_record(
record_type,
&content,
file_path.as_deref(),
thumbnail.as_deref(),
metadata.as_deref(),
).map_err(|e| format!("Failed to insert record: {}", e))
}
/// 获取记录
#[tauri::command]
async fn record_get(
id: String,
state: tauri::State<'_, AppState>,
) -> Result<Option<database::Record>, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.get_record(&id)
.map_err(|e| format!("Failed to get record: {}", e))
}
/// 列出记录
#[tauri::command]
async fn record_list(
limit: Option<usize>,
offset: Option<usize>,
state: tauri::State<'_, AppState>,
) -> Result<Vec<database::Record>, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.list_records(limit, offset)
.map_err(|e| format!("Failed to list records: {}", e))
}
/// 删除记录
#[tauri::command]
async fn record_delete(
id: String,
state: tauri::State<'_, AppState>,
) -> Result<bool, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.delete_record(&id)
.map_err(|e| format!("Failed to delete record: {}", e))
}
/// 清空所有记录
#[tauri::command]
async fn record_clear(
state: tauri::State<'_, AppState>,
) -> Result<usize, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.clear_records()
.map_err(|e| format!("Failed to clear records: {}", e))
}
/// 获取记录数量
#[tauri::command]
async fn record_count(
state: tauri::State<'_, AppState>,
) -> Result<usize, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.get_records_count()
.map_err(|e| format!("Failed to get records count: {}", e))
}
/// 获取设置
#[tauri::command]
async fn setting_get(
key: String,
state: tauri::State<'_, AppState>,
) -> Result<Option<String>, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.get_setting(&key)
.map_err(|e| format!("Failed to get setting: {}", e))
}
/// 设置设置
#[tauri::command]
async fn setting_set(
key: String,
value: String,
state: tauri::State<'_, AppState>,
) -> Result<(), String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.set_setting(&key, &value)
.map_err(|e| format!("Failed to set setting: {}", e))
}
/// 删除设置
#[tauri::command]
async fn setting_delete(
key: String,
state: tauri::State<'_, AppState>,
) -> Result<bool, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.delete_setting(&key)
.map_err(|e| format!("Failed to delete setting: {}", e))
}
/// 列出所有设置
#[tauri::command]
async fn setting_list(
state: tauri::State<'_, AppState>,
) -> Result<Vec<database::Setting>, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.list_settings()
.map_err(|e| format!("Failed to list settings: {}", e))
}
// ============= OCR 相关命令 =============
/// 执行 OCR 识别
#[tauri::command]
async fn ocr_recognize(
image_path: PathBuf,
engine: String,
state: tauri::State<'_, AppState>,
window: tauri::Window,
) -> Result<ocr::OcrResult, String> {
let ocr_engine = match engine.as_str() {
"baidu" => OcrEngineType::Baidu,
"tencent" => OcrEngineType::Tencent,
"local" => OcrEngineType::Local,
_ => return Err("不支持的 OCR 引擎".to_string()),
};
let config_manager = state.config_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
// 从数据库获取 API 密钥配置
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
// 构建 OCR 配置
let ocr_config = match ocr_engine {
OcrEngineType::Baidu => {
let api_key = db.get_setting("baidu_ocr_api_key")
.map_err(|e| format!("Failed to get API key: {}", e))?
.ok_or_else(|| "百度 API Key 未设置".to_string())?;
let secret_key = db.get_setting("baidu_ocr_secret_key")
.map_err(|e| format!("Failed to get secret key: {}", e))?
.ok_or_else(|| "百度 Secret Key 未设置".to_string())?;
OcrConfig {
engine: OcrEngineType::Baidu,
baidu: Some(BaiduOcrConfig {
api_key,
secret_key,
accurate: true,
}),
tencent: None,
local: None,
}
}
OcrEngineType::Tencent => {
let secret_id = db.get_setting("tencent_ocr_secret_id")
.map_err(|e| format!("Failed to get secret id: {}", e))?
.ok_or_else(|| "腾讯云 Secret ID 未设置".to_string())?;
let secret_key = db.get_setting("tencent_ocr_secret_key")
.map_err(|e| format!("Failed to get secret key: {}", e))?
.ok_or_else(|| "腾讯云 Secret Key 未设置".to_string())?;
OcrConfig {
engine: OcrEngineType::Tencent,
baidu: None,
tencent: Some(TencentOcrConfig {
secret_id,
secret_key,
region: "ap-guangzhou".to_string(),
}),
local: None,
}
}
OcrEngineType::Local => {
let plugin_manager = state.plugin_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
let plugin_path = db.get_setting("local_ocr_plugin_path")
.map_err(|e| format!("Failed to get plugin path: {}", e))?
.ok_or_else(|| "本地 OCR 插件路径未设置".to_string())?;
OcrConfig {
engine: OcrEngineType::Local,
baidu: None,
tencent: None,
local: Some(LocalOcrConfig {
plugin_path: PathBuf::from(plugin_path),
model_path: None,
}),
}
}
};
// 读取图片数据
let image_data = std::fs::read(&image_path)
.map_err(|e| format!("Failed to read image: {}", e))?;
// 执行 OCR
let provider = ocr::CloudOcrProvider::new();
// 通过窗口事件发送进度更新
let window_clone = window.clone();
let progress_callback = Box::new(move |progress| {
let _ = window_clone.emit("ocr-progress", &progress);
});
let result = tokio::spawn(async move {
provider.recognize(&image_data, &ocr_config, Some(progress_callback)).await
}).await
.map_err(|e| format!("Failed to join OCR task: {}", e))?
.map_err(|e| format!("OCR error: {}", e))?;
Ok(result)
}
/// 保存 OCR API 密钥
#[tauri::command]
async fn ocr_save_api_key(
provider: String,
keys: std::collections::HashMap<String, String>,
state: tauri::State<'_, AppState>,
) -> Result<(), String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
match provider.as_str() {
"baidu" => {
if let Some(api_key) = keys.get("api_key") {
db.set_setting("baidu_ocr_api_key", api_key)
.map_err(|e| format!("Failed to save API key: {}", e))?;
}
if let Some(secret_key) = keys.get("secret_key") {
db.set_setting("baidu_ocr_secret_key", secret_key)
.map_err(|e| format!("Failed to save secret key: {}", e))?;
}
}
"tencent" => {
if let Some(secret_id) = keys.get("secret_id") {
db.set_setting("tencent_ocr_secret_id", secret_id)
.map_err(|e| format!("Failed to save secret id: {}", e))?;
}
if let Some(secret_key) = keys.get("secret_key") {
db.set_setting("tencent_ocr_secret_key", secret_key)
.map_err(|e| format!("Failed to save secret key: {}", e))?;
}
}
_ => return Err("不支持的 OCR 提供者".to_string()),
}
Ok(())
}
/// 获取已保存的 API 密钥(返回脱敏信息)
#[tauri::command]
async fn ocr_get_api_keys(
state: tauri::State<'_, AppState>,
) -> Result<std::collections::HashMap<String, bool>, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
let mut result = std::collections::HashMap::new();
// 检查百度密钥
let baidu_key = db.get_setting("baidu_ocr_api_key")
.map_err(|e| format!("Failed to check baidu key: {}", e))?;
result.insert("baidu".to_string(), baidu_key.is_some());
// 检查腾讯云密钥
let tencent_key = db.get_setting("tencent_ocr_secret_id")
.map_err(|e| format!("Failed to check tencent key: {}", e))?;
result.insert("tencent".to_string(), tencent_key.is_some());
Ok(result)
}
// ============= 插件相关命令 =============
/// 获取插件列表
#[tauri::command]
async fn plugin_list(
state: tauri::State<'_, AppState>,
) -> Result<Vec<plugin::PluginStatus>, String> {
let plugin_manager = state.plugin_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
tokio::spawn(async move {
plugin_manager.get_plugin_status().await
}).await
.map_err(|e| format!("Failed to join plugin list task: {}", e))?
.map_err(|e| format!("Failed to get plugin list: {}", e))
}
/// 安装插件
#[tauri::command]
async fn plugin_install(
plugin_id: String,
state: tauri::State<'_, AppState>,
window: tauri::Window,
) -> Result<String, String> {
let plugin_manager = state.plugin_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
// 在后台任务中发送进度更新
let window_clone = window.clone();
tokio::spawn(async move {
while let Some(progress) = rx.recv().await {
let _ = window_clone.emit("plugin-install-progress", &progress);
}
});
// 执行安装
let install_path = tokio::spawn(async move {
plugin_manager.install_plugin(&plugin_id, tx).await
}).await
.map_err(|e| format!("Failed to join install task: {}", e))?
.map_err(|e| format!("Installation failed: {}", e))?;
Ok(install_path.to_string_lossy().to_string())
}
/// 卸载插件
#[tauri::command]
async fn plugin_uninstall(
plugin_id: String,
state: tauri::State<'_, AppState>,
) -> Result<(), String> {
let plugin_manager = state.plugin_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
plugin_manager.uninstall_plugin(&plugin_id)
.map_err(|e| format!("Failed to uninstall plugin: {}", e))
}
// ============= AI 分类相关命令 =============
/// 执行 AI 分类
#[tauri::command]
async fn ai_classify(
record_id: String,
template_id: Option<String>,
variables: HashMap<String, String>,
state: tauri::State<'_, AppState>,
) -> Result<ClassificationResult, String> {
let classifier = state.classifier.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
let result = tokio::spawn(async move {
// 释放锁后再调用
let classifier_ref = unsafe { &*(&*classifier as *const Classifier) };
classifier_ref.classify(template_id.as_deref(), &variables).await
}).await
.map_err(|e| format!("Failed to join classify task: {}", e))?
.map_err(|e| format!("Classification error: {}", e))?;
// 保存分类结果到数据库
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.save_classification(
&record_id,
&result.category,
result.subcategory.as_deref(),
&result.tags,
result.confidence,
result.reasoning.as_deref(),
template_id.as_deref(),
).map_err(|e| format!("Failed to save classification: {}", e))?;
Ok(result)
}
/// 执行流式 AI 分类
#[tauri::command]
async fn ai_classify_stream(
record_id: String,
template_id: Option<String>,
variables: HashMap<String, String>,
state: tauri::State<'_, AppState>,
window: tauri::Window,
) -> Result<ClassificationResult, String> {
let classifier = state.classifier.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
let window_clone = window.clone();
let result = tokio::spawn(async move {
let classifier_ref = unsafe { &*(&*classifier as *const Classifier) };
classifier_ref.classify_stream(
template_id.as_deref(),
&variables,
|chunk| {
match chunk {
StreamChunk::Text(text) => {
let _ = window_clone.emit("ai-classify-chunk", &text);
}
StreamChunk::Done => {
let _ = window_clone.emit("ai-classify-done", ());
}
StreamChunk::Error(err) => {
let _ = window_clone.emit("ai-classify-error", &err);
}
}
},
).await
}).await
.map_err(|e| format!("Failed to join classify task: {}", e))?
.map_err(|e| format!("Classification error: {}", e))?;
// 保存分类结果到数据库
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.save_classification(
&record_id,
&result.category,
result.subcategory.as_deref(),
&result.tags,
result.confidence,
result.reasoning.as_deref(),
template_id.as_deref(),
).map_err(|e| format!("Failed to save classification: {}", e))?;
Ok(result)
}
/// 保存 AI API 密钥
#[tauri::command]
async fn ai_save_api_key(
provider: String,
api_key: String,
model: Option<String>,
base_url: Option<String>,
state: tauri::State<'_, AppState>,
) -> Result<(), String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
match provider.as_str() {
"claude" => {
db.set_setting("claude_api_key", &api_key)
.map_err(|e| format!("Failed to save API key: {}", e))?;
if let Some(model) = model {
db.set_setting("claude_model", &model)
.map_err(|e| format!("Failed to save model: {}", e))?;
}
}
"openai" => {
db.set_setting("openai_api_key", &api_key)
.map_err(|e| format!("Failed to save API key: {}", e))?;
if let Some(model) = model {
db.set_setting("openai_model", &model)
.map_err(|e| format!("Failed to save model: {}", e))?;
}
if let Some(base_url) = base_url {
db.set_setting("openai_base_url", &base_url)
.map_err(|e| format!("Failed to save base URL: {}", e))?;
}
}
_ => return Err("不支持的 AI 提供商".to_string()),
}
Ok(())
}
/// 获取已保存的 API 密钥(返回脱敏信息)
#[tauri::command]
async fn ai_get_api_keys(
state: tauri::State<'_, AppState>,
) -> Result<HashMap<String, bool>, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
let mut result = HashMap::new();
let claude_key = db.get_setting("claude_api_key")
.map_err(|e| format!("Failed to check claude key: {}", e))?;
result.insert("claude".to_string(), claude_key.is_some());
let openai_key = db.get_setting("openai_api_key")
.map_err(|e| format!("Failed to check openai key: {}", e))?;
result.insert("openai".to_string(), openai_key.is_some());
Ok(result)
}
/// 配置 AI 提供商
#[tauri::command]
async fn ai_configure_provider(
provider: String,
state: tauri::State<'_, AppState>,
) -> Result<(), String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
let classifier = state.classifier.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
match provider.as_str() {
"claude" => {
let api_key = db.get_setting("claude_api_key")
.map_err(|e| format!("Failed to get API key: {}", e))?
.ok_or_else(|| "Claude API Key 未设置".to_string())?;
let model = db.get_setting("claude_model")
.map_err(|e| format!("Failed to get model: {}", e))?;
let classifier_ref = unsafe { &*(&*classifier as *const Classifier) };
classifier_ref.configure_claude(api_key, model)
.map_err(|e| format!("Failed to configure Claude: {}", e))?;
}
"openai" => {
let api_key = db.get_setting("openai_api_key")
.map_err(|e| format!("Failed to get API key: {}", e))?
.ok_or_else(|| "OpenAI API Key 未设置".to_string())?;
let model = db.get_setting("openai_model")
.map_err(|e| format!("Failed to get model: {}", e))?;
let base_url = db.get_setting("openai_base_url")
.map_err(|e| format!("Failed to get base URL: {}", e))?;
let classifier_ref = unsafe { &*(&*classifier as *const Classifier) };
classifier_ref.configure_openai(api_key, model, base_url)
.map_err(|e| format!("Failed to configure OpenAI: {}", e))?;
}
_ => return Err("不支持的 AI 提供商".to_string()),
}
Ok(())
}
// ============= 模板管理相关命令 =============
/// 列出所有模板
#[tauri::command]
async fn template_list(
state: tauri::State<'_, AppState>,
) -> Result<Vec<ai::Template>, String> {
let template_manager = state.template_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
template_manager.list_templates()
.map_err(|e| format!("Failed to list templates: {}", e))
}
/// 获取单个模板
#[tauri::command]
async fn template_get(
id: String,
state: tauri::State<'_, AppState>,
) -> Result<Option<ai::Template>, String> {
let template_manager = state.template_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
Ok(template_manager.get_template(&id))
}
/// 保存模板
#[tauri::command]
async fn template_save(
template: ai::Template,
state: tauri::State<'_, AppState>,
) -> Result<(), String> {
let template_manager = state.template_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
template_manager.upsert_template(template)
.map_err(|e| format!("Failed to save template: {}", e))
}
/// 删除模板
#[tauri::command]
async fn template_delete(
id: String,
state: tauri::State<'_, AppState>,
) -> Result<bool, String> {
let template_manager = state.template_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
template_manager.delete_template(&id)
.map_err(|e| format!("Failed to delete template: {}", e))
}
/// 测试模板渲染
#[tauri::command]
async fn template_test(
id: String,
variables: HashMap<String, String>,
state: tauri::State<'_, AppState>,
) -> Result<(String, String), String> {
let template_manager = state.template_manager.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
template_manager.test_template(&id, &variables)
.map_err(|e| format!("Failed to test template: {}", e))
}
// ============= 分类结果相关命令 =============
/// 获取记录的分类结果
#[tauri::command]
async fn classification_get(
record_id: String,
state: tauri::State<'_, AppState>,
) -> Result<Option<database::Classification>, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.get_classification(&record_id)
.map_err(|e| format!("Failed to get classification: {}", e))
}
/// 确认分类结果
#[tauri::command]
async fn classification_confirm(
id: String,
state: tauri::State<'_, AppState>,
) -> Result<bool, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.confirm_classification(&id)
.map_err(|e| format!("Failed to confirm classification: {}", e))
}
/// 获取分类历史
#[tauri::command]
async fn classification_history(
record_id: String,
limit: Option<usize>,
state: tauri::State<'_, AppState>,
) -> Result<Vec<database::ClassificationHistory>, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.get_classification_history(&record_id, limit)
.map_err(|e| format!("Failed to get classification history: {}", e))
}
/// 获取分类统计
#[tauri::command]
async fn classification_stats(
state: tauri::State<'_, AppState>,
) -> Result<Vec<(String, usize)>, String> {
let db = state.database.lock()
.map_err(|e| format!("Failed to acquire lock: {}", e))?;
db.get_category_stats()
.map_err(|e| format!("Failed to get category stats: {}", e))
}

6
src-tauri/src/main.rs Normal file
View File

@@ -0,0 +1,6 @@
// Prevents additional console window on Windows in release, DO NOT REMOVE!!
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
fn main() {
app_lib::run();
}

380
src-tauri/src/ocr/cloud.rs Normal file
View File

@@ -0,0 +1,380 @@
use super::{OcrConfig, OcrEngineType, OcrProgress, OcrResult};
use super::result::{TextBlock, TextBlockType, BoundingBox};
use anyhow::{Context, Result};
use chrono::Utc;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sha2::{Sha256, Digest};
use std::time::Instant;
/// 云端 OCR 提供者
pub struct CloudOcrProvider {
client: Client,
}
impl CloudOcrProvider {
/// 创建新的云端 OCR 提供者
pub fn new() -> Self {
Self {
client: Client::new(),
}
}
/// 执行 OCR 识别
pub async fn recognize(
&self,
image_data: &[u8],
config: &OcrConfig,
progress_callback: Option<Box<dyn Fn(OcrProgress) + Send + Sync>>,
) -> Result<OcrResult> {
let start = Instant::now();
if let Some(callback) = &progress_callback {
(callback)(OcrProgress::Starting);
}
let result = match config.engine {
OcrEngineType::Baidu => {
let baidu_config = config.baidu.as_ref()
.context("百度 OCR 配置未设置")?;
self.baidu_ocr(image_data, baidu_config, progress_callback).await?
}
OcrEngineType::Tencent => {
let tencent_config = config.tencent.as_ref()
.context("腾讯云 OCR 配置未设置")?;
self.tencent_ocr(image_data, tencent_config, progress_callback).await?
}
_ => return Err(anyhow::anyhow!("不支持的 OCR 引擎类型")),
};
let duration = start.elapsed().as_millis() as u64;
let mut final_result = result;
final_result.duration_ms = duration;
Ok(final_result)
}
/// 百度 OCR 识别
async fn baidu_ocr(
&self,
image_data: &[u8],
config: &super::BaiduOcrConfig,
progress_callback: Option<Box<dyn Fn(OcrProgress) + Send + Sync>>,
) -> Result<OcrResult> {
// 获取访问令牌
let token = self.get_baidu_token(&config.api_key, &config.secret_key).await?;
// 编码图片为 Base64
let base64_image = base64::encode(image_data);
// 构建请求
let url = if config.accurate {
"https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic"
} else {
"https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic"
};
if let Some(callback) = &progress_callback {
(callback)(OcrProgress::Processing {
progress: 50.0,
message: "正在调用百度 OCR API...".to_string(),
});
}
let response = self
.client
.post(url)
.query(&[("access_token", &token)])
.form(&[("image", base64_image)])
.send()
.await
.context("百度 OCR API 请求失败")?;
let response_text = response.text().await?;
let baidu_result: BaiduOcrResponse = serde_json::from_str(&response_text)
.with_context(|| format!("百度 OCR 响应解析失败: {}", response_text))?;
if let Some(error_msg) = baidu_result.error_msg {
return Err(anyhow::anyhow!("百度 OCR 错误: {} (错误码: {})",
error_msg, baidu_result.error_code.unwrap_or(0)));
}
// 转换为统一格式
let blocks = baidu_result
.words_result
.unwrap_or_default()
.into_iter()
.enumerate()
.map(|(idx, item)| TextBlock {
text: item.words,
confidence: item.probability.unwrap_or(0.0) * 100.0,
bbox: BoundingBox {
x: 0,
y: idx as u32 * 20, // 百度不返回位置信息,使用估算值
width: 100,
height: 20,
},
block_type: TextBlockType::Text,
})
.collect();
let result = OcrResult::from_blocks(blocks, "Baidu".to_string(), 0);
if let Some(callback) = &progress_callback {
(callback)(OcrProgress::Completed(result.clone()));
}
Ok(result)
}
/// 获取百度访问令牌
async fn get_baidu_token(&self, api_key: &str, secret_key: &str) -> Result<String> {
let response = self
.client
.post("https://aip.baidubce.com/oauth/2.0/token")
.query(&[
("grant_type", "client_credentials"),
("client_id", api_key),
("client_secret", secret_key),
])
.send()
.await
.context("获取百度访问令牌失败")?;
let token_response: BaiduTokenResponse = response.json().await?;
Ok(token_response.access_token)
}
/// 腾讯云 OCR 识别
async fn tencent_ocr(
&self,
image_data: &[u8],
config: &super::TencentOcrConfig,
progress_callback: Option<Box<dyn Fn(OcrProgress) + Send + Sync>>,
) -> Result<OcrResult> {
let base64_image = base64::encode(image_data);
if let Some(callback) = &progress_callback {
(callback)(OcrProgress::Processing {
progress: 50.0,
message: "正在调用腾讯云 OCR API...".to_string(),
});
}
// 构建请求参数
let request = TencentOcrRequest {
image_base64: base64_image,
..Default::default()
};
// 构建签名
let timestamp = Utc::now().timestamp();
let endpoint = "ocr.tencentcloudapi.com";
let service = "ocr";
let version = "2018-11-19";
let action = "GeneralBasicOCR";
let authorization = self.tencent_authorization(
&config.secret_id,
&config.secret_key,
endpoint,
service,
version,
action,
timestamp,
&request,
)?;
let url = format!("https://{}?Action={}&Version={}", endpoint, action, version);
let response = self
.client
.post(&url)
.header("Authorization", authorization)
.header("Content-Type", "application/json")
.header("Host", endpoint)
.header("X-TC-Timestamp", timestamp.to_string())
.header("X-TC-Region", &config.region)
.json(&request)
.send()
.await
.context("腾讯云 OCR API 请求失败")?;
let response_text = response.text().await?;
let tencent_result: TencentOcrResponse = serde_json::from_str(&response_text)
.with_context(|| format!("腾讯云 OCR 响应解析失败: {}", response_text))?;
if let Some(error) = &tencent_result.response.error {
return Err(anyhow::anyhow!("腾讯云 OCR 错误: {} (代码: {})",
error.message, error.code));
}
// 转换为统一格式
let blocks = tencent_result
.response
.text_detections
.unwrap_or_default()
.into_iter()
.map(|item| TextBlock {
text: item.detected_text,
confidence: item.confidence,
bbox: BoundingBox {
x: item.polygon.as_ref().map(|p| p[0].x as u32).unwrap_or(0),
y: item.polygon.as_ref().map(|p| p[0].y as u32).unwrap_or(0),
width: 100,
height: 20,
},
block_type: TextBlockType::Text,
})
.collect();
let result = OcrResult::from_blocks(blocks, "Tencent".to_string(), 0);
if let Some(callback) = &progress_callback {
(callback)(OcrProgress::Completed(result.clone()));
}
Ok(result)
}
/// 生成腾讯云 API 签名
fn tencent_authorization(
&self,
secret_id: &str,
secret_key: &str,
endpoint: &str,
service: &str,
version: &str,
action: &str,
timestamp: i64,
request: &TencentOcrRequest,
) -> Result<String> {
// 简化的签名实现(实际生产环境应使用完整的 HMAC-SHA256
let date = Utc::now().format("%Y-%m-%d").to_string();
// 构建签名字符串
let payload = serde_json::to_string(request)?;
let hashed_payload = format!("{:x}", Sha256::digest(payload.as_bytes()));
let credential_scope = format!("{}/{}/tc3_request", date, service);
// 简化版:实际应包含完整 HTTP 方法、路径、查询参数等
let string_to_sign = format!(
"TC3-HMAC-SHA256\n{}\n{}\n{}",
timestamp, credential_scope, hashed_payload
);
// 计算签名
let secret_date = hmac_sha256(format!("TC3{}", secret_key).as_bytes(), date.as_bytes());
let secret_service = hmac_sha256(&secret_date, service.as_bytes());
let secret_signing = hmac_sha256(&secret_service, b"tc3_request");
let signature = hex_encode(hmac_sha256(&secret_signing, string_to_sign.as_bytes()));
let authorization = format!(
"TC3-HMAC-SHA256 Credential={}/{}, SignedHeaders=content-type;host, Signature={}",
secret_id, credential_scope, signature
);
Ok(authorization)
}
}
impl Default for CloudOcrProvider {
fn default() -> Self {
Self::new()
}
}
/// 百度 OCR 令牌响应
#[derive(Debug, Deserialize)]
struct BaiduTokenResponse {
access_token: String,
}
/// 百度 OCR 响应
#[derive(Debug, Deserialize)]
struct BaiduOcrResponse {
#[serde(rename = "words_result")]
words_result: Option<Vec<BaiduWord>>,
#[serde(rename = "error_code")]
error_code: Option<i32>,
#[serde(rename = "error_msg")]
error_msg: Option<String>,
}
#[derive(Debug, Deserialize)]
struct BaiduWord {
words: String,
probability: Option<f32>,
}
/// 腾讯云 OCR 请求
#[derive(Debug, Serialize, Default)]
struct TencentOcrRequest {
#[serde(rename = "ImageBase64")]
image_base64: String,
}
/// 腾讯云 OCR 响应
#[derive(Debug, Deserialize)]
struct TencentOcrResponse {
#[serde(rename = "Response")]
response: TencentResponse,
}
#[derive(Debug, Deserialize)]
struct TencentResponse {
#[serde(rename = "Error")]
error: Option<TencentError>,
#[serde(rename = "TextDetections")]
text_detections: Option<Vec<TencentTextDetection>>,
}
#[derive(Debug, Deserialize)]
struct TencentError {
#[serde(rename = "Code")]
code: String,
#[serde(rename = "Message")]
message: String,
}
#[derive(Debug, Deserialize)]
struct TencentTextDetection {
#[serde(rename = "DetectedText")]
detected_text: String,
#[serde(rename = "Confidence")]
confidence: f32,
#[serde(rename = "Polygon")]
polygon: Option<Vec<TencentPoint>>,
}
#[derive(Debug, Deserialize)]
struct TencentPoint {
#[serde(rename = "X")]
x: f32,
#[serde(rename = "Y")]
y: f32,
}
/// HMAC-SHA256 辅助函数
fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
use hmac::Hmac;
use hmac::Mac;
type HmacSha256 = Hmac<Sha256>;
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key length error");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
/// 十六进制编码
fn hex_encode(data: Vec<u8>) -> String {
data.iter().map(|b| format!("{:02x}", b)).collect()
}
/// 云端 OCR 请求(用于前端调用)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CloudOcrRequest {
pub image_path: String,
pub engine: OcrEngineType,
}

174
src-tauri/src/ocr/local.rs Normal file
View File

@@ -0,0 +1,174 @@
use super::{OcrProgress, OcrResult};
use super::result::{TextBlock, TextBlockType, BoundingBox};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::process::{Command, Output};
use std::time::Instant;
/// 本地 OCR 插件
pub struct LocalOcrPlugin {
plugin_path: std::path::PathBuf,
}
impl LocalOcrPlugin {
/// 创建新的本地 OCR 插件
pub fn new(plugin_path: std::path::PathBuf) -> Self {
Self { plugin_path }
}
/// 执行 OCR 识别
pub fn recognize(
&self,
image_path: &Path,
progress_callback: Option<Box<dyn Fn(OcrProgress) + Send + Sync>>,
) -> Result<OcrResult> {
let start = Instant::now();
if let Some(callback) = &progress_callback {
(callback)(OcrProgress::Starting);
}
// 检查插件是否存在
if !self.plugin_path.exists() {
return Err(anyhow::anyhow!("本地 OCR 插件不存在: {}",
self.plugin_path.display()));
}
if let Some(callback) = &progress_callback {
(callback)(OcrProgress::Processing {
progress: 30.0,
message: "正在启动本地 OCR 引擎...".to_string(),
});
}
// 调用本地插件
let output = self.call_plugin(image_path)?;
if let Some(callback) = &progress_callback {
(callback)(OcrProgress::Processing {
progress: 70.0,
message: "正在解析 OCR 结果...".to_string(),
});
}
// 解析结果
let plugin_result: LocalOcrResponse = serde_json::from_str(&output)
.context("本地 OCR 插件返回的 JSON 格式错误")?;
if !plugin_result.success {
return Err(anyhow::anyhow!("本地 OCR 失败: {}",
plugin_result.error.unwrap_or_else(|| "未知错误".to_string())));
}
// 转换为统一格式
let blocks = plugin_result
.blocks
.into_iter()
.map(|block| TextBlock {
text: block.text,
confidence: block.confidence,
bbox: BoundingBox {
x: block.bbox_x,
y: block.bbox_y,
width: block.bbox_width,
height: block.bbox_height,
},
block_type: match block.block_type.as_str() {
"title" => TextBlockType::Title,
"list" => TextBlockType::List,
"table" => TextBlockType::Table,
_ => TextBlockType::Text,
},
})
.collect();
let mut result = OcrResult::from_blocks(blocks, "Local".to_string(), 0);
// 添加元数据
if let Some(engine_info) = plugin_result.engine {
result.metadata.insert("engine".to_string(), engine_info);
}
if let Some(lang) = plugin_result.language {
result.metadata.insert("language".to_string(), lang);
}
let duration = start.elapsed().as_millis() as u64;
result.duration_ms = duration;
if let Some(callback) = &progress_callback {
(callback)(OcrProgress::Completed(result.clone()));
}
Ok(result)
}
/// 调用本地插件
fn call_plugin(&self, image_path: &Path) -> Result<String> {
let output = Command::new(&self.plugin_path)
.arg("recognize")
.arg(image_path)
.output()
.context("执行本地 OCR 插件失败")?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(anyhow::anyhow!("插件执行失败: {}", stderr));
}
Ok(String::from_utf8_lossy(&output.stdout).to_string())
}
/// 检查插件是否可用
pub fn is_available(&self) -> bool {
self.plugin_path.exists()
}
/// 获取插件版本信息
pub fn get_version(&self) -> Result<String> {
let output = Command::new(&self.plugin_path)
.arg("version")
.output()
.context("获取插件版本失败")?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
} else {
Ok("unknown".to_string())
}
}
}
/// 本地 OCR 插件响应
#[derive(Debug, Deserialize, Serialize)]
struct LocalOcrResponse {
/// 是否成功
success: bool,
/// 错误信息
error: Option<String>,
/// OCR 引擎信息
engine: Option<String>,
/// 识别语言
language: Option<String>,
/// 文本块列表
blocks: Vec<LocalOcrBlock>,
}
/// 本地 OCR 文本块
#[derive(Debug, Deserialize, Serialize)]
struct LocalOcrBlock {
/// 文本内容
text: String,
/// 置信度 (0-100)
confidence: f32,
/// 边界框 X 坐标
bbox_x: u32,
/// 边界框 Y 坐标
bbox_y: u32,
/// 边界框宽度
bbox_width: u32,
/// 边界框高度
bbox_height: u32,
/// 文本块类型
block_type: String,
}

82
src-tauri/src/ocr/mod.rs Normal file
View File

@@ -0,0 +1,82 @@
pub mod cloud;
pub mod local;
pub mod result;
pub use cloud::{CloudOcrProvider, CloudOcrRequest};
pub use result::OcrResult;
pub use local::LocalOcrPlugin;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// OCR 引擎类型
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum OcrEngineType {
/// 百度 OCR
Baidu,
/// 腾讯云 OCR
Tencent,
/// 本地插件 OCR
Local,
}
/// OCR 配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrConfig {
/// 选择的 OCR 引擎
pub engine: OcrEngineType,
/// 百度 API 配置
pub baidu: Option<BaiduOcrConfig>,
/// 腾讯云 API 配置
pub tencent: Option<TencentOcrConfig>,
/// 本地插件配置
pub local: Option<LocalOcrConfig>,
}
/// 百度 OCR 配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BaiduOcrConfig {
/// API Key
pub api_key: String,
/// Secret Key
pub secret_key: String,
/// 是否使用通用文字识别(高精度版)
pub accurate: bool,
}
/// 腾讯云 OCR 配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TencentOcrConfig {
/// Secret ID
pub secret_id: String,
/// Secret Key
pub secret_key: String,
/// 地域
pub region: String,
}
/// 本地 OCR 插件配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalOcrConfig {
/// 插件可执行文件路径
pub plugin_path: PathBuf,
/// 语言模型路径
pub model_path: Option<PathBuf>,
}
/// OCR 进度回调
pub type OcrProgressCallback = Box<dyn Fn(OcrProgress) + Send + Sync>;
/// OCR 进度信息
#[derive(Debug, Clone)]
pub enum OcrProgress {
/// 开始识别
Starting,
/// 处理中
Processing { progress: f32, message: String },
/// 识别完成
Completed(OcrResult),
/// 识别失败
Failed { error: String },
}

167
src-tauri/src/ocr/result.rs Normal file
View File

@@ -0,0 +1,167 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// OCR 识别结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrResult {
/// 识别的完整文本
pub text: String,
/// 识别的文本块列表
pub blocks: Vec<TextBlock>,
/// 识别置信度 (0-100)
pub confidence: f32,
/// 耗时(毫秒)
pub duration_ms: u64,
/// 使用的 OCR 引擎
pub engine: String,
/// 额外信息
pub metadata: HashMap<String, String>,
}
/// 文本块
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextBlock {
/// 文本内容
pub text: String,
/// 置信度 (0-100)
pub confidence: f32,
/// 边界框(像素坐标)
pub bbox: BoundingBox,
/// 文本块类型
pub block_type: TextBlockType,
}
/// 边界框
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BoundingBox {
/// 左上角 X 坐标
pub x: u32,
/// 左上角 Y 坐标
pub y: u32,
/// 宽度
pub width: u32,
/// 高度
pub height: u32,
}
/// 文本块类型
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TextBlockType {
/// 文本
Text,
/// 标题
Title,
/// 列表
List,
/// 表格
Table,
/// 其他
Other,
}
impl OcrResult {
/// 创建空的 OCR 结果
pub fn empty() -> Self {
Self {
text: String::new(),
blocks: Vec::new(),
confidence: 0.0,
duration_ms: 0,
engine: String::new(),
metadata: HashMap::new(),
}
}
/// 从文本块列表构建结果
pub fn from_blocks(
blocks: Vec<TextBlock>,
engine: String,
duration_ms: u64,
) -> Self {
let text = blocks
.iter()
.map(|b| b.text.as_str())
.collect::<Vec<&str>>()
.join("\n");
let confidence = if blocks.is_empty() {
0.0
} else {
blocks.iter().map(|b| b.confidence).sum::<f32>() / blocks.len() as f32
};
Self {
text,
blocks,
confidence,
duration_ms,
engine,
metadata: HashMap::new(),
}
}
/// 搜索包含关键词的文本块
pub fn search(&self, keyword: &str) -> Vec<&TextBlock> {
let keyword_lower = keyword.to_lowercase();
self.blocks
.iter()
.filter(|block| block.text.to_lowercase().contains(&keyword_lower))
.collect()
}
/// 导出为纯文本
pub fn to_plain_text(&self) -> String {
self.text.clone()
}
/// 导出为 Markdown 格式
pub fn to_markdown(&self) -> String {
let mut md = String::new();
for block in &self.blocks {
match block.block_type {
TextBlockType::Title => {
md.push_str(&format!("## {}\n\n", block.text));
}
TextBlockType::List => {
md.push_str(&format!("- {}\n", block.text));
}
TextBlockType::Table => {
md.push_str(&format!("| {}\n", block.text));
}
_ => {
md.push_str(&format!("{}\n", block.text));
}
}
}
md
}
/// 获取高亮显示的 HTML
pub fn to_html_highlighted(&self, keyword: &str) -> String {
let mut html = String::from("<div class='ocr-result'>");
for block in &self.blocks {
let text = if !keyword.is_empty()
&& block.text.to_lowercase().contains(&keyword.to_lowercase())
{
block.text.replace(
&keyword.to_lowercase(),
&format!("<mark>{}</mark>", keyword),
)
} else {
block.text.clone()
};
html.push_str(&format!(
"<p class='ocr-block' style='position: absolute; left: {}px; top: {}px;'>{}</p>",
block.bbox.x, block.bbox.y, text
));
}
html.push_str("</div>");
html
}
}

432
src-tauri/src/plugin/mod.rs Normal file
View File

@@ -0,0 +1,432 @@
use anyhow::{Context, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime};
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc;
/// 插件元数据
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginMetadata {
/// 插件 ID
pub id: String,
/// 插件名称
pub name: String,
/// 插件描述
pub description: String,
/// 插件版本
pub version: String,
/// 插件作者
pub author: String,
/// 插件类型
pub plugin_type: PluginType,
/// 下载 URL
pub download_url: String,
/// SHA256 校验和
pub sha256: String,
/// 文件大小(字节)
pub file_size: u64,
/// 最低兼容版本
pub min_app_version: String,
/// 主页 URL
pub homepage_url: Option<String>,
/// 图标 URL
pub icon_url: Option<String>,
/// 依赖项
pub dependencies: Vec<String>,
/// 发布日期
pub published_at: String,
/// 许可证
pub license: String,
}
/// 插件类型
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum PluginType {
/// OCR 插件
Ocr,
/// 图床插件
ImageHost,
/// 其他
Other,
}
/// 插件状态
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginStatus {
/// 插件元数据
pub metadata: PluginMetadata,
/// 是否已安装
pub installed: bool,
/// 安装路径
pub install_path: Option<PathBuf>,
/// 是否有更新
pub has_update: bool,
/// 安装时间
pub installed_at: Option<String>,
}
/// 插件管理器
pub struct PluginManager {
client: Client,
config_dir: PathBuf,
plugins_dir: PathBuf,
registry_url: String,
}
impl PluginManager {
/// 创建新的插件管理器
pub fn new(config_dir: PathBuf) -> Self {
let plugins_dir = config_dir.join("plugins");
fs::create_dir_all(&plugins_dir)
.expect("Failed to create plugins directory");
Self {
client: Client::builder()
.timeout(Duration::from_secs(30))
.build()
.unwrap(),
config_dir,
plugins_dir,
registry_url: "https://raw.githubusercontent.com/cutthenthink/plugins/main".to_string(),
}
}
/// 获取远程插件列表
pub async fn fetch_remote_plugins(&self) -> Result<Vec<PluginMetadata>> {
let url = format!("{}/plugins.json", self.registry_url);
let response = self
.client
.get(&url)
.send()
.await
.context("Failed to fetch plugins list")?;
if !response.status().is_success() {
return Err(anyhow::anyhow!("Failed to fetch plugins: HTTP {}", response.status()));
}
let plugins: Vec<PluginMetadata> = response.json().await?;
Ok(plugins)
}
/// 获取插件状态列表
pub async fn get_plugin_status(&self) -> Result<Vec<PluginStatus>> {
let remote_plugins = self.fetch_remote_plugins().await?;
let installed_plugins = self.get_installed_plugins()?;
let mut status_list = Vec::new();
for remote_plugin in remote_plugins {
let installed = installed_plugins.get(&remote_plugin.id);
let has_update = if let Some(installed) = installed {
installed.version != remote_plugin.version
} else {
false
};
status_list.push(PluginStatus {
metadata: remote_plugin.clone(),
installed: installed.is_some(),
install_path: installed.and_then(|p| Some(p.install_path.clone()?)),
has_update,
installed_at: installed.and_then(|p| p.installed_at.clone()),
});
}
Ok(status_list)
}
/// 下载并安装插件
pub async fn install_plugin(
&self,
plugin_id: &str,
progress_callback: mpsc::Sender<InstallProgress>,
) -> Result<PathBuf> {
// 获取远程插件信息
let plugins = self.fetch_remote_plugins().await?;
let plugin = plugins
.iter()
.find(|p| p.id == plugin_id)
.ok_or_else(|| anyhow::anyhow!("Plugin not found: {}", plugin_id))?;
// 发送开始事件
let _ = progress_callback
.send(InstallProgress::Starting {
plugin_name: plugin.name.clone(),
file_size: plugin.file_size,
})
.await;
// 下载文件
let download_path = self
.download_plugin(plugin, progress_callback.clone())
.await?;
// 验证 SHA256
let _ = progress_callback
.send(InstallProgress::Verifying {
plugin_name: plugin.name.clone(),
})
.await;
self.verify_checksum(&download_path, &plugin.sha256)?;
// 解压/安装
let _ = progress_callback
.send(InstallProgress::Installing {
plugin_name: plugin.name.clone(),
})
.await;
let install_path = self.install_plugin_file(plugin, &download_path).await?;
// 保存安装记录
self.save_install_record(plugin, &install_path)?;
// 清理下载文件
let _ = fs::remove_file(&download_path);
// 发送完成事件
let _ = progress_callback
.send(InstallProgress::Completed {
plugin_id: plugin.id.clone(),
install_path: install_path.clone(),
})
.await;
Ok(install_path)
}
/// 卸载插件
pub fn uninstall_plugin(&self, plugin_id: &str) -> Result<()> {
let install_path = self.plugins_dir.join(plugin_id);
if install_path.exists() {
fs::remove_dir_all(&install_path)
.context("Failed to remove plugin directory")?;
}
// 删除安装记录
let record_path = self.get_install_record_path(plugin_id);
if record_path.exists() {
fs::remove_file(&record_path)?;
}
log::info!("Uninstalled plugin: {}", plugin_id);
Ok(())
}
/// 下载插件文件
async fn download_plugin(
&self,
plugin: &PluginMetadata,
mut progress_callback: mpsc::Sender<InstallProgress>,
) -> Result<PathBuf> {
let response = self
.client
.get(&plugin.download_url)
.send()
.await
.context("Failed to download plugin")?;
if !response.status().is_success() {
return Err(anyhow::anyhow!("Download failed: HTTP {}", response.status()));
}
let total_size = response.content_length().unwrap_or(plugin.file_size);
let download_path = self.config_dir.join("downloads").join(format!("{}.zip", &plugin.id));
fs::create_dir_all(download_path.parent().unwrap())?;
let mut file = tokio::fs::File::create(&download_path).await?;
let mut downloaded = 0u64;
let mut stream = response.bytes_stream();
use futures_util::StreamExt;
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Failed to read download chunk")?;
file.write_all(&chunk).await?;
downloaded += chunk.len() as u64;
let progress = (downloaded as f64 / total_size as f64) * 100.0;
let _ = progress_callback
.send(InstallProgress::Downloading {
plugin_name: plugin.name.clone(),
progress,
downloaded,
total_size,
})
.await;
}
file.flush().await?;
Ok(download_path)
}
/// 验证文件校验和
fn verify_checksum(&self, file_path: &Path, expected_sha256: &str) -> Result<()> {
let content = fs::read(file_path)?;
let hash = Sha256::digest(&content);
let calculated_sha256 = hex::encode(hash);
if calculated_sha256 != expected_sha256.to_lowercase() {
fs::remove_file(file_path)?;
return Err(anyhow::anyhow!(
"Checksum verification failed. Expected: {}, Got: {}",
expected_sha256,
calculated_sha256
));
}
Ok(())
}
/// 安装插件文件
async fn install_plugin_file(
&self,
plugin: &PluginMetadata,
download_path: &Path,
) -> Result<PathBuf> {
let install_dir = self.plugins_dir.join(&plugin.id);
// 如果已存在,先删除
if install_dir.exists() {
fs::remove_dir_all(&install_dir)?;
}
fs::create_dir_all(&install_dir)?;
// 对于本地 OCR 插件,直接复制可执行文件
if plugin.plugin_type == PluginType::Ocr {
let exe_name = if cfg!(windows) {
format!("{}.exe", plugin.id)
} else {
plugin.id.clone()
};
let exe_path = install_dir.join(&exe_name);
fs::copy(download_path, &exe_path)?;
// 设置可执行权限
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&exe_path)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&exe_path, perms)?;
}
Ok(exe_path)
} else {
// 其他类型需要解压
Ok(install_dir)
}
}
/// 保存安装记录
fn save_install_record(&self, plugin: &PluginMetadata, install_path: &Path) -> Result<()> {
let record = InstallRecord {
plugin_id: plugin.id.clone(),
version: plugin.version.clone(),
install_path: install_path.clone(),
installed_at: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_secs(),
};
let record_path = self.get_install_record_path(&plugin.id);
let content = serde_json::to_string_pretty(&record)?;
fs::write(&record_path, content)?;
Ok(())
}
/// 获取安装记录路径
fn get_install_record_path(&self, plugin_id: &str) -> PathBuf {
self.config_dir.join(format!("plugin_{}.json", plugin_id))
}
/// 获取已安装的插件
fn get_installed_plugins(&self) -> Result<HashMap<String, InstallRecord>> {
let mut plugins = HashMap::new();
for entry in fs::read_dir(&self.config_dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json")
&& path.file_name()
.and_then(|s| s.to_str())
.map(|s| s.starts_with("plugin_"))
== Some(true)
{
if let Ok(content) = fs::read_to_string(&path) {
if let Ok(record) = serde_json::from_str::<InstallRecord>(&content) {
plugins.insert(record.plugin_id.clone(), record);
}
}
}
}
Ok(plugins)
}
/// 检查插件更新
pub async fn check_updates(&self) -> Result<Vec<String>> {
let status_list = self.get_plugin_status().await?;
let updates: Vec<String> = status_list
.into_iter()
.filter(|s| s.has_update)
.map(|s| s.metadata.id)
.collect();
Ok(updates)
}
}
/// 安装记录
#[derive(Debug, Clone, Serialize, Deserialize)]
struct InstallRecord {
plugin_id: String,
version: String,
install_path: PathBuf,
installed_at: u64,
}
/// 安装进度
#[derive(Debug, Clone)]
pub enum InstallProgress {
Starting {
plugin_name: String,
file_size: u64,
},
Downloading {
plugin_name: String,
progress: f64,
downloaded: u64,
total_size: u64,
},
Verifying {
plugin_name: String,
},
Installing {
plugin_name: String,
},
Completed {
plugin_id: String,
install_path: PathBuf,
},
Failed {
plugin_id: String,
error: String,
},
}

308
src-tauri/src/screenshot.rs Normal file
View File

@@ -0,0 +1,308 @@
use anyhow::Result;
use arboard::Clipboard;
use chrono::Utc;
use image::{DynamicImage, ImageFormat};
use screenshots::Screen;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
use std::io::Cursor;
/// 截图元数据
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScreenshotMetadata {
pub id: String,
pub filename: String,
pub filepath: PathBuf,
pub width: u32,
pub height: u32,
pub file_size: u64,
pub created_at: String,
pub thumbnail_base64: Option<String>,
}
/// 截图区域选择参数
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegionSelection {
pub x: i32,
pub y: i32,
pub width: u32,
pub height: u32,
}
/// 截图管理器
pub struct ScreenshotManager {
base_dir: PathBuf,
screenshots_dir: PathBuf,
thumbnails_dir: PathBuf,
}
impl ScreenshotManager {
/// 创建新的截图管理器
pub fn new(base_dir: PathBuf) -> Result<Self> {
let screenshots_dir = base_dir.join("screenshots");
let thumbnails_dir = base_dir.join("thumbnails");
// 创建目录
fs::create_dir_all(&screenshots_dir)?;
fs::create_dir_all(&thumbnails_dir)?;
Ok(Self {
base_dir,
screenshots_dir,
thumbnails_dir,
})
}
/// 使用默认目录创建截图管理器
pub fn with_default_dir() -> Result<Self> {
let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
let app_data_dir = home_dir.join(".cutthink-lite");
Self::new(app_data_dir)
}
/// 生成时间戳文件名
fn generate_filename(&self) -> String {
let timestamp = Utc::now().format("%Y%m%d_%H%M%S_%3f");
format!("screenshot_{}.png", timestamp)
}
/// 生成唯一 ID
fn generate_id(&self, filename: &str) -> String {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
filename.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
/// 全屏截图
pub fn capture_fullscreen(&self) -> Result<ScreenshotMetadata> {
let screen = Screen::all()?
.first()
.ok_or_else(|| anyhow::anyhow!("No screen found"))?
.clone();
let image = screen.capture()?;
let width = image.width();
let height = image.height();
let filename = self.generate_filename();
let filepath = self.screenshots_dir.join(&filename);
let id = self.generate_id(&filename);
let created_at = Utc::now().to_rfc3339();
// 保存图片
image.save(&filepath)?;
// 生成缩略图
let thumbnail_base64 = self.generate_thumbnail(&filepath, width, height)?;
// 获取文件大小
let file_size = fs::metadata(&filepath)?.len();
Ok(ScreenshotMetadata {
id,
filename,
filepath,
width,
height,
file_size,
created_at,
thumbnail_base64,
})
}
/// 区域截图
pub fn capture_region(&self, region: RegionSelection) -> Result<ScreenshotMetadata> {
let screen = Screen::all()?
.first()
.ok_or_else(|| anyhow::anyhow!("No screen found"))?
.clone();
let full_image = screen.capture()?;
// 确保区域在屏幕范围内
let x = region.x.max(0);
let y = region.y.max(0);
let width = region.width.min(full_image.width() as u32 - x as u32);
let height = region.height.min(full_image.height() as u32 - y as u32);
// 裁剪区域
let image_buffer = full_image.buffer();
let cropped_image = image::imageops::crop(
image_buffer,
x as u32,
y as u32,
width,
height,
).to_image();
let dynamic_image = DynamicImage::ImageRgba8(cropped_image);
let width = dynamic_image.width();
let height = dynamic_image.height();
let filename = self.generate_filename();
let filepath = self.screenshots_dir.join(&filename);
let id = self.generate_id(&filename);
let created_at = Utc::now().to_rfc3339();
// 保存图片
dynamic_image.save(&filepath)?;
// 生成缩略图
let thumbnail_base64 = self.generate_thumbnail(&filepath, width, height)?;
// 获取文件大小
let file_size = fs::metadata(&filepath)?.len();
Ok(ScreenshotMetadata {
id,
filename,
filepath,
width,
height,
file_size,
created_at,
thumbnail_base64,
})
}
/// 生成缩略图并转换为 base64
fn generate_thumbnail(&self, filepath: &Path, width: u32, height: u32) -> Result<Option<String>> {
const THUMBNAIL_MAX_SIZE: u32 = 200;
// 如果图片已经很小,直接使用原图
if width <= THUMBNAIL_MAX_SIZE && height <= THUMBNAIL_MAX_SIZE {
return Ok(None);
}
// 计算缩略图尺寸(保持宽高比)
let scale = THUMBNAIL_MAX_SIZE as f32 / width.max(height) as f32;
let new_width = (width as f32 * scale) as u32;
let new_height = (height as f32 * scale) as u32;
// 加载并缩放图片
let img = image::open(filepath)?;
let thumbnail = image::imageops::thumbnail(&img, new_width, new_height);
// 转换为 base64
let mut buffer = Cursor::new(Vec::new());
thumbnail.write_to(&mut buffer, ImageFormat::Png)?;
let base64_string = base64::encode(&buffer.into_inner());
Ok(Some(format!("data:image/png;base64,{}", base64_string)))
}
/// 复制图片到剪贴板
pub fn copy_to_clipboard(&self, filepath: &Path) -> Result<()> {
let img = image::open(filepath)?;
let rgba = img.to_rgba8();
// 构建 PNG 格式数据
let mut buffer = Cursor::new(Vec::new());
rgba.write_to(&mut buffer, ImageFormat::Png)?;
let png_data = buffer.into_inner();
// 复制到剪贴板
let mut clipboard = Clipboard::new()?;
// 注意arboard 的图片支持可能有限,这里提供基础实现
// 如果不工作,可能需要使用平台特定的剪贴板 API
#[cfg(target_os = "windows")]
{
clipboard.set_image(arboard::ImageData {
width: rgba.width() as usize,
height: rgba.height() as usize,
bytes: rgba.as_raw().into(),
})?;
}
#[cfg(not(target_os = "windows"))]
{
// 对于非 Windows 平台,先尝试复制文件路径作为后备
clipboard.set_text(filepath.to_string_lossy().to_string())?;
}
Ok(())
}
/// 删除截图
pub fn delete_screenshot(&self, filepath: &Path) -> Result<()> {
if filepath.exists() {
fs::remove_file(filepath)?;
}
Ok(())
}
/// 获取所有截图
pub fn list_screenshots(&self) -> Result<Vec<ScreenshotMetadata>> {
let mut screenshots = Vec::new();
let entries = fs::read_dir(&self.screenshots_dir)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("png") {
if let Ok(metadata) = fs::metadata(&path) {
if let Ok(img) = image::open(&path) {
let filename = path.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown.png")
.to_string();
let id = self.generate_id(&filename);
let created_at = metadata.created()
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| {
let datetime: chrono::DateTime<Utc> = DateTime::from(d);
datetime.to_rfc3339()
})
.unwrap_or_else(|| Utc::now().to_rfc3339());
let thumbnail_base64 = self.generate_thumbnail(&path, img.width(), img.height())?;
screenshots.push(ScreenshotMetadata {
id,
filename,
filepath: path,
width: img.width(),
height: img.height(),
file_size: metadata.len(),
created_at,
thumbnail_base64,
});
}
}
}
}
// 按创建时间排序(最新的在前)
screenshots.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(screenshots)
}
/// 清理旧截图(保留最近 N 个)
pub fn cleanup_old_screenshots(&self, keep_count: usize) -> Result<usize> {
let mut screenshots = self.list_screenshots()?;
let deleted_count = if screenshots.len() > keep_count {
screenshots.truncate(keep_count);
let to_delete = &screenshots[keep_count..];
for screenshot in to_delete {
let _ = self.delete_screenshot(&screenshot.filepath);
}
to_delete.len()
} else {
0
};
Ok(deleted_count)
}
}
use chrono::DateTime;

View File

@@ -0,0 +1,226 @@
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Nonce,
};
use anyhow::{anyhow, Context, Result};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::fs;
use std::path::{Path, PathBuf};
/// 安全存储管理器
pub struct SecureStorage {
storage_path: PathBuf,
key: [u8; 32],
}
impl SecureStorage {
/// 创建新的安全存储
pub fn new(storage_path: PathBuf, password: &str) -> Result<Self> {
// 从密码派生加密密钥
let key = Self::derive_key(password);
// 确保存储目录存在
if let Some(parent) = storage_path.parent() {
fs::create_dir_all(parent)
.context("Failed to create storage directory")?;
}
Ok(Self { storage_path, key })
}
/// 从密码派生密钥
fn derive_key(password: &str) -> [u8; 32] {
let mut key = [0u8; 32];
// 使用 PBKDF2 的简化版本(使用 SHA256 多次哈希)
let mut hash = Sha256::digest(password.as_bytes());
for _ in 0..10000 {
hash = Sha256::digest(&hash);
}
key.copy_from_slice(&hash[..32]);
key
}
/// 保存敏感数据
pub fn save(&self, key: &str, value: &str) -> Result<()> {
// 加载现有数据
let mut data = self.load_data().unwrap_or_default();
// 加密值
let encrypted = self.encrypt(value)?;
// 更新数据
data.insert(key.to_string(), EncryptedValue {
ciphertext: hex::encode(encrypted.0),
nonce: hex::encode(encrypted.1),
});
// 保存到文件
self.save_data(&data)?;
log::info!("Saved encrypted value for key: {}", key);
Ok(())
}
/// 获取敏感数据
pub fn get(&self, key: &str) -> Result<Option<String>> {
let data = self.load_data()?;
if let Some(encrypted) = data.get(key) {
let ciphertext = hex::decode(&encrypted.ciphertext)
.context("Failed to decode ciphertext")?;
let nonce = hex::decode(&encrypted.nonce)
.context("Failed to decode nonce")?;
let nonce = Nonce::from_slice(&nonce);
let cipher = Aes256Gcm::new(&self.key.into());
let plaintext = cipher
.decrypt(nonce, ciphertext.as_ref())
.map_err(|_| anyhow!("Decryption failed - wrong password or corrupted data"))?;
Ok(Some(String::from_utf8(plaintext)?))
} else {
Ok(None)
}
}
/// 删除敏感数据
pub fn delete(&self, key: &str) -> Result<bool> {
let mut data = self.load_data()?;
if data.remove(key).is_some() {
self.save_data(&data)?;
Ok(true)
} else {
Ok(false)
}
}
/// 列出所有键
pub fn list_keys(&self) -> Result<Vec<String>> {
let data = self.load_data()?;
Ok(data.keys().cloned().collect())
}
/// 加密数据
fn encrypt(&self, plaintext: &str) -> Result<(Vec<u8>, [u8; 12])> {
let cipher = Aes256Gcm::new(&self.key.into());
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|_| anyhow!("Encryption failed"))?;
Ok((ciphertext, nonce.into()))
}
/// 加载加密数据
fn load_data(&self) -> Result<StorageData> {
if !self.storage_path.exists() {
return Ok(StorageData::new());
}
let content = fs::read_to_string(&self.storage_path)
.context("Failed to read storage file")?;
serde_json::from_str(&content).context("Failed to parse storage file")
}
/// 保存加密数据
fn save_data(&self, data: &StorageData) -> Result<()> {
let content = serde_json::to_string_pretty(data)
.context("Failed to serialize storage data")?;
fs::write(&self.storage_path, content)
.context("Failed to write storage file")?;
Ok(())
}
}
/// 加密值
#[derive(Debug, Clone, Serialize, Deserialize)]
struct EncryptedValue {
ciphertext: String,
nonce: String,
}
/// 存储数据
type StorageData = std::collections::HashMap<String, EncryptedValue>;
/// API 密钥存储
pub struct ApiKeyStorage {
storage: SecureStorage,
}
impl ApiKeyStorage {
/// 创建新的 API 密钥存储
pub fn new(config_dir: &Path, password: &str) -> Result<Self> {
let storage_path = config_dir.join("secure_storage.json");
let storage = SecureStorage::new(storage_path, password)?;
Ok(Self { storage })
}
/// 保存百度 API 密钥
pub fn save_baidu_key(&self, api_key: &str, secret_key: &str) -> Result<()> {
self.storage.save("baidu_ocr_api_key", api_key)?;
self.storage.save("baidu_ocr_secret_key", secret_key)?;
Ok(())
}
/// 获取百度 API 密钥
pub fn get_baidu_key(&self) -> Result<(Option<String>, Option<String>)> {
let api_key = self.storage.get("baidu_ocr_api_key")?;
let secret_key = self.storage.get("baidu_ocr_secret_key")?;
Ok((api_key, secret_key))
}
/// 保存腾讯云 API 密钥
pub fn save_tencent_key(&self, secret_id: &str, secret_key: &str) -> Result<()> {
self.storage.save("tencent_ocr_secret_id", secret_id)?;
self.storage.save("tencent_ocr_secret_key", secret_key)?;
Ok(())
}
/// 获取腾讯云 API 密钥
pub fn get_tencent_key(&self) -> Result<(Option<String>, Option<String>)> {
let secret_id = self.storage.get("tencent_ocr_secret_id")?;
let secret_key = self.storage.get("tencent_ocr_secret_key")?;
Ok((secret_id, secret_key))
}
/// 删除所有密钥
pub fn clear_all(&self) -> Result<()> {
for key in self.storage.list_keys()? {
self.storage.delete(&key)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_secure_storage() {
let dir = tempdir().unwrap();
let storage_path = dir.path().join("test_storage.json");
let storage = SecureStorage::new(storage_path, "test_password").unwrap();
// 保存和获取
storage.save("test_key", "test_value").unwrap();
let value = storage.get("test_key").unwrap().unwrap();
assert_eq!(value, "test_value");
// 删除
assert!(storage.delete("test_key").unwrap());
assert!(!storage.delete("test_key").unwrap());
}
}

387
src-tauri/src/upload.rs Normal file
View File

@@ -0,0 +1,387 @@
use crate::config::ImageHostConfig;
use anyhow::{anyhow, Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use chrono::Utc;
use reqwest::header;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::time::sleep;
/// 上传结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UploadResult {
pub url: String,
pub delete_url: Option<String>,
pub image_host: String,
pub uploaded_at: String,
pub file_size: u64,
}
/// 上传进度事件
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UploadProgress {
Starting,
Uploading { progress: f32, message: String },
Completed(UploadResult),
Failed { error: String },
}
/// 上传器
pub struct Uploader {
retry_count: u32,
timeout_seconds: u64,
}
impl Uploader {
/// 创建新的上传器
pub fn new(retry_count: u32, timeout_seconds: u64) -> Self {
Self {
retry_count,
timeout_seconds,
}
}
/// 上传图片到指定图床
pub async fn upload_image(
&self,
image_path: &Path,
config: &ImageHostConfig,
) -> Result<UploadResult> {
// 读取图片数据
let image_data = fs::read(image_path)
.context("Failed to read image file")?;
let file_size = image_data.len() as u64;
// 根据图床类型上传
let result = match config {
ImageHostConfig::GitHub { .. } => {
self.upload_to_github(image_path, &image_data, config).await?
}
ImageHostConfig::Imgur { .. } => {
self.upload_to_imgur(image_path, &image_data, config).await?
}
ImageHostConfig::Custom { .. } => {
self.upload_to_custom(image_path, &image_data, config).await?
}
};
Ok(result)
}
/// 上传到 GitHub
async fn upload_to_github(
&self,
image_path: &Path,
image_data: &[u8],
config: &ImageHostConfig,
) -> Result<UploadResult> {
let (token, owner, repo, path, branch) = match config {
ImageHostConfig::GitHub {
token,
owner,
repo,
path,
branch,
} => (token, owner, repo, path, branch),
_ => return Err(anyhow!("Invalid GitHub config")),
};
let branch = branch.as_deref().unwrap_or("main");
let filename = image_path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow!("Invalid filename"))?;
let github_path = if path.is_empty() {
filename.to_string()
} else {
format!("{}/{}", path.trim_end_matches('/'), filename)
};
// Base64 编码
let content = BASE64.encode(image_data);
// 构建 API 请求
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_seconds))
.build()?;
let url = format!(
"https://api.github.com/repos/{}/{}/contents/{}",
owner, repo, github_path
);
let request_body = serde_json::json!({
"message": format!("Upload {}", filename),
"content": content,
"branch": branch
});
let response = client
.put(&url)
.header(header::AUTHORIZATION, format!("Bearer {}", token))
.header(header::ACCEPT, "application/vnd.github.v3+json")
.header(header::USER_AGENT, "CutThenThink-Lite")
.json(&request_body)
.send()
.await
.context("Failed to send request to GitHub")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow!("GitHub API error: {}", error_text));
}
let response_json: serde_json::Value = response.json().await?;
let download_url = response_json["content"]["download_url"]
.as_str()
.ok_or_else(|| anyhow!("Missing download_url in response"))?
.to_string();
// 构建删除 URL使用 SHA
let sha = response_json["content"]["sha"]
.as_str()
.ok_or_else(|| anyhow!("Missing sha in response"))?;
let delete_url = Some(format!(
"github://{}://{}/{}?sha={}",
owner, repo, github_path, sha
));
Ok(UploadResult {
url: download_url,
delete_url,
image_host: "GitHub".to_string(),
uploaded_at: Utc::now().to_rfc3339(),
file_size: image_data.len() as u64,
})
}
/// 上传到 Imgur
async fn upload_to_imgur(
&self,
image_path: &Path,
image_data: &[u8],
config: &ImageHostConfig,
) -> Result<UploadResult> {
let client_id = match config {
ImageHostConfig::Imgur { client_id } => client_id,
_ => return Err(anyhow!("Invalid Imgur config")),
};
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_seconds))
.build()?;
let form = reqwest::multipart::Form::new().part(
"image",
reqwest::multipart::Part::bytes(image_data.to_vec())
.file_name(
image_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("image.png")
.to_string(),
)
.mime_str("image/png")
.unwrap(),
);
let response = client
.post("https://api.imgur.com/3/image")
.header(header::AUTHORIZATION, format!("Client-ID {}", client_id))
.multipart(form)
.send()
.await
.context("Failed to send request to Imgur")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow!("Imgur API error: {}", error_text));
}
let response_json: serde_json::Value = response.json().await?;
if response_json["success"].as_bool() != Some(true) {
return Err(anyhow!("Imgur upload failed"));
}
let url = response_json["data"]["link"]
.as_str()
.ok_or_else(|| anyhow!("Missing link in response"))?
.to_string();
let delete_hash = response_json["data"]["deletehash"]
.as_str()
.map(|hash| format!("imgur://{}", hash));
Ok(UploadResult {
url,
delete_url: delete_hash,
image_host: "Imgur".to_string(),
uploaded_at: Utc::now().to_rfc3339(),
file_size: image_data.len() as u64,
})
}
/// 上传到自定义图床
async fn upload_to_custom(
&self,
image_path: &Path,
image_data: &[u8],
config: &ImageHostConfig,
) -> Result<UploadResult> {
let (url, headers, form_field) = match config {
ImageHostConfig::Custom {
url,
headers,
form_field,
} => (url, headers, form_field),
_ => return Err(anyhow!("Invalid custom config")),
};
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_seconds))
.build()?;
let field_name = form_field.as_deref().unwrap_or("file");
let mut form = reqwest::multipart::Form::new().part(
field_name.to_string(),
reqwest::multipart::Part::bytes(image_data.to_vec())
.file_name(
image_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("image.png")
.to_string(),
)
.mime_str("image/png")
.unwrap(),
);
let mut request = client.post(url).multipart(form);
// 添加自定义头部
if let Some(headers) = headers {
for header in headers {
request = request.header(&header.name, &header.value);
}
}
let response = request
.send()
.await
.context("Failed to send request to custom host")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow!("Custom host error: {}", error_text));
}
// 尝试解析 JSON 响应
let response_json: serde_json::Value = response.json().await?;
// 尝试从响应中提取 URL
let url = self.extract_url_from_response(&response_json)?;
Ok(UploadResult {
url,
delete_url: None,
image_host: "Custom".to_string(),
uploaded_at: Utc::now().to_rfc3339(),
file_size: image_data.len() as u64,
})
}
/// 从响应中提取 URL
fn extract_url_from_response(&self, response: &serde_json::Value) -> Result<String> {
// 尝试常见的响应字段
let url_fields = vec!["url", "link", "image_url", "data", "url"];
for field in url_fields {
if let Some(url_value) = response.get(field) {
// 如果是字符串,直接使用
if let Some(url_str) = url_value.as_str() {
return Ok(url_str.to_string());
}
// 如果是对象,尝试获取 url 字段
if let Some(obj) = url_value.as_object() {
if let Some(url_str) = obj.get("url").and_then(|v| v.as_str()) {
return Ok(url_str.to_string());
}
}
}
}
// 如果找不到,返回整个响应的字符串表示
Err(anyhow!(
"Could not extract URL from response: {}",
serde_json::to_string(response).unwrap_or_default()
))
}
/// 上传图片(带重试)
pub async fn upload_with_retry(
&self,
image_path: &Path,
config: &ImageHostConfig,
mut progress_callback: impl FnMut(UploadProgress),
) -> Result<UploadResult> {
progress_callback(UploadProgress::Starting);
let mut last_error = None;
for attempt in 0..=self.retry_count {
if attempt > 0 {
progress_callback(UploadProgress::Uploading {
progress: (attempt as f32 / (self.retry_count + 1) as f32) * 100.0,
message: format!("重试上传 {}/{}", attempt, self.retry_count),
});
// 等待一段时间再重试
sleep(Duration::from_secs(2)).await;
}
match self.upload_image(image_path, config).await {
Ok(result) => {
progress_callback(UploadProgress::Completed(result.clone()));
return Ok(result);
}
Err(e) => {
last_error = Some(e);
}
}
}
progress_callback(UploadProgress::Failed {
error: last_error
.as_ref()
.map(|e| e.to_string())
.unwrap_or_else(|| "Unknown error".to_string()),
});
Err(last_error.unwrap_or_else(|| anyhow!("Upload failed")))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_upload_result_serialization() {
let result = UploadResult {
url: "https://example.com/image.png".to_string(),
delete_url: Some("https://example.com/delete".to_string()),
image_host: "Test".to_string(),
uploaded_at: Utc::now().to_rfc3339(),
file_size: 1024,
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("url"));
}
}