migrate /diff endpoint to websocket (#851)
This commit is contained in:
@@ -13,7 +13,6 @@ use std::{
|
||||
use anyhow::anyhow;
|
||||
use async_stream::try_stream;
|
||||
use async_trait::async_trait;
|
||||
use axum::response::sse::Event;
|
||||
use command_group::AsyncGroupChild;
|
||||
use db::{
|
||||
DBService,
|
||||
@@ -79,7 +78,7 @@ pub struct LocalContainerService {
|
||||
|
||||
impl LocalContainerService {
|
||||
// Max cumulative content bytes allowed per diff stream
|
||||
const MAX_CUMULATIVE_DIFF_BYTES: usize = 150 * 1024; // 150KB
|
||||
const MAX_CUMULATIVE_DIFF_BYTES: usize = 200 * 1024 * 1024; // 200MB
|
||||
|
||||
// Apply stream-level omit policy based on cumulative bytes.
|
||||
// If adding this diff's contents exceeds the cap, strip contents and set stats.
|
||||
@@ -625,12 +624,12 @@ impl LocalContainerService {
|
||||
Ok(project_repo_path)
|
||||
}
|
||||
|
||||
/// Create a diff stream for merged attempts (never changes)
|
||||
/// Create a diff log stream for merged attempts (never changes) for WebSocket
|
||||
fn create_merged_diff_stream(
|
||||
&self,
|
||||
project_repo_path: &Path,
|
||||
merge_commit_id: &str,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
|
||||
{
|
||||
let diffs = self.git().get_diffs(
|
||||
DiffTarget::Commit {
|
||||
@@ -653,23 +652,22 @@ impl LocalContainerService {
|
||||
let entry_index = GitService::diff_path(&diff);
|
||||
let patch =
|
||||
ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff);
|
||||
let event = LogMsg::JsonPatch(patch).to_sse_event();
|
||||
Ok::<_, std::io::Error>(event)
|
||||
Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch))
|
||||
}))
|
||||
.chain(futures::stream::once(async {
|
||||
Ok::<_, std::io::Error>(LogMsg::Finished.to_sse_event())
|
||||
Ok::<_, std::io::Error>(LogMsg::Finished)
|
||||
}))
|
||||
.boxed();
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
/// Create a live diff stream for ongoing attempts
|
||||
/// Create a live diff log stream for ongoing attempts for WebSocket
|
||||
async fn create_live_diff_stream(
|
||||
&self,
|
||||
worktree_path: &Path,
|
||||
base_commit: &Commit,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
|
||||
{
|
||||
// Get initial snapshot
|
||||
let git_service = self.git().clone();
|
||||
@@ -681,9 +679,7 @@ impl LocalContainerService {
|
||||
None,
|
||||
)?;
|
||||
|
||||
// cumulative counter for entire stream
|
||||
let cumulative = Arc::new(AtomicUsize::new(0));
|
||||
// track which file paths have been emitted with full content already
|
||||
let full_sent = Arc::new(std::sync::RwLock::new(HashSet::<String>::new()));
|
||||
let initial_diffs: Vec<_> = initial_diffs
|
||||
.into_iter()
|
||||
@@ -708,8 +704,7 @@ impl LocalContainerService {
|
||||
let entry_index = GitService::diff_path(&diff);
|
||||
let patch =
|
||||
ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff);
|
||||
let event = LogMsg::JsonPatch(patch).to_sse_event();
|
||||
Ok::<_, std::io::Error>(event)
|
||||
Ok::<_, std::io::Error>(LogMsg::JsonPatch(patch))
|
||||
}))
|
||||
.boxed();
|
||||
|
||||
@@ -723,7 +718,6 @@ impl LocalContainerService {
|
||||
let cumulative = Arc::clone(&cumulative);
|
||||
let full_sent = Arc::clone(&full_sent);
|
||||
try_stream! {
|
||||
// Move the expensive watcher setup to blocking thread to avoid blocking the async runtime
|
||||
let watcher_result = tokio::task::spawn_blocking(move || {
|
||||
filesystem_watcher::async_watcher(worktree_path_for_spawn)
|
||||
})
|
||||
@@ -739,7 +733,7 @@ impl LocalContainerService {
|
||||
let changed_paths = Self::extract_changed_paths(&events, &canonical_worktree_path, &worktree_path);
|
||||
|
||||
if !changed_paths.is_empty() {
|
||||
for event in Self::process_file_changes(
|
||||
for msg in Self::process_file_changes(
|
||||
&git_service,
|
||||
&worktree_path,
|
||||
&base_commit,
|
||||
@@ -750,7 +744,7 @@ impl LocalContainerService {
|
||||
tracing::error!("Error processing file changes: {}", e);
|
||||
io::Error::other(e.to_string())
|
||||
})? {
|
||||
yield event;
|
||||
yield msg;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -767,8 +761,6 @@ impl LocalContainerService {
|
||||
}
|
||||
}.boxed();
|
||||
|
||||
// Ensure all initial diffs are emitted before live updates, to avoid
|
||||
// earlier files being abbreviated due to interleaving ordering.
|
||||
let combined_stream = initial_stream.chain(live_stream);
|
||||
Ok(combined_stream.boxed())
|
||||
}
|
||||
@@ -792,7 +784,7 @@ impl LocalContainerService {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Process file changes and generate diff events
|
||||
/// Process file changes and generate diff messages (for WS)
|
||||
fn process_file_changes(
|
||||
git_service: &GitService,
|
||||
worktree_path: &Path,
|
||||
@@ -800,7 +792,7 @@ impl LocalContainerService {
|
||||
changed_paths: &[String],
|
||||
cumulative_bytes: &Arc<AtomicUsize>,
|
||||
full_sent_paths: &Arc<std::sync::RwLock<HashSet<String>>>,
|
||||
) -> Result<Vec<Event>, ContainerError> {
|
||||
) -> Result<Vec<LogMsg>, ContainerError> {
|
||||
let path_filter: Vec<&str> = changed_paths.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
let current_diffs = git_service.get_diffs(
|
||||
@@ -811,7 +803,7 @@ impl LocalContainerService {
|
||||
Some(&path_filter),
|
||||
)?;
|
||||
|
||||
let mut events = Vec::new();
|
||||
let mut msgs = Vec::new();
|
||||
let mut files_with_diffs = HashSet::new();
|
||||
|
||||
// Add/update files that have diffs
|
||||
@@ -861,24 +853,17 @@ impl LocalContainerService {
|
||||
}
|
||||
}
|
||||
|
||||
// If this diff would be omitted and we already sent a full-content
|
||||
// version of this path earlier in the stream, skip sending a
|
||||
// degrading replacement.
|
||||
if diff.content_omitted {
|
||||
if full_sent_paths.read().unwrap().contains(&file_path) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// Track that we have sent a full-content version
|
||||
{
|
||||
let mut guard = full_sent_paths.write().unwrap();
|
||||
guard.insert(file_path.clone());
|
||||
}
|
||||
let mut guard = full_sent_paths.write().unwrap();
|
||||
guard.insert(file_path.clone());
|
||||
}
|
||||
|
||||
let patch = ConversationPatch::add_diff(escape_json_pointer_segment(&file_path), diff);
|
||||
let event = LogMsg::JsonPatch(patch).to_sse_event();
|
||||
events.push(event);
|
||||
msgs.push(LogMsg::JsonPatch(patch));
|
||||
}
|
||||
|
||||
// Remove files that changed but no longer have diffs
|
||||
@@ -886,12 +871,11 @@ impl LocalContainerService {
|
||||
if !files_with_diffs.contains(changed_path) {
|
||||
let patch =
|
||||
ConversationPatch::remove_diff(escape_json_pointer_segment(changed_path));
|
||||
let event = LogMsg::JsonPatch(patch).to_sse_event();
|
||||
events.push(event);
|
||||
msgs.push(LogMsg::JsonPatch(patch));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(events)
|
||||
Ok(msgs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1151,10 +1135,10 @@ impl ContainerService for LocalContainerService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_diff(
|
||||
async fn stream_diff(
|
||||
&self,
|
||||
task_attempt: &TaskAttempt,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>
|
||||
{
|
||||
let project_repo_path = self.get_project_repo_path(task_attempt).await?;
|
||||
let latest_merge =
|
||||
@@ -1177,7 +1161,6 @@ impl ContainerService for LocalContainerService {
|
||||
false
|
||||
};
|
||||
|
||||
// Show merged diff when no new work is on the branch or container
|
||||
if let Some(merge) = &latest_merge
|
||||
&& let Some(commit) = merge.merge_commit()
|
||||
&& self.is_container_clean(task_attempt).await?
|
||||
@@ -1186,17 +1169,14 @@ impl ContainerService for LocalContainerService {
|
||||
return self.create_merged_diff_stream(&project_repo_path, &commit);
|
||||
}
|
||||
|
||||
// worktree is needed for non-merged diffs
|
||||
let container_ref = self.ensure_container_exists(task_attempt).await?;
|
||||
let worktree_path = PathBuf::from(container_ref);
|
||||
|
||||
let base_commit = self.git().get_base_commit(
|
||||
&project_repo_path,
|
||||
&task_branch,
|
||||
&task_attempt.base_branch,
|
||||
)?;
|
||||
|
||||
// Handle ongoing attempts (live streaming diff)
|
||||
self.create_live_diff_stream(&worktree_path, &base_commit)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use axum::{
|
||||
BoxError, Extension, Json, Router,
|
||||
Extension, Json, Router,
|
||||
extract::{
|
||||
Query, State,
|
||||
ws::{WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
http::StatusCode,
|
||||
middleware::from_fn_with_state,
|
||||
response::{
|
||||
IntoResponse, Json as ResponseJson, Sse,
|
||||
sse::{Event, KeepAlive},
|
||||
},
|
||||
response::{IntoResponse, Json as ResponseJson},
|
||||
routing::{get, post},
|
||||
};
|
||||
use db::models::{
|
||||
@@ -32,7 +29,6 @@ use executors::{
|
||||
},
|
||||
profile::ExecutorProfileId,
|
||||
};
|
||||
use futures_util::TryStreamExt;
|
||||
use git2::BranchType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use services::services::{
|
||||
@@ -962,14 +958,50 @@ pub async fn replace_process(
|
||||
})))
|
||||
}
|
||||
|
||||
pub async fn get_task_attempt_diff(
|
||||
#[axum::debug_handler]
|
||||
pub async fn stream_task_attempt_diff_ws(
|
||||
ws: WebSocketUpgrade,
|
||||
Extension(task_attempt): Extension<TaskAttempt>,
|
||||
State(deployment): State<DeploymentImpl>,
|
||||
// ) -> Result<ResponseJson<ApiResponse<Diff>>, ApiError> {
|
||||
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, BoxError>>>, ApiError> {
|
||||
let stream = deployment.container().get_diff(&task_attempt).await?;
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| async move {
|
||||
if let Err(e) = handle_task_attempt_diff_ws(socket, deployment, task_attempt).await {
|
||||
tracing::warn!("diff WS closed: {}", e);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Ok(Sse::new(stream.map_err(|e| -> BoxError { e.into() })).keep_alive(KeepAlive::default()))
|
||||
async fn handle_task_attempt_diff_ws(
|
||||
socket: WebSocket,
|
||||
deployment: DeploymentImpl,
|
||||
task_attempt: TaskAttempt,
|
||||
) -> anyhow::Result<()> {
|
||||
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
||||
use utils::log_msg::LogMsg;
|
||||
|
||||
let mut stream = deployment
|
||||
.container()
|
||||
.stream_diff(&task_attempt)
|
||||
.await?
|
||||
.map_ok(|msg: LogMsg| msg.to_ws_message_unchecked());
|
||||
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(msg) => {
|
||||
if sender.send(msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("stream error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, TS)]
|
||||
@@ -1730,7 +1762,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
|
||||
.route("/commit-compare", get(compare_commit_to_head))
|
||||
.route("/start-dev-server", post(start_dev_server))
|
||||
.route("/branch-status", get(get_task_attempt_branch_status))
|
||||
.route("/diff", get(get_task_attempt_diff))
|
||||
.route("/diff/ws", get(stream_task_attempt_diff_ws))
|
||||
.route("/merge", post(merge_task_attempt))
|
||||
.route("/push", post(push_task_attempt_branch))
|
||||
.route("/rebase", post(rebase_task_attempt))
|
||||
|
||||
@@ -6,7 +6,6 @@ use std::{
|
||||
|
||||
use anyhow::{Error as AnyhowError, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use axum::response::sse::Event;
|
||||
use db::{
|
||||
DBService,
|
||||
models::{
|
||||
@@ -194,10 +193,11 @@ pub trait ContainerService {
|
||||
copy_files: &str,
|
||||
) -> Result<(), ContainerError>;
|
||||
|
||||
async fn get_diff(
|
||||
/// Stream diff updates as LogMsg for WebSocket endpoints.
|
||||
async fn stream_diff(
|
||||
&self,
|
||||
task_attempt: &TaskAttempt,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, ContainerError>;
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, ContainerError>;
|
||||
|
||||
/// Fetch the MsgStore for a given execution ID, panicking if missing.
|
||||
async fn get_msg_store_by_id(&self, uuid: &Uuid) -> Option<Arc<MsgStore>> {
|
||||
|
||||
@@ -44,7 +44,7 @@ pub struct GitService {}
|
||||
|
||||
// Max inline diff size for UI (in bytes). Files larger than this will have
|
||||
// their contents omitted from the diff stream to avoid UI crashes.
|
||||
const MAX_INLINE_DIFF_BYTES: usize = 150 * 1024; // ~150KB
|
||||
const MAX_INLINE_DIFF_BYTES: usize = 2 * 1024 * 1024; // ~2MB
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, TS, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useCallback } from 'react';
|
||||
import type { PatchType } from 'shared/types';
|
||||
import { useJsonPatchStream } from './useJsonPatchStream';
|
||||
import { useJsonPatchWsStream } from './useJsonPatchWsStream';
|
||||
|
||||
interface DiffState {
|
||||
entries: Record<string, PatchType>;
|
||||
@@ -17,7 +17,7 @@ export const useDiffStream = (
|
||||
enabled: boolean
|
||||
): UseDiffStreamResult => {
|
||||
const endpoint = attemptId
|
||||
? `/api/task-attempts/${attemptId}/diff`
|
||||
? `/api/task-attempts/${attemptId}/diff/ws`
|
||||
: undefined;
|
||||
|
||||
const initialData = useCallback(
|
||||
@@ -27,7 +27,7 @@ export const useDiffStream = (
|
||||
[]
|
||||
);
|
||||
|
||||
const { data, isConnected, error } = useJsonPatchStream(
|
||||
const { data, isConnected, error } = useJsonPatchWsStream(
|
||||
endpoint,
|
||||
enabled && !!attemptId,
|
||||
initialData
|
||||
|
||||
Reference in New Issue
Block a user