From 46d3f3c7df7f6ef22285814a2ada0bfe23827e94 Mon Sep 17 00:00:00 2001 From: Solomon Date: Thu, 18 Sep 2025 16:11:46 +0100 Subject: [PATCH] Migrate followup draft SSE to WebSockets (#776) --- crates/server/src/routes/task_attempts.rs | 65 ++++++++++++++----- crates/services/src/services/events.rs | 50 +++++++------- .../src/hooks/follow-up/useDraftStream.ts | 10 +-- frontend/src/hooks/useJsonPatchWsStream.ts | 12 ++-- 4 files changed, 88 insertions(+), 49 deletions(-) diff --git a/crates/server/src/routes/task_attempts.rs b/crates/server/src/routes/task_attempts.rs index b9ed1ad8..940b1ab9 100644 --- a/crates/server/src/routes/task_attempts.rs +++ b/crates/server/src/routes/task_attempts.rs @@ -2,11 +2,14 @@ use std::path::PathBuf; use axum::{ BoxError, Extension, Json, Router, - extract::{Query, State}, + extract::{ + Query, State, + ws::{WebSocket, WebSocketUpgrade}, + }, http::StatusCode, middleware::from_fn_with_state, response::{ - Json as ResponseJson, Sse, + IntoResponse, Json as ResponseJson, Sse, sse::{Event, KeepAlive}, }, routing::{get, post}, @@ -513,22 +516,52 @@ pub async fn save_follow_up_draft( } #[axum::debug_handler] -pub async fn stream_follow_up_draft( +pub async fn stream_follow_up_draft_ws( + ws: WebSocketUpgrade, Extension(task_attempt): Extension, State(deployment): State, -) -> Result< - Sse>>>, - ApiError, -> { - let stream = deployment +) -> impl IntoResponse { + ws.on_upgrade(move |socket| async move { + if let Err(e) = handle_follow_up_draft_ws(socket, deployment, task_attempt.id).await { + tracing::warn!("follow-up draft WS closed: {}", e); + } + }) +} + +async fn handle_follow_up_draft_ws( + socket: WebSocket, + deployment: DeploymentImpl, + task_attempt_id: uuid::Uuid, +) -> anyhow::Result<()> { + use futures_util::{SinkExt, StreamExt, TryStreamExt}; + + let mut stream = deployment .events() - .stream_follow_up_draft_for_attempt(task_attempt.id) - .await - .map_err(|e| ApiError::from(deployment::DeploymentError::from(e)))?; - Ok( - Sse::new(stream.map_err(|e| -> Box { Box::new(e) })) - .keep_alive(KeepAlive::default()), - ) + .stream_follow_up_draft_for_attempt_raw(task_attempt_id) + .await? + .map_ok(|msg| msg.to_ws_message_unchecked()); + + // Split socket into sender and receiver + let (mut sender, mut receiver) = socket.split(); + + // Drain (and ignore) any client->server messages so pings/pongs work + tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} }); + + // Forward server messages + while let Some(item) = stream.next().await { + match item { + Ok(msg) => { + if sender.send(msg).await.is_err() { + break; + } + } + Err(e) => { + tracing::error!("stream error: {}", e); + break; + } + } + } + Ok(()) } #[axum::debug_handler] @@ -1690,7 +1723,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router { "/follow-up-draft", get(get_follow_up_draft).put(save_follow_up_draft), ) - .route("/follow-up-draft/stream", get(stream_follow_up_draft)) + .route("/follow-up-draft/stream/ws", get(stream_follow_up_draft_ws)) .route("/follow-up-draft/queue", post(set_follow_up_queue)) .route("/replace-process", post(replace_process)) .route("/commit-info", get(get_commit_info)) diff --git a/crates/services/src/services/events.rs b/crates/services/src/services/events.rs index bbc3e57a..63344195 100644 --- a/crates/services/src/services/events.rs +++ b/crates/services/src/services/events.rs @@ -1,7 +1,6 @@ use std::{str::FromStr, sync::Arc}; use anyhow::Error as AnyhowError; -use axum::response::sse::Event; use db::{ DBService, models::{ @@ -10,7 +9,7 @@ use db::{ task_attempt::TaskAttempt, }, }; -use futures::{StreamExt, TryStreamExt}; +use futures::StreamExt; use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation}; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -802,11 +801,11 @@ impl EventService { Ok(combined_stream) } - /// Stream follow-up draft for a specific task attempt with initial snapshot - pub async fn stream_follow_up_draft_for_attempt( + /// Stream follow-up draft for a specific task attempt (raw LogMsg format for WebSocket) + pub async fn stream_follow_up_draft_for_attempt_raw( &self, task_attempt_id: Uuid, - ) -> Result>, EventError> + ) -> Result>, EventError> { // Get initial snapshot of follow-up draft let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id( @@ -827,11 +826,13 @@ impl EventService { version: 0, }); - let initial_patch = json!([{ - "op": "replace", - "path": "/", - "value": { "follow_up_draft": draft } - }]); + let initial_patch = json!([ + { + "op": "replace", + "path": "/", + "value": { "follow_up_draft": draft } + } + ]); let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap()); // Filtered live stream, mapped into direct JSON patches that update /follow_up_draft @@ -848,11 +849,13 @@ impl EventService { RecordTypes::FollowUpDraft(draft) => { if draft.task_attempt_id == task_attempt_id { // Build a direct patch to replace /follow_up_draft - let direct = json!([{ - "op": "replace", - "path": "/follow_up_draft", - "value": draft - }]); + let direct = json!([ + { + "op": "replace", + "path": "/follow_up_draft", + "value": draft + } + ]); let direct_patch = serde_json::from_value(direct).unwrap(); return Some(Ok(LogMsg::JsonPatch(direct_patch))); } @@ -875,11 +878,13 @@ impl EventService { "updated_at": chrono::Utc::now(), "version": 0 }); - let direct = json!([{ - "op": "replace", - "path": "/follow_up_draft", - "value": empty - }]); + let direct = json!([ + { + "op": "replace", + "path": "/follow_up_draft", + "value": empty + } + ]); let direct_patch = serde_json::from_value(direct).unwrap(); return Some(Ok(LogMsg::JsonPatch(direct_patch))); } @@ -896,10 +901,7 @@ impl EventService { ); let initial_stream = futures::stream::once(async move { Ok(initial_msg) }); - let combined_stream = initial_stream - .chain(filtered_stream) - .map_ok(|msg| msg.to_sse_event()) - .boxed(); + let combined_stream = initial_stream.chain(filtered_stream).boxed(); Ok(combined_stream) } diff --git a/frontend/src/hooks/follow-up/useDraftStream.ts b/frontend/src/hooks/follow-up/useDraftStream.ts index 8e0f701f..a64030ba 100644 --- a/frontend/src/hooks/follow-up/useDraftStream.ts +++ b/frontend/src/hooks/follow-up/useDraftStream.ts @@ -1,5 +1,5 @@ import { useCallback, useEffect, useRef, useState } from 'react'; -import { useJsonPatchStream } from '@/hooks/useJsonPatchStream'; +import { useJsonPatchWsStream } from '@/hooks/useJsonPatchWsStream'; import { attemptsApi } from '@/lib/api'; import type { FollowUpDraft } from 'shared/types'; import { inIframe } from '@/vscode/bridge'; @@ -14,7 +14,7 @@ export function useDraftStream(attemptId?: string) { const forceNextApplyRef = useRef(false); const endpoint = attemptId - ? `/api/task-attempts/${attemptId}/follow-up-draft/stream` + ? `/api/task-attempts/${attemptId}/follow-up-draft/stream/ws` : undefined; const makeInitial = useCallback( @@ -35,7 +35,7 @@ export function useDraftStream(attemptId?: string) { [attemptId] ); - const { data, isConnected, error } = useJsonPatchStream( + const { data, isConnected, error } = useJsonPatchWsStream( endpoint, !!endpoint, makeInitial @@ -64,7 +64,7 @@ export function useDraftStream(attemptId?: string) { }); if (!isDraftLoaded) setIsDraftLoaded(true); } catch { - // ignore, rely on SSE + // ignore, rely on stream } }; hydrate(); @@ -73,7 +73,7 @@ export function useDraftStream(attemptId?: string) { }; }, [attemptId, isDraftLoaded]); - // Handle SSE stream + // Handle stream updates useEffect(() => { if (!data) return; const d = data.follow_up_draft; diff --git a/frontend/src/hooks/useJsonPatchWsStream.ts b/frontend/src/hooks/useJsonPatchWsStream.ts index 63cc6b04..5f3588fa 100644 --- a/frontend/src/hooks/useJsonPatchWsStream.ts +++ b/frontend/src/hooks/useJsonPatchWsStream.ts @@ -2,6 +2,10 @@ import { useEffect, useState, useRef } from 'react'; import { applyPatch } from 'rfc6902'; import type { Operation } from 'rfc6902'; +type WsJsonPatchMsg = { JsonPatch: Operation[] }; +type WsFinishedMsg = { finished: boolean }; +type WsMsg = WsJsonPatchMsg | WsFinishedMsg; + interface UseJsonPatchStreamOptions { /** * Called once when the stream starts to inject initial data @@ -98,10 +102,10 @@ export const useJsonPatchWsStream = ( ws.onmessage = (event) => { try { - const msg = JSON.parse(event.data); + const msg: WsMsg = JSON.parse(event.data); // Handle JsonPatch messages (same as SSE json_patch event) - if (msg.JsonPatch) { + if ('JsonPatch' in msg) { const patches: Operation[] = msg.JsonPatch; const filtered = options.deduplicatePatches ? options.deduplicatePatches(patches) @@ -119,8 +123,8 @@ export const useJsonPatchWsStream = ( setData(dataRef.current); } - // Handle Finished messages (same as SSE finished event) - if (msg.Finished !== undefined) { + // Handle finished messages ({finished: true}) + if ('finished' in msg) { ws.close(); wsRef.current = null; setIsConnected(false);