migrate /diff endpoint to websocket (#851)

This commit is contained in:
Solomon
2025-09-26 10:10:44 +01:00
committed by GitHub
parent 4f343fdb8f
commit bfb0c3f2ea
5 changed files with 71 additions and 59 deletions

View File

@@ -13,7 +13,6 @@ use std::{
use anyhow::anyhow;
use async_stream::try_stream;
use async_trait::async_trait;
use axum::response::sse::Event;
use command_group::AsyncGroupChild;
use db::{
DBService,
@@ -79,7 +78,7 @@ pub struct LocalContainerService {
impl LocalContainerService {
// Max cumulative content bytes allowed per diff stream
const MAX_CUMULATIVE_DIFF_BYTES: usize = 150 * 1024; // 150KB
const MAX_CUMULATIVE_DIFF_BYTES: usize = 200 * 1024 * 1024; // 200MB
// Apply stream-level omit policy based on cumulative bytes.
// If adding this diff's contents exceeds the cap, strip contents and set stats.
@@ -625,12 +624,12 @@ impl LocalContainerService {
Ok(project_repo_path)
}
/// Create a diff stream for merged attempts (never changes)
/// Create a diff log stream for merged attempts (never changes) for WebSocket
fn create_merged_diff_stream(
&self,
project_repo_path: &Path,
merge_commit_id: &str,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
{
let diffs = self.git().get_diffs(
DiffTarget::Commit {
@@ -653,23 +652,22 @@ impl LocalContainerService {
let entry_index = GitService::diff_path(&diff);
let patch =
ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff);
let event = LogMsg::JsonPatch(patch).to_sse_event();
Ok::<_, std::io::Error>(event)
Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch))
}))
.chain(futures::stream::once(async {
Ok::<_, std::io::Error>(LogMsg::Finished.to_sse_event())
Ok::<_, std::io::Error>(LogMsg::Finished)
}))
.boxed();
Ok(stream)
}
/// Create a live diff stream for ongoing attempts
/// Create a live diff log stream for ongoing attempts for WebSocket
async fn create_live_diff_stream(
&self,
worktree_path: &Path,
base_commit: &Commit,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
{
// Get initial snapshot
let git_service = self.git().clone();
@@ -681,9 +679,7 @@ impl LocalContainerService {
None,
)?;
// cumulative counter for entire stream
let cumulative = Arc::new(AtomicUsize::new(0));
// track which file paths have been emitted with full content already
let full_sent = Arc::new(std::sync::RwLock::new(HashSet::<String>::new()));
let initial_diffs: Vec<_> = initial_diffs
.into_iter()
@@ -708,8 +704,7 @@ impl LocalContainerService {
let entry_index = GitService::diff_path(&diff);
let patch =
ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff);
let event = LogMsg::JsonPatch(patch).to_sse_event();
Ok::<_, std::io::Error>(event)
Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch))
}))
.boxed();
@@ -723,7 +718,6 @@ impl LocalContainerService {
let cumulative = Arc::clone(&cumulative);
let full_sent = Arc::clone(&full_sent);
try_stream! {
// Move the expensive watcher setup to blocking thread to avoid blocking the async runtime
let watcher_result = tokio::task::spawn_blocking(move || {
filesystem_watcher::async_watcher(worktree_path_for_spawn)
})
@@ -739,7 +733,7 @@ impl LocalContainerService {
let changed_paths = Self::extract_changed_paths(&events, &canonical_worktree_path, &worktree_path);
if !changed_paths.is_empty() {
for event in Self::process_file_changes(
for msg in Self::process_file_changes(
&git_service,
&worktree_path,
&base_commit,
@@ -750,7 +744,7 @@ impl LocalContainerService {
tracing::error!("Error processing file changes: {}", e);
io::Error::other(e.to_string())
})? {
yield event;
yield msg;
}
}
}
@@ -767,8 +761,6 @@ impl LocalContainerService {
}
}.boxed();
// Ensure all initial diffs are emitted before live updates, to avoid
// earlier files being abbreviated due to interleaving ordering.
let combined_stream = initial_stream.chain(live_stream);
Ok(combined_stream.boxed())
}
@@ -792,7 +784,7 @@ impl LocalContainerService {
.collect()
}
/// Process file changes and generate diff events
/// Process file changes and generate diff messages (for WS)
fn process_file_changes(
git_service: &GitService,
worktree_path: &Path,
@@ -800,7 +792,7 @@ impl LocalContainerService {
changed_paths: &[String],
cumulative_bytes: &Arc<AtomicUsize>,
full_sent_paths: &Arc<std::sync::RwLock<HashSet<String>>>,
) -> Result<Vec<Event>, ContainerError> {
) -> Result<Vec<LogMsg>, ContainerError> {
let path_filter: Vec<&str> = changed_paths.iter().map(|s| s.as_str()).collect();
let current_diffs = git_service.get_diffs(
@@ -811,7 +803,7 @@ impl LocalContainerService {
Some(&path_filter),
)?;
let mut events = Vec::new();
let mut msgs = Vec::new();
let mut files_with_diffs = HashSet::new();
// Add/update files that have diffs
@@ -861,24 +853,17 @@ impl LocalContainerService {
}
}
// If this diff would be omitted and we already sent a full-content
// version of this path earlier in the stream, skip sending a
// degrading replacement.
if diff.content_omitted {
if full_sent_paths.read().unwrap().contains(&file_path) {
continue;
}
} else {
// Track that we have sent a full-content version
{
let mut guard = full_sent_paths.write().unwrap();
guard.insert(file_path.clone());
}
let mut guard = full_sent_paths.write().unwrap();
guard.insert(file_path.clone());
}
let patch = ConversationPatch::add_diff(escape_json_pointer_segment(&file_path), diff);
let event = LogMsg::JsonPatch(patch).to_sse_event();
events.push(event);
msgs.push(LogMsg::JsonPatch(patch));
}
// Remove files that changed but no longer have diffs
@@ -886,12 +871,11 @@ impl LocalContainerService {
if !files_with_diffs.contains(changed_path) {
let patch =
ConversationPatch::remove_diff(escape_json_pointer_segment(changed_path));
let event = LogMsg::JsonPatch(patch).to_sse_event();
events.push(event);
msgs.push(LogMsg::JsonPatch(patch));
}
}
Ok(events)
Ok(msgs)
}
}
@@ -1151,10 +1135,10 @@ impl ContainerService for LocalContainerService {
Ok(())
}
async fn get_diff(
async fn stream_diff(
&self,
task_attempt: &TaskAttempt,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
{
let project_repo_path = self.get_project_repo_path(task_attempt).await?;
let latest_merge =
@@ -1177,7 +1161,6 @@ impl ContainerService for LocalContainerService {
false
};
// Show merged diff when no new work is on the branch or container
if let Some(merge) = &latest_merge
&& let Some(commit) = merge.merge_commit()
&& self.is_container_clean(task_attempt).await?
@@ -1186,17 +1169,14 @@ impl ContainerService for LocalContainerService {
return self.create_merged_diff_stream(&project_repo_path, &commit);
}
// worktree is needed for non-merged diffs
let container_ref = self.ensure_container_exists(task_attempt).await?;
let worktree_path = PathBuf::from(container_ref);
let base_commit = self.git().get_base_commit(
&project_repo_path,
&task_branch,
&task_attempt.base_branch,
)?;
// Handle ongoing attempts (live streaming diff)
self.create_live_diff_stream(&worktree_path, &base_commit)
.await
}

View File

@@ -1,17 +1,14 @@
use std::path::PathBuf;
use axum::{
BoxError, Extension, Json, Router,
Extension, Json, Router,
extract::{
Query, State,
ws::{WebSocket, WebSocketUpgrade},
},
http::StatusCode,
middleware::from_fn_with_state,
response::{
IntoResponse, Json as ResponseJson, Sse,
sse::{Event, KeepAlive},
},
response::{IntoResponse, Json as ResponseJson},
routing::{get, post},
};
use db::models::{
@@ -32,7 +29,6 @@ use executors::{
},
profile::ExecutorProfileId,
};
use futures_util::TryStreamExt;
use git2::BranchType;
use serde::{Deserialize, Serialize};
use services::services::{
@@ -962,14 +958,50 @@ pub async fn replace_process(
})))
}
pub async fn get_task_attempt_diff(
#[axum::debug_handler]
pub async fn stream_task_attempt_diff_ws(
ws: WebSocketUpgrade,
Extension(task_attempt): Extension<TaskAttempt>,
State(deployment): State<DeploymentImpl>,
// ) -> Result<ResponseJson<ApiResponse<Diff>>, ApiError> {
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, BoxError>>>, ApiError> {
let stream = deployment.container().get_diff(&task_attempt).await?;
) -> impl IntoResponse {
ws.on_upgrade(move |socket| async move {
if let Err(e) = handle_task_attempt_diff_ws(socket, deployment, task_attempt).await {
tracing::warn!("diff WS closed: {}", e);
}
})
}
Ok(Sse::new(stream.map_err(|e| -> BoxError { e.into() })).keep_alive(KeepAlive::default()))
async fn handle_task_attempt_diff_ws(
socket: WebSocket,
deployment: DeploymentImpl,
task_attempt: TaskAttempt,
) -> anyhow::Result<()> {
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use utils::log_msg::LogMsg;
let mut stream = deployment
.container()
.stream_diff(&task_attempt)
.await?
.map_ok(|msg: LogMsg| msg.to_ws_message_unchecked());
let (mut sender, mut receiver) = socket.split();
tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
while let Some(item) = stream.next().await {
match item {
Ok(msg) => {
if sender.send(msg).await.is_err() {
break;
}
}
Err(e) => {
tracing::error!("stream error: {}", e);
break;
}
}
}
Ok(())
}
#[derive(Debug, Serialize, TS)]
@@ -1730,7 +1762,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
.route("/commit-compare", get(compare_commit_to_head))
.route("/start-dev-server", post(start_dev_server))
.route("/branch-status", get(get_task_attempt_branch_status))
.route("/diff", get(get_task_attempt_diff))
.route("/diff/ws", get(stream_task_attempt_diff_ws))
.route("/merge", post(merge_task_attempt))
.route("/push", post(push_task_attempt_branch))
.route("/rebase", post(rebase_task_attempt))

View File

@@ -6,7 +6,6 @@ use std::{
use anyhow::{Error as AnyhowError, anyhow};
use async_trait::async_trait;
use axum::response::sse::Event;
use db::{
DBService,
models::{
@@ -194,10 +193,11 @@ pub trait ContainerService {
copy_files: &str,
) -> Result<(), ContainerError>;
async fn get_diff(
/// Stream diff updates as LogMsg for WebSocket endpoints.
async fn stream_diff(
&self,
task_attempt: &TaskAttempt,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>;
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>;
/// Fetch the MsgStore for a given execution ID, panicking if missing.
async fn get_msg_store_by_id(&self, uuid: &Uuid) -> Option<Arc<MsgStore>> {

View File

@@ -44,7 +44,7 @@ pub struct GitService {}
// Max inline diff size for UI (in bytes). Files larger than this will have
// their contents omitted from the diff stream to avoid UI crashes.
const MAX_INLINE_DIFF_BYTES: usize = 150 * 1024; // ~150KB
const MAX_INLINE_DIFF_BYTES: usize = 2 * 1024 * 1024; // ~2MB
#[derive(Debug, Clone, Serialize, Deserialize, TS, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]

View File

@@ -1,6 +1,6 @@
import { useCallback } from 'react';
import type { PatchType } from 'shared/types';
import { useJsonPatchStream } from './useJsonPatchStream';
import { useJsonPatchWsStream } from './useJsonPatchWsStream';
interface DiffState {
entries: Record<string, PatchType>;
@@ -17,7 +17,7 @@ export const useDiffStream = (
enabled: boolean
): UseDiffStreamResult => {
const endpoint = attemptId
? `/api/task-attempts/${attemptId}/diff`
? `/api/task-attempts/${attemptId}/diff/ws`
: undefined;
const initialData = useCallback(
@@ -27,7 +27,7 @@ export const useDiffStream = (
[]
);
const { data, isConnected, error } = useJsonPatchStream(
const { data, isConnected, error } = useJsonPatchWsStream(
endpoint,
enabled && !!attemptId,
initialData