Bound the diff-stream channel (#983)

The bounded channel applies backpressure, so the filewatcher writer pauses when the channel is full.
This commit is contained in:
Solomon
2025-10-09 12:39:03 +01:00
committed by GitHub
parent 3493503602
commit 6a81ba77f4
2 changed files with 30 additions and 25 deletions

View File

@@ -618,6 +618,7 @@ impl LocalContainerService {
base_commit.clone(),
stats_only,
)
.await
.map_err(|e| ContainerError::Other(anyhow!("{e}")))
}
}

View File

@@ -13,7 +13,7 @@ use futures::StreamExt;
use notify_debouncer_full::DebouncedEvent;
use thiserror::Error;
use tokio::{sync::mpsc, task::JoinHandle};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::wrappers::ReceiverStream;
use utils::{
diff::{self, Diff},
log_msg::LogMsg,
@@ -27,6 +27,8 @@ use crate::services::{
/// Maximum cumulative diff bytes to stream before omitting content (200MB)
pub const MAX_CUMULATIVE_DIFF_BYTES: usize = 200 * 1024 * 1024;
const DIFF_STREAM_CHANNEL_CAPACITY: usize = 1000;
/// Errors that can occur during diff stream creation and operation
#[derive(Error, Debug)]
pub enum DiffStreamError {
@@ -85,11 +87,15 @@ struct DiffWatcherContext {
cumulative: Arc<AtomicUsize>,
full_sent: Arc<std::sync::RwLock<HashSet<String>>>,
stats_only: bool,
tx: mpsc::UnboundedSender<Result<LogMsg, io::Error>>,
tx: mpsc::Sender<Result<LogMsg, io::Error>>,
}
impl DiffWatcherContext {
fn handle_events(&self, events: Vec<DebouncedEvent>, canonical_worktree_path: &Path) -> bool {
async fn handle_events(
&self,
events: Vec<DebouncedEvent>,
canonical_worktree_path: &Path,
) -> bool {
let changed_paths =
extract_changed_paths(&events, canonical_worktree_path, &self.worktree_path);
@@ -106,17 +112,17 @@ impl DiffWatcherContext {
&self.full_sent,
self.stats_only,
) {
Ok(messages) => send_messages(&self.tx, messages),
Ok(messages) => send_messages(&self.tx, messages).await,
Err(err) => {
tracing::error!("Error processing file changes: {err}");
send_error(&self.tx, err.to_string());
send_error(&self.tx, err.to_string()).await;
false
}
}
}
}
pub fn create(
pub async fn create(
git_service: GitService,
worktree_path: PathBuf,
base_commit: Commit,
@@ -147,12 +153,9 @@ pub fn create(
}
}
let (tx, rx) = mpsc::unbounded_channel::<Result<LogMsg, io::Error>>();
if !send_initial_diffs(&tx, initial_diffs) {
return Ok(DiffStreamHandle::new(
UnboundedReceiverStream::new(rx).boxed(),
None,
));
let (tx, rx) = mpsc::channel::<Result<LogMsg, io::Error>>(DIFF_STREAM_CHANNEL_CAPACITY);
if !send_initial_diffs(&tx, initial_diffs).await {
return Ok(DiffStreamHandle::new(ReceiverStream::new(rx).boxed(), None));
}
let tx_clone = tx.clone();
@@ -177,7 +180,7 @@ pub fn create(
Ok(Ok(parts)) => parts,
Ok(Err(e)) => {
tracing::error!("Failed to set up filesystem watcher: {e}");
send_error(&ctx.tx, e.to_string());
send_error(&ctx.tx, e.to_string()).await;
return;
}
Err(join_err) => {
@@ -185,7 +188,8 @@ pub fn create(
send_error(
&ctx.tx,
format!("Failed to spawn watcher setup: {join_err}"),
);
)
.await;
return;
}
};
@@ -195,7 +199,7 @@ pub fn create(
while let Some(result) = watcher_rx.next().await {
match result {
Ok(events) => {
if !ctx.handle_events(events, &canonical_worktree_path) {
if !ctx.handle_events(events, &canonical_worktree_path).await {
return;
}
}
@@ -206,7 +210,7 @@ pub fn create(
.collect::<Vec<_>>()
.join("; ");
tracing::error!("Filesystem watcher error: {message}");
send_error(&ctx.tx, message);
send_error(&ctx.tx, message).await;
return;
}
}
@@ -216,39 +220,39 @@ pub fn create(
drop(tx);
Ok(DiffStreamHandle::new(
UnboundedReceiverStream::new(rx).boxed(),
ReceiverStream::new(rx).boxed(),
Some(watcher_task),
))
}
fn send_initial_diffs(
tx: &mpsc::UnboundedSender<Result<LogMsg, io::Error>>,
async fn send_initial_diffs(
tx: &mpsc::Sender<Result<LogMsg, io::Error>>,
diffs: Vec<Diff>,
) -> bool {
for diff in diffs {
let entry_index = GitService::diff_path(&diff);
let patch = ConversationPatch::add_diff(escape_json_pointer_segment(&entry_index), diff);
if tx.send(Ok(LogMsg::JsonPatch(patch))).is_err() {
if tx.send(Ok(LogMsg::JsonPatch(patch))).await.is_err() {
return false;
}
}
true
}
fn send_messages(
tx: &mpsc::UnboundedSender<Result<LogMsg, io::Error>>,
async fn send_messages(
tx: &mpsc::Sender<Result<LogMsg, io::Error>>,
messages: Vec<LogMsg>,
) -> bool {
for msg in messages {
if tx.send(Ok(msg)).is_err() {
if tx.send(Ok(msg)).await.is_err() {
return false;
}
}
true
}
fn send_error(tx: &mpsc::UnboundedSender<Result<LogMsg, io::Error>>, message: String) {
let _ = tx.send(Err(io::Error::other(message)));
async fn send_error(tx: &mpsc::Sender<Result<LogMsg, io::Error>>, message: String) {
let _ = tx.send(Err(io::Error::other(message))).await;
}
pub fn apply_stream_omit_policy(diff: &mut Diff, sent_bytes: &Arc<AtomicUsize>, stats_only: bool) {