From e06dd1f6dc94171afeeb2ad3b9cba44ac46f5eeb Mon Sep 17 00:00:00 2001 From: Alex Netsch Date: Tue, 28 Oct 2025 15:36:47 +0000 Subject: [PATCH] Claude approval refactor (#1080) * WIP claude approvals * Use canusetool * Remove old exitplanmode approvals * WIP approvals * types * Remove bloat * Cleanup, exit on finish * Approval messages, cleanup * Cleanup * Fix msg types * Lint fmt * Cleanup * Send deny * add missing timeout to hooks * FIx timeout issue * Cleanup * Error handling, log writer bugs * Remove deprecated approbal endpoints * Remove tool matching strategies in favour of only id based matching * remove register session, parse result at protocol level * Remove circular peer, remove unneeded trait * Types --- crates/executors/src/approvals.rs | 7 - crates/executors/src/executors/claude.rs | 449 ++++++++---------- .../executors/src/executors/claude/client.rs | 206 ++++++++ .../src/executors/claude/protocol.rs | 200 ++++++++ .../executors/src/executors/claude/types.rs | 178 +++++++ .../executors/src/executors/codex/client.rs | 6 - .../executors/src/executors/hooks/confirm.py | 179 ------- crates/local-deployment/src/container.rs | 12 +- crates/server/src/routes/approvals.rs | 80 +--- crates/services/src/services/approvals.rs | 269 ++--------- .../services/approvals/executor_approvals.rs | 20 +- crates/utils/src/approvals.rs | 18 +- shared/types.ts | 2 +- 13 files changed, 845 insertions(+), 781 deletions(-) create mode 100644 crates/executors/src/executors/claude/client.rs create mode 100644 crates/executors/src/executors/claude/protocol.rs create mode 100644 crates/executors/src/executors/claude/types.rs delete mode 100755 crates/executors/src/executors/hooks/confirm.py diff --git a/crates/executors/src/approvals.rs b/crates/executors/src/approvals.rs index 5f12fa64..2b2e1443 100644 --- a/crates/executors/src/approvals.rs +++ b/crates/executors/src/approvals.rs @@ -26,9 +26,6 @@ impl ExecutorApprovalError { /// Abstraction for executor approval backends. #[async_trait] pub trait ExecutorApprovalService: Send + Sync { - /// Registers the session identifier associated with subsequent approval requests. - async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError>; - /// Requests approval for a tool invocation and waits for the final decision. async fn request_tool_approval( &self, @@ -43,10 +40,6 @@ pub struct NoopExecutorApprovalService; #[async_trait] impl ExecutorApprovalService for NoopExecutorApprovalService { - async fn register_session(&self, _session_id: &str) -> Result<(), ExecutorApprovalError> { - Ok(()) - } - async fn request_tool_approval( &self, _tool_name: &str, diff --git a/crates/executors/src/executors/claude.rs b/crates/executors/src/executors/claude.rs index 770d0015..f8dadd90 100644 --- a/crates/executors/src/executors/claude.rs +++ b/crates/executors/src/executors/claude.rs @@ -1,5 +1,8 @@ -#[cfg(unix)] -use std::os::unix::fs::PermissionsExt; +// SDK submodules +pub mod client; +pub mod protocol; +pub mod types; + use std::{collections::HashMap, path::Path, process::Stdio, sync::Arc}; use async_trait::async_trait; @@ -7,42 +10,33 @@ use command_group::AsyncCommandGroup; use futures::StreamExt; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use tokio::{io::AsyncWriteExt, process::Command, sync::OnceCell}; +use tokio::process::Command; use ts_rs::TS; use workspace_utils::{ - approvals::APPROVAL_TIMEOUT_SECONDS, + approvals::ApprovalStatus, diff::{concatenate_diff_hunks, create_unified_diff, create_unified_diff_hunk}, log_msg::LogMsg, msg_store::MsgStore, path::make_path_relative, - port_file::read_port_file, shell::get_shell_command, }; +use self::{client::ClaudeAgentClient, protocol::ProtocolPeer, types::PermissionMode}; use crate::{ + approvals::ExecutorApprovalService, command::{CmdOverrides, CommandBuilder, apply_overrides}, - executors::{AppendPrompt, ExecutorError, SpawnedChild, StandardCodingAgentExecutor}, + executors::{ + AppendPrompt, ExecutorError, SpawnedChild, StandardCodingAgentExecutor, + codex::client::LogWriter, + }, logs::{ ActionType, FileChange, NormalizedEntry, NormalizedEntryType, TodoItem, ToolStatus, stderr_processor::normalize_stderr_logs, utils::{EntryIndexProvider, patch::ConversationPatch}, }, + stdout_dup::create_stdout_pipe_writer, }; -static BACKEND_PORT: OnceCell = OnceCell::const_new(); -async fn get_backend_port() -> std::io::Result { - BACKEND_PORT - .get_or_try_init(|| async { read_port_file("vibe-kanban").await }) - .await - .copied() -} - -const CONFIRM_HOOK_SCRIPT: &str = include_str!("./hooks/confirm.py"); - -/// Natural language marker we add in our Python hook to denote user feedback -/// This marker is added by our confirm.py hook script and is robust to Claude Code format changes -const USER_FEEDBACK_MARKER: &str = "User feedback: "; - fn base_command(claude_code_router: bool) -> &'static str { if claude_code_router { "npx -y @musistudio/claude-code-router@1.0.58 code" @@ -51,7 +45,10 @@ fn base_command(claude_code_router: bool) -> &'static str { } } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, JsonSchema)] +use derivative::Derivative; + +#[derive(Derivative, Clone, Serialize, Deserialize, TS, JsonSchema)] +#[derivative(Debug, PartialEq)] pub struct ClaudeCode { #[serde(default)] pub append_prompt: AppendPrompt, @@ -67,6 +64,11 @@ pub struct ClaudeCode { pub dangerously_skip_permissions: Option, #[serde(flatten)] pub cmd: CmdOverrides, + + #[serde(skip)] + #[ts(skip)] + #[derivative(Debug = "ignore", PartialEq = "ignore")] + approvals_service: Option>, } impl ClaudeCode { @@ -87,30 +89,14 @@ impl ClaudeCode { if plan && approvals { tracing::warn!("Both plan and approvals are enabled. Plan will take precedence."); } - - if plan { - builder = builder.extend_params(["--permission-mode=plan"]); - } - if plan || approvals { - match settings_json(plan).await { - // TODO: Avoid quoting - Ok(settings) => match shlex::try_quote(&settings) { - Ok(quoted) => { - builder = builder.extend_params(["--settings", "ed]); - } - Err(e) => { - tracing::error!("Failed to quote approvals JSON for --settings: {e}"); - } - }, - Err(e) => { - tracing::error!( - "Failed to generate approvals JSON. Not running approvals: {e}" - ); - } - } + // Enable bypass at startup, otherwise we cannot change to it after exiting plan mode + builder = builder.extend_params(["--permission-prompt-tool=stdio"]); + builder = builder.extend_params([format!( + "--permission-mode={}", + PermissionMode::BypassPermissions + )]); } - if self.dangerously_skip_permissions.unwrap_or(false) { builder = builder.extend_params(["--dangerously-skip-permissions"]); } @@ -120,49 +106,58 @@ impl ClaudeCode { builder = builder.extend_params([ "--verbose", "--output-format=stream-json", + "--input-format=stream-json", "--include-partial-messages", ]); apply_overrides(builder, &self.cmd) } + + pub fn permission_mode(&self) -> PermissionMode { + if self.plan.unwrap_or(false) { + PermissionMode::Plan + } else if self.approvals.unwrap_or(false) { + PermissionMode::Default + } else { + PermissionMode::BypassPermissions + } + } + + pub fn get_hooks(&self) -> Option { + if self.plan.unwrap_or(false) { + Some(serde_json::json!({ + "PreToolUse": [ + { + "matcher": "^ExitPlanMode$", + "hookCallbackIds": ["tool_approval"], + } + ] + })) + } else if self.approvals.unwrap_or(false) { + Some(serde_json::json!({ + "PreToolUse": [ + { + "matcher": "^(?!(Glob|Grep|NotebookRead|Read|Task|TodoWrite)$).*", + "hookCallbackIds": ["tool_approval"], + } + ] + })) + } else { + None + } + } } #[async_trait] impl StandardCodingAgentExecutor for ClaudeCode { + fn use_approvals(&mut self, approvals: Arc) { + self.approvals_service = Some(approvals); + } + async fn spawn(&self, current_dir: &Path, prompt: &str) -> Result { - let (shell_cmd, shell_arg) = get_shell_command(); let command_builder = self.build_command_builder().await; - let mut base_command = command_builder.build_initial(); - - if self.plan.unwrap_or(false) { - base_command = create_watchkill_script(&base_command); - } - - if self.approvals.unwrap_or(false) || self.plan.unwrap_or(false) { - write_python_hook(current_dir).await? - } - - let combined_prompt = self.append_prompt.combine_prompt(prompt); - - let mut command = Command::new(shell_cmd); - command - .kill_on_drop(true) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .current_dir(current_dir) - .arg(shell_arg) - .arg(&base_command); - - let mut child = command.group_spawn()?; - - // Feed the prompt in, then close the pipe so Claude sees EOF - if let Some(mut stdin) = child.inner().stdin.take() { - stdin.write_all(combined_prompt.as_bytes()).await?; - stdin.shutdown().await?; - } - - Ok(child.into()) + let base_command = command_builder.build_initial(); + self.spawn_internal(current_dir, prompt, base_command).await } async fn spawn_follow_up( @@ -171,44 +166,13 @@ impl StandardCodingAgentExecutor for ClaudeCode { prompt: &str, session_id: &str, ) -> Result { - let (shell_cmd, shell_arg) = get_shell_command(); let command_builder = self.build_command_builder().await; - // Build follow-up command with --resume {session_id} - let mut base_command = command_builder.build_follow_up(&[ + let base_command = command_builder.build_follow_up(&[ "--fork-session".to_string(), "--resume".to_string(), session_id.to_string(), ]); - - if self.plan.unwrap_or(false) { - base_command = create_watchkill_script(&base_command); - } - - if self.approvals.unwrap_or(false) || self.plan.unwrap_or(false) { - write_python_hook(current_dir).await? - } - - let combined_prompt = self.append_prompt.combine_prompt(prompt); - - let mut command = Command::new(shell_cmd); - command - .kill_on_drop(true) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .current_dir(current_dir) - .arg(shell_arg) - .arg(&base_command); - - let mut child = command.group_spawn()?; - - // Feed the followup prompt in, then close the pipe - if let Some(mut stdin) = child.inner().stdin.take() { - stdin.write_all(combined_prompt.as_bytes()).await?; - stdin.shutdown().await?; - } - - Ok(child.into()) + self.spawn_internal(current_dir, prompt, base_command).await } fn normalize_logs(&self, msg_store: Arc, current_dir: &Path) { @@ -232,116 +196,74 @@ impl StandardCodingAgentExecutor for ClaudeCode { } } -async fn write_python_hook(current_dir: &Path) -> Result<(), ExecutorError> { - let hooks_dir = current_dir.join(".claude").join("hooks"); - tokio::fs::create_dir_all(&hooks_dir).await?; - let hook_path = hooks_dir.join("confirm.py"); +impl ClaudeCode { + async fn spawn_internal( + &self, + current_dir: &Path, + prompt: &str, + base_command: String, + ) -> Result { + let (shell_cmd, shell_arg) = get_shell_command(); + let combined_prompt = self.append_prompt.combine_prompt(prompt); - let mut file = tokio::fs::File::create(&hook_path).await?; - file.write_all(CONFIRM_HOOK_SCRIPT.as_bytes()).await?; - file.flush().await?; + let mut command = Command::new(shell_cmd); + command + .kill_on_drop(true) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .current_dir(current_dir) + .arg(shell_arg) + .arg(&base_command); - // TODO: Handle Windows permissioning - #[cfg(unix)] - { - let perm = std::fs::Permissions::from_mode(0o755); - tokio::fs::set_permissions(&hook_path, perm).await?; + let mut child = command.group_spawn()?; + let child_stdout = child.inner().stdout.take().ok_or_else(|| { + ExecutorError::Io(std::io::Error::other("Claude Code missing stdout")) + })?; + let child_stdin = + child.inner().stdin.take().ok_or_else(|| { + ExecutorError::Io(std::io::Error::other("Claude Code missing stdin")) + })?; + + let new_stdout = create_stdout_pipe_writer(&mut child)?; + let permission_mode = self.permission_mode(); + let hooks = self.get_hooks(); + + // Spawn task to handle the SDK client with control protocol + let prompt_clone = combined_prompt.clone(); + let approvals_clone = self.approvals_service.clone(); + tokio::spawn(async move { + let log_writer = LogWriter::new(new_stdout); + let client = ClaudeAgentClient::new(log_writer.clone(), approvals_clone); + let protocol_peer = ProtocolPeer::spawn(child_stdin, child_stdout, client.clone()); + + // Initialize control protocol + if let Err(e) = protocol_peer.initialize(hooks).await { + tracing::error!("Failed to initialize control protocol: {e}"); + let _ = log_writer + .log_raw(&format!("Error: Failed to initialize - {e}")) + .await; + return; + } + + if let Err(e) = protocol_peer.set_permission_mode(permission_mode).await { + tracing::warn!("Failed to set permission mode to {permission_mode}: {e}"); + } + + // Send user message + if let Err(e) = protocol_peer.send_user_message(prompt_clone).await { + tracing::error!("Failed to send prompt: {e}"); + let _ = log_writer + .log_raw(&format!("Error: Failed to send prompt - {e}")) + .await; + } + }); + + Ok(SpawnedChild { + child, + exit_signal: None, + }) } - - // ignore the confirm.py script - let gitignore_path = hooks_dir.join(".gitignore"); - if !tokio::fs::try_exists(&gitignore_path).await? { - let mut gitignore_file = tokio::fs::File::create(&gitignore_path).await?; - gitignore_file - .write_all(b"confirm.py\n.gitignore\n") - .await?; - gitignore_file.flush().await?; - } - - Ok(()) -} - -// Configure settings json -async fn settings_json(plan: bool) -> Result { - let backend_port = get_backend_port().await?; - let backend_timeout = APPROVAL_TIMEOUT_SECONDS + 5; // add buffer - - let matcher = if plan { - "^ExitPlanMode$" - } else { - "^(?!(Glob|Grep|NotebookRead|Read|Task|TodoWrite)$).*" - }; - - Ok(serde_json::json!({ - "hooks": { - "PreToolUse": [ - { - "matcher": matcher, - "hooks": [ - { - "type": "command", - "command": format!("$CLAUDE_PROJECT_DIR/.claude/hooks/confirm.py --timeout-seconds {backend_timeout} --poll-interval 5 --backend-port {backend_port} --feedback-marker '{USER_FEEDBACK_MARKER}'"), - "timeout": backend_timeout + 10 - } - ] - } - ] - } - }) - .to_string()) -} - -fn create_watchkill_script(command: &str) -> String { - // Hack: we concatenate so that Claude doesn't trigger the watchkill when reading this file - // during development, since it contains the stop phrase - let claude_plan_stop_indicator = concat!("Approval ", "request timed out"); - let cmd = shlex::try_quote(command).unwrap().to_string(); - - format!( - r#"#!/usr/bin/env bash -set -euo pipefail - -word="{claude_plan_stop_indicator}" - -exit_code=0 -while IFS= read -r line; do - printf '%s\n' "$line" - if [[ $line == *"$word"* ]]; then - exit 0 - fi -done < <(bash -lc {cmd} <&0 2>&1) - -exit_code=${{PIPESTATUS[0]}} -exit "$exit_code" -"# - ) -} - -/// Extract user denial reason from tool result error messages -/// Our confirm.py hook prefixes user feedback with "User feedback: " for easy extraction -/// Supports both string content and Claude's array format: [{"type":"text","text":"..."}] -fn extract_denial_reason(content: &serde_json::Value) -> Option { - // First try to parse as string - let content_str = if let Some(s) = content.as_str() { - s.to_string() - } else if let Ok(items) = - serde_json::from_value::>(content.clone()) - { - // Handle array format: [{"type":"text","text":"..."}] - items - .into_iter() - .map(|item| item.text) - .collect::>() - .join("\n") - } else { - // Try to serialize the value as a string - content.to_string() - }; - - content_str - .split_once(USER_FEEDBACK_MARKER) - .map(|(_, rest)| rest.trim().to_string()) - .filter(|s| !s.is_empty()) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -491,6 +413,7 @@ impl ClaudeLogProcessor { ClaudeJson::ToolResult { session_id, .. } => session_id.clone(), ClaudeJson::StreamEvent { session_id, .. } => session_id.clone(), ClaudeJson::Result { session_id, .. } => session_id.clone(), + ClaudeJson::ApprovalResponse { .. } => None, ClaudeJson::Unknown { .. } => None, } } @@ -580,12 +503,22 @@ impl ClaudeLogProcessor { serde_json::to_value(content_item).unwrap_or(serde_json::Value::Null), ), }), - ClaudeContentItem::ToolUse { tool_data, .. } => { + ClaudeContentItem::ToolUse { tool_data, id } => { let name = tool_data.get_name(); let action_type = Self::extract_action_type(tool_data, worktree_path); let content = Self::generate_concise_content(tool_data, &action_type, worktree_path); + // Create metadata with tool_call_id for approval matching + let mut metadata = + serde_json::to_value(content_item).unwrap_or(serde_json::Value::Null); + if let Some(obj) = metadata.as_object_mut() { + obj.insert( + "tool_call_id".to_string(), + serde_json::Value::String(id.clone()), + ); + } + Some(NormalizedEntry { timestamp: None, entry_type: NormalizedEntryType::ToolUse { @@ -594,9 +527,7 @@ impl ClaudeLogProcessor { status: ToolStatus::Created, }, content, - metadata: Some( - serde_json::to_value(content_item).unwrap_or(serde_json::Value::Null), - ), + metadata: Some(metadata), }) } ClaudeContentItem::ToolResult { .. } => { @@ -837,6 +768,17 @@ impl ClaudeLogProcessor { &action_type, worktree_path, ); + + // Create metadata with tool_call_id for approval matching + let mut metadata = + serde_json::to_value(item).unwrap_or(serde_json::Value::Null); + if let Some(obj) = metadata.as_object_mut() { + obj.insert( + "tool_call_id".to_string(), + serde_json::Value::String(id.clone()), + ); + } + let entry = NormalizedEntry { timestamp: None, entry_type: NormalizedEntryType::ToolUse { @@ -845,9 +787,7 @@ impl ClaudeLogProcessor { status: ToolStatus::Created, }, content: content_text.clone(), - metadata: Some( - serde_json::to_value(item).unwrap_or(serde_json::Value::Null), - ), + metadata: Some(metadata), }; let is_new = entry_index.is_none(); let id_num = entry_index.unwrap_or_else(|| entry_index_provider.next()); @@ -930,7 +870,7 @@ impl ClaudeLogProcessor { { let is_command = matches!(info.tool_data, ClaudeToolData::Bash { .. }); - let display_tool_name = if is_command { + let _display_tool_name = if is_command { info.tool_name.clone() } else { let raw_name = info.tool_data.get_name().to_string(); @@ -1048,24 +988,8 @@ impl ClaudeLogProcessor { }; patches.push(ConversationPatch::replace(info.entry_index, entry)); } - - if is_error.unwrap_or(false) - && let Some(denial_reason) = extract_denial_reason(content) - { - let user_feedback = NormalizedEntry { - timestamp: None, - entry_type: NormalizedEntryType::UserFeedback { - denied_tool: display_tool_name.clone(), - }, - content: denial_reason, - metadata: None, - }; - let feedback_index = entry_index_provider.next(); - patches.push(ConversationPatch::add_normalized_entry( - feedback_index, - user_feedback, - )); - } + // Note: With control protocol, denials are handled via protocol messages + // rather than error content parsing } } } @@ -1166,6 +1090,40 @@ impl ClaudeLogProcessor { patches.push(ConversationPatch::add_normalized_entry(idx, entry)); } } + ClaudeJson::ApprovalResponse { + call_id: _, + tool_name, + approval_status, + } => { + // Convert denials and timeouts to visible entries (matching Codex behavior) + let entry_opt = match approval_status { + ApprovalStatus::Pending => None, + ApprovalStatus::Approved => None, + ApprovalStatus::Denied { reason } => Some(NormalizedEntry { + timestamp: None, + entry_type: NormalizedEntryType::UserFeedback { + denied_tool: tool_name.clone(), + }, + content: reason + .as_ref() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| "User denied this tool use request".to_string()), + metadata: None, + }), + ApprovalStatus::TimedOut => Some(NormalizedEntry { + timestamp: None, + entry_type: NormalizedEntryType::ErrorMessage, + content: format!("Approval timed out for tool {tool_name}"), + metadata: None, + }), + }; + + if let Some(entry) = entry_opt { + let idx = entry_index_provider.next(); + patches.push(ConversationPatch::add_normalized_entry(idx, entry)); + } + } ClaudeJson::Unknown { data } => { let entry = NormalizedEntry { timestamp: None, @@ -1434,7 +1392,7 @@ impl StreamingContentState { } // Data structures for parsing Claude's JSON output format -#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] +#[derive(Deserialize, Serialize, Debug, Clone)] #[serde(tag = "type")] pub enum ClaudeJson { #[serde(rename = "system")] @@ -1497,6 +1455,12 @@ pub enum ClaudeJson { #[serde(default, alias = "sessionId")] session_id: Option, }, + #[serde(rename = "approval_response")] + ApprovalResponse { + call_id: String, + tool_name: String, + approval_status: ApprovalStatus, + }, // Catch-all for unknown message types #[serde(untagged)] Unknown { @@ -2006,6 +1970,7 @@ mod tests { base_command_override: None, additional_params: None, }, + approvals_service: None, }; let msg_store = Arc::new(MsgStore::new()); let current_dir = std::path::PathBuf::from("/tmp/test-worktree"); diff --git a/crates/executors/src/executors/claude/client.rs b/crates/executors/src/executors/claude/client.rs new file mode 100644 index 00000000..97a9340e --- /dev/null +++ b/crates/executors/src/executors/claude/client.rs @@ -0,0 +1,206 @@ +use std::sync::Arc; + +use tokio::sync::Mutex; +use workspace_utils::approvals::ApprovalStatus; + +use super::types::PermissionMode; +use crate::{ + approvals::{ExecutorApprovalError, ExecutorApprovalService}, + executors::{ + ExecutorError, + claude::{ + ClaudeJson, + types::{ + PermissionResult, PermissionUpdate, PermissionUpdateDestination, + PermissionUpdateType, + }, + }, + codex::client::LogWriter, + }, +}; + +const EXIT_PLAN_MODE_NAME: &str = "ExitPlanMode"; + +/// Claude Agent client with control protocol support +pub struct ClaudeAgentClient { + log_writer: LogWriter, + approvals: Option>, + auto_approve: bool, // true when approvals is None + latest_unhandled_tool_use_id: Mutex>, +} + +impl ClaudeAgentClient { + /// Create a new client with optional approval service + pub fn new( + log_writer: LogWriter, + approvals: Option>, + ) -> Arc { + let auto_approve = approvals.is_none(); + Arc::new(Self { + log_writer, + approvals, + auto_approve, + latest_unhandled_tool_use_id: Mutex::new(None), + }) + } + async fn set_latest_unhandled_tool_use_id(&self, tool_use_id: String) { + if self.latest_unhandled_tool_use_id.lock().await.is_some() { + tracing::warn!( + "Overwriting unhandled tool_use_id: {} with new tool_use_id: {}", + self.latest_unhandled_tool_use_id + .lock() + .await + .as_ref() + .unwrap(), + tool_use_id + ); + } + let mut guard = self.latest_unhandled_tool_use_id.lock().await; + guard.replace(tool_use_id); + } + + async fn handle_approval( + &self, + tool_use_id: String, + tool_name: String, + tool_input: serde_json::Value, + ) -> Result { + // Use approval service to request tool approval + let approval_service = self + .approvals + .as_ref() + .ok_or(ExecutorApprovalError::ServiceUnavailable)?; + let status = approval_service + .request_tool_approval(&tool_name, tool_input.clone(), &tool_use_id) + .await; + match status { + Ok(status) => { + // Log the approval response so we it appears in the executor logs + self.log_writer + .log_raw(&serde_json::to_string(&ClaudeJson::ApprovalResponse { + call_id: tool_use_id.clone(), + tool_name: tool_name.clone(), + approval_status: status.clone(), + })?) + .await?; + match status { + ApprovalStatus::Approved => { + if tool_name == EXIT_PLAN_MODE_NAME { + Ok(PermissionResult::Allow { + updated_input: tool_input, + updated_permissions: Some(vec![PermissionUpdate { + update_type: PermissionUpdateType::SetMode, + mode: Some(PermissionMode::BypassPermissions), + destination: PermissionUpdateDestination::Session, + }]), + }) + } else { + Ok(PermissionResult::Allow { + updated_input: tool_input, + updated_permissions: None, + }) + } + } + ApprovalStatus::Denied { reason } => { + let message = reason.unwrap_or("Denied by user".to_string()); + Ok(PermissionResult::Deny { + message, + interrupt: Some(false), + }) + } + ApprovalStatus::TimedOut => Ok(PermissionResult::Deny { + message: "Approval request timed out".to_string(), + interrupt: Some(false), + }), + ApprovalStatus::Pending => Ok(PermissionResult::Deny { + message: "Approval still pending (unexpected)".to_string(), + interrupt: Some(false), + }), + } + } + Err(e) => { + tracing::error!("Tool approval request failed: {e}"); + Ok(PermissionResult::Deny { + message: "Tool approval request failed".to_string(), + interrupt: Some(false), + }) + } + } + } + + pub async fn on_can_use_tool( + &self, + tool_name: String, + input: serde_json::Value, + _permission_suggestions: Option>, + ) -> Result { + if self.auto_approve { + Ok(PermissionResult::Allow { + updated_input: input, + updated_permissions: None, + }) + } else { + let latest_tool_use_id = { + let guard = self.latest_unhandled_tool_use_id.lock().await.take(); + guard.clone() + }; + + if let Some(latest_tool_use_id) = latest_tool_use_id { + self.handle_approval(latest_tool_use_id, tool_name, input) + .await + } else { + // Auto approve tools with no matching tool_use_id. + // This rare edge case happens if a tool call triggers no hook callback, + // so no tool_use_id is available to match the approval request to. + tracing::warn!( + "No unhandled tool_use_id available for tool '{}', cannot request approval", + tool_name + ); + Ok(PermissionResult::Allow { + updated_input: input, + updated_permissions: None, + }) + } + } + } + + pub async fn on_hook_callback( + &self, + _callback_id: String, + _input: serde_json::Value, + tool_use_id: Option, + ) -> Result { + if self.auto_approve { + Ok(serde_json::json!({ + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "allow", + "permissionDecisionReason": "Auto-approved by SDK" + } + })) + } else { + // Hook callbacks is only used to store tool_use_id for later approval request + // Both hook callback and can_use_tool are needed. + // - Hook callbacks have a constant 60s timeout, so cannot be used for long approvals + // - can_use_tool does not provide tool_use_id, so cannot be used alone + // Together they allow matching approval requests to tool uses. + // This works because `ask` decision in hook callback triggers a can_use_tool request + // https://docs.claude.com/en/api/agent-sdk/permissions#permission-flow-diagram + if let Some(tool_use_id) = tool_use_id.clone() { + self.set_latest_unhandled_tool_use_id(tool_use_id).await; + } + Ok(serde_json::json!({ + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "ask", + "permissionDecisionReason": "Forwarding to canusetool service" + } + })) + } + } + + pub async fn on_non_control(&self, line: &str) -> Result<(), ExecutorError> { + // Forward all non-control messages to stdout + self.log_writer.log_raw(line).await + } +} diff --git a/crates/executors/src/executors/claude/protocol.rs b/crates/executors/src/executors/claude/protocol.rs new file mode 100644 index 00000000..678e5574 --- /dev/null +++ b/crates/executors/src/executors/claude/protocol.rs @@ -0,0 +1,200 @@ +use std::sync::Arc; + +use tokio::{ + io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, + process::{ChildStdin, ChildStdout}, + sync::Mutex, +}; + +use super::types::{ + CLIMessage, ControlRequestType, ControlResponseMessage, ControlResponseType, + SDKControlRequestMessage, +}; +use crate::executors::{ + ExecutorError, + claude::{ + client::ClaudeAgentClient, + types::{PermissionMode, SDKControlRequestType}, + }, +}; + +/// Handles bidirectional control protocol communication +#[derive(Clone)] +pub struct ProtocolPeer { + stdin: Arc>, +} + +impl ProtocolPeer { + pub fn spawn(stdin: ChildStdin, stdout: ChildStdout, client: Arc) -> Self { + let peer = Self { + stdin: Arc::new(Mutex::new(stdin)), + }; + + let reader_peer = peer.clone(); + tokio::spawn(async move { + if let Err(e) = reader_peer.read_loop(stdout, client).await { + tracing::error!("Protocol reader loop error: {}", e); + } + }); + + peer + } + + async fn read_loop( + &self, + stdout: ChildStdout, + client: Arc, + ) -> Result<(), ExecutorError> { + let mut reader = BufReader::new(stdout); + let mut buffer = String::new(); + + loop { + buffer.clear(); + match reader.read_line(&mut buffer).await { + Ok(0) => break, // EOF + Ok(_) => { + let line = buffer.trim(); + if line.is_empty() { + continue; + } + // Parse message using typed enum + match serde_json::from_str::(line) { + Ok(CLIMessage::ControlRequest { + request_id, + request, + }) => { + self.handle_control_request(&client, request_id, request) + .await; + } + Ok(CLIMessage::ControlResponse { .. }) => {} + Ok(CLIMessage::Result(_)) => { + client.on_non_control(line).await?; + break; + } + _ => { + client.on_non_control(line).await?; + } + } + } + Err(e) => { + tracing::error!("Error reading stdout: {}", e); + break; + } + } + } + Ok(()) + } + + async fn handle_control_request( + &self, + client: &Arc, + request_id: String, + request: ControlRequestType, + ) { + match request { + ControlRequestType::CanUseTool { + tool_name, + input, + permission_suggestions, + } => { + match client + .on_can_use_tool(tool_name, input, permission_suggestions) + .await + { + Ok(result) => { + if let Err(e) = self + .send_hook_response(request_id, serde_json::to_value(result).unwrap()) + .await + { + tracing::error!("Failed to send permission result: {e}"); + } + } + Err(e) => { + tracing::error!("Error in on_can_use_tool: {e}"); + if let Err(e2) = self.send_error(request_id, e.to_string()).await { + tracing::error!("Failed to send error response: {e2}"); + } + } + } + } + ControlRequestType::HookCallback { + callback_id, + input, + tool_use_id, + } => { + match client + .on_hook_callback(callback_id, input, tool_use_id) + .await + { + Ok(hook_output) => { + if let Err(e) = self.send_hook_response(request_id, hook_output).await { + tracing::error!("Failed to send hook callback result: {e}"); + } + } + Err(e) => { + tracing::error!("Error in on_hook_callback: {e}"); + if let Err(e2) = self.send_error(request_id, e.to_string()).await { + tracing::error!("Failed to send error response: {e2}"); + } + } + } + } + } + } + + pub async fn send_hook_response( + &self, + request_id: String, + hook_output: serde_json::Value, + ) -> Result<(), ExecutorError> { + self.send_json(&ControlResponseMessage::new(ControlResponseType::Success { + request_id, + response: Some(hook_output), + })) + .await + } + + /// Send error response to CLI + async fn send_error(&self, request_id: String, error: String) -> Result<(), ExecutorError> { + self.send_json(&ControlResponseMessage::new(ControlResponseType::Error { + request_id, + error: Some(error), + })) + .await + } + + /// Send JSON message to stdin + async fn send_json(&self, message: &T) -> Result<(), ExecutorError> { + let json = serde_json::to_string(message)?; + let mut stdin = self.stdin.lock().await; + stdin.write_all(json.as_bytes()).await?; + stdin.write_all(b"\n").await?; + stdin.flush().await?; + Ok(()) + } + + pub async fn send_user_message(&self, content: String) -> Result<(), ExecutorError> { + let message = serde_json::json!({ + "type": "user", + "message": { + "role": "user", + "content": content + } + }); + self.send_json(&message).await + } + + pub async fn initialize(&self, hooks: Option) -> Result<(), ExecutorError> { + self.send_json(&SDKControlRequestMessage::new( + SDKControlRequestType::Initialize { hooks }, + )) + .await + } + + pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<(), ExecutorError> { + self.send_json(&SDKControlRequestMessage::new( + SDKControlRequestType::SetPermissionMode { mode }, + )) + .await + } +} diff --git a/crates/executors/src/executors/claude/types.rs b/crates/executors/src/executors/claude/types.rs new file mode 100644 index 00000000..c4972d12 --- /dev/null +++ b/crates/executors/src/executors/claude/types.rs @@ -0,0 +1,178 @@ +//! Type definitions for Claude Code control protocol +//! +//! Similar to: https://github.com/ZhangHanDong/claude-code-api-rs/blob/main/claude-code-sdk-rs/src/types.rs + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// Top-level message types from CLI stdout +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum CLIMessage { + ControlRequest { + request_id: String, + request: ControlRequestType, + }, + ControlResponse { + response: ControlResponseType, + }, + Result(serde_json::Value), + #[serde(untagged)] + Other(serde_json::Value), +} + +/// Control request from SDK to CLI (outgoing) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SDKControlRequestMessage { + #[serde(rename = "type")] + message_type: String, // Always "control_request" + pub request_id: String, + pub request: SDKControlRequestType, +} + +impl SDKControlRequestMessage { + pub fn new(request: SDKControlRequestType) -> Self { + use uuid::Uuid; + Self { + message_type: "control_request".to_string(), + request_id: Uuid::new_v4().to_string(), + request, + } + } +} + +/// Control response from SDK to CLI (outgoing) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ControlResponseMessage { + #[serde(rename = "type")] + message_type: String, // Always "control_response" + pub response: ControlResponseType, +} + +impl ControlResponseMessage { + pub fn new(response: ControlResponseType) -> Self { + Self { + message_type: "control_response".to_string(), + response, + } + } +} + +/// Types of control requests +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "subtype", rename_all = "snake_case")] +pub enum ControlRequestType { + CanUseTool { + tool_name: String, + input: Value, + #[serde(skip_serializing_if = "Option::is_none")] + permission_suggestions: Option>, + }, + HookCallback { + #[serde(rename = "callback_id")] + callback_id: String, + input: Value, + #[serde(skip_serializing_if = "Option::is_none")] + tool_use_id: Option, + }, +} + +/// Result of permission check +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "behavior", rename_all = "camelCase")] +pub enum PermissionResult { + Allow { + #[serde(rename = "updatedInput")] + updated_input: Value, + #[serde(skip_serializing_if = "Option::is_none", rename = "updatedPermissions")] + updated_permissions: Option>, + }, + Deny { + message: String, + #[serde(skip_serializing_if = "Option::is_none")] + interrupt: Option, + }, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum PermissionUpdateType { + SetMode, + AddRules, + RemoveRules, + ClearRules, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum PermissionUpdateDestination { + Session, + UserSettings, + ProjectSettings, + LocalSettings, +} + +/// Permission update operation +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionUpdate { + #[serde(rename = "type")] + pub update_type: PermissionUpdateType, + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + pub destination: PermissionUpdateDestination, +} + +/// Control response from SDK to CLI +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "subtype", rename_all = "snake_case")] +pub enum ControlResponseType { + Success { + request_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + response: Option, + }, + Error { + request_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "subtype", rename_all = "snake_case")] +pub enum SDKControlRequestType { + SetPermissionMode { + mode: PermissionMode, + }, + Initialize { + #[serde(skip_serializing_if = "Option::is_none")] + hooks: Option, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum PermissionMode { + Default, + AcceptEdits, + Plan, + BypassPermissions, +} + +impl PermissionMode { + pub fn as_str(&self) -> &'static str { + match self { + Self::Default => "default", + Self::AcceptEdits => "acceptEdits", + Self::Plan => "plan", + Self::BypassPermissions => "bypassPermissions", + } + } +} + +impl std::fmt::Display for PermissionMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } +} diff --git a/crates/executors/src/executors/codex/client.rs b/crates/executors/src/executors/codex/client.rs index c28cd2cb..13138832 100644 --- a/crates/executors/src/executors/codex/client.rs +++ b/crates/executors/src/executors/codex/client.rs @@ -235,12 +235,6 @@ impl AppServerClient { let mut guard = self.conversation_id.lock().await; guard.replace(*conversation_id); } - if let Some(approvals) = self.approvals.as_ref() { - approvals - .register_session(&conversation_id.to_string()) - .await - .map_err(|err| ExecutorError::Io(io::Error::other(err.to_string())))?; - } self.flush_pending_feedback().await; Ok(()) } diff --git a/crates/executors/src/executors/hooks/confirm.py b/crates/executors/src/executors/hooks/confirm.py deleted file mode 100755 index 35562628..00000000 --- a/crates/executors/src/executors/hooks/confirm.py +++ /dev/null @@ -1,179 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import json -import sys -import time -import urllib.error -import urllib.request -from typing import Optional - - -def json_error(reason: Optional[str], feedback_marker: Optional[str] = None) -> None: - """Emit a deny PreToolUse JSON to stdout and exit(0).""" - # Prefix user feedback with marker for extraction if provided - formatted_reason = reason - if reason and feedback_marker: - formatted_reason = f"{feedback_marker}{reason}" - - payload = { - "hookSpecificOutput": { - "hookEventName": "PreToolUse", - "permissionDecision": "deny", - "permissionDecisionReason": formatted_reason, - } - } - print(json.dumps(payload, ensure_ascii=False)) - sys.exit(0) - - -def json_success() -> None: - payload = { - "hookSpecificOutput": { - "hookEventName": "PreToolUse", - "permissionDecision": "allow", - }, - "suppressOutput": True, - } - print(json.dumps(payload, ensure_ascii=False)) - sys.exit(0) - - -def http_post_json(url: str, body: dict) -> dict: - data = json.dumps(body).encode("utf-8") - req = urllib.request.Request( - url, data=data, headers={"Content-Type": "application/json"}, method="POST" - ) - try: - with urllib.request.urlopen(req, timeout=10) as resp: - return json.loads(resp.read().decode("utf-8") or "{}") - except ( - urllib.error.HTTPError, - urllib.error.URLError, - TimeoutError, - json.JSONDecodeError, - ) as e: - json_error( - f"Failed to create approval request. Backend may be unavailable. ({e})" - ) - raise # unreachable - - -def http_get_json(url: str) -> dict: - req = urllib.request.Request(url, method="GET") - try: - with urllib.request.urlopen(req, timeout=10) as resp: - return json.loads(resp.read().decode("utf-8") or "{}") - except ( - urllib.error.HTTPError, - urllib.error.URLError, - TimeoutError, - json.JSONDecodeError, - ) as e: - json_error(f"Lost connection to approval backend: {e}") - raise - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="PreToolUse approval gate. All parameters are passed via CLI." - ) - parser.add_argument( - "-t", - "--timeout-seconds", - type=int, - required=True, - help="Maximum time to wait for approval before timing out (seconds).", - ) - parser.add_argument( - "-p", - "--poll-interval", - type=int, - required=True, - help="Seconds between polling the backend for status.", - ) - parser.add_argument( - "-b", - "--backend-port", - type=int, - required=True, - help="Port of the approval backend running on 127.0.0.1.", - ) - parser.add_argument( - "-m", - "--feedback-marker", - type=str, - required=True, - help="Marker prefix for user feedback messages.", - ) - args = parser.parse_args() - - if args.timeout_seconds <= 0: - parser.error("--timeout-seconds must be a positive integer") - if args.poll_interval <= 0: - parser.error("--poll-interval must be a positive integer") - if args.poll_interval > args.timeout_seconds: - parser.error("--poll-interval cannot be greater than --timeout-seconds") - - return args - - -def main(): - args = parse_args() - port = args.backend_port - - url = f"http://127.0.0.1:{port}" - create_endpoint = f"{url}/api/approvals/create" - - try: - raw_payload = sys.stdin.read() - incoming = json.loads(raw_payload or "{}") - except json.JSONDecodeError: - json_error("Invalid JSON payload on stdin") - - tool_name = incoming.get("tool_name") - tool_input = incoming.get("tool_input") - session_id = incoming.get("session_id", "unknown") - - create_payload = { - "tool_name": tool_name, - "tool_input": tool_input, - "session_id": session_id, - } - - response = http_post_json(create_endpoint, create_payload) - approval_id = response.get("id") - if not approval_id: - json_error("Invalid response from approval backend") - - print( - f"Approval request created: {approval_id}. Waiting for user response...", - file=sys.stderr, - ) - - elapsed = 0 - while elapsed < args.timeout_seconds: - result = http_get_json(f"{url}/api/approvals/{approval_id}/status") - status = result.get("status") - - if status == "approved": - json_success() - elif status == "denied": - reason = result.get("reason") - json_error(reason, args.feedback_marker) - elif status == "timed_out": - # concat to avoid triggering the watchkill script - json_error( - "Approval request" + f" timed out after {args.timeout_seconds} seconds" - ) - elif status == "pending": - time.sleep(args.poll_interval) - elapsed += args.poll_interval - else: - json_error(f"Unknown approval status: {status}") - - # concat to avoid triggering the watchkill script - json_error("Approval request"+ f" timed out after {args.timeout_seconds} seconds") - - -if __name__ == "__main__": - main() diff --git a/crates/local-deployment/src/container.rs b/crates/local-deployment/src/container.rs index a9925a6f..fac8bdb8 100644 --- a/crates/local-deployment/src/container.rs +++ b/crates/local-deployment/src/container.rs @@ -809,11 +809,13 @@ impl ContainerService for LocalContainerService { let approvals_service: Arc = match executor_action.base_executor() { - Some(BaseCodingAgent::Codex) => ExecutorApprovalBridge::new( - self.approvals.clone(), - self.db.clone(), - execution_process.id, - ), + Some(BaseCodingAgent::Codex) | Some(BaseCodingAgent::ClaudeCode) => { + ExecutorApprovalBridge::new( + self.approvals.clone(), + self.db.clone(), + execution_process.id, + ) + } _ => Arc::new(NoopExecutorApprovalService {}), }; diff --git a/crates/server/src/routes/approvals.rs b/crates/server/src/routes/approvals.rs index 9dc3d9a6..c4c263bb 100644 --- a/crates/server/src/routes/approvals.rs +++ b/crates/server/src/routes/approvals.rs @@ -2,60 +2,13 @@ use axum::{ Json, Router, extract::{Path, State}, http::StatusCode, - routing::{get, post}, + routing::post, }; -use db::models::execution_process::ExecutionProcess; use deployment::Deployment; -use services::services::container::ContainerService; -use utils::approvals::{ - ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus, CreateApprovalRequest, - EXIT_PLAN_MODE_TOOL_NAME, -}; +use utils::approvals::{ApprovalResponse, ApprovalStatus}; use crate::DeploymentImpl; -pub async fn create_approval( - State(deployment): State, - Json(request): Json, -) -> Result, StatusCode> { - let service = deployment.approvals(); - - match service - .create_from_session(&deployment.db().pool, request) - .await - { - Ok(approval) => { - deployment - .track_if_analytics_allowed( - "approval_created", - serde_json::json!({ - "approval_id": approval.id, - "tool_name": &approval.tool_name, - "execution_process_id": approval.execution_process_id.to_string(), - }), - ) - .await; - - Ok(Json(approval)) - } - Err(e) => { - tracing::error!("Failed to create approval: {:?}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } -} - -pub async fn get_approval_status( - State(deployment): State, - Path(id): Path, -) -> Result, StatusCode> { - let service = deployment.approvals(); - match service.status(&id).await { - Some(status) => Ok(Json(status)), - None => Err(StatusCode::NOT_FOUND), - } -} - pub async fn respond_to_approval( State(deployment): State, Path(id): Path, @@ -77,21 +30,6 @@ pub async fn respond_to_approval( ) .await; - if matches!(status, ApprovalStatus::Approved) - && context.tool_name == EXIT_PLAN_MODE_TOOL_NAME - // If exiting plan mode, automatically start a new execution process with different - // permissions - && let Ok(ctx) = ExecutionProcess::load_context( - &deployment.db().pool, - context.execution_process_id, - ) - .await - && let Err(e) = deployment.container().exit_plan_mode_tool(ctx).await - { - tracing::error!("failed to exit plan mode: {:?}", e); - return Err(StatusCode::INTERNAL_SERVER_ERROR); - } - Ok(Json(status)) } Err(e) => { @@ -101,18 +39,6 @@ pub async fn respond_to_approval( } } -pub async fn get_pending_approvals( - State(deployment): State, -) -> Json> { - let service = deployment.approvals(); - let approvals = service.pending().await; - Json(approvals) -} - pub fn router() -> Router { - Router::new() - .route("/approvals/create", post(create_approval)) - .route("/approvals/{id}/status", get(get_approval_status)) - .route("/approvals/{id}/respond", post(respond_to_approval)) - .route("/approvals/pending", get(get_pending_approvals)) + Router::new().route("/approvals/{id}/respond", post(respond_to_approval)) } diff --git a/crates/services/src/services/approvals.rs b/crates/services/src/services/approvals.rs index 35bec4c2..54bbfb39 100644 --- a/crates/services/src/services/approvals.rs +++ b/crates/services/src/services/approvals.rs @@ -2,11 +2,9 @@ pub mod executor_approvals; use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration}; -use chrono::{DateTime, Utc}; use dashmap::DashMap; use db::models::{ execution_process::ExecutionProcess, - executor_session::ExecutorSession, task::{Task, TaskStatus}, }; use executors::{ @@ -21,10 +19,7 @@ use sqlx::{Error as SqlxError, SqlitePool}; use thiserror::Error; use tokio::sync::{RwLock, oneshot}; use utils::{ - approvals::{ - ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus, - CreateApprovalRequest, - }, + approvals::{ApprovalRequest, ApprovalResponse, ApprovalStatus}, log_msg::LogMsg, msg_store::MsgStore, }; @@ -36,8 +31,6 @@ struct PendingApproval { entry: NormalizedEntry, execution_process_id: Uuid, tool_name: String, - requested_at: DateTime, - timeout_at: DateTime, response_tx: oneshot::Sender, } @@ -81,7 +74,7 @@ impl Approvals { } } - async fn create_internal( + pub async fn create_with_waiter( &self, request: ApprovalRequest, ) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> { @@ -94,12 +87,7 @@ impl Approvals { if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await { // Find the matching tool use entry by name and input - let matching_tool = find_matching_tool_use( - store.clone(), - &request.tool_name, - &request.tool_input, - request.tool_call_id.as_deref(), - ); + let matching_tool = find_matching_tool_use(store.clone(), &request.tool_call_id); if let Some((idx, matching_tool)) = matching_tool { let approval_entry = matching_tool @@ -118,8 +106,6 @@ impl Approvals { entry: matching_tool, execution_process_id: request.execution_process_id, tool_name: request.tool_name.clone(), - requested_at: request.created_at, - timeout_at: request.timeout_at, response_tx: tx, }, ); @@ -147,41 +133,6 @@ impl Approvals { Ok((request, waiter)) } - #[tracing::instrument(skip(self, request))] - pub async fn create(&self, request: ApprovalRequest) -> Result { - let (request, _) = self.create_internal(request).await?; - Ok(request) - } - - pub async fn create_with_waiter( - &self, - request: ApprovalRequest, - ) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> { - self.create_internal(request).await - } - - pub async fn create_from_session( - &self, - pool: &SqlitePool, - payload: CreateApprovalRequest, - ) -> Result { - let session_id = payload.session_id.clone(); - let execution_process_id = - match ExecutorSession::find_by_session_id(pool, &session_id).await? { - Some(session) => session.execution_process_id, - None => { - tracing::warn!("No executor session found for session_id: {}", session_id); - return Err(ApprovalError::NoExecutorSession(session_id)); - } - }; - - // Move the task to InReview if it's still InProgress - ensure_task_in_review(pool, execution_process_id).await; - - let request = ApprovalRequest::from_create(payload, execution_process_id); - self.create(request).await - } - #[tracing::instrument(skip(self, id, req))] pub async fn respond( &self, @@ -238,39 +189,6 @@ impl Approvals { } } - pub async fn status(&self, id: &str) -> Option { - if let Some(f) = self.completed.get(id) { - return Some(f.clone()); - } - if let Some(p) = self.pending.get(id) { - if chrono::Utc::now() >= p.timeout_at { - return Some(ApprovalStatus::TimedOut); - } - return Some(ApprovalStatus::Pending); - } - None - } - - pub async fn pending(&self) -> Vec { - self.pending - .iter() - .filter_map(|entry| { - let (id, pending) = entry.pair(); - - match &pending.entry.entry_type { - NormalizedEntryType::ToolUse { tool_name, .. } => Some(ApprovalPendingInfo { - approval_id: id.clone(), - execution_process_id: pending.execution_process_id, - tool_name: tool_name.clone(), - requested_at: pending.requested_at, - timeout_at: pending.timeout_at, - }), - _ => None, - } - }) - .collect() - } - #[tracing::instrument(skip(self, id, timeout_at, waiter))] fn spawn_timeout_watcher( &self, @@ -352,126 +270,37 @@ pub(crate) async fn ensure_task_in_review(pool: &SqlitePool, execution_process_i } } -/// Comparison strategy for matching tool use entries -enum ToolComparisonStrategy { - /// Compare by tool_call_id - ToolCallId(String), - /// Compare deserialized ClaudeToolData structures (for known tools) - Deserialized(executors::executors::claude::ClaudeToolData), - /// Compare raw JSON input fields (for Unknown tools like MCP) - RawJson, -} - /// Find a matching tool use entry that hasn't been assigned to an approval yet -/// Matches by tool name and tool input to support parallel tool calls +/// Matches by tool call id from tool metadata fn find_matching_tool_use( store: Arc, - tool_name: &str, - tool_input: &serde_json::Value, - tool_call_id: Option<&str>, + tool_call_id: &str, ) -> Option<(usize, NormalizedEntry)> { - use executors::executors::claude::ClaudeToolData; - let history = store.get_history(); - // Determine comparison strategy based on tool type - let strategy = if let Some(call_id) = tool_call_id { - // If tool_call_id is provided, use it for matching - ToolComparisonStrategy::ToolCallId(call_id.to_string()) - } else { - match serde_json::from_value::(serde_json::json!({ - "name": tool_name, - "input": tool_input - })) { - Ok(ClaudeToolData::Unknown { .. }) => { - // For Unknown tools (MCP, future tools), use raw JSON comparison - ToolComparisonStrategy::RawJson - } - Ok(data) => { - // For known tools, use deserialized comparison with proper alias handling - ToolComparisonStrategy::Deserialized(data) - } - Err(e) => { - tracing::warn!( - "Failed to deserialize tool_input for tool '{}': {}", - tool_name, - e - ); - return None; - } - } - }; - - // Single loop through history with strategy-based comparison + // Single loop through history for msg in history.iter().rev() { if let LogMsg::JsonPatch(patch) = msg && let Some((idx, entry)) = extract_normalized_entry_from_patch(patch) - && let NormalizedEntryType::ToolUse { - tool_name: entry_tool_name, - status, - .. - } = &entry.entry_type + && let NormalizedEntryType::ToolUse { status, .. } = &entry.entry_type { // Only match tools that are in Created state if !matches!(status, ToolStatus::Created) { continue; } - // Tool name must match - if entry_tool_name != tool_name { - continue; - } - - // Apply comparison strategy - if let Some(metadata) = &entry.metadata { - let is_match = match &strategy { - ToolComparisonStrategy::ToolCallId(call_id) => { - // Match by tool_call_id in metadata - if let Ok(ToolCallMetadata { - tool_call_id: entry_call_id, - .. - }) = serde_json::from_value::(metadata.clone()) - { - entry_call_id == *call_id - } else { - false - } - } - ToolComparisonStrategy::RawJson => { - // Compare raw JSON input for Unknown tools - if let Some(entry_input) = metadata.get("input") { - entry_input == tool_input - } else { - false - } - } - ToolComparisonStrategy::Deserialized(approval_data) => { - // Compare deserialized structures for known tools - if let Ok(entry_tool_data) = - serde_json::from_value::(metadata.clone()) - { - entry_tool_data == *approval_data - } else { - false - } - } - }; - - if is_match { - let strategy_name = match strategy { - ToolComparisonStrategy::ToolCallId(call_id) => { - format!("tool_call_id '{call_id}'") - } - ToolComparisonStrategy::RawJson => "raw input comparison".to_string(), - ToolComparisonStrategy::Deserialized(_) => { - "deserialized tool data".to_string() - } - }; - tracing::debug!( - "Matched tool use entry at index {idx} for tool '{tool_name}' by {strategy_name}" - ); - return Some((idx, entry)); - } + // Match by tool call id from metadata + if let Some(metadata) = &entry.metadata + && let Ok(ToolCallMetadata { + tool_call_id: entry_call_id, + .. + }) = serde_json::from_value::(metadata.clone()) + && entry_call_id == tool_call_id + { + tracing::debug!( + "Matched tool use entry at index {idx} for tool call id '{tool_call_id}'" + ); + return Some((idx, entry)); } } } @@ -491,19 +320,9 @@ mod tests { fn create_tool_use_entry( tool_name: &str, file_path: &str, + id: &str, status: ToolStatus, ) -> NormalizedEntry { - // Create metadata that mimics the actual structure from Claude Code - // which has an "input" field containing the original tool parameters - let metadata = serde_json::json!({ - "type": "tool_use", - "id": format!("test-{}", file_path), - "name": tool_name, - "input": { - "file_path": file_path - } - }); - NormalizedEntry { timestamp: None, entry_type: NormalizedEntryType::ToolUse { @@ -514,7 +333,12 @@ mod tests { status, }, content: format!("Reading {file_path}"), - metadata: Some(metadata), + metadata: Some( + serde_json::to_value(ToolCallMetadata { + tool_call_id: id.to_string(), + }) + .unwrap(), + ), } } @@ -523,9 +347,9 @@ mod tests { let store = Arc::new(MsgStore::new()); // Setup: Simulate 3 parallel Read tool calls with different files - let read_foo = create_tool_use_entry("Read", "foo.rs", ToolStatus::Created); - let read_bar = create_tool_use_entry("Read", "bar.rs", ToolStatus::Created); - let read_baz = create_tool_use_entry("Read", "baz.rs", ToolStatus::Created); + let read_foo = create_tool_use_entry("Read", "foo.rs", "foo-id", ToolStatus::Created); + let read_bar = create_tool_use_entry("Read", "bar.rs", "bar-id", ToolStatus::Created); + let read_baz = create_tool_use_entry("Read", "baz.rs", "baz-id", ToolStatus::Created); store.push_patch( executors::logs::utils::patch::ConversationPatch::add_normalized_entry(0, read_foo), @@ -537,17 +361,12 @@ mod tests { executors::logs::utils::patch::ConversationPatch::add_normalized_entry(2, read_baz), ); - // Test 1: Each approval request matches its specific tool by input - let foo_input = serde_json::json!({"file_path": "foo.rs"}); - let bar_input = serde_json::json!({"file_path": "bar.rs"}); - let baz_input = serde_json::json!({"file_path": "baz.rs"}); - - let (idx_foo, _) = find_matching_tool_use(store.clone(), "Read", &foo_input, None) - .expect("Should match foo.rs"); - let (idx_bar, _) = find_matching_tool_use(store.clone(), "Read", &bar_input, None) - .expect("Should match bar.rs"); - let (idx_baz, _) = find_matching_tool_use(store.clone(), "Read", &baz_input, None) - .expect("Should match baz.rs"); + let (idx_foo, _) = + find_matching_tool_use(store.clone(), "foo-id").expect("Should match foo.rs"); + let (idx_bar, _) = + find_matching_tool_use(store.clone(), "bar-id").expect("Should match bar.rs"); + let (idx_baz, _) = + find_matching_tool_use(store.clone(), "baz-id").expect("Should match baz.rs"); assert_eq!(idx_foo, 0, "foo.rs should match first entry"); assert_eq!(idx_bar, 1, "bar.rs should match second entry"); @@ -557,6 +376,7 @@ mod tests { let read_pending = create_tool_use_entry( "Read", "pending.rs", + "pending-id", ToolStatus::PendingApproval { approval_id: "test-id".to_string(), requested_at: chrono::Utc::now(), @@ -567,24 +387,15 @@ mod tests { executors::logs::utils::patch::ConversationPatch::add_normalized_entry(3, read_pending), ); - let pending_input = serde_json::json!({"file_path": "pending.rs"}); assert!( - find_matching_tool_use(store.clone(), "Read", &pending_input, None).is_none(), + find_matching_tool_use(store.clone(), "pending-id").is_none(), "Should not match tools in PendingApproval state" ); - // Test 3: Wrong tool name returns None - let write_input = serde_json::json!({"file_path": "foo.rs", "content": "test"}); + // Test 3: Wrong tool id returns None assert!( - find_matching_tool_use(store.clone(), "Write", &write_input, None).is_none(), - "Should not match different tool names" - ); - - // Test 4: Wrong input parameters returns None - let wrong_input = serde_json::json!({"file_path": "nonexistent.rs"}); - assert!( - find_matching_tool_use(store.clone(), "Read", &wrong_input, None).is_none(), - "Should not match with different input parameters" + find_matching_tool_use(store.clone(), "wrong-id").is_none(), + "Should not match different tool ids" ); } } diff --git a/crates/services/src/services/approvals/executor_approvals.rs b/crates/services/src/services/approvals/executor_approvals.rs index 6dba8eaf..dbbd9882 100644 --- a/crates/services/src/services/approvals/executor_approvals.rs +++ b/crates/services/src/services/approvals/executor_approvals.rs @@ -4,7 +4,6 @@ use async_trait::async_trait; use db::{self, DBService}; use executors::approvals::{ExecutorApprovalError, ExecutorApprovalService}; use serde_json::Value; -use tokio::sync::RwLock; use utils::approvals::{ApprovalRequest, ApprovalStatus, CreateApprovalRequest}; use uuid::Uuid; @@ -14,7 +13,6 @@ pub struct ExecutorApprovalBridge { approvals: Approvals, db: DBService, execution_process_id: Uuid, - session_id: RwLock>, } impl ExecutorApprovalBridge { @@ -23,41 +21,25 @@ impl ExecutorApprovalBridge { approvals, db, execution_process_id, - session_id: RwLock::new(None), }) } } #[async_trait] impl ExecutorApprovalService for ExecutorApprovalBridge { - async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError> { - let mut guard = self.session_id.write().await; - guard.replace(session_id.to_string()); - - Ok(()) - } - async fn request_tool_approval( &self, tool_name: &str, tool_input: Value, tool_call_id: &str, ) -> Result { - let session_id = { - let guard = self.session_id.read().await; - guard - .clone() - .ok_or(ExecutorApprovalError::SessionNotRegistered)? - }; - super::ensure_task_in_review(&self.db.pool, self.execution_process_id).await; let request = ApprovalRequest::from_create( CreateApprovalRequest { tool_name: tool_name.to_string(), tool_input, - session_id, - tool_call_id: Some(tool_call_id.to_string()), + tool_call_id: tool_call_id.to_string(), }, self.execution_process_id, ); diff --git a/crates/utils/src/approvals.rs b/crates/utils/src/approvals.rs index c331163f..86f09464 100644 --- a/crates/utils/src/approvals.rs +++ b/crates/utils/src/approvals.rs @@ -4,15 +4,13 @@ use ts_rs::TS; use uuid::Uuid; pub const APPROVAL_TIMEOUT_SECONDS: i64 = 3600; // 1 hour -pub const EXIT_PLAN_MODE_TOOL_NAME: &str = "ExitPlanMode"; #[derive(Debug, Clone, Serialize, Deserialize, TS)] pub struct ApprovalRequest { pub id: String, pub tool_name: String, pub tool_input: serde_json::Value, - pub session_id: String, - pub tool_call_id: Option, + pub tool_call_id: String, pub execution_process_id: Uuid, pub created_at: DateTime, pub timeout_at: DateTime, @@ -25,7 +23,6 @@ impl ApprovalRequest { id: Uuid::new_v4().to_string(), tool_name: request.tool_name, tool_input: request.tool_input, - session_id: request.session_id, tool_call_id: request.tool_call_id, execution_process_id, created_at: now, @@ -39,8 +36,7 @@ impl ApprovalRequest { pub struct CreateApprovalRequest { pub tool_name: String, pub tool_input: serde_json::Value, - pub session_id: String, - pub tool_call_id: Option, + pub tool_call_id: String, } #[derive(Debug, Clone, Serialize, Deserialize, TS)] @@ -62,13 +58,3 @@ pub struct ApprovalResponse { pub execution_process_id: Uuid, pub status: ApprovalStatus, } - -#[derive(Debug, Clone, Serialize, Deserialize, TS)] -#[ts(export)] -pub struct ApprovalPendingInfo { - pub approval_id: String, - pub execution_process_id: Uuid, - pub tool_name: String, - pub requested_at: DateTime, - pub timeout_at: DateTime, -} diff --git a/shared/types.ts b/shared/types.ts index 1514b1f9..4ced17e1 100644 --- a/shared/types.ts +++ b/shared/types.ts @@ -322,7 +322,7 @@ export type PatchType = { "type": "NORMALIZED_ENTRY", "content": NormalizedEntry export type ApprovalStatus = { "status": "pending" } | { "status": "approved" } | { "status": "denied", reason?: string, } | { "status": "timed_out" }; -export type CreateApprovalRequest = { tool_name: string, tool_input: JsonValue, session_id: string, tool_call_id: string | null, }; +export type CreateApprovalRequest = { tool_name: string, tool_input: JsonValue, tool_call_id: string, }; export type ApprovalResponse = { execution_process_id: string, status: ApprovalStatus, };