Migrate followup draft SSE to WebSockets (#776)
This commit is contained in:
@@ -2,11 +2,14 @@ use std::path::PathBuf;
|
||||
|
||||
use axum::{
|
||||
BoxError, Extension, Json, Router,
|
||||
extract::{Query, State},
|
||||
extract::{
|
||||
Query, State,
|
||||
ws::{WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
http::StatusCode,
|
||||
middleware::from_fn_with_state,
|
||||
response::{
|
||||
Json as ResponseJson, Sse,
|
||||
IntoResponse, Json as ResponseJson, Sse,
|
||||
sse::{Event, KeepAlive},
|
||||
},
|
||||
routing::{get, post},
|
||||
@@ -513,22 +516,52 @@ pub async fn save_follow_up_draft(
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn stream_follow_up_draft(
|
||||
pub async fn stream_follow_up_draft_ws(
|
||||
ws: WebSocketUpgrade,
|
||||
Extension(task_attempt): Extension<TaskAttempt>,
|
||||
State(deployment): State<DeploymentImpl>,
|
||||
) -> Result<
|
||||
Sse<impl futures_util::Stream<Item = Result<Event, Box<dyn std::error::Error + Send + Sync>>>>,
|
||||
ApiError,
|
||||
> {
|
||||
let stream = deployment
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| async move {
|
||||
if let Err(e) = handle_follow_up_draft_ws(socket, deployment, task_attempt.id).await {
|
||||
tracing::warn!("follow-up draft WS closed: {}", e);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_follow_up_draft_ws(
|
||||
socket: WebSocket,
|
||||
deployment: DeploymentImpl,
|
||||
task_attempt_id: uuid::Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
||||
|
||||
let mut stream = deployment
|
||||
.events()
|
||||
.stream_follow_up_draft_for_attempt(task_attempt.id)
|
||||
.await
|
||||
.map_err(|e| ApiError::from(deployment::DeploymentError::from(e)))?;
|
||||
Ok(
|
||||
Sse::new(stream.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) }))
|
||||
.keep_alive(KeepAlive::default()),
|
||||
)
|
||||
.stream_follow_up_draft_for_attempt_raw(task_attempt_id)
|
||||
.await?
|
||||
.map_ok(|msg| msg.to_ws_message_unchecked());
|
||||
|
||||
// Split socket into sender and receiver
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Drain (and ignore) any client->server messages so pings/pongs work
|
||||
tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
|
||||
|
||||
// Forward server messages
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(msg) => {
|
||||
if sender.send(msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("stream error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
@@ -1690,7 +1723,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
|
||||
"/follow-up-draft",
|
||||
get(get_follow_up_draft).put(save_follow_up_draft),
|
||||
)
|
||||
.route("/follow-up-draft/stream", get(stream_follow_up_draft))
|
||||
.route("/follow-up-draft/stream/ws", get(stream_follow_up_draft_ws))
|
||||
.route("/follow-up-draft/queue", post(set_follow_up_queue))
|
||||
.route("/replace-process", post(replace_process))
|
||||
.route("/commit-info", get(get_commit_info))
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
use anyhow::Error as AnyhowError;
|
||||
use axum::response::sse::Event;
|
||||
use db::{
|
||||
DBService,
|
||||
models::{
|
||||
@@ -10,7 +9,7 @@ use db::{
|
||||
task_attempt::TaskAttempt,
|
||||
},
|
||||
};
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use futures::StreamExt;
|
||||
use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
@@ -802,11 +801,11 @@ impl EventService {
|
||||
Ok(combined_stream)
|
||||
}
|
||||
|
||||
/// Stream follow-up draft for a specific task attempt with initial snapshot
|
||||
pub async fn stream_follow_up_draft_for_attempt(
|
||||
/// Stream follow-up draft for a specific task attempt (raw LogMsg format for WebSocket)
|
||||
pub async fn stream_follow_up_draft_for_attempt_raw(
|
||||
&self,
|
||||
task_attempt_id: Uuid,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, EventError>
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
|
||||
{
|
||||
// Get initial snapshot of follow-up draft
|
||||
let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id(
|
||||
@@ -827,11 +826,13 @@ impl EventService {
|
||||
version: 0,
|
||||
});
|
||||
|
||||
let initial_patch = json!([{
|
||||
"op": "replace",
|
||||
"path": "/",
|
||||
"value": { "follow_up_draft": draft }
|
||||
}]);
|
||||
let initial_patch = json!([
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/",
|
||||
"value": { "follow_up_draft": draft }
|
||||
}
|
||||
]);
|
||||
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
|
||||
|
||||
// Filtered live stream, mapped into direct JSON patches that update /follow_up_draft
|
||||
@@ -848,11 +849,13 @@ impl EventService {
|
||||
RecordTypes::FollowUpDraft(draft) => {
|
||||
if draft.task_attempt_id == task_attempt_id {
|
||||
// Build a direct patch to replace /follow_up_draft
|
||||
let direct = json!([{
|
||||
"op": "replace",
|
||||
"path": "/follow_up_draft",
|
||||
"value": draft
|
||||
}]);
|
||||
let direct = json!([
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/follow_up_draft",
|
||||
"value": draft
|
||||
}
|
||||
]);
|
||||
let direct_patch = serde_json::from_value(direct).unwrap();
|
||||
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
|
||||
}
|
||||
@@ -875,11 +878,13 @@ impl EventService {
|
||||
"updated_at": chrono::Utc::now(),
|
||||
"version": 0
|
||||
});
|
||||
let direct = json!([{
|
||||
"op": "replace",
|
||||
"path": "/follow_up_draft",
|
||||
"value": empty
|
||||
}]);
|
||||
let direct = json!([
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/follow_up_draft",
|
||||
"value": empty
|
||||
}
|
||||
]);
|
||||
let direct_patch = serde_json::from_value(direct).unwrap();
|
||||
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
|
||||
}
|
||||
@@ -896,10 +901,7 @@ impl EventService {
|
||||
);
|
||||
|
||||
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
|
||||
let combined_stream = initial_stream
|
||||
.chain(filtered_stream)
|
||||
.map_ok(|msg| msg.to_sse_event())
|
||||
.boxed();
|
||||
let combined_stream = initial_stream.chain(filtered_stream).boxed();
|
||||
|
||||
Ok(combined_stream)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useJsonPatchStream } from '@/hooks/useJsonPatchStream';
|
||||
import { useJsonPatchWsStream } from '@/hooks/useJsonPatchWsStream';
|
||||
import { attemptsApi } from '@/lib/api';
|
||||
import type { FollowUpDraft } from 'shared/types';
|
||||
import { inIframe } from '@/vscode/bridge';
|
||||
@@ -14,7 +14,7 @@ export function useDraftStream(attemptId?: string) {
|
||||
const forceNextApplyRef = useRef<boolean>(false);
|
||||
|
||||
const endpoint = attemptId
|
||||
? `/api/task-attempts/${attemptId}/follow-up-draft/stream`
|
||||
? `/api/task-attempts/${attemptId}/follow-up-draft/stream/ws`
|
||||
: undefined;
|
||||
|
||||
const makeInitial = useCallback(
|
||||
@@ -35,7 +35,7 @@ export function useDraftStream(attemptId?: string) {
|
||||
[attemptId]
|
||||
);
|
||||
|
||||
const { data, isConnected, error } = useJsonPatchStream<DraftStreamState>(
|
||||
const { data, isConnected, error } = useJsonPatchWsStream<DraftStreamState>(
|
||||
endpoint,
|
||||
!!endpoint,
|
||||
makeInitial
|
||||
@@ -64,7 +64,7 @@ export function useDraftStream(attemptId?: string) {
|
||||
});
|
||||
if (!isDraftLoaded) setIsDraftLoaded(true);
|
||||
} catch {
|
||||
// ignore, rely on SSE
|
||||
// ignore, rely on stream
|
||||
}
|
||||
};
|
||||
hydrate();
|
||||
@@ -73,7 +73,7 @@ export function useDraftStream(attemptId?: string) {
|
||||
};
|
||||
}, [attemptId, isDraftLoaded]);
|
||||
|
||||
// Handle SSE stream
|
||||
// Handle stream updates
|
||||
useEffect(() => {
|
||||
if (!data) return;
|
||||
const d = data.follow_up_draft;
|
||||
|
||||
@@ -2,6 +2,10 @@ import { useEffect, useState, useRef } from 'react';
|
||||
import { applyPatch } from 'rfc6902';
|
||||
import type { Operation } from 'rfc6902';
|
||||
|
||||
type WsJsonPatchMsg = { JsonPatch: Operation[] };
|
||||
type WsFinishedMsg = { finished: boolean };
|
||||
type WsMsg = WsJsonPatchMsg | WsFinishedMsg;
|
||||
|
||||
interface UseJsonPatchStreamOptions<T> {
|
||||
/**
|
||||
* Called once when the stream starts to inject initial data
|
||||
@@ -98,10 +102,10 @@ export const useJsonPatchWsStream = <T>(
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
const msg: WsMsg = JSON.parse(event.data);
|
||||
|
||||
// Handle JsonPatch messages (same as SSE json_patch event)
|
||||
if (msg.JsonPatch) {
|
||||
if ('JsonPatch' in msg) {
|
||||
const patches: Operation[] = msg.JsonPatch;
|
||||
const filtered = options.deduplicatePatches
|
||||
? options.deduplicatePatches(patches)
|
||||
@@ -119,8 +123,8 @@ export const useJsonPatchWsStream = <T>(
|
||||
setData(dataRef.current);
|
||||
}
|
||||
|
||||
// Handle Finished messages (same as SSE finished event)
|
||||
if (msg.Finished !== undefined) {
|
||||
// Handle finished messages ({finished: true})
|
||||
if ('finished' in msg) {
|
||||
ws.close();
|
||||
wsRef.current = null;
|
||||
setIsConnected(false);
|
||||
|
||||
Reference in New Issue
Block a user