Migrate followup draft SSE to WebSockets (#776)

This commit is contained in:
Solomon
2025-09-18 16:11:46 +01:00
committed by GitHub
parent 0c10e42f64
commit 46d3f3c7df
4 changed files with 88 additions and 49 deletions

View File

@@ -2,11 +2,14 @@ use std::path::PathBuf;
use axum::{ use axum::{
BoxError, Extension, Json, Router, BoxError, Extension, Json, Router,
extract::{Query, State}, extract::{
Query, State,
ws::{WebSocket, WebSocketUpgrade},
},
http::StatusCode, http::StatusCode,
middleware::from_fn_with_state, middleware::from_fn_with_state,
response::{ response::{
Json as ResponseJson, Sse, IntoResponse, Json as ResponseJson, Sse,
sse::{Event, KeepAlive}, sse::{Event, KeepAlive},
}, },
routing::{get, post}, routing::{get, post},
@@ -513,22 +516,52 @@ pub async fn save_follow_up_draft(
} }
#[axum::debug_handler] #[axum::debug_handler]
pub async fn stream_follow_up_draft( pub async fn stream_follow_up_draft_ws(
ws: WebSocketUpgrade,
Extension(task_attempt): Extension<TaskAttempt>, Extension(task_attempt): Extension<TaskAttempt>,
State(deployment): State<DeploymentImpl>, State(deployment): State<DeploymentImpl>,
) -> Result< ) -> impl IntoResponse {
Sse<impl futures_util::Stream<Item = Result<Event, Box<dyn std::error::Error + Send + Sync>>>>, ws.on_upgrade(move |socket| async move {
ApiError, if let Err(e) = handle_follow_up_draft_ws(socket, deployment, task_attempt.id).await {
> { tracing::warn!("follow-up draft WS closed: {}", e);
let stream = deployment }
})
}
async fn handle_follow_up_draft_ws(
socket: WebSocket,
deployment: DeploymentImpl,
task_attempt_id: uuid::Uuid,
) -> anyhow::Result<()> {
use futures_util::{SinkExt, StreamExt, TryStreamExt};
let mut stream = deployment
.events() .events()
.stream_follow_up_draft_for_attempt(task_attempt.id) .stream_follow_up_draft_for_attempt_raw(task_attempt_id)
.await .await?
.map_err(|e| ApiError::from(deployment::DeploymentError::from(e)))?; .map_ok(|msg| msg.to_ws_message_unchecked());
Ok(
Sse::new(stream.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })) // Split socket into sender and receiver
.keep_alive(KeepAlive::default()), let (mut sender, mut receiver) = socket.split();
)
// Drain (and ignore) any client->server messages so pings/pongs work
tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
// Forward server messages
while let Some(item) = stream.next().await {
match item {
Ok(msg) => {
if sender.send(msg).await.is_err() {
break;
}
}
Err(e) => {
tracing::error!("stream error: {}", e);
break;
}
}
}
Ok(())
} }
#[axum::debug_handler] #[axum::debug_handler]
@@ -1690,7 +1723,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
"/follow-up-draft", "/follow-up-draft",
get(get_follow_up_draft).put(save_follow_up_draft), get(get_follow_up_draft).put(save_follow_up_draft),
) )
.route("/follow-up-draft/stream", get(stream_follow_up_draft)) .route("/follow-up-draft/stream/ws", get(stream_follow_up_draft_ws))
.route("/follow-up-draft/queue", post(set_follow_up_queue)) .route("/follow-up-draft/queue", post(set_follow_up_queue))
.route("/replace-process", post(replace_process)) .route("/replace-process", post(replace_process))
.route("/commit-info", get(get_commit_info)) .route("/commit-info", get(get_commit_info))

View File

@@ -1,7 +1,6 @@
use std::{str::FromStr, sync::Arc}; use std::{str::FromStr, sync::Arc};
use anyhow::Error as AnyhowError; use anyhow::Error as AnyhowError;
use axum::response::sse::Event;
use db::{ use db::{
DBService, DBService,
models::{ models::{
@@ -10,7 +9,7 @@ use db::{
task_attempt::TaskAttempt, task_attempt::TaskAttempt,
}, },
}; };
use futures::{StreamExt, TryStreamExt}; use futures::StreamExt;
use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation}; use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
@@ -802,11 +801,11 @@ impl EventService {
Ok(combined_stream) Ok(combined_stream)
} }
/// Stream follow-up draft for a specific task attempt with initial snapshot /// Stream follow-up draft for a specific task attempt (raw LogMsg format for WebSocket)
pub async fn stream_follow_up_draft_for_attempt( pub async fn stream_follow_up_draft_for_attempt_raw(
&self, &self,
task_attempt_id: Uuid, task_attempt_id: Uuid,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, EventError> ) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
{ {
// Get initial snapshot of follow-up draft // Get initial snapshot of follow-up draft
let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id( let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id(
@@ -827,11 +826,13 @@ impl EventService {
version: 0, version: 0,
}); });
let initial_patch = json!([{ let initial_patch = json!([
"op": "replace", {
"path": "/", "op": "replace",
"value": { "follow_up_draft": draft } "path": "/",
}]); "value": { "follow_up_draft": draft }
}
]);
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap()); let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
// Filtered live stream, mapped into direct JSON patches that update /follow_up_draft // Filtered live stream, mapped into direct JSON patches that update /follow_up_draft
@@ -848,11 +849,13 @@ impl EventService {
RecordTypes::FollowUpDraft(draft) => { RecordTypes::FollowUpDraft(draft) => {
if draft.task_attempt_id == task_attempt_id { if draft.task_attempt_id == task_attempt_id {
// Build a direct patch to replace /follow_up_draft // Build a direct patch to replace /follow_up_draft
let direct = json!([{ let direct = json!([
"op": "replace", {
"path": "/follow_up_draft", "op": "replace",
"value": draft "path": "/follow_up_draft",
}]); "value": draft
}
]);
let direct_patch = serde_json::from_value(direct).unwrap(); let direct_patch = serde_json::from_value(direct).unwrap();
return Some(Ok(LogMsg::JsonPatch(direct_patch))); return Some(Ok(LogMsg::JsonPatch(direct_patch)));
} }
@@ -875,11 +878,13 @@ impl EventService {
"updated_at": chrono::Utc::now(), "updated_at": chrono::Utc::now(),
"version": 0 "version": 0
}); });
let direct = json!([{ let direct = json!([
"op": "replace", {
"path": "/follow_up_draft", "op": "replace",
"value": empty "path": "/follow_up_draft",
}]); "value": empty
}
]);
let direct_patch = serde_json::from_value(direct).unwrap(); let direct_patch = serde_json::from_value(direct).unwrap();
return Some(Ok(LogMsg::JsonPatch(direct_patch))); return Some(Ok(LogMsg::JsonPatch(direct_patch)));
} }
@@ -896,10 +901,7 @@ impl EventService {
); );
let initial_stream = futures::stream::once(async move { Ok(initial_msg) }); let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
let combined_stream = initial_stream let combined_stream = initial_stream.chain(filtered_stream).boxed();
.chain(filtered_stream)
.map_ok(|msg| msg.to_sse_event())
.boxed();
Ok(combined_stream) Ok(combined_stream)
} }

View File

@@ -1,5 +1,5 @@
import { useCallback, useEffect, useRef, useState } from 'react'; import { useCallback, useEffect, useRef, useState } from 'react';
import { useJsonPatchStream } from '@/hooks/useJsonPatchStream'; import { useJsonPatchWsStream } from '@/hooks/useJsonPatchWsStream';
import { attemptsApi } from '@/lib/api'; import { attemptsApi } from '@/lib/api';
import type { FollowUpDraft } from 'shared/types'; import type { FollowUpDraft } from 'shared/types';
import { inIframe } from '@/vscode/bridge'; import { inIframe } from '@/vscode/bridge';
@@ -14,7 +14,7 @@ export function useDraftStream(attemptId?: string) {
const forceNextApplyRef = useRef<boolean>(false); const forceNextApplyRef = useRef<boolean>(false);
const endpoint = attemptId const endpoint = attemptId
? `/api/task-attempts/${attemptId}/follow-up-draft/stream` ? `/api/task-attempts/${attemptId}/follow-up-draft/stream/ws`
: undefined; : undefined;
const makeInitial = useCallback( const makeInitial = useCallback(
@@ -35,7 +35,7 @@ export function useDraftStream(attemptId?: string) {
[attemptId] [attemptId]
); );
const { data, isConnected, error } = useJsonPatchStream<DraftStreamState>( const { data, isConnected, error } = useJsonPatchWsStream<DraftStreamState>(
endpoint, endpoint,
!!endpoint, !!endpoint,
makeInitial makeInitial
@@ -64,7 +64,7 @@ export function useDraftStream(attemptId?: string) {
}); });
if (!isDraftLoaded) setIsDraftLoaded(true); if (!isDraftLoaded) setIsDraftLoaded(true);
} catch { } catch {
// ignore, rely on SSE // ignore, rely on stream
} }
}; };
hydrate(); hydrate();
@@ -73,7 +73,7 @@ export function useDraftStream(attemptId?: string) {
}; };
}, [attemptId, isDraftLoaded]); }, [attemptId, isDraftLoaded]);
// Handle SSE stream // Handle stream updates
useEffect(() => { useEffect(() => {
if (!data) return; if (!data) return;
const d = data.follow_up_draft; const d = data.follow_up_draft;

View File

@@ -2,6 +2,10 @@ import { useEffect, useState, useRef } from 'react';
import { applyPatch } from 'rfc6902'; import { applyPatch } from 'rfc6902';
import type { Operation } from 'rfc6902'; import type { Operation } from 'rfc6902';
type WsJsonPatchMsg = { JsonPatch: Operation[] };
type WsFinishedMsg = { finished: boolean };
type WsMsg = WsJsonPatchMsg | WsFinishedMsg;
interface UseJsonPatchStreamOptions<T> { interface UseJsonPatchStreamOptions<T> {
/** /**
* Called once when the stream starts to inject initial data * Called once when the stream starts to inject initial data
@@ -98,10 +102,10 @@ export const useJsonPatchWsStream = <T>(
ws.onmessage = (event) => { ws.onmessage = (event) => {
try { try {
const msg = JSON.parse(event.data); const msg: WsMsg = JSON.parse(event.data);
// Handle JsonPatch messages (same as SSE json_patch event) // Handle JsonPatch messages (same as SSE json_patch event)
if (msg.JsonPatch) { if ('JsonPatch' in msg) {
const patches: Operation[] = msg.JsonPatch; const patches: Operation[] = msg.JsonPatch;
const filtered = options.deduplicatePatches const filtered = options.deduplicatePatches
? options.deduplicatePatches(patches) ? options.deduplicatePatches(patches)
@@ -119,8 +123,8 @@ export const useJsonPatchWsStream = <T>(
setData(dataRef.current); setData(dataRef.current);
} }
// Handle Finished messages (same as SSE finished event) // Handle finished messages ({finished: true})
if (msg.Finished !== undefined) { if ('finished' in msg) {
ws.close(); ws.close();
wsRef.current = null; wsRef.current = null;
setIsConnected(false); setIsConnected(false);