migrate /diff endpoint to websocket (#851)

This commit is contained in:
Solomon
2025-09-26 10:10:44 +01:00
committed by GitHub
parent 4f343fdb8f
commit bfb0c3f2ea
5 changed files with 71 additions and 59 deletions

View File

@@ -13,7 +13,6 @@ use std::{
use anyhow::anyhow;
use async_stream::try_stream;
use async_trait::async_trait;
use axum::response::sse::Event;
use command_group::AsyncGroupChild;
use db::{
DBService,
@@ -79,7 +78,7 @@ pub struct LocalContainerService {
impl LocalContainerService {
// Max cumulative content bytes allowed per diff stream
const MAX_CUMULATIVE_DIFF_BYTES: usize = 150 * 1024; // 150KB
const MAX_CUMULATIVE_DIFF_BYTES: usize = 200 * 1024 * 1024; // 200MB
// Apply stream-level omit policy based on cumulative bytes.
// If adding this diff's contents exceeds the cap, strip contents and set stats.
@@ -625,12 +624,12 @@ impl LocalContainerService {
Ok(project_repo_path)
}
/// Create a diff stream for merged attempts (never changes)
/// Create a diff log stream for merged attempts (never changes) for WebSocket
fn create_merged_diff_stream(
&self,
project_repo_path: &Path,
merge_commit_id: &str,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
{
let diffs = self.git().get_diffs(
DiffTarget::Commit {
@@ -653,23 +652,22 @@ impl LocalContainerService {
let entry_index = GitService::diff_path(&diff);
let patch =
ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff);
let event = LogMsg::JsonPatch(patch).to_sse_event();
Ok::<_, std::io::Error>(event)
Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch))
}))
.chain(futures::stream::once(async {
Ok::<_, std::io::Error>(LogMsg::Finished.to_sse_event())
Ok::<_, std::io::Error>(LogMsg::Finished)
}))
.boxed();
Ok(stream)
}
/// Create a live diff stream for ongoing attempts
/// Create a live diff log stream for ongoing attempts for WebSocket
async fn create_live_diff_stream(
&self,
worktree_path: &Path,
base_commit: &Commit,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
{
// Get initial snapshot
let git_service = self.git().clone();
@@ -681,9 +679,7 @@ impl LocalContainerService {
None,
)?;
// cumulative counter for entire stream
let cumulative = Arc::new(AtomicUsize::new(0));
// track which file paths have been emitted with full content already
let full_sent = Arc::new(std::sync::RwLock::new(HashSet::<String>::new()));
let initial_diffs: Vec<_> = initial_diffs
.into_iter()
@@ -708,8 +704,7 @@ impl LocalContainerService {
let entry_index = GitService::diff_path(&diff);
let patch =
ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff);
let event = LogMsg::JsonPatch(patch).to_sse_event();
Ok::<_, std::io::Error>(event)
Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch))
}))
.boxed();
@@ -723,7 +718,6 @@ impl LocalContainerService {
let cumulative = Arc::clone(&cumulative);
let full_sent = Arc::clone(&full_sent);
try_stream! {
// Move the expensive watcher setup to blocking thread to avoid blocking the async runtime
let watcher_result = tokio::task::spawn_blocking(move || {
filesystem_watcher::async_watcher(worktree_path_for_spawn)
})
@@ -739,7 +733,7 @@ impl LocalContainerService {
let changed_paths = Self::extract_changed_paths(&events, &canonical_worktree_path, &worktree_path);
if !changed_paths.is_empty() {
for event in Self::process_file_changes(
for msg in Self::process_file_changes(
&git_service,
&worktree_path,
&base_commit,
@@ -750,7 +744,7 @@ impl LocalContainerService {
tracing::error!("Error processing file changes: {}", e);
io::Error::other(e.to_string())
})? {
yield event;
yield msg;
}
}
}
@@ -767,8 +761,6 @@ impl LocalContainerService {
}
}.boxed();
// Ensure all initial diffs are emitted before live updates, to avoid
// earlier files being abbreviated due to interleaving ordering.
let combined_stream = initial_stream.chain(live_stream);
Ok(combined_stream.boxed())
}
@@ -792,7 +784,7 @@ impl LocalContainerService {
.collect()
}
/// Process file changes and generate diff events
/// Process file changes and generate diff messages (for WS)
fn process_file_changes(
git_service: &GitService,
worktree_path: &Path,
@@ -800,7 +792,7 @@ impl LocalContainerService {
changed_paths: &[String],
cumulative_bytes: &Arc<AtomicUsize>,
full_sent_paths: &Arc<std::sync::RwLock<HashSet<String>>>,
) -> Result<Vec<Event>, ContainerError> {
) -> Result<Vec<LogMsg>, ContainerError> {
let path_filter: Vec<&str> = changed_paths.iter().map(|s| s.as_str()).collect();
let current_diffs = git_service.get_diffs(
@@ -811,7 +803,7 @@ impl LocalContainerService {
Some(&path_filter),
)?;
let mut events = Vec::new();
let mut msgs = Vec::new();
let mut files_with_diffs = HashSet::new();
// Add/update files that have diffs
@@ -861,24 +853,17 @@ impl LocalContainerService {
}
}
// If this diff would be omitted and we already sent a full-content
// version of this path earlier in the stream, skip sending a
// degrading replacement.
if diff.content_omitted {
if full_sent_paths.read().unwrap().contains(&file_path) {
continue;
}
} else {
// Track that we have sent a full-content version
{
let mut guard = full_sent_paths.write().unwrap();
guard.insert(file_path.clone());
}
let mut guard = full_sent_paths.write().unwrap();
guard.insert(file_path.clone());
}
let patch = ConversationPatch::add_diff(escape_json_pointer_segment(&file_path), diff);
let event = LogMsg::JsonPatch(patch).to_sse_event();
events.push(event);
msgs.push(LogMsg::JsonPatch(patch));
}
// Remove files that changed but no longer have diffs
@@ -886,12 +871,11 @@ impl LocalContainerService {
if !files_with_diffs.contains(changed_path) {
let patch =
ConversationPatch::remove_diff(escape_json_pointer_segment(changed_path));
let event = LogMsg::JsonPatch(patch).to_sse_event();
events.push(event);
msgs.push(LogMsg::JsonPatch(patch));
}
}
Ok(events)
Ok(msgs)
}
}
@@ -1151,10 +1135,10 @@ impl ContainerService for LocalContainerService {
Ok(())
}
async fn get_diff(
async fn stream_diff(
&self,
task_attempt: &TaskAttempt,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
{
let project_repo_path = self.get_project_repo_path(task_attempt).await?;
let latest_merge =
@@ -1177,7 +1161,6 @@ impl ContainerService for LocalContainerService {
false
};
// Show merged diff when no new work is on the branch or container
if let Some(merge) = &latest_merge
&& let Some(commit) = merge.merge_commit()
&& self.is_container_clean(task_attempt).await?
@@ -1186,17 +1169,14 @@ impl ContainerService for LocalContainerService {
return self.create_merged_diff_stream(&project_repo_path, &commit);
}
// worktree is needed for non-merged diffs
let container_ref = self.ensure_container_exists(task_attempt).await?;
let worktree_path = PathBuf::from(container_ref);
let base_commit = self.git().get_base_commit(
&project_repo_path,
&task_branch,
&task_attempt.base_branch,
)?;
// Handle ongoing attempts (live streaming diff)
self.create_live_diff_stream(&worktree_path, &base_commit)
.await
}