migrate /diff endpoint to websocket (#851)

This commit is contained in:
Solomon
2025-09-26 10:10:44 +01:00
committed by GitHub
parent 4f343fdb8f
commit bfb0c3f2ea
5 changed files with 71 additions and 59 deletions

View File

@@ -13,7 +13,6 @@ use std::{
use anyhow::anyhow; use anyhow::anyhow;
use async_stream::try_stream; use async_stream::try_stream;
use async_trait::async_trait; use async_trait::async_trait;
use axum::response::sse::Event;
use command_group::AsyncGroupChild; use command_group::AsyncGroupChild;
use db::{ use db::{
DBService, DBService,
@@ -79,7 +78,7 @@ pub struct LocalContainerService {
impl LocalContainerService { impl LocalContainerService {
// Max cumulative content bytes allowed per diff stream // Max cumulative content bytes allowed per diff stream
const MAX_CUMULATIVE_DIFF_BYTES: usize = 150 * 1024; // 150KB const MAX_CUMULATIVE_DIFF_BYTES: usize = 200 * 1024 * 1024; // 200MB
// Apply stream-level omit policy based on cumulative bytes. // Apply stream-level omit policy based on cumulative bytes.
// If adding this diff's contents exceeds the cap, strip contents and set stats. // If adding this diff's contents exceeds the cap, strip contents and set stats.
@@ -625,12 +624,12 @@ impl LocalContainerService {
Ok(project_repo_path) Ok(project_repo_path)
} }
/// Create a diff stream for merged attempts (never changes) /// Create a diff log stream for merged attempts (never changes) for WebSocket
fn create_merged_diff_stream( fn create_merged_diff_stream(
&self, &self,
project_repo_path: &Path, project_repo_path: &Path,
merge_commit_id: &str, merge_commit_id: &str,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError> ) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
{ {
let diffs = self.git().get_diffs( let diffs = self.git().get_diffs(
DiffTarget::Commit { DiffTarget::Commit {
@@ -653,23 +652,22 @@ impl LocalContainerService {
let entry_index = GitService::diff_path(&diff); let entry_index = GitService::diff_path(&diff);
let patch = let patch =
ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff); ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff);
let event = LogMsg::JsonPatch(patch).to_sse_event(); Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch))
Ok::<_, std::io::Error>(event)
})) }))
.chain(futures::stream::once(async { .chain(futures::stream::once(async {
Ok::<_, std::io::Error>(LogMsg::Finished.to_sse_event()) Ok::<_, std::io::Error>(LogMsg::Finished)
})) }))
.boxed(); .boxed();
Ok(stream) Ok(stream)
} }
/// Create a live diff stream for ongoing attempts /// Create a live diff log stream for ongoing attempts for WebSocket
async fn create_live_diff_stream( async fn create_live_diff_stream(
&self, &self,
worktree_path: &Path, worktree_path: &Path,
base_commit: &Commit, base_commit: &Commit,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError> ) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
{ {
// Get initial snapshot // Get initial snapshot
let git_service = self.git().clone(); let git_service = self.git().clone();
@@ -681,9 +679,7 @@ impl LocalContainerService {
None, None,
)?; )?;
// cumulative counter for entire stream
let cumulative = Arc::new(AtomicUsize::new(0)); let cumulative = Arc::new(AtomicUsize::new(0));
// track which file paths have been emitted with full content already
let full_sent = Arc::new(std::sync::RwLock::new(HashSet::<String>::new())); let full_sent = Arc::new(std::sync::RwLock::new(HashSet::<String>::new()));
let initial_diffs: Vec<_> = initial_diffs let initial_diffs: Vec<_> = initial_diffs
.into_iter() .into_iter()
@@ -708,8 +704,7 @@ impl LocalContainerService {
let entry_index = GitService::diff_path(&diff); let entry_index = GitService::diff_path(&diff);
let patch = let patch =
ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff); ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff);
let event = LogMsg::JsonPatch(patch).to_sse_event(); Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch))
Ok::<_, std::io::Error>(event)
})) }))
.boxed(); .boxed();
@@ -723,7 +718,6 @@ impl LocalContainerService {
let cumulative = Arc::clone(&cumulative); let cumulative = Arc::clone(&cumulative);
let full_sent = Arc::clone(&full_sent); let full_sent = Arc::clone(&full_sent);
try_stream! { try_stream! {
// Move the expensive watcher setup to blocking thread to avoid blocking the async runtime
let watcher_result = tokio::task::spawn_blocking(move || { let watcher_result = tokio::task::spawn_blocking(move || {
filesystem_watcher::async_watcher(worktree_path_for_spawn) filesystem_watcher::async_watcher(worktree_path_for_spawn)
}) })
@@ -739,7 +733,7 @@ impl LocalContainerService {
let changed_paths = Self::extract_changed_paths(&events, &canonical_worktree_path, &worktree_path); let changed_paths = Self::extract_changed_paths(&events, &canonical_worktree_path, &worktree_path);
if !changed_paths.is_empty() { if !changed_paths.is_empty() {
for event in Self::process_file_changes( for msg in Self::process_file_changes(
&git_service, &git_service,
&worktree_path, &worktree_path,
&base_commit, &base_commit,
@@ -750,7 +744,7 @@ impl LocalContainerService {
tracing::error!("Error processing file changes: {}", e); tracing::error!("Error processing file changes: {}", e);
io::Error::other(e.to_string()) io::Error::other(e.to_string())
})? { })? {
yield event; yield msg;
} }
} }
} }
@@ -767,8 +761,6 @@ impl LocalContainerService {
} }
}.boxed(); }.boxed();
// Ensure all initial diffs are emitted before live updates, to avoid
// earlier files being abbreviated due to interleaving ordering.
let combined_stream = initial_stream.chain(live_stream); let combined_stream = initial_stream.chain(live_stream);
Ok(combined_stream.boxed()) Ok(combined_stream.boxed())
} }
@@ -792,7 +784,7 @@ impl LocalContainerService {
.collect() .collect()
} }
/// Process file changes and generate diff events /// Process file changes and generate diff messages (for WS)
fn process_file_changes( fn process_file_changes(
git_service: &GitService, git_service: &GitService,
worktree_path: &Path, worktree_path: &Path,
@@ -800,7 +792,7 @@ impl LocalContainerService {
changed_paths: &[String], changed_paths: &[String],
cumulative_bytes: &Arc<AtomicUsize>, cumulative_bytes: &Arc<AtomicUsize>,
full_sent_paths: &Arc<std::sync::RwLock<HashSet<String>>>, full_sent_paths: &Arc<std::sync::RwLock<HashSet<String>>>,
) -> Result<Vec<Event>, ContainerError> { ) -> Result<Vec<LogMsg>, ContainerError> {
let path_filter: Vec<&str> = changed_paths.iter().map(|s| s.as_str()).collect(); let path_filter: Vec<&str> = changed_paths.iter().map(|s| s.as_str()).collect();
let current_diffs = git_service.get_diffs( let current_diffs = git_service.get_diffs(
@@ -811,7 +803,7 @@ impl LocalContainerService {
Some(&path_filter), Some(&path_filter),
)?; )?;
let mut events = Vec::new(); let mut msgs = Vec::new();
let mut files_with_diffs = HashSet::new(); let mut files_with_diffs = HashSet::new();
// Add/update files that have diffs // Add/update files that have diffs
@@ -861,24 +853,17 @@ impl LocalContainerService {
} }
} }
// If this diff would be omitted and we already sent a full-content
// version of this path earlier in the stream, skip sending a
// degrading replacement.
if diff.content_omitted { if diff.content_omitted {
if full_sent_paths.read().unwrap().contains(&file_path) { if full_sent_paths.read().unwrap().contains(&file_path) {
continue; continue;
} }
} else { } else {
// Track that we have sent a full-content version let mut guard = full_sent_paths.write().unwrap();
{ guard.insert(file_path.clone());
let mut guard = full_sent_paths.write().unwrap();
guard.insert(file_path.clone());
}
} }
let patch = ConversationPatch::add_diff(escape_json_pointer_segment(&file_path), diff); let patch = ConversationPatch::add_diff(escape_json_pointer_segment(&file_path), diff);
let event = LogMsg::JsonPatch(patch).to_sse_event(); msgs.push(LogMsg::JsonPatch(patch));
events.push(event);
} }
// Remove files that changed but no longer have diffs // Remove files that changed but no longer have diffs
@@ -886,12 +871,11 @@ impl LocalContainerService {
if !files_with_diffs.contains(changed_path) { if !files_with_diffs.contains(changed_path) {
let patch = let patch =
ConversationPatch::remove_diff(escape_json_pointer_segment(changed_path)); ConversationPatch::remove_diff(escape_json_pointer_segment(changed_path));
let event = LogMsg::JsonPatch(patch).to_sse_event(); msgs.push(LogMsg::JsonPatch(patch));
events.push(event);
} }
} }
Ok(events) Ok(msgs)
} }
} }
@@ -1151,10 +1135,10 @@ impl ContainerService for LocalContainerService {
Ok(()) Ok(())
} }
async fn get_diff( async fn stream_diff(
&self, &self,
task_attempt: &TaskAttempt, task_attempt: &TaskAttempt,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError> ) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
{ {
let project_repo_path = self.get_project_repo_path(task_attempt).await?; let project_repo_path = self.get_project_repo_path(task_attempt).await?;
let latest_merge = let latest_merge =
@@ -1177,7 +1161,6 @@ impl ContainerService for LocalContainerService {
false false
}; };
// Show merged diff when no new work is on the branch or container
if let Some(merge) = &latest_merge if let Some(merge) = &latest_merge
&& let Some(commit) = merge.merge_commit() && let Some(commit) = merge.merge_commit()
&& self.is_container_clean(task_attempt).await? && self.is_container_clean(task_attempt).await?
@@ -1186,17 +1169,14 @@ impl ContainerService for LocalContainerService {
return self.create_merged_diff_stream(&project_repo_path, &commit); return self.create_merged_diff_stream(&project_repo_path, &commit);
} }
// worktree is needed for non-merged diffs
let container_ref = self.ensure_container_exists(task_attempt).await?; let container_ref = self.ensure_container_exists(task_attempt).await?;
let worktree_path = PathBuf::from(container_ref); let worktree_path = PathBuf::from(container_ref);
let base_commit = self.git().get_base_commit( let base_commit = self.git().get_base_commit(
&project_repo_path, &project_repo_path,
&task_branch, &task_branch,
&task_attempt.base_branch, &task_attempt.base_branch,
)?; )?;
// Handle ongoing attempts (live streaming diff)
self.create_live_diff_stream(&worktree_path, &base_commit) self.create_live_diff_stream(&worktree_path, &base_commit)
.await .await
} }

View File

@@ -1,17 +1,14 @@
use std::path::PathBuf; use std::path::PathBuf;
use axum::{ use axum::{
BoxError, Extension, Json, Router, Extension, Json, Router,
extract::{ extract::{
Query, State, Query, State,
ws::{WebSocket, WebSocketUpgrade}, ws::{WebSocket, WebSocketUpgrade},
}, },
http::StatusCode, http::StatusCode,
middleware::from_fn_with_state, middleware::from_fn_with_state,
response::{ response::{IntoResponse, Json as ResponseJson},
IntoResponse, Json as ResponseJson, Sse,
sse::{Event, KeepAlive},
},
routing::{get, post}, routing::{get, post},
}; };
use db::models::{ use db::models::{
@@ -32,7 +29,6 @@ use executors::{
}, },
profile::ExecutorProfileId, profile::ExecutorProfileId,
}; };
use futures_util::TryStreamExt;
use git2::BranchType; use git2::BranchType;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use services::services::{ use services::services::{
@@ -962,14 +958,50 @@ pub async fn replace_process(
}))) })))
} }
pub async fn get_task_attempt_diff( #[axum::debug_handler]
pub async fn stream_task_attempt_diff_ws(
ws: WebSocketUpgrade,
Extension(task_attempt): Extension<TaskAttempt>, Extension(task_attempt): Extension<TaskAttempt>,
State(deployment): State<DeploymentImpl>, State(deployment): State<DeploymentImpl>,
// ) -> Result<ResponseJson<ApiResponse<Diff>>, ApiError> { ) -> impl IntoResponse {
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, BoxError>>>, ApiError> { ws.on_upgrade(move |socket| async move {
let stream = deployment.container().get_diff(&task_attempt).await?; if let Err(e) = handle_task_attempt_diff_ws(socket, deployment, task_attempt).await {
tracing::warn!("diff WS closed: {}", e);
}
})
}
Ok(Sse::new(stream.map_err(|e| -> BoxError { e.into() })).keep_alive(KeepAlive::default())) async fn handle_task_attempt_diff_ws(
socket: WebSocket,
deployment: DeploymentImpl,
task_attempt: TaskAttempt,
) -> anyhow::Result<()> {
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use utils::log_msg::LogMsg;
let mut stream = deployment
.container()
.stream_diff(&task_attempt)
.await?
.map_ok(|msg: LogMsg| msg.to_ws_message_unchecked());
let (mut sender, mut receiver) = socket.split();
tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
while let Some(item) = stream.next().await {
match item {
Ok(msg) => {
if sender.send(msg).await.is_err() {
break;
}
}
Err(e) => {
tracing::error!("stream error: {}", e);
break;
}
}
}
Ok(())
} }
#[derive(Debug, Serialize, TS)] #[derive(Debug, Serialize, TS)]
@@ -1730,7 +1762,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
.route("/commit-compare", get(compare_commit_to_head)) .route("/commit-compare", get(compare_commit_to_head))
.route("/start-dev-server", post(start_dev_server)) .route("/start-dev-server", post(start_dev_server))
.route("/branch-status", get(get_task_attempt_branch_status)) .route("/branch-status", get(get_task_attempt_branch_status))
.route("/diff", get(get_task_attempt_diff)) .route("/diff/ws", get(stream_task_attempt_diff_ws))
.route("/merge", post(merge_task_attempt)) .route("/merge", post(merge_task_attempt))
.route("/push", post(push_task_attempt_branch)) .route("/push", post(push_task_attempt_branch))
.route("/rebase", post(rebase_task_attempt)) .route("/rebase", post(rebase_task_attempt))

View File

@@ -6,7 +6,6 @@ use std::{
use anyhow::{Error as AnyhowError, anyhow}; use anyhow::{Error as AnyhowError, anyhow};
use async_trait::async_trait; use async_trait::async_trait;
use axum::response::sse::Event;
use db::{ use db::{
DBService, DBService,
models::{ models::{
@@ -194,10 +193,11 @@ pub trait ContainerService {
copy_files: &str, copy_files: &str,
) -> Result<(), ContainerError>; ) -> Result<(), ContainerError>;
async fn get_diff( /// Stream diff updates as LogMsg for WebSocket endpoints.
async fn stream_diff(
&self, &self,
task_attempt: &TaskAttempt, task_attempt: &TaskAttempt,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>; ) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>;
/// Fetch the MsgStore for a given execution ID, panicking if missing. /// Fetch the MsgStore for a given execution ID, panicking if missing.
async fn get_msg_store_by_id(&self, uuid: &Uuid) -> Option<Arc<MsgStore>> { async fn get_msg_store_by_id(&self, uuid: &Uuid) -> Option<Arc<MsgStore>> {

View File

@@ -44,7 +44,7 @@ pub struct GitService {}
// Max inline diff size for UI (in bytes). Files larger than this will have // Max inline diff size for UI (in bytes). Files larger than this will have
// their contents omitted from the diff stream to avoid UI crashes. // their contents omitted from the diff stream to avoid UI crashes.
const MAX_INLINE_DIFF_BYTES: usize = 150 * 1024; // ~150KB const MAX_INLINE_DIFF_BYTES: usize = 2 * 1024 * 1024; // ~2MB
#[derive(Debug, Clone, Serialize, Deserialize, TS, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, TS, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]

View File

@@ -1,6 +1,6 @@
import { useCallback } from 'react'; import { useCallback } from 'react';
import type { PatchType } from 'shared/types'; import type { PatchType } from 'shared/types';
import { useJsonPatchStream } from './useJsonPatchStream'; import { useJsonPatchWsStream } from './useJsonPatchWsStream';
interface DiffState { interface DiffState {
entries: Record<string, PatchType>; entries: Record<string, PatchType>;
@@ -17,7 +17,7 @@ export const useDiffStream = (
enabled: boolean enabled: boolean
): UseDiffStreamResult => { ): UseDiffStreamResult => {
const endpoint = attemptId const endpoint = attemptId
? `/api/task-attempts/${attemptId}/diff` ? `/api/task-attempts/${attemptId}/diff/ws`
: undefined; : undefined;
const initialData = useCallback( const initialData = useCallback(
@@ -27,7 +27,7 @@ export const useDiffStream = (
[] []
); );
const { data, isConnected, error } = useJsonPatchStream( const { data, isConnected, error } = useJsonPatchWsStream(
endpoint, endpoint,
enabled && !!attemptId, enabled && !!attemptId,
initialData initialData