diff --git a/crates/local-deployment/src/container.rs b/crates/local-deployment/src/container.rs index 0e39aae2..392afc4e 100644 --- a/crates/local-deployment/src/container.rs +++ b/crates/local-deployment/src/container.rs @@ -13,7 +13,6 @@ use std::{ use anyhow::anyhow; use async_stream::try_stream; use async_trait::async_trait; -use axum::response::sse::Event; use command_group::AsyncGroupChild; use db::{ DBService, @@ -79,7 +78,7 @@ pub struct LocalContainerService { impl LocalContainerService { // Max cumulative content bytes allowed per diff stream - const MAX_CUMULATIVE_DIFF_BYTES: usize = 150 * 1024; // 150KB + const MAX_CUMULATIVE_DIFF_BYTES: usize = 200 * 1024 * 1024; // 200MB // Apply stream-level omit policy based on cumulative bytes. // If adding this diff's contents exceeds the cap, strip contents and set stats. @@ -625,12 +624,12 @@ impl LocalContainerService { Ok(project_repo_path) } - /// Create a diff stream for merged attempts (never changes) + /// Create a diff log stream for merged attempts (never changes) for WebSocket fn create_merged_diff_stream( &self, project_repo_path: &Path, merge_commit_id: &str, - ) -> Result>, ContainerError> + ) -> Result>, ContainerError> { let diffs = self.git().get_diffs( DiffTarget::Commit { @@ -653,23 +652,22 @@ impl LocalContainerService { let entry_index = GitService::diff_path(&diff); let patch = ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff); - let event = LogMsg::JsonPatch(patch).to_sse_event(); - Ok::<_, std::io::Error>(event) + Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch)) })) .chain(futures::stream::once(async { - Ok::<_, std::io::Error>(LogMsg::Finished.to_sse_event()) + Ok::<_, std::io::Error>(LogMsg::Finished) })) .boxed(); Ok(stream) } - /// Create a live diff stream for ongoing attempts + /// Create a live diff log stream for ongoing attempts for WebSocket async fn create_live_diff_stream( &self, worktree_path: &Path, base_commit: &Commit, - ) -> Result>, ContainerError> + ) -> Result>, ContainerError> { // Get initial snapshot let git_service = self.git().clone(); @@ -681,9 +679,7 @@ impl LocalContainerService { None, )?; - // cumulative counter for entire stream let cumulative = Arc::new(AtomicUsize::new(0)); - // track which file paths have been emitted with full content already let full_sent = Arc::new(std::sync::RwLock::new(HashSet::::new())); let initial_diffs: Vec<_> = initial_diffs .into_iter() @@ -708,8 +704,7 @@ impl LocalContainerService { let entry_index = GitService::diff_path(&diff); let patch = ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff); - let event = LogMsg::JsonPatch(patch).to_sse_event(); - Ok::<_, std::io::Error>(event) + Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch)) })) .boxed(); @@ -723,7 +718,6 @@ impl LocalContainerService { let cumulative = Arc::clone(&cumulative); let full_sent = Arc::clone(&full_sent); try_stream! { - // Move the expensive watcher setup to blocking thread to avoid blocking the async runtime let watcher_result = tokio::task::spawn_blocking(move || { filesystem_watcher::async_watcher(worktree_path_for_spawn) }) @@ -739,7 +733,7 @@ impl LocalContainerService { let changed_paths = Self::extract_changed_paths(&events, &canonical_worktree_path, &worktree_path); if !changed_paths.is_empty() { - for event in Self::process_file_changes( + for msg in Self::process_file_changes( &git_service, &worktree_path, &base_commit, @@ -750,7 +744,7 @@ impl LocalContainerService { tracing::error!("Error processing file changes: {}", e); io::Error::other(e.to_string()) })? { - yield event; + yield msg; } } } @@ -767,8 +761,6 @@ impl LocalContainerService { } }.boxed(); - // Ensure all initial diffs are emitted before live updates, to avoid - // earlier files being abbreviated due to interleaving ordering. let combined_stream = initial_stream.chain(live_stream); Ok(combined_stream.boxed()) } @@ -792,7 +784,7 @@ impl LocalContainerService { .collect() } - /// Process file changes and generate diff events + /// Process file changes and generate diff messages (for WS) fn process_file_changes( git_service: &GitService, worktree_path: &Path, @@ -800,7 +792,7 @@ impl LocalContainerService { changed_paths: &[String], cumulative_bytes: &Arc, full_sent_paths: &Arc>>, - ) -> Result, ContainerError> { + ) -> Result, ContainerError> { let path_filter: Vec<&str> = changed_paths.iter().map(|s| s.as_str()).collect(); let current_diffs = git_service.get_diffs( @@ -811,7 +803,7 @@ impl LocalContainerService { Some(&path_filter), )?; - let mut events = Vec::new(); + let mut msgs = Vec::new(); let mut files_with_diffs = HashSet::new(); // Add/update files that have diffs @@ -861,24 +853,17 @@ impl LocalContainerService { } } - // If this diff would be omitted and we already sent a full-content - // version of this path earlier in the stream, skip sending a - // degrading replacement. if diff.content_omitted { if full_sent_paths.read().unwrap().contains(&file_path) { continue; } } else { - // Track that we have sent a full-content version - { - let mut guard = full_sent_paths.write().unwrap(); - guard.insert(file_path.clone()); - } + let mut guard = full_sent_paths.write().unwrap(); + guard.insert(file_path.clone()); } let patch = ConversationPatch::add_diff(escape_json_pointer_segment(&file_path), diff); - let event = LogMsg::JsonPatch(patch).to_sse_event(); - events.push(event); + msgs.push(LogMsg::JsonPatch(patch)); } // Remove files that changed but no longer have diffs @@ -886,12 +871,11 @@ impl LocalContainerService { if !files_with_diffs.contains(changed_path) { let patch = ConversationPatch::remove_diff(escape_json_pointer_segment(changed_path)); - let event = LogMsg::JsonPatch(patch).to_sse_event(); - events.push(event); + msgs.push(LogMsg::JsonPatch(patch)); } } - Ok(events) + Ok(msgs) } } @@ -1151,10 +1135,10 @@ impl ContainerService for LocalContainerService { Ok(()) } - async fn get_diff( + async fn stream_diff( &self, task_attempt: &TaskAttempt, - ) -> Result>, ContainerError> + ) -> Result>, ContainerError> { let project_repo_path = self.get_project_repo_path(task_attempt).await?; let latest_merge = @@ -1177,7 +1161,6 @@ impl ContainerService for LocalContainerService { false }; - // Show merged diff when no new work is on the branch or container if let Some(merge) = &latest_merge && let Some(commit) = merge.merge_commit() && self.is_container_clean(task_attempt).await? @@ -1186,17 +1169,14 @@ impl ContainerService for LocalContainerService { return self.create_merged_diff_stream(&project_repo_path, &commit); } - // worktree is needed for non-merged diffs let container_ref = self.ensure_container_exists(task_attempt).await?; let worktree_path = PathBuf::from(container_ref); - let base_commit = self.git().get_base_commit( &project_repo_path, &task_branch, &task_attempt.base_branch, )?; - // Handle ongoing attempts (live streaming diff) self.create_live_diff_stream(&worktree_path, &base_commit) .await } diff --git a/crates/server/src/routes/task_attempts.rs b/crates/server/src/routes/task_attempts.rs index 940b1ab9..a501e24c 100644 --- a/crates/server/src/routes/task_attempts.rs +++ b/crates/server/src/routes/task_attempts.rs @@ -1,17 +1,14 @@ use std::path::PathBuf; use axum::{ - BoxError, Extension, Json, Router, + Extension, Json, Router, extract::{ Query, State, ws::{WebSocket, WebSocketUpgrade}, }, http::StatusCode, middleware::from_fn_with_state, - response::{ - IntoResponse, Json as ResponseJson, Sse, - sse::{Event, KeepAlive}, - }, + response::{IntoResponse, Json as ResponseJson}, routing::{get, post}, }; use db::models::{ @@ -32,7 +29,6 @@ use executors::{ }, profile::ExecutorProfileId, }; -use futures_util::TryStreamExt; use git2::BranchType; use serde::{Deserialize, Serialize}; use services::services::{ @@ -962,14 +958,50 @@ pub async fn replace_process( }))) } -pub async fn get_task_attempt_diff( +#[axum::debug_handler] +pub async fn stream_task_attempt_diff_ws( + ws: WebSocketUpgrade, Extension(task_attempt): Extension, State(deployment): State, - // ) -> Result>, ApiError> { -) -> Result>>, ApiError> { - let stream = deployment.container().get_diff(&task_attempt).await?; +) -> impl IntoResponse { + ws.on_upgrade(move |socket| async move { + if let Err(e) = handle_task_attempt_diff_ws(socket, deployment, task_attempt).await { + tracing::warn!("diff WS closed: {}", e); + } + }) +} - Ok(Sse::new(stream.map_err(|e| -> BoxError { e.into() })).keep_alive(KeepAlive::default())) +async fn handle_task_attempt_diff_ws( + socket: WebSocket, + deployment: DeploymentImpl, + task_attempt: TaskAttempt, +) -> anyhow::Result<()> { + use futures_util::{SinkExt, StreamExt, TryStreamExt}; + use utils::log_msg::LogMsg; + + let mut stream = deployment + .container() + .stream_diff(&task_attempt) + .await? + .map_ok(|msg: LogMsg| msg.to_ws_message_unchecked()); + + let (mut sender, mut receiver) = socket.split(); + tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} }); + + while let Some(item) = stream.next().await { + match item { + Ok(msg) => { + if sender.send(msg).await.is_err() { + break; + } + } + Err(e) => { + tracing::error!("stream error: {}", e); + break; + } + } + } + Ok(()) } #[derive(Debug, Serialize, TS)] @@ -1730,7 +1762,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router { .route("/commit-compare", get(compare_commit_to_head)) .route("/start-dev-server", post(start_dev_server)) .route("/branch-status", get(get_task_attempt_branch_status)) - .route("/diff", get(get_task_attempt_diff)) + .route("/diff/ws", get(stream_task_attempt_diff_ws)) .route("/merge", post(merge_task_attempt)) .route("/push", post(push_task_attempt_branch)) .route("/rebase", post(rebase_task_attempt)) diff --git a/crates/services/src/services/container.rs b/crates/services/src/services/container.rs index f01bcd91..3aac4126 100644 --- a/crates/services/src/services/container.rs +++ b/crates/services/src/services/container.rs @@ -6,7 +6,6 @@ use std::{ use anyhow::{Error as AnyhowError, anyhow}; use async_trait::async_trait; -use axum::response::sse::Event; use db::{ DBService, models::{ @@ -194,10 +193,11 @@ pub trait ContainerService { copy_files: &str, ) -> Result<(), ContainerError>; - async fn get_diff( + /// Stream diff updates as LogMsg for WebSocket endpoints. + async fn stream_diff( &self, task_attempt: &TaskAttempt, - ) -> Result>, ContainerError>; + ) -> Result>, ContainerError>; /// Fetch the MsgStore for a given execution ID, panicking if missing. async fn get_msg_store_by_id(&self, uuid: &Uuid) -> Option> { diff --git a/crates/services/src/services/git.rs b/crates/services/src/services/git.rs index f07bf74a..8c5efa58 100644 --- a/crates/services/src/services/git.rs +++ b/crates/services/src/services/git.rs @@ -44,7 +44,7 @@ pub struct GitService {} // Max inline diff size for UI (in bytes). Files larger than this will have // their contents omitted from the diff stream to avoid UI crashes. -const MAX_INLINE_DIFF_BYTES: usize = 150 * 1024; // ~150KB +const MAX_INLINE_DIFF_BYTES: usize = 2 * 1024 * 1024; // ~2MB #[derive(Debug, Clone, Serialize, Deserialize, TS, PartialEq, Eq)] #[serde(rename_all = "snake_case")] diff --git a/frontend/src/hooks/useDiffStream.ts b/frontend/src/hooks/useDiffStream.ts index ae9dba0b..e39fa1be 100644 --- a/frontend/src/hooks/useDiffStream.ts +++ b/frontend/src/hooks/useDiffStream.ts @@ -1,6 +1,6 @@ import { useCallback } from 'react'; import type { PatchType } from 'shared/types'; -import { useJsonPatchStream } from './useJsonPatchStream'; +import { useJsonPatchWsStream } from './useJsonPatchWsStream'; interface DiffState { entries: Record; @@ -17,7 +17,7 @@ export const useDiffStream = ( enabled: boolean ): UseDiffStreamResult => { const endpoint = attemptId - ? `/api/task-attempts/${attemptId}/diff` + ? `/api/task-attempts/${attemptId}/diff/ws` : undefined; const initialData = useCallback( @@ -27,7 +27,7 @@ export const useDiffStream = ( [] ); - const { data, isConnected, error } = useJsonPatchStream( + const { data, isConnected, error } = useJsonPatchWsStream( endpoint, enabled && !!attemptId, initialData