diff --git a/crates/executors/src/executors/acp/harness.rs b/crates/executors/src/executors/acp/harness.rs index d0ad87a0..b318e21f 100644 --- a/crates/executors/src/executors/acp/harness.rs +++ b/crates/executors/src/executors/acp/harness.rs @@ -82,6 +82,7 @@ impl AcpAgentHarness { Ok(SpawnedChild { child, exit_signal: Some(exit_rx), + interrupt_sender: None, }) } @@ -119,6 +120,7 @@ impl AcpAgentHarness { Ok(SpawnedChild { child, exit_signal: Some(exit_rx), + interrupt_sender: None, }) } diff --git a/crates/executors/src/executors/claude.rs b/crates/executors/src/executors/claude.rs index 9904bbd5..8f2d595f 100644 --- a/crates/executors/src/executors/claude.rs +++ b/crates/executors/src/executors/claude.rs @@ -252,13 +252,17 @@ impl ClaudeCode { let permission_mode = self.permission_mode(); let hooks = self.get_hooks(); + // Create interrupt channel for graceful shutdown + let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel::<()>(); + // Spawn task to handle the SDK client with control protocol let prompt_clone = combined_prompt.clone(); let approvals_clone = self.approvals_service.clone(); tokio::spawn(async move { let log_writer = LogWriter::new(new_stdout); let client = ClaudeAgentClient::new(log_writer.clone(), approvals_clone); - let protocol_peer = ProtocolPeer::spawn(child_stdin, child_stdout, client.clone()); + let protocol_peer = + ProtocolPeer::spawn(child_stdin, child_stdout, client.clone(), interrupt_rx); // Initialize control protocol if let Err(e) = protocol_peer.initialize(hooks).await { @@ -285,6 +289,7 @@ impl ClaudeCode { Ok(SpawnedChild { child, exit_signal: None, + interrupt_sender: Some(interrupt_tx), }) } } diff --git a/crates/executors/src/executors/claude/protocol.rs b/crates/executors/src/executors/claude/protocol.rs index 8789b7c5..d1bed326 100644 --- a/crates/executors/src/executors/claude/protocol.rs +++ b/crates/executors/src/executors/claude/protocol.rs @@ -1,20 +1,18 @@ use std::sync::Arc; +use futures::FutureExt; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, process::{ChildStdin, ChildStdout}, - sync::Mutex, + sync::{Mutex, oneshot}, }; -use super::types::{ - CLIMessage, ControlRequestType, ControlResponseMessage, ControlResponseType, - SDKControlRequestMessage, -}; +use super::types::{CLIMessage, ControlRequestType, ControlResponseMessage, ControlResponseType}; use crate::executors::{ ExecutorError, claude::{ client::ClaudeAgentClient, - types::{PermissionMode, SDKControlRequestType}, + types::{Message, PermissionMode, SDKControlRequest, SDKControlRequestType}, }, }; @@ -25,14 +23,19 @@ pub struct ProtocolPeer { } impl ProtocolPeer { - pub fn spawn(stdin: ChildStdin, stdout: ChildStdout, client: Arc) -> Self { + pub fn spawn( + stdin: ChildStdin, + stdout: ChildStdout, + client: Arc, + interrupt_rx: oneshot::Receiver<()>, + ) -> Self { let peer = Self { stdin: Arc::new(Mutex::new(stdin)), }; let reader_peer = peer.clone(); tokio::spawn(async move { - if let Err(e) = reader_peer.read_loop(stdout, client).await { + if let Err(e) = reader_peer.read_loop(stdout, client, interrupt_rx).await { tracing::error!("Protocol reader loop error: {}", e); } }); @@ -44,41 +47,53 @@ impl ProtocolPeer { &self, stdout: ChildStdout, client: Arc, + interrupt_rx: oneshot::Receiver<()>, ) -> Result<(), ExecutorError> { let mut reader = BufReader::new(stdout); let mut buffer = String::new(); + // Fuse the receiver so it returns Pending forever after completing + let mut interrupt_rx = interrupt_rx.fuse(); loop { buffer.clear(); - match reader.read_line(&mut buffer).await { - Ok(0) => break, // EOF - Ok(_) => { - let line = buffer.trim(); - if line.is_empty() { - continue; - } - // Parse message using typed enum - match serde_json::from_str::(line) { - Ok(CLIMessage::ControlRequest { - request_id, - request, - }) => { - self.handle_control_request(&client, request_id, request) - .await; + tokio::select! { + line_result = reader.read_line(&mut buffer) => { + match line_result { + Ok(0) => break, // EOF + Ok(_) => { + let line = buffer.trim(); + if line.is_empty() { + continue; + } + // Parse message using typed enum + match serde_json::from_str::(line) { + Ok(CLIMessage::ControlRequest { + request_id, + request, + }) => { + self.handle_control_request(&client, request_id, request) + .await; + } + Ok(CLIMessage::ControlResponse { .. }) => {} + Ok(CLIMessage::Result(_)) => { + client.on_non_control(line).await?; + break; + } + _ => { + client.on_non_control(line).await?; + } + } } - Ok(CLIMessage::ControlResponse { .. }) => {} - Ok(CLIMessage::Result(_)) => { - client.on_non_control(line).await?; + Err(e) => { + tracing::error!("Error reading stdout: {}", e); break; } - _ => { - client.on_non_control(line).await?; - } } } - Err(e) => { - tracing::error!("Error reading stdout: {}", e); - break; + _ = &mut interrupt_rx => { + if let Err(e) = self.interrupt().await { + tracing::debug!("Failed to send interrupt to Claude: {e}"); + } } } } @@ -164,7 +179,6 @@ impl ProtocolPeer { .await } - /// Send JSON message to stdin async fn send_json(&self, message: &T) -> Result<(), ExecutorError> { let json = serde_json::to_string(message)?; let mut stdin = self.stdin.lock().await; @@ -175,25 +189,23 @@ impl ProtocolPeer { } pub async fn send_user_message(&self, content: String) -> Result<(), ExecutorError> { - let message = serde_json::json!({ - "type": "user", - "message": { - "role": "user", - "content": content - } - }); + let message = Message::new_user(content); self.send_json(&message).await } pub async fn initialize(&self, hooks: Option) -> Result<(), ExecutorError> { - self.send_json(&SDKControlRequestMessage::new( - SDKControlRequestType::Initialize { hooks }, - )) + self.send_json(&SDKControlRequest::new(SDKControlRequestType::Initialize { + hooks, + })) .await } + pub async fn interrupt(&self) -> Result<(), ExecutorError> { + self.send_json(&SDKControlRequest::new(SDKControlRequestType::Interrupt {})) + .await + } pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<(), ExecutorError> { - self.send_json(&SDKControlRequestMessage::new( + self.send_json(&SDKControlRequest::new( SDKControlRequestType::SetPermissionMode { mode }, )) .await diff --git a/crates/executors/src/executors/claude/types.rs b/crates/executors/src/executors/claude/types.rs index 1e5802bf..2e76f8ed 100644 --- a/crates/executors/src/executors/claude/types.rs +++ b/crates/executors/src/executors/claude/types.rs @@ -23,14 +23,14 @@ pub enum CLIMessage { /// Control request from SDK to CLI (outgoing) #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SDKControlRequestMessage { +pub struct SDKControlRequest { #[serde(rename = "type")] message_type: String, // Always "control_request" pub request_id: String, pub request: SDKControlRequestType, } -impl SDKControlRequestMessage { +impl SDKControlRequest { pub fn new(request: SDKControlRequestType) -> Self { use uuid::Uuid; Self { @@ -141,6 +141,29 @@ pub enum ControlResponseType { }, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Message { + User { message: ClaudeUserMessage }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClaudeUserMessage { + role: String, + content: String, +} + +impl Message { + pub fn new_user(content: String) -> Self { + Self::User { + message: ClaudeUserMessage { + role: "user".to_string(), + content, + }, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "subtype", rename_all = "snake_case")] pub enum SDKControlRequestType { @@ -151,6 +174,7 @@ pub enum SDKControlRequestType { #[serde(skip_serializing_if = "Option::is_none")] hooks: Option, }, + Interrupt {}, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] diff --git a/crates/executors/src/executors/codex.rs b/crates/executors/src/executors/codex.rs index 8bf8ce16..e9f6d9e3 100644 --- a/crates/executors/src/executors/codex.rs +++ b/crates/executors/src/executors/codex.rs @@ -376,6 +376,7 @@ impl Codex { Ok(SpawnedChild { child, exit_signal: Some(exit_signal_rx), + interrupt_sender: None, }) } diff --git a/crates/executors/src/executors/mod.rs b/crates/executors/src/executors/mod.rs index a13d5564..8c4fab2c 100644 --- a/crates/executors/src/executors/mod.rs +++ b/crates/executors/src/executors/mod.rs @@ -236,10 +236,17 @@ pub enum ExecutorExitResult { /// and mark it according to the result. pub type ExecutorExitSignal = tokio::sync::oneshot::Receiver; +/// Sender for requesting graceful interrupt of an executor. +/// When sent, the executor should attempt to interrupt gracefully before being killed. +pub type InterruptSender = tokio::sync::oneshot::Sender<()>; + #[derive(Debug)] pub struct SpawnedChild { pub child: AsyncGroupChild, + /// Executor → Container: signals when executor wants to exit pub exit_signal: Option, + /// Container → Executor: signals when container wants to interrupt + pub interrupt_sender: Option, } impl From for SpawnedChild { @@ -247,6 +254,7 @@ impl From for SpawnedChild { Self { child, exit_signal: None, + interrupt_sender: None, } } } diff --git a/crates/local-deployment/src/container.rs b/crates/local-deployment/src/container.rs index 18a60ad0..a80db077 100644 --- a/crates/local-deployment/src/container.rs +++ b/crates/local-deployment/src/container.rs @@ -31,7 +31,7 @@ use executors::{ coding_agent_initial::CodingAgentInitialRequest, }, approvals::{ExecutorApprovalService, NoopExecutorApprovalService}, - executors::{BaseCodingAgent, ExecutorExitResult, ExecutorExitSignal}, + executors::{BaseCodingAgent, ExecutorExitResult, ExecutorExitSignal, InterruptSender}, logs::{ NormalizedEntryType, utils::{ @@ -70,6 +70,7 @@ use crate::{command, copy}; pub struct LocalContainerService { db: DBService, child_store: Arc>>>>, + interrupt_senders: Arc>>, msg_stores: Arc>>>, config: Arc>, git: GitService, @@ -94,10 +95,12 @@ impl LocalContainerService { publisher: Result, ) -> Self { let child_store = Arc::new(RwLock::new(HashMap::new())); + let interrupt_senders = Arc::new(RwLock::new(HashMap::new())); let container = LocalContainerService { db, child_store, + interrupt_senders, msg_stores, config, git, @@ -128,6 +131,16 @@ impl LocalContainerService { map.remove(id); } + async fn add_interrupt_sender(&self, id: Uuid, sender: InterruptSender) { + let mut map = self.interrupt_senders.write().await; + map.insert(id, sender); + } + + async fn take_interrupt_sender(&self, id: &Uuid) -> Option { + let mut map = self.interrupt_senders.write().await; + map.remove(id) + } + /// Defensively check for externally deleted worktrees and mark them as deleted in the database async fn check_externally_deleted_worktrees(db: &DBService) -> Result<(), DeploymentError> { let active_attempts = TaskAttempt::find_by_worktree_deleted(&db.pool).await?; @@ -986,6 +999,12 @@ impl ContainerService for LocalContainerService { self.add_child_to_store(execution_process.id, spawned.child) .await; + // Store interrupt sender for graceful shutdown + if let Some(interrupt_sender) = spawned.interrupt_sender { + self.add_interrupt_sender(execution_process.id, interrupt_sender) + .await; + } + // Spawn unified exit monitor: watches OS exit and optional executor signal let _hn = self.spawn_exit_monitor(&execution_process.id, spawned.exit_signal); @@ -1012,6 +1031,36 @@ impl ContainerService for LocalContainerService { ExecutionProcess::update_completion(&self.db.pool, execution_process.id, status, exit_code) .await?; + // Try graceful interrupt first, then force kill + if let Some(interrupt_sender) = self.take_interrupt_sender(&execution_process.id).await { + // Send interrupt signal (ignore error if receiver dropped) + let _ = interrupt_sender.send(()); + + // Wait for graceful exit with timeout + let graceful_exit = { + let mut child_guard = child.write().await; + tokio::time::timeout(Duration::from_secs(5), child_guard.wait()).await + }; + + match graceful_exit { + Ok(Ok(_)) => { + tracing::debug!( + "Process {} exited gracefully after interrupt", + execution_process.id + ); + } + Ok(Err(e)) => { + tracing::info!("Error waiting for process {}: {}", execution_process.id, e); + } + Err(_) => { + tracing::debug!( + "Graceful shutdown timed out for process {}, force killing", + execution_process.id + ); + } + } + } + // Kill the child process and remove from the store { let mut child_guard = child.write().await;