Consolidate Retry and Follow-up (#800)
This commit is contained in:
@@ -148,6 +148,19 @@ pub trait ContainerService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup_action(&self, cleanup_script: Option<String>) -> Option<Box<ExecutorAction>> {
|
||||
cleanup_script.map(|script| {
|
||||
Box::new(ExecutorAction::new(
|
||||
ExecutorActionType::ScriptRequest(ScriptRequest {
|
||||
script,
|
||||
language: ScriptRequestLanguage::Bash,
|
||||
context: ScriptContext::CleanupScript,
|
||||
}),
|
||||
None,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
async fn try_stop(&self, task_attempt: &TaskAttempt) {
|
||||
// stop all execution processes for this attempt
|
||||
if let Ok(processes) =
|
||||
@@ -499,16 +512,7 @@ pub trait ContainerService {
|
||||
);
|
||||
let prompt = ImageService::canonicalise_image_paths(&task.to_prompt(), &worktree_path);
|
||||
|
||||
let cleanup_action = project.cleanup_script.map(|script| {
|
||||
Box::new(ExecutorAction::new(
|
||||
ExecutorActionType::ScriptRequest(ScriptRequest {
|
||||
script,
|
||||
language: ScriptRequestLanguage::Bash,
|
||||
context: ScriptContext::CleanupScript,
|
||||
}),
|
||||
None,
|
||||
))
|
||||
});
|
||||
let cleanup_action = self.cleanup_action(project.cleanup_script);
|
||||
|
||||
// Choose whether to execute the setup_script or coding agent first
|
||||
let execution_process = if let Some(setup_script) = project.setup_script {
|
||||
|
||||
474
crates/services/src/services/drafts.rs
Normal file
474
crates/services/src/services/drafts.rs
Normal file
@@ -0,0 +1,474 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use db::{
|
||||
DBService,
|
||||
models::{
|
||||
draft::{Draft, DraftType, UpsertDraft},
|
||||
execution_process::{ExecutionProcess, ExecutionProcessError, ExecutionProcessRunReason},
|
||||
image::TaskImage,
|
||||
task_attempt::TaskAttempt,
|
||||
},
|
||||
};
|
||||
use executors::{
|
||||
actions::{
|
||||
ExecutorAction, ExecutorActionType, coding_agent_follow_up::CodingAgentFollowUpRequest,
|
||||
},
|
||||
profile::ExecutorProfileId,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Error as SqlxError;
|
||||
use thiserror::Error;
|
||||
use ts_rs::TS;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
container::{ContainerError, ContainerService},
|
||||
image::{ImageError, ImageService},
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DraftsServiceError {
|
||||
#[error(transparent)]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error(transparent)]
|
||||
Container(#[from] ContainerError),
|
||||
#[error(transparent)]
|
||||
Image(#[from] ImageError),
|
||||
#[error(transparent)]
|
||||
ExecutionProcess(#[from] ExecutionProcessError),
|
||||
#[error("Conflict: {0}")]
|
||||
Conflict(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, TS)]
|
||||
pub struct DraftResponse {
|
||||
pub task_attempt_id: Uuid,
|
||||
pub draft_type: DraftType,
|
||||
pub retry_process_id: Option<Uuid>,
|
||||
pub prompt: String,
|
||||
pub queued: bool,
|
||||
pub variant: Option<String>,
|
||||
pub image_ids: Option<Vec<Uuid>>,
|
||||
pub version: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, TS)]
|
||||
pub struct UpdateFollowUpDraftRequest {
|
||||
pub prompt: Option<String>,
|
||||
pub variant: Option<Option<String>>,
|
||||
pub image_ids: Option<Vec<Uuid>>,
|
||||
pub version: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, TS)]
|
||||
pub struct UpdateRetryFollowUpDraftRequest {
|
||||
pub retry_process_id: Uuid,
|
||||
pub prompt: Option<String>,
|
||||
pub variant: Option<Option<String>>,
|
||||
pub image_ids: Option<Vec<Uuid>>,
|
||||
pub version: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, TS)]
|
||||
pub struct SetQueueRequest {
|
||||
pub queued: bool,
|
||||
pub expected_queued: Option<bool>,
|
||||
pub expected_version: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DraftsService {
|
||||
db: DBService,
|
||||
image: ImageService,
|
||||
}
|
||||
|
||||
impl DraftsService {
|
||||
pub fn new(db: DBService, image: ImageService) -> Self {
|
||||
Self { db, image }
|
||||
}
|
||||
|
||||
fn pool(&self) -> &sqlx::SqlitePool {
|
||||
&self.db.pool
|
||||
}
|
||||
|
||||
fn draft_to_response(d: Draft) -> DraftResponse {
|
||||
DraftResponse {
|
||||
task_attempt_id: d.task_attempt_id,
|
||||
draft_type: d.draft_type,
|
||||
retry_process_id: d.retry_process_id,
|
||||
prompt: d.prompt,
|
||||
queued: d.queued,
|
||||
variant: d.variant,
|
||||
image_ids: d.image_ids,
|
||||
version: d.version,
|
||||
}
|
||||
}
|
||||
|
||||
async fn ensure_follow_up_draft_row(
|
||||
&self,
|
||||
attempt_id: Uuid,
|
||||
) -> Result<Draft, DraftsServiceError> {
|
||||
if let Some(d) =
|
||||
Draft::find_by_task_attempt_and_type(self.pool(), attempt_id, DraftType::FollowUp)
|
||||
.await?
|
||||
{
|
||||
return Ok(d);
|
||||
}
|
||||
|
||||
let _ = Draft::upsert(
|
||||
self.pool(),
|
||||
&UpsertDraft {
|
||||
task_attempt_id: attempt_id,
|
||||
draft_type: DraftType::FollowUp,
|
||||
retry_process_id: None,
|
||||
prompt: "".to_string(),
|
||||
queued: false,
|
||||
variant: None,
|
||||
image_ids: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Draft::find_by_task_attempt_and_type(self.pool(), attempt_id, DraftType::FollowUp)
|
||||
.await?
|
||||
.ok_or(SqlxError::RowNotFound)
|
||||
.map_err(DraftsServiceError::from)
|
||||
}
|
||||
|
||||
async fn associate_images_for_task_if_any(
|
||||
&self,
|
||||
task_id: Uuid,
|
||||
image_ids: &Option<Vec<Uuid>>,
|
||||
) -> Result<(), DraftsServiceError> {
|
||||
if let Some(ids) = image_ids
|
||||
&& !ids.is_empty()
|
||||
{
|
||||
TaskImage::associate_many_dedup(self.pool(), task_id, ids).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn has_running_processes_for_attempt(
|
||||
&self,
|
||||
attempt_id: Uuid,
|
||||
) -> Result<bool, DraftsServiceError> {
|
||||
let processes =
|
||||
ExecutionProcess::find_by_task_attempt_id(self.pool(), attempt_id, false).await?;
|
||||
Ok(processes.into_iter().any(|p| {
|
||||
matches!(
|
||||
p.status,
|
||||
db::models::execution_process::ExecutionProcessStatus::Running
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
async fn fetch_draft_response(
|
||||
&self,
|
||||
task_attempt_id: Uuid,
|
||||
draft_type: DraftType,
|
||||
) -> Result<DraftResponse, DraftsServiceError> {
|
||||
let d =
|
||||
Draft::find_by_task_attempt_and_type(self.pool(), task_attempt_id, draft_type).await?;
|
||||
let resp = if let Some(d) = d {
|
||||
Self::draft_to_response(d)
|
||||
} else {
|
||||
DraftResponse {
|
||||
task_attempt_id,
|
||||
draft_type,
|
||||
retry_process_id: None,
|
||||
prompt: "".to_string(),
|
||||
queued: false,
|
||||
variant: None,
|
||||
image_ids: None,
|
||||
version: 0,
|
||||
}
|
||||
};
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn handle_images_for_prompt(
|
||||
&self,
|
||||
task_id: Uuid,
|
||||
image_ids: &[Uuid],
|
||||
prompt: &str,
|
||||
worktree_path: &Path,
|
||||
) -> Result<String, DraftsServiceError> {
|
||||
if image_ids.is_empty() {
|
||||
return Ok(prompt.to_string());
|
||||
}
|
||||
|
||||
TaskImage::associate_many_dedup(self.pool(), task_id, image_ids).await?;
|
||||
self.image
|
||||
.copy_images_by_ids_to_worktree(worktree_path, image_ids)
|
||||
.await?;
|
||||
Ok(ImageService::canonicalise_image_paths(
|
||||
prompt,
|
||||
worktree_path,
|
||||
))
|
||||
}
|
||||
|
||||
async fn start_follow_up_from_draft(
|
||||
&self,
|
||||
container: &(dyn ContainerService + Send + Sync),
|
||||
task_attempt: &TaskAttempt,
|
||||
draft: &Draft,
|
||||
) -> Result<ExecutionProcess, DraftsServiceError> {
|
||||
let worktree_ref = container.ensure_container_exists(task_attempt).await?;
|
||||
let worktree_path = PathBuf::from(worktree_ref);
|
||||
let session_id =
|
||||
ExecutionProcess::require_latest_session_id(self.pool(), task_attempt.id).await?;
|
||||
|
||||
let base_profile =
|
||||
ExecutionProcess::latest_executor_profile_for_attempt(self.pool(), task_attempt.id)
|
||||
.await?;
|
||||
let executor_profile_id = ExecutorProfileId {
|
||||
executor: base_profile.executor,
|
||||
variant: draft.variant.clone(),
|
||||
};
|
||||
|
||||
let task = task_attempt
|
||||
.parent_task(self.pool())
|
||||
.await?
|
||||
.ok_or(SqlxError::RowNotFound)
|
||||
.map_err(DraftsServiceError::from)?;
|
||||
let project = task
|
||||
.parent_project(self.pool())
|
||||
.await?
|
||||
.ok_or(SqlxError::RowNotFound)
|
||||
.map_err(DraftsServiceError::from)?;
|
||||
|
||||
let cleanup_action = container.cleanup_action(project.cleanup_script);
|
||||
|
||||
let mut prompt = draft.prompt.clone();
|
||||
if let Some(image_ids) = &draft.image_ids {
|
||||
prompt = self
|
||||
.handle_images_for_prompt(task_attempt.task_id, image_ids, &prompt, &worktree_path)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let follow_up_request = CodingAgentFollowUpRequest {
|
||||
prompt,
|
||||
session_id,
|
||||
executor_profile_id,
|
||||
};
|
||||
|
||||
let follow_up_action = ExecutorAction::new(
|
||||
ExecutorActionType::CodingAgentFollowUpRequest(follow_up_request),
|
||||
cleanup_action,
|
||||
);
|
||||
|
||||
let execution_process = container
|
||||
.start_execution(
|
||||
task_attempt,
|
||||
&follow_up_action,
|
||||
&ExecutionProcessRunReason::CodingAgent,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _ = Draft::clear_after_send(self.pool(), task_attempt.id, DraftType::FollowUp).await;
|
||||
|
||||
Ok(execution_process)
|
||||
}
|
||||
|
||||
pub async fn save_follow_up_draft(
|
||||
&self,
|
||||
task_attempt: &TaskAttempt,
|
||||
payload: &UpdateFollowUpDraftRequest,
|
||||
) -> Result<DraftResponse, DraftsServiceError> {
|
||||
let pool = self.pool();
|
||||
let d = self.ensure_follow_up_draft_row(task_attempt.id).await?;
|
||||
if d.queued {
|
||||
return Err(DraftsServiceError::Conflict(
|
||||
"Draft is queued; click Edit to unqueue before editing".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(expected_version) = payload.version
|
||||
&& d.version != expected_version
|
||||
{
|
||||
return Err(DraftsServiceError::Conflict(
|
||||
"Draft changed, please retry with latest".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if payload.prompt.is_none() && payload.variant.is_none() && payload.image_ids.is_none() {
|
||||
} else {
|
||||
Draft::update_partial(
|
||||
pool,
|
||||
task_attempt.id,
|
||||
DraftType::FollowUp,
|
||||
payload.prompt.clone(),
|
||||
payload.variant.clone(),
|
||||
payload.image_ids.clone(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(task) = task_attempt.parent_task(pool).await? {
|
||||
self.associate_images_for_task_if_any(task.id, &payload.image_ids)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let current =
|
||||
Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::FollowUp)
|
||||
.await?
|
||||
.map(Self::draft_to_response)
|
||||
.unwrap_or(DraftResponse {
|
||||
task_attempt_id: task_attempt.id,
|
||||
draft_type: DraftType::FollowUp,
|
||||
retry_process_id: None,
|
||||
prompt: "".to_string(),
|
||||
queued: false,
|
||||
variant: None,
|
||||
image_ids: None,
|
||||
version: 0,
|
||||
});
|
||||
|
||||
Ok(current)
|
||||
}
|
||||
|
||||
pub async fn save_retry_follow_up_draft(
|
||||
&self,
|
||||
task_attempt: &TaskAttempt,
|
||||
payload: &UpdateRetryFollowUpDraftRequest,
|
||||
) -> Result<DraftResponse, DraftsServiceError> {
|
||||
let pool = self.pool();
|
||||
let existing =
|
||||
Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::Retry).await?;
|
||||
|
||||
if let Some(d) = &existing {
|
||||
if d.queued {
|
||||
return Err(DraftsServiceError::Conflict(
|
||||
"Retry draft is queued; unqueue before editing".to_string(),
|
||||
));
|
||||
}
|
||||
if let Some(expected_version) = payload.version
|
||||
&& d.version != expected_version
|
||||
{
|
||||
return Err(DraftsServiceError::Conflict(
|
||||
"Retry draft changed, please retry with latest".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if existing.is_none() {
|
||||
let draft = Draft::upsert(
|
||||
pool,
|
||||
&UpsertDraft {
|
||||
task_attempt_id: task_attempt.id,
|
||||
draft_type: DraftType::Retry,
|
||||
retry_process_id: Some(payload.retry_process_id),
|
||||
prompt: payload.prompt.clone().unwrap_or_default(),
|
||||
queued: false,
|
||||
variant: payload.variant.clone().unwrap_or(None),
|
||||
image_ids: payload.image_ids.clone(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Ok(Self::draft_to_response(draft));
|
||||
}
|
||||
|
||||
if payload.prompt.is_none() && payload.variant.is_none() && payload.image_ids.is_none() {
|
||||
} else {
|
||||
Draft::update_partial(
|
||||
pool,
|
||||
task_attempt.id,
|
||||
DraftType::Retry,
|
||||
payload.prompt.clone(),
|
||||
payload.variant.clone(),
|
||||
payload.image_ids.clone(),
|
||||
Some(payload.retry_process_id),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(task) = task_attempt.parent_task(pool).await? {
|
||||
self.associate_images_for_task_if_any(task.id, &payload.image_ids)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let draft = Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::Retry)
|
||||
.await?
|
||||
.ok_or(SqlxError::RowNotFound)
|
||||
.map_err(DraftsServiceError::from)?;
|
||||
Ok(Self::draft_to_response(draft))
|
||||
}
|
||||
|
||||
pub async fn delete_retry_follow_up_draft(
|
||||
&self,
|
||||
task_attempt: &TaskAttempt,
|
||||
) -> Result<(), DraftsServiceError> {
|
||||
Draft::delete_by_task_attempt_and_type(self.pool(), task_attempt.id, DraftType::Retry)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_follow_up_queue(
|
||||
&self,
|
||||
container: &(dyn ContainerService + Send + Sync),
|
||||
task_attempt: &TaskAttempt,
|
||||
payload: &SetQueueRequest,
|
||||
) -> Result<DraftResponse, DraftsServiceError> {
|
||||
let pool = self.pool();
|
||||
|
||||
let rows_updated = Draft::set_queued(
|
||||
pool,
|
||||
task_attempt.id,
|
||||
DraftType::FollowUp,
|
||||
payload.queued,
|
||||
payload.expected_queued,
|
||||
payload.expected_version,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let draft =
|
||||
Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::FollowUp)
|
||||
.await?;
|
||||
|
||||
if rows_updated == 0 {
|
||||
if draft.is_none() {
|
||||
return Err(DraftsServiceError::Conflict(
|
||||
"No draft to queue".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
return Err(DraftsServiceError::Conflict(
|
||||
"Draft changed, please refresh and try again".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let should_consider_start = draft.as_ref().map(|c| c.queued).unwrap_or(false)
|
||||
&& !self
|
||||
.has_running_processes_for_attempt(task_attempt.id)
|
||||
.await?;
|
||||
|
||||
if should_consider_start
|
||||
&& Draft::try_mark_sending(pool, task_attempt.id, DraftType::FollowUp)
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let _ = self
|
||||
.start_follow_up_from_draft(container, task_attempt, draft.as_ref().unwrap())
|
||||
.await;
|
||||
}
|
||||
|
||||
let draft =
|
||||
Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::FollowUp)
|
||||
.await?
|
||||
.ok_or(SqlxError::RowNotFound)
|
||||
.map_err(DraftsServiceError::from)?;
|
||||
|
||||
Ok(Self::draft_to_response(draft))
|
||||
}
|
||||
|
||||
pub async fn get_draft(
|
||||
&self,
|
||||
task_attempt_id: Uuid,
|
||||
draft_type: DraftType,
|
||||
) -> Result<DraftResponse, DraftsServiceError> {
|
||||
self.fetch_draft_response(task_attempt_id, draft_type).await
|
||||
}
|
||||
}
|
||||
@@ -1,202 +1,29 @@
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
use anyhow::Error as AnyhowError;
|
||||
use db::{
|
||||
DBService,
|
||||
models::{
|
||||
draft::{Draft, DraftType},
|
||||
execution_process::ExecutionProcess,
|
||||
task::{Task, TaskWithAttemptStatus},
|
||||
task::Task,
|
||||
task_attempt::TaskAttempt,
|
||||
},
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use sqlx::{Error as SqlxError, SqlitePool, ValueRef, sqlite::SqliteOperation};
|
||||
use strum_macros::{Display, EnumString};
|
||||
use thiserror::Error;
|
||||
use sqlx::{Error as SqlxError, Sqlite, SqlitePool, decode::Decode, sqlite::SqliteOperation};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use ts_rs::TS;
|
||||
use utils::{log_msg::LogMsg, msg_store::MsgStore};
|
||||
use utils::msg_store::MsgStore;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum EventError {
|
||||
#[error(transparent)]
|
||||
Sqlx(#[from] SqlxError),
|
||||
#[error(transparent)]
|
||||
Parse(#[from] serde_json::Error),
|
||||
#[error(transparent)]
|
||||
Other(#[from] AnyhowError), // Catches any unclassified errors
|
||||
}
|
||||
#[path = "events/patches.rs"]
|
||||
pub mod patches;
|
||||
#[path = "events/streams.rs"]
|
||||
mod streams;
|
||||
#[path = "events/types.rs"]
|
||||
pub mod types;
|
||||
|
||||
/// Trait for types that can be used in JSON patch operations
|
||||
pub trait Patchable: serde::Serialize {
|
||||
const PATH_PREFIX: &'static str;
|
||||
type Id: ToString + Copy;
|
||||
fn id(&self) -> Self::Id;
|
||||
}
|
||||
|
||||
/// Implementations of Patchable for all supported types
|
||||
impl Patchable for TaskWithAttemptStatus {
|
||||
const PATH_PREFIX: &'static str = "/tasks";
|
||||
type Id = Uuid;
|
||||
fn id(&self) -> Self::Id {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl Patchable for ExecutionProcess {
|
||||
const PATH_PREFIX: &'static str = "/execution_processes";
|
||||
type Id = Uuid;
|
||||
fn id(&self) -> Self::Id {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl Patchable for TaskAttempt {
|
||||
const PATH_PREFIX: &'static str = "/task_attempts";
|
||||
type Id = Uuid;
|
||||
fn id(&self) -> Self::Id {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl Patchable for db::models::follow_up_draft::FollowUpDraft {
|
||||
const PATH_PREFIX: &'static str = "/follow_up_drafts";
|
||||
type Id = Uuid;
|
||||
fn id(&self) -> Self::Id {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic patch operations that work with any Patchable type
|
||||
pub mod patch_ops {
|
||||
use super::*;
|
||||
|
||||
/// Escape JSON Pointer special characters
|
||||
pub(crate) fn escape_pointer_segment(s: &str) -> String {
|
||||
s.replace('~', "~0").replace('/', "~1")
|
||||
}
|
||||
|
||||
/// Create path for operation
|
||||
fn path_for<T: Patchable>(id: T::Id) -> String {
|
||||
format!(
|
||||
"{}/{}",
|
||||
T::PATH_PREFIX,
|
||||
escape_pointer_segment(&id.to_string())
|
||||
)
|
||||
}
|
||||
|
||||
/// Create patch for adding a new record
|
||||
pub fn add<T: Patchable>(value: &T) -> Patch {
|
||||
Patch(vec![PatchOperation::Add(AddOperation {
|
||||
path: path_for::<T>(value.id())
|
||||
.try_into()
|
||||
.expect("Path should be valid"),
|
||||
value: serde_json::to_value(value).expect("Serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Create patch for updating an existing record
|
||||
pub fn replace<T: Patchable>(value: &T) -> Patch {
|
||||
Patch(vec![PatchOperation::Replace(ReplaceOperation {
|
||||
path: path_for::<T>(value.id())
|
||||
.try_into()
|
||||
.expect("Path should be valid"),
|
||||
value: serde_json::to_value(value).expect("Serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Create patch for removing a record
|
||||
pub fn remove<T: Patchable>(id: T::Id) -> Patch {
|
||||
Patch(vec![PatchOperation::Remove(RemoveOperation {
|
||||
path: path_for::<T>(id).try_into().expect("Path should be valid"),
|
||||
})])
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functions for creating task-specific patches
|
||||
pub mod task_patch {
|
||||
use super::*;
|
||||
|
||||
/// Create patch for adding a new task
|
||||
pub fn add(task: &TaskWithAttemptStatus) -> Patch {
|
||||
patch_ops::add(task)
|
||||
}
|
||||
|
||||
/// Create patch for updating an existing task
|
||||
pub fn replace(task: &TaskWithAttemptStatus) -> Patch {
|
||||
patch_ops::replace(task)
|
||||
}
|
||||
|
||||
/// Create patch for removing a task
|
||||
pub fn remove(task_id: Uuid) -> Patch {
|
||||
patch_ops::remove::<TaskWithAttemptStatus>(task_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functions for creating execution process-specific patches
|
||||
pub mod execution_process_patch {
|
||||
use super::*;
|
||||
|
||||
/// Create patch for adding a new execution process
|
||||
pub fn add(process: &ExecutionProcess) -> Patch {
|
||||
patch_ops::add(process)
|
||||
}
|
||||
|
||||
/// Create patch for updating an existing execution process
|
||||
pub fn replace(process: &ExecutionProcess) -> Patch {
|
||||
patch_ops::replace(process)
|
||||
}
|
||||
|
||||
/// Create patch for removing an execution process
|
||||
pub fn remove(process_id: Uuid) -> Patch {
|
||||
patch_ops::remove::<ExecutionProcess>(process_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functions for creating task attempt-specific patches
|
||||
pub mod task_attempt_patch {
|
||||
use super::*;
|
||||
|
||||
/// Create patch for adding a new task attempt
|
||||
pub fn add(attempt: &TaskAttempt) -> Patch {
|
||||
patch_ops::add(attempt)
|
||||
}
|
||||
|
||||
/// Create patch for updating an existing task attempt
|
||||
pub fn replace(attempt: &TaskAttempt) -> Patch {
|
||||
patch_ops::replace(attempt)
|
||||
}
|
||||
|
||||
/// Create patch for removing a task attempt
|
||||
pub fn remove(attempt_id: Uuid) -> Patch {
|
||||
patch_ops::remove::<TaskAttempt>(attempt_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functions for creating follow up draft-specific patches
|
||||
pub mod follow_up_draft_patch {
|
||||
use super::*;
|
||||
|
||||
/// Create patch for adding a new follow up draft
|
||||
pub fn add(draft: &db::models::follow_up_draft::FollowUpDraft) -> Patch {
|
||||
patch_ops::add(draft)
|
||||
}
|
||||
|
||||
/// Create patch for updating an existing follow up draft
|
||||
pub fn replace(draft: &db::models::follow_up_draft::FollowUpDraft) -> Patch {
|
||||
patch_ops::replace(draft)
|
||||
}
|
||||
|
||||
/// Create patch for removing a follow up draft
|
||||
pub fn remove(draft_id: Uuid) -> Patch {
|
||||
patch_ops::remove::<db::models::follow_up_draft::FollowUpDraft>(draft_id)
|
||||
}
|
||||
}
|
||||
pub use patches::{draft_patch, execution_process_patch, task_attempt_patch, task_patch};
|
||||
pub use types::{EventError, EventPatch, EventPatchInner, HookTables, RecordTypes};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EventService {
|
||||
@@ -206,40 +33,6 @@ pub struct EventService {
|
||||
entry_count: Arc<RwLock<usize>>,
|
||||
}
|
||||
|
||||
#[derive(EnumString, Display)]
|
||||
enum HookTables {
|
||||
#[strum(to_string = "tasks")]
|
||||
Tasks,
|
||||
#[strum(to_string = "task_attempts")]
|
||||
TaskAttempts,
|
||||
#[strum(to_string = "execution_processes")]
|
||||
ExecutionProcesses,
|
||||
#[strum(to_string = "follow_up_drafts")]
|
||||
FollowUpDrafts,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, TS)]
|
||||
#[serde(tag = "type", content = "data", rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum RecordTypes {
|
||||
Task(Task),
|
||||
TaskAttempt(TaskAttempt),
|
||||
ExecutionProcess(ExecutionProcess),
|
||||
FollowUpDraft(db::models::follow_up_draft::FollowUpDraft),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, TS)]
|
||||
pub struct EventPatchInner {
|
||||
db_op: String,
|
||||
record: RecordTypes,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, TS)]
|
||||
pub struct EventPatch {
|
||||
op: String,
|
||||
path: String,
|
||||
value: EventPatchInner,
|
||||
}
|
||||
|
||||
impl EventService {
|
||||
/// Creates a new EventService that will work with a DBService configured with hooks
|
||||
pub fn new(db: DBService, msg_store: Arc<MsgStore>, entry_count: Arc<RwLock<usize>>) -> Self {
|
||||
@@ -297,85 +90,67 @@ impl EventService {
|
||||
let msg_store_for_hook = msg_store.clone();
|
||||
let entry_count_for_hook = entry_count.clone();
|
||||
let db_for_hook = db_service.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let mut handle = conn.lock_handle().await?;
|
||||
let runtime_handle = tokio::runtime::Handle::current();
|
||||
|
||||
// Set up preupdate hook to capture task data before deletion
|
||||
handle.set_preupdate_hook({
|
||||
let msg_store_for_preupdate = msg_store_for_hook.clone();
|
||||
move |preupdate: sqlx::sqlite::PreupdateHookResult<'_>| {
|
||||
if preupdate.operation == sqlx::sqlite::SqliteOperation::Delete {
|
||||
match preupdate.table {
|
||||
"tasks" => {
|
||||
// Extract task ID from old column values before deletion
|
||||
if let Ok(id_value) = preupdate.get_old_column_value(0)
|
||||
&& !id_value.is_null()
|
||||
{
|
||||
// Decode UUID from SQLite value
|
||||
if let Ok(task_id) =
|
||||
<uuid::Uuid as sqlx::Decode<'_, sqlx::Sqlite>>::decode(
|
||||
id_value,
|
||||
)
|
||||
{
|
||||
let patch = task_patch::remove(task_id);
|
||||
msg_store_for_preupdate.push_patch(patch);
|
||||
}
|
||||
}
|
||||
}
|
||||
"execution_processes" => {
|
||||
// Extract process ID from old column values before deletion
|
||||
if let Ok(id_value) = preupdate.get_old_column_value(0)
|
||||
&& !id_value.is_null()
|
||||
{
|
||||
// Decode UUID from SQLite value
|
||||
if let Ok(process_id) =
|
||||
<uuid::Uuid as sqlx::Decode<'_, sqlx::Sqlite>>::decode(
|
||||
id_value,
|
||||
)
|
||||
{
|
||||
let patch = execution_process_patch::remove(process_id);
|
||||
msg_store_for_preupdate.push_patch(patch);
|
||||
}
|
||||
}
|
||||
}
|
||||
"task_attempts" => {
|
||||
// Extract attempt ID from old column values before deletion
|
||||
if let Ok(id_value) = preupdate.get_old_column_value(0)
|
||||
&& !id_value.is_null()
|
||||
{
|
||||
// Decode UUID from SQLite value
|
||||
if let Ok(attempt_id) =
|
||||
<uuid::Uuid as sqlx::Decode<'_, sqlx::Sqlite>>::decode(
|
||||
id_value,
|
||||
)
|
||||
{
|
||||
let patch = task_attempt_patch::remove(attempt_id);
|
||||
msg_store_for_preupdate.push_patch(patch);
|
||||
}
|
||||
}
|
||||
}
|
||||
"follow_up_drafts" => {
|
||||
// Extract draft ID from old column values before deletion
|
||||
if let Ok(id_value) = preupdate.get_old_column_value(0)
|
||||
&& !id_value.is_null()
|
||||
{
|
||||
// Decode UUID from SQLite value
|
||||
if let Ok(draft_id) =
|
||||
<uuid::Uuid as sqlx::Decode<'_, sqlx::Sqlite>>::decode(
|
||||
id_value,
|
||||
)
|
||||
{
|
||||
let patch = follow_up_draft_patch::remove(draft_id);
|
||||
msg_store_for_preupdate.push_patch(patch);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Ignore other tables
|
||||
if preupdate.operation != SqliteOperation::Delete {
|
||||
return;
|
||||
}
|
||||
|
||||
match preupdate.table {
|
||||
"tasks" => {
|
||||
if let Ok(value) = preupdate.get_old_column_value(0)
|
||||
&& let Ok(task_id) = <Uuid as Decode<Sqlite>>::decode(value)
|
||||
{
|
||||
let patch = task_patch::remove(task_id);
|
||||
msg_store_for_preupdate.push_patch(patch);
|
||||
}
|
||||
}
|
||||
"task_attempts" => {
|
||||
if let Ok(value) = preupdate.get_old_column_value(0)
|
||||
&& let Ok(attempt_id) = <Uuid as Decode<Sqlite>>::decode(value)
|
||||
{
|
||||
let patch = task_attempt_patch::remove(attempt_id);
|
||||
msg_store_for_preupdate.push_patch(patch);
|
||||
}
|
||||
}
|
||||
"execution_processes" => {
|
||||
if let Ok(value) = preupdate.get_old_column_value(0)
|
||||
&& let Ok(process_id) = <Uuid as Decode<Sqlite>>::decode(value)
|
||||
{
|
||||
let patch = execution_process_patch::remove(process_id);
|
||||
msg_store_for_preupdate.push_patch(patch);
|
||||
}
|
||||
}
|
||||
"drafts" => {
|
||||
let draft_type = preupdate
|
||||
.get_old_column_value(2)
|
||||
.ok()
|
||||
.and_then(|val| <String as Decode<Sqlite>>::decode(val).ok())
|
||||
.and_then(|s| DraftType::from_str(&s).ok());
|
||||
let task_attempt_id = preupdate
|
||||
.get_old_column_value(1)
|
||||
.ok()
|
||||
.and_then(|val| <Uuid as Decode<Sqlite>>::decode(val).ok());
|
||||
|
||||
if let (Some(draft_type), Some(task_attempt_id)) =
|
||||
(draft_type, task_attempt_id)
|
||||
{
|
||||
let patch = match draft_type {
|
||||
DraftType::FollowUp => {
|
||||
draft_patch::follow_up_clear(task_attempt_id)
|
||||
}
|
||||
DraftType::Retry => {
|
||||
draft_patch::retry_clear(task_attempt_id)
|
||||
}
|
||||
};
|
||||
msg_store_for_preupdate.push_patch(patch);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -390,28 +165,20 @@ impl EventService {
|
||||
let rowid = hook.rowid;
|
||||
runtime_handle.spawn(async move {
|
||||
let record_type: RecordTypes = match (table, hook.operation.clone()) {
|
||||
(HookTables::Tasks, SqliteOperation::Delete) => {
|
||||
// Task deletion is now handled by preupdate hook
|
||||
// Skip post-update processing to avoid duplicate patches
|
||||
return;
|
||||
}
|
||||
(HookTables::ExecutionProcesses, SqliteOperation::Delete) => {
|
||||
// Execution process deletion is now handled by preupdate hook
|
||||
// Skip post-update processing to avoid duplicate patches
|
||||
return;
|
||||
}
|
||||
(HookTables::TaskAttempts, SqliteOperation::Delete) => {
|
||||
// Task attempt deletion is now handled by preupdate hook
|
||||
// Skip post-update processing to avoid duplicate patches
|
||||
(HookTables::Tasks, SqliteOperation::Delete)
|
||||
| (HookTables::TaskAttempts, SqliteOperation::Delete)
|
||||
| (HookTables::ExecutionProcesses, SqliteOperation::Delete)
|
||||
| (HookTables::Drafts, SqliteOperation::Delete) => {
|
||||
// Deletions handled in preupdate hook for reliable data capture
|
||||
return;
|
||||
}
|
||||
(HookTables::Tasks, _) => {
|
||||
match Task::find_by_rowid(&db.pool, rowid).await {
|
||||
Ok(Some(task)) => RecordTypes::Task(task),
|
||||
Ok(None) => {
|
||||
// Row not found - likely already deleted, skip processing
|
||||
tracing::debug!("Task rowid {} not found, skipping", rowid);
|
||||
return;
|
||||
Ok(None) => RecordTypes::DeletedTask {
|
||||
rowid,
|
||||
project_id: None,
|
||||
task_id: None,
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch task: {:?}", e);
|
||||
@@ -422,10 +189,9 @@ impl EventService {
|
||||
(HookTables::TaskAttempts, _) => {
|
||||
match TaskAttempt::find_by_rowid(&db.pool, rowid).await {
|
||||
Ok(Some(attempt)) => RecordTypes::TaskAttempt(attempt),
|
||||
Ok(None) => {
|
||||
// Row not found - likely already deleted, skip processing
|
||||
tracing::debug!("TaskAttempt rowid {} not found, skipping", rowid);
|
||||
return;
|
||||
Ok(None) => RecordTypes::DeletedTaskAttempt {
|
||||
rowid,
|
||||
task_id: None,
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
@@ -439,10 +205,10 @@ impl EventService {
|
||||
(HookTables::ExecutionProcesses, _) => {
|
||||
match ExecutionProcess::find_by_rowid(&db.pool, rowid).await {
|
||||
Ok(Some(process)) => RecordTypes::ExecutionProcess(process),
|
||||
Ok(None) => {
|
||||
// Row not found - likely already deleted, skip processing
|
||||
tracing::debug!("ExecutionProcess rowid {} not found, skipping", rowid);
|
||||
return;
|
||||
Ok(None) => RecordTypes::DeletedExecutionProcess {
|
||||
rowid,
|
||||
task_attempt_id: None,
|
||||
process_id: None,
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
@@ -453,28 +219,19 @@ impl EventService {
|
||||
}
|
||||
}
|
||||
}
|
||||
(HookTables::FollowUpDrafts, SqliteOperation::Delete) => {
|
||||
// Follow up draft deletion is now handled by preupdate hook
|
||||
// Skip post-update processing to avoid duplicate patches
|
||||
return;
|
||||
}
|
||||
(HookTables::FollowUpDrafts, _) => {
|
||||
match db::models::follow_up_draft::FollowUpDraft::find_by_rowid(
|
||||
&db.pool, rowid,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(draft)) => RecordTypes::FollowUpDraft(draft),
|
||||
Ok(None) => {
|
||||
// Row not found - likely already deleted, skip processing
|
||||
tracing::debug!("FollowUpDraft rowid {} not found, skipping", rowid);
|
||||
return;
|
||||
(HookTables::Drafts, _) => {
|
||||
match Draft::find_by_rowid(&db.pool, rowid).await {
|
||||
Ok(Some(draft)) => match draft.draft_type {
|
||||
DraftType::FollowUp => RecordTypes::Draft(draft),
|
||||
DraftType::Retry => RecordTypes::RetryDraft(draft),
|
||||
},
|
||||
Ok(None) => RecordTypes::DeletedDraft {
|
||||
rowid,
|
||||
draft_type: DraftType::Retry,
|
||||
task_attempt_id: None,
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to fetch follow_up_draft: {:?}",
|
||||
e
|
||||
);
|
||||
tracing::error!("Failed to fetch draft: {:?}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -514,6 +271,33 @@ impl EventService {
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Draft updates: emit direct patches used by the follow-up draft stream
|
||||
RecordTypes::Draft(draft) => {
|
||||
let patch = draft_patch::follow_up_replace(draft);
|
||||
msg_store_for_hook.push_patch(patch);
|
||||
return;
|
||||
}
|
||||
RecordTypes::RetryDraft(draft) => {
|
||||
let patch = draft_patch::retry_replace(draft);
|
||||
msg_store_for_hook.push_patch(patch);
|
||||
return;
|
||||
}
|
||||
RecordTypes::DeletedDraft { draft_type, task_attempt_id: Some(id), .. } => {
|
||||
let patch = match draft_type {
|
||||
DraftType::FollowUp => draft_patch::follow_up_clear(*id),
|
||||
DraftType::Retry => draft_patch::retry_clear(*id),
|
||||
};
|
||||
msg_store_for_hook.push_patch(patch);
|
||||
return;
|
||||
}
|
||||
RecordTypes::DeletedTask {
|
||||
task_id: Some(task_id),
|
||||
..
|
||||
} => {
|
||||
let patch = task_patch::remove(*task_id);
|
||||
msg_store_for_hook.push_patch(patch);
|
||||
return;
|
||||
}
|
||||
RecordTypes::TaskAttempt(attempt) => {
|
||||
// Task attempts should update the parent task with fresh data
|
||||
if let Ok(Some(task)) =
|
||||
@@ -532,6 +316,27 @@ impl EventService {
|
||||
return;
|
||||
}
|
||||
}
|
||||
RecordTypes::DeletedTaskAttempt {
|
||||
task_id: Some(task_id),
|
||||
..
|
||||
} => {
|
||||
// Task attempt deletion should update the parent task with fresh data
|
||||
if let Ok(Some(task)) =
|
||||
Task::find_by_id(&db.pool, *task_id).await
|
||||
&& let Ok(task_list) =
|
||||
Task::find_by_project_id_with_attempt_status(
|
||||
&db.pool,
|
||||
task.project_id,
|
||||
)
|
||||
.await
|
||||
&& let Some(task_with_status) =
|
||||
task_list.into_iter().find(|t| t.id == *task_id)
|
||||
{
|
||||
let patch = task_patch::replace(&task_with_status);
|
||||
msg_store_for_hook.push_patch(patch);
|
||||
return;
|
||||
}
|
||||
}
|
||||
RecordTypes::ExecutionProcess(process) => {
|
||||
let patch = match hook.operation {
|
||||
SqliteOperation::Insert => {
|
||||
@@ -559,6 +364,31 @@ impl EventService {
|
||||
|
||||
return;
|
||||
}
|
||||
RecordTypes::DeletedExecutionProcess {
|
||||
process_id: Some(process_id),
|
||||
task_attempt_id,
|
||||
..
|
||||
} => {
|
||||
let patch = execution_process_patch::remove(*process_id);
|
||||
msg_store_for_hook.push_patch(patch);
|
||||
|
||||
if let Some(task_attempt_id) = task_attempt_id
|
||||
&& let Err(err) =
|
||||
EventService::push_task_update_for_attempt(
|
||||
&db.pool,
|
||||
msg_store_for_hook.clone(),
|
||||
*task_attempt_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
"Failed to push task update after execution process removal: {:?}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -597,293 +427,4 @@ impl EventService {
|
||||
pub fn msg_store(&self) -> &Arc<MsgStore> {
|
||||
&self.msg_store
|
||||
}
|
||||
|
||||
/// Stream raw task messages for a specific project with initial snapshot
|
||||
pub async fn stream_tasks_raw(
|
||||
&self,
|
||||
project_id: Uuid,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
|
||||
{
|
||||
// Get initial snapshot of tasks
|
||||
let tasks = Task::find_by_project_id_with_attempt_status(&self.db.pool, project_id).await?;
|
||||
|
||||
// Convert task array to object keyed by task ID
|
||||
let tasks_map: serde_json::Map<String, serde_json::Value> = tasks
|
||||
.into_iter()
|
||||
.map(|task| (task.id.to_string(), serde_json::to_value(task).unwrap()))
|
||||
.collect();
|
||||
|
||||
let initial_patch = json!([{
|
||||
"op": "replace",
|
||||
"path": "/tasks",
|
||||
"value": tasks_map
|
||||
}]);
|
||||
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
|
||||
|
||||
// Clone necessary data for the async filter
|
||||
let db_pool = self.db.pool.clone();
|
||||
|
||||
// Get filtered event stream
|
||||
let filtered_stream =
|
||||
BroadcastStream::new(self.msg_store.get_receiver()).filter_map(move |msg_result| {
|
||||
let db_pool = db_pool.clone();
|
||||
async move {
|
||||
match msg_result {
|
||||
Ok(LogMsg::JsonPatch(patch)) => {
|
||||
// Filter events based on project_id
|
||||
if let Some(patch_op) = patch.0.first() {
|
||||
// Check if this is a direct task patch (new format)
|
||||
if patch_op.path().starts_with("/tasks/") {
|
||||
match patch_op {
|
||||
json_patch::PatchOperation::Add(op) => {
|
||||
// Parse task data directly from value
|
||||
if let Ok(task) =
|
||||
serde_json::from_value::<TaskWithAttemptStatus>(
|
||||
op.value.clone(),
|
||||
)
|
||||
&& task.project_id == project_id
|
||||
{
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
json_patch::PatchOperation::Replace(op) => {
|
||||
// Parse task data directly from value
|
||||
if let Ok(task) =
|
||||
serde_json::from_value::<TaskWithAttemptStatus>(
|
||||
op.value.clone(),
|
||||
)
|
||||
&& task.project_id == project_id
|
||||
{
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
json_patch::PatchOperation::Remove(_) => {
|
||||
// For remove operations, we need to check project membership differently
|
||||
// We could cache this information or let it pass through for now
|
||||
// Since we don't have the task data, we'll allow all removals
|
||||
// and let the client handle filtering
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else if let Ok(event_patch_value) = serde_json::to_value(patch_op)
|
||||
&& let Ok(event_patch) =
|
||||
serde_json::from_value::<EventPatch>(event_patch_value)
|
||||
{
|
||||
// Handle old EventPatch format for non-task records
|
||||
match &event_patch.value.record {
|
||||
RecordTypes::Task(task) => {
|
||||
if task.project_id == project_id {
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
RecordTypes::TaskAttempt(attempt) => {
|
||||
// Check if this task_attempt belongs to a task in our project
|
||||
if let Ok(Some(task)) =
|
||||
Task::find_by_id(&db_pool, attempt.task_id).await
|
||||
&& task.project_id == project_id
|
||||
{
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Ok(other) => Some(Ok(other)), // Pass through non-patch messages
|
||||
Err(_) => None, // Filter out broadcast errors
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Start with initial snapshot, then live updates
|
||||
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
|
||||
let combined_stream = initial_stream.chain(filtered_stream).boxed();
|
||||
|
||||
Ok(combined_stream)
|
||||
}
|
||||
|
||||
/// Stream execution processes for a specific task attempt with initial snapshot (raw LogMsg format for WebSocket)
|
||||
pub async fn stream_execution_processes_for_attempt_raw(
|
||||
&self,
|
||||
task_attempt_id: Uuid,
|
||||
show_soft_deleted: bool,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
|
||||
{
|
||||
// Get initial snapshot of execution processes (filtering at SQL level)
|
||||
let processes = ExecutionProcess::find_by_task_attempt_id(
|
||||
&self.db.pool,
|
||||
task_attempt_id,
|
||||
show_soft_deleted,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Convert processes array to object keyed by process ID
|
||||
let processes_map: serde_json::Map<String, serde_json::Value> = processes
|
||||
.into_iter()
|
||||
.map(|process| {
|
||||
(
|
||||
process.id.to_string(),
|
||||
serde_json::to_value(process).unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let initial_patch = json!([{
|
||||
"op": "replace",
|
||||
"path": "/execution_processes",
|
||||
"value": processes_map
|
||||
}]);
|
||||
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
|
||||
|
||||
// Get filtered event stream
|
||||
let filtered_stream = BroadcastStream::new(self.msg_store.get_receiver()).filter_map(
|
||||
move |msg_result| async move {
|
||||
match msg_result {
|
||||
Ok(LogMsg::JsonPatch(patch)) => {
|
||||
// Filter events based on task_attempt_id
|
||||
if let Some(patch_op) = patch.0.first() {
|
||||
// Check if this is a modern execution process patch
|
||||
if patch_op.path().starts_with("/execution_processes/") {
|
||||
match patch_op {
|
||||
json_patch::PatchOperation::Add(op) => {
|
||||
// Parse execution process data directly from value
|
||||
if let Ok(process) =
|
||||
serde_json::from_value::<ExecutionProcess>(
|
||||
op.value.clone(),
|
||||
)
|
||||
&& process.task_attempt_id == task_attempt_id
|
||||
{
|
||||
if !show_soft_deleted && process.dropped {
|
||||
return None;
|
||||
}
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
json_patch::PatchOperation::Replace(op) => {
|
||||
// Parse execution process data directly from value
|
||||
if let Ok(process) =
|
||||
serde_json::from_value::<ExecutionProcess>(
|
||||
op.value.clone(),
|
||||
)
|
||||
&& process.task_attempt_id == task_attempt_id
|
||||
{
|
||||
if !show_soft_deleted && process.dropped {
|
||||
let remove_patch =
|
||||
execution_process_patch::remove(process.id);
|
||||
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
|
||||
}
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
json_patch::PatchOperation::Remove(_) => {
|
||||
// For remove operations, we can't verify task_attempt_id
|
||||
// so we allow all removals and let the client handle filtering
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// Fallback to legacy EventPatch format for backward compatibility
|
||||
else if let Ok(event_patch_value) = serde_json::to_value(patch_op)
|
||||
&& let Ok(event_patch) =
|
||||
serde_json::from_value::<EventPatch>(event_patch_value)
|
||||
&& let RecordTypes::ExecutionProcess(process) =
|
||||
&event_patch.value.record
|
||||
&& process.task_attempt_id == task_attempt_id
|
||||
{
|
||||
if !show_soft_deleted && process.dropped {
|
||||
let remove_patch = execution_process_patch::remove(process.id);
|
||||
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
|
||||
}
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Ok(other) => Some(Ok(other)), // Pass through non-patch messages
|
||||
Err(_) => None, // Filter out broadcast errors
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Start with initial snapshot, then live updates
|
||||
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
|
||||
let combined_stream = initial_stream.chain(filtered_stream).boxed();
|
||||
|
||||
Ok(combined_stream)
|
||||
}
|
||||
|
||||
/// Stream follow-up draft for a specific task attempt (raw LogMsg format for WebSocket)
|
||||
pub async fn stream_follow_up_draft_for_attempt_raw(
|
||||
&self,
|
||||
task_attempt_id: Uuid,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
|
||||
{
|
||||
// Get initial snapshot of follow-up draft
|
||||
let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id(
|
||||
&self.db.pool,
|
||||
task_attempt_id,
|
||||
)
|
||||
.await?
|
||||
.unwrap_or(db::models::follow_up_draft::FollowUpDraft {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
task_attempt_id,
|
||||
prompt: String::new(),
|
||||
queued: false,
|
||||
sending: false,
|
||||
variant: None,
|
||||
image_ids: None,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
version: 0,
|
||||
});
|
||||
|
||||
let initial_patch = json!([
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/",
|
||||
"value": { "follow_up_draft": draft }
|
||||
}
|
||||
]);
|
||||
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
|
||||
|
||||
// Filtered live stream, mapped into direct JSON patches that update /follow_up_draft
|
||||
let filtered_stream = BroadcastStream::new(self.msg_store.get_receiver()).filter_map(
|
||||
move |msg_result| async move {
|
||||
match msg_result {
|
||||
Ok(LogMsg::JsonPatch(patch)) => {
|
||||
if let Some(event_patch_op) = patch.0.first()
|
||||
&& let Ok(event_patch_value) = serde_json::to_value(event_patch_op)
|
||||
&& let Ok(event_patch) =
|
||||
serde_json::from_value::<EventPatch>(event_patch_value)
|
||||
&& let RecordTypes::FollowUpDraft(draft) = &event_patch.value.record
|
||||
&& draft.task_attempt_id == task_attempt_id
|
||||
{
|
||||
// Build a direct patch to replace /follow_up_draft
|
||||
let direct = json!([
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/follow_up_draft",
|
||||
"value": draft
|
||||
}
|
||||
]);
|
||||
let direct_patch = serde_json::from_value(direct).unwrap();
|
||||
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
|
||||
}
|
||||
None
|
||||
}
|
||||
Ok(other) => Some(Ok(other)),
|
||||
Err(_) => None,
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
|
||||
let combined_stream = initial_stream.chain(filtered_stream).boxed();
|
||||
|
||||
Ok(combined_stream)
|
||||
}
|
||||
}
|
||||
|
||||
204
crates/services/src/services/events/patches.rs
Normal file
204
crates/services/src/services/events/patches.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
use db::models::{
|
||||
draft::{Draft, DraftType},
|
||||
execution_process::ExecutionProcess,
|
||||
task::TaskWithAttemptStatus,
|
||||
task_attempt::TaskAttempt,
|
||||
};
|
||||
use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation};
|
||||
use uuid::Uuid;
|
||||
|
||||
// Shared helper to escape JSON Pointer segments
|
||||
fn escape_pointer_segment(s: &str) -> String {
|
||||
s.replace('~', "~0").replace('/', "~1")
|
||||
}
|
||||
|
||||
/// Helper functions for creating task-specific patches
|
||||
pub mod task_patch {
|
||||
use super::*;
|
||||
|
||||
fn task_path(task_id: Uuid) -> String {
|
||||
format!("/tasks/{}", escape_pointer_segment(&task_id.to_string()))
|
||||
}
|
||||
|
||||
/// Create patch for adding a new task
|
||||
pub fn add(task: &TaskWithAttemptStatus) -> Patch {
|
||||
Patch(vec![PatchOperation::Add(AddOperation {
|
||||
path: task_path(task.id)
|
||||
.try_into()
|
||||
.expect("Task path should be valid"),
|
||||
value: serde_json::to_value(task).expect("Task serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Create patch for updating an existing task
|
||||
pub fn replace(task: &TaskWithAttemptStatus) -> Patch {
|
||||
Patch(vec![PatchOperation::Replace(ReplaceOperation {
|
||||
path: task_path(task.id)
|
||||
.try_into()
|
||||
.expect("Task path should be valid"),
|
||||
value: serde_json::to_value(task).expect("Task serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Create patch for removing a task
|
||||
pub fn remove(task_id: Uuid) -> Patch {
|
||||
Patch(vec![PatchOperation::Remove(RemoveOperation {
|
||||
path: task_path(task_id)
|
||||
.try_into()
|
||||
.expect("Task path should be valid"),
|
||||
})])
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functions for creating execution process-specific patches
|
||||
pub mod execution_process_patch {
|
||||
use super::*;
|
||||
|
||||
fn execution_process_path(process_id: Uuid) -> String {
|
||||
format!(
|
||||
"/execution_processes/{}",
|
||||
escape_pointer_segment(&process_id.to_string())
|
||||
)
|
||||
}
|
||||
|
||||
/// Create patch for adding a new execution process
|
||||
pub fn add(process: &ExecutionProcess) -> Patch {
|
||||
Patch(vec![PatchOperation::Add(AddOperation {
|
||||
path: execution_process_path(process.id)
|
||||
.try_into()
|
||||
.expect("Execution process path should be valid"),
|
||||
value: serde_json::to_value(process)
|
||||
.expect("Execution process serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Create patch for updating an existing execution process
|
||||
pub fn replace(process: &ExecutionProcess) -> Patch {
|
||||
Patch(vec![PatchOperation::Replace(ReplaceOperation {
|
||||
path: execution_process_path(process.id)
|
||||
.try_into()
|
||||
.expect("Execution process path should be valid"),
|
||||
value: serde_json::to_value(process)
|
||||
.expect("Execution process serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Create patch for removing an execution process
|
||||
pub fn remove(process_id: Uuid) -> Patch {
|
||||
Patch(vec![PatchOperation::Remove(RemoveOperation {
|
||||
path: execution_process_path(process_id)
|
||||
.try_into()
|
||||
.expect("Execution process path should be valid"),
|
||||
})])
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functions for creating draft-specific patches
|
||||
pub mod draft_patch {
|
||||
use super::*;
|
||||
|
||||
fn follow_up_path(attempt_id: Uuid) -> String {
|
||||
format!("/drafts/{attempt_id}/follow_up")
|
||||
}
|
||||
|
||||
fn retry_path(attempt_id: Uuid) -> String {
|
||||
format!("/drafts/{attempt_id}/retry")
|
||||
}
|
||||
|
||||
/// Replace the follow-up draft for a specific attempt
|
||||
pub fn follow_up_replace(draft: &Draft) -> Patch {
|
||||
Patch(vec![PatchOperation::Replace(ReplaceOperation {
|
||||
path: follow_up_path(draft.task_attempt_id)
|
||||
.try_into()
|
||||
.expect("Path should be valid"),
|
||||
value: serde_json::to_value(draft).expect("Draft serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Replace the retry draft for a specific attempt
|
||||
pub fn retry_replace(draft: &Draft) -> Patch {
|
||||
Patch(vec![PatchOperation::Replace(ReplaceOperation {
|
||||
path: retry_path(draft.task_attempt_id)
|
||||
.try_into()
|
||||
.expect("Path should be valid"),
|
||||
value: serde_json::to_value(draft).expect("Draft serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Clear the follow-up draft for an attempt (replace with an empty draft)
|
||||
pub fn follow_up_clear(attempt_id: Uuid) -> Patch {
|
||||
let empty = Draft {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
task_attempt_id: attempt_id,
|
||||
draft_type: DraftType::FollowUp,
|
||||
retry_process_id: None,
|
||||
prompt: String::new(),
|
||||
queued: false,
|
||||
sending: false,
|
||||
variant: None,
|
||||
image_ids: None,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
version: 0,
|
||||
};
|
||||
Patch(vec![PatchOperation::Replace(ReplaceOperation {
|
||||
path: follow_up_path(attempt_id)
|
||||
.try_into()
|
||||
.expect("Path should be valid"),
|
||||
value: serde_json::to_value(empty).expect("Draft serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Clear the retry draft for an attempt (set to null)
|
||||
pub fn retry_clear(attempt_id: Uuid) -> Patch {
|
||||
Patch(vec![PatchOperation::Replace(ReplaceOperation {
|
||||
path: retry_path(attempt_id)
|
||||
.try_into()
|
||||
.expect("Path should be valid"),
|
||||
value: serde_json::Value::Null,
|
||||
})])
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functions for creating task attempt-specific patches
|
||||
pub mod task_attempt_patch {
|
||||
use super::*;
|
||||
|
||||
fn attempt_path(attempt_id: Uuid) -> String {
|
||||
format!(
|
||||
"/task_attempts/{}",
|
||||
escape_pointer_segment(&attempt_id.to_string())
|
||||
)
|
||||
}
|
||||
|
||||
/// Create patch for adding a new task attempt
|
||||
pub fn add(attempt: &TaskAttempt) -> Patch {
|
||||
Patch(vec![PatchOperation::Add(AddOperation {
|
||||
path: attempt_path(attempt.id)
|
||||
.try_into()
|
||||
.expect("Task attempt path should be valid"),
|
||||
value: serde_json::to_value(attempt)
|
||||
.expect("Task attempt serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Create patch for updating an existing task attempt
|
||||
pub fn replace(attempt: &TaskAttempt) -> Patch {
|
||||
Patch(vec![PatchOperation::Replace(ReplaceOperation {
|
||||
path: attempt_path(attempt.id)
|
||||
.try_into()
|
||||
.expect("Task attempt path should be valid"),
|
||||
value: serde_json::to_value(attempt)
|
||||
.expect("Task attempt serialization should not fail"),
|
||||
})])
|
||||
}
|
||||
|
||||
/// Create patch for removing a task attempt
|
||||
pub fn remove(attempt_id: Uuid) -> Patch {
|
||||
Patch(vec![PatchOperation::Remove(RemoveOperation {
|
||||
path: attempt_path(attempt_id)
|
||||
.try_into()
|
||||
.expect("Task attempt path should be valid"),
|
||||
})])
|
||||
}
|
||||
}
|
||||
374
crates/services/src/services/events/streams.rs
Normal file
374
crates/services/src/services/events/streams.rs
Normal file
@@ -0,0 +1,374 @@
|
||||
use db::models::{
|
||||
draft::{Draft, DraftType},
|
||||
execution_process::ExecutionProcess,
|
||||
task::{Task, TaskWithAttemptStatus},
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use serde_json::json;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use utils::log_msg::LogMsg;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
EventService,
|
||||
patches::execution_process_patch,
|
||||
types::{EventError, EventPatch, RecordTypes},
|
||||
};
|
||||
|
||||
impl EventService {
|
||||
/// Stream raw task messages for a specific project with initial snapshot
|
||||
pub async fn stream_tasks_raw(
|
||||
&self,
|
||||
project_id: Uuid,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
|
||||
{
|
||||
// Get initial snapshot of tasks
|
||||
let tasks = Task::find_by_project_id_with_attempt_status(&self.db.pool, project_id).await?;
|
||||
|
||||
// Convert task array to object keyed by task ID
|
||||
let tasks_map: serde_json::Map<String, serde_json::Value> = tasks
|
||||
.into_iter()
|
||||
.map(|task| (task.id.to_string(), serde_json::to_value(task).unwrap()))
|
||||
.collect();
|
||||
|
||||
let initial_patch = json!([{
|
||||
"op": "replace",
|
||||
"path": "/tasks",
|
||||
"value": tasks_map
|
||||
}]);
|
||||
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
|
||||
|
||||
// Clone necessary data for the async filter
|
||||
let db_pool = self.db.pool.clone();
|
||||
|
||||
// Get filtered event stream
|
||||
let filtered_stream =
|
||||
BroadcastStream::new(self.msg_store.get_receiver()).filter_map(move |msg_result| {
|
||||
let db_pool = db_pool.clone();
|
||||
async move {
|
||||
match msg_result {
|
||||
Ok(LogMsg::JsonPatch(patch)) => {
|
||||
// Filter events based on project_id
|
||||
if let Some(patch_op) = patch.0.first() {
|
||||
// Check if this is a direct task patch (new format)
|
||||
if patch_op.path().starts_with("/tasks/") {
|
||||
match patch_op {
|
||||
json_patch::PatchOperation::Add(op) => {
|
||||
// Parse task data directly from value
|
||||
if let Ok(task) =
|
||||
serde_json::from_value::<TaskWithAttemptStatus>(
|
||||
op.value.clone(),
|
||||
)
|
||||
&& task.project_id == project_id
|
||||
{
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
json_patch::PatchOperation::Replace(op) => {
|
||||
// Parse task data directly from value
|
||||
if let Ok(task) =
|
||||
serde_json::from_value::<TaskWithAttemptStatus>(
|
||||
op.value.clone(),
|
||||
)
|
||||
&& task.project_id == project_id
|
||||
{
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
json_patch::PatchOperation::Remove(_) => {
|
||||
// For remove operations, we need to check project membership differently
|
||||
// We could cache this information or let it pass through for now
|
||||
// Since we don't have the task data, we'll allow all removals
|
||||
// and let the client handle filtering
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else if let Ok(event_patch_value) = serde_json::to_value(patch_op)
|
||||
&& let Ok(event_patch) =
|
||||
serde_json::from_value::<EventPatch>(event_patch_value)
|
||||
{
|
||||
// Handle old EventPatch format for non-task records
|
||||
match &event_patch.value.record {
|
||||
RecordTypes::Task(task) => {
|
||||
if task.project_id == project_id {
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
RecordTypes::DeletedTask {
|
||||
project_id: Some(deleted_project_id),
|
||||
..
|
||||
} => {
|
||||
if *deleted_project_id == project_id {
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
RecordTypes::TaskAttempt(attempt) => {
|
||||
// Check if this task_attempt belongs to a task in our project
|
||||
if let Ok(Some(task)) =
|
||||
Task::find_by_id(&db_pool, attempt.task_id).await
|
||||
&& task.project_id == project_id
|
||||
{
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
RecordTypes::DeletedTaskAttempt {
|
||||
task_id: Some(deleted_task_id),
|
||||
..
|
||||
} => {
|
||||
// Check if deleted attempt belonged to a task in our project
|
||||
if let Ok(Some(task)) =
|
||||
Task::find_by_id(&db_pool, *deleted_task_id).await
|
||||
&& task.project_id == project_id
|
||||
{
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Ok(other) => Some(Ok(other)), // Pass through non-patch messages
|
||||
Err(_) => None, // Filter out broadcast errors
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Start with initial snapshot, then live updates
|
||||
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
|
||||
let combined_stream = initial_stream.chain(filtered_stream).boxed();
|
||||
|
||||
Ok(combined_stream)
|
||||
}
|
||||
|
||||
/// Stream execution processes for a specific task attempt with initial snapshot (raw LogMsg format for WebSocket)
|
||||
pub async fn stream_execution_processes_for_attempt_raw(
|
||||
&self,
|
||||
task_attempt_id: Uuid,
|
||||
show_soft_deleted: bool,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
|
||||
{
|
||||
// Get initial snapshot of execution processes (filtering at SQL level)
|
||||
let processes = ExecutionProcess::find_by_task_attempt_id(
|
||||
&self.db.pool,
|
||||
task_attempt_id,
|
||||
show_soft_deleted,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Convert processes array to object keyed by process ID
|
||||
let processes_map: serde_json::Map<String, serde_json::Value> = processes
|
||||
.into_iter()
|
||||
.map(|process| {
|
||||
(
|
||||
process.id.to_string(),
|
||||
serde_json::to_value(process).unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let initial_patch = json!([{
|
||||
"op": "replace",
|
||||
"path": "/execution_processes",
|
||||
"value": processes_map
|
||||
}]);
|
||||
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
|
||||
|
||||
// Get filtered event stream
|
||||
let filtered_stream = BroadcastStream::new(self.msg_store.get_receiver()).filter_map(
|
||||
move |msg_result| async move {
|
||||
match msg_result {
|
||||
Ok(LogMsg::JsonPatch(patch)) => {
|
||||
// Filter events based on task_attempt_id
|
||||
if let Some(patch_op) = patch.0.first() {
|
||||
// Check if this is a modern execution process patch
|
||||
if patch_op.path().starts_with("/execution_processes/") {
|
||||
match patch_op {
|
||||
json_patch::PatchOperation::Add(op) => {
|
||||
// Parse execution process data directly from value
|
||||
if let Ok(process) =
|
||||
serde_json::from_value::<ExecutionProcess>(
|
||||
op.value.clone(),
|
||||
)
|
||||
&& process.task_attempt_id == task_attempt_id
|
||||
{
|
||||
if !show_soft_deleted && process.dropped {
|
||||
let remove_patch =
|
||||
execution_process_patch::remove(process.id);
|
||||
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
|
||||
}
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
json_patch::PatchOperation::Replace(op) => {
|
||||
// Parse execution process data directly from value
|
||||
if let Ok(process) =
|
||||
serde_json::from_value::<ExecutionProcess>(
|
||||
op.value.clone(),
|
||||
)
|
||||
&& process.task_attempt_id == task_attempt_id
|
||||
{
|
||||
if !show_soft_deleted && process.dropped {
|
||||
let remove_patch =
|
||||
execution_process_patch::remove(process.id);
|
||||
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
|
||||
}
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
json_patch::PatchOperation::Remove(_) => {
|
||||
// For remove operations, we can't verify task_attempt_id
|
||||
// so we allow all removals and let the client handle filtering
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// Fallback to legacy EventPatch format for backward compatibility
|
||||
else if let Ok(event_patch_value) = serde_json::to_value(patch_op)
|
||||
&& let Ok(event_patch) =
|
||||
serde_json::from_value::<EventPatch>(event_patch_value)
|
||||
{
|
||||
match &event_patch.value.record {
|
||||
RecordTypes::ExecutionProcess(process) => {
|
||||
if process.task_attempt_id == task_attempt_id {
|
||||
if !show_soft_deleted && process.dropped {
|
||||
let remove_patch =
|
||||
execution_process_patch::remove(process.id);
|
||||
return Some(Ok(LogMsg::JsonPatch(remove_patch)));
|
||||
}
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
RecordTypes::DeletedExecutionProcess {
|
||||
task_attempt_id: Some(deleted_attempt_id),
|
||||
..
|
||||
} => {
|
||||
if *deleted_attempt_id == task_attempt_id {
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Ok(other) => Some(Ok(other)), // Pass through non-patch messages
|
||||
Err(_) => None, // Filter out broadcast errors
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Start with initial snapshot, then live updates
|
||||
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
|
||||
let combined_stream = initial_stream.chain(filtered_stream).boxed();
|
||||
|
||||
Ok(combined_stream)
|
||||
}
|
||||
|
||||
/// Stream drafts for all task attempts in a project with initial snapshot (raw LogMsg)
|
||||
pub async fn stream_drafts_for_project_raw(
|
||||
&self,
|
||||
project_id: Uuid,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
|
||||
{
|
||||
// Load all attempt ids for tasks in this project
|
||||
let attempt_ids: Vec<Uuid> = sqlx::query_scalar(
|
||||
r#"SELECT ta.id
|
||||
FROM task_attempts ta
|
||||
JOIN tasks t ON t.id = ta.task_id
|
||||
WHERE t.project_id = ?"#,
|
||||
)
|
||||
.bind(project_id)
|
||||
.fetch_all(&self.db.pool)
|
||||
.await?;
|
||||
|
||||
// Build initial drafts map keyed by attempt_id
|
||||
let mut drafts_map: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
|
||||
for attempt_id in attempt_ids {
|
||||
let fu = Draft::find_by_task_attempt_and_type(
|
||||
&self.db.pool,
|
||||
attempt_id,
|
||||
DraftType::FollowUp,
|
||||
)
|
||||
.await?
|
||||
.unwrap_or(Draft {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
task_attempt_id: attempt_id,
|
||||
draft_type: DraftType::FollowUp,
|
||||
retry_process_id: None,
|
||||
prompt: String::new(),
|
||||
queued: false,
|
||||
sending: false,
|
||||
variant: None,
|
||||
image_ids: None,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
version: 0,
|
||||
});
|
||||
let re =
|
||||
Draft::find_by_task_attempt_and_type(&self.db.pool, attempt_id, DraftType::Retry)
|
||||
.await?;
|
||||
let entry = json!({
|
||||
"follow_up": fu,
|
||||
"retry": serde_json::to_value(re).unwrap_or(serde_json::Value::Null),
|
||||
});
|
||||
drafts_map.insert(attempt_id.to_string(), entry);
|
||||
}
|
||||
|
||||
let initial_patch = json!([
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/drafts",
|
||||
"value": drafts_map
|
||||
}
|
||||
]);
|
||||
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
|
||||
|
||||
let db_pool = self.db.pool.clone();
|
||||
// Live updates: accept direct draft patches and filter by project membership
|
||||
let filtered_stream =
|
||||
BroadcastStream::new(self.msg_store.get_receiver()).filter_map(move |msg_result| {
|
||||
let db_pool = db_pool.clone();
|
||||
async move {
|
||||
match msg_result {
|
||||
Ok(LogMsg::JsonPatch(patch)) => {
|
||||
if let Some(op) = patch.0.first() {
|
||||
let path = op.path();
|
||||
if let Some(rest) = path.strip_prefix("/drafts/")
|
||||
&& let Some((attempt_str, _)) = rest.split_once('/')
|
||||
&& let Ok(attempt_id) = Uuid::parse_str(attempt_str)
|
||||
{
|
||||
// Check project membership
|
||||
if let Ok(Some(task_attempt)) =
|
||||
db::models::task_attempt::TaskAttempt::find_by_id(
|
||||
&db_pool, attempt_id,
|
||||
)
|
||||
.await
|
||||
&& let Ok(Some(task)) = db::models::task::Task::find_by_id(
|
||||
&db_pool,
|
||||
task_attempt.task_id,
|
||||
)
|
||||
.await
|
||||
&& task.project_id == project_id
|
||||
{
|
||||
return Some(Ok(LogMsg::JsonPatch(patch)));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Ok(other) => Some(Ok(other)),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
|
||||
let combined_stream = initial_stream.chain(filtered_stream).boxed();
|
||||
Ok(combined_stream)
|
||||
}
|
||||
}
|
||||
77
crates/services/src/services/events/types.rs
Normal file
77
crates/services/src/services/events/types.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use anyhow::Error as AnyhowError;
|
||||
use db::models::{
|
||||
draft::{Draft, DraftType},
|
||||
execution_process::ExecutionProcess,
|
||||
task::Task,
|
||||
task_attempt::TaskAttempt,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Error as SqlxError;
|
||||
use strum_macros::{Display, EnumString};
|
||||
use thiserror::Error;
|
||||
use ts_rs::TS;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum EventError {
|
||||
#[error(transparent)]
|
||||
Sqlx(#[from] SqlxError),
|
||||
#[error(transparent)]
|
||||
Parse(#[from] serde_json::Error),
|
||||
#[error(transparent)]
|
||||
Other(#[from] AnyhowError), // Catches any unclassified errors
|
||||
}
|
||||
|
||||
#[derive(EnumString, Display)]
|
||||
pub enum HookTables {
|
||||
#[strum(to_string = "tasks")]
|
||||
Tasks,
|
||||
#[strum(to_string = "task_attempts")]
|
||||
TaskAttempts,
|
||||
#[strum(to_string = "execution_processes")]
|
||||
ExecutionProcesses,
|
||||
#[strum(to_string = "drafts")]
|
||||
Drafts,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, TS)]
|
||||
#[serde(tag = "type", content = "data", rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum RecordTypes {
|
||||
Task(Task),
|
||||
TaskAttempt(TaskAttempt),
|
||||
ExecutionProcess(ExecutionProcess),
|
||||
Draft(Draft),
|
||||
RetryDraft(Draft),
|
||||
DeletedTask {
|
||||
rowid: i64,
|
||||
project_id: Option<Uuid>,
|
||||
task_id: Option<Uuid>,
|
||||
},
|
||||
DeletedTaskAttempt {
|
||||
rowid: i64,
|
||||
task_id: Option<Uuid>,
|
||||
},
|
||||
DeletedExecutionProcess {
|
||||
rowid: i64,
|
||||
task_attempt_id: Option<Uuid>,
|
||||
process_id: Option<Uuid>,
|
||||
},
|
||||
DeletedDraft {
|
||||
rowid: i64,
|
||||
draft_type: DraftType,
|
||||
task_attempt_id: Option<Uuid>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, TS)]
|
||||
pub struct EventPatchInner {
|
||||
pub(crate) db_op: String,
|
||||
pub(crate) record: RecordTypes,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, TS)]
|
||||
pub struct EventPatch {
|
||||
pub(crate) op: String,
|
||||
pub(crate) path: String,
|
||||
pub(crate) value: EventPatchInner,
|
||||
}
|
||||
@@ -89,6 +89,36 @@ impl std::fmt::Display for Commit {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct WorktreeResetOptions {
|
||||
pub perform_reset: bool,
|
||||
pub force_when_dirty: bool,
|
||||
pub is_dirty: bool,
|
||||
pub log_skip_when_dirty: bool,
|
||||
}
|
||||
|
||||
impl WorktreeResetOptions {
|
||||
pub fn new(
|
||||
perform_reset: bool,
|
||||
force_when_dirty: bool,
|
||||
is_dirty: bool,
|
||||
log_skip_when_dirty: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
perform_reset,
|
||||
force_when_dirty,
|
||||
is_dirty,
|
||||
log_skip_when_dirty,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy)]
|
||||
pub struct WorktreeResetOutcome {
|
||||
pub needed: bool,
|
||||
pub applied: bool,
|
||||
}
|
||||
|
||||
/// Target for diff generation
|
||||
pub enum DiffTarget<'p> {
|
||||
/// Work-in-progress branch checked out in this worktree
|
||||
@@ -1074,6 +1104,47 @@ impl GitService {
|
||||
.map_err(|e| GitServiceError::InvalidRepository(format!("git status failed: {e}")))
|
||||
}
|
||||
|
||||
/// Evaluate whether any action is needed to reset to `target_commit_oid` and
|
||||
/// optionally perform the actions.
|
||||
pub fn reconcile_worktree_to_commit(
|
||||
&self,
|
||||
worktree_path: &Path,
|
||||
target_commit_oid: &str,
|
||||
options: WorktreeResetOptions,
|
||||
) -> WorktreeResetOutcome {
|
||||
let WorktreeResetOptions {
|
||||
perform_reset,
|
||||
force_when_dirty,
|
||||
is_dirty,
|
||||
log_skip_when_dirty,
|
||||
} = options;
|
||||
|
||||
let head_oid = self.get_head_info(worktree_path).ok().map(|h| h.oid);
|
||||
let mut outcome = WorktreeResetOutcome::default();
|
||||
|
||||
if head_oid.as_deref() != Some(target_commit_oid) || is_dirty {
|
||||
outcome.needed = true;
|
||||
|
||||
if perform_reset {
|
||||
if is_dirty && !force_when_dirty {
|
||||
if log_skip_when_dirty {
|
||||
tracing::warn!("Worktree dirty; skipping reset as not forced");
|
||||
}
|
||||
} else if let Err(e) = self.reset_worktree_to_commit(
|
||||
worktree_path,
|
||||
target_commit_oid,
|
||||
force_when_dirty,
|
||||
) {
|
||||
tracing::error!("Failed to reset worktree: {}", e);
|
||||
} else {
|
||||
outcome.applied = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outcome
|
||||
}
|
||||
|
||||
/// Reset the given worktree to the specified commit SHA.
|
||||
/// If `force` is false and the worktree is dirty, returns WorktreeDirty error.
|
||||
pub fn reset_worktree_to_commit(
|
||||
|
||||
@@ -3,6 +3,7 @@ pub mod approvals;
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
pub mod container;
|
||||
pub mod drafts;
|
||||
pub mod events;
|
||||
pub mod file_ranker;
|
||||
pub mod file_search_cache;
|
||||
|
||||
Reference in New Issue
Block a user