Add agent interrupts (#1408)
* Add interrupt sender to gracefully stop claude code * Remove debug logs * Lint * interrupt agent in read loop * rm comments * Revert claude client arch change
This commit is contained in:
@@ -82,6 +82,7 @@ impl AcpAgentHarness {
|
||||
Ok(SpawnedChild {
|
||||
child,
|
||||
exit_signal: Some(exit_rx),
|
||||
interrupt_sender: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -119,6 +120,7 @@ impl AcpAgentHarness {
|
||||
Ok(SpawnedChild {
|
||||
child,
|
||||
exit_signal: Some(exit_rx),
|
||||
interrupt_sender: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -252,13 +252,17 @@ impl ClaudeCode {
|
||||
let permission_mode = self.permission_mode();
|
||||
let hooks = self.get_hooks();
|
||||
|
||||
// Create interrupt channel for graceful shutdown
|
||||
let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
|
||||
// Spawn task to handle the SDK client with control protocol
|
||||
let prompt_clone = combined_prompt.clone();
|
||||
let approvals_clone = self.approvals_service.clone();
|
||||
tokio::spawn(async move {
|
||||
let log_writer = LogWriter::new(new_stdout);
|
||||
let client = ClaudeAgentClient::new(log_writer.clone(), approvals_clone);
|
||||
let protocol_peer = ProtocolPeer::spawn(child_stdin, child_stdout, client.clone());
|
||||
let protocol_peer =
|
||||
ProtocolPeer::spawn(child_stdin, child_stdout, client.clone(), interrupt_rx);
|
||||
|
||||
// Initialize control protocol
|
||||
if let Err(e) = protocol_peer.initialize(hooks).await {
|
||||
@@ -285,6 +289,7 @@ impl ClaudeCode {
|
||||
Ok(SpawnedChild {
|
||||
child,
|
||||
exit_signal: None,
|
||||
interrupt_sender: Some(interrupt_tx),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||
process::{ChildStdin, ChildStdout},
|
||||
sync::Mutex,
|
||||
sync::{Mutex, oneshot},
|
||||
};
|
||||
|
||||
use super::types::{
|
||||
CLIMessage, ControlRequestType, ControlResponseMessage, ControlResponseType,
|
||||
SDKControlRequestMessage,
|
||||
};
|
||||
use super::types::{CLIMessage, ControlRequestType, ControlResponseMessage, ControlResponseType};
|
||||
use crate::executors::{
|
||||
ExecutorError,
|
||||
claude::{
|
||||
client::ClaudeAgentClient,
|
||||
types::{PermissionMode, SDKControlRequestType},
|
||||
types::{Message, PermissionMode, SDKControlRequest, SDKControlRequestType},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -25,14 +23,19 @@ pub struct ProtocolPeer {
|
||||
}
|
||||
|
||||
impl ProtocolPeer {
|
||||
pub fn spawn(stdin: ChildStdin, stdout: ChildStdout, client: Arc<ClaudeAgentClient>) -> Self {
|
||||
pub fn spawn(
|
||||
stdin: ChildStdin,
|
||||
stdout: ChildStdout,
|
||||
client: Arc<ClaudeAgentClient>,
|
||||
interrupt_rx: oneshot::Receiver<()>,
|
||||
) -> Self {
|
||||
let peer = Self {
|
||||
stdin: Arc::new(Mutex::new(stdin)),
|
||||
};
|
||||
|
||||
let reader_peer = peer.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = reader_peer.read_loop(stdout, client).await {
|
||||
if let Err(e) = reader_peer.read_loop(stdout, client, interrupt_rx).await {
|
||||
tracing::error!("Protocol reader loop error: {}", e);
|
||||
}
|
||||
});
|
||||
@@ -44,41 +47,53 @@ impl ProtocolPeer {
|
||||
&self,
|
||||
stdout: ChildStdout,
|
||||
client: Arc<ClaudeAgentClient>,
|
||||
interrupt_rx: oneshot::Receiver<()>,
|
||||
) -> Result<(), ExecutorError> {
|
||||
let mut reader = BufReader::new(stdout);
|
||||
let mut buffer = String::new();
|
||||
// Fuse the receiver so it returns Pending forever after completing
|
||||
let mut interrupt_rx = interrupt_rx.fuse();
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
match reader.read_line(&mut buffer).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
let line = buffer.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// Parse message using typed enum
|
||||
match serde_json::from_str::<CLIMessage>(line) {
|
||||
Ok(CLIMessage::ControlRequest {
|
||||
request_id,
|
||||
request,
|
||||
}) => {
|
||||
self.handle_control_request(&client, request_id, request)
|
||||
.await;
|
||||
tokio::select! {
|
||||
line_result = reader.read_line(&mut buffer) => {
|
||||
match line_result {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
let line = buffer.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// Parse message using typed enum
|
||||
match serde_json::from_str::<CLIMessage>(line) {
|
||||
Ok(CLIMessage::ControlRequest {
|
||||
request_id,
|
||||
request,
|
||||
}) => {
|
||||
self.handle_control_request(&client, request_id, request)
|
||||
.await;
|
||||
}
|
||||
Ok(CLIMessage::ControlResponse { .. }) => {}
|
||||
Ok(CLIMessage::Result(_)) => {
|
||||
client.on_non_control(line).await?;
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
client.on_non_control(line).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(CLIMessage::ControlResponse { .. }) => {}
|
||||
Ok(CLIMessage::Result(_)) => {
|
||||
client.on_non_control(line).await?;
|
||||
Err(e) => {
|
||||
tracing::error!("Error reading stdout: {}", e);
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
client.on_non_control(line).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error reading stdout: {}", e);
|
||||
break;
|
||||
_ = &mut interrupt_rx => {
|
||||
if let Err(e) = self.interrupt().await {
|
||||
tracing::debug!("Failed to send interrupt to Claude: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -164,7 +179,6 @@ impl ProtocolPeer {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send JSON message to stdin
|
||||
async fn send_json<T: serde::Serialize>(&self, message: &T) -> Result<(), ExecutorError> {
|
||||
let json = serde_json::to_string(message)?;
|
||||
let mut stdin = self.stdin.lock().await;
|
||||
@@ -175,25 +189,23 @@ impl ProtocolPeer {
|
||||
}
|
||||
|
||||
pub async fn send_user_message(&self, content: String) -> Result<(), ExecutorError> {
|
||||
let message = serde_json::json!({
|
||||
"type": "user",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
});
|
||||
let message = Message::new_user(content);
|
||||
self.send_json(&message).await
|
||||
}
|
||||
|
||||
pub async fn initialize(&self, hooks: Option<serde_json::Value>) -> Result<(), ExecutorError> {
|
||||
self.send_json(&SDKControlRequestMessage::new(
|
||||
SDKControlRequestType::Initialize { hooks },
|
||||
))
|
||||
self.send_json(&SDKControlRequest::new(SDKControlRequestType::Initialize {
|
||||
hooks,
|
||||
}))
|
||||
.await
|
||||
}
|
||||
pub async fn interrupt(&self) -> Result<(), ExecutorError> {
|
||||
self.send_json(&SDKControlRequest::new(SDKControlRequestType::Interrupt {}))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<(), ExecutorError> {
|
||||
self.send_json(&SDKControlRequestMessage::new(
|
||||
self.send_json(&SDKControlRequest::new(
|
||||
SDKControlRequestType::SetPermissionMode { mode },
|
||||
))
|
||||
.await
|
||||
|
||||
@@ -23,14 +23,14 @@ pub enum CLIMessage {
|
||||
|
||||
/// Control request from SDK to CLI (outgoing)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SDKControlRequestMessage {
|
||||
pub struct SDKControlRequest {
|
||||
#[serde(rename = "type")]
|
||||
message_type: String, // Always "control_request"
|
||||
pub request_id: String,
|
||||
pub request: SDKControlRequestType,
|
||||
}
|
||||
|
||||
impl SDKControlRequestMessage {
|
||||
impl SDKControlRequest {
|
||||
pub fn new(request: SDKControlRequestType) -> Self {
|
||||
use uuid::Uuid;
|
||||
Self {
|
||||
@@ -141,6 +141,29 @@ pub enum ControlResponseType {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum Message {
|
||||
User { message: ClaudeUserMessage },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClaudeUserMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new_user(content: String) -> Self {
|
||||
Self::User {
|
||||
message: ClaudeUserMessage {
|
||||
role: "user".to_string(),
|
||||
content,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "subtype", rename_all = "snake_case")]
|
||||
pub enum SDKControlRequestType {
|
||||
@@ -151,6 +174,7 @@ pub enum SDKControlRequestType {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
hooks: Option<Value>,
|
||||
},
|
||||
Interrupt {},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
||||
@@ -376,6 +376,7 @@ impl Codex {
|
||||
Ok(SpawnedChild {
|
||||
child,
|
||||
exit_signal: Some(exit_signal_rx),
|
||||
interrupt_sender: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -236,10 +236,17 @@ pub enum ExecutorExitResult {
|
||||
/// and mark it according to the result.
|
||||
pub type ExecutorExitSignal = tokio::sync::oneshot::Receiver<ExecutorExitResult>;
|
||||
|
||||
/// Sender for requesting graceful interrupt of an executor.
|
||||
/// When sent, the executor should attempt to interrupt gracefully before being killed.
|
||||
pub type InterruptSender = tokio::sync::oneshot::Sender<()>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SpawnedChild {
|
||||
pub child: AsyncGroupChild,
|
||||
/// Executor → Container: signals when executor wants to exit
|
||||
pub exit_signal: Option<ExecutorExitSignal>,
|
||||
/// Container → Executor: signals when container wants to interrupt
|
||||
pub interrupt_sender: Option<InterruptSender>,
|
||||
}
|
||||
|
||||
impl From<AsyncGroupChild> for SpawnedChild {
|
||||
@@ -247,6 +254,7 @@ impl From<AsyncGroupChild> for SpawnedChild {
|
||||
Self {
|
||||
child,
|
||||
exit_signal: None,
|
||||
interrupt_sender: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ use executors::{
|
||||
coding_agent_initial::CodingAgentInitialRequest,
|
||||
},
|
||||
approvals::{ExecutorApprovalService, NoopExecutorApprovalService},
|
||||
executors::{BaseCodingAgent, ExecutorExitResult, ExecutorExitSignal},
|
||||
executors::{BaseCodingAgent, ExecutorExitResult, ExecutorExitSignal, InterruptSender},
|
||||
logs::{
|
||||
NormalizedEntryType,
|
||||
utils::{
|
||||
@@ -70,6 +70,7 @@ use crate::{command, copy};
|
||||
pub struct LocalContainerService {
|
||||
db: DBService,
|
||||
child_store: Arc<RwLock<HashMap<Uuid, Arc<RwLock<AsyncGroupChild>>>>>,
|
||||
interrupt_senders: Arc<RwLock<HashMap<Uuid, InterruptSender>>>,
|
||||
msg_stores: Arc<RwLock<HashMap<Uuid, Arc<MsgStore>>>>,
|
||||
config: Arc<RwLock<Config>>,
|
||||
git: GitService,
|
||||
@@ -94,10 +95,12 @@ impl LocalContainerService {
|
||||
publisher: Result<SharePublisher, RemoteClientNotConfigured>,
|
||||
) -> Self {
|
||||
let child_store = Arc::new(RwLock::new(HashMap::new()));
|
||||
let interrupt_senders = Arc::new(RwLock::new(HashMap::new()));
|
||||
|
||||
let container = LocalContainerService {
|
||||
db,
|
||||
child_store,
|
||||
interrupt_senders,
|
||||
msg_stores,
|
||||
config,
|
||||
git,
|
||||
@@ -128,6 +131,16 @@ impl LocalContainerService {
|
||||
map.remove(id);
|
||||
}
|
||||
|
||||
async fn add_interrupt_sender(&self, id: Uuid, sender: InterruptSender) {
|
||||
let mut map = self.interrupt_senders.write().await;
|
||||
map.insert(id, sender);
|
||||
}
|
||||
|
||||
async fn take_interrupt_sender(&self, id: &Uuid) -> Option<InterruptSender> {
|
||||
let mut map = self.interrupt_senders.write().await;
|
||||
map.remove(id)
|
||||
}
|
||||
|
||||
/// Defensively check for externally deleted worktrees and mark them as deleted in the database
|
||||
async fn check_externally_deleted_worktrees(db: &DBService) -> Result<(), DeploymentError> {
|
||||
let active_attempts = TaskAttempt::find_by_worktree_deleted(&db.pool).await?;
|
||||
@@ -986,6 +999,12 @@ impl ContainerService for LocalContainerService {
|
||||
self.add_child_to_store(execution_process.id, spawned.child)
|
||||
.await;
|
||||
|
||||
// Store interrupt sender for graceful shutdown
|
||||
if let Some(interrupt_sender) = spawned.interrupt_sender {
|
||||
self.add_interrupt_sender(execution_process.id, interrupt_sender)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Spawn unified exit monitor: watches OS exit and optional executor signal
|
||||
let _hn = self.spawn_exit_monitor(&execution_process.id, spawned.exit_signal);
|
||||
|
||||
@@ -1012,6 +1031,36 @@ impl ContainerService for LocalContainerService {
|
||||
ExecutionProcess::update_completion(&self.db.pool, execution_process.id, status, exit_code)
|
||||
.await?;
|
||||
|
||||
// Try graceful interrupt first, then force kill
|
||||
if let Some(interrupt_sender) = self.take_interrupt_sender(&execution_process.id).await {
|
||||
// Send interrupt signal (ignore error if receiver dropped)
|
||||
let _ = interrupt_sender.send(());
|
||||
|
||||
// Wait for graceful exit with timeout
|
||||
let graceful_exit = {
|
||||
let mut child_guard = child.write().await;
|
||||
tokio::time::timeout(Duration::from_secs(5), child_guard.wait()).await
|
||||
};
|
||||
|
||||
match graceful_exit {
|
||||
Ok(Ok(_)) => {
|
||||
tracing::debug!(
|
||||
"Process {} exited gracefully after interrupt",
|
||||
execution_process.id
|
||||
);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::info!("Error waiting for process {}: {}", execution_process.id, e);
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::debug!(
|
||||
"Graceful shutdown timed out for process {}, force killing",
|
||||
execution_process.id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Kill the child process and remove from the store
|
||||
{
|
||||
let mut child_guard = child.write().await;
|
||||
|
||||
Reference in New Issue
Block a user