diff --git a/crates/db/.sqlx/query-417b6e6333eb2164b4cb1d9869cf786f34fa0219b30461234c47a869945c2a79.json b/crates/db/.sqlx/query-417b6e6333eb2164b4cb1d9869cf786f34fa0219b30461234c47a869945c2a79.json deleted file mode 100644 index d6c84afb..00000000 --- a/crates/db/.sqlx/query-417b6e6333eb2164b4cb1d9869cf786f34fa0219b30461234c47a869945c2a79.json +++ /dev/null @@ -1,74 +0,0 @@ -{ - "db_name": "SQLite", - "query": "INSERT INTO follow_up_drafts (id, task_attempt_id, prompt, queued, variant, image_ids)\n VALUES ($1, $2, $3, $4, $5, $6)\n ON CONFLICT(task_attempt_id) DO UPDATE SET\n prompt = excluded.prompt,\n queued = excluded.queued,\n variant = excluded.variant,\n image_ids = excluded.image_ids\n RETURNING \n id as \"id!: Uuid\",\n task_attempt_id as \"task_attempt_id!: Uuid\",\n prompt as \"prompt!: String\",\n queued as \"queued!: bool\",\n sending as \"sending!: bool\",\n variant,\n image_ids as \"image_ids?: String\",\n created_at as \"created_at!: DateTime\",\n updated_at as \"updated_at!: DateTime\",\n version as \"version!: i64\"", - "describe": { - "columns": [ - { - "name": "id!: Uuid", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "task_attempt_id!: Uuid", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "prompt!: String", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "queued!: bool", - "ordinal": 3, - "type_info": "Integer" - }, - { - "name": "sending!: bool", - "ordinal": 4, - "type_info": "Integer" - }, - { - "name": "variant", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "image_ids?: String", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at!: DateTime", - "ordinal": 7, - "type_info": "Datetime" - }, - { - "name": "updated_at!: DateTime", - "ordinal": 8, - "type_info": "Datetime" - }, - { - "name": "version!: i64", - "ordinal": 9, - "type_info": "Integer" - } - ], - "parameters": { - "Right": 6 - }, - "nullable": [ - true, - false, - false, - false, - false, - true, - true, - false, - false, - false - ] - }, - "hash": "417b6e6333eb2164b4cb1d9869cf786f34fa0219b30461234c47a869945c2a79" -} diff --git a/crates/db/.sqlx/query-457ee97807c0e4cc329e1c3b8b765b85ec3d0eb90aa38f1f891e7ec9308278e9.json b/crates/db/.sqlx/query-457ee97807c0e4cc329e1c3b8b765b85ec3d0eb90aa38f1f891e7ec9308278e9.json new file mode 100644 index 00000000..3c9656cd --- /dev/null +++ b/crates/db/.sqlx/query-457ee97807c0e4cc329e1c3b8b765b85ec3d0eb90aa38f1f891e7ec9308278e9.json @@ -0,0 +1,86 @@ +{ + "db_name": "SQLite", + "query": "INSERT INTO drafts (id, task_attempt_id, draft_type, retry_process_id, prompt, queued, variant, image_ids)\n VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n ON CONFLICT(task_attempt_id, draft_type) DO UPDATE SET\n retry_process_id = excluded.retry_process_id,\n prompt = excluded.prompt,\n queued = excluded.queued,\n variant = excluded.variant,\n image_ids = excluded.image_ids,\n version = drafts.version + 1\n RETURNING\n id as \"id!: Uuid\",\n task_attempt_id as \"task_attempt_id!: Uuid\",\n draft_type,\n retry_process_id as \"retry_process_id?: Uuid\",\n prompt,\n queued as \"queued!: bool\",\n sending as \"sending!: bool\",\n variant,\n image_ids,\n created_at as \"created_at!: DateTime\",\n updated_at as \"updated_at!: DateTime\",\n version as \"version!: i64\"", + "describe": { + "columns": [ + { + "name": "id!: Uuid", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "task_attempt_id!: Uuid", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "draft_type", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "retry_process_id?: Uuid", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "prompt", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "queued!: bool", + "ordinal": 5, + "type_info": "Integer" + }, + { + "name": "sending!: bool", + "ordinal": 6, + "type_info": "Integer" + }, + { + "name": "variant", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "image_ids", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "created_at!: DateTime", + "ordinal": 9, + "type_info": "Datetime" + }, + { + "name": "updated_at!: DateTime", + "ordinal": 10, + "type_info": "Datetime" + }, + { + "name": "version!: i64", + "ordinal": 11, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 8 + }, + "nullable": [ + true, + false, + false, + true, + false, + false, + false, + true, + true, + false, + false, + false + ] + }, + "hash": "457ee97807c0e4cc329e1c3b8b765b85ec3d0eb90aa38f1f891e7ec9308278e9" +} diff --git a/crates/db/.sqlx/query-c98097bb6edac80896cf320ca9f670f18db291bf4d626923b63dde3445fb4a3d.json b/crates/db/.sqlx/query-971d979ba0156b060d173c37db009407e24b1d507800cec45828c6b9eef75b86.json similarity index 51% rename from crates/db/.sqlx/query-c98097bb6edac80896cf320ca9f670f18db291bf4d626923b63dde3445fb4a3d.json rename to crates/db/.sqlx/query-971d979ba0156b060d173c37db009407e24b1d507800cec45828c6b9eef75b86.json index b7463496..f0d86347 100644 --- a/crates/db/.sqlx/query-c98097bb6edac80896cf320ca9f670f18db291bf4d626923b63dde3445fb4a3d.json +++ b/crates/db/.sqlx/query-971d979ba0156b060d173c37db009407e24b1d507800cec45828c6b9eef75b86.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT \n id as \"id!: Uuid\",\n task_attempt_id as \"task_attempt_id!: Uuid\",\n prompt as \"prompt!: String\",\n queued as \"queued!: bool\",\n sending as \"sending!: bool\",\n variant,\n image_ids as \"image_ids?: String\",\n created_at as \"created_at!: DateTime\",\n updated_at as \"updated_at!: DateTime\",\n version as \"version!: i64\"\n FROM follow_up_drafts\n WHERE task_attempt_id = $1", + "query": "SELECT\n id as \"id!: Uuid\",\n task_attempt_id as \"task_attempt_id!: Uuid\",\n draft_type,\n retry_process_id as \"retry_process_id?: Uuid\",\n prompt,\n queued as \"queued!: bool\",\n sending as \"sending!: bool\",\n variant,\n image_ids,\n created_at as \"created_at!: DateTime\",\n updated_at as \"updated_at!: DateTime\",\n version as \"version!: i64\"\n FROM drafts\n WHERE task_attempt_id = $1 AND draft_type = $2", "describe": { "columns": [ { @@ -14,53 +14,65 @@ "type_info": "Text" }, { - "name": "prompt!: String", + "name": "draft_type", "ordinal": 2, "type_info": "Text" }, { - "name": "queued!: bool", + "name": "retry_process_id?: Uuid", "ordinal": 3, + "type_info": "Text" + }, + { + "name": "prompt", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "queued!: bool", + "ordinal": 5, "type_info": "Integer" }, { "name": "sending!: bool", - "ordinal": 4, + "ordinal": 6, "type_info": "Integer" }, { "name": "variant", - "ordinal": 5, + "ordinal": 7, "type_info": "Text" }, { - "name": "image_ids?: String", - "ordinal": 6, + "name": "image_ids", + "ordinal": 8, "type_info": "Text" }, { "name": "created_at!: DateTime", - "ordinal": 7, + "ordinal": 9, "type_info": "Datetime" }, { "name": "updated_at!: DateTime", - "ordinal": 8, + "ordinal": 10, "type_info": "Datetime" }, { "name": "version!: i64", - "ordinal": 9, + "ordinal": 11, "type_info": "Integer" } ], "parameters": { - "Right": 1 + "Right": 2 }, "nullable": [ true, false, false, + true, + false, false, false, true, @@ -70,5 +82,5 @@ false ] }, - "hash": "c98097bb6edac80896cf320ca9f670f18db291bf4d626923b63dde3445fb4a3d" + "hash": "971d979ba0156b060d173c37db009407e24b1d507800cec45828c6b9eef75b86" } diff --git a/crates/db/.sqlx/query-9778726648c310caa65a00d31e7f9ecc38ca88b7536300143a889eda327ed1a4.json b/crates/db/.sqlx/query-9778726648c310caa65a00d31e7f9ecc38ca88b7536300143a889eda327ed1a4.json deleted file mode 100644 index 1c152df1..00000000 --- a/crates/db/.sqlx/query-9778726648c310caa65a00d31e7f9ecc38ca88b7536300143a889eda327ed1a4.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "db_name": "SQLite", - "query": "UPDATE follow_up_drafts\n SET sending = 1, updated_at = CURRENT_TIMESTAMP, version = version + 1\n WHERE task_attempt_id = $1\n AND queued = 1\n AND sending = 0\n AND TRIM(prompt) != ''", - "describe": { - "columns": [], - "parameters": { - "Right": 1 - }, - "nullable": [] - }, - "hash": "9778726648c310caa65a00d31e7f9ecc38ca88b7536300143a889eda327ed1a4" -} diff --git a/crates/db/.sqlx/query-d3bdec518c805d8eeb37c2c7d782ce05f7dd1d4df18dab306e91d83f874efe90.json b/crates/db/.sqlx/query-d3bdec518c805d8eeb37c2c7d782ce05f7dd1d4df18dab306e91d83f874efe90.json deleted file mode 100644 index 156b6b0e..00000000 --- a/crates/db/.sqlx/query-d3bdec518c805d8eeb37c2c7d782ce05f7dd1d4df18dab306e91d83f874efe90.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "db_name": "SQLite", - "query": "UPDATE follow_up_drafts \n SET prompt = '', queued = 0, sending = 0, image_ids = NULL, updated_at = CURRENT_TIMESTAMP, version = version + 1\n WHERE task_attempt_id = $1", - "describe": { - "columns": [], - "parameters": { - "Right": 1 - }, - "nullable": [] - }, - "hash": "d3bdec518c805d8eeb37c2c7d782ce05f7dd1d4df18dab306e91d83f874efe90" -} diff --git a/crates/db/.sqlx/query-1d406258fa90610bddb8973e25fd9dc4f59b0769d943d2cc74d9008e68670f3e.json b/crates/db/.sqlx/query-eb8c35173d48f942dd8c93ce0d5b88b05fcffa19785a364727dc54fff8741bf4.json similarity index 52% rename from crates/db/.sqlx/query-1d406258fa90610bddb8973e25fd9dc4f59b0769d943d2cc74d9008e68670f3e.json rename to crates/db/.sqlx/query-eb8c35173d48f942dd8c93ce0d5b88b05fcffa19785a364727dc54fff8741bf4.json index 3e036d59..eeb07f32 100644 --- a/crates/db/.sqlx/query-1d406258fa90610bddb8973e25fd9dc4f59b0769d943d2cc74d9008e68670f3e.json +++ b/crates/db/.sqlx/query-eb8c35173d48f942dd8c93ce0d5b88b05fcffa19785a364727dc54fff8741bf4.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT \n id as \"id!: Uuid\",\n task_attempt_id as \"task_attempt_id!: Uuid\",\n prompt as \"prompt!: String\",\n queued as \"queued!: bool\",\n sending as \"sending!: bool\",\n variant,\n image_ids as \"image_ids?: String\",\n created_at as \"created_at!: DateTime\",\n updated_at as \"updated_at!: DateTime\",\n version as \"version!: i64\"\n FROM follow_up_drafts\n WHERE rowid = $1", + "query": "SELECT\n id as \"id!: Uuid\",\n task_attempt_id as \"task_attempt_id!: Uuid\",\n draft_type,\n retry_process_id as \"retry_process_id?: Uuid\",\n prompt,\n queued as \"queued!: bool\",\n sending as \"sending!: bool\",\n variant,\n image_ids,\n created_at as \"created_at!: DateTime\",\n updated_at as \"updated_at!: DateTime\",\n version as \"version!: i64\"\n FROM drafts\n WHERE rowid = $1", "describe": { "columns": [ { @@ -14,43 +14,53 @@ "type_info": "Text" }, { - "name": "prompt!: String", + "name": "draft_type", "ordinal": 2, "type_info": "Text" }, { - "name": "queued!: bool", + "name": "retry_process_id?: Uuid", "ordinal": 3, + "type_info": "Text" + }, + { + "name": "prompt", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "queued!: bool", + "ordinal": 5, "type_info": "Integer" }, { "name": "sending!: bool", - "ordinal": 4, + "ordinal": 6, "type_info": "Integer" }, { "name": "variant", - "ordinal": 5, + "ordinal": 7, "type_info": "Text" }, { - "name": "image_ids?: String", - "ordinal": 6, + "name": "image_ids", + "ordinal": 8, "type_info": "Text" }, { "name": "created_at!: DateTime", - "ordinal": 7, + "ordinal": 9, "type_info": "Datetime" }, { "name": "updated_at!: DateTime", - "ordinal": 8, + "ordinal": 10, "type_info": "Datetime" }, { "name": "version!: i64", - "ordinal": 9, + "ordinal": 11, "type_info": "Integer" } ], @@ -61,6 +71,8 @@ true, false, false, + true, + false, false, false, true, @@ -70,5 +82,5 @@ false ] }, - "hash": "1d406258fa90610bddb8973e25fd9dc4f59b0769d943d2cc74d9008e68670f3e" + "hash": "eb8c35173d48f942dd8c93ce0d5b88b05fcffa19785a364727dc54fff8741bf4" } diff --git a/crates/db/migrations/20250921222241_unify_drafts_tables.sql b/crates/db/migrations/20250921222241_unify_drafts_tables.sql new file mode 100644 index 00000000..91561736 --- /dev/null +++ b/crates/db/migrations/20250921222241_unify_drafts_tables.sql @@ -0,0 +1,53 @@ +-- Unify follow_up_drafts and retry_drafts into a single drafts table +-- This migration consolidates the duplicate code between the two draft types + +-- Create the unified drafts table +CREATE TABLE IF NOT EXISTS drafts ( + id TEXT PRIMARY KEY, + task_attempt_id TEXT NOT NULL, + draft_type TEXT NOT NULL CHECK(draft_type IN ('follow_up', 'retry')), + retry_process_id TEXT NULL, -- Only used for retry drafts + prompt TEXT NOT NULL DEFAULT '', + queued INTEGER NOT NULL DEFAULT 0, + sending INTEGER NOT NULL DEFAULT 0, + version INTEGER NOT NULL DEFAULT 0, + variant TEXT NULL, + image_ids TEXT NULL, -- JSON array of UUID strings + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY(task_attempt_id) REFERENCES task_attempts(id) ON DELETE CASCADE, + FOREIGN KEY(retry_process_id) REFERENCES execution_processes(id) ON DELETE CASCADE, + -- Unique constraint: only one draft per task_attempt_id and draft_type + UNIQUE(task_attempt_id, draft_type) +); + +-- Create indexes +CREATE INDEX IF NOT EXISTS idx_drafts_task_attempt_id + ON drafts(task_attempt_id); + +CREATE INDEX IF NOT EXISTS idx_drafts_draft_type + ON drafts(draft_type); + +CREATE INDEX IF NOT EXISTS idx_drafts_queued_sending + ON drafts(queued, sending) WHERE queued = 1; + +-- Migrate existing follow_up_drafts +INSERT INTO drafts ( + id, task_attempt_id, draft_type, retry_process_id, prompt, + queued, sending, version, variant, image_ids, created_at, updated_at +) +SELECT + id, task_attempt_id, 'follow_up', NULL, prompt, + queued, sending, version, variant, image_ids, created_at, updated_at +FROM follow_up_drafts; + +-- Drop old tables +DROP TABLE IF EXISTS follow_up_drafts; + +-- Create trigger to keep updated_at current +CREATE TRIGGER IF NOT EXISTS trg_drafts_updated_at +AFTER UPDATE ON drafts +FOR EACH ROW +BEGIN + UPDATE drafts SET updated_at = CURRENT_TIMESTAMP WHERE id = OLD.id; +END; \ No newline at end of file diff --git a/crates/db/src/models/draft.rs b/crates/db/src/models/draft.rs new file mode 100644 index 00000000..9259304f --- /dev/null +++ b/crates/db/src/models/draft.rs @@ -0,0 +1,368 @@ +use std::str::FromStr; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::{FromRow, QueryBuilder, Sqlite, SqlitePool}; +use ts_rs::TS; +use uuid::Uuid; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, TS, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +#[ts(rename_all = "snake_case")] +pub enum DraftType { + FollowUp, + Retry, +} + +impl DraftType { + pub fn as_str(&self) -> &'static str { + match self { + DraftType::FollowUp => "follow_up", + DraftType::Retry => "retry", + } + } +} + +impl FromStr for DraftType { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "follow_up" => Ok(DraftType::FollowUp), + "retry" => Ok(DraftType::Retry), + _ => Err(()), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, TS)] +pub struct Draft { + pub id: Uuid, + pub task_attempt_id: Uuid, + pub draft_type: DraftType, + #[serde(skip_serializing_if = "Option::is_none")] + pub retry_process_id: Option, + pub prompt: String, + pub queued: bool, + pub sending: bool, + pub variant: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_ids: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub version: i64, +} + +#[derive(Debug, Clone, FromRow)] +struct DraftRow { + pub id: Uuid, + pub task_attempt_id: Uuid, + pub draft_type: String, + pub retry_process_id: Option, + pub prompt: String, + pub queued: bool, + pub sending: bool, + pub variant: Option, + pub image_ids: Option, + pub created_at: DateTime, + pub updated_at: DateTime, + pub version: i64, +} + +impl From for Draft { + fn from(r: DraftRow) -> Self { + let image_ids = r + .image_ids + .as_deref() + .and_then(|s| serde_json::from_str::>(s).ok()); + Draft { + id: r.id, + task_attempt_id: r.task_attempt_id, + draft_type: DraftType::from_str(&r.draft_type).unwrap_or(DraftType::FollowUp), + retry_process_id: r.retry_process_id, + prompt: r.prompt, + queued: r.queued, + sending: r.sending, + variant: r.variant, + image_ids, + created_at: r.created_at, + updated_at: r.updated_at, + version: r.version, + } + } +} + +#[derive(Debug, Deserialize, TS)] +pub struct UpsertDraft { + pub task_attempt_id: Uuid, + pub draft_type: DraftType, + pub retry_process_id: Option, + pub prompt: String, + pub queued: bool, + pub variant: Option, + pub image_ids: Option>, +} + +impl Draft { + pub async fn find_by_rowid(pool: &SqlitePool, rowid: i64) -> Result, sqlx::Error> { + sqlx::query_as!( + DraftRow, + r#"SELECT + id as "id!: Uuid", + task_attempt_id as "task_attempt_id!: Uuid", + draft_type, + retry_process_id as "retry_process_id?: Uuid", + prompt, + queued as "queued!: bool", + sending as "sending!: bool", + variant, + image_ids, + created_at as "created_at!: DateTime", + updated_at as "updated_at!: DateTime", + version as "version!: i64" + FROM drafts + WHERE rowid = $1"#, + rowid + ) + .fetch_optional(pool) + .await + .map(|opt| opt.map(Draft::from)) + } + + pub async fn find_by_task_attempt_and_type( + pool: &SqlitePool, + task_attempt_id: Uuid, + draft_type: DraftType, + ) -> Result, sqlx::Error> { + let draft_type_str = draft_type.as_str(); + sqlx::query_as!( + DraftRow, + r#"SELECT + id as "id!: Uuid", + task_attempt_id as "task_attempt_id!: Uuid", + draft_type, + retry_process_id as "retry_process_id?: Uuid", + prompt, + queued as "queued!: bool", + sending as "sending!: bool", + variant, + image_ids, + created_at as "created_at!: DateTime", + updated_at as "updated_at!: DateTime", + version as "version!: i64" + FROM drafts + WHERE task_attempt_id = $1 AND draft_type = $2"#, + task_attempt_id, + draft_type_str + ) + .fetch_optional(pool) + .await + .map(|opt| opt.map(Draft::from)) + } + + pub async fn upsert(pool: &SqlitePool, data: &UpsertDraft) -> Result { + // Validate retry_process_id requirement + if data.draft_type == DraftType::Retry && data.retry_process_id.is_none() { + return Err(sqlx::Error::Protocol( + "retry_process_id is required for retry drafts".into(), + )); + } + + let id = Uuid::new_v4(); + let image_ids_json = data + .image_ids + .as_ref() + .map(|ids| serde_json::to_string(ids).unwrap_or_else(|_| "[]".to_string())); + let draft_type_str = data.draft_type.as_str(); + let prompt = data.prompt.clone(); + let variant = data.variant.clone(); + sqlx::query_as!( + DraftRow, + r#"INSERT INTO drafts (id, task_attempt_id, draft_type, retry_process_id, prompt, queued, variant, image_ids) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT(task_attempt_id, draft_type) DO UPDATE SET + retry_process_id = excluded.retry_process_id, + prompt = excluded.prompt, + queued = excluded.queued, + variant = excluded.variant, + image_ids = excluded.image_ids, + version = drafts.version + 1 + RETURNING + id as "id!: Uuid", + task_attempt_id as "task_attempt_id!: Uuid", + draft_type, + retry_process_id as "retry_process_id?: Uuid", + prompt, + queued as "queued!: bool", + sending as "sending!: bool", + variant, + image_ids, + created_at as "created_at!: DateTime", + updated_at as "updated_at!: DateTime", + version as "version!: i64""#, + id, + data.task_attempt_id, + draft_type_str, + data.retry_process_id, + prompt, + data.queued, + variant, + image_ids_json + ) + .fetch_one(pool) + .await + .map(Draft::from) + } + + pub async fn clear_after_send( + pool: &SqlitePool, + task_attempt_id: Uuid, + draft_type: DraftType, + ) -> Result<(), sqlx::Error> { + let draft_type_str = draft_type.as_str(); + + match draft_type { + DraftType::FollowUp => { + // Follow-up drafts: update to empty + sqlx::query( + r#"UPDATE drafts + SET prompt = '', queued = 0, sending = 0, image_ids = NULL, updated_at = CURRENT_TIMESTAMP, version = version + 1 + WHERE task_attempt_id = ? AND draft_type = ?"#, + ) + .bind(task_attempt_id) + .bind(draft_type_str) + .execute(pool) + .await?; + } + DraftType::Retry => { + // Retry drafts: delete the record + Self::delete_by_task_attempt_and_type(pool, task_attempt_id, draft_type).await?; + } + } + Ok(()) + } + + pub async fn delete_by_task_attempt_and_type( + pool: &SqlitePool, + task_attempt_id: Uuid, + draft_type: DraftType, + ) -> Result<(), sqlx::Error> { + sqlx::query(r#"DELETE FROM drafts WHERE task_attempt_id = ? AND draft_type = ?"#) + .bind(task_attempt_id) + .bind(draft_type.as_str()) + .execute(pool) + .await?; + + Ok(()) + } + + /// Attempt to atomically mark this draft as "sending" if it's currently queued and non-empty. + /// Returns true if the row was updated (we acquired the send lock), false otherwise. + pub async fn try_mark_sending( + pool: &SqlitePool, + task_attempt_id: Uuid, + draft_type: DraftType, + ) -> Result { + let draft_type_str = draft_type.as_str(); + let result = sqlx::query( + r#"UPDATE drafts + SET sending = 1, updated_at = CURRENT_TIMESTAMP, version = version + 1 + WHERE task_attempt_id = ? + AND draft_type = ? + AND queued = 1 + AND sending = 0 + AND TRIM(prompt) != ''"#, + ) + .bind(task_attempt_id) + .bind(draft_type_str) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) + } + + /// Partial update on a draft by attempt and type. Updates only provided fields + /// and bumps `updated_at` and `version` when any change occurs. + pub async fn update_partial( + pool: &SqlitePool, + task_attempt_id: Uuid, + draft_type: DraftType, + prompt: Option, + variant: Option>, + image_ids: Option>, + retry_process_id: Option, + ) -> Result<(), sqlx::Error> { + if retry_process_id.is_none() + && prompt.is_none() + && variant.is_none() + && image_ids.is_none() + { + return Ok(()); + } + let mut query = QueryBuilder::::new("UPDATE drafts SET "); + + let mut separated = query.separated(", "); + if let Some(rpid) = retry_process_id { + separated.push("retry_process_id = "); + separated.push_bind_unseparated(rpid); + } + if let Some(p) = prompt { + separated.push("prompt = "); + separated.push_bind_unseparated(p); + } + if let Some(v_opt) = variant { + separated.push("variant = "); + match v_opt { + Some(v) => separated.push_bind_unseparated(v), + None => separated.push_bind_unseparated(Option::::None), + }; + } + if let Some(ids) = image_ids { + let image_ids_json = serde_json::to_string(&ids).unwrap_or_else(|_| "[]".to_string()); + separated.push("image_ids = "); + separated.push_bind_unseparated(image_ids_json); + } + separated.push("updated_at = CURRENT_TIMESTAMP"); + separated.push("version = version + 1"); + + query.push(" WHERE task_attempt_id = "); + query.push_bind(task_attempt_id); + query.push(" AND draft_type = "); + query.push_bind(draft_type.as_str()); + query.build().execute(pool).await?; + Ok(()) + } + + /// Set queued flag (and bump metadata) for a draft by attempt and type. + pub async fn set_queued( + pool: &SqlitePool, + task_attempt_id: Uuid, + draft_type: DraftType, + queued: bool, + expected_queued: Option, + expected_version: Option, + ) -> Result { + let result = sqlx::query( + r#"UPDATE drafts + SET queued = CASE + WHEN ?1 THEN (TRIM(prompt) <> '') + ELSE 0 + END, + updated_at = CURRENT_TIMESTAMP, + version = version + 1 + WHERE task_attempt_id = ?2 + AND draft_type = ?3 + AND (?4 IS NULL OR queued = ?4) + AND (?5 IS NULL OR version = ?5)"#, + ) + .bind(queued as i64) + .bind(task_attempt_id) + .bind(draft_type.as_str()) + .bind(expected_queued.map(|value| value as i64)) + .bind(expected_version) + .execute(pool) + .await?; + + Ok(result.rows_affected()) + } +} diff --git a/crates/db/src/models/execution_process.rs b/crates/db/src/models/execution_process.rs index 02fdcb90..5edb5f0e 100644 --- a/crates/db/src/models/execution_process.rs +++ b/crates/db/src/models/execution_process.rs @@ -1,5 +1,8 @@ use chrono::{DateTime, Utc}; -use executors::actions::ExecutorAction; +use executors::{ + actions::{ExecutorAction, ExecutorActionType}, + profile::ExecutorProfileId, +}; use serde::{Deserialize, Serialize}; use serde_json::Value; use sqlx::{FromRow, SqlitePool, Type}; @@ -21,6 +24,8 @@ pub enum ExecutionProcessError { UpdateFailed(String), #[error("Invalid executor action format")] InvalidExecutorAction, + #[error("Validation error: {0}")] + ValidationError(String), } #[derive(Debug, Clone, Type, Serialize, Deserialize, PartialEq, TS)] @@ -532,4 +537,53 @@ impl ExecutionProcess { task, }) } + + /// Require latest session_id for a task attempt; error if none exists + pub async fn require_latest_session_id( + pool: &SqlitePool, + attempt_id: Uuid, + ) -> Result { + Self::find_latest_session_id_by_task_attempt(pool, attempt_id) + .await? + .ok_or_else(|| { + ExecutionProcessError::ValidationError( + "Couldn't find a prior session_id, please create a new task attempt" + .to_string(), + ) + }) + } + + /// Fetch the latest CodingAgent executor profile for a task attempt + pub async fn latest_executor_profile_for_attempt( + pool: &SqlitePool, + attempt_id: Uuid, + ) -> Result { + let latest_execution_process = Self::find_latest_by_task_attempt_and_run_reason( + pool, + attempt_id, + &ExecutionProcessRunReason::CodingAgent, + ) + .await? + .ok_or_else(|| { + ExecutionProcessError::ValidationError( + "Couldn't find initial coding agent process, has it run yet?".to_string(), + ) + })?; + + let action = latest_execution_process + .executor_action() + .map_err(|e| ExecutionProcessError::ValidationError(e.to_string()))?; + + match &action.typ { + ExecutorActionType::CodingAgentInitialRequest(request) => { + Ok(request.executor_profile_id.clone()) + } + ExecutorActionType::CodingAgentFollowUpRequest(request) => { + Ok(request.executor_profile_id.clone()) + } + _ => Err(ExecutionProcessError::ValidationError( + "Couldn't find profile from initial request".to_string(), + )), + } + } } diff --git a/crates/db/src/models/follow_up_draft.rs b/crates/db/src/models/follow_up_draft.rs deleted file mode 100644 index e87f30af..00000000 --- a/crates/db/src/models/follow_up_draft.rs +++ /dev/null @@ -1,195 +0,0 @@ -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use sqlx::{FromRow, SqlitePool}; -use ts_rs::TS; -use uuid::Uuid; - -#[derive(Debug, Clone, Serialize, Deserialize, TS)] -pub struct FollowUpDraft { - pub id: Uuid, - pub task_attempt_id: Uuid, - pub prompt: String, - pub queued: bool, - pub sending: bool, - pub variant: Option, - // Stored as JSON in the DB; serde handles Uuid <-> string in JSON - #[serde(skip_serializing_if = "Option::is_none")] - pub image_ids: Option>, - pub created_at: DateTime, - pub updated_at: DateTime, - pub version: i64, -} - -#[derive(Debug, Clone, FromRow)] -struct FollowUpDraftRow { - pub id: Uuid, - pub task_attempt_id: Uuid, - pub prompt: String, - pub queued: bool, - pub sending: bool, - pub variant: Option, - pub image_ids: Option, - pub created_at: DateTime, - pub updated_at: DateTime, - pub version: i64, -} - -impl From for FollowUpDraft { - fn from(r: FollowUpDraftRow) -> Self { - let image_ids = r - .image_ids - .as_deref() - .and_then(|s| serde_json::from_str::>(s).ok()); - FollowUpDraft { - id: r.id, - task_attempt_id: r.task_attempt_id, - prompt: r.prompt, - queued: r.queued, - sending: r.sending, - variant: r.variant, - image_ids, - created_at: r.created_at, - updated_at: r.updated_at, - version: r.version, - } - } -} - -#[derive(Debug, Deserialize, TS)] -pub struct UpsertFollowUpDraft { - pub task_attempt_id: Uuid, - pub prompt: String, - pub queued: bool, - pub variant: Option, - pub image_ids: Option>, -} - -impl FollowUpDraft { - pub async fn find_by_rowid(pool: &SqlitePool, rowid: i64) -> Result, sqlx::Error> { - sqlx::query_as!( - FollowUpDraftRow, - r#"SELECT - id as "id!: Uuid", - task_attempt_id as "task_attempt_id!: Uuid", - prompt as "prompt!: String", - queued as "queued!: bool", - sending as "sending!: bool", - variant, - image_ids as "image_ids?: String", - created_at as "created_at!: DateTime", - updated_at as "updated_at!: DateTime", - version as "version!: i64" - FROM follow_up_drafts - WHERE rowid = $1"#, - rowid - ) - .fetch_optional(pool) - .await - .map(|opt| opt.map(FollowUpDraft::from)) - } - pub async fn find_by_task_attempt_id( - pool: &SqlitePool, - task_attempt_id: Uuid, - ) -> Result, sqlx::Error> { - sqlx::query_as!( - FollowUpDraftRow, - r#"SELECT - id as "id!: Uuid", - task_attempt_id as "task_attempt_id!: Uuid", - prompt as "prompt!: String", - queued as "queued!: bool", - sending as "sending!: bool", - variant, - image_ids as "image_ids?: String", - created_at as "created_at!: DateTime", - updated_at as "updated_at!: DateTime", - version as "version!: i64" - FROM follow_up_drafts - WHERE task_attempt_id = $1"#, - task_attempt_id - ) - .fetch_optional(pool) - .await - .map(|opt| opt.map(FollowUpDraft::from)) - } - - pub async fn upsert( - pool: &SqlitePool, - data: &UpsertFollowUpDraft, - ) -> Result { - let id = Uuid::new_v4(); - { - let image_ids_json = data - .image_ids - .as_ref() - .map(|ids| serde_json::to_string(ids).unwrap_or_else(|_| "[]".to_string())); - - sqlx::query_as!( - FollowUpDraftRow, - r#"INSERT INTO follow_up_drafts (id, task_attempt_id, prompt, queued, variant, image_ids) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT(task_attempt_id) DO UPDATE SET - prompt = excluded.prompt, - queued = excluded.queued, - variant = excluded.variant, - image_ids = excluded.image_ids - RETURNING - id as "id!: Uuid", - task_attempt_id as "task_attempt_id!: Uuid", - prompt as "prompt!: String", - queued as "queued!: bool", - sending as "sending!: bool", - variant, - image_ids as "image_ids?: String", - created_at as "created_at!: DateTime", - updated_at as "updated_at!: DateTime", - version as "version!: i64""#, - id, - data.task_attempt_id, - data.prompt, - data.queued, - data.variant, - image_ids_json - ) - .fetch_one(pool) - .await - .map(FollowUpDraft::from) - } - } - - pub async fn clear_after_send( - pool: &SqlitePool, - task_attempt_id: Uuid, - ) -> Result<(), sqlx::Error> { - sqlx::query!( - r#"UPDATE follow_up_drafts - SET prompt = '', queued = 0, sending = 0, image_ids = NULL, updated_at = CURRENT_TIMESTAMP, version = version + 1 - WHERE task_attempt_id = $1"#, - task_attempt_id - ) - .execute(pool) - .await?; - Ok(()) - } - - /// Attempt to atomically mark this draft as "sending" if it's currently queued and non-empty. - /// Returns true if the row was updated (we acquired the send lock), false otherwise. - pub async fn try_mark_sending( - pool: &SqlitePool, - task_attempt_id: Uuid, - ) -> Result { - let result = sqlx::query!( - r#"UPDATE follow_up_drafts - SET sending = 1, updated_at = CURRENT_TIMESTAMP, version = version + 1 - WHERE task_attempt_id = $1 - AND queued = 1 - AND sending = 0 - AND TRIM(prompt) != ''"#, - task_attempt_id - ) - .execute(pool) - .await?; - - Ok(result.rows_affected() > 0) - } -} diff --git a/crates/db/src/models/mod.rs b/crates/db/src/models/mod.rs index 7f5605aa..4861d8f3 100644 --- a/crates/db/src/models/mod.rs +++ b/crates/db/src/models/mod.rs @@ -1,7 +1,7 @@ +pub mod draft; pub mod execution_process; pub mod execution_process_logs; pub mod executor_session; -pub mod follow_up_draft; pub mod image; pub mod merge; pub mod project; diff --git a/crates/deployment/src/lib.rs b/crates/deployment/src/lib.rs index e0621a48..42b834bd 100644 --- a/crates/deployment/src/lib.rs +++ b/crates/deployment/src/lib.rs @@ -22,6 +22,7 @@ use services::services::{ auth::{AuthError, AuthService}, config::{Config, ConfigError}, container::{ContainerError, ContainerService}, + drafts::DraftsService, events::{EventError, EventService}, file_search_cache::FileSearchCache, filesystem::{FilesystemError, FilesystemService}, @@ -105,6 +106,8 @@ pub trait Deployment: Clone + Send + Sync + 'static { fn approvals(&self) -> &Approvals; + fn drafts(&self) -> &DraftsService; + async fn update_sentry_scope(&self) -> Result<(), DeploymentError> { let user_id = self.user_id(); let config = self.config().read().await; diff --git a/crates/local-deployment/src/container.rs b/crates/local-deployment/src/container.rs index cb0a217d..924b432d 100644 --- a/crates/local-deployment/src/container.rs +++ b/crates/local-deployment/src/container.rs @@ -16,11 +16,11 @@ use command_group::AsyncGroupChild; use db::{ DBService, models::{ + draft::{Draft, DraftType}, execution_process::{ ExecutionContext, ExecutionProcess, ExecutionProcessRunReason, ExecutionProcessStatus, }, executor_session::ExecutorSession, - follow_up_draft::FollowUpDraft, image::TaskImage, merge::Merge, project::Project, @@ -1323,8 +1323,12 @@ impl LocalContainerService { } // Load draft and ensure it's eligible - let Some(draft) = - FollowUpDraft::find_by_task_attempt_id(&self.db.pool, ctx.task_attempt.id).await? + let Some(draft) = Draft::find_by_task_attempt_and_type( + &self.db.pool, + ctx.task_attempt.id, + DraftType::FollowUp, + ) + .await? else { return Ok(()); }; @@ -1334,7 +1338,7 @@ impl LocalContainerService { } // Atomically acquire sending lock; if not acquired, someone else is sending. - if !FollowUpDraft::try_mark_sending(&self.db.pool, ctx.task_attempt.id) + if !Draft::try_mark_sending(&self.db.pool, ctx.task_attempt.id, DraftType::FollowUp) .await .unwrap_or(false) { @@ -1396,19 +1400,7 @@ impl LocalContainerService { .task .parent_project(&self.db.pool) .await? - .and_then(|p| p.cleanup_script) - .map(|script| { - Box::new(executors::actions::ExecutorAction::new( - executors::actions::ExecutorActionType::ScriptRequest( - executors::actions::script::ScriptRequest { - script, - language: executors::actions::script::ScriptRequestLanguage::Bash, - context: executors::actions::script::ScriptContext::CleanupScript, - }, - ), - None, - )) - }); + .and_then(|project| self.cleanup_action(project.cleanup_script)); // Handle images: associate, copy to worktree, canonicalize prompt let mut prompt = draft.prompt.clone(); @@ -1451,7 +1443,8 @@ impl LocalContainerService { .await?; // Clear the draft to reflect that it has been consumed - let _ = FollowUpDraft::clear_after_send(&self.db.pool, ctx.task_attempt.id).await; + let _ = + Draft::clear_after_send(&self.db.pool, ctx.task_attempt.id, DraftType::FollowUp).await; Ok(()) } diff --git a/crates/local-deployment/src/lib.rs b/crates/local-deployment/src/lib.rs index 88ff41d2..d135bb0c 100644 --- a/crates/local-deployment/src/lib.rs +++ b/crates/local-deployment/src/lib.rs @@ -10,6 +10,7 @@ use services::services::{ auth::AuthService, config::{Config, load_config_from_file, save_config_to_file}, container::ContainerService, + drafts::DraftsService, events::EventService, file_search_cache::FileSearchCache, filesystem::FilesystemService, @@ -42,6 +43,7 @@ pub struct LocalDeployment { events: EventService, file_search_cache: Arc, approvals: Approvals, + drafts: DraftsService, } #[async_trait] @@ -124,6 +126,7 @@ impl Deployment for LocalDeployment { container.spawn_worktree_cleanup().await; let events = EventService::new(db.clone(), events_msg_store, events_entry_count); + let drafts = DraftsService::new(db.clone(), image.clone()); let file_search_cache = Arc::new(FileSearchCache::new()); Ok(Self { @@ -141,6 +144,7 @@ impl Deployment for LocalDeployment { events, file_search_cache, approvals, + drafts, }) } @@ -202,4 +206,8 @@ impl Deployment for LocalDeployment { fn approvals(&self) -> &Approvals { &self.approvals } + + fn drafts(&self) -> &DraftsService { + &self.drafts + } } diff --git a/crates/server/src/bin/generate_types.rs b/crates/server/src/bin/generate_types.rs index b7a607e9..92aed015 100644 --- a/crates/server/src/bin/generate_types.rs +++ b/crates/server/src/bin/generate_types.rs @@ -43,8 +43,9 @@ fn generate_types_content() -> String { server::routes::config::UpdateMcpServersBody::decl(), server::routes::config::GetMcpServerResponse::decl(), server::routes::task_attempts::CreateFollowUpAttempt::decl(), - server::routes::task_attempts::FollowUpDraftResponse::decl(), - server::routes::task_attempts::UpdateFollowUpDraftRequest::decl(), + services::services::drafts::DraftResponse::decl(), + services::services::drafts::UpdateFollowUpDraftRequest::decl(), + services::services::drafts::UpdateRetryFollowUpDraftRequest::decl(), server::routes::task_attempts::ChangeTargetBranchRequest::decl(), server::routes::task_attempts::ChangeTargetBranchResponse::decl(), server::routes::tasks::CreateAndStartTaskRequest::decl(), @@ -101,7 +102,8 @@ fn generate_types_content() -> String { db::models::merge::PrMerge::decl(), db::models::merge::MergeStatus::decl(), db::models::merge::PullRequestInfo::decl(), - db::models::follow_up_draft::FollowUpDraft::decl(), + db::models::draft::Draft::decl(), + db::models::draft::DraftType::decl(), executors::logs::CommandExitStatus::decl(), executors::logs::CommandRunResult::decl(), executors::logs::NormalizedEntry::decl(), diff --git a/crates/server/src/error.rs b/crates/server/src/error.rs index 53dc9e6d..f1e0eb31 100644 --- a/crates/server/src/error.rs +++ b/crates/server/src/error.rs @@ -11,8 +11,9 @@ use deployment::DeploymentError; use executors::executors::ExecutorError; use git2::Error as Git2Error; use services::services::{ - auth::AuthError, config::ConfigError, container::ContainerError, git::GitServiceError, - github_service::GitHubServiceError, image::ImageError, worktree_manager::WorktreeError, + auth::AuthError, config::ConfigError, container::ContainerError, drafts::DraftsServiceError, + git::GitServiceError, github_service::GitHubServiceError, image::ImageError, + worktree_manager::WorktreeError, }; use thiserror::Error; use utils::response::ApiResponse; @@ -46,6 +47,8 @@ pub enum ApiError { Config(#[from] ConfigError), #[error(transparent)] Image(#[from] ImageError), + #[error(transparent)] + Drafts(#[from] DraftsServiceError), #[error("Multipart error: {0}")] Multipart(#[from] MultipartError), #[error("IO error: {0}")] @@ -95,6 +98,19 @@ impl IntoResponse for ApiError { ImageError::NotFound => (StatusCode::NOT_FOUND, "ImageNotFound"), _ => (StatusCode::INTERNAL_SERVER_ERROR, "ImageError"), }, + ApiError::Drafts(drafts_err) => match drafts_err { + DraftsServiceError::Conflict(_) => (StatusCode::CONFLICT, "ConflictError"), + DraftsServiceError::Database(_) => { + (StatusCode::INTERNAL_SERVER_ERROR, "DatabaseError") + } + DraftsServiceError::Container(_) => { + (StatusCode::INTERNAL_SERVER_ERROR, "ContainerError") + } + DraftsServiceError::Image(_) => (StatusCode::INTERNAL_SERVER_ERROR, "ImageError"), + DraftsServiceError::ExecutionProcess(_) => { + (StatusCode::INTERNAL_SERVER_ERROR, "ExecutionProcessError") + } + }, ApiError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "IoError"), ApiError::Multipart(_) => (StatusCode::BAD_REQUEST, "MultipartError"), ApiError::Conflict(_) => (StatusCode::CONFLICT, "ConflictError"), @@ -122,6 +138,15 @@ impl IntoResponse for ApiError { }, ApiError::Multipart(_) => "Failed to upload file. Please ensure the file is valid and try again.".to_string(), ApiError::Conflict(msg) => msg.clone(), + ApiError::Drafts(drafts_err) => match drafts_err { + DraftsServiceError::Conflict(msg) => msg.clone(), + DraftsServiceError::Database(_) => format!("{}: {}", error_type, drafts_err), + DraftsServiceError::Container(_) => format!("{}: {}", error_type, drafts_err), + DraftsServiceError::Image(_) => format!("{}: {}", error_type, drafts_err), + DraftsServiceError::ExecutionProcess(_) => { + format!("{}: {}", error_type, drafts_err) + } + }, _ => format!("{}: {}", error_type, self), }; let response = ApiResponse::<()>::error(&error_message); diff --git a/crates/server/src/routes/drafts.rs b/crates/server/src/routes/drafts.rs new file mode 100644 index 00000000..959bd9fc --- /dev/null +++ b/crates/server/src/routes/drafts.rs @@ -0,0 +1,67 @@ +use axum::{ + Router, + extract::{ + Query, State, + ws::{WebSocket, WebSocketUpgrade}, + }, + response::IntoResponse, + routing::get, +}; +use deployment::Deployment; +use futures_util::{SinkExt, StreamExt, TryStreamExt}; +use serde::Deserialize; +use uuid::Uuid; + +use crate::DeploymentImpl; + +#[derive(Debug, Deserialize)] +pub struct DraftsQuery { + pub project_id: Uuid, +} + +pub async fn stream_project_drafts_ws( + ws: WebSocketUpgrade, + State(deployment): State, + Query(query): Query, +) -> impl IntoResponse { + ws.on_upgrade(move |socket| async move { + if let Err(e) = handle_project_drafts_ws(socket, deployment, query.project_id).await { + tracing::warn!("drafts WS closed: {}", e); + } + }) +} + +async fn handle_project_drafts_ws( + socket: WebSocket, + deployment: DeploymentImpl, + project_id: Uuid, +) -> anyhow::Result<()> { + let mut stream = deployment + .events() + .stream_drafts_for_project_raw(project_id) + .await? + .map_ok(|msg| msg.to_ws_message_unchecked()); + + let (mut sender, mut receiver) = socket.split(); + tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} }); + + while let Some(item) = stream.next().await { + match item { + Ok(msg) => { + if sender.send(msg).await.is_err() { + break; + } + } + Err(e) => { + tracing::error!("stream error: {}", e); + break; + } + } + } + Ok(()) +} + +pub fn router(_deployment: &DeploymentImpl) -> Router { + let inner = Router::new().route("/stream/ws", get(stream_project_drafts_ws)); + Router::new().nest("/drafts", inner) +} diff --git a/crates/server/src/routes/images.rs b/crates/server/src/routes/images.rs index af40a729..d202af00 100644 --- a/crates/server/src/routes/images.rs +++ b/crates/server/src/routes/images.rs @@ -7,10 +7,14 @@ use axum::{ routing::{delete, get, post}, }; use chrono::{DateTime, Utc}; -use db::models::image::Image; +use db::models::{ + image::{Image, TaskImage}, + task::Task, +}; use deployment::Deployment; use serde::{Deserialize, Serialize}; use services::services::image::ImageError; +use sqlx::Error as SqlxError; use tokio::fs::File; use tokio_util::io::ReaderStream; use ts_rs::TS; @@ -50,9 +54,19 @@ impl ImageResponse { pub async fn upload_image( State(deployment): State, - mut multipart: Multipart, + multipart: Multipart, ) -> Result>, ApiError> { + let image_response = process_image_upload(&deployment, multipart, None).await?; + Ok(ResponseJson(ApiResponse::success(image_response))) +} + +pub(crate) async fn process_image_upload( + deployment: &DeploymentImpl, + mut multipart: Multipart, + link_task_id: Option, +) -> Result { let image_service = deployment.image(); + while let Some(field) = multipart.next_field().await? { if field.name() == Some("image") { let filename = field @@ -63,6 +77,15 @@ pub async fn upload_image( let data = field.bytes().await?; let image = image_service.store_image(&data, &filename).await?; + if let Some(task_id) = link_task_id { + TaskImage::associate_many_dedup( + &deployment.db().pool, + task_id, + std::slice::from_ref(&image.id), + ) + .await?; + } + deployment .track_if_analytics_allowed( "image_uploaded", @@ -70,18 +93,31 @@ pub async fn upload_image( "image_id": image.id.to_string(), "size_bytes": image.size_bytes, "mime_type": image.mime_type, + "task_id": link_task_id.map(|id| id.to_string()), }), ) .await; - let image_response = ImageResponse::from_image(image); - return Ok(ResponseJson(ApiResponse::success(image_response))); + return Ok(ImageResponse::from_image(image)); } } Err(ApiError::Image(ImageError::NotFound)) } +pub async fn upload_task_image( + Path(task_id): Path, + State(deployment): State, + multipart: Multipart, +) -> Result>, ApiError> { + Task::find_by_id(&deployment.db().pool, task_id) + .await? + .ok_or(ApiError::Database(SqlxError::RowNotFound))?; + + let image_response = process_image_upload(&deployment, multipart, Some(task_id)).await?; + Ok(ResponseJson(ApiResponse::success(image_response))) +} + /// Serve an image file by ID pub async fn serve_image( Path(image_id): Path, @@ -143,4 +179,8 @@ pub fn routes() -> Router { .route("/{id}/file", get(serve_image)) .route("/{id}", delete(delete_image)) .route("/task/{task_id}", get(get_task_images)) + .route( + "/task/{task_id}/upload", + post(upload_task_image).layer(DefaultBodyLimit::max(20 * 1024 * 1024)), + ) } diff --git a/crates/server/src/routes/mod.rs b/crates/server/src/routes/mod.rs index 8c267042..629b20be 100644 --- a/crates/server/src/routes/mod.rs +++ b/crates/server/src/routes/mod.rs @@ -11,6 +11,7 @@ pub mod config; pub mod containers; pub mod filesystem; // pub mod github; +pub mod drafts; pub mod events; pub mod execution_processes; pub mod frontend; @@ -28,6 +29,7 @@ pub fn router(deployment: DeploymentImpl) -> IntoMakeService { .merge(config::router()) .merge(containers::router(&deployment)) .merge(projects::router(&deployment)) + .merge(drafts::router(&deployment)) .merge(tasks::router(&deployment)) .merge(task_attempts::router(&deployment)) .merge(execution_processes::router(&deployment)) diff --git a/crates/server/src/routes/task_attempts.rs b/crates/server/src/routes/task_attempts.rs index 1ca6b606..9a67b0af 100644 --- a/crates/server/src/routes/task_attempts.rs +++ b/crates/server/src/routes/task_attempts.rs @@ -1,4 +1,5 @@ -use std::path::PathBuf; +pub mod drafts; +pub mod util; use axum::{ Extension, Json, Router, @@ -12,9 +13,8 @@ use axum::{ routing::{get, post}, }; use db::models::{ + draft::{Draft, DraftType}, execution_process::{ExecutionProcess, ExecutionProcessRunReason}, - follow_up_draft::FollowUpDraft, - image::TaskImage, merge::{Merge, MergeStatus, PrMerge, PullRequestInfo}, project::{Project, ProjectError}, task::{Task, TaskRelationships, TaskStatus}, @@ -33,16 +33,20 @@ use git2::BranchType; use serde::{Deserialize, Serialize}; use services::services::{ container::ContainerService, - git::ConflictOp, + git::{ConflictOp, WorktreeResetOptions}, github_service::{CreatePrRequest, GitHubService, GitHubServiceError}, - image::ImageService, }; use sqlx::Error as SqlxError; use ts_rs::TS; use utils::response::ApiResponse; use uuid::Uuid; -use crate::{DeploymentImpl, error::ApiError, middleware::load_task_attempt_middleware}; +use crate::{ + DeploymentImpl, + error::ApiError, + middleware::load_task_attempt_middleware, + routes::task_attempts::util::{ensure_worktree_path, handle_images_for_prompt}, +}; #[derive(Debug, Deserialize, Serialize, TS)] pub struct RebaseTaskAttemptRequest { @@ -191,6 +195,9 @@ pub struct CreateFollowUpAttempt { pub prompt: String, pub variant: Option, pub image_ids: Option>, + pub retry_process_id: Option, + pub force_when_dirty: Option, + pub perform_git_reset: Option, } pub async fn follow_up( @@ -201,46 +208,18 @@ pub async fn follow_up( tracing::info!("{:?}", task_attempt); // Ensure worktree exists (recreate if needed for cold task support) - deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; + let _ = ensure_worktree_path(&deployment, &task_attempt).await?; // Get latest session id (ignoring dropped) - let session_id = ExecutionProcess::find_latest_session_id_by_task_attempt( - &deployment.db().pool, - task_attempt.id, - ) - .await? - .ok_or(ApiError::TaskAttempt(TaskAttemptError::ValidationError( - "Couldn't find a prior session_id, please create a new task attempt".to_string(), - )))?; + let session_id = + ExecutionProcess::require_latest_session_id(&deployment.db().pool, task_attempt.id).await?; - // Get ExecutionProcess for profile data - let latest_execution_process = ExecutionProcess::find_latest_by_task_attempt_and_run_reason( + // Get executor profile data from the latest CodingAgent process + let initial_executor_profile_id = ExecutionProcess::latest_executor_profile_for_attempt( &deployment.db().pool, task_attempt.id, - &ExecutionProcessRunReason::CodingAgent, ) - .await? - .ok_or(ApiError::TaskAttempt(TaskAttemptError::ValidationError( - "Couldn't find initial coding agent process, has it run yet?".to_string(), - )))?; - let initial_executor_profile_id = match &latest_execution_process - .executor_action() - .map_err(|e| ApiError::TaskAttempt(TaskAttemptError::ValidationError(e.to_string())))? - .typ - { - ExecutorActionType::CodingAgentInitialRequest(request) => { - Ok(request.executor_profile_id.clone()) - } - ExecutorActionType::CodingAgentFollowUpRequest(request) => { - Ok(request.executor_profile_id.clone()) - } - _ => Err(ApiError::TaskAttempt(TaskAttemptError::ValidationError( - "Couldn't find profile from initial request".to_string(), - ))), - }?; + .await?; let executor_profile_id = ExecutorProfileId { executor: initial_executor_profile_id.executor, @@ -259,33 +238,74 @@ pub async fn follow_up( .await? .ok_or(SqlxError::RowNotFound)?; - let mut prompt = payload.prompt; - if let Some(image_ids) = &payload.image_ids { - TaskImage::associate_many_dedup(&deployment.db().pool, task.id, image_ids).await?; - - // Copy new images from the image cache to the worktree - if let Some(container_ref) = &task_attempt.container_ref { - let worktree_path = std::path::PathBuf::from(container_ref); - deployment - .image() - .copy_images_by_ids_to_worktree(&worktree_path, image_ids) - .await?; - - // Update image paths in prompt with full worktree path - prompt = ImageService::canonicalise_image_paths(&prompt, &worktree_path); + // If retry settings provided, perform replace-logic before proceeding + if let Some(proc_id) = payload.retry_process_id { + let pool = &deployment.db().pool; + // Validate process belongs to attempt + let process = + ExecutionProcess::find_by_id(pool, proc_id) + .await? + .ok_or(ApiError::TaskAttempt(TaskAttemptError::ValidationError( + "Process not found".to_string(), + )))?; + if process.task_attempt_id != task_attempt.id { + return Err(ApiError::TaskAttempt(TaskAttemptError::ValidationError( + "Process does not belong to this attempt".to_string(), + ))); } + + // Determine target reset OID: before the target process + let mut target_before_oid = process.before_head_commit.clone(); + if target_before_oid.is_none() { + target_before_oid = + ExecutionProcess::find_prev_after_head_commit(pool, task_attempt.id, proc_id) + .await?; + } + + // Decide if Git reset is needed and apply it (best-effort) + let force_when_dirty = payload.force_when_dirty.unwrap_or(false); + let perform_git_reset = payload.perform_git_reset.unwrap_or(true); + if let Some(target_oid) = &target_before_oid { + let wt_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let wt = wt_buf.as_path(); + let is_dirty = deployment + .container() + .is_container_clean(&task_attempt) + .await + .map(|is_clean| !is_clean) + .unwrap_or(false); + + deployment.git().reconcile_worktree_to_commit( + wt, + target_oid, + WorktreeResetOptions::new( + perform_git_reset, + force_when_dirty, + is_dirty, + perform_git_reset, + ), + ); + } + + // Stop any running processes for this attempt + deployment.container().try_stop(&task_attempt).await; + + // Soft-drop the target process and all later processes + let _ = ExecutionProcess::drop_at_and_after(pool, task_attempt.id, proc_id).await?; + + // Best-effort: clear any retry draft for this attempt + let _ = Draft::clear_after_send(pool, task_attempt.id, DraftType::Retry).await; } - let cleanup_action = project.cleanup_script.map(|script| { - Box::new(ExecutorAction::new( - ExecutorActionType::ScriptRequest(ScriptRequest { - script, - language: ScriptRequestLanguage::Bash, - context: ScriptContext::CleanupScript, - }), - None, - )) - }); + let mut prompt = payload.prompt; + if let Some(image_ids) = &payload.image_ids { + prompt = handle_images_for_prompt(&deployment, &task_attempt, task.id, image_ids, &prompt) + .await?; + } + + let cleanup_action = deployment + .container() + .cleanup_action(project.cleanup_script); let follow_up_request = CodingAgentFollowUpRequest { prompt, @@ -307,513 +327,18 @@ pub async fn follow_up( ) .await?; - // Clear any persisted follow-up draft for this attempt to avoid stale UI after manual send - let _ = FollowUpDraft::clear_after_send(&deployment.db().pool, task_attempt.id).await; + // Clear drafts post-send: + // - If this was a retry send, the retry draft has already been cleared above. + // - Otherwise, clear the follow-up draft to avoid. + if payload.retry_process_id.is_none() { + let _ = + Draft::clear_after_send(&deployment.db().pool, task_attempt.id, DraftType::FollowUp) + .await; + } Ok(ResponseJson(ApiResponse::success(execution_process))) } -// Follow-up draft APIs and queueing -#[derive(Debug, Serialize, TS)] -pub struct FollowUpDraftResponse { - pub task_attempt_id: Uuid, - pub prompt: String, - pub queued: bool, - pub variant: Option, - pub image_ids: Option>, // attachments - pub version: i64, -} - -#[derive(Debug, Deserialize, TS)] -pub struct UpdateFollowUpDraftRequest { - pub prompt: Option, - // Present with null explicitly clears variant; absent leaves unchanged - pub variant: Option>, - pub image_ids: Option>, // send empty array to clear; omit to leave unchanged - pub version: Option, // optimistic concurrency -} - -#[derive(Debug, Deserialize, TS)] -pub struct SetQueueRequest { - pub queued: bool, - pub expected_queued: Option, - pub expected_version: Option, -} - -async fn has_running_processes_for_attempt( - pool: &sqlx::SqlitePool, - attempt_id: Uuid, -) -> Result { - let processes = ExecutionProcess::find_by_task_attempt_id(pool, attempt_id, false).await?; - Ok(processes.into_iter().any(|p| { - matches!( - p.status, - db::models::execution_process::ExecutionProcessStatus::Running - ) - })) -} - -#[axum::debug_handler] -pub async fn get_follow_up_draft( - Extension(task_attempt): Extension, - State(deployment): State, -) -> Result>, ApiError> { - let pool = &deployment.db().pool; - let draft = FollowUpDraft::find_by_task_attempt_id(pool, task_attempt.id) - .await? - .map(|d| FollowUpDraftResponse { - task_attempt_id: d.task_attempt_id, - prompt: d.prompt, - queued: d.queued, - variant: d.variant, - image_ids: d.image_ids, - version: d.version, - }) - .unwrap_or(FollowUpDraftResponse { - task_attempt_id: task_attempt.id, - prompt: "".to_string(), - queued: false, - variant: None, - image_ids: None, - version: 0, - }); - Ok(ResponseJson(ApiResponse::success(draft))) -} - -#[axum::debug_handler] -pub async fn save_follow_up_draft( - Extension(task_attempt): Extension, - State(deployment): State, - Json(payload): Json, -) -> Result>, ApiError> { - let pool = &deployment.db().pool; - - // Enforce: cannot edit while queued - let d = match FollowUpDraft::find_by_task_attempt_id(pool, task_attempt.id).await? { - Some(d) => d, - None => { - // Create empty draft implicitly - let id = uuid::Uuid::new_v4(); - sqlx::query( - r#"INSERT INTO follow_up_drafts (id, task_attempt_id, prompt, queued, sending) - VALUES (?, ?, '', 0, 0)"#, - ) - .bind(id) - .bind(task_attempt.id) - .execute(pool) - .await?; - FollowUpDraft::find_by_task_attempt_id(pool, task_attempt.id) - .await? - .ok_or(SqlxError::RowNotFound)? - } - }; - if d.queued { - return Err(ApiError::Conflict( - "Draft is queued; click Edit to unqueue before editing".to_string(), - )); - } - - // Optimistic concurrency check - if let Some(expected_version) = payload.version - && d.version != expected_version - { - return Err(ApiError::Conflict( - "Draft changed, please retry with latest".to_string(), - )); - } - - if payload.prompt.is_none() && payload.variant.is_none() && payload.image_ids.is_none() { - // nothing to change; return current - } else { - // Build a conservative UPDATE using positional binds to avoid SQL builder quirks - let mut set_clauses: Vec<&str> = Vec::new(); - let mut has_variant_null = false; - if payload.prompt.is_some() { - set_clauses.push("prompt = ?"); - } - if let Some(variant_opt) = &payload.variant { - match variant_opt { - Some(_) => set_clauses.push("variant = ?"), - None => { - has_variant_null = true; - set_clauses.push("variant = NULL"); - } - } - } - if payload.image_ids.is_some() { - set_clauses.push("image_ids = ?"); - } - // Always bump metadata when something changes - set_clauses.push("updated_at = CURRENT_TIMESTAMP"); - set_clauses.push("version = version + 1"); - - let mut sql = String::from("UPDATE follow_up_drafts SET "); - sql.push_str(&set_clauses.join(", ")); - sql.push_str(" WHERE task_attempt_id = ?"); - - let mut q = sqlx::query(&sql); - if let Some(prompt) = &payload.prompt { - q = q.bind(prompt); - } - if let Some(variant_opt) = &payload.variant - && let Some(v) = variant_opt - { - q = q.bind(v); - } - if let Some(image_ids) = &payload.image_ids { - let image_ids_json = - serde_json::to_string(image_ids).unwrap_or_else(|_| "[]".to_string()); - q = q.bind(image_ids_json); - } - // WHERE bind - q = q.bind(task_attempt.id); - q.execute(pool).await?; - let _ = has_variant_null; // silence unused (document intent) - } - - // Ensure images are associated with the task for preview/loading - if let Some(image_ids) = &payload.image_ids - && !image_ids.is_empty() - { - // get parent task - let task = task_attempt - .parent_task(&deployment.db().pool) - .await? - .ok_or(SqlxError::RowNotFound)?; - TaskImage::associate_many_dedup(pool, task.id, image_ids).await?; - } - - // If queued and no process running for this attempt, attempt to start immediately. - // Use an atomic sending lock to prevent duplicate starts when concurrent requests occur. - let current = FollowUpDraft::find_by_task_attempt_id(pool, task_attempt.id).await?; - let should_consider_start = current.as_ref().map(|c| c.queued).unwrap_or(false) - && !has_running_processes_for_attempt(pool, task_attempt.id).await?; - if should_consider_start { - if FollowUpDraft::try_mark_sending(pool, task_attempt.id) - .await - .unwrap_or(false) - { - // Start follow up with saved draft - let _ = - start_follow_up_from_draft(&deployment, &task_attempt, current.as_ref().unwrap()) - .await; - } else { - tracing::debug!( - "Follow-up draft for attempt {} already being sent or not eligible", - task_attempt.id - ); - } - } - - // Return current draft state (may have been cleared if started immediately) - let current = FollowUpDraft::find_by_task_attempt_id(pool, task_attempt.id) - .await? - .map(|d| FollowUpDraftResponse { - task_attempt_id: d.task_attempt_id, - prompt: d.prompt, - queued: d.queued, - variant: d.variant, - image_ids: d.image_ids, - version: d.version, - }) - .unwrap_or(FollowUpDraftResponse { - task_attempt_id: task_attempt.id, - prompt: "".to_string(), - queued: false, - variant: None, - image_ids: None, - version: 0, - }); - - Ok(ResponseJson(ApiResponse::success(current))) -} - -#[axum::debug_handler] -pub async fn stream_follow_up_draft_ws( - ws: WebSocketUpgrade, - Extension(task_attempt): Extension, - State(deployment): State, -) -> impl IntoResponse { - ws.on_upgrade(move |socket| async move { - if let Err(e) = handle_follow_up_draft_ws(socket, deployment, task_attempt.id).await { - tracing::warn!("follow-up draft WS closed: {}", e); - } - }) -} - -async fn handle_follow_up_draft_ws( - socket: WebSocket, - deployment: DeploymentImpl, - task_attempt_id: uuid::Uuid, -) -> anyhow::Result<()> { - use futures_util::{SinkExt, StreamExt, TryStreamExt}; - - let mut stream = deployment - .events() - .stream_follow_up_draft_for_attempt_raw(task_attempt_id) - .await? - .map_ok(|msg| msg.to_ws_message_unchecked()); - - // Split socket into sender and receiver - let (mut sender, mut receiver) = socket.split(); - - // Drain (and ignore) any client->server messages so pings/pongs work - tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} }); - - // Forward server messages - while let Some(item) = stream.next().await { - match item { - Ok(msg) => { - if sender.send(msg).await.is_err() { - break; - } - } - Err(e) => { - tracing::error!("stream error: {}", e); - break; - } - } - } - Ok(()) -} - -#[axum::debug_handler] -pub async fn set_follow_up_queue( - Extension(task_attempt): Extension, - State(deployment): State, - Json(payload): Json, -) -> Result>, ApiError> { - let pool = &deployment.db().pool; - let Some(d) = FollowUpDraft::find_by_task_attempt_id(pool, task_attempt.id).await? else { - return Err(ApiError::Conflict("No draft to queue".to_string())); - }; - - // Optimistic concurrency: ensure caller's view matches current state (if provided) - if let Some(expected) = payload.expected_queued - && d.queued != expected - { - return Err(ApiError::Conflict( - "Draft state changed, please refresh and try again".to_string(), - )); - } - if let Some(expected_v) = payload.expected_version - && d.version != expected_v - { - return Err(ApiError::Conflict( - "Draft changed, please refresh and try again".to_string(), - )); - } - - if payload.queued { - let should_queue = !d.prompt.trim().is_empty(); - sqlx::query( - r#"UPDATE follow_up_drafts - SET queued = ?, updated_at = CURRENT_TIMESTAMP, version = version + 1 - WHERE task_attempt_id = ?"#, - ) - .bind(should_queue as i64) - .bind(task_attempt.id) - .execute(pool) - .await?; - } else { - // Unqueue - sqlx::query( - r#"UPDATE follow_up_drafts - SET queued = 0, updated_at = CURRENT_TIMESTAMP, version = version + 1 - WHERE task_attempt_id = ?"#, - ) - .bind(task_attempt.id) - .execute(pool) - .await?; - } - - // If queued and no process running for this attempt, attempt to start immediately. - let current = FollowUpDraft::find_by_task_attempt_id(pool, task_attempt.id).await?; - let should_consider_start = current.as_ref().map(|c| c.queued).unwrap_or(false) - && !has_running_processes_for_attempt(pool, task_attempt.id).await?; - if should_consider_start { - if FollowUpDraft::try_mark_sending(pool, task_attempt.id) - .await - .unwrap_or(false) - { - let _ = - start_follow_up_from_draft(&deployment, &task_attempt, current.as_ref().unwrap()) - .await; - } else { - // Schedule a short delayed recheck to handle timing edges - let deployment_clone = deployment.clone(); - let task_attempt_clone = task_attempt.clone(); - tokio::spawn(async move { - use std::time::Duration; - tokio::time::sleep(Duration::from_millis(1200)).await; - let pool = &deployment_clone.db().pool; - // Still no running process? - let running = match ExecutionProcess::find_by_task_attempt_id( - pool, - task_attempt_clone.id, - false, - ) - .await - { - Ok(procs) => procs.into_iter().any(|p| { - matches!( - p.status, - db::models::execution_process::ExecutionProcessStatus::Running - ) - }), - Err(_) => true, // assume running on error to avoid duplicate starts - }; - if running { - return; - } - // Still queued and eligible? - let draft = - match FollowUpDraft::find_by_task_attempt_id(pool, task_attempt_clone.id).await - { - Ok(Some(d)) if d.queued && !d.sending && !d.prompt.trim().is_empty() => d, - _ => return, - }; - if FollowUpDraft::try_mark_sending(pool, task_attempt_clone.id) - .await - .unwrap_or(false) - { - let _ = - start_follow_up_from_draft(&deployment_clone, &task_attempt_clone, &draft) - .await; - } - }); - } - } - - let d = FollowUpDraft::find_by_task_attempt_id(pool, task_attempt.id) - .await? - .ok_or(SqlxError::RowNotFound)?; - let resp = FollowUpDraftResponse { - task_attempt_id: d.task_attempt_id, - prompt: d.prompt, - queued: d.queued, - variant: d.variant, - image_ids: d.image_ids, - version: d.version, - }; - Ok(ResponseJson(ApiResponse::success(resp))) -} - -async fn start_follow_up_from_draft( - deployment: &DeploymentImpl, - task_attempt: &TaskAttempt, - draft: &FollowUpDraft, -) -> Result { - // Ensure worktree exists - deployment - .container() - .ensure_container_exists(task_attempt) - .await?; - - // Get latest session id (ignoring dropped) - let session_id = ExecutionProcess::find_latest_session_id_by_task_attempt( - &deployment.db().pool, - task_attempt.id, - ) - .await? - .ok_or(ApiError::TaskAttempt(TaskAttemptError::ValidationError( - "Couldn't find a prior session_id, please create a new task attempt".to_string(), - )))?; - - // Get latest coding agent process to inherit executor profile - let latest_execution_process = ExecutionProcess::find_latest_by_task_attempt_and_run_reason( - &deployment.db().pool, - task_attempt.id, - &ExecutionProcessRunReason::CodingAgent, - ) - .await? - .ok_or(ApiError::TaskAttempt(TaskAttemptError::ValidationError( - "Couldn't find initial coding agent process, has it run yet?".to_string(), - )))?; - let initial_executor_profile_id = match &latest_execution_process - .executor_action() - .map_err(|e| ApiError::TaskAttempt(TaskAttemptError::ValidationError(e.to_string())))? - .typ - { - ExecutorActionType::CodingAgentInitialRequest(request) => { - Ok(request.executor_profile_id.clone()) - } - ExecutorActionType::CodingAgentFollowUpRequest(request) => { - Ok(request.executor_profile_id.clone()) - } - _ => Err(ApiError::TaskAttempt(TaskAttemptError::ValidationError( - "Couldn't find profile from initial request".to_string(), - ))), - }?; - - // Inherit executor profile; override variant if provided in draft - let executor_profile_id = ExecutorProfileId { - executor: initial_executor_profile_id.executor, - variant: draft.variant.clone(), - }; - - // Get parent task -> project and cleanup action - let task = task_attempt - .parent_task(&deployment.db().pool) - .await? - .ok_or(SqlxError::RowNotFound)?; - let project = task - .parent_project(&deployment.db().pool) - .await? - .ok_or(SqlxError::RowNotFound)?; - - let cleanup_action = project.cleanup_script.map(|script| { - Box::new(ExecutorAction::new( - ExecutorActionType::ScriptRequest(ScriptRequest { - script, - language: ScriptRequestLanguage::Bash, - context: ScriptContext::CleanupScript, - }), - None, - )) - }); - - // Handle images: associate to task, copy to worktree, and canonicalize paths in prompt - let mut prompt = draft.prompt.clone(); - if let Some(image_ids) = &draft.image_ids { - TaskImage::associate_many_dedup(&deployment.db().pool, task_attempt.task_id, image_ids) - .await?; - if let Some(container_ref) = &task_attempt.container_ref { - let worktree_path = std::path::PathBuf::from(container_ref); - deployment - .image() - .copy_images_by_ids_to_worktree(&worktree_path, image_ids) - .await?; - prompt = ImageService::canonicalise_image_paths(&prompt, &worktree_path); - } - } - - let follow_up_request = CodingAgentFollowUpRequest { - prompt, - session_id, - executor_profile_id, - }; - - let follow_up_action = ExecutorAction::new( - ExecutorActionType::CodingAgentFollowUpRequest(follow_up_request), - cleanup_action, - ); - - let execution_process = deployment - .container() - .start_execution( - task_attempt, - &follow_up_action, - &ExecutionProcessRunReason::CodingAgent, - ) - .await?; - - // Best-effort: clear the draft after scheduling the execution - let _ = FollowUpDraft::clear_after_send(&deployment.db().pool, task_attempt.id).await; - - Ok(execution_process) -} - #[axum::debug_handler] pub async fn replace_process( Extension(task_attempt): Extension, @@ -849,55 +374,23 @@ pub async fn replace_process( // Decide if Git reset is needed and apply it let mut git_reset_needed = false; let mut git_reset_applied = false; - if perform_git_reset { - if let Some(target_oid) = &target_before_oid { - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let wt = std::path::Path::new(&container_ref); - let head_oid = deployment.git().get_head_info(wt).ok().map(|h| h.oid); - let is_dirty = deployment - .container() - .is_container_clean(&task_attempt) - .await - .map(|is_clean| !is_clean) - .unwrap_or(false); - if head_oid.as_deref() != Some(target_oid.as_str()) || is_dirty { - git_reset_needed = true; - if is_dirty && !force_when_dirty { - git_reset_applied = false; // cannot reset now - } else if let Err(e) = - deployment - .git() - .reset_worktree_to_commit(wt, target_oid, force_when_dirty) - { - tracing::error!("Failed to reset worktree: {}", e); - git_reset_applied = false; - } else { - git_reset_applied = true; - } - } - } - } else { - // Only compute necessity - if let Some(target_oid) = &target_before_oid { - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let wt = std::path::Path::new(&container_ref); - let head_oid = deployment.git().get_head_info(wt).ok().map(|h| h.oid); - let is_dirty = deployment - .container() - .is_container_clean(&task_attempt) - .await - .map(|is_clean| !is_clean) - .unwrap_or(false); - if head_oid.as_deref() != Some(target_oid.as_str()) || is_dirty { - git_reset_needed = true; - } - } + if let Some(target_oid) = &target_before_oid { + let wt_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let wt = wt_buf.as_path(); + let is_dirty = deployment + .container() + .is_container_clean(&task_attempt) + .await + .map(|is_clean| !is_clean) + .unwrap_or(false); + + let outcome = deployment.git().reconcile_worktree_to_commit( + wt, + target_oid, + WorktreeResetOptions::new(perform_git_reset, force_when_dirty, is_dirty, false), + ); + git_reset_needed = outcome.needed; + git_reset_applied = outcome.applied; } // Stop any running processes for this attempt @@ -1042,11 +535,8 @@ pub async fn get_commit_info( "Missing sha param".to_string(), ))); }; - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let wt = std::path::Path::new(&container_ref); + let wt_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let wt = wt_buf.as_path(); let subject = deployment.git().get_commit_subject(wt, &sha)?; Ok(ResponseJson(ApiResponse::success(CommitInfo { sha, @@ -1073,11 +563,8 @@ pub async fn compare_commit_to_head( "Missing sha param".to_string(), ))); }; - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let wt = std::path::Path::new(&container_ref); + let wt_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let wt = wt_buf.as_path(); let head_info = deployment.git().get_head_info(wt)?; let (ahead_from_head, behind_from_head) = deployment @@ -1106,11 +593,8 @@ pub async fn merge_task_attempt( .ok_or(ApiError::TaskAttempt(TaskAttemptError::TaskNotFound))?; let ctx = TaskAttempt::load_context(pool, task_attempt.id, task.id, task.project_id).await?; - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let worktree_path = std::path::Path::new(&container_ref); + let worktree_path_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let worktree_path = worktree_path_buf.as_path(); let task_uuid_str = task.id.to_string(); let first_uuid_section = task_uuid_str.split('-').next().unwrap_or(&task_uuid_str); @@ -1169,12 +653,7 @@ pub async fn push_task_attempt_branch( let github_service = GitHubService::new(&github_token)?; github_service.check_token().await?; - let ws_path = PathBuf::from( - deployment - .container() - .ensure_container_exists(&task_attempt) - .await?, - ); + let ws_path = ensure_worktree_path(&deployment, &task_attempt).await?; deployment .git() @@ -1218,12 +697,7 @@ pub async fn create_github_pr( .await? .ok_or(ApiError::Project(ProjectError::ProjectNotFound))?; - let workspace_path = PathBuf::from( - deployment - .container() - .ensure_container_exists(&task_attempt) - .await?, - ); + let workspace_path = ensure_worktree_path(&deployment, &task_attempt).await?; // Push the branch to GitHub first if let Err(e) = @@ -1334,22 +808,14 @@ pub async fn open_task_attempt_in_editor( Json(payload): Json>, ) -> Result>, ApiError> { // Get the task attempt to access the worktree path - let attempt = &task_attempt; - let base_path = attempt.container_ref.as_ref().ok_or_else(|| { - tracing::error!( - "No container ref found for task attempt {}", - task_attempt.id - ); - ApiError::TaskAttempt(TaskAttemptError::ValidationError( - "No container ref found".to_string(), - )) - })?; + let base_path_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let base_path = base_path_buf.as_path(); // If a specific file path is provided, use it; otherwise use the base path let path = if let Some(file_path) = payload.as_ref().and_then(|req| req.file_path.as_ref()) { - std::path::Path::new(base_path).join(file_path) + base_path.join(file_path) } else { - std::path::PathBuf::from(base_path) + base_path.to_path_buf() }; let editor_config = { @@ -1418,20 +884,14 @@ pub async fn get_task_attempt_branch_status( .ok() .map(|is_clean| !is_clean); let head_oid = { - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let wt = std::path::Path::new(&container_ref); + let wt_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let wt = wt_buf.as_path(); deployment.git().get_head_info(wt).ok().map(|h| h.oid) }; // Detect conflicts and operation in progress (best-effort) let (is_rebase_in_progress, conflicted_files, conflict_op) = { - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let wt = std::path::Path::new(&container_ref); + let wt_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let wt = wt_buf.as_path(); let in_rebase = deployment.git().is_rebase_in_progress(wt).unwrap_or(false); let conflicts = deployment .git() @@ -1445,11 +905,8 @@ pub async fn get_task_attempt_branch_status( (in_rebase, conflicts, op) }; let (uncommitted_count, untracked_count) = { - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let wt = std::path::Path::new(&container_ref); + let wt_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let wt = wt_buf.as_path(); match deployment.git().get_worktree_change_counts(wt) { Ok((a, b)) => (Some(a), Some(b)), Err(_) => (None, None), @@ -1633,11 +1090,8 @@ pub async fn rebase_task_attempt( } } - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let worktree_path = std::path::Path::new(&container_ref); + let worktree_path_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let worktree_path = worktree_path_buf.as_path(); let result = deployment.git().rebase_branch( &ctx.project.git_repo_path, @@ -1677,11 +1131,8 @@ pub async fn abort_conflicts_task_attempt( State(deployment): State, ) -> Result>, ApiError> { // Resolve worktree path for this attempt - let container_ref = deployment - .container() - .ensure_container_exists(&task_attempt) - .await?; - let worktree_path = std::path::Path::new(&container_ref); + let worktree_path_buf = ensure_worktree_path(&deployment, &task_attempt).await?; + let worktree_path = worktree_path_buf.as_path(); deployment.git().abort_conflicts(worktree_path)?; @@ -1827,11 +1278,12 @@ pub fn router(deployment: &DeploymentImpl) -> Router { .route("/", get(get_task_attempt)) .route("/follow-up", post(follow_up)) .route( - "/follow-up-draft", - get(get_follow_up_draft).put(save_follow_up_draft), + "/draft", + get(drafts::get_draft) + .put(drafts::save_draft) + .delete(drafts::delete_draft), ) - .route("/follow-up-draft/stream/ws", get(stream_follow_up_draft_ws)) - .route("/follow-up-draft/queue", post(set_follow_up_queue)) + .route("/draft/queue", post(drafts::set_draft_queue)) .route("/replace-process", post(replace_process)) .route("/commit-info", get(get_commit_info)) .route("/commit-compare", get(compare_commit_to_head)) diff --git a/crates/server/src/routes/task_attempts/drafts.rs b/crates/server/src/routes/task_attempts/drafts.rs new file mode 100644 index 00000000..d355d3eb --- /dev/null +++ b/crates/server/src/routes/task_attempts/drafts.rs @@ -0,0 +1,147 @@ +use axum::{Extension, Json, extract::State, response::Json as ResponseJson}; +use db::models::{ + draft::DraftType, + task_attempt::{TaskAttempt, TaskAttemptError}, +}; +use deployment::Deployment; +use serde::Deserialize; +use services::services::drafts::{ + DraftResponse, SetQueueRequest, UpdateFollowUpDraftRequest, UpdateRetryFollowUpDraftRequest, +}; +use utils::response::ApiResponse; + +use crate::{DeploymentImpl, error::ApiError}; + +#[derive(Debug, Deserialize)] +pub struct DraftTypeQuery { + #[serde(rename = "type")] + pub draft_type: DraftType, +} + +#[axum::debug_handler] +pub async fn save_follow_up_draft( + Extension(task_attempt): Extension, + State(deployment): State, + Json(payload): Json, +) -> Result>, ApiError> { + let service = deployment.drafts(); + let resp = service + .save_follow_up_draft(&task_attempt, &payload) + .await?; + Ok(ResponseJson(ApiResponse::success(resp))) +} + +#[axum::debug_handler] +pub async fn save_retry_follow_up_draft( + Extension(task_attempt): Extension, + State(deployment): State, + Json(payload): Json, +) -> Result>, ApiError> { + let service = deployment.drafts(); + let resp = service + .save_retry_follow_up_draft(&task_attempt, &payload) + .await?; + Ok(ResponseJson(ApiResponse::success(resp))) +} + +#[axum::debug_handler] +pub async fn delete_retry_follow_up_draft( + Extension(task_attempt): Extension, + State(deployment): State, +) -> Result>, ApiError> { + let service = deployment.drafts(); + service.delete_retry_follow_up_draft(&task_attempt).await?; + Ok(ResponseJson(ApiResponse::success(()))) +} + +#[axum::debug_handler] +pub async fn set_follow_up_queue( + Extension(task_attempt): Extension, + State(deployment): State, + Json(payload): Json, +) -> Result>, ApiError> { + let service = deployment.drafts(); + let resp = service + .set_follow_up_queue(deployment.container(), &task_attempt, &payload) + .await?; + Ok(ResponseJson(ApiResponse::success(resp))) +} + +#[axum::debug_handler] +pub async fn get_draft( + Extension(task_attempt): Extension, + State(deployment): State, + axum::extract::Query(q): axum::extract::Query, +) -> Result>, ApiError> { + let service = deployment.drafts(); + let resp = service.get_draft(task_attempt.id, q.draft_type).await?; + Ok(ResponseJson(ApiResponse::success(resp))) +} + +#[axum::debug_handler] +pub async fn save_draft( + Extension(task_attempt): Extension, + State(deployment): State, + axum::extract::Query(q): axum::extract::Query, + Json(payload): Json, +) -> Result>, ApiError> { + let service = deployment.drafts(); + match q.draft_type { + DraftType::FollowUp => { + let body: UpdateFollowUpDraftRequest = + serde_json::from_value(payload).map_err(|e| { + ApiError::TaskAttempt(TaskAttemptError::ValidationError(e.to_string())) + })?; + let resp = service.save_follow_up_draft(&task_attempt, &body).await?; + Ok(ResponseJson(ApiResponse::success(resp))) + } + DraftType::Retry => { + let body: UpdateRetryFollowUpDraftRequest = + serde_json::from_value(payload).map_err(|e| { + ApiError::TaskAttempt(TaskAttemptError::ValidationError(e.to_string())) + })?; + let resp = service + .save_retry_follow_up_draft(&task_attempt, &body) + .await?; + Ok(ResponseJson(ApiResponse::success(resp))) + } + } +} + +#[axum::debug_handler] +pub async fn delete_draft( + Extension(task_attempt): Extension, + State(deployment): State, + axum::extract::Query(q): axum::extract::Query, +) -> Result>, ApiError> { + let service = deployment.drafts(); + match q.draft_type { + DraftType::FollowUp => Err(ApiError::TaskAttempt(TaskAttemptError::ValidationError( + "Cannot delete follow-up draft; unqueue or edit instead".to_string(), + ))), + DraftType::Retry => { + service.delete_retry_follow_up_draft(&task_attempt).await?; + Ok(ResponseJson(ApiResponse::success(()))) + } + } +} + +#[axum::debug_handler] +pub async fn set_draft_queue( + Extension(task_attempt): Extension, + State(deployment): State, + axum::extract::Query(q): axum::extract::Query, + Json(payload): Json, +) -> Result>, ApiError> { + if q.draft_type != DraftType::FollowUp { + return Err(ApiError::TaskAttempt(TaskAttemptError::ValidationError( + "Queue is only supported for follow-up drafts".to_string(), + ))); + } + + let service = deployment.drafts(); + let resp = service + .set_follow_up_queue(deployment.container(), &task_attempt, &payload) + .await?; + Ok(ResponseJson(ApiResponse::success(resp))) +} diff --git a/crates/server/src/routes/task_attempts/util.rs b/crates/server/src/routes/task_attempts/util.rs new file mode 100644 index 00000000..4c0d3aa3 --- /dev/null +++ b/crates/server/src/routes/task_attempts/util.rs @@ -0,0 +1,45 @@ +use db::models::image::TaskImage; +use deployment::Deployment; +use services::services::{container::ContainerService, image::ImageService}; +use uuid::Uuid; + +use crate::error::ApiError; + +/// Resolve and ensure the worktree path for a task attempt. +pub async fn ensure_worktree_path( + deployment: &crate::DeploymentImpl, + attempt: &db::models::task_attempt::TaskAttempt, +) -> Result { + let container_ref = deployment + .container() + .ensure_container_exists(attempt) + .await?; + Ok(std::path::PathBuf::from(container_ref)) +} + +/// Associate images to the task, copy into worktree, and canonicalize paths in the prompt. +/// Returns the transformed prompt. +pub async fn handle_images_for_prompt( + deployment: &crate::DeploymentImpl, + attempt: &db::models::task_attempt::TaskAttempt, + task_id: Uuid, + image_ids: &[Uuid], + prompt: &str, +) -> Result { + if image_ids.is_empty() { + return Ok(prompt.to_string()); + } + + TaskImage::associate_many_dedup(&deployment.db().pool, task_id, image_ids).await?; + + // Copy to worktree and canonicalize + let worktree_path = ensure_worktree_path(deployment, attempt).await?; + deployment + .image() + .copy_images_by_ids_to_worktree(&worktree_path, image_ids) + .await?; + Ok(ImageService::canonicalise_image_paths( + prompt, + &worktree_path, + )) +} diff --git a/crates/services/src/services/container.rs b/crates/services/src/services/container.rs index 92588974..16ffc260 100644 --- a/crates/services/src/services/container.rs +++ b/crates/services/src/services/container.rs @@ -148,6 +148,19 @@ pub trait ContainerService { Ok(()) } + fn cleanup_action(&self, cleanup_script: Option) -> Option> { + cleanup_script.map(|script| { + Box::new(ExecutorAction::new( + ExecutorActionType::ScriptRequest(ScriptRequest { + script, + language: ScriptRequestLanguage::Bash, + context: ScriptContext::CleanupScript, + }), + None, + )) + }) + } + async fn try_stop(&self, task_attempt: &TaskAttempt) { // stop all execution processes for this attempt if let Ok(processes) = @@ -499,16 +512,7 @@ pub trait ContainerService { ); let prompt = ImageService::canonicalise_image_paths(&task.to_prompt(), &worktree_path); - let cleanup_action = project.cleanup_script.map(|script| { - Box::new(ExecutorAction::new( - ExecutorActionType::ScriptRequest(ScriptRequest { - script, - language: ScriptRequestLanguage::Bash, - context: ScriptContext::CleanupScript, - }), - None, - )) - }); + let cleanup_action = self.cleanup_action(project.cleanup_script); // Choose whether to execute the setup_script or coding agent first let execution_process = if let Some(setup_script) = project.setup_script { diff --git a/crates/services/src/services/drafts.rs b/crates/services/src/services/drafts.rs new file mode 100644 index 00000000..fae75650 --- /dev/null +++ b/crates/services/src/services/drafts.rs @@ -0,0 +1,474 @@ +use std::path::{Path, PathBuf}; + +use db::{ + DBService, + models::{ + draft::{Draft, DraftType, UpsertDraft}, + execution_process::{ExecutionProcess, ExecutionProcessError, ExecutionProcessRunReason}, + image::TaskImage, + task_attempt::TaskAttempt, + }, +}; +use executors::{ + actions::{ + ExecutorAction, ExecutorActionType, coding_agent_follow_up::CodingAgentFollowUpRequest, + }, + profile::ExecutorProfileId, +}; +use serde::{Deserialize, Serialize}; +use sqlx::Error as SqlxError; +use thiserror::Error; +use ts_rs::TS; +use uuid::Uuid; + +use super::{ + container::{ContainerError, ContainerService}, + image::{ImageError, ImageService}, +}; + +#[derive(Debug, Error)] +pub enum DraftsServiceError { + #[error(transparent)] + Database(#[from] sqlx::Error), + #[error(transparent)] + Container(#[from] ContainerError), + #[error(transparent)] + Image(#[from] ImageError), + #[error(transparent)] + ExecutionProcess(#[from] ExecutionProcessError), + #[error("Conflict: {0}")] + Conflict(String), +} + +#[derive(Debug, Serialize, TS)] +pub struct DraftResponse { + pub task_attempt_id: Uuid, + pub draft_type: DraftType, + pub retry_process_id: Option, + pub prompt: String, + pub queued: bool, + pub variant: Option, + pub image_ids: Option>, + pub version: i64, +} + +#[derive(Debug, Deserialize, TS)] +pub struct UpdateFollowUpDraftRequest { + pub prompt: Option, + pub variant: Option>, + pub image_ids: Option>, + pub version: Option, +} + +#[derive(Debug, Deserialize, TS)] +pub struct UpdateRetryFollowUpDraftRequest { + pub retry_process_id: Uuid, + pub prompt: Option, + pub variant: Option>, + pub image_ids: Option>, + pub version: Option, +} + +#[derive(Debug, Deserialize, TS)] +pub struct SetQueueRequest { + pub queued: bool, + pub expected_queued: Option, + pub expected_version: Option, +} + +#[derive(Clone)] +pub struct DraftsService { + db: DBService, + image: ImageService, +} + +impl DraftsService { + pub fn new(db: DBService, image: ImageService) -> Self { + Self { db, image } + } + + fn pool(&self) -> &sqlx::SqlitePool { + &self.db.pool + } + + fn draft_to_response(d: Draft) -> DraftResponse { + DraftResponse { + task_attempt_id: d.task_attempt_id, + draft_type: d.draft_type, + retry_process_id: d.retry_process_id, + prompt: d.prompt, + queued: d.queued, + variant: d.variant, + image_ids: d.image_ids, + version: d.version, + } + } + + async fn ensure_follow_up_draft_row( + &self, + attempt_id: Uuid, + ) -> Result { + if let Some(d) = + Draft::find_by_task_attempt_and_type(self.pool(), attempt_id, DraftType::FollowUp) + .await? + { + return Ok(d); + } + + let _ = Draft::upsert( + self.pool(), + &UpsertDraft { + task_attempt_id: attempt_id, + draft_type: DraftType::FollowUp, + retry_process_id: None, + prompt: "".to_string(), + queued: false, + variant: None, + image_ids: None, + }, + ) + .await?; + + Draft::find_by_task_attempt_and_type(self.pool(), attempt_id, DraftType::FollowUp) + .await? + .ok_or(SqlxError::RowNotFound) + .map_err(DraftsServiceError::from) + } + + async fn associate_images_for_task_if_any( + &self, + task_id: Uuid, + image_ids: &Option>, + ) -> Result<(), DraftsServiceError> { + if let Some(ids) = image_ids + && !ids.is_empty() + { + TaskImage::associate_many_dedup(self.pool(), task_id, ids).await?; + } + Ok(()) + } + + async fn has_running_processes_for_attempt( + &self, + attempt_id: Uuid, + ) -> Result { + let processes = + ExecutionProcess::find_by_task_attempt_id(self.pool(), attempt_id, false).await?; + Ok(processes.into_iter().any(|p| { + matches!( + p.status, + db::models::execution_process::ExecutionProcessStatus::Running + ) + })) + } + + async fn fetch_draft_response( + &self, + task_attempt_id: Uuid, + draft_type: DraftType, + ) -> Result { + let d = + Draft::find_by_task_attempt_and_type(self.pool(), task_attempt_id, draft_type).await?; + let resp = if let Some(d) = d { + Self::draft_to_response(d) + } else { + DraftResponse { + task_attempt_id, + draft_type, + retry_process_id: None, + prompt: "".to_string(), + queued: false, + variant: None, + image_ids: None, + version: 0, + } + }; + Ok(resp) + } + + async fn handle_images_for_prompt( + &self, + task_id: Uuid, + image_ids: &[Uuid], + prompt: &str, + worktree_path: &Path, + ) -> Result { + if image_ids.is_empty() { + return Ok(prompt.to_string()); + } + + TaskImage::associate_many_dedup(self.pool(), task_id, image_ids).await?; + self.image + .copy_images_by_ids_to_worktree(worktree_path, image_ids) + .await?; + Ok(ImageService::canonicalise_image_paths( + prompt, + worktree_path, + )) + } + + async fn start_follow_up_from_draft( + &self, + container: &(dyn ContainerService + Send + Sync), + task_attempt: &TaskAttempt, + draft: &Draft, + ) -> Result { + let worktree_ref = container.ensure_container_exists(task_attempt).await?; + let worktree_path = PathBuf::from(worktree_ref); + let session_id = + ExecutionProcess::require_latest_session_id(self.pool(), task_attempt.id).await?; + + let base_profile = + ExecutionProcess::latest_executor_profile_for_attempt(self.pool(), task_attempt.id) + .await?; + let executor_profile_id = ExecutorProfileId { + executor: base_profile.executor, + variant: draft.variant.clone(), + }; + + let task = task_attempt + .parent_task(self.pool()) + .await? + .ok_or(SqlxError::RowNotFound) + .map_err(DraftsServiceError::from)?; + let project = task + .parent_project(self.pool()) + .await? + .ok_or(SqlxError::RowNotFound) + .map_err(DraftsServiceError::from)?; + + let cleanup_action = container.cleanup_action(project.cleanup_script); + + let mut prompt = draft.prompt.clone(); + if let Some(image_ids) = &draft.image_ids { + prompt = self + .handle_images_for_prompt(task_attempt.task_id, image_ids, &prompt, &worktree_path) + .await?; + } + + let follow_up_request = CodingAgentFollowUpRequest { + prompt, + session_id, + executor_profile_id, + }; + + let follow_up_action = ExecutorAction::new( + ExecutorActionType::CodingAgentFollowUpRequest(follow_up_request), + cleanup_action, + ); + + let execution_process = container + .start_execution( + task_attempt, + &follow_up_action, + &ExecutionProcessRunReason::CodingAgent, + ) + .await?; + + let _ = Draft::clear_after_send(self.pool(), task_attempt.id, DraftType::FollowUp).await; + + Ok(execution_process) + } + + pub async fn save_follow_up_draft( + &self, + task_attempt: &TaskAttempt, + payload: &UpdateFollowUpDraftRequest, + ) -> Result { + let pool = self.pool(); + let d = self.ensure_follow_up_draft_row(task_attempt.id).await?; + if d.queued { + return Err(DraftsServiceError::Conflict( + "Draft is queued; click Edit to unqueue before editing".to_string(), + )); + } + + if let Some(expected_version) = payload.version + && d.version != expected_version + { + return Err(DraftsServiceError::Conflict( + "Draft changed, please retry with latest".to_string(), + )); + } + + if payload.prompt.is_none() && payload.variant.is_none() && payload.image_ids.is_none() { + } else { + Draft::update_partial( + pool, + task_attempt.id, + DraftType::FollowUp, + payload.prompt.clone(), + payload.variant.clone(), + payload.image_ids.clone(), + None, + ) + .await?; + } + + if let Some(task) = task_attempt.parent_task(pool).await? { + self.associate_images_for_task_if_any(task.id, &payload.image_ids) + .await?; + } + + let current = + Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::FollowUp) + .await? + .map(Self::draft_to_response) + .unwrap_or(DraftResponse { + task_attempt_id: task_attempt.id, + draft_type: DraftType::FollowUp, + retry_process_id: None, + prompt: "".to_string(), + queued: false, + variant: None, + image_ids: None, + version: 0, + }); + + Ok(current) + } + + pub async fn save_retry_follow_up_draft( + &self, + task_attempt: &TaskAttempt, + payload: &UpdateRetryFollowUpDraftRequest, + ) -> Result { + let pool = self.pool(); + let existing = + Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::Retry).await?; + + if let Some(d) = &existing { + if d.queued { + return Err(DraftsServiceError::Conflict( + "Retry draft is queued; unqueue before editing".to_string(), + )); + } + if let Some(expected_version) = payload.version + && d.version != expected_version + { + return Err(DraftsServiceError::Conflict( + "Retry draft changed, please retry with latest".to_string(), + )); + } + } + + if existing.is_none() { + let draft = Draft::upsert( + pool, + &UpsertDraft { + task_attempt_id: task_attempt.id, + draft_type: DraftType::Retry, + retry_process_id: Some(payload.retry_process_id), + prompt: payload.prompt.clone().unwrap_or_default(), + queued: false, + variant: payload.variant.clone().unwrap_or(None), + image_ids: payload.image_ids.clone(), + }, + ) + .await?; + + return Ok(Self::draft_to_response(draft)); + } + + if payload.prompt.is_none() && payload.variant.is_none() && payload.image_ids.is_none() { + } else { + Draft::update_partial( + pool, + task_attempt.id, + DraftType::Retry, + payload.prompt.clone(), + payload.variant.clone(), + payload.image_ids.clone(), + Some(payload.retry_process_id), + ) + .await?; + } + + if let Some(task) = task_attempt.parent_task(pool).await? { + self.associate_images_for_task_if_any(task.id, &payload.image_ids) + .await?; + } + + let draft = Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::Retry) + .await? + .ok_or(SqlxError::RowNotFound) + .map_err(DraftsServiceError::from)?; + Ok(Self::draft_to_response(draft)) + } + + pub async fn delete_retry_follow_up_draft( + &self, + task_attempt: &TaskAttempt, + ) -> Result<(), DraftsServiceError> { + Draft::delete_by_task_attempt_and_type(self.pool(), task_attempt.id, DraftType::Retry) + .await?; + + Ok(()) + } + + pub async fn set_follow_up_queue( + &self, + container: &(dyn ContainerService + Send + Sync), + task_attempt: &TaskAttempt, + payload: &SetQueueRequest, + ) -> Result { + let pool = self.pool(); + + let rows_updated = Draft::set_queued( + pool, + task_attempt.id, + DraftType::FollowUp, + payload.queued, + payload.expected_queued, + payload.expected_version, + ) + .await?; + + let draft = + Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::FollowUp) + .await?; + + if rows_updated == 0 { + if draft.is_none() { + return Err(DraftsServiceError::Conflict( + "No draft to queue".to_string(), + )); + }; + + return Err(DraftsServiceError::Conflict( + "Draft changed, please refresh and try again".to_string(), + )); + } + + let should_consider_start = draft.as_ref().map(|c| c.queued).unwrap_or(false) + && !self + .has_running_processes_for_attempt(task_attempt.id) + .await?; + + if should_consider_start + && Draft::try_mark_sending(pool, task_attempt.id, DraftType::FollowUp) + .await + .unwrap_or(false) + { + let _ = self + .start_follow_up_from_draft(container, task_attempt, draft.as_ref().unwrap()) + .await; + } + + let draft = + Draft::find_by_task_attempt_and_type(pool, task_attempt.id, DraftType::FollowUp) + .await? + .ok_or(SqlxError::RowNotFound) + .map_err(DraftsServiceError::from)?; + + Ok(Self::draft_to_response(draft)) + } + + pub async fn get_draft( + &self, + task_attempt_id: Uuid, + draft_type: DraftType, + ) -> Result { + self.fetch_draft_response(task_attempt_id, draft_type).await + } +} diff --git a/crates/services/src/services/events.rs b/crates/services/src/services/events.rs index eab9b197..ea071b58 100644 --- a/crates/services/src/services/events.rs +++ b/crates/services/src/services/events.rs @@ -1,202 +1,29 @@ use std::{str::FromStr, sync::Arc}; -use anyhow::Error as AnyhowError; use db::{ DBService, models::{ + draft::{Draft, DraftType}, execution_process::ExecutionProcess, - task::{Task, TaskWithAttemptStatus}, + task::Task, task_attempt::TaskAttempt, }, }; -use futures::StreamExt; -use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation}; -use serde::{Deserialize, Serialize}; use serde_json::json; -use sqlx::{Error as SqlxError, SqlitePool, ValueRef, sqlite::SqliteOperation}; -use strum_macros::{Display, EnumString}; -use thiserror::Error; +use sqlx::{Error as SqlxError, Sqlite, SqlitePool, decode::Decode, sqlite::SqliteOperation}; use tokio::sync::RwLock; -use tokio_stream::wrappers::BroadcastStream; -use ts_rs::TS; -use utils::{log_msg::LogMsg, msg_store::MsgStore}; +use utils::msg_store::MsgStore; use uuid::Uuid; -#[derive(Debug, Error)] -pub enum EventError { - #[error(transparent)] - Sqlx(#[from] SqlxError), - #[error(transparent)] - Parse(#[from] serde_json::Error), - #[error(transparent)] - Other(#[from] AnyhowError), // Catches any unclassified errors -} +#[path = "events/patches.rs"] +pub mod patches; +#[path = "events/streams.rs"] +mod streams; +#[path = "events/types.rs"] +pub mod types; -/// Trait for types that can be used in JSON patch operations -pub trait Patchable: serde::Serialize { - const PATH_PREFIX: &'static str; - type Id: ToString + Copy; - fn id(&self) -> Self::Id; -} - -/// Implementations of Patchable for all supported types -impl Patchable for TaskWithAttemptStatus { - const PATH_PREFIX: &'static str = "/tasks"; - type Id = Uuid; - fn id(&self) -> Self::Id { - self.id - } -} - -impl Patchable for ExecutionProcess { - const PATH_PREFIX: &'static str = "/execution_processes"; - type Id = Uuid; - fn id(&self) -> Self::Id { - self.id - } -} - -impl Patchable for TaskAttempt { - const PATH_PREFIX: &'static str = "/task_attempts"; - type Id = Uuid; - fn id(&self) -> Self::Id { - self.id - } -} - -impl Patchable for db::models::follow_up_draft::FollowUpDraft { - const PATH_PREFIX: &'static str = "/follow_up_drafts"; - type Id = Uuid; - fn id(&self) -> Self::Id { - self.id - } -} - -/// Generic patch operations that work with any Patchable type -pub mod patch_ops { - use super::*; - - /// Escape JSON Pointer special characters - pub(crate) fn escape_pointer_segment(s: &str) -> String { - s.replace('~', "~0").replace('/', "~1") - } - - /// Create path for operation - fn path_for(id: T::Id) -> String { - format!( - "{}/{}", - T::PATH_PREFIX, - escape_pointer_segment(&id.to_string()) - ) - } - - /// Create patch for adding a new record - pub fn add(value: &T) -> Patch { - Patch(vec![PatchOperation::Add(AddOperation { - path: path_for::(value.id()) - .try_into() - .expect("Path should be valid"), - value: serde_json::to_value(value).expect("Serialization should not fail"), - })]) - } - - /// Create patch for updating an existing record - pub fn replace(value: &T) -> Patch { - Patch(vec![PatchOperation::Replace(ReplaceOperation { - path: path_for::(value.id()) - .try_into() - .expect("Path should be valid"), - value: serde_json::to_value(value).expect("Serialization should not fail"), - })]) - } - - /// Create patch for removing a record - pub fn remove(id: T::Id) -> Patch { - Patch(vec![PatchOperation::Remove(RemoveOperation { - path: path_for::(id).try_into().expect("Path should be valid"), - })]) - } -} - -/// Helper functions for creating task-specific patches -pub mod task_patch { - use super::*; - - /// Create patch for adding a new task - pub fn add(task: &TaskWithAttemptStatus) -> Patch { - patch_ops::add(task) - } - - /// Create patch for updating an existing task - pub fn replace(task: &TaskWithAttemptStatus) -> Patch { - patch_ops::replace(task) - } - - /// Create patch for removing a task - pub fn remove(task_id: Uuid) -> Patch { - patch_ops::remove::(task_id) - } -} - -/// Helper functions for creating execution process-specific patches -pub mod execution_process_patch { - use super::*; - - /// Create patch for adding a new execution process - pub fn add(process: &ExecutionProcess) -> Patch { - patch_ops::add(process) - } - - /// Create patch for updating an existing execution process - pub fn replace(process: &ExecutionProcess) -> Patch { - patch_ops::replace(process) - } - - /// Create patch for removing an execution process - pub fn remove(process_id: Uuid) -> Patch { - patch_ops::remove::(process_id) - } -} - -/// Helper functions for creating task attempt-specific patches -pub mod task_attempt_patch { - use super::*; - - /// Create patch for adding a new task attempt - pub fn add(attempt: &TaskAttempt) -> Patch { - patch_ops::add(attempt) - } - - /// Create patch for updating an existing task attempt - pub fn replace(attempt: &TaskAttempt) -> Patch { - patch_ops::replace(attempt) - } - - /// Create patch for removing a task attempt - pub fn remove(attempt_id: Uuid) -> Patch { - patch_ops::remove::(attempt_id) - } -} - -/// Helper functions for creating follow up draft-specific patches -pub mod follow_up_draft_patch { - use super::*; - - /// Create patch for adding a new follow up draft - pub fn add(draft: &db::models::follow_up_draft::FollowUpDraft) -> Patch { - patch_ops::add(draft) - } - - /// Create patch for updating an existing follow up draft - pub fn replace(draft: &db::models::follow_up_draft::FollowUpDraft) -> Patch { - patch_ops::replace(draft) - } - - /// Create patch for removing a follow up draft - pub fn remove(draft_id: Uuid) -> Patch { - patch_ops::remove::(draft_id) - } -} +pub use patches::{draft_patch, execution_process_patch, task_attempt_patch, task_patch}; +pub use types::{EventError, EventPatch, EventPatchInner, HookTables, RecordTypes}; #[derive(Clone)] pub struct EventService { @@ -206,40 +33,6 @@ pub struct EventService { entry_count: Arc>, } -#[derive(EnumString, Display)] -enum HookTables { - #[strum(to_string = "tasks")] - Tasks, - #[strum(to_string = "task_attempts")] - TaskAttempts, - #[strum(to_string = "execution_processes")] - ExecutionProcesses, - #[strum(to_string = "follow_up_drafts")] - FollowUpDrafts, -} - -#[derive(Serialize, Deserialize, TS)] -#[serde(tag = "type", content = "data", rename_all = "SCREAMING_SNAKE_CASE")] -pub enum RecordTypes { - Task(Task), - TaskAttempt(TaskAttempt), - ExecutionProcess(ExecutionProcess), - FollowUpDraft(db::models::follow_up_draft::FollowUpDraft), -} - -#[derive(Serialize, Deserialize, TS)] -pub struct EventPatchInner { - db_op: String, - record: RecordTypes, -} - -#[derive(Serialize, Deserialize, TS)] -pub struct EventPatch { - op: String, - path: String, - value: EventPatchInner, -} - impl EventService { /// Creates a new EventService that will work with a DBService configured with hooks pub fn new(db: DBService, msg_store: Arc, entry_count: Arc>) -> Self { @@ -297,85 +90,67 @@ impl EventService { let msg_store_for_hook = msg_store.clone(); let entry_count_for_hook = entry_count.clone(); let db_for_hook = db_service.clone(); - Box::pin(async move { let mut handle = conn.lock_handle().await?; let runtime_handle = tokio::runtime::Handle::current(); - - // Set up preupdate hook to capture task data before deletion handle.set_preupdate_hook({ let msg_store_for_preupdate = msg_store_for_hook.clone(); move |preupdate: sqlx::sqlite::PreupdateHookResult<'_>| { - if preupdate.operation == sqlx::sqlite::SqliteOperation::Delete { - match preupdate.table { - "tasks" => { - // Extract task ID from old column values before deletion - if let Ok(id_value) = preupdate.get_old_column_value(0) - && !id_value.is_null() - { - // Decode UUID from SQLite value - if let Ok(task_id) = - >::decode( - id_value, - ) - { - let patch = task_patch::remove(task_id); - msg_store_for_preupdate.push_patch(patch); - } - } - } - "execution_processes" => { - // Extract process ID from old column values before deletion - if let Ok(id_value) = preupdate.get_old_column_value(0) - && !id_value.is_null() - { - // Decode UUID from SQLite value - if let Ok(process_id) = - >::decode( - id_value, - ) - { - let patch = execution_process_patch::remove(process_id); - msg_store_for_preupdate.push_patch(patch); - } - } - } - "task_attempts" => { - // Extract attempt ID from old column values before deletion - if let Ok(id_value) = preupdate.get_old_column_value(0) - && !id_value.is_null() - { - // Decode UUID from SQLite value - if let Ok(attempt_id) = - >::decode( - id_value, - ) - { - let patch = task_attempt_patch::remove(attempt_id); - msg_store_for_preupdate.push_patch(patch); - } - } - } - "follow_up_drafts" => { - // Extract draft ID from old column values before deletion - if let Ok(id_value) = preupdate.get_old_column_value(0) - && !id_value.is_null() - { - // Decode UUID from SQLite value - if let Ok(draft_id) = - >::decode( - id_value, - ) - { - let patch = follow_up_draft_patch::remove(draft_id); - msg_store_for_preupdate.push_patch(patch); - } - } - } - _ => { - // Ignore other tables + if preupdate.operation != SqliteOperation::Delete { + return; + } + + match preupdate.table { + "tasks" => { + if let Ok(value) = preupdate.get_old_column_value(0) + && let Ok(task_id) = >::decode(value) + { + let patch = task_patch::remove(task_id); + msg_store_for_preupdate.push_patch(patch); } } + "task_attempts" => { + if let Ok(value) = preupdate.get_old_column_value(0) + && let Ok(attempt_id) = >::decode(value) + { + let patch = task_attempt_patch::remove(attempt_id); + msg_store_for_preupdate.push_patch(patch); + } + } + "execution_processes" => { + if let Ok(value) = preupdate.get_old_column_value(0) + && let Ok(process_id) = >::decode(value) + { + let patch = execution_process_patch::remove(process_id); + msg_store_for_preupdate.push_patch(patch); + } + } + "drafts" => { + let draft_type = preupdate + .get_old_column_value(2) + .ok() + .and_then(|val| >::decode(val).ok()) + .and_then(|s| DraftType::from_str(&s).ok()); + let task_attempt_id = preupdate + .get_old_column_value(1) + .ok() + .and_then(|val| >::decode(val).ok()); + + if let (Some(draft_type), Some(task_attempt_id)) = + (draft_type, task_attempt_id) + { + let patch = match draft_type { + DraftType::FollowUp => { + draft_patch::follow_up_clear(task_attempt_id) + } + DraftType::Retry => { + draft_patch::retry_clear(task_attempt_id) + } + }; + msg_store_for_preupdate.push_patch(patch); + } + } + _ => {} } } }); @@ -390,28 +165,20 @@ impl EventService { let rowid = hook.rowid; runtime_handle.spawn(async move { let record_type: RecordTypes = match (table, hook.operation.clone()) { - (HookTables::Tasks, SqliteOperation::Delete) => { - // Task deletion is now handled by preupdate hook - // Skip post-update processing to avoid duplicate patches - return; - } - (HookTables::ExecutionProcesses, SqliteOperation::Delete) => { - // Execution process deletion is now handled by preupdate hook - // Skip post-update processing to avoid duplicate patches - return; - } - (HookTables::TaskAttempts, SqliteOperation::Delete) => { - // Task attempt deletion is now handled by preupdate hook - // Skip post-update processing to avoid duplicate patches + (HookTables::Tasks, SqliteOperation::Delete) + | (HookTables::TaskAttempts, SqliteOperation::Delete) + | (HookTables::ExecutionProcesses, SqliteOperation::Delete) + | (HookTables::Drafts, SqliteOperation::Delete) => { + // Deletions handled in preupdate hook for reliable data capture return; } (HookTables::Tasks, _) => { match Task::find_by_rowid(&db.pool, rowid).await { Ok(Some(task)) => RecordTypes::Task(task), - Ok(None) => { - // Row not found - likely already deleted, skip processing - tracing::debug!("Task rowid {} not found, skipping", rowid); - return; + Ok(None) => RecordTypes::DeletedTask { + rowid, + project_id: None, + task_id: None, }, Err(e) => { tracing::error!("Failed to fetch task: {:?}", e); @@ -422,10 +189,9 @@ impl EventService { (HookTables::TaskAttempts, _) => { match TaskAttempt::find_by_rowid(&db.pool, rowid).await { Ok(Some(attempt)) => RecordTypes::TaskAttempt(attempt), - Ok(None) => { - // Row not found - likely already deleted, skip processing - tracing::debug!("TaskAttempt rowid {} not found, skipping", rowid); - return; + Ok(None) => RecordTypes::DeletedTaskAttempt { + rowid, + task_id: None, }, Err(e) => { tracing::error!( @@ -439,10 +205,10 @@ impl EventService { (HookTables::ExecutionProcesses, _) => { match ExecutionProcess::find_by_rowid(&db.pool, rowid).await { Ok(Some(process)) => RecordTypes::ExecutionProcess(process), - Ok(None) => { - // Row not found - likely already deleted, skip processing - tracing::debug!("ExecutionProcess rowid {} not found, skipping", rowid); - return; + Ok(None) => RecordTypes::DeletedExecutionProcess { + rowid, + task_attempt_id: None, + process_id: None, }, Err(e) => { tracing::error!( @@ -453,28 +219,19 @@ impl EventService { } } } - (HookTables::FollowUpDrafts, SqliteOperation::Delete) => { - // Follow up draft deletion is now handled by preupdate hook - // Skip post-update processing to avoid duplicate patches - return; - } - (HookTables::FollowUpDrafts, _) => { - match db::models::follow_up_draft::FollowUpDraft::find_by_rowid( - &db.pool, rowid, - ) - .await - { - Ok(Some(draft)) => RecordTypes::FollowUpDraft(draft), - Ok(None) => { - // Row not found - likely already deleted, skip processing - tracing::debug!("FollowUpDraft rowid {} not found, skipping", rowid); - return; + (HookTables::Drafts, _) => { + match Draft::find_by_rowid(&db.pool, rowid).await { + Ok(Some(draft)) => match draft.draft_type { + DraftType::FollowUp => RecordTypes::Draft(draft), + DraftType::Retry => RecordTypes::RetryDraft(draft), + }, + Ok(None) => RecordTypes::DeletedDraft { + rowid, + draft_type: DraftType::Retry, + task_attempt_id: None, }, Err(e) => { - tracing::error!( - "Failed to fetch follow_up_draft: {:?}", - e - ); + tracing::error!("Failed to fetch draft: {:?}", e); return; } } @@ -514,6 +271,33 @@ impl EventService { return; } } + // Draft updates: emit direct patches used by the follow-up draft stream + RecordTypes::Draft(draft) => { + let patch = draft_patch::follow_up_replace(draft); + msg_store_for_hook.push_patch(patch); + return; + } + RecordTypes::RetryDraft(draft) => { + let patch = draft_patch::retry_replace(draft); + msg_store_for_hook.push_patch(patch); + return; + } + RecordTypes::DeletedDraft { draft_type, task_attempt_id: Some(id), .. } => { + let patch = match draft_type { + DraftType::FollowUp => draft_patch::follow_up_clear(*id), + DraftType::Retry => draft_patch::retry_clear(*id), + }; + msg_store_for_hook.push_patch(patch); + return; + } + RecordTypes::DeletedTask { + task_id: Some(task_id), + .. + } => { + let patch = task_patch::remove(*task_id); + msg_store_for_hook.push_patch(patch); + return; + } RecordTypes::TaskAttempt(attempt) => { // Task attempts should update the parent task with fresh data if let Ok(Some(task)) = @@ -532,6 +316,27 @@ impl EventService { return; } } + RecordTypes::DeletedTaskAttempt { + task_id: Some(task_id), + .. + } => { + // Task attempt deletion should update the parent task with fresh data + if let Ok(Some(task)) = + Task::find_by_id(&db.pool, *task_id).await + && let Ok(task_list) = + Task::find_by_project_id_with_attempt_status( + &db.pool, + task.project_id, + ) + .await + && let Some(task_with_status) = + task_list.into_iter().find(|t| t.id == *task_id) + { + let patch = task_patch::replace(&task_with_status); + msg_store_for_hook.push_patch(patch); + return; + } + } RecordTypes::ExecutionProcess(process) => { let patch = match hook.operation { SqliteOperation::Insert => { @@ -559,6 +364,31 @@ impl EventService { return; } + RecordTypes::DeletedExecutionProcess { + process_id: Some(process_id), + task_attempt_id, + .. + } => { + let patch = execution_process_patch::remove(*process_id); + msg_store_for_hook.push_patch(patch); + + if let Some(task_attempt_id) = task_attempt_id + && let Err(err) = + EventService::push_task_update_for_attempt( + &db.pool, + msg_store_for_hook.clone(), + *task_attempt_id, + ) + .await + { + tracing::error!( + "Failed to push task update after execution process removal: {:?}", + err + ); + } + + return; + } _ => {} } @@ -597,293 +427,4 @@ impl EventService { pub fn msg_store(&self) -> &Arc { &self.msg_store } - - /// Stream raw task messages for a specific project with initial snapshot - pub async fn stream_tasks_raw( - &self, - project_id: Uuid, - ) -> Result>, EventError> - { - // Get initial snapshot of tasks - let tasks = Task::find_by_project_id_with_attempt_status(&self.db.pool, project_id).await?; - - // Convert task array to object keyed by task ID - let tasks_map: serde_json::Map = tasks - .into_iter() - .map(|task| (task.id.to_string(), serde_json::to_value(task).unwrap())) - .collect(); - - let initial_patch = json!([{ - "op": "replace", - "path": "/tasks", - "value": tasks_map - }]); - let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap()); - - // Clone necessary data for the async filter - let db_pool = self.db.pool.clone(); - - // Get filtered event stream - let filtered_stream = - BroadcastStream::new(self.msg_store.get_receiver()).filter_map(move |msg_result| { - let db_pool = db_pool.clone(); - async move { - match msg_result { - Ok(LogMsg::JsonPatch(patch)) => { - // Filter events based on project_id - if let Some(patch_op) = patch.0.first() { - // Check if this is a direct task patch (new format) - if patch_op.path().starts_with("/tasks/") { - match patch_op { - json_patch::PatchOperation::Add(op) => { - // Parse task data directly from value - if let Ok(task) = - serde_json::from_value::( - op.value.clone(), - ) - && task.project_id == project_id - { - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } - json_patch::PatchOperation::Replace(op) => { - // Parse task data directly from value - if let Ok(task) = - serde_json::from_value::( - op.value.clone(), - ) - && task.project_id == project_id - { - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } - json_patch::PatchOperation::Remove(_) => { - // For remove operations, we need to check project membership differently - // We could cache this information or let it pass through for now - // Since we don't have the task data, we'll allow all removals - // and let the client handle filtering - return Some(Ok(LogMsg::JsonPatch(patch))); - } - _ => {} - } - } else if let Ok(event_patch_value) = serde_json::to_value(patch_op) - && let Ok(event_patch) = - serde_json::from_value::(event_patch_value) - { - // Handle old EventPatch format for non-task records - match &event_patch.value.record { - RecordTypes::Task(task) => { - if task.project_id == project_id { - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } - RecordTypes::TaskAttempt(attempt) => { - // Check if this task_attempt belongs to a task in our project - if let Ok(Some(task)) = - Task::find_by_id(&db_pool, attempt.task_id).await - && task.project_id == project_id - { - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } - _ => {} - } - } - } - None - } - Ok(other) => Some(Ok(other)), // Pass through non-patch messages - Err(_) => None, // Filter out broadcast errors - } - } - }); - - // Start with initial snapshot, then live updates - let initial_stream = futures::stream::once(async move { Ok(initial_msg) }); - let combined_stream = initial_stream.chain(filtered_stream).boxed(); - - Ok(combined_stream) - } - - /// Stream execution processes for a specific task attempt with initial snapshot (raw LogMsg format for WebSocket) - pub async fn stream_execution_processes_for_attempt_raw( - &self, - task_attempt_id: Uuid, - show_soft_deleted: bool, - ) -> Result>, EventError> - { - // Get initial snapshot of execution processes (filtering at SQL level) - let processes = ExecutionProcess::find_by_task_attempt_id( - &self.db.pool, - task_attempt_id, - show_soft_deleted, - ) - .await?; - - // Convert processes array to object keyed by process ID - let processes_map: serde_json::Map = processes - .into_iter() - .map(|process| { - ( - process.id.to_string(), - serde_json::to_value(process).unwrap(), - ) - }) - .collect(); - - let initial_patch = json!([{ - "op": "replace", - "path": "/execution_processes", - "value": processes_map - }]); - let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap()); - - // Get filtered event stream - let filtered_stream = BroadcastStream::new(self.msg_store.get_receiver()).filter_map( - move |msg_result| async move { - match msg_result { - Ok(LogMsg::JsonPatch(patch)) => { - // Filter events based on task_attempt_id - if let Some(patch_op) = patch.0.first() { - // Check if this is a modern execution process patch - if patch_op.path().starts_with("/execution_processes/") { - match patch_op { - json_patch::PatchOperation::Add(op) => { - // Parse execution process data directly from value - if let Ok(process) = - serde_json::from_value::( - op.value.clone(), - ) - && process.task_attempt_id == task_attempt_id - { - if !show_soft_deleted && process.dropped { - return None; - } - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } - json_patch::PatchOperation::Replace(op) => { - // Parse execution process data directly from value - if let Ok(process) = - serde_json::from_value::( - op.value.clone(), - ) - && process.task_attempt_id == task_attempt_id - { - if !show_soft_deleted && process.dropped { - let remove_patch = - execution_process_patch::remove(process.id); - return Some(Ok(LogMsg::JsonPatch(remove_patch))); - } - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } - json_patch::PatchOperation::Remove(_) => { - // For remove operations, we can't verify task_attempt_id - // so we allow all removals and let the client handle filtering - return Some(Ok(LogMsg::JsonPatch(patch))); - } - _ => {} - } - } - // Fallback to legacy EventPatch format for backward compatibility - else if let Ok(event_patch_value) = serde_json::to_value(patch_op) - && let Ok(event_patch) = - serde_json::from_value::(event_patch_value) - && let RecordTypes::ExecutionProcess(process) = - &event_patch.value.record - && process.task_attempt_id == task_attempt_id - { - if !show_soft_deleted && process.dropped { - let remove_patch = execution_process_patch::remove(process.id); - return Some(Ok(LogMsg::JsonPatch(remove_patch))); - } - return Some(Ok(LogMsg::JsonPatch(patch))); - } - } - None - } - Ok(other) => Some(Ok(other)), // Pass through non-patch messages - Err(_) => None, // Filter out broadcast errors - } - }, - ); - - // Start with initial snapshot, then live updates - let initial_stream = futures::stream::once(async move { Ok(initial_msg) }); - let combined_stream = initial_stream.chain(filtered_stream).boxed(); - - Ok(combined_stream) - } - - /// Stream follow-up draft for a specific task attempt (raw LogMsg format for WebSocket) - pub async fn stream_follow_up_draft_for_attempt_raw( - &self, - task_attempt_id: Uuid, - ) -> Result>, EventError> - { - // Get initial snapshot of follow-up draft - let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id( - &self.db.pool, - task_attempt_id, - ) - .await? - .unwrap_or(db::models::follow_up_draft::FollowUpDraft { - id: uuid::Uuid::new_v4(), - task_attempt_id, - prompt: String::new(), - queued: false, - sending: false, - variant: None, - image_ids: None, - created_at: chrono::Utc::now(), - updated_at: chrono::Utc::now(), - version: 0, - }); - - let initial_patch = json!([ - { - "op": "replace", - "path": "/", - "value": { "follow_up_draft": draft } - } - ]); - let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap()); - - // Filtered live stream, mapped into direct JSON patches that update /follow_up_draft - let filtered_stream = BroadcastStream::new(self.msg_store.get_receiver()).filter_map( - move |msg_result| async move { - match msg_result { - Ok(LogMsg::JsonPatch(patch)) => { - if let Some(event_patch_op) = patch.0.first() - && let Ok(event_patch_value) = serde_json::to_value(event_patch_op) - && let Ok(event_patch) = - serde_json::from_value::(event_patch_value) - && let RecordTypes::FollowUpDraft(draft) = &event_patch.value.record - && draft.task_attempt_id == task_attempt_id - { - // Build a direct patch to replace /follow_up_draft - let direct = json!([ - { - "op": "replace", - "path": "/follow_up_draft", - "value": draft - } - ]); - let direct_patch = serde_json::from_value(direct).unwrap(); - return Some(Ok(LogMsg::JsonPatch(direct_patch))); - } - None - } - Ok(other) => Some(Ok(other)), - Err(_) => None, - } - }, - ); - - let initial_stream = futures::stream::once(async move { Ok(initial_msg) }); - let combined_stream = initial_stream.chain(filtered_stream).boxed(); - - Ok(combined_stream) - } } diff --git a/crates/services/src/services/events/patches.rs b/crates/services/src/services/events/patches.rs new file mode 100644 index 00000000..d8299371 --- /dev/null +++ b/crates/services/src/services/events/patches.rs @@ -0,0 +1,204 @@ +use db::models::{ + draft::{Draft, DraftType}, + execution_process::ExecutionProcess, + task::TaskWithAttemptStatus, + task_attempt::TaskAttempt, +}; +use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation}; +use uuid::Uuid; + +// Shared helper to escape JSON Pointer segments +fn escape_pointer_segment(s: &str) -> String { + s.replace('~', "~0").replace('/', "~1") +} + +/// Helper functions for creating task-specific patches +pub mod task_patch { + use super::*; + + fn task_path(task_id: Uuid) -> String { + format!("/tasks/{}", escape_pointer_segment(&task_id.to_string())) + } + + /// Create patch for adding a new task + pub fn add(task: &TaskWithAttemptStatus) -> Patch { + Patch(vec![PatchOperation::Add(AddOperation { + path: task_path(task.id) + .try_into() + .expect("Task path should be valid"), + value: serde_json::to_value(task).expect("Task serialization should not fail"), + })]) + } + + /// Create patch for updating an existing task + pub fn replace(task: &TaskWithAttemptStatus) -> Patch { + Patch(vec![PatchOperation::Replace(ReplaceOperation { + path: task_path(task.id) + .try_into() + .expect("Task path should be valid"), + value: serde_json::to_value(task).expect("Task serialization should not fail"), + })]) + } + + /// Create patch for removing a task + pub fn remove(task_id: Uuid) -> Patch { + Patch(vec![PatchOperation::Remove(RemoveOperation { + path: task_path(task_id) + .try_into() + .expect("Task path should be valid"), + })]) + } +} + +/// Helper functions for creating execution process-specific patches +pub mod execution_process_patch { + use super::*; + + fn execution_process_path(process_id: Uuid) -> String { + format!( + "/execution_processes/{}", + escape_pointer_segment(&process_id.to_string()) + ) + } + + /// Create patch for adding a new execution process + pub fn add(process: &ExecutionProcess) -> Patch { + Patch(vec![PatchOperation::Add(AddOperation { + path: execution_process_path(process.id) + .try_into() + .expect("Execution process path should be valid"), + value: serde_json::to_value(process) + .expect("Execution process serialization should not fail"), + })]) + } + + /// Create patch for updating an existing execution process + pub fn replace(process: &ExecutionProcess) -> Patch { + Patch(vec![PatchOperation::Replace(ReplaceOperation { + path: execution_process_path(process.id) + .try_into() + .expect("Execution process path should be valid"), + value: serde_json::to_value(process) + .expect("Execution process serialization should not fail"), + })]) + } + + /// Create patch for removing an execution process + pub fn remove(process_id: Uuid) -> Patch { + Patch(vec![PatchOperation::Remove(RemoveOperation { + path: execution_process_path(process_id) + .try_into() + .expect("Execution process path should be valid"), + })]) + } +} + +/// Helper functions for creating draft-specific patches +pub mod draft_patch { + use super::*; + + fn follow_up_path(attempt_id: Uuid) -> String { + format!("/drafts/{attempt_id}/follow_up") + } + + fn retry_path(attempt_id: Uuid) -> String { + format!("/drafts/{attempt_id}/retry") + } + + /// Replace the follow-up draft for a specific attempt + pub fn follow_up_replace(draft: &Draft) -> Patch { + Patch(vec![PatchOperation::Replace(ReplaceOperation { + path: follow_up_path(draft.task_attempt_id) + .try_into() + .expect("Path should be valid"), + value: serde_json::to_value(draft).expect("Draft serialization should not fail"), + })]) + } + + /// Replace the retry draft for a specific attempt + pub fn retry_replace(draft: &Draft) -> Patch { + Patch(vec![PatchOperation::Replace(ReplaceOperation { + path: retry_path(draft.task_attempt_id) + .try_into() + .expect("Path should be valid"), + value: serde_json::to_value(draft).expect("Draft serialization should not fail"), + })]) + } + + /// Clear the follow-up draft for an attempt (replace with an empty draft) + pub fn follow_up_clear(attempt_id: Uuid) -> Patch { + let empty = Draft { + id: uuid::Uuid::new_v4(), + task_attempt_id: attempt_id, + draft_type: DraftType::FollowUp, + retry_process_id: None, + prompt: String::new(), + queued: false, + sending: false, + variant: None, + image_ids: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + version: 0, + }; + Patch(vec![PatchOperation::Replace(ReplaceOperation { + path: follow_up_path(attempt_id) + .try_into() + .expect("Path should be valid"), + value: serde_json::to_value(empty).expect("Draft serialization should not fail"), + })]) + } + + /// Clear the retry draft for an attempt (set to null) + pub fn retry_clear(attempt_id: Uuid) -> Patch { + Patch(vec![PatchOperation::Replace(ReplaceOperation { + path: retry_path(attempt_id) + .try_into() + .expect("Path should be valid"), + value: serde_json::Value::Null, + })]) + } +} + +/// Helper functions for creating task attempt-specific patches +pub mod task_attempt_patch { + use super::*; + + fn attempt_path(attempt_id: Uuid) -> String { + format!( + "/task_attempts/{}", + escape_pointer_segment(&attempt_id.to_string()) + ) + } + + /// Create patch for adding a new task attempt + pub fn add(attempt: &TaskAttempt) -> Patch { + Patch(vec![PatchOperation::Add(AddOperation { + path: attempt_path(attempt.id) + .try_into() + .expect("Task attempt path should be valid"), + value: serde_json::to_value(attempt) + .expect("Task attempt serialization should not fail"), + })]) + } + + /// Create patch for updating an existing task attempt + pub fn replace(attempt: &TaskAttempt) -> Patch { + Patch(vec![PatchOperation::Replace(ReplaceOperation { + path: attempt_path(attempt.id) + .try_into() + .expect("Task attempt path should be valid"), + value: serde_json::to_value(attempt) + .expect("Task attempt serialization should not fail"), + })]) + } + + /// Create patch for removing a task attempt + pub fn remove(attempt_id: Uuid) -> Patch { + Patch(vec![PatchOperation::Remove(RemoveOperation { + path: attempt_path(attempt_id) + .try_into() + .expect("Task attempt path should be valid"), + })]) + } +} diff --git a/crates/services/src/services/events/streams.rs b/crates/services/src/services/events/streams.rs new file mode 100644 index 00000000..8779ea32 --- /dev/null +++ b/crates/services/src/services/events/streams.rs @@ -0,0 +1,374 @@ +use db::models::{ + draft::{Draft, DraftType}, + execution_process::ExecutionProcess, + task::{Task, TaskWithAttemptStatus}, +}; +use futures::StreamExt; +use serde_json::json; +use tokio_stream::wrappers::BroadcastStream; +use utils::log_msg::LogMsg; +use uuid::Uuid; + +use super::{ + EventService, + patches::execution_process_patch, + types::{EventError, EventPatch, RecordTypes}, +}; + +impl EventService { + /// Stream raw task messages for a specific project with initial snapshot + pub async fn stream_tasks_raw( + &self, + project_id: Uuid, + ) -> Result>, EventError> + { + // Get initial snapshot of tasks + let tasks = Task::find_by_project_id_with_attempt_status(&self.db.pool, project_id).await?; + + // Convert task array to object keyed by task ID + let tasks_map: serde_json::Map = tasks + .into_iter() + .map(|task| (task.id.to_string(), serde_json::to_value(task).unwrap())) + .collect(); + + let initial_patch = json!([{ + "op": "replace", + "path": "/tasks", + "value": tasks_map + }]); + let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap()); + + // Clone necessary data for the async filter + let db_pool = self.db.pool.clone(); + + // Get filtered event stream + let filtered_stream = + BroadcastStream::new(self.msg_store.get_receiver()).filter_map(move |msg_result| { + let db_pool = db_pool.clone(); + async move { + match msg_result { + Ok(LogMsg::JsonPatch(patch)) => { + // Filter events based on project_id + if let Some(patch_op) = patch.0.first() { + // Check if this is a direct task patch (new format) + if patch_op.path().starts_with("/tasks/") { + match patch_op { + json_patch::PatchOperation::Add(op) => { + // Parse task data directly from value + if let Ok(task) = + serde_json::from_value::( + op.value.clone(), + ) + && task.project_id == project_id + { + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + json_patch::PatchOperation::Replace(op) => { + // Parse task data directly from value + if let Ok(task) = + serde_json::from_value::( + op.value.clone(), + ) + && task.project_id == project_id + { + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + json_patch::PatchOperation::Remove(_) => { + // For remove operations, we need to check project membership differently + // We could cache this information or let it pass through for now + // Since we don't have the task data, we'll allow all removals + // and let the client handle filtering + return Some(Ok(LogMsg::JsonPatch(patch))); + } + _ => {} + } + } else if let Ok(event_patch_value) = serde_json::to_value(patch_op) + && let Ok(event_patch) = + serde_json::from_value::(event_patch_value) + { + // Handle old EventPatch format for non-task records + match &event_patch.value.record { + RecordTypes::Task(task) => { + if task.project_id == project_id { + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + RecordTypes::DeletedTask { + project_id: Some(deleted_project_id), + .. + } => { + if *deleted_project_id == project_id { + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + RecordTypes::TaskAttempt(attempt) => { + // Check if this task_attempt belongs to a task in our project + if let Ok(Some(task)) = + Task::find_by_id(&db_pool, attempt.task_id).await + && task.project_id == project_id + { + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + RecordTypes::DeletedTaskAttempt { + task_id: Some(deleted_task_id), + .. + } => { + // Check if deleted attempt belonged to a task in our project + if let Ok(Some(task)) = + Task::find_by_id(&db_pool, *deleted_task_id).await + && task.project_id == project_id + { + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + _ => {} + } + } + } + None + } + Ok(other) => Some(Ok(other)), // Pass through non-patch messages + Err(_) => None, // Filter out broadcast errors + } + } + }); + + // Start with initial snapshot, then live updates + let initial_stream = futures::stream::once(async move { Ok(initial_msg) }); + let combined_stream = initial_stream.chain(filtered_stream).boxed(); + + Ok(combined_stream) + } + + /// Stream execution processes for a specific task attempt with initial snapshot (raw LogMsg format for WebSocket) + pub async fn stream_execution_processes_for_attempt_raw( + &self, + task_attempt_id: Uuid, + show_soft_deleted: bool, + ) -> Result>, EventError> + { + // Get initial snapshot of execution processes (filtering at SQL level) + let processes = ExecutionProcess::find_by_task_attempt_id( + &self.db.pool, + task_attempt_id, + show_soft_deleted, + ) + .await?; + + // Convert processes array to object keyed by process ID + let processes_map: serde_json::Map = processes + .into_iter() + .map(|process| { + ( + process.id.to_string(), + serde_json::to_value(process).unwrap(), + ) + }) + .collect(); + + let initial_patch = json!([{ + "op": "replace", + "path": "/execution_processes", + "value": processes_map + }]); + let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap()); + + // Get filtered event stream + let filtered_stream = BroadcastStream::new(self.msg_store.get_receiver()).filter_map( + move |msg_result| async move { + match msg_result { + Ok(LogMsg::JsonPatch(patch)) => { + // Filter events based on task_attempt_id + if let Some(patch_op) = patch.0.first() { + // Check if this is a modern execution process patch + if patch_op.path().starts_with("/execution_processes/") { + match patch_op { + json_patch::PatchOperation::Add(op) => { + // Parse execution process data directly from value + if let Ok(process) = + serde_json::from_value::( + op.value.clone(), + ) + && process.task_attempt_id == task_attempt_id + { + if !show_soft_deleted && process.dropped { + let remove_patch = + execution_process_patch::remove(process.id); + return Some(Ok(LogMsg::JsonPatch(remove_patch))); + } + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + json_patch::PatchOperation::Replace(op) => { + // Parse execution process data directly from value + if let Ok(process) = + serde_json::from_value::( + op.value.clone(), + ) + && process.task_attempt_id == task_attempt_id + { + if !show_soft_deleted && process.dropped { + let remove_patch = + execution_process_patch::remove(process.id); + return Some(Ok(LogMsg::JsonPatch(remove_patch))); + } + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + json_patch::PatchOperation::Remove(_) => { + // For remove operations, we can't verify task_attempt_id + // so we allow all removals and let the client handle filtering + return Some(Ok(LogMsg::JsonPatch(patch))); + } + _ => {} + } + } + // Fallback to legacy EventPatch format for backward compatibility + else if let Ok(event_patch_value) = serde_json::to_value(patch_op) + && let Ok(event_patch) = + serde_json::from_value::(event_patch_value) + { + match &event_patch.value.record { + RecordTypes::ExecutionProcess(process) => { + if process.task_attempt_id == task_attempt_id { + if !show_soft_deleted && process.dropped { + let remove_patch = + execution_process_patch::remove(process.id); + return Some(Ok(LogMsg::JsonPatch(remove_patch))); + } + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + RecordTypes::DeletedExecutionProcess { + task_attempt_id: Some(deleted_attempt_id), + .. + } => { + if *deleted_attempt_id == task_attempt_id { + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + _ => {} + } + } + } + None + } + Ok(other) => Some(Ok(other)), // Pass through non-patch messages + Err(_) => None, // Filter out broadcast errors + } + }, + ); + + // Start with initial snapshot, then live updates + let initial_stream = futures::stream::once(async move { Ok(initial_msg) }); + let combined_stream = initial_stream.chain(filtered_stream).boxed(); + + Ok(combined_stream) + } + + /// Stream drafts for all task attempts in a project with initial snapshot (raw LogMsg) + pub async fn stream_drafts_for_project_raw( + &self, + project_id: Uuid, + ) -> Result>, EventError> + { + // Load all attempt ids for tasks in this project + let attempt_ids: Vec = sqlx::query_scalar( + r#"SELECT ta.id + FROM task_attempts ta + JOIN tasks t ON t.id = ta.task_id + WHERE t.project_id = ?"#, + ) + .bind(project_id) + .fetch_all(&self.db.pool) + .await?; + + // Build initial drafts map keyed by attempt_id + let mut drafts_map: serde_json::Map = serde_json::Map::new(); + for attempt_id in attempt_ids { + let fu = Draft::find_by_task_attempt_and_type( + &self.db.pool, + attempt_id, + DraftType::FollowUp, + ) + .await? + .unwrap_or(Draft { + id: uuid::Uuid::new_v4(), + task_attempt_id: attempt_id, + draft_type: DraftType::FollowUp, + retry_process_id: None, + prompt: String::new(), + queued: false, + sending: false, + variant: None, + image_ids: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + version: 0, + }); + let re = + Draft::find_by_task_attempt_and_type(&self.db.pool, attempt_id, DraftType::Retry) + .await?; + let entry = json!({ + "follow_up": fu, + "retry": serde_json::to_value(re).unwrap_or(serde_json::Value::Null), + }); + drafts_map.insert(attempt_id.to_string(), entry); + } + + let initial_patch = json!([ + { + "op": "replace", + "path": "/drafts", + "value": drafts_map + } + ]); + let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap()); + + let db_pool = self.db.pool.clone(); + // Live updates: accept direct draft patches and filter by project membership + let filtered_stream = + BroadcastStream::new(self.msg_store.get_receiver()).filter_map(move |msg_result| { + let db_pool = db_pool.clone(); + async move { + match msg_result { + Ok(LogMsg::JsonPatch(patch)) => { + if let Some(op) = patch.0.first() { + let path = op.path(); + if let Some(rest) = path.strip_prefix("/drafts/") + && let Some((attempt_str, _)) = rest.split_once('/') + && let Ok(attempt_id) = Uuid::parse_str(attempt_str) + { + // Check project membership + if let Ok(Some(task_attempt)) = + db::models::task_attempt::TaskAttempt::find_by_id( + &db_pool, attempt_id, + ) + .await + && let Ok(Some(task)) = db::models::task::Task::find_by_id( + &db_pool, + task_attempt.task_id, + ) + .await + && task.project_id == project_id + { + return Some(Ok(LogMsg::JsonPatch(patch))); + } + } + } + None + } + Ok(other) => Some(Ok(other)), + Err(_) => None, + } + } + }); + + let initial_stream = futures::stream::once(async move { Ok(initial_msg) }); + let combined_stream = initial_stream.chain(filtered_stream).boxed(); + Ok(combined_stream) + } +} diff --git a/crates/services/src/services/events/types.rs b/crates/services/src/services/events/types.rs new file mode 100644 index 00000000..ba1eaec4 --- /dev/null +++ b/crates/services/src/services/events/types.rs @@ -0,0 +1,77 @@ +use anyhow::Error as AnyhowError; +use db::models::{ + draft::{Draft, DraftType}, + execution_process::ExecutionProcess, + task::Task, + task_attempt::TaskAttempt, +}; +use serde::{Deserialize, Serialize}; +use sqlx::Error as SqlxError; +use strum_macros::{Display, EnumString}; +use thiserror::Error; +use ts_rs::TS; +use uuid::Uuid; + +#[derive(Debug, Error)] +pub enum EventError { + #[error(transparent)] + Sqlx(#[from] SqlxError), + #[error(transparent)] + Parse(#[from] serde_json::Error), + #[error(transparent)] + Other(#[from] AnyhowError), // Catches any unclassified errors +} + +#[derive(EnumString, Display)] +pub enum HookTables { + #[strum(to_string = "tasks")] + Tasks, + #[strum(to_string = "task_attempts")] + TaskAttempts, + #[strum(to_string = "execution_processes")] + ExecutionProcesses, + #[strum(to_string = "drafts")] + Drafts, +} + +#[derive(Serialize, Deserialize, TS)] +#[serde(tag = "type", content = "data", rename_all = "SCREAMING_SNAKE_CASE")] +pub enum RecordTypes { + Task(Task), + TaskAttempt(TaskAttempt), + ExecutionProcess(ExecutionProcess), + Draft(Draft), + RetryDraft(Draft), + DeletedTask { + rowid: i64, + project_id: Option, + task_id: Option, + }, + DeletedTaskAttempt { + rowid: i64, + task_id: Option, + }, + DeletedExecutionProcess { + rowid: i64, + task_attempt_id: Option, + process_id: Option, + }, + DeletedDraft { + rowid: i64, + draft_type: DraftType, + task_attempt_id: Option, + }, +} + +#[derive(Serialize, Deserialize, TS)] +pub struct EventPatchInner { + pub(crate) db_op: String, + pub(crate) record: RecordTypes, +} + +#[derive(Serialize, Deserialize, TS)] +pub struct EventPatch { + pub(crate) op: String, + pub(crate) path: String, + pub(crate) value: EventPatchInner, +} diff --git a/crates/services/src/services/git.rs b/crates/services/src/services/git.rs index eef1738e..9d453058 100644 --- a/crates/services/src/services/git.rs +++ b/crates/services/src/services/git.rs @@ -89,6 +89,36 @@ impl std::fmt::Display for Commit { } } +#[derive(Debug, Clone, Copy)] +pub struct WorktreeResetOptions { + pub perform_reset: bool, + pub force_when_dirty: bool, + pub is_dirty: bool, + pub log_skip_when_dirty: bool, +} + +impl WorktreeResetOptions { + pub fn new( + perform_reset: bool, + force_when_dirty: bool, + is_dirty: bool, + log_skip_when_dirty: bool, + ) -> Self { + Self { + perform_reset, + force_when_dirty, + is_dirty, + log_skip_when_dirty, + } + } +} + +#[derive(Debug, Default, Clone, Copy)] +pub struct WorktreeResetOutcome { + pub needed: bool, + pub applied: bool, +} + /// Target for diff generation pub enum DiffTarget<'p> { /// Work-in-progress branch checked out in this worktree @@ -1074,6 +1104,47 @@ impl GitService { .map_err(|e| GitServiceError::InvalidRepository(format!("git status failed: {e}"))) } + /// Evaluate whether any action is needed to reset to `target_commit_oid` and + /// optionally perform the actions. + pub fn reconcile_worktree_to_commit( + &self, + worktree_path: &Path, + target_commit_oid: &str, + options: WorktreeResetOptions, + ) -> WorktreeResetOutcome { + let WorktreeResetOptions { + perform_reset, + force_when_dirty, + is_dirty, + log_skip_when_dirty, + } = options; + + let head_oid = self.get_head_info(worktree_path).ok().map(|h| h.oid); + let mut outcome = WorktreeResetOutcome::default(); + + if head_oid.as_deref() != Some(target_commit_oid) || is_dirty { + outcome.needed = true; + + if perform_reset { + if is_dirty && !force_when_dirty { + if log_skip_when_dirty { + tracing::warn!("Worktree dirty; skipping reset as not forced"); + } + } else if let Err(e) = self.reset_worktree_to_commit( + worktree_path, + target_commit_oid, + force_when_dirty, + ) { + tracing::error!("Failed to reset worktree: {}", e); + } else { + outcome.applied = true; + } + } + } + + outcome + } + /// Reset the given worktree to the specified commit SHA. /// If `force` is false and the worktree is dirty, returns WorktreeDirty error. pub fn reset_worktree_to_commit( diff --git a/crates/services/src/services/mod.rs b/crates/services/src/services/mod.rs index 9c842427..fcbf68b8 100644 --- a/crates/services/src/services/mod.rs +++ b/crates/services/src/services/mod.rs @@ -3,6 +3,7 @@ pub mod approvals; pub mod auth; pub mod config; pub mod container; +pub mod drafts; pub mod events; pub mod file_ranker; pub mod file_search_cache; diff --git a/frontend/package.json b/frontend/package.json index e4a4ba66..e3e846ed 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -54,6 +54,7 @@ "react-virtuoso": "^4.14.0", "react-window": "^1.8.11", "rfc6902": "^5.1.2", + "react-use-websocket": "^4.7.0", "tailwind-merge": "^2.2.0", "tailwindcss-animate": "^1.0.7", "vibe-kanban-web-companion": "^0.0.4", diff --git a/frontend/src/components/NormalizedConversation/DisplayConversationEntry.tsx b/frontend/src/components/NormalizedConversation/DisplayConversationEntry.tsx index 2b57cbb9..d67a167b 100644 --- a/frontend/src/components/NormalizedConversation/DisplayConversationEntry.tsx +++ b/frontend/src/components/NormalizedConversation/DisplayConversationEntry.tsx @@ -30,6 +30,7 @@ import RawLogText from '../common/RawLogText'; import UserMessage from './UserMessage'; import PendingApprovalEntry from './PendingApprovalEntry'; import { cn } from '@/lib/utils'; +import { useRetryUi } from '@/contexts/RetryUiContext'; type Props = { entry: NormalizedEntry | ProcessStartPayload; @@ -612,14 +613,19 @@ function DisplayConversationEntry({ entry: NormalizedEntry | ProcessStartPayload ): entry is ProcessStartPayload => 'processId' in entry; + const { isProcessGreyed } = useRetryUi(); + const greyed = isProcessGreyed(executionProcessId); + if (isProcessStart(entry)) { const toolAction: any = entry.action ?? null; return ( - +
+ +
); } @@ -644,7 +650,6 @@ function DisplayConversationEntry({ } const renderToolUse = () => { if (!isNormalizedEntry(entry)) return null; - if (entryType.type !== 'tool_use') return null; const toolEntry = entryType; @@ -698,7 +703,13 @@ function DisplayConversationEntry({ ); })(); - const content =
{body}
; + const content = ( +
+ {body} +
+ ); if (isPendingApprovalStatus(status)) { return ( @@ -720,7 +731,9 @@ function DisplayConversationEntry({ if (isSystem || isError) { return ( -
+
void; +}) { + const { t } = useTranslation(['common']); + const attemptId = attempt.id; + const { retryDraft, isRetryLoaded } = useDraftStream(attemptId); + const { isAttemptRunning, attemptData } = useAttemptExecution(attemptId); + const { data: branchStatus } = useBranchStatus(attemptId); + const { profiles } = useUserSystem(); + + // Errors are now reserved for send/cancel; creation occurs outside via useProcessRetry + const [initError] = useState(null); + + const draft = useMemo(() => { + if (!retryDraft || retryDraft.retry_process_id !== executionProcessId) { + return null; + } + return { + ...retryDraft, + retry_process_id: executionProcessId, + }; + }, [retryDraft, executionProcessId]); + + const { + message, + setMessage, + images, + setImages, + handleImageUploaded, + clearImagesAndUploads, + } = useDraftEditor({ + draft, + taskId: attempt.task_id, + }); + + // Presentation-only: show/hide image upload panel + const [showImageUpload, setShowImageUpload] = useState(false); + + // Variant selection: start with initialVariant or draft.variant + const [selectedVariant, setSelectedVariant] = useState( + draft?.variant ?? initialVariant ?? null + ); + useEffect(() => { + if (draft?.variant !== undefined) setSelectedVariant(draft.variant ?? null); + }, [draft?.variant]); + + const { isSaving, saveStatus } = useDraftAutosave({ + draftType: 'retry', + attemptId, + serverDraft: draft, + current: { + prompt: message, + variant: selectedVariant, + image_ids: images.map((img) => img.id), + retry_process_id: executionProcessId, + }, + isDraftSending: false, + }); + + const [sendError, setSendError] = useState(null); + const [isSending, setIsSending] = useState(false); + // Show overlay and keep UI disabled while waiting for server to clear retry_draft + const [isFinalizing, setIsFinalizing] = useState( + false + ); + const canSend = !isAttemptRunning && !!(message.trim() || images.length > 0); + + const onCancel = async () => { + setSendError(null); + setIsFinalizing('cancel'); + try { + await attemptsApi.deleteDraft(attemptId, 'retry'); + } catch (error: unknown) { + setIsFinalizing(false); + setSendError((error as Error)?.message || 'Failed to cancel retry'); + } + }; + + // Safety net: if server provided a draft but local message is empty, force-apply once + useEffect(() => { + if (!isRetryLoaded || !draft) return; + const serverPrompt = draft.prompt || ''; + if (message === '' && serverPrompt !== '') { + setMessage(serverPrompt); + if (import.meta.env.DEV) { + // One-shot debug to validate hydration ordering in dev + console.debug('[retry/hydrate] applied server prompt fallback', { + attemptId, + processId: executionProcessId, + len: serverPrompt.length, + }); + } + } + }, [ + attemptId, + draft, + executionProcessId, + isRetryLoaded, + message, + setMessage, + ]); + + const onSend = async () => { + if (!canSend) return; + setSendError(null); + setIsSending(true); + try { + // Fetch process details and compute confirmation payload + const proc = await executionProcessesApi.getDetails(executionProcessId); + type WithBefore = { before_head_commit?: string | null }; + const before = (proc as WithBefore)?.before_head_commit || null; + let targetSubject: string | null = null; + let commitsToReset: number | null = null; + let isLinear: boolean | null = null; + if (before) { + try { + const info = await commitsApi.getInfo(attemptId, before); + targetSubject = info.subject; + const cmp = await commitsApi.compareToHead(attemptId, before); + commitsToReset = cmp.is_linear ? cmp.ahead_from_head : null; + isLinear = cmp.is_linear; + } catch { + /* ignore */ + } + } + + const head = branchStatus?.head_oid || null; + const dirty = !!branchStatus?.has_uncommitted_changes; + const needReset = !!(before && (before !== head || dirty)); + const canGitReset = needReset && !dirty; + + // Compute later processes summary for UI + const procs = (attemptData.processes || []).filter( + (p) => !p.dropped && shouldShowInLogs(p.run_reason) + ); + const idx = procs.findIndex((p) => p.id === executionProcessId); + const later = idx >= 0 ? procs.slice(idx + 1) : []; + const laterCount = later.length; + const laterCoding = later.filter((p) => + isCodingAgent(p.run_reason) + ).length; + const laterSetup = later.filter( + (p) => p.run_reason === PROCESS_RUN_REASONS.SETUP_SCRIPT + ).length; + const laterCleanup = later.filter( + (p) => p.run_reason === PROCESS_RUN_REASONS.CLEANUP_SCRIPT + ).length; + + // Ask user for confirmation + let modalResult: RestoreLogsDialogResult | undefined; + try { + modalResult = await showModal('restore-logs', { + targetSha: before, + targetSubject, + commitsToReset, + isLinear, + laterCount, + laterCoding, + laterSetup, + laterCleanup, + needGitReset: needReset, + canGitReset, + hasRisk: dirty, + uncommittedCount: branchStatus?.uncommitted_count ?? 0, + untrackedCount: branchStatus?.untracked_count ?? 0, + initialWorktreeResetOn: true, + initialForceReset: false, + }); + } catch { + setIsSending(false); + return; // dialog closed + } + if (!modalResult || modalResult.action !== 'confirmed') { + setIsSending(false); + return; + } + + await attemptsApi.followUp(attemptId, { + prompt: message, + variant: selectedVariant, + image_ids: images.map((img) => img.id), + retry_process_id: executionProcessId, + force_when_dirty: modalResult.forceWhenDirty ?? false, + perform_git_reset: modalResult.performGitReset ?? true, + }); + clearImagesAndUploads(); + // Keep overlay up until stream clears the retry draft + setIsFinalizing('send'); + } catch (error: unknown) { + setSendError((error as Error)?.message || 'Failed to send retry'); + setIsSending(false); + setIsFinalizing(false); + } + }; + + // Once server stream clears retry_draft, exit retry mode (both cancel and send) + useEffect(() => { + const stillRetrying = !!retryDraft?.retry_process_id; + if ((isFinalizing || isSending) && !stillRetrying) { + setIsFinalizing(false); + setIsSending(false); + onCancelled?.(); + return; + } + }, [ + retryDraft?.retry_process_id, + isFinalizing, + isSending, + onCancelled, + attemptId, + ]); + + return ( +
+ {initError && ( + + + {initError} + + )} + + void 0} + disabled={isSending || !!isFinalizing} + showLoadingOverlay={isSending || !!isFinalizing} + textareaClassName="bg-background" + /> + + {/* Draft save/load status (no queue/sending for retry) */} + + +
+ {/* Image button */} + + +
+ + +
+
+ + {showImageUpload && ( +
+ imagesApi.uploadForTask(attempt.task_id, file)} + onDelete={imagesApi.delete} + onImageUploaded={(image) => { + handleImageUploaded(image); + setMessage((prev) => appendImageMarkdown(prev, image)); + }} + disabled={isSending || !!isFinalizing} + collapsible={false} + defaultExpanded={true} + /> +
+ )} + + {sendError && ( + + + {sendError} + + )} +
+ ); +} diff --git a/frontend/src/components/NormalizedConversation/UserMessage.tsx b/frontend/src/components/NormalizedConversation/UserMessage.tsx index 1376ac72..0a713f84 100644 --- a/frontend/src/components/NormalizedConversation/UserMessage.tsx +++ b/frontend/src/components/NormalizedConversation/UserMessage.tsx @@ -1,11 +1,13 @@ import MarkdownRenderer from '@/components/ui/markdown-renderer'; import { Button } from '@/components/ui/button'; -import { Pencil, Send, X } from 'lucide-react'; -import { useState } from 'react'; -import { Textarea } from '@/components/ui/textarea'; +import { Pencil } from 'lucide-react'; +import { useEffect, useState } from 'react'; import { useProcessRetry } from '@/hooks/useProcessRetry'; import { TaskAttempt, type BaseAgentCapability } from 'shared/types'; import { useUserSystem } from '@/components/config-provider'; +import { useDraftStream } from '@/hooks/follow-up/useDraftStream'; +import { RetryEditorInline } from './RetryEditorInline'; +import { useRetryUi } from '@/contexts/RetryUiContext'; const UserMessage = ({ content, @@ -17,9 +19,11 @@ const UserMessage = ({ taskAttempt?: TaskAttempt; }) => { const [isEditing, setIsEditing] = useState(false); - const [editContent, setEditContent] = useState(content); const retryHook = useProcessRetry(taskAttempt); const { capabilities } = useUserSystem(); + const attemptId = taskAttempt?.id; + const { retryDraft } = useDraftStream(attemptId); + const { activeRetryProcessId, isProcessGreyed } = useRetryUi(); const canFork = !!( taskAttempt?.executor && @@ -28,21 +32,53 @@ const UserMessage = ({ ) ); - const handleEdit = () => { - if (!executionProcessId) return; - retryHook?.retryProcess(executionProcessId, editContent).then(() => { + // Enter retry mode: create retry draft; actual editor will render inline + const startRetry = async () => { + if (!executionProcessId || !taskAttempt) return; + setIsEditing(true); + retryHook?.startRetry(executionProcessId, content).catch(() => { + // rollback if server call fails setIsEditing(false); }); }; + // Exit editing state once draft disappears (sent/cancelled) + useEffect(() => { + if (!retryDraft?.retry_process_id) setIsEditing(false); + }, [retryDraft?.retry_process_id]); + + // On reload or when server provides a retry_draft for this process, show editor + useEffect(() => { + if ( + executionProcessId && + retryDraft?.retry_process_id && + retryDraft.retry_process_id === executionProcessId + ) { + setIsEditing(true); + } + }, [executionProcessId, retryDraft?.retry_process_id]); + + const showRetryEditor = + !!executionProcessId && + isEditing && + activeRetryProcessId === executionProcessId; + const greyed = + !!executionProcessId && + isProcessGreyed(executionProcessId) && + !showRetryEditor; + return ( -
+
- {isEditing ? ( -