diff --git a/crates/db/Cargo.toml b/crates/db/Cargo.toml index 0910aee5..8cd9983a 100644 --- a/crates/db/Cargo.toml +++ b/crates/db/Cargo.toml @@ -14,7 +14,7 @@ serde_json = { workspace = true } anyhow = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } -sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid"] } +sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "sqlite-preupdate-hook", "chrono", "uuid"] } chrono = { version = "0.4", features = ["serde"] } uuid = { version = "1.0", features = ["v4", "serde"] } ts-rs = { workspace = true } diff --git a/crates/local-deployment/Cargo.toml b/crates/local-deployment/Cargo.toml index 9120df4a..beeee086 100644 --- a/crates/local-deployment/Cargo.toml +++ b/crates/local-deployment/Cargo.toml @@ -17,7 +17,7 @@ serde_json = { workspace = true } anyhow = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } -sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid"] } +sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "sqlite-preupdate-hook", "chrono", "uuid"] } chrono = { version = "0.4", features = ["serde"] } uuid = { version = "1.0", features = ["v4", "serde"] } ts-rs = { workspace = true } diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index 1fb45109..66dbca28 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -22,7 +22,7 @@ serde_json = { workspace = true } anyhow = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } -sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid"] } +sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "sqlite-preupdate-hook", "chrono", "uuid"] } chrono = { version = "0.4", features = ["serde"] } uuid = { version = "1.0", features = ["v4", "serde"] } ts-rs = { workspace = true } diff --git a/crates/services/Cargo.toml b/crates/services/Cargo.toml index 36818d04..9ee917de 100644 --- a/crates/services/Cargo.toml +++ b/crates/services/Cargo.toml @@ -19,7 +19,7 @@ serde_json = { workspace = true } anyhow = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } -sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid"] } +sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "sqlite-preupdate-hook", "chrono", "uuid"] } chrono = { version = "0.4", features = ["serde"] } uuid = { version = "1.0", features = ["v4", "serde"] } ts-rs = { workspace = true } diff --git a/crates/services/src/services/events.rs b/crates/services/src/services/events.rs index 63344195..eab9b197 100644 --- a/crates/services/src/services/events.rs +++ b/crates/services/src/services/events.rs @@ -13,7 +13,7 @@ use futures::StreamExt; use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation}; use serde::{Deserialize, Serialize}; use serde_json::json; -use sqlx::{Error as SqlxError, SqlitePool, sqlite::SqliteOperation}; +use sqlx::{Error as SqlxError, SqlitePool, ValueRef, sqlite::SqliteOperation}; use strum_macros::{Display, EnumString}; use thiserror::Error; use tokio::sync::RwLock; @@ -32,98 +32,169 @@ pub enum EventError { Other(#[from] AnyhowError), // Catches any unclassified errors } +/// Trait for types that can be used in JSON patch operations +pub trait Patchable: serde::Serialize { + const PATH_PREFIX: &'static str; + type Id: ToString + Copy; + fn id(&self) -> Self::Id; +} + +/// Implementations of Patchable for all supported types +impl Patchable for TaskWithAttemptStatus { + const PATH_PREFIX: &'static str = "/tasks"; + type Id = Uuid; + fn id(&self) -> Self::Id { + self.id + } +} + +impl Patchable for ExecutionProcess { + const PATH_PREFIX: &'static str = "/execution_processes"; + type Id = Uuid; + fn id(&self) -> Self::Id { + self.id + } +} + +impl Patchable for TaskAttempt { + const PATH_PREFIX: &'static str = "/task_attempts"; + type Id = Uuid; + fn id(&self) -> Self::Id { + self.id + } +} + +impl Patchable for db::models::follow_up_draft::FollowUpDraft { + const PATH_PREFIX: &'static str = "/follow_up_drafts"; + type Id = Uuid; + fn id(&self) -> Self::Id { + self.id + } +} + +/// Generic patch operations that work with any Patchable type +pub mod patch_ops { + use super::*; + + /// Escape JSON Pointer special characters + pub(crate) fn escape_pointer_segment(s: &str) -> String { + s.replace('~', "~0").replace('/', "~1") + } + + /// Create path for operation + fn path_for(id: T::Id) -> String { + format!( + "{}/{}", + T::PATH_PREFIX, + escape_pointer_segment(&id.to_string()) + ) + } + + /// Create patch for adding a new record + pub fn add(value: &T) -> Patch { + Patch(vec![PatchOperation::Add(AddOperation { + path: path_for::(value.id()) + .try_into() + .expect("Path should be valid"), + value: serde_json::to_value(value).expect("Serialization should not fail"), + })]) + } + + /// Create patch for updating an existing record + pub fn replace(value: &T) -> Patch { + Patch(vec![PatchOperation::Replace(ReplaceOperation { + path: path_for::(value.id()) + .try_into() + .expect("Path should be valid"), + value: serde_json::to_value(value).expect("Serialization should not fail"), + })]) + } + + /// Create patch for removing a record + pub fn remove(id: T::Id) -> Patch { + Patch(vec![PatchOperation::Remove(RemoveOperation { + path: path_for::(id).try_into().expect("Path should be valid"), + })]) + } +} + /// Helper functions for creating task-specific patches pub mod task_patch { use super::*; - /// Escape JSON Pointer special characters - fn escape_pointer_segment(s: &str) -> String { - s.replace('~', "~0").replace('/', "~1") - } - - /// Create path for task operation - fn task_path(task_id: Uuid) -> String { - format!("/tasks/{}", escape_pointer_segment(&task_id.to_string())) - } - /// Create patch for adding a new task pub fn add(task: &TaskWithAttemptStatus) -> Patch { - Patch(vec![PatchOperation::Add(AddOperation { - path: task_path(task.id) - .try_into() - .expect("Task path should be valid"), - value: serde_json::to_value(task).expect("Task serialization should not fail"), - })]) + patch_ops::add(task) } /// Create patch for updating an existing task pub fn replace(task: &TaskWithAttemptStatus) -> Patch { - Patch(vec![PatchOperation::Replace(ReplaceOperation { - path: task_path(task.id) - .try_into() - .expect("Task path should be valid"), - value: serde_json::to_value(task).expect("Task serialization should not fail"), - })]) + patch_ops::replace(task) } /// Create patch for removing a task pub fn remove(task_id: Uuid) -> Patch { - Patch(vec![PatchOperation::Remove(RemoveOperation { - path: task_path(task_id) - .try_into() - .expect("Task path should be valid"), - })]) + patch_ops::remove::(task_id) } } /// Helper functions for creating execution process-specific patches pub mod execution_process_patch { - use db::models::execution_process::ExecutionProcess; - use super::*; - /// Escape JSON Pointer special characters - fn escape_pointer_segment(s: &str) -> String { - s.replace('~', "~0").replace('/', "~1") - } - - /// Create path for execution process operation - fn execution_process_path(process_id: Uuid) -> String { - format!( - "/execution_processes/{}", - escape_pointer_segment(&process_id.to_string()) - ) - } - /// Create patch for adding a new execution process pub fn add(process: &ExecutionProcess) -> Patch { - Patch(vec![PatchOperation::Add(AddOperation { - path: execution_process_path(process.id) - .try_into() - .expect("Execution process path should be valid"), - value: serde_json::to_value(process) - .expect("Execution process serialization should not fail"), - })]) + patch_ops::add(process) } /// Create patch for updating an existing execution process pub fn replace(process: &ExecutionProcess) -> Patch { - Patch(vec![PatchOperation::Replace(ReplaceOperation { - path: execution_process_path(process.id) - .try_into() - .expect("Execution process path should be valid"), - value: serde_json::to_value(process) - .expect("Execution process serialization should not fail"), - })]) + patch_ops::replace(process) } /// Create patch for removing an execution process pub fn remove(process_id: Uuid) -> Patch { - Patch(vec![PatchOperation::Remove(RemoveOperation { - path: execution_process_path(process_id) - .try_into() - .expect("Execution process path should be valid"), - })]) + patch_ops::remove::(process_id) + } +} + +/// Helper functions for creating task attempt-specific patches +pub mod task_attempt_patch { + use super::*; + + /// Create patch for adding a new task attempt + pub fn add(attempt: &TaskAttempt) -> Patch { + patch_ops::add(attempt) + } + + /// Create patch for updating an existing task attempt + pub fn replace(attempt: &TaskAttempt) -> Patch { + patch_ops::replace(attempt) + } + + /// Create patch for removing a task attempt + pub fn remove(attempt_id: Uuid) -> Patch { + patch_ops::remove::(attempt_id) + } +} + +/// Helper functions for creating follow up draft-specific patches +pub mod follow_up_draft_patch { + use super::*; + + /// Create patch for adding a new follow up draft + pub fn add(draft: &db::models::follow_up_draft::FollowUpDraft) -> Patch { + patch_ops::add(draft) + } + + /// Create patch for updating an existing follow up draft + pub fn replace(draft: &db::models::follow_up_draft::FollowUpDraft) -> Patch { + patch_ops::replace(draft) + } + + /// Create patch for removing a follow up draft + pub fn remove(draft_id: Uuid) -> Patch { + patch_ops::remove::(draft_id) } } @@ -154,24 +225,6 @@ pub enum RecordTypes { TaskAttempt(TaskAttempt), ExecutionProcess(ExecutionProcess), FollowUpDraft(db::models::follow_up_draft::FollowUpDraft), - DeletedTask { - rowid: i64, - project_id: Option, - task_id: Option, - }, - DeletedTaskAttempt { - rowid: i64, - task_id: Option, - }, - DeletedExecutionProcess { - rowid: i64, - task_attempt_id: Option, - process_id: Option, - }, - DeletedFollowUpDraft { - rowid: i64, - task_attempt_id: Option, - }, } #[derive(Serialize, Deserialize, TS)] @@ -248,6 +301,85 @@ impl EventService { Box::pin(async move { let mut handle = conn.lock_handle().await?; let runtime_handle = tokio::runtime::Handle::current(); + + // Set up preupdate hook to capture task data before deletion + handle.set_preupdate_hook({ + let msg_store_for_preupdate = msg_store_for_hook.clone(); + move |preupdate: sqlx::sqlite::PreupdateHookResult<'_>| { + if preupdate.operation == sqlx::sqlite::SqliteOperation::Delete { + match preupdate.table { + "tasks" => { + // Extract task ID from old column values before deletion + if let Ok(id_value) = preupdate.get_old_column_value(0) + && !id_value.is_null() + { + // Decode UUID from SQLite value + if let Ok(task_id) = + >::decode( + id_value, + ) + { + let patch = task_patch::remove(task_id); + msg_store_for_preupdate.push_patch(patch); + } + } + } + "execution_processes" => { + // Extract process ID from old column values before deletion + if let Ok(id_value) = preupdate.get_old_column_value(0) + && !id_value.is_null() + { + // Decode UUID from SQLite value + if let Ok(process_id) = + >::decode( + id_value, + ) + { + let patch = execution_process_patch::remove(process_id); + msg_store_for_preupdate.push_patch(patch); + } + } + } + "task_attempts" => { + // Extract attempt ID from old column values before deletion + if let Ok(id_value) = preupdate.get_old_column_value(0) + && !id_value.is_null() + { + // Decode UUID from SQLite value + if let Ok(attempt_id) = + >::decode( + id_value, + ) + { + let patch = task_attempt_patch::remove(attempt_id); + msg_store_for_preupdate.push_patch(patch); + } + } + } + "follow_up_drafts" => { + // Extract draft ID from old column values before deletion + if let Ok(id_value) = preupdate.get_old_column_value(0) + && !id_value.is_null() + { + // Decode UUID from SQLite value + if let Ok(draft_id) = + >::decode( + id_value, + ) + { + let patch = follow_up_draft_patch::remove(draft_id); + msg_store_for_preupdate.push_patch(patch); + } + } + } + _ => { + // Ignore other tables + } + } + } + } + }); + handle.set_update_hook(move |hook: sqlx::sqlite::UpdateHookResult<'_>| { let runtime_handle = runtime_handle.clone(); let entry_count_for_hook = entry_count_for_hook.clone(); @@ -259,49 +391,27 @@ impl EventService { runtime_handle.spawn(async move { let record_type: RecordTypes = match (table, hook.operation.clone()) { (HookTables::Tasks, SqliteOperation::Delete) => { - // Try to get task before deletion to capture project_id and task_id - let task_info = - Task::find_by_rowid(&db.pool, rowid).await.ok().flatten(); - RecordTypes::DeletedTask { - rowid, - project_id: task_info.as_ref().map(|t| t.project_id), - task_id: task_info.as_ref().map(|t| t.id), - } - } - (HookTables::TaskAttempts, SqliteOperation::Delete) => { - // Try to get task_attempt before deletion to capture task_id - let task_id = TaskAttempt::find_by_rowid(&db.pool, rowid) - .await - .ok() - .flatten() - .map(|attempt| attempt.task_id); - RecordTypes::DeletedTaskAttempt { rowid, task_id } + // Task deletion is now handled by preupdate hook + // Skip post-update processing to avoid duplicate patches + return; } (HookTables::ExecutionProcesses, SqliteOperation::Delete) => { - // Try to get execution_process before deletion to capture full process data - if let Ok(Some(process)) = - ExecutionProcess::find_by_rowid(&db.pool, rowid).await - { - RecordTypes::DeletedExecutionProcess { - rowid, - task_attempt_id: Some(process.task_attempt_id), - process_id: Some(process.id), - } - } else { - RecordTypes::DeletedExecutionProcess { - rowid, - task_attempt_id: None, - process_id: None, - } - } + // Execution process deletion is now handled by preupdate hook + // Skip post-update processing to avoid duplicate patches + return; + } + (HookTables::TaskAttempts, SqliteOperation::Delete) => { + // Task attempt deletion is now handled by preupdate hook + // Skip post-update processing to avoid duplicate patches + return; } (HookTables::Tasks, _) => { match Task::find_by_rowid(&db.pool, rowid).await { Ok(Some(task)) => RecordTypes::Task(task), - Ok(None) => RecordTypes::DeletedTask { - rowid, - project_id: None, - task_id: None, + Ok(None) => { + // Row not found - likely already deleted, skip processing + tracing::debug!("Task rowid {} not found, skipping", rowid); + return; }, Err(e) => { tracing::error!("Failed to fetch task: {:?}", e); @@ -312,9 +422,10 @@ impl EventService { (HookTables::TaskAttempts, _) => { match TaskAttempt::find_by_rowid(&db.pool, rowid).await { Ok(Some(attempt)) => RecordTypes::TaskAttempt(attempt), - Ok(None) => RecordTypes::DeletedTaskAttempt { - rowid, - task_id: None, + Ok(None) => { + // Row not found - likely already deleted, skip processing + tracing::debug!("TaskAttempt rowid {} not found, skipping", rowid); + return; }, Err(e) => { tracing::error!( @@ -328,10 +439,10 @@ impl EventService { (HookTables::ExecutionProcesses, _) => { match ExecutionProcess::find_by_rowid(&db.pool, rowid).await { Ok(Some(process)) => RecordTypes::ExecutionProcess(process), - Ok(None) => RecordTypes::DeletedExecutionProcess { - rowid, - task_attempt_id: None, - process_id: None, + Ok(None) => { + // Row not found - likely already deleted, skip processing + tracing::debug!("ExecutionProcess rowid {} not found, skipping", rowid); + return; }, Err(e) => { tracing::error!( @@ -343,19 +454,9 @@ impl EventService { } } (HookTables::FollowUpDrafts, SqliteOperation::Delete) => { - // Try to get draft before deletion to capture attempt id - let attempt_id = - db::models::follow_up_draft::FollowUpDraft::find_by_rowid( - &db.pool, rowid, - ) - .await - .ok() - .flatten() - .map(|d| d.task_attempt_id); - RecordTypes::DeletedFollowUpDraft { - rowid, - task_attempt_id: attempt_id, - } + // Follow up draft deletion is now handled by preupdate hook + // Skip post-update processing to avoid duplicate patches + return; } (HookTables::FollowUpDrafts, _) => { match db::models::follow_up_draft::FollowUpDraft::find_by_rowid( @@ -364,9 +465,10 @@ impl EventService { .await { Ok(Some(draft)) => RecordTypes::FollowUpDraft(draft), - Ok(None) => RecordTypes::DeletedFollowUpDraft { - rowid, - task_attempt_id: None, + Ok(None) => { + // Row not found - likely already deleted, skip processing + tracing::debug!("FollowUpDraft rowid {} not found, skipping", rowid); + return; }, Err(e) => { tracing::error!( @@ -412,14 +514,6 @@ impl EventService { return; } } - RecordTypes::DeletedTask { - task_id: Some(task_id), - .. - } => { - let patch = task_patch::remove(*task_id); - msg_store_for_hook.push_patch(patch); - return; - } RecordTypes::TaskAttempt(attempt) => { // Task attempts should update the parent task with fresh data if let Ok(Some(task)) = @@ -438,27 +532,6 @@ impl EventService { return; } } - RecordTypes::DeletedTaskAttempt { - task_id: Some(task_id), - .. - } => { - // Task attempt deletion should update the parent task with fresh data - if let Ok(Some(task)) = - Task::find_by_id(&db.pool, *task_id).await - && let Ok(task_list) = - Task::find_by_project_id_with_attempt_status( - &db.pool, - task.project_id, - ) - .await - && let Some(task_with_status) = - task_list.into_iter().find(|t| t.id == *task_id) - { - let patch = task_patch::replace(&task_with_status); - msg_store_for_hook.push_patch(patch); - return; - } - } RecordTypes::ExecutionProcess(process) => { let patch = match hook.operation { SqliteOperation::Insert => { @@ -486,31 +559,6 @@ impl EventService { return; } - RecordTypes::DeletedExecutionProcess { - process_id: Some(process_id), - task_attempt_id, - .. - } => { - let patch = execution_process_patch::remove(*process_id); - msg_store_for_hook.push_patch(patch); - - if let Some(task_attempt_id) = task_attempt_id - && let Err(err) = - EventService::push_task_update_for_attempt( - &db.pool, - msg_store_for_hook.clone(), - *task_attempt_id, - ) - .await - { - tracing::error!( - "Failed to push task update after execution process removal: {:?}", - err - ); - } - - return; - } _ => {} } @@ -629,14 +677,6 @@ impl EventService { return Some(Ok(LogMsg::JsonPatch(patch))); } } - RecordTypes::DeletedTask { - project_id: Some(deleted_project_id), - .. - } => { - if *deleted_project_id == project_id { - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } RecordTypes::TaskAttempt(attempt) => { // Check if this task_attempt belongs to a task in our project if let Ok(Some(task)) = @@ -646,18 +686,6 @@ impl EventService { return Some(Ok(LogMsg::JsonPatch(patch))); } } - RecordTypes::DeletedTaskAttempt { - task_id: Some(deleted_task_id), - .. - } => { - // Check if deleted attempt belonged to a task in our project - if let Ok(Some(task)) = - Task::find_by_id(&db_pool, *deleted_task_id).await - && task.project_id == project_id - { - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } _ => {} } } @@ -762,28 +790,15 @@ impl EventService { else if let Ok(event_patch_value) = serde_json::to_value(patch_op) && let Ok(event_patch) = serde_json::from_value::(event_patch_value) + && let RecordTypes::ExecutionProcess(process) = + &event_patch.value.record + && process.task_attempt_id == task_attempt_id { - match &event_patch.value.record { - RecordTypes::ExecutionProcess(process) => { - if process.task_attempt_id == task_attempt_id { - if !show_soft_deleted && process.dropped { - let remove_patch = - execution_process_patch::remove(process.id); - return Some(Ok(LogMsg::JsonPatch(remove_patch))); - } - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } - RecordTypes::DeletedExecutionProcess { - task_attempt_id: Some(deleted_attempt_id), - .. - } => { - if *deleted_attempt_id == task_attempt_id { - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } - _ => {} + if !show_soft_deleted && process.dropped { + let remove_patch = execution_process_patch::remove(process.id); + return Some(Ok(LogMsg::JsonPatch(remove_patch))); } + return Some(Ok(LogMsg::JsonPatch(patch))); } } None @@ -844,53 +859,19 @@ impl EventService { && let Ok(event_patch_value) = serde_json::to_value(event_patch_op) && let Ok(event_patch) = serde_json::from_value::(event_patch_value) + && let RecordTypes::FollowUpDraft(draft) = &event_patch.value.record + && draft.task_attempt_id == task_attempt_id { - match &event_patch.value.record { - RecordTypes::FollowUpDraft(draft) => { - if draft.task_attempt_id == task_attempt_id { - // Build a direct patch to replace /follow_up_draft - let direct = json!([ - { - "op": "replace", - "path": "/follow_up_draft", - "value": draft - } - ]); - let direct_patch = serde_json::from_value(direct).unwrap(); - return Some(Ok(LogMsg::JsonPatch(direct_patch))); - } + // Build a direct patch to replace /follow_up_draft + let direct = json!([ + { + "op": "replace", + "path": "/follow_up_draft", + "value": draft } - RecordTypes::DeletedFollowUpDraft { - task_attempt_id: Some(id), - .. - } => { - if *id == task_attempt_id { - // Replace with empty draft state - let empty = json!({ - "id": uuid::Uuid::new_v4(), - "task_attempt_id": id, - "prompt": "", - "queued": false, - "sending": false, - "variant": null, - "image_ids": null, - "created_at": chrono::Utc::now(), - "updated_at": chrono::Utc::now(), - "version": 0 - }); - let direct = json!([ - { - "op": "replace", - "path": "/follow_up_draft", - "value": empty - } - ]); - let direct_patch = serde_json::from_value(direct).unwrap(); - return Some(Ok(LogMsg::JsonPatch(direct_patch))); - } - } - _ => {} - } + ]); + let direct_patch = serde_json::from_value(direct).unwrap(); + return Some(Ok(LogMsg::JsonPatch(direct_patch))); } None }