Migrate followup draft SSE to WebSockets (#776)
This commit is contained in:
@@ -2,11 +2,14 @@ use std::path::PathBuf;
|
|||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
BoxError, Extension, Json, Router,
|
BoxError, Extension, Json, Router,
|
||||||
extract::{Query, State},
|
extract::{
|
||||||
|
Query, State,
|
||||||
|
ws::{WebSocket, WebSocketUpgrade},
|
||||||
|
},
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
middleware::from_fn_with_state,
|
middleware::from_fn_with_state,
|
||||||
response::{
|
response::{
|
||||||
Json as ResponseJson, Sse,
|
IntoResponse, Json as ResponseJson, Sse,
|
||||||
sse::{Event, KeepAlive},
|
sse::{Event, KeepAlive},
|
||||||
},
|
},
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
@@ -513,22 +516,52 @@ pub async fn save_follow_up_draft(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[axum::debug_handler]
|
#[axum::debug_handler]
|
||||||
pub async fn stream_follow_up_draft(
|
pub async fn stream_follow_up_draft_ws(
|
||||||
|
ws: WebSocketUpgrade,
|
||||||
Extension(task_attempt): Extension<TaskAttempt>,
|
Extension(task_attempt): Extension<TaskAttempt>,
|
||||||
State(deployment): State<DeploymentImpl>,
|
State(deployment): State<DeploymentImpl>,
|
||||||
) -> Result<
|
) -> impl IntoResponse {
|
||||||
Sse<impl futures_util::Stream<Item = Result<Event, Box<dyn std::error::Error + Send + Sync>>>>,
|
ws.on_upgrade(move |socket| async move {
|
||||||
ApiError,
|
if let Err(e) = handle_follow_up_draft_ws(socket, deployment, task_attempt.id).await {
|
||||||
> {
|
tracing::warn!("follow-up draft WS closed: {}", e);
|
||||||
let stream = deployment
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_follow_up_draft_ws(
|
||||||
|
socket: WebSocket,
|
||||||
|
deployment: DeploymentImpl,
|
||||||
|
task_attempt_id: uuid::Uuid,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
||||||
|
|
||||||
|
let mut stream = deployment
|
||||||
.events()
|
.events()
|
||||||
.stream_follow_up_draft_for_attempt(task_attempt.id)
|
.stream_follow_up_draft_for_attempt_raw(task_attempt_id)
|
||||||
.await
|
.await?
|
||||||
.map_err(|e| ApiError::from(deployment::DeploymentError::from(e)))?;
|
.map_ok(|msg| msg.to_ws_message_unchecked());
|
||||||
Ok(
|
|
||||||
Sse::new(stream.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) }))
|
// Split socket into sender and receiver
|
||||||
.keep_alive(KeepAlive::default()),
|
let (mut sender, mut receiver) = socket.split();
|
||||||
)
|
|
||||||
|
// Drain (and ignore) any client->server messages so pings/pongs work
|
||||||
|
tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
|
||||||
|
|
||||||
|
// Forward server messages
|
||||||
|
while let Some(item) = stream.next().await {
|
||||||
|
match item {
|
||||||
|
Ok(msg) => {
|
||||||
|
if sender.send(msg).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("stream error: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[axum::debug_handler]
|
#[axum::debug_handler]
|
||||||
@@ -1690,7 +1723,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
|
|||||||
"/follow-up-draft",
|
"/follow-up-draft",
|
||||||
get(get_follow_up_draft).put(save_follow_up_draft),
|
get(get_follow_up_draft).put(save_follow_up_draft),
|
||||||
)
|
)
|
||||||
.route("/follow-up-draft/stream", get(stream_follow_up_draft))
|
.route("/follow-up-draft/stream/ws", get(stream_follow_up_draft_ws))
|
||||||
.route("/follow-up-draft/queue", post(set_follow_up_queue))
|
.route("/follow-up-draft/queue", post(set_follow_up_queue))
|
||||||
.route("/replace-process", post(replace_process))
|
.route("/replace-process", post(replace_process))
|
||||||
.route("/commit-info", get(get_commit_info))
|
.route("/commit-info", get(get_commit_info))
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
use std::{str::FromStr, sync::Arc};
|
use std::{str::FromStr, sync::Arc};
|
||||||
|
|
||||||
use anyhow::Error as AnyhowError;
|
use anyhow::Error as AnyhowError;
|
||||||
use axum::response::sse::Event;
|
|
||||||
use db::{
|
use db::{
|
||||||
DBService,
|
DBService,
|
||||||
models::{
|
models::{
|
||||||
@@ -10,7 +9,7 @@ use db::{
|
|||||||
task_attempt::TaskAttempt,
|
task_attempt::TaskAttempt,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use futures::{StreamExt, TryStreamExt};
|
use futures::StreamExt;
|
||||||
use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation};
|
use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
@@ -802,11 +801,11 @@ impl EventService {
|
|||||||
Ok(combined_stream)
|
Ok(combined_stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stream follow-up draft for a specific task attempt with initial snapshot
|
/// Stream follow-up draft for a specific task attempt (raw LogMsg format for WebSocket)
|
||||||
pub async fn stream_follow_up_draft_for_attempt(
|
pub async fn stream_follow_up_draft_for_attempt_raw(
|
||||||
&self,
|
&self,
|
||||||
task_attempt_id: Uuid,
|
task_attempt_id: Uuid,
|
||||||
) -> Result<futures::stream::BoxStream<'static, Result<Event, std::io::Error>>, EventError>
|
) -> Result<futures::stream::BoxStream<'static, Result<LogMsg, std::io::Error>>, EventError>
|
||||||
{
|
{
|
||||||
// Get initial snapshot of follow-up draft
|
// Get initial snapshot of follow-up draft
|
||||||
let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id(
|
let draft = db::models::follow_up_draft::FollowUpDraft::find_by_task_attempt_id(
|
||||||
@@ -827,11 +826,13 @@ impl EventService {
|
|||||||
version: 0,
|
version: 0,
|
||||||
});
|
});
|
||||||
|
|
||||||
let initial_patch = json!([{
|
let initial_patch = json!([
|
||||||
"op": "replace",
|
{
|
||||||
"path": "/",
|
"op": "replace",
|
||||||
"value": { "follow_up_draft": draft }
|
"path": "/",
|
||||||
}]);
|
"value": { "follow_up_draft": draft }
|
||||||
|
}
|
||||||
|
]);
|
||||||
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
|
let initial_msg = LogMsg::JsonPatch(serde_json::from_value(initial_patch).unwrap());
|
||||||
|
|
||||||
// Filtered live stream, mapped into direct JSON patches that update /follow_up_draft
|
// Filtered live stream, mapped into direct JSON patches that update /follow_up_draft
|
||||||
@@ -848,11 +849,13 @@ impl EventService {
|
|||||||
RecordTypes::FollowUpDraft(draft) => {
|
RecordTypes::FollowUpDraft(draft) => {
|
||||||
if draft.task_attempt_id == task_attempt_id {
|
if draft.task_attempt_id == task_attempt_id {
|
||||||
// Build a direct patch to replace /follow_up_draft
|
// Build a direct patch to replace /follow_up_draft
|
||||||
let direct = json!([{
|
let direct = json!([
|
||||||
"op": "replace",
|
{
|
||||||
"path": "/follow_up_draft",
|
"op": "replace",
|
||||||
"value": draft
|
"path": "/follow_up_draft",
|
||||||
}]);
|
"value": draft
|
||||||
|
}
|
||||||
|
]);
|
||||||
let direct_patch = serde_json::from_value(direct).unwrap();
|
let direct_patch = serde_json::from_value(direct).unwrap();
|
||||||
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
|
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
|
||||||
}
|
}
|
||||||
@@ -875,11 +878,13 @@ impl EventService {
|
|||||||
"updated_at": chrono::Utc::now(),
|
"updated_at": chrono::Utc::now(),
|
||||||
"version": 0
|
"version": 0
|
||||||
});
|
});
|
||||||
let direct = json!([{
|
let direct = json!([
|
||||||
"op": "replace",
|
{
|
||||||
"path": "/follow_up_draft",
|
"op": "replace",
|
||||||
"value": empty
|
"path": "/follow_up_draft",
|
||||||
}]);
|
"value": empty
|
||||||
|
}
|
||||||
|
]);
|
||||||
let direct_patch = serde_json::from_value(direct).unwrap();
|
let direct_patch = serde_json::from_value(direct).unwrap();
|
||||||
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
|
return Some(Ok(LogMsg::JsonPatch(direct_patch)));
|
||||||
}
|
}
|
||||||
@@ -896,10 +901,7 @@ impl EventService {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
|
let initial_stream = futures::stream::once(async move { Ok(initial_msg) });
|
||||||
let combined_stream = initial_stream
|
let combined_stream = initial_stream.chain(filtered_stream).boxed();
|
||||||
.chain(filtered_stream)
|
|
||||||
.map_ok(|msg| msg.to_sse_event())
|
|
||||||
.boxed();
|
|
||||||
|
|
||||||
Ok(combined_stream)
|
Ok(combined_stream)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||||
import { useJsonPatchStream } from '@/hooks/useJsonPatchStream';
|
import { useJsonPatchWsStream } from '@/hooks/useJsonPatchWsStream';
|
||||||
import { attemptsApi } from '@/lib/api';
|
import { attemptsApi } from '@/lib/api';
|
||||||
import type { FollowUpDraft } from 'shared/types';
|
import type { FollowUpDraft } from 'shared/types';
|
||||||
import { inIframe } from '@/vscode/bridge';
|
import { inIframe } from '@/vscode/bridge';
|
||||||
@@ -14,7 +14,7 @@ export function useDraftStream(attemptId?: string) {
|
|||||||
const forceNextApplyRef = useRef<boolean>(false);
|
const forceNextApplyRef = useRef<boolean>(false);
|
||||||
|
|
||||||
const endpoint = attemptId
|
const endpoint = attemptId
|
||||||
? `/api/task-attempts/${attemptId}/follow-up-draft/stream`
|
? `/api/task-attempts/${attemptId}/follow-up-draft/stream/ws`
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
const makeInitial = useCallback(
|
const makeInitial = useCallback(
|
||||||
@@ -35,7 +35,7 @@ export function useDraftStream(attemptId?: string) {
|
|||||||
[attemptId]
|
[attemptId]
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data, isConnected, error } = useJsonPatchStream<DraftStreamState>(
|
const { data, isConnected, error } = useJsonPatchWsStream<DraftStreamState>(
|
||||||
endpoint,
|
endpoint,
|
||||||
!!endpoint,
|
!!endpoint,
|
||||||
makeInitial
|
makeInitial
|
||||||
@@ -64,7 +64,7 @@ export function useDraftStream(attemptId?: string) {
|
|||||||
});
|
});
|
||||||
if (!isDraftLoaded) setIsDraftLoaded(true);
|
if (!isDraftLoaded) setIsDraftLoaded(true);
|
||||||
} catch {
|
} catch {
|
||||||
// ignore, rely on SSE
|
// ignore, rely on stream
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
hydrate();
|
hydrate();
|
||||||
@@ -73,7 +73,7 @@ export function useDraftStream(attemptId?: string) {
|
|||||||
};
|
};
|
||||||
}, [attemptId, isDraftLoaded]);
|
}, [attemptId, isDraftLoaded]);
|
||||||
|
|
||||||
// Handle SSE stream
|
// Handle stream updates
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!data) return;
|
if (!data) return;
|
||||||
const d = data.follow_up_draft;
|
const d = data.follow_up_draft;
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ import { useEffect, useState, useRef } from 'react';
|
|||||||
import { applyPatch } from 'rfc6902';
|
import { applyPatch } from 'rfc6902';
|
||||||
import type { Operation } from 'rfc6902';
|
import type { Operation } from 'rfc6902';
|
||||||
|
|
||||||
|
type WsJsonPatchMsg = { JsonPatch: Operation[] };
|
||||||
|
type WsFinishedMsg = { finished: boolean };
|
||||||
|
type WsMsg = WsJsonPatchMsg | WsFinishedMsg;
|
||||||
|
|
||||||
interface UseJsonPatchStreamOptions<T> {
|
interface UseJsonPatchStreamOptions<T> {
|
||||||
/**
|
/**
|
||||||
* Called once when the stream starts to inject initial data
|
* Called once when the stream starts to inject initial data
|
||||||
@@ -98,10 +102,10 @@ export const useJsonPatchWsStream = <T>(
|
|||||||
|
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
try {
|
try {
|
||||||
const msg = JSON.parse(event.data);
|
const msg: WsMsg = JSON.parse(event.data);
|
||||||
|
|
||||||
// Handle JsonPatch messages (same as SSE json_patch event)
|
// Handle JsonPatch messages (same as SSE json_patch event)
|
||||||
if (msg.JsonPatch) {
|
if ('JsonPatch' in msg) {
|
||||||
const patches: Operation[] = msg.JsonPatch;
|
const patches: Operation[] = msg.JsonPatch;
|
||||||
const filtered = options.deduplicatePatches
|
const filtered = options.deduplicatePatches
|
||||||
? options.deduplicatePatches(patches)
|
? options.deduplicatePatches(patches)
|
||||||
@@ -119,8 +123,8 @@ export const useJsonPatchWsStream = <T>(
|
|||||||
setData(dataRef.current);
|
setData(dataRef.current);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Finished messages (same as SSE finished event)
|
// Handle finished messages ({finished: true})
|
||||||
if (msg.Finished !== undefined) {
|
if ('finished' in msg) {
|
||||||
ws.close();
|
ws.close();
|
||||||
wsRef.current = null;
|
wsRef.current = null;
|
||||||
setIsConnected(false);
|
setIsConnected(false);
|
||||||
|
|||||||
Reference in New Issue
Block a user