Migrate followup draft SSE to WebSockets (#776)

This commit is contained in:
Solomon
2025-09-18 16:11:46 +01:00
committed by GitHub
parent 0c10e42f64
commit 46d3f3c7df
4 changed files with 88 additions and 49 deletions

View File

@@ -2,11 +2,14 @@ use std::path::PathBuf;
use axum::{
BoxError, Extension, Json, Router,
extract::{Query, State},
extract::{
Query, State,
ws::{WebSocket, WebSocketUpgrade},
},
http::StatusCode,
middleware::from_fn_with_state,
response::{
Json as ResponseJson, Sse,
IntoResponse, Json as ResponseJson, Sse,
sse::{Event, KeepAlive},
},
routing::{get, post},
@@ -513,22 +516,52 @@ pub async fn save_follow_up_draft(
}
#[axum::debug_handler]
pub async fn stream_follow_up_draft(
pub async fn stream_follow_up_draft_ws(
ws: WebSocketUpgrade,
Extension(task_attempt): Extension<TaskAttempt>,
State(deployment): State<DeploymentImpl>,
) -> Result<
Sse<impl futures_util::Stream<Item = Result<Event, Box<dyn std::error::Error + Send + Sync>>>>,
ApiError,
> {
let stream = deployment
) -> impl IntoResponse {
ws.on_upgrade(move |socket| async move {
if let Err(e) = handle_follow_up_draft_ws(socket, deployment, task_attempt.id).await {
tracing::warn!("follow-up draft WS closed: {}", e);
}
})
}
async fn handle_follow_up_draft_ws(
socket: WebSocket,
deployment: DeploymentImpl,
task_attempt_id: uuid::Uuid,
) -> anyhow::Result<()> {
use futures_util::{SinkExt, StreamExt, TryStreamExt};
let mut stream = deployment
.events()
.stream_follow_up_draft_for_attempt(task_attempt.id)
.await
.map_err(|e| ApiError::from(deployment::DeploymentError::from(e)))?;
Ok(
Sse::new(stream.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) }))
.keep_alive(KeepAlive::default()),
)
.stream_follow_up_draft_for_attempt_raw(task_attempt_id)
.await?
.map_ok(|msg| msg.to_ws_message_unchecked());
// Split socket into sender and receiver
let (mut sender, mut receiver) = socket.split();
// Drain (and ignore) any client->server messages so pings/pongs work
tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
// Forward server messages
while let Some(item) = stream.next().await {
match item {
Ok(msg) => {
if sender.send(msg).await.is_err() {
break;
}
}
Err(e) => {
tracing::error!("stream error: {}", e);
break;
}
}
}
Ok(())
}
#[axum::debug_handler]
@@ -1690,7 +1723,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
"/follow-up-draft",
get(get_follow_up_draft).put(save_follow_up_draft),
)
.route("/follow-up-draft/stream", get(stream_follow_up_draft))
.route("/follow-up-draft/stream/ws", get(stream_follow_up_draft_ws))
.route("/follow-up-draft/queue", post(set_follow_up_queue))
.route("/replace-process", post(replace_process))
.route("/commit-info", get(get_commit_info))

View File

@@ -1,7 +1,6 @@
use std::{str::FromStr, sync::Arc};
use anyhow::Error as AnyhowError;
use axum::response::sse::Event;
use db::{
DBService,
models::{
@@ -10,7 +9,7 @@ use db::{
task_attempt::TaskAttempt,
},
};
use futures::{StreamExt, TryStreamExt};
use futures::StreamExt;
use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation};
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -802,11 +801,11 @@ impl EventService {
Ok(combined_stream)
}
/// Stream follow-up draft for a specific task attempt with initial snapshot
pub async fn stream_follow_up_draft_for_attempt(
/// Stream follow-up draft for a specific task attempt (raw LogMsg format for WebSocket)
pub async fn stream_follow_up_draft_for_attempt_raw(
&self,
task_attempt_id: Uuid,
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, EventError>
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
{
// Get initial snapshot of follow-up draft
let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id(
@@ -827,11 +826,13 @@ impl EventService {
version: 0,
});
let initial_patch = json!([{
"op": "replace",
"path": "/",
"value": { "follow_up_draft": draft }
}]);
let initial_patch = json!([
{
"op": "replace",
"path": "/",
"value": { "follow_up_draft": draft }
}
]);
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
// Filtered live stream, mapped into direct JSON patches that update /follow_up_draft
@@ -848,11 +849,13 @@ impl EventService {
RecordTypes::FollowUpDraft(draft) => {
if draft.task_attempt_id == task_attempt_id {
// Build a direct patch to replace /follow_up_draft
let direct = json!([{
"op": "replace",
"path": "/follow_up_draft",
"value": draft
}]);
let direct = json!([
{
"op": "replace",
"path": "/follow_up_draft",
"value": draft
}
]);
let direct_patch = serde_json::from_value(direct).unwrap();
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
}
@@ -875,11 +878,13 @@ impl EventService {
"updated_at": chrono::Utc::now(),
"version": 0
});
let direct = json!([{
"op": "replace",
"path": "/follow_up_draft",
"value": empty
}]);
let direct = json!([
{
"op": "replace",
"path": "/follow_up_draft",
"value": empty
}
]);
let direct_patch = serde_json::from_value(direct).unwrap();
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
}
@@ -896,10 +901,7 @@ impl EventService {
);
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
let combined_stream = initial_stream
.chain(filtered_stream)
.map_ok(|msg| msg.to_sse_event())
.boxed();
let combined_stream = initial_stream.chain(filtered_stream).boxed();
Ok(combined_stream)
}

View File

@@ -1,5 +1,5 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import { useJsonPatchStream } from '@/hooks/useJsonPatchStream';
import { useJsonPatchWsStream } from '@/hooks/useJsonPatchWsStream';
import { attemptsApi } from '@/lib/api';
import type { FollowUpDraft } from 'shared/types';
import { inIframe } from '@/vscode/bridge';
@@ -14,7 +14,7 @@ export function useDraftStream(attemptId?: string) {
const forceNextApplyRef = useRef<boolean>(false);
const endpoint = attemptId
? `/api/task-attempts/${attemptId}/follow-up-draft/stream`
? `/api/task-attempts/${attemptId}/follow-up-draft/stream/ws`
: undefined;
const makeInitial = useCallback(
@@ -35,7 +35,7 @@ export function useDraftStream(attemptId?: string) {
[attemptId]
);
const { data, isConnected, error } = useJsonPatchStream<DraftStreamState>(
const { data, isConnected, error } = useJsonPatchWsStream<DraftStreamState>(
endpoint,
!!endpoint,
makeInitial
@@ -64,7 +64,7 @@ export function useDraftStream(attemptId?: string) {
});
if (!isDraftLoaded) setIsDraftLoaded(true);
} catch {
// ignore, rely on SSE
// ignore, rely on stream
}
};
hydrate();
@@ -73,7 +73,7 @@ export function useDraftStream(attemptId?: string) {
};
}, [attemptId, isDraftLoaded]);
// Handle SSE stream
// Handle stream updates
useEffect(() => {
if (!data) return;
const d = data.follow_up_draft;

View File

@@ -2,6 +2,10 @@ import { useEffect, useState, useRef } from 'react';
import { applyPatch } from 'rfc6902';
import type { Operation } from 'rfc6902';
type WsJsonPatchMsg = { JsonPatch: Operation[] };
type WsFinishedMsg = { finished: boolean };
type WsMsg = WsJsonPatchMsg | WsFinishedMsg;
interface UseJsonPatchStreamOptions<T> {
/**
* Called once when the stream starts to inject initial data
@@ -98,10 +102,10 @@ export const useJsonPatchWsStream = <T>(
ws.onmessage = (event) => {
try {
const msg = JSON.parse(event.data);
const msg: WsMsg = JSON.parse(event.data);
// Handle JsonPatch messages (same as SSE json_patch event)
if (msg.JsonPatch) {
if ('JsonPatch' in msg) {
const patches: Operation[] = msg.JsonPatch;
const filtered = options.deduplicatePatches
? options.deduplicatePatches(patches)
@@ -119,8 +123,8 @@ export const useJsonPatchWsStream = <T>(
setData(dataRef.current);
}
// Handle Finished messages (same as SSE finished event)
if (msg.Finished !== undefined) {
// Handle finished messages ({finished: true})
if ('finished' in msg) {
ws.close();
wsRef.current = null;
setIsConnected(false);