Consolidate Retry and Follow-up (#800)

This commit is contained in:
Solomon
2025-09-30 13:09:50 +01:00
committed by GitHub
parent 71bfe9ac0b
commit f9878e9183
55 changed files with 3644 additions and 2294 deletions

View File

@@ -148,6 +148,19 @@ pub trait ContainerService {
Ok(())
}
fn cleanup_action(&self, cleanup_script: Option<String>) -> Option<Box<ExecutorAction>> {
cleanup_script.map(|script| {
Box::new(ExecutorAction::new(
ExecutorActionType::ScriptRequest(ScriptRequest {
script,
language: ScriptRequestLanguage::Bash,
context: ScriptContext::CleanupScript,
}),
None,
))
})
}
async fn try_stop(&self, task_attempt: &TaskAttempt) {
// stop all execution processes for this attempt
if let Ok(processes) =
@@ -499,16 +512,7 @@ pub trait ContainerService {
);
let prompt = ImageService::canonicalise_image_paths(&task.to_prompt(), &worktree_path);
let cleanup_action = project.cleanup_script.map(|script| {
Box::new(ExecutorAction::new(
ExecutorActionType::ScriptRequest(ScriptRequest {
script,
language: ScriptRequestLanguage::Bash,
context: ScriptContext::CleanupScript,
}),
None,
))
});
let cleanup_action = self.cleanup_action(project.cleanup_script);
// Choose whether to execute the setup_script or coding agent first
let execution_process = if let Some(setup_script) = project.setup_script {

View File

@@ -0,0 +1,474 @@
use std::path::{Path, PathBuf};
use db::{
DBService,
models::{
draft::{Draft, DraftType, UpsertDraft},
execution_process::{ExecutionProcess, ExecutionProcessError, ExecutionProcessRunReason},
image::TaskImage,
task_attempt::TaskAttempt,
},
};
use executors::{
actions::{
ExecutorAction, ExecutorActionType, coding_agent_follow_up::CodingAgentFollowUpRequest,
},
profile::ExecutorProfileId,
};
use serde::{Deserialize, Serialize};
use sqlx::Error as SqlxError;
use thiserror::Error;
use ts_rs::TS;
use uuid::Uuid;
use super::{
container::{ContainerError, ContainerService},
image::{ImageError, ImageService},
};
#[derive(Debug, Error)]
pub enum DraftsServiceError {
#[error(transparent)]
Database(#[from] sqlx::Error),
#[error(transparent)]
Container(#[from] ContainerError),
#[error(transparent)]
Image(#[from] ImageError),
#[error(transparent)]
ExecutionProcess(#[from] ExecutionProcessError),
#[error("Conflict: {0}")]
Conflict(String),
}
#[derive(Debug, Serialize, TS)]
pub struct DraftResponse {
pub task_attempt_id: Uuid,
pub draft_type: DraftType,
pub retry_process_id: Option<Uuid>,
pub prompt: String,
pub queued: bool,
pub variant: Option<String>,
pub image_ids: Option<Vec<Uuid>>,
pub version: i64,
}
#[derive(Debug, Deserialize, TS)]
pub struct UpdateFollowUpDraftRequest {
pub prompt: Option<String>,
pub variant: Option<Option<String>>,
pub image_ids: Option<Vec<Uuid>>,
pub version: Option<i64>,
}
#[derive(Debug, Deserialize, TS)]
pub struct UpdateRetryFollowUpDraftRequest {
pub retry_process_id: Uuid,
pub prompt: Option<String>,
pub variant: Option<Option<String>>,
pub image_ids: Option<Vec<Uuid>>,
pub version: Option<i64>,
}
#[derive(Debug, Deserialize, TS)]
pub struct SetQueueRequest {
pub queued: bool,
pub expected_queued: Option<bool>,
pub expected_version: Option<i64>,
}
#[derive(Clone)]
pub struct DraftsService {
db: DBService,
image: ImageService,
}
impl DraftsService {
pub fn new(db: DBService, image: ImageService) -> Self {
Self { db, image }
}
fn pool(&self) -> &sqlx::SqlitePool {
&self.db.pool
}
fn draft_to_response(d: Draft) -> DraftResponse {
DraftResponse {
task_attempt_id: d.task_attempt_id,
draft_type: d.draft_type,
retry_process_id: d.retry_process_id,
prompt: d.prompt,
queued: d.queued,
variant: d.variant,
image_ids: d.image_ids,
version: d.version,
}
}
async fn ensure_follow_up_draft_row(
&self,
attempt_id: Uuid,
) -> Result<Draft, DraftsServiceError> {
if let Some(d) =
Draft::find_by_task_attempt_and_type(self.pool(), attempt_id, DraftType::FollowUp)
.await?
{
return Ok(d);
}
let _ = Draft::upsert(
self.pool(),
&UpsertDraft {
task_attempt_id: attempt_id,
draft_type: DraftType::FollowUp,
retry_process_id: None,
prompt: "".to_string(),
queued: false,
variant: None,
image_ids: None,
},
)
.await?;
Draft::find_by_task_attempt_and_type(self.pool(), attempt_id, DraftType::FollowUp)
.await?
.ok_or(SqlxError::RowNotFound)
.map_err(DraftsServiceError::from)
}
async fn associate_images_for_task_if_any(
&self,
task_id: Uuid,
image_ids: &Option<Vec<Uuid>>,
) -> Result<(), DraftsServiceError> {
if let Some(ids) = image_ids
&& !ids.is_empty()
{
TaskImage::associate_many_dedup(self.pool(), task_id, ids).await?;
}
Ok(())
}
async fn has_running_processes_for_attempt(
&self,
attempt_id: Uuid,
) -> Result<bool, DraftsServiceError> {
let processes =
ExecutionProcess::find_by_task_attempt_id(self.pool(), attempt_id, false).await?;
Ok(processes.into_iter().any(|p| {
matches!(
p.status,
db::models::execution_process::ExecutionProcessStatus::Running
)
}))
}
async fn fetch_draft_response(
&self,
task_attempt_id: Uuid,
draft_type: DraftType,
) -> Result<DraftResponse, DraftsServiceError> {
let d =
Draft::find_by_task_attempt_and_type(self.pool(), task_attempt_id, draft_type).await?;
let resp = if let Some(d) = d {
Self::draft_to_response(d)
} else {
DraftResponse {
task_attempt_id,
draft_type,
retry_process_id: None,
prompt: "".to_string(),
queued: false,
variant: None,
image_ids: None,
version: 0,
}
};
Ok(resp)
}
async fn handle_images_for_prompt(
&self,
task_id: Uuid,
image_ids: &[Uuid],
prompt: &str,
worktree_path: &Path,
) -> Result<String, DraftsServiceError> {
if image_ids.is_empty() {
return Ok(prompt.to_string());
}
TaskImage::associate_many_dedup(self.pool(), task_id, image_ids).await?;
self.image
.copy_images_by_ids_to_worktree(worktree_path, image_ids)
.await?;
Ok(ImageService::canonicalise_image_paths(
prompt,
worktree_path,
))
}
async fn start_follow_up_from_draft(
&self,
container: &(dyn ContainerService + Send + Sync),
task_attempt: &TaskAttempt,
draft: &Draft,
) -> Result<ExecutionProcess, DraftsServiceError> {
let worktree_ref = container.ensure_container_exists(task_attempt).await?;
let worktree_path = PathBuf::from(worktree_ref);
let session_id =
ExecutionProcess::require_latest_session_id(self.pool(), task_attempt.id).await?;
let base_profile =
ExecutionProcess::latest_executor_profile_for_attempt(self.pool(), task_attempt.id)
.await?;
let executor_profile_id = ExecutorProfileId {
executor: base_profile.executor,
variant: draft.variant.clone(),
};
let task = task_attempt
.parent_task(self.pool())
.await?
.ok_or(SqlxError::RowNotFound)
.map_err(DraftsServiceError::from)?;
let project = task
.parent_project(self.pool())
.await?
.ok_or(SqlxError::RowNotFound)
.map_err(DraftsServiceError::from)?;
let cleanup_action = container.cleanup_action(project.cleanup_script);
let mut prompt = draft.prompt.clone();
if let Some(image_ids) = &draft.image_ids {
prompt = self
.handle_images_for_prompt(task_attempt.task_id, image_ids, &prompt, &worktree_path)
.await?;
}
let follow_up_request = CodingAgentFollowUpRequest {
prompt,
session_id,
executor_profile_id,
};
let follow_up_action = ExecutorAction::new(
ExecutorActionType::CodingAgentFollowUpRequest(follow_up_request),
cleanup_action,
);
let execution_process = container
.start_execution(
task_attempt,
&follow_up_action,
&ExecutionProcessRunReason::CodingAgent,
)
.await?;
let _ = Draft::clear_after_send(self.pool(), task_attempt.id, DraftType::FollowUp).await;
Ok(execution_process)
}
pub async fn save_follow_up_draft(
&self,
task_attempt: &TaskAttempt,
payload: &UpdateFollowUpDraftRequest,
) -> Result<DraftResponse, DraftsServiceError> {
let pool = self.pool();
let d = self.ensure_follow_up_draft_row(task_attempt.id).await?;
if d.queued {
return Err(DraftsServiceError::Conflict(
"Draft is queued; click Edit to unqueue before editing".to_string(),
));
}
if let Some(expected_version) = payload.version
&& d.version != expected_version
{
return Err(DraftsServiceError::Conflict(
"Draft changed, please retry with latest".to_string(),
));
}
if payload.prompt.is_none() && payload.variant.is_none() && payload.image_ids.is_none() {
} else {
Draft::update_partial(
pool,
task_attempt.id,
DraftType::FollowUp,
payload.prompt.clone(),
payload.variant.clone(),
payload.image_ids.clone(),
None,
)
.await?;
}
if let Some(task) = task_attempt.parent_task(pool).await? {
self.associate_images_for_task_if_any(task.id, &payload.image_ids)
.await?;
}
let current =
Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::FollowUp)
.await?
.map(Self::draft_to_response)
.unwrap_or(DraftResponse {
task_attempt_id: task_attempt.id,
draft_type: DraftType::FollowUp,
retry_process_id: None,
prompt: "".to_string(),
queued: false,
variant: None,
image_ids: None,
version: 0,
});
Ok(current)
}
pub async fn save_retry_follow_up_draft(
&self,
task_attempt: &TaskAttempt,
payload: &UpdateRetryFollowUpDraftRequest,
) -> Result<DraftResponse, DraftsServiceError> {
let pool = self.pool();
let existing =
Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::Retry).await?;
if let Some(d) = &existing {
if d.queued {
return Err(DraftsServiceError::Conflict(
"Retry draft is queued; unqueue before editing".to_string(),
));
}
if let Some(expected_version) = payload.version
&& d.version != expected_version
{
return Err(DraftsServiceError::Conflict(
"Retry draft changed, please retry with latest".to_string(),
));
}
}
if existing.is_none() {
let draft = Draft::upsert(
pool,
&UpsertDraft {
task_attempt_id: task_attempt.id,
draft_type: DraftType::Retry,
retry_process_id: Some(payload.retry_process_id),
prompt: payload.prompt.clone().unwrap_or_default(),
queued: false,
variant: payload.variant.clone().unwrap_or(None),
image_ids: payload.image_ids.clone(),
},
)
.await?;
return Ok(Self::draft_to_response(draft));
}
if payload.prompt.is_none() && payload.variant.is_none() && payload.image_ids.is_none() {
} else {
Draft::update_partial(
pool,
task_attempt.id,
DraftType::Retry,
payload.prompt.clone(),
payload.variant.clone(),
payload.image_ids.clone(),
Some(payload.retry_process_id),
)
.await?;
}
if let Some(task) = task_attempt.parent_task(pool).await? {
self.associate_images_for_task_if_any(task.id, &payload.image_ids)
.await?;
}
let draft = Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::Retry)
.await?
.ok_or(SqlxError::RowNotFound)
.map_err(DraftsServiceError::from)?;
Ok(Self::draft_to_response(draft))
}
pub async fn delete_retry_follow_up_draft(
&self,
task_attempt: &TaskAttempt,
) -> Result<(), DraftsServiceError> {
Draft::delete_by_task_attempt_and_type(self.pool(), task_attempt.id, DraftType::Retry)
.await?;
Ok(())
}
pub async fn set_follow_up_queue(
&self,
container: &(dyn ContainerService + Send + Sync),
task_attempt: &TaskAttempt,
payload: &SetQueueRequest,
) -> Result<DraftResponse, DraftsServiceError> {
let pool = self.pool();
let rows_updated = Draft::set_queued(
pool,
task_attempt.id,
DraftType::FollowUp,
payload.queued,
payload.expected_queued,
payload.expected_version,
)
.await?;
let draft =
Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::FollowUp)
.await?;
if rows_updated == 0 {
if draft.is_none() {
return Err(DraftsServiceError::Conflict(
"No draft to queue".to_string(),
));
};
return Err(DraftsServiceError::Conflict(
"Draft changed, please refresh and try again".to_string(),
));
}
let should_consider_start = draft.as_ref().map(|c| c.queued).unwrap_or(false)
&& !self
.has_running_processes_for_attempt(task_attempt.id)
.await?;
if should_consider_start
&& Draft::try_mark_sending(pool, task_attempt.id, DraftType::FollowUp)
.await
.unwrap_or(false)
{
let _ = self
.start_follow_up_from_draft(container, task_attempt, draft.as_ref().unwrap())
.await;
}
let draft =
Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::FollowUp)
.await?
.ok_or(SqlxError::RowNotFound)
.map_err(DraftsServiceError::from)?;
Ok(Self::draft_to_response(draft))
}
pub async fn get_draft(
&self,
task_attempt_id: Uuid,
draft_type: DraftType,
) -> Result<DraftResponse, DraftsServiceError> {
self.fetch_draft_response(task_attempt_id, draft_type).await
}
}

View File

@@ -1,202 +1,29 @@
use std::{str::FromStr, sync::Arc};
use anyhow::Error as AnyhowError;
use db::{
DBService,
models::{
draft::{Draft, DraftType},
execution_process::ExecutionProcess,
task::{Task, TaskWithAttemptStatus},
task::Task,
task_attempt::TaskAttempt,
},
};
use futures::StreamExt;
use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation};
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::{Error as SqlxError, SqlitePool, ValueRef, sqlite::SqliteOperation};
use strum_macros::{Display, EnumString};
use thiserror::Error;
use sqlx::{Error as SqlxError, Sqlite, SqlitePool, decode::Decode, sqlite::SqliteOperation};
use tokio::sync::RwLock;
use tokio_stream::wrappers::BroadcastStream;
use ts_rs::TS;
use utils::{log_msg::LogMsg, msg_store::MsgStore};
use utils::msg_store::MsgStore;
use uuid::Uuid;
#[derive(Debug, Error)]
pub enum EventError {
#[error(transparent)]
Sqlx(#[from] SqlxError),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Other(#[from] AnyhowError), // Catches any unclassified errors
}
#[path = "events/patches.rs"]
pub mod patches;
#[path = "events/streams.rs"]
mod streams;
#[path = "events/types.rs"]
pub mod types;
/// Trait for types that can be used in JSON patch operations
pub trait Patchable: serde::Serialize {
const PATH_PREFIX: &'static str;
type Id: ToString + Copy;
fn id(&self) -> Self::Id;
}
/// Implementations of Patchable for all supported types
impl Patchable for TaskWithAttemptStatus {
const PATH_PREFIX: &'static str = "/tasks";
type Id = Uuid;
fn id(&self) -> Self::Id {
self.id
}
}
impl Patchable for ExecutionProcess {
const PATH_PREFIX: &'static str = "/execution_processes";
type Id = Uuid;
fn id(&self) -> Self::Id {
self.id
}
}
impl Patchable for TaskAttempt {
const PATH_PREFIX: &'static str = "/task_attempts";
type Id = Uuid;
fn id(&self) -> Self::Id {
self.id
}
}
impl Patchable for db::models::follow_up_draft::FollowUpDraft {
const PATH_PREFIX: &'static str = "/follow_up_drafts";
type Id = Uuid;
fn id(&self) -> Self::Id {
self.id
}
}
/// Generic patch operations that work with any Patchable type
pub mod patch_ops {
use super::*;
/// Escape JSON Pointer special characters
pub(crate) fn escape_pointer_segment(s: &str) -> String {
s.replace('~', "~0").replace('/', "~1")
}
/// Create path for operation
fn path_for<T: Patchable>(id: T::Id) -> String {
format!(
"{}/{}",
T::PATH_PREFIX,
escape_pointer_segment(&id.to_string())
)
}
/// Create patch for adding a new record
pub fn add<T: Patchable>(value: &T) -> Patch {
Patch(vec![PatchOperation::Add(AddOperation {
path: path_for::<T>(value.id())
.try_into()
.expect("Path should be valid"),
value: serde_json::to_value(value).expect("Serialization should not fail"),
})])
}
/// Create patch for updating an existing record
pub fn replace<T: Patchable>(value: &T) -> Patch {
Patch(vec![PatchOperation::Replace(ReplaceOperation {
path: path_for::<T>(value.id())
.try_into()
.expect("Path should be valid"),
value: serde_json::to_value(value).expect("Serialization should not fail"),
})])
}
/// Create patch for removing a record
pub fn remove<T: Patchable>(id: T::Id) -> Patch {
Patch(vec![PatchOperation::Remove(RemoveOperation {
path: path_for::<T>(id).try_into().expect("Path should be valid"),
})])
}
}
/// Helper functions for creating task-specific patches
pub mod task_patch {
use super::*;
/// Create patch for adding a new task
pub fn add(task: &TaskWithAttemptStatus) -> Patch {
patch_ops::add(task)
}
/// Create patch for updating an existing task
pub fn replace(task: &TaskWithAttemptStatus) -> Patch {
patch_ops::replace(task)
}
/// Create patch for removing a task
pub fn remove(task_id: Uuid) -> Patch {
patch_ops::remove::<TaskWithAttemptStatus>(task_id)
}
}
/// Helper functions for creating execution process-specific patches
pub mod execution_process_patch {
use super::*;
/// Create patch for adding a new execution process
pub fn add(process: &ExecutionProcess) -> Patch {
patch_ops::add(process)
}
/// Create patch for updating an existing execution process
pub fn replace(process: &ExecutionProcess) -> Patch {
patch_ops::replace(process)
}
/// Create patch for removing an execution process
pub fn remove(process_id: Uuid) -> Patch {
patch_ops::remove::<ExecutionProcess>(process_id)
}
}
/// Helper functions for creating task attempt-specific patches
pub mod task_attempt_patch {
use super::*;
/// Create patch for adding a new task attempt
pub fn add(attempt: &TaskAttempt) -> Patch {
patch_ops::add(attempt)
}
/// Create patch for updating an existing task attempt
pub fn replace(attempt: &TaskAttempt) -> Patch {
patch_ops::replace(attempt)
}
/// Create patch for removing a task attempt
pub fn remove(attempt_id: Uuid) -> Patch {
patch_ops::remove::<TaskAttempt>(attempt_id)
}
}
/// Helper functions for creating follow up draft-specific patches
pub mod follow_up_draft_patch {
use super::*;
/// Create patch for adding a new follow up draft
pub fn add(draft: &db::models::follow_up_draft::FollowUpDraft) -> Patch {
patch_ops::add(draft)
}
/// Create patch for updating an existing follow up draft
pub fn replace(draft: &db::models::follow_up_draft::FollowUpDraft) -> Patch {
patch_ops::replace(draft)
}
/// Create patch for removing a follow up draft
pub fn remove(draft_id: Uuid) -> Patch {
patch_ops::remove::<db::models::follow_up_draft::FollowUpDraft>(draft_id)
}
}
pub use patches::{draft_patch, execution_process_patch, task_attempt_patch, task_patch};
pub use types::{EventError, EventPatch, EventPatchInner, HookTables, RecordTypes};
#[derive(Clone)]
pub struct EventService {
@@ -206,40 +33,6 @@ pub struct EventService {
entry_count: Arc<RwLock<usize>>,
}
#[derive(EnumString, Display)]
enum HookTables {
#[strum(to_string = "tasks")]
Tasks,
#[strum(to_string = "task_attempts")]
TaskAttempts,
#[strum(to_string = "execution_processes")]
ExecutionProcesses,
#[strum(to_string = "follow_up_drafts")]
FollowUpDrafts,
}
#[derive(Serialize, Deserialize, TS)]
#[serde(tag = "type", content = "data", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum RecordTypes {
Task(Task),
TaskAttempt(TaskAttempt),
ExecutionProcess(ExecutionProcess),
FollowUpDraft(db::models::follow_up_draft::FollowUpDraft),
}
#[derive(Serialize, Deserialize, TS)]
pub struct EventPatchInner {
db_op: String,
record: RecordTypes,
}
#[derive(Serialize, Deserialize, TS)]
pub struct EventPatch {
op: String,
path: String,
value: EventPatchInner,
}
impl EventService {
/// Creates a new EventService that will work with a DBService configured with hooks
pub fn new(db: DBService, msg_store: Arc<MsgStore>, entry_count: Arc<RwLock<usize>>) -> Self {
@@ -297,85 +90,67 @@ impl EventService {
let msg_store_for_hook = msg_store.clone();
let entry_count_for_hook = entry_count.clone();
let db_for_hook = db_service.clone();
Box::pin(async move {
let mut handle = conn.lock_handle().await?;
let runtime_handle = tokio::runtime::Handle::current();
// Set up preupdate hook to capture task data before deletion
handle.set_preupdate_hook({
let msg_store_for_preupdate = msg_store_for_hook.clone();
move |preupdate: sqlx::sqlite::PreupdateHookResult<'_>| {
if preupdate.operation == sqlx::sqlite::SqliteOperation::Delete {
match preupdate.table {
"tasks" => {
// Extract task ID from old column values before deletion
if let Ok(id_value) = preupdate.get_old_column_value(0)
&& !id_value.is_null()
{
// Decode UUID from SQLite value
if let Ok(task_id) =
<uuid::Uuid as sqlx::Decode<'_, sqlx::Sqlite>>::decode(
id_value,
)
{
let patch = task_patch::remove(task_id);
msg_store_for_preupdate.push_patch(patch);
}
}
}
"execution_processes" => {
// Extract process ID from old column values before deletion
if let Ok(id_value) = preupdate.get_old_column_value(0)
&& !id_value.is_null()
{
// Decode UUID from SQLite value
if let Ok(process_id) =
<uuid::Uuid as sqlx::Decode<'_, sqlx::Sqlite>>::decode(
id_value,
)
{
let patch = execution_process_patch::remove(process_id);
msg_store_for_preupdate.push_patch(patch);
}
}
}
"task_attempts" => {
// Extract attempt ID from old column values before deletion
if let Ok(id_value) = preupdate.get_old_column_value(0)
&& !id_value.is_null()
{
// Decode UUID from SQLite value
if let Ok(attempt_id) =
<uuid::Uuid as sqlx::Decode<'_, sqlx::Sqlite>>::decode(
id_value,
)
{
let patch = task_attempt_patch::remove(attempt_id);
msg_store_for_preupdate.push_patch(patch);
}
}
}
"follow_up_drafts" => {
// Extract draft ID from old column values before deletion
if let Ok(id_value) = preupdate.get_old_column_value(0)
&& !id_value.is_null()
{
// Decode UUID from SQLite value
if let Ok(draft_id) =
<uuid::Uuid as sqlx::Decode<'_, sqlx::Sqlite>>::decode(
id_value,
)
{
let patch = follow_up_draft_patch::remove(draft_id);
msg_store_for_preupdate.push_patch(patch);
}
}
}
_ => {
// Ignore other tables
if preupdate.operation != SqliteOperation::Delete {
return;
}
match preupdate.table {
"tasks" => {
if let Ok(value) = preupdate.get_old_column_value(0)
&& let Ok(task_id) = <Uuid as Decode<Sqlite>>::decode(value)
{
let patch = task_patch::remove(task_id);
msg_store_for_preupdate.push_patch(patch);
}
}
"task_attempts" => {
if let Ok(value) = preupdate.get_old_column_value(0)
&& let Ok(attempt_id) = <Uuid as Decode<Sqlite>>::decode(value)
{
let patch = task_attempt_patch::remove(attempt_id);
msg_store_for_preupdate.push_patch(patch);
}
}
"execution_processes" => {
if let Ok(value) = preupdate.get_old_column_value(0)
&& let Ok(process_id) = <Uuid as Decode<Sqlite>>::decode(value)
{
let patch = execution_process_patch::remove(process_id);
msg_store_for_preupdate.push_patch(patch);
}
}
"drafts" => {
let draft_type = preupdate
.get_old_column_value(2)
.ok()
.and_then(|val| <String as Decode<Sqlite>>::decode(val).ok())
.and_then(|s| DraftType::from_str(&s).ok());
let task_attempt_id = preupdate
.get_old_column_value(1)
.ok()
.and_then(|val| <Uuid as Decode<Sqlite>>::decode(val).ok());
if let (Some(draft_type), Some(task_attempt_id)) =
(draft_type, task_attempt_id)
{
let patch = match draft_type {
DraftType::FollowUp => {
draft_patch::follow_up_clear(task_attempt_id)
}
DraftType::Retry => {
draft_patch::retry_clear(task_attempt_id)
}
};
msg_store_for_preupdate.push_patch(patch);
}
}
_ => {}
}
}
});
@@ -390,28 +165,20 @@ impl EventService {
let rowid = hook.rowid;
runtime_handle.spawn(async move {
let record_type: RecordTypes = match (table, hook.operation.clone()) {
(HookTables::Tasks, SqliteOperation::Delete) => {
// Task deletion is now handled by preupdate hook
// Skip post-update processing to avoid duplicate patches
return;
}
(HookTables::ExecutionProcesses, SqliteOperation::Delete) => {
// Execution process deletion is now handled by preupdate hook
// Skip post-update processing to avoid duplicate patches
return;
}
(HookTables::TaskAttempts, SqliteOperation::Delete) => {
// Task attempt deletion is now handled by preupdate hook
// Skip post-update processing to avoid duplicate patches
(HookTables::Tasks, SqliteOperation::Delete)
| (HookTables::TaskAttempts, SqliteOperation::Delete)
| (HookTables::ExecutionProcesses, SqliteOperation::Delete)
| (HookTables::Drafts, SqliteOperation::Delete) => {
// Deletions handled in preupdate hook for reliable data capture
return;
}
(HookTables::Tasks, _) => {
match Task::find_by_rowid(&db.pool, rowid).await {
Ok(Some(task)) => RecordTypes::Task(task),
Ok(None) => {
// Row not found - likely already deleted, skip processing
tracing::debug!("Task rowid {} not found, skipping", rowid);
return;
Ok(None) => RecordTypes::DeletedTask {
rowid,
project_id: None,
task_id: None,
},
Err(e) => {
tracing::error!("Failed to fetch task: {:?}", e);
@@ -422,10 +189,9 @@ impl EventService {
(HookTables::TaskAttempts, _) => {
match TaskAttempt::find_by_rowid(&db.pool, rowid).await {
Ok(Some(attempt)) => RecordTypes::TaskAttempt(attempt),
Ok(None) => {
// Row not found - likely already deleted, skip processing
tracing::debug!("TaskAttempt rowid {} not found, skipping", rowid);
return;
Ok(None) => RecordTypes::DeletedTaskAttempt {
rowid,
task_id: None,
},
Err(e) => {
tracing::error!(
@@ -439,10 +205,10 @@ impl EventService {
(HookTables::ExecutionProcesses, _) => {
match ExecutionProcess::find_by_rowid(&db.pool, rowid).await {
Ok(Some(process)) => RecordTypes::ExecutionProcess(process),
Ok(None) => {
// Row not found - likely already deleted, skip processing
tracing::debug!("ExecutionProcess rowid {} not found, skipping", rowid);
return;
Ok(None) => RecordTypes::DeletedExecutionProcess {
rowid,
task_attempt_id: None,
process_id: None,
},
Err(e) => {
tracing::error!(
@@ -453,28 +219,19 @@ impl EventService {
}
}
}
(HookTables::FollowUpDrafts, SqliteOperation::Delete) => {
// Follow up draft deletion is now handled by preupdate hook
// Skip post-update processing to avoid duplicate patches
return;
}
(HookTables::FollowUpDrafts, _) => {
match db::models::follow_up_draft::FollowUpDraft::find_by_rowid(
&db.pool, rowid,
)
.await
{
Ok(Some(draft)) => RecordTypes::FollowUpDraft(draft),
Ok(None) => {
// Row not found - likely already deleted, skip processing
tracing::debug!("FollowUpDraft rowid {} not found, skipping", rowid);
return;
(HookTables::Drafts, _) => {
match Draft::find_by_rowid(&db.pool, rowid).await {
Ok(Some(draft)) => match draft.draft_type {
DraftType::FollowUp => RecordTypes::Draft(draft),
DraftType::Retry => RecordTypes::RetryDraft(draft),
},
Ok(None) => RecordTypes::DeletedDraft {
rowid,
draft_type: DraftType::Retry,
task_attempt_id: None,
},
Err(e) => {
tracing::error!(
"Failed to fetch follow_up_draft: {:?}",
e
);
tracing::error!("Failed to fetch draft: {:?}", e);
return;
}
}
@@ -514,6 +271,33 @@ impl EventService {
return;
}
}
// Draft updates: emit direct patches used by the follow-up draft stream
RecordTypes::Draft(draft) => {
let patch = draft_patch::follow_up_replace(draft);
msg_store_for_hook.push_patch(patch);
return;
}
RecordTypes::RetryDraft(draft) => {
let patch = draft_patch::retry_replace(draft);
msg_store_for_hook.push_patch(patch);
return;
}
RecordTypes::DeletedDraft { draft_type, task_attempt_id: Some(id), .. } => {
let patch = match draft_type {
DraftType::FollowUp => draft_patch::follow_up_clear(*id),
DraftType::Retry => draft_patch::retry_clear(*id),
};
msg_store_for_hook.push_patch(patch);
return;
}
RecordTypes::DeletedTask {
task_id: Some(task_id),
..
} => {
let patch = task_patch::remove(*task_id);
msg_store_for_hook.push_patch(patch);
return;
}
RecordTypes::TaskAttempt(attempt) => {
// Task attempts should update the parent task with fresh data
if let Ok(Some(task)) =
@@ -532,6 +316,27 @@ impl EventService {
return;
}
}
RecordTypes::DeletedTaskAttempt {
task_id: Some(task_id),
..
} => {
// Task attempt deletion should update the parent task with fresh data
if let Ok(Some(task)) =
Task::find_by_id(&db.pool, *task_id).await
&& let Ok(task_list) =
Task::find_by_project_id_with_attempt_status(
&db.pool,
task.project_id,
)
.await
&& let Some(task_with_status) =
task_list.into_iter().find(|t| t.id == *task_id)
{
let patch = task_patch::replace(&task_with_status);
msg_store_for_hook.push_patch(patch);
return;
}
}
RecordTypes::ExecutionProcess(process) => {
let patch = match hook.operation {
SqliteOperation::Insert => {
@@ -559,6 +364,31 @@ impl EventService {
return;
}
RecordTypes::DeletedExecutionProcess {
process_id: Some(process_id),
task_attempt_id,
..
} => {
let patch = execution_process_patch::remove(*process_id);
msg_store_for_hook.push_patch(patch);
if let Some(task_attempt_id) = task_attempt_id
&& let Err(err) =
EventService::push_task_update_for_attempt(
&db.pool,
msg_store_for_hook.clone(),
*task_attempt_id,
)
.await
{
tracing::error!(
"Failed to push task update after execution process removal: {:?}",
err
);
}
return;
}
_ => {}
}
@@ -597,293 +427,4 @@ impl EventService {
pub fn msg_store(&self) -> &Arc<MsgStore> {
&self.msg_store
}
/// Stream raw task messages for a specific project with initial snapshot
pub async fn stream_tasks_raw(
&self,
project_id: Uuid,
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
{
// Get initial snapshot of tasks
let tasks = Task::find_by_project_id_with_attempt_status(&self.db.pool, project_id).await?;
// Convert task array to object keyed by task ID
let tasks_map: serde_json::Map<String, serde_json::Value> = tasks
.into_iter()
.map(|task| (task.id.to_string(), serde_json::to_value(task).unwrap()))
.collect();
let initial_patch = json!([{
"op": "replace",
"path": "/tasks",
"value": tasks_map
}]);
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
// Clone necessary data for the async filter
let db_pool = self.db.pool.clone();
// Get filtered event stream
let filtered_stream =
BroadcastStream::new(self.msg_store.get_receiver()).filter_map(move |msg_result| {
let db_pool = db_pool.clone();
async move {
match msg_result {
Ok(LogMsg::JsonPatch(patch)) => {
// Filter events based on project_id
if let Some(patch_op) = patch.0.first() {
// Check if this is a direct task patch (new format)
if patch_op.path().starts_with("/tasks/") {
match patch_op {
json_patch::PatchOperation::Add(op) => {
// Parse task data directly from value
if let Ok(task) =
serde_json::from_value::<TaskWithAttemptStatus>(
op.value.clone(),
)
&& task.project_id == project_id
{
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
json_patch::PatchOperation::Replace(op) => {
// Parse task data directly from value
if let Ok(task) =
serde_json::from_value::<TaskWithAttemptStatus>(
op.value.clone(),
)
&& task.project_id == project_id
{
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
json_patch::PatchOperation::Remove(_) => {
// For remove operations, we need to check project membership differently
// We could cache this information or let it pass through for now
// Since we don't have the task data, we'll allow all removals
// and let the client handle filtering
return Some(Ok(LogMsg::JsonPatch(patch)));
}
_ => {}
}
} else if let Ok(event_patch_value) = serde_json::to_value(patch_op)
&& let Ok(event_patch) =
serde_json::from_value::<EventPatch>(event_patch_value)
{
// Handle old EventPatch format for non-task records
match &event_patch.value.record {
RecordTypes::Task(task) => {
if task.project_id == project_id {
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
RecordTypes::TaskAttempt(attempt) => {
// Check if this task_attempt belongs to a task in our project
if let Ok(Some(task)) =
Task::find_by_id(&db_pool, attempt.task_id).await
&& task.project_id == project_id
{
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
_ => {}
}
}
}
None
}
Ok(other) => Some(Ok(other)), // Pass through non-patch messages
Err(_) => None, // Filter out broadcast errors
}
}
});
// Start with initial snapshot, then live updates
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
let combined_stream = initial_stream.chain(filtered_stream).boxed();
Ok(combined_stream)
}
/// Stream execution processes for a specific task attempt with initial snapshot (raw LogMsg format for WebSocket)
pub async fn stream_execution_processes_for_attempt_raw(
&self,
task_attempt_id: Uuid,
show_soft_deleted: bool,
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
{
// Get initial snapshot of execution processes (filtering at SQL level)
let processes = ExecutionProcess::find_by_task_attempt_id(
&self.db.pool,
task_attempt_id,
show_soft_deleted,
)
.await?;
// Convert processes array to object keyed by process ID
let processes_map: serde_json::Map<String, serde_json::Value> = processes
.into_iter()
.map(|process| {
(
process.id.to_string(),
serde_json::to_value(process).unwrap(),
)
})
.collect();
let initial_patch = json!([{
"op": "replace",
"path": "/execution_processes",
"value": processes_map
}]);
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
// Get filtered event stream
let filtered_stream = BroadcastStream::new(self.msg_store.get_receiver()).filter_map(
move |msg_result| async move {
match msg_result {
Ok(LogMsg::JsonPatch(patch)) => {
// Filter events based on task_attempt_id
if let Some(patch_op) = patch.0.first() {
// Check if this is a modern execution process patch
if patch_op.path().starts_with("/execution_processes/") {
match patch_op {
json_patch::PatchOperation::Add(op) => {
// Parse execution process data directly from value
if let Ok(process) =
serde_json::from_value::<ExecutionProcess>(
op.value.clone(),
)
&& process.task_attempt_id == task_attempt_id
{
if !show_soft_deleted && process.dropped {
return None;
}
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
json_patch::PatchOperation::Replace(op) => {
// Parse execution process data directly from value
if let Ok(process) =
serde_json::from_value::<ExecutionProcess>(
op.value.clone(),
)
&& process.task_attempt_id == task_attempt_id
{
if !show_soft_deleted && process.dropped {
let remove_patch =
execution_process_patch::remove(process.id);
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
}
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
json_patch::PatchOperation::Remove(_) => {
// For remove operations, we can't verify task_attempt_id
// so we allow all removals and let the client handle filtering
return Some(Ok(LogMsg::JsonPatch(patch)));
}
_ => {}
}
}
// Fallback to legacy EventPatch format for backward compatibility
else if let Ok(event_patch_value) = serde_json::to_value(patch_op)
&& let Ok(event_patch) =
serde_json::from_value::<EventPatch>(event_patch_value)
&& let RecordTypes::ExecutionProcess(process) =
&event_patch.value.record
&& process.task_attempt_id == task_attempt_id
{
if !show_soft_deleted && process.dropped {
let remove_patch = execution_process_patch::remove(process.id);
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
}
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
None
}
Ok(other) => Some(Ok(other)), // Pass through non-patch messages
Err(_) => None, // Filter out broadcast errors
}
},
);
// Start with initial snapshot, then live updates
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
let combined_stream = initial_stream.chain(filtered_stream).boxed();
Ok(combined_stream)
}
/// Stream follow-up draft for a specific task attempt (raw LogMsg format for WebSocket)
pub async fn stream_follow_up_draft_for_attempt_raw(
&self,
task_attempt_id: Uuid,
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
{
// Get initial snapshot of follow-up draft
let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id(
&self.db.pool,
task_attempt_id,
)
.await?
.unwrap_or(db::models::follow_up_draft::FollowUpDraft {
id: uuid::Uuid::new_v4(),
task_attempt_id,
prompt: String::new(),
queued: false,
sending: false,
variant: None,
image_ids: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
version: 0,
});
let initial_patch = json!([
{
"op": "replace",
"path": "/",
"value": { "follow_up_draft": draft }
}
]);
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
// Filtered live stream, mapped into direct JSON patches that update /follow_up_draft
let filtered_stream = BroadcastStream::new(self.msg_store.get_receiver()).filter_map(
move |msg_result| async move {
match msg_result {
Ok(LogMsg::JsonPatch(patch)) => {
if let Some(event_patch_op) = patch.0.first()
&& let Ok(event_patch_value) = serde_json::to_value(event_patch_op)
&& let Ok(event_patch) =
serde_json::from_value::<EventPatch>(event_patch_value)
&& let RecordTypes::FollowUpDraft(draft) = &event_patch.value.record
&& draft.task_attempt_id == task_attempt_id
{
// Build a direct patch to replace /follow_up_draft
let direct = json!([
{
"op": "replace",
"path": "/follow_up_draft",
"value": draft
}
]);
let direct_patch = serde_json::from_value(direct).unwrap();
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
}
None
}
Ok(other) => Some(Ok(other)),
Err(_) => None,
}
},
);
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
let combined_stream = initial_stream.chain(filtered_stream).boxed();
Ok(combined_stream)
}
}

View File

@@ -0,0 +1,204 @@
use db::models::{
draft::{Draft, DraftType},
execution_process::ExecutionProcess,
task::TaskWithAttemptStatus,
task_attempt::TaskAttempt,
};
use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation};
use uuid::Uuid;
// Shared helper to escape JSON Pointer segments
fn escape_pointer_segment(s: &str) -> String {
s.replace('~', "~0").replace('/', "~1")
}
/// Helper functions for creating task-specific patches
pub mod task_patch {
use super::*;
fn task_path(task_id: Uuid) -> String {
format!("/tasks/{}", escape_pointer_segment(&task_id.to_string()))
}
/// Create patch for adding a new task
pub fn add(task: &TaskWithAttemptStatus) -> Patch {
Patch(vec![PatchOperation::Add(AddOperation {
path: task_path(task.id)
.try_into()
.expect("Task path should be valid"),
value: serde_json::to_value(task).expect("Task serialization should not fail"),
})])
}
/// Create patch for updating an existing task
pub fn replace(task: &TaskWithAttemptStatus) -> Patch {
Patch(vec![PatchOperation::Replace(ReplaceOperation {
path: task_path(task.id)
.try_into()
.expect("Task path should be valid"),
value: serde_json::to_value(task).expect("Task serialization should not fail"),
})])
}
/// Create patch for removing a task
pub fn remove(task_id: Uuid) -> Patch {
Patch(vec![PatchOperation::Remove(RemoveOperation {
path: task_path(task_id)
.try_into()
.expect("Task path should be valid"),
})])
}
}
/// Helper functions for creating execution process-specific patches
pub mod execution_process_patch {
use super::*;
fn execution_process_path(process_id: Uuid) -> String {
format!(
"/execution_processes/{}",
escape_pointer_segment(&process_id.to_string())
)
}
/// Create patch for adding a new execution process
pub fn add(process: &ExecutionProcess) -> Patch {
Patch(vec![PatchOperation::Add(AddOperation {
path: execution_process_path(process.id)
.try_into()
.expect("Execution process path should be valid"),
value: serde_json::to_value(process)
.expect("Execution process serialization should not fail"),
})])
}
/// Create patch for updating an existing execution process
pub fn replace(process: &ExecutionProcess) -> Patch {
Patch(vec![PatchOperation::Replace(ReplaceOperation {
path: execution_process_path(process.id)
.try_into()
.expect("Execution process path should be valid"),
value: serde_json::to_value(process)
.expect("Execution process serialization should not fail"),
})])
}
/// Create patch for removing an execution process
pub fn remove(process_id: Uuid) -> Patch {
Patch(vec![PatchOperation::Remove(RemoveOperation {
path: execution_process_path(process_id)
.try_into()
.expect("Execution process path should be valid"),
})])
}
}
/// Helper functions for creating draft-specific patches
pub mod draft_patch {
use super::*;
fn follow_up_path(attempt_id: Uuid) -> String {
format!("/drafts/{attempt_id}/follow_up")
}
fn retry_path(attempt_id: Uuid) -> String {
format!("/drafts/{attempt_id}/retry")
}
/// Replace the follow-up draft for a specific attempt
pub fn follow_up_replace(draft: &Draft) -> Patch {
Patch(vec![PatchOperation::Replace(ReplaceOperation {
path: follow_up_path(draft.task_attempt_id)
.try_into()
.expect("Path should be valid"),
value: serde_json::to_value(draft).expect("Draft serialization should not fail"),
})])
}
/// Replace the retry draft for a specific attempt
pub fn retry_replace(draft: &Draft) -> Patch {
Patch(vec![PatchOperation::Replace(ReplaceOperation {
path: retry_path(draft.task_attempt_id)
.try_into()
.expect("Path should be valid"),
value: serde_json::to_value(draft).expect("Draft serialization should not fail"),
})])
}
/// Clear the follow-up draft for an attempt (replace with an empty draft)
pub fn follow_up_clear(attempt_id: Uuid) -> Patch {
let empty = Draft {
id: uuid::Uuid::new_v4(),
task_attempt_id: attempt_id,
draft_type: DraftType::FollowUp,
retry_process_id: None,
prompt: String::new(),
queued: false,
sending: false,
variant: None,
image_ids: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
version: 0,
};
Patch(vec![PatchOperation::Replace(ReplaceOperation {
path: follow_up_path(attempt_id)
.try_into()
.expect("Path should be valid"),
value: serde_json::to_value(empty).expect("Draft serialization should not fail"),
})])
}
/// Clear the retry draft for an attempt (set to null)
pub fn retry_clear(attempt_id: Uuid) -> Patch {
Patch(vec![PatchOperation::Replace(ReplaceOperation {
path: retry_path(attempt_id)
.try_into()
.expect("Path should be valid"),
value: serde_json::Value::Null,
})])
}
}
/// Helper functions for creating task attempt-specific patches
pub mod task_attempt_patch {
use super::*;
fn attempt_path(attempt_id: Uuid) -> String {
format!(
"/task_attempts/{}",
escape_pointer_segment(&attempt_id.to_string())
)
}
/// Create patch for adding a new task attempt
pub fn add(attempt: &TaskAttempt) -> Patch {
Patch(vec![PatchOperation::Add(AddOperation {
path: attempt_path(attempt.id)
.try_into()
.expect("Task attempt path should be valid"),
value: serde_json::to_value(attempt)
.expect("Task attempt serialization should not fail"),
})])
}
/// Create patch for updating an existing task attempt
pub fn replace(attempt: &TaskAttempt) -> Patch {
Patch(vec![PatchOperation::Replace(ReplaceOperation {
path: attempt_path(attempt.id)
.try_into()
.expect("Task attempt path should be valid"),
value: serde_json::to_value(attempt)
.expect("Task attempt serialization should not fail"),
})])
}
/// Create patch for removing a task attempt
pub fn remove(attempt_id: Uuid) -> Patch {
Patch(vec![PatchOperation::Remove(RemoveOperation {
path: attempt_path(attempt_id)
.try_into()
.expect("Task attempt path should be valid"),
})])
}
}

View File

@@ -0,0 +1,374 @@
use db::models::{
draft::{Draft, DraftType},
execution_process::ExecutionProcess,
task::{Task, TaskWithAttemptStatus},
};
use futures::StreamExt;
use serde_json::json;
use tokio_stream::wrappers::BroadcastStream;
use utils::log_msg::LogMsg;
use uuid::Uuid;
use super::{
EventService,
patches::execution_process_patch,
types::{EventError, EventPatch, RecordTypes},
};
impl EventService {
/// Stream raw task messages for a specific project with initial snapshot
pub async fn stream_tasks_raw(
&self,
project_id: Uuid,
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
{
// Get initial snapshot of tasks
let tasks = Task::find_by_project_id_with_attempt_status(&self.db.pool, project_id).await?;
// Convert task array to object keyed by task ID
let tasks_map: serde_json::Map<String, serde_json::Value> = tasks
.into_iter()
.map(|task| (task.id.to_string(), serde_json::to_value(task).unwrap()))
.collect();
let initial_patch = json!([{
"op": "replace",
"path": "/tasks",
"value": tasks_map
}]);
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
// Clone necessary data for the async filter
let db_pool = self.db.pool.clone();
// Get filtered event stream
let filtered_stream =
BroadcastStream::new(self.msg_store.get_receiver()).filter_map(move |msg_result| {
let db_pool = db_pool.clone();
async move {
match msg_result {
Ok(LogMsg::JsonPatch(patch)) => {
// Filter events based on project_id
if let Some(patch_op) = patch.0.first() {
// Check if this is a direct task patch (new format)
if patch_op.path().starts_with("/tasks/") {
match patch_op {
json_patch::PatchOperation::Add(op) => {
// Parse task data directly from value
if let Ok(task) =
serde_json::from_value::<TaskWithAttemptStatus>(
op.value.clone(),
)
&& task.project_id == project_id
{
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
json_patch::PatchOperation::Replace(op) => {
// Parse task data directly from value
if let Ok(task) =
serde_json::from_value::<TaskWithAttemptStatus>(
op.value.clone(),
)
&& task.project_id == project_id
{
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
json_patch::PatchOperation::Remove(_) => {
// For remove operations, we need to check project membership differently
// We could cache this information or let it pass through for now
// Since we don't have the task data, we'll allow all removals
// and let the client handle filtering
return Some(Ok(LogMsg::JsonPatch(patch)));
}
_ => {}
}
} else if let Ok(event_patch_value) = serde_json::to_value(patch_op)
&& let Ok(event_patch) =
serde_json::from_value::<EventPatch>(event_patch_value)
{
// Handle old EventPatch format for non-task records
match &event_patch.value.record {
RecordTypes::Task(task) => {
if task.project_id == project_id {
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
RecordTypes::DeletedTask {
project_id: Some(deleted_project_id),
..
} => {
if *deleted_project_id == project_id {
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
RecordTypes::TaskAttempt(attempt) => {
// Check if this task_attempt belongs to a task in our project
if let Ok(Some(task)) =
Task::find_by_id(&db_pool, attempt.task_id).await
&& task.project_id == project_id
{
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
RecordTypes::DeletedTaskAttempt {
task_id: Some(deleted_task_id),
..
} => {
// Check if deleted attempt belonged to a task in our project
if let Ok(Some(task)) =
Task::find_by_id(&db_pool, *deleted_task_id).await
&& task.project_id == project_id
{
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
_ => {}
}
}
}
None
}
Ok(other) => Some(Ok(other)), // Pass through non-patch messages
Err(_) => None, // Filter out broadcast errors
}
}
});
// Start with initial snapshot, then live updates
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
let combined_stream = initial_stream.chain(filtered_stream).boxed();
Ok(combined_stream)
}
/// Stream execution processes for a specific task attempt with initial snapshot (raw LogMsg format for WebSocket)
pub async fn stream_execution_processes_for_attempt_raw(
&self,
task_attempt_id: Uuid,
show_soft_deleted: bool,
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
{
// Get initial snapshot of execution processes (filtering at SQL level)
let processes = ExecutionProcess::find_by_task_attempt_id(
&self.db.pool,
task_attempt_id,
show_soft_deleted,
)
.await?;
// Convert processes array to object keyed by process ID
let processes_map: serde_json::Map<String, serde_json::Value> = processes
.into_iter()
.map(|process| {
(
process.id.to_string(),
serde_json::to_value(process).unwrap(),
)
})
.collect();
let initial_patch = json!([{
"op": "replace",
"path": "/execution_processes",
"value": processes_map
}]);
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
// Get filtered event stream
let filtered_stream = BroadcastStream::new(self.msg_store.get_receiver()).filter_map(
move |msg_result| async move {
match msg_result {
Ok(LogMsg::JsonPatch(patch)) => {
// Filter events based on task_attempt_id
if let Some(patch_op) = patch.0.first() {
// Check if this is a modern execution process patch
if patch_op.path().starts_with("/execution_processes/") {
match patch_op {
json_patch::PatchOperation::Add(op) => {
// Parse execution process data directly from value
if let Ok(process) =
serde_json::from_value::<ExecutionProcess>(
op.value.clone(),
)
&& process.task_attempt_id == task_attempt_id
{
if !show_soft_deleted && process.dropped {
let remove_patch =
execution_process_patch::remove(process.id);
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
}
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
json_patch::PatchOperation::Replace(op) => {
// Parse execution process data directly from value
if let Ok(process) =
serde_json::from_value::<ExecutionProcess>(
op.value.clone(),
)
&& process.task_attempt_id == task_attempt_id
{
if !show_soft_deleted && process.dropped {
let remove_patch =
execution_process_patch::remove(process.id);
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
}
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
json_patch::PatchOperation::Remove(_) => {
// For remove operations, we can't verify task_attempt_id
// so we allow all removals and let the client handle filtering
return Some(Ok(LogMsg::JsonPatch(patch)));
}
_ => {}
}
}
// Fallback to legacy EventPatch format for backward compatibility
else if let Ok(event_patch_value) = serde_json::to_value(patch_op)
&& let Ok(event_patch) =
serde_json::from_value::<EventPatch>(event_patch_value)
{
match &event_patch.value.record {
RecordTypes::ExecutionProcess(process) => {
if process.task_attempt_id == task_attempt_id {
if !show_soft_deleted && process.dropped {
let remove_patch =
execution_process_patch::remove(process.id);
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
}
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
RecordTypes::DeletedExecutionProcess {
task_attempt_id: Some(deleted_attempt_id),
..
} => {
if *deleted_attempt_id == task_attempt_id {
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
_ => {}
}
}
}
None
}
Ok(other) => Some(Ok(other)), // Pass through non-patch messages
Err(_) => None, // Filter out broadcast errors
}
},
);
// Start with initial snapshot, then live updates
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
let combined_stream = initial_stream.chain(filtered_stream).boxed();
Ok(combined_stream)
}
/// Stream drafts for all task attempts in a project with initial snapshot (raw LogMsg)
pub async fn stream_drafts_for_project_raw(
&self,
project_id: Uuid,
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
{
// Load all attempt ids for tasks in this project
let attempt_ids: Vec<Uuid> = sqlx::query_scalar(
r#"SELECT ta.id
FROM task_attempts ta
JOIN tasks t ON t.id = ta.task_id
WHERE t.project_id = ?"#,
)
.bind(project_id)
.fetch_all(&self.db.pool)
.await?;
// Build initial drafts map keyed by attempt_id
let mut drafts_map: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
for attempt_id in attempt_ids {
let fu = Draft::find_by_task_attempt_and_type(
&self.db.pool,
attempt_id,
DraftType::FollowUp,
)
.await?
.unwrap_or(Draft {
id: uuid::Uuid::new_v4(),
task_attempt_id: attempt_id,
draft_type: DraftType::FollowUp,
retry_process_id: None,
prompt: String::new(),
queued: false,
sending: false,
variant: None,
image_ids: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
version: 0,
});
let re =
Draft::find_by_task_attempt_and_type(&self.db.pool, attempt_id, DraftType::Retry)
.await?;
let entry = json!({
"follow_up": fu,
"retry": serde_json::to_value(re).unwrap_or(serde_json::Value::Null),
});
drafts_map.insert(attempt_id.to_string(), entry);
}
let initial_patch = json!([
{
"op": "replace",
"path": "/drafts",
"value": drafts_map
}
]);
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
let db_pool = self.db.pool.clone();
// Live updates: accept direct draft patches and filter by project membership
let filtered_stream =
BroadcastStream::new(self.msg_store.get_receiver()).filter_map(move |msg_result| {
let db_pool = db_pool.clone();
async move {
match msg_result {
Ok(LogMsg::JsonPatch(patch)) => {
if let Some(op) = patch.0.first() {
let path = op.path();
if let Some(rest) = path.strip_prefix("/drafts/")
&& let Some((attempt_str, _)) = rest.split_once('/')
&& let Ok(attempt_id) = Uuid::parse_str(attempt_str)
{
// Check project membership
if let Ok(Some(task_attempt)) =
db::models::task_attempt::TaskAttempt::find_by_id(
&db_pool, attempt_id,
)
.await
&& let Ok(Some(task)) = db::models::task::Task::find_by_id(
&db_pool,
task_attempt.task_id,
)
.await
&& task.project_id == project_id
{
return Some(Ok(LogMsg::JsonPatch(patch)));
}
}
}
None
}
Ok(other) => Some(Ok(other)),
Err(_) => None,
}
}
});
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
let combined_stream = initial_stream.chain(filtered_stream).boxed();
Ok(combined_stream)
}
}

View File

@@ -0,0 +1,77 @@
use anyhow::Error as AnyhowError;
use db::models::{
draft::{Draft, DraftType},
execution_process::ExecutionProcess,
task::Task,
task_attempt::TaskAttempt,
};
use serde::{Deserialize, Serialize};
use sqlx::Error as SqlxError;
use strum_macros::{Display, EnumString};
use thiserror::Error;
use ts_rs::TS;
use uuid::Uuid;
#[derive(Debug, Error)]
pub enum EventError {
#[error(transparent)]
Sqlx(#[from] SqlxError),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Other(#[from] AnyhowError), // Catches any unclassified errors
}
#[derive(EnumString, Display)]
pub enum HookTables {
#[strum(to_string = "tasks")]
Tasks,
#[strum(to_string = "task_attempts")]
TaskAttempts,
#[strum(to_string = "execution_processes")]
ExecutionProcesses,
#[strum(to_string = "drafts")]
Drafts,
}
#[derive(Serialize, Deserialize, TS)]
#[serde(tag = "type", content = "data", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum RecordTypes {
Task(Task),
TaskAttempt(TaskAttempt),
ExecutionProcess(ExecutionProcess),
Draft(Draft),
RetryDraft(Draft),
DeletedTask {
rowid: i64,
project_id: Option<Uuid>,
task_id: Option<Uuid>,
},
DeletedTaskAttempt {
rowid: i64,
task_id: Option<Uuid>,
},
DeletedExecutionProcess {
rowid: i64,
task_attempt_id: Option<Uuid>,
process_id: Option<Uuid>,
},
DeletedDraft {
rowid: i64,
draft_type: DraftType,
task_attempt_id: Option<Uuid>,
},
}
#[derive(Serialize, Deserialize, TS)]
pub struct EventPatchInner {
pub(crate) db_op: String,
pub(crate) record: RecordTypes,
}
#[derive(Serialize, Deserialize, TS)]
pub struct EventPatch {
pub(crate) op: String,
pub(crate) path: String,
pub(crate) value: EventPatchInner,
}

View File

@@ -89,6 +89,36 @@ impl std::fmt::Display for Commit {
}
}
#[derive(Debug, Clone, Copy)]
pub struct WorktreeResetOptions {
pub perform_reset: bool,
pub force_when_dirty: bool,
pub is_dirty: bool,
pub log_skip_when_dirty: bool,
}
impl WorktreeResetOptions {
pub fn new(
perform_reset: bool,
force_when_dirty: bool,
is_dirty: bool,
log_skip_when_dirty: bool,
) -> Self {
Self {
perform_reset,
force_when_dirty,
is_dirty,
log_skip_when_dirty,
}
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct WorktreeResetOutcome {
pub needed: bool,
pub applied: bool,
}
/// Target for diff generation
pub enum DiffTarget<'p> {
/// Work-in-progress branch checked out in this worktree
@@ -1074,6 +1104,47 @@ impl GitService {
.map_err(|e| GitServiceError::InvalidRepository(format!("git status failed: {e}")))
}
/// Evaluate whether any action is needed to reset to `target_commit_oid` and
/// optionally perform the actions.
pub fn reconcile_worktree_to_commit(
&self,
worktree_path: &Path,
target_commit_oid: &str,
options: WorktreeResetOptions,
) -> WorktreeResetOutcome {
let WorktreeResetOptions {
perform_reset,
force_when_dirty,
is_dirty,
log_skip_when_dirty,
} = options;
let head_oid = self.get_head_info(worktree_path).ok().map(|h| h.oid);
let mut outcome = WorktreeResetOutcome::default();
if head_oid.as_deref() != Some(target_commit_oid) || is_dirty {
outcome.needed = true;
if perform_reset {
if is_dirty && !force_when_dirty {
if log_skip_when_dirty {
tracing::warn!("Worktree dirty; skipping reset as not forced");
}
} else if let Err(e) = self.reset_worktree_to_commit(
worktree_path,
target_commit_oid,
force_when_dirty,
) {
tracing::error!("Failed to reset worktree: {}", e);
} else {
outcome.applied = true;
}
}
}
outcome
}
/// Reset the given worktree to the specified commit SHA.
/// If `force` is false and the worktree is dirty, returns WorktreeDirty error.
pub fn reset_worktree_to_commit(

View File

@@ -3,6 +3,7 @@ pub mod approvals;
pub mod auth;
pub mod config;
pub mod container;
pub mod drafts;
pub mod events;
pub mod file_ranker;
pub mod file_search_cache;