Add agent interrupts (#1408)

* Add interrupt sender to gracefully stop claude code

* Remove debug logs

* Lint

* interrupt agent in read loop

* rm comments

* Revert claude client arch change
This commit is contained in:
Alex Netsch
2025-12-04 15:36:34 +00:00
committed by GitHub
parent ef1ba1b4bb
commit 9f4fabc285
7 changed files with 149 additions and 48 deletions

View File

@@ -82,6 +82,7 @@ impl AcpAgentHarness {
Ok(SpawnedChild {
child,
exit_signal: Some(exit_rx),
interrupt_sender: None,
})
}
@@ -119,6 +120,7 @@ impl AcpAgentHarness {
Ok(SpawnedChild {
child,
exit_signal: Some(exit_rx),
interrupt_sender: None,
})
}

View File

@@ -252,13 +252,17 @@ impl ClaudeCode {
let permission_mode = self.permission_mode();
let hooks = self.get_hooks();
// Create interrupt channel for graceful shutdown
let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel::<()>();
// Spawn task to handle the SDK client with control protocol
let prompt_clone = combined_prompt.clone();
let approvals_clone = self.approvals_service.clone();
tokio::spawn(async move {
let log_writer = LogWriter::new(new_stdout);
let client = ClaudeAgentClient::new(log_writer.clone(), approvals_clone);
let protocol_peer = ProtocolPeer::spawn(child_stdin, child_stdout, client.clone());
let protocol_peer =
ProtocolPeer::spawn(child_stdin, child_stdout, client.clone(), interrupt_rx);
// Initialize control protocol
if let Err(e) = protocol_peer.initialize(hooks).await {
@@ -285,6 +289,7 @@ impl ClaudeCode {
Ok(SpawnedChild {
child,
exit_signal: None,
interrupt_sender: Some(interrupt_tx),
})
}
}

View File

@@ -1,20 +1,18 @@
use std::sync::Arc;
use futures::FutureExt;
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
process::{ChildStdin, ChildStdout},
sync::Mutex,
sync::{Mutex, oneshot},
};
use super::types::{
CLIMessage, ControlRequestType, ControlResponseMessage, ControlResponseType,
SDKControlRequestMessage,
};
use super::types::{CLIMessage, ControlRequestType, ControlResponseMessage, ControlResponseType};
use crate::executors::{
ExecutorError,
claude::{
client::ClaudeAgentClient,
types::{PermissionMode, SDKControlRequestType},
types::{Message, PermissionMode, SDKControlRequest, SDKControlRequestType},
},
};
@@ -25,14 +23,19 @@ pub struct ProtocolPeer {
}
impl ProtocolPeer {
pub fn spawn(stdin: ChildStdin, stdout: ChildStdout, client: Arc<ClaudeAgentClient>) -> Self {
pub fn spawn(
stdin: ChildStdin,
stdout: ChildStdout,
client: Arc<ClaudeAgentClient>,
interrupt_rx: oneshot::Receiver<()>,
) -> Self {
let peer = Self {
stdin: Arc::new(Mutex::new(stdin)),
};
let reader_peer = peer.clone();
tokio::spawn(async move {
if let Err(e) = reader_peer.read_loop(stdout, client).await {
if let Err(e) = reader_peer.read_loop(stdout, client, interrupt_rx).await {
tracing::error!("Protocol reader loop error: {}", e);
}
});
@@ -44,41 +47,53 @@ impl ProtocolPeer {
&self,
stdout: ChildStdout,
client: Arc<ClaudeAgentClient>,
interrupt_rx: oneshot::Receiver<()>,
) -> Result<(), ExecutorError> {
let mut reader = BufReader::new(stdout);
let mut buffer = String::new();
// Fuse the receiver so it returns Pending forever after completing
let mut interrupt_rx = interrupt_rx.fuse();
loop {
buffer.clear();
match reader.read_line(&mut buffer).await {
Ok(0) => break, // EOF
Ok(_) => {
let line = buffer.trim();
if line.is_empty() {
continue;
}
// Parse message using typed enum
match serde_json::from_str::<CLIMessage>(line) {
Ok(CLIMessage::ControlRequest {
request_id,
request,
}) => {
self.handle_control_request(&client, request_id, request)
.await;
tokio::select! {
line_result = reader.read_line(&mut buffer) => {
match line_result {
Ok(0) => break, // EOF
Ok(_) => {
let line = buffer.trim();
if line.is_empty() {
continue;
}
// Parse message using typed enum
match serde_json::from_str::<CLIMessage>(line) {
Ok(CLIMessage::ControlRequest {
request_id,
request,
}) => {
self.handle_control_request(&client, request_id, request)
.await;
}
Ok(CLIMessage::ControlResponse { .. }) => {}
Ok(CLIMessage::Result(_)) => {
client.on_non_control(line).await?;
break;
}
_ => {
client.on_non_control(line).await?;
}
}
}
Ok(CLIMessage::ControlResponse { .. }) => {}
Ok(CLIMessage::Result(_)) => {
client.on_non_control(line).await?;
Err(e) => {
tracing::error!("Error reading stdout: {}", e);
break;
}
_ => {
client.on_non_control(line).await?;
}
}
}
Err(e) => {
tracing::error!("Error reading stdout: {}", e);
break;
_ = &mut interrupt_rx => {
if let Err(e) = self.interrupt().await {
tracing::debug!("Failed to send interrupt to Claude: {e}");
}
}
}
}
@@ -164,7 +179,6 @@ impl ProtocolPeer {
.await
}
/// Send JSON message to stdin
async fn send_json<T: serde::Serialize>(&self, message: &T) -> Result<(), ExecutorError> {
let json = serde_json::to_string(message)?;
let mut stdin = self.stdin.lock().await;
@@ -175,25 +189,23 @@ impl ProtocolPeer {
}
pub async fn send_user_message(&self, content: String) -> Result<(), ExecutorError> {
let message = serde_json::json!({
"type": "user",
"message": {
"role": "user",
"content": content
}
});
let message = Message::new_user(content);
self.send_json(&message).await
}
pub async fn initialize(&self, hooks: Option<serde_json::Value>) -> Result<(), ExecutorError> {
self.send_json(&SDKControlRequestMessage::new(
SDKControlRequestType::Initialize { hooks },
))
self.send_json(&SDKControlRequest::new(SDKControlRequestType::Initialize {
hooks,
}))
.await
}
pub async fn interrupt(&self) -> Result<(), ExecutorError> {
self.send_json(&SDKControlRequest::new(SDKControlRequestType::Interrupt {}))
.await
}
pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<(), ExecutorError> {
self.send_json(&SDKControlRequestMessage::new(
self.send_json(&SDKControlRequest::new(
SDKControlRequestType::SetPermissionMode { mode },
))
.await

View File

@@ -23,14 +23,14 @@ pub enum CLIMessage {
/// Control request from SDK to CLI (outgoing)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SDKControlRequestMessage {
pub struct SDKControlRequest {
#[serde(rename = "type")]
message_type: String, // Always "control_request"
pub request_id: String,
pub request: SDKControlRequestType,
}
impl SDKControlRequestMessage {
impl SDKControlRequest {
pub fn new(request: SDKControlRequestType) -> Self {
use uuid::Uuid;
Self {
@@ -141,6 +141,29 @@ pub enum ControlResponseType {
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Message {
User { message: ClaudeUserMessage },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaudeUserMessage {
role: String,
content: String,
}
impl Message {
pub fn new_user(content: String) -> Self {
Self::User {
message: ClaudeUserMessage {
role: "user".to_string(),
content,
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "subtype", rename_all = "snake_case")]
pub enum SDKControlRequestType {
@@ -151,6 +174,7 @@ pub enum SDKControlRequestType {
#[serde(skip_serializing_if = "Option::is_none")]
hooks: Option<Value>,
},
Interrupt {},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]

View File

@@ -376,6 +376,7 @@ impl Codex {
Ok(SpawnedChild {
child,
exit_signal: Some(exit_signal_rx),
interrupt_sender: None,
})
}

View File

@@ -236,10 +236,17 @@ pub enum ExecutorExitResult {
/// and mark it according to the result.
pub type ExecutorExitSignal = tokio::sync::oneshot::Receiver<ExecutorExitResult>;
/// Sender for requesting graceful interrupt of an executor.
/// When sent, the executor should attempt to interrupt gracefully before being killed.
pub type InterruptSender = tokio::sync::oneshot::Sender<()>;
#[derive(Debug)]
pub struct SpawnedChild {
pub child: AsyncGroupChild,
/// Executor → Container: signals when executor wants to exit
pub exit_signal: Option<ExecutorExitSignal>,
/// Container → Executor: signals when container wants to interrupt
pub interrupt_sender: Option<InterruptSender>,
}
impl From<AsyncGroupChild> for SpawnedChild {
@@ -247,6 +254,7 @@ impl From<AsyncGroupChild> for SpawnedChild {
Self {
child,
exit_signal: None,
interrupt_sender: None,
}
}
}

View File

@@ -31,7 +31,7 @@ use executors::{
coding_agent_initial::CodingAgentInitialRequest,
},
approvals::{ExecutorApprovalService, NoopExecutorApprovalService},
executors::{BaseCodingAgent, ExecutorExitResult, ExecutorExitSignal},
executors::{BaseCodingAgent, ExecutorExitResult, ExecutorExitSignal, InterruptSender},
logs::{
NormalizedEntryType,
utils::{
@@ -70,6 +70,7 @@ use crate::{command, copy};
pub struct LocalContainerService {
db: DBService,
child_store: Arc<RwLock<HashMap<Uuid, Arc<RwLock<AsyncGroupChild>>>>>,
interrupt_senders: Arc<RwLock<HashMap<Uuid, InterruptSender>>>,
msg_stores: Arc<RwLock<HashMap<Uuid, Arc<MsgStore>>>>,
config: Arc<RwLock<Config>>,
git: GitService,
@@ -94,10 +95,12 @@ impl LocalContainerService {
publisher: Result<SharePublisher, RemoteClientNotConfigured>,
) -> Self {
let child_store = Arc::new(RwLock::new(HashMap::new()));
let interrupt_senders = Arc::new(RwLock::new(HashMap::new()));
let container = LocalContainerService {
db,
child_store,
interrupt_senders,
msg_stores,
config,
git,
@@ -128,6 +131,16 @@ impl LocalContainerService {
map.remove(id);
}
async fn add_interrupt_sender(&self, id: Uuid, sender: InterruptSender) {
let mut map = self.interrupt_senders.write().await;
map.insert(id, sender);
}
async fn take_interrupt_sender(&self, id: &Uuid) -> Option<InterruptSender> {
let mut map = self.interrupt_senders.write().await;
map.remove(id)
}
/// Defensively check for externally deleted worktrees and mark them as deleted in the database
async fn check_externally_deleted_worktrees(db: &DBService) -> Result<(), DeploymentError> {
let active_attempts = TaskAttempt::find_by_worktree_deleted(&db.pool).await?;
@@ -986,6 +999,12 @@ impl ContainerService for LocalContainerService {
self.add_child_to_store(execution_process.id, spawned.child)
.await;
// Store interrupt sender for graceful shutdown
if let Some(interrupt_sender) = spawned.interrupt_sender {
self.add_interrupt_sender(execution_process.id, interrupt_sender)
.await;
}
// Spawn unified exit monitor: watches OS exit and optional executor signal
let _hn = self.spawn_exit_monitor(&execution_process.id, spawned.exit_signal);
@@ -1012,6 +1031,36 @@ impl ContainerService for LocalContainerService {
ExecutionProcess::update_completion(&self.db.pool, execution_process.id, status, exit_code)
.await?;
// Try graceful interrupt first, then force kill
if let Some(interrupt_sender) = self.take_interrupt_sender(&execution_process.id).await {
// Send interrupt signal (ignore error if receiver dropped)
let _ = interrupt_sender.send(());
// Wait for graceful exit with timeout
let graceful_exit = {
let mut child_guard = child.write().await;
tokio::time::timeout(Duration::from_secs(5), child_guard.wait()).await
};
match graceful_exit {
Ok(Ok(_)) => {
tracing::debug!(
"Process {} exited gracefully after interrupt",
execution_process.id
);
}
Ok(Err(e)) => {
tracing::info!("Error waiting for process {}: {}", execution_process.id, e);
}
Err(_) => {
tracing::debug!(
"Graceful shutdown timed out for process {}, force killing",
execution_process.id
);
}
}
}
// Kill the child process and remove from the store
{
let mut child_guard = child.write().await;