Claude approval refactor (#1080)

* WIP claude approvals

* Use canusetool

* Remove old exitplanmode approvals

* WIP approvals

* types

* Remove bloat

* Cleanup, exit on finish

* Approval messages, cleanup

* Cleanup

* Fix msg types

* Lint fmt

* Cleanup

* Send deny

* add missing timeout to hooks

* FIx timeout issue

* Cleanup

* Error handling, log writer bugs

* Remove deprecated approbal endpoints

* Remove tool matching strategies in favour of only id based matching

* remove register session, parse result at protocol level

* Remove circular peer, remove unneeded trait

* Types
This commit is contained in:
Alex Netsch
2025-10-28 15:36:47 +00:00
committed by GitHub
parent a70a7bfbad
commit e06dd1f6dc
13 changed files with 845 additions and 781 deletions

View File

@@ -26,9 +26,6 @@ impl ExecutorApprovalError {
/// Abstraction for executor approval backends.
#[async_trait]
pub trait ExecutorApprovalService: Send + Sync {
/// Registers the session identifier associated with subsequent approval requests.
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError>;
/// Requests approval for a tool invocation and waits for the final decision.
async fn request_tool_approval(
&self,
@@ -43,10 +40,6 @@ pub struct NoopExecutorApprovalService;
#[async_trait]
impl ExecutorApprovalService for NoopExecutorApprovalService {
async fn register_session(&self, _session_id: &str) -> Result<(), ExecutorApprovalError> {
Ok(())
}
async fn request_tool_approval(
&self,
_tool_name: &str,

View File

@@ -1,5 +1,8 @@
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
// SDK submodules
pub mod client;
pub mod protocol;
pub mod types;
use std::{collections::HashMap, path::Path, process::Stdio, sync::Arc};
use async_trait::async_trait;
@@ -7,42 +10,33 @@ use command_group::AsyncCommandGroup;
use futures::StreamExt;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::{io::AsyncWriteExt, process::Command, sync::OnceCell};
use tokio::process::Command;
use ts_rs::TS;
use workspace_utils::{
approvals::APPROVAL_TIMEOUT_SECONDS,
approvals::ApprovalStatus,
diff::{concatenate_diff_hunks, create_unified_diff, create_unified_diff_hunk},
log_msg::LogMsg,
msg_store::MsgStore,
path::make_path_relative,
port_file::read_port_file,
shell::get_shell_command,
};
use self::{client::ClaudeAgentClient, protocol::ProtocolPeer, types::PermissionMode};
use crate::{
approvals::ExecutorApprovalService,
command::{CmdOverrides, CommandBuilder, apply_overrides},
executors::{AppendPrompt, ExecutorError, SpawnedChild, StandardCodingAgentExecutor},
executors::{
AppendPrompt, ExecutorError, SpawnedChild, StandardCodingAgentExecutor,
codex::client::LogWriter,
},
logs::{
ActionType, FileChange, NormalizedEntry, NormalizedEntryType, TodoItem, ToolStatus,
stderr_processor::normalize_stderr_logs,
utils::{EntryIndexProvider, patch::ConversationPatch},
},
stdout_dup::create_stdout_pipe_writer,
};
static BACKEND_PORT: OnceCell<u16> = OnceCell::const_new();
async fn get_backend_port() -> std::io::Result<u16> {
BACKEND_PORT
.get_or_try_init(|| async { read_port_file("vibe-kanban").await })
.await
.copied()
}
const CONFIRM_HOOK_SCRIPT: &str = include_str!("./hooks/confirm.py");
/// Natural language marker we add in our Python hook to denote user feedback
/// This marker is added by our confirm.py hook script and is robust to Claude Code format changes
const USER_FEEDBACK_MARKER: &str = "User feedback: ";
fn base_command(claude_code_router: bool) -> &'static str {
if claude_code_router {
"npx -y @musistudio/claude-code-router@1.0.58 code"
@@ -51,7 +45,10 @@ fn base_command(claude_code_router: bool) -> &'static str {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, JsonSchema)]
use derivative::Derivative;
#[derive(Derivative, Clone, Serialize, Deserialize, TS, JsonSchema)]
#[derivative(Debug, PartialEq)]
pub struct ClaudeCode {
#[serde(default)]
pub append_prompt: AppendPrompt,
@@ -67,6 +64,11 @@ pub struct ClaudeCode {
pub dangerously_skip_permissions: Option<bool>,
#[serde(flatten)]
pub cmd: CmdOverrides,
#[serde(skip)]
#[ts(skip)]
#[derivative(Debug = "ignore", PartialEq = "ignore")]
approvals_service: Option<Arc<dyn ExecutorApprovalService>>,
}
impl ClaudeCode {
@@ -87,30 +89,14 @@ impl ClaudeCode {
if plan && approvals {
tracing::warn!("Both plan and approvals are enabled. Plan will take precedence.");
}
if plan {
builder = builder.extend_params(["--permission-mode=plan"]);
}
if plan || approvals {
match settings_json(plan).await {
// TODO: Avoid quoting
Ok(settings) => match shlex::try_quote(&settings) {
Ok(quoted) => {
builder = builder.extend_params(["--settings", &quoted]);
}
Err(e) => {
tracing::error!("Failed to quote approvals JSON for --settings: {e}");
}
},
Err(e) => {
tracing::error!(
"Failed to generate approvals JSON. Not running approvals: {e}"
);
}
}
// Enable bypass at startup, otherwise we cannot change to it after exiting plan mode
builder = builder.extend_params(["--permission-prompt-tool=stdio"]);
builder = builder.extend_params([format!(
"--permission-mode={}",
PermissionMode::BypassPermissions
)]);
}
if self.dangerously_skip_permissions.unwrap_or(false) {
builder = builder.extend_params(["--dangerously-skip-permissions"]);
}
@@ -120,49 +106,58 @@ impl ClaudeCode {
builder = builder.extend_params([
"--verbose",
"--output-format=stream-json",
"--input-format=stream-json",
"--include-partial-messages",
]);
apply_overrides(builder, &self.cmd)
}
pub fn permission_mode(&self) -> PermissionMode {
if self.plan.unwrap_or(false) {
PermissionMode::Plan
} else if self.approvals.unwrap_or(false) {
PermissionMode::Default
} else {
PermissionMode::BypassPermissions
}
}
pub fn get_hooks(&self) -> Option<serde_json::Value> {
if self.plan.unwrap_or(false) {
Some(serde_json::json!({
"PreToolUse": [
{
"matcher": "^ExitPlanMode$",
"hookCallbackIds": ["tool_approval"],
}
]
}))
} else if self.approvals.unwrap_or(false) {
Some(serde_json::json!({
"PreToolUse": [
{
"matcher": "^(?!(Glob|Grep|NotebookRead|Read|Task|TodoWrite)$).*",
"hookCallbackIds": ["tool_approval"],
}
]
}))
} else {
None
}
}
}
#[async_trait]
impl StandardCodingAgentExecutor for ClaudeCode {
fn use_approvals(&mut self, approvals: Arc<dyn ExecutorApprovalService>) {
self.approvals_service = Some(approvals);
}
async fn spawn(&self, current_dir: &Path, prompt: &str) -> Result<SpawnedChild, ExecutorError> {
let (shell_cmd, shell_arg) = get_shell_command();
let command_builder = self.build_command_builder().await;
let mut base_command = command_builder.build_initial();
if self.plan.unwrap_or(false) {
base_command = create_watchkill_script(&base_command);
}
if self.approvals.unwrap_or(false) || self.plan.unwrap_or(false) {
write_python_hook(current_dir).await?
}
let combined_prompt = self.append_prompt.combine_prompt(prompt);
let mut command = Command::new(shell_cmd);
command
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.current_dir(current_dir)
.arg(shell_arg)
.arg(&base_command);
let mut child = command.group_spawn()?;
// Feed the prompt in, then close the pipe so Claude sees EOF
if let Some(mut stdin) = child.inner().stdin.take() {
stdin.write_all(combined_prompt.as_bytes()).await?;
stdin.shutdown().await?;
}
Ok(child.into())
let base_command = command_builder.build_initial();
self.spawn_internal(current_dir, prompt, base_command).await
}
async fn spawn_follow_up(
@@ -171,44 +166,13 @@ impl StandardCodingAgentExecutor for ClaudeCode {
prompt: &str,
session_id: &str,
) -> Result<SpawnedChild, ExecutorError> {
let (shell_cmd, shell_arg) = get_shell_command();
let command_builder = self.build_command_builder().await;
// Build follow-up command with --resume {session_id}
let mut base_command = command_builder.build_follow_up(&[
let base_command = command_builder.build_follow_up(&[
"--fork-session".to_string(),
"--resume".to_string(),
session_id.to_string(),
]);
if self.plan.unwrap_or(false) {
base_command = create_watchkill_script(&base_command);
}
if self.approvals.unwrap_or(false) || self.plan.unwrap_or(false) {
write_python_hook(current_dir).await?
}
let combined_prompt = self.append_prompt.combine_prompt(prompt);
let mut command = Command::new(shell_cmd);
command
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.current_dir(current_dir)
.arg(shell_arg)
.arg(&base_command);
let mut child = command.group_spawn()?;
// Feed the followup prompt in, then close the pipe
if let Some(mut stdin) = child.inner().stdin.take() {
stdin.write_all(combined_prompt.as_bytes()).await?;
stdin.shutdown().await?;
}
Ok(child.into())
self.spawn_internal(current_dir, prompt, base_command).await
}
fn normalize_logs(&self, msg_store: Arc<MsgStore>, current_dir: &Path) {
@@ -232,116 +196,74 @@ impl StandardCodingAgentExecutor for ClaudeCode {
}
}
async fn write_python_hook(current_dir: &Path) -> Result<(), ExecutorError> {
let hooks_dir = current_dir.join(".claude").join("hooks");
tokio::fs::create_dir_all(&hooks_dir).await?;
let hook_path = hooks_dir.join("confirm.py");
impl ClaudeCode {
async fn spawn_internal(
&self,
current_dir: &Path,
prompt: &str,
base_command: String,
) -> Result<SpawnedChild, ExecutorError> {
let (shell_cmd, shell_arg) = get_shell_command();
let combined_prompt = self.append_prompt.combine_prompt(prompt);
let mut file = tokio::fs::File::create(&hook_path).await?;
file.write_all(CONFIRM_HOOK_SCRIPT.as_bytes()).await?;
file.flush().await?;
let mut command = Command::new(shell_cmd);
command
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.current_dir(current_dir)
.arg(shell_arg)
.arg(&base_command);
// TODO: Handle Windows permissioning
#[cfg(unix)]
{
let perm = std::fs::Permissions::from_mode(0o755);
tokio::fs::set_permissions(&hook_path, perm).await?;
let mut child = command.group_spawn()?;
let child_stdout = child.inner().stdout.take().ok_or_else(|| {
ExecutorError::Io(std::io::Error::other("Claude Code missing stdout"))
})?;
let child_stdin =
child.inner().stdin.take().ok_or_else(|| {
ExecutorError::Io(std::io::Error::other("Claude Code missing stdin"))
})?;
let new_stdout = create_stdout_pipe_writer(&mut child)?;
let permission_mode = self.permission_mode();
let hooks = self.get_hooks();
// Spawn task to handle the SDK client with control protocol
let prompt_clone = combined_prompt.clone();
let approvals_clone = self.approvals_service.clone();
tokio::spawn(async move {
let log_writer = LogWriter::new(new_stdout);
let client = ClaudeAgentClient::new(log_writer.clone(), approvals_clone);
let protocol_peer = ProtocolPeer::spawn(child_stdin, child_stdout, client.clone());
// Initialize control protocol
if let Err(e) = protocol_peer.initialize(hooks).await {
tracing::error!("Failed to initialize control protocol: {e}");
let _ = log_writer
.log_raw(&format!("Error: Failed to initialize - {e}"))
.await;
return;
}
if let Err(e) = protocol_peer.set_permission_mode(permission_mode).await {
tracing::warn!("Failed to set permission mode to {permission_mode}: {e}");
}
// Send user message
if let Err(e) = protocol_peer.send_user_message(prompt_clone).await {
tracing::error!("Failed to send prompt: {e}");
let _ = log_writer
.log_raw(&format!("Error: Failed to send prompt - {e}"))
.await;
}
});
Ok(SpawnedChild {
child,
exit_signal: None,
})
}
// ignore the confirm.py script
let gitignore_path = hooks_dir.join(".gitignore");
if !tokio::fs::try_exists(&gitignore_path).await? {
let mut gitignore_file = tokio::fs::File::create(&gitignore_path).await?;
gitignore_file
.write_all(b"confirm.py\n.gitignore\n")
.await?;
gitignore_file.flush().await?;
}
Ok(())
}
// Configure settings json
async fn settings_json(plan: bool) -> Result<String, std::io::Error> {
let backend_port = get_backend_port().await?;
let backend_timeout = APPROVAL_TIMEOUT_SECONDS + 5; // add buffer
let matcher = if plan {
"^ExitPlanMode$"
} else {
"^(?!(Glob|Grep|NotebookRead|Read|Task|TodoWrite)$).*"
};
Ok(serde_json::json!({
"hooks": {
"PreToolUse": [
{
"matcher": matcher,
"hooks": [
{
"type": "command",
"command": format!("$CLAUDE_PROJECT_DIR/.claude/hooks/confirm.py --timeout-seconds {backend_timeout} --poll-interval 5 --backend-port {backend_port} --feedback-marker '{USER_FEEDBACK_MARKER}'"),
"timeout": backend_timeout + 10
}
]
}
]
}
})
.to_string())
}
fn create_watchkill_script(command: &str) -> String {
// Hack: we concatenate so that Claude doesn't trigger the watchkill when reading this file
// during development, since it contains the stop phrase
let claude_plan_stop_indicator = concat!("Approval ", "request timed out");
let cmd = shlex::try_quote(command).unwrap().to_string();
format!(
r#"#!/usr/bin/env bash
set -euo pipefail
word="{claude_plan_stop_indicator}"
exit_code=0
while IFS= read -r line; do
printf '%s\n' "$line"
if [[ $line == *"$word"* ]]; then
exit 0
fi
done < <(bash -lc {cmd} <&0 2>&1)
exit_code=${{PIPESTATUS[0]}}
exit "$exit_code"
"#
)
}
/// Extract user denial reason from tool result error messages
/// Our confirm.py hook prefixes user feedback with "User feedback: " for easy extraction
/// Supports both string content and Claude's array format: [{"type":"text","text":"..."}]
fn extract_denial_reason(content: &serde_json::Value) -> Option<String> {
// First try to parse as string
let content_str = if let Some(s) = content.as_str() {
s.to_string()
} else if let Ok(items) =
serde_json::from_value::<Vec<ClaudeToolResultTextItem>>(content.clone())
{
// Handle array format: [{"type":"text","text":"..."}]
items
.into_iter()
.map(|item| item.text)
.collect::<Vec<_>>()
.join("\n")
} else {
// Try to serialize the value as a string
content.to_string()
};
content_str
.split_once(USER_FEEDBACK_MARKER)
.map(|(_, rest)| rest.trim().to_string())
.filter(|s| !s.is_empty())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -491,6 +413,7 @@ impl ClaudeLogProcessor {
ClaudeJson::ToolResult { session_id, .. } => session_id.clone(),
ClaudeJson::StreamEvent { session_id, .. } => session_id.clone(),
ClaudeJson::Result { session_id, .. } => session_id.clone(),
ClaudeJson::ApprovalResponse { .. } => None,
ClaudeJson::Unknown { .. } => None,
}
}
@@ -580,12 +503,22 @@ impl ClaudeLogProcessor {
serde_json::to_value(content_item).unwrap_or(serde_json::Value::Null),
),
}),
ClaudeContentItem::ToolUse { tool_data, .. } => {
ClaudeContentItem::ToolUse { tool_data, id } => {
let name = tool_data.get_name();
let action_type = Self::extract_action_type(tool_data, worktree_path);
let content =
Self::generate_concise_content(tool_data, &action_type, worktree_path);
// Create metadata with tool_call_id for approval matching
let mut metadata =
serde_json::to_value(content_item).unwrap_or(serde_json::Value::Null);
if let Some(obj) = metadata.as_object_mut() {
obj.insert(
"tool_call_id".to_string(),
serde_json::Value::String(id.clone()),
);
}
Some(NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -594,9 +527,7 @@ impl ClaudeLogProcessor {
status: ToolStatus::Created,
},
content,
metadata: Some(
serde_json::to_value(content_item).unwrap_or(serde_json::Value::Null),
),
metadata: Some(metadata),
})
}
ClaudeContentItem::ToolResult { .. } => {
@@ -837,6 +768,17 @@ impl ClaudeLogProcessor {
&action_type,
worktree_path,
);
// Create metadata with tool_call_id for approval matching
let mut metadata =
serde_json::to_value(item).unwrap_or(serde_json::Value::Null);
if let Some(obj) = metadata.as_object_mut() {
obj.insert(
"tool_call_id".to_string(),
serde_json::Value::String(id.clone()),
);
}
let entry = NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -845,9 +787,7 @@ impl ClaudeLogProcessor {
status: ToolStatus::Created,
},
content: content_text.clone(),
metadata: Some(
serde_json::to_value(item).unwrap_or(serde_json::Value::Null),
),
metadata: Some(metadata),
};
let is_new = entry_index.is_none();
let id_num = entry_index.unwrap_or_else(|| entry_index_provider.next());
@@ -930,7 +870,7 @@ impl ClaudeLogProcessor {
{
let is_command = matches!(info.tool_data, ClaudeToolData::Bash { .. });
let display_tool_name = if is_command {
let _display_tool_name = if is_command {
info.tool_name.clone()
} else {
let raw_name = info.tool_data.get_name().to_string();
@@ -1048,24 +988,8 @@ impl ClaudeLogProcessor {
};
patches.push(ConversationPatch::replace(info.entry_index, entry));
}
if is_error.unwrap_or(false)
&& let Some(denial_reason) = extract_denial_reason(content)
{
let user_feedback = NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::UserFeedback {
denied_tool: display_tool_name.clone(),
},
content: denial_reason,
metadata: None,
};
let feedback_index = entry_index_provider.next();
patches.push(ConversationPatch::add_normalized_entry(
feedback_index,
user_feedback,
));
}
// Note: With control protocol, denials are handled via protocol messages
// rather than error content parsing
}
}
}
@@ -1166,6 +1090,40 @@ impl ClaudeLogProcessor {
patches.push(ConversationPatch::add_normalized_entry(idx, entry));
}
}
ClaudeJson::ApprovalResponse {
call_id: _,
tool_name,
approval_status,
} => {
// Convert denials and timeouts to visible entries (matching Codex behavior)
let entry_opt = match approval_status {
ApprovalStatus::Pending => None,
ApprovalStatus::Approved => None,
ApprovalStatus::Denied { reason } => Some(NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::UserFeedback {
denied_tool: tool_name.clone(),
},
content: reason
.as_ref()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "User denied this tool use request".to_string()),
metadata: None,
}),
ApprovalStatus::TimedOut => Some(NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ErrorMessage,
content: format!("Approval timed out for tool {tool_name}"),
metadata: None,
}),
};
if let Some(entry) = entry_opt {
let idx = entry_index_provider.next();
patches.push(ConversationPatch::add_normalized_entry(idx, entry));
}
}
ClaudeJson::Unknown { data } => {
let entry = NormalizedEntry {
timestamp: None,
@@ -1434,7 +1392,7 @@ impl StreamingContentState {
}
// Data structures for parsing Claude's JSON output format
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum ClaudeJson {
#[serde(rename = "system")]
@@ -1497,6 +1455,12 @@ pub enum ClaudeJson {
#[serde(default, alias = "sessionId")]
session_id: Option<String>,
},
#[serde(rename = "approval_response")]
ApprovalResponse {
call_id: String,
tool_name: String,
approval_status: ApprovalStatus,
},
// Catch-all for unknown message types
#[serde(untagged)]
Unknown {
@@ -2006,6 +1970,7 @@ mod tests {
base_command_override: None,
additional_params: None,
},
approvals_service: None,
};
let msg_store = Arc::new(MsgStore::new());
let current_dir = std::path::PathBuf::from("/tmp/test-worktree");

View File

@@ -0,0 +1,206 @@
use std::sync::Arc;
use tokio::sync::Mutex;
use workspace_utils::approvals::ApprovalStatus;
use super::types::PermissionMode;
use crate::{
approvals::{ExecutorApprovalError, ExecutorApprovalService},
executors::{
ExecutorError,
claude::{
ClaudeJson,
types::{
PermissionResult, PermissionUpdate, PermissionUpdateDestination,
PermissionUpdateType,
},
},
codex::client::LogWriter,
},
};
const EXIT_PLAN_MODE_NAME: &str = "ExitPlanMode";
/// Claude Agent client with control protocol support
pub struct ClaudeAgentClient {
log_writer: LogWriter,
approvals: Option<Arc<dyn ExecutorApprovalService>>,
auto_approve: bool, // true when approvals is None
latest_unhandled_tool_use_id: Mutex<Option<String>>,
}
impl ClaudeAgentClient {
/// Create a new client with optional approval service
pub fn new(
log_writer: LogWriter,
approvals: Option<Arc<dyn ExecutorApprovalService>>,
) -> Arc<Self> {
let auto_approve = approvals.is_none();
Arc::new(Self {
log_writer,
approvals,
auto_approve,
latest_unhandled_tool_use_id: Mutex::new(None),
})
}
async fn set_latest_unhandled_tool_use_id(&self, tool_use_id: String) {
if self.latest_unhandled_tool_use_id.lock().await.is_some() {
tracing::warn!(
"Overwriting unhandled tool_use_id: {} with new tool_use_id: {}",
self.latest_unhandled_tool_use_id
.lock()
.await
.as_ref()
.unwrap(),
tool_use_id
);
}
let mut guard = self.latest_unhandled_tool_use_id.lock().await;
guard.replace(tool_use_id);
}
async fn handle_approval(
&self,
tool_use_id: String,
tool_name: String,
tool_input: serde_json::Value,
) -> Result<PermissionResult, ExecutorError> {
// Use approval service to request tool approval
let approval_service = self
.approvals
.as_ref()
.ok_or(ExecutorApprovalError::ServiceUnavailable)?;
let status = approval_service
.request_tool_approval(&tool_name, tool_input.clone(), &tool_use_id)
.await;
match status {
Ok(status) => {
// Log the approval response so we it appears in the executor logs
self.log_writer
.log_raw(&serde_json::to_string(&ClaudeJson::ApprovalResponse {
call_id: tool_use_id.clone(),
tool_name: tool_name.clone(),
approval_status: status.clone(),
})?)
.await?;
match status {
ApprovalStatus::Approved => {
if tool_name == EXIT_PLAN_MODE_NAME {
Ok(PermissionResult::Allow {
updated_input: tool_input,
updated_permissions: Some(vec![PermissionUpdate {
update_type: PermissionUpdateType::SetMode,
mode: Some(PermissionMode::BypassPermissions),
destination: PermissionUpdateDestination::Session,
}]),
})
} else {
Ok(PermissionResult::Allow {
updated_input: tool_input,
updated_permissions: None,
})
}
}
ApprovalStatus::Denied { reason } => {
let message = reason.unwrap_or("Denied by user".to_string());
Ok(PermissionResult::Deny {
message,
interrupt: Some(false),
})
}
ApprovalStatus::TimedOut => Ok(PermissionResult::Deny {
message: "Approval request timed out".to_string(),
interrupt: Some(false),
}),
ApprovalStatus::Pending => Ok(PermissionResult::Deny {
message: "Approval still pending (unexpected)".to_string(),
interrupt: Some(false),
}),
}
}
Err(e) => {
tracing::error!("Tool approval request failed: {e}");
Ok(PermissionResult::Deny {
message: "Tool approval request failed".to_string(),
interrupt: Some(false),
})
}
}
}
pub async fn on_can_use_tool(
&self,
tool_name: String,
input: serde_json::Value,
_permission_suggestions: Option<Vec<PermissionUpdate>>,
) -> Result<PermissionResult, ExecutorError> {
if self.auto_approve {
Ok(PermissionResult::Allow {
updated_input: input,
updated_permissions: None,
})
} else {
let latest_tool_use_id = {
let guard = self.latest_unhandled_tool_use_id.lock().await.take();
guard.clone()
};
if let Some(latest_tool_use_id) = latest_tool_use_id {
self.handle_approval(latest_tool_use_id, tool_name, input)
.await
} else {
// Auto approve tools with no matching tool_use_id.
// This rare edge case happens if a tool call triggers no hook callback,
// so no tool_use_id is available to match the approval request to.
tracing::warn!(
"No unhandled tool_use_id available for tool '{}', cannot request approval",
tool_name
);
Ok(PermissionResult::Allow {
updated_input: input,
updated_permissions: None,
})
}
}
}
pub async fn on_hook_callback(
&self,
_callback_id: String,
_input: serde_json::Value,
tool_use_id: Option<String>,
) -> Result<serde_json::Value, ExecutorError> {
if self.auto_approve {
Ok(serde_json::json!({
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
"permissionDecisionReason": "Auto-approved by SDK"
}
}))
} else {
// Hook callbacks is only used to store tool_use_id for later approval request
// Both hook callback and can_use_tool are needed.
// - Hook callbacks have a constant 60s timeout, so cannot be used for long approvals
// - can_use_tool does not provide tool_use_id, so cannot be used alone
// Together they allow matching approval requests to tool uses.
// This works because `ask` decision in hook callback triggers a can_use_tool request
// https://docs.claude.com/en/api/agent-sdk/permissions#permission-flow-diagram
if let Some(tool_use_id) = tool_use_id.clone() {
self.set_latest_unhandled_tool_use_id(tool_use_id).await;
}
Ok(serde_json::json!({
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "ask",
"permissionDecisionReason": "Forwarding to canusetool service"
}
}))
}
}
pub async fn on_non_control(&self, line: &str) -> Result<(), ExecutorError> {
// Forward all non-control messages to stdout
self.log_writer.log_raw(line).await
}
}

View File

@@ -0,0 +1,200 @@
use std::sync::Arc;
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
process::{ChildStdin, ChildStdout},
sync::Mutex,
};
use super::types::{
CLIMessage, ControlRequestType, ControlResponseMessage, ControlResponseType,
SDKControlRequestMessage,
};
use crate::executors::{
ExecutorError,
claude::{
client::ClaudeAgentClient,
types::{PermissionMode, SDKControlRequestType},
},
};
/// Handles bidirectional control protocol communication
#[derive(Clone)]
pub struct ProtocolPeer {
stdin: Arc<Mutex<ChildStdin>>,
}
impl ProtocolPeer {
pub fn spawn(stdin: ChildStdin, stdout: ChildStdout, client: Arc<ClaudeAgentClient>) -> Self {
let peer = Self {
stdin: Arc::new(Mutex::new(stdin)),
};
let reader_peer = peer.clone();
tokio::spawn(async move {
if let Err(e) = reader_peer.read_loop(stdout, client).await {
tracing::error!("Protocol reader loop error: {}", e);
}
});
peer
}
async fn read_loop(
&self,
stdout: ChildStdout,
client: Arc<ClaudeAgentClient>,
) -> Result<(), ExecutorError> {
let mut reader = BufReader::new(stdout);
let mut buffer = String::new();
loop {
buffer.clear();
match reader.read_line(&mut buffer).await {
Ok(0) => break, // EOF
Ok(_) => {
let line = buffer.trim();
if line.is_empty() {
continue;
}
// Parse message using typed enum
match serde_json::from_str::<CLIMessage>(line) {
Ok(CLIMessage::ControlRequest {
request_id,
request,
}) => {
self.handle_control_request(&client, request_id, request)
.await;
}
Ok(CLIMessage::ControlResponse { .. }) => {}
Ok(CLIMessage::Result(_)) => {
client.on_non_control(line).await?;
break;
}
_ => {
client.on_non_control(line).await?;
}
}
}
Err(e) => {
tracing::error!("Error reading stdout: {}", e);
break;
}
}
}
Ok(())
}
async fn handle_control_request(
&self,
client: &Arc<ClaudeAgentClient>,
request_id: String,
request: ControlRequestType,
) {
match request {
ControlRequestType::CanUseTool {
tool_name,
input,
permission_suggestions,
} => {
match client
.on_can_use_tool(tool_name, input, permission_suggestions)
.await
{
Ok(result) => {
if let Err(e) = self
.send_hook_response(request_id, serde_json::to_value(result).unwrap())
.await
{
tracing::error!("Failed to send permission result: {e}");
}
}
Err(e) => {
tracing::error!("Error in on_can_use_tool: {e}");
if let Err(e2) = self.send_error(request_id, e.to_string()).await {
tracing::error!("Failed to send error response: {e2}");
}
}
}
}
ControlRequestType::HookCallback {
callback_id,
input,
tool_use_id,
} => {
match client
.on_hook_callback(callback_id, input, tool_use_id)
.await
{
Ok(hook_output) => {
if let Err(e) = self.send_hook_response(request_id, hook_output).await {
tracing::error!("Failed to send hook callback result: {e}");
}
}
Err(e) => {
tracing::error!("Error in on_hook_callback: {e}");
if let Err(e2) = self.send_error(request_id, e.to_string()).await {
tracing::error!("Failed to send error response: {e2}");
}
}
}
}
}
}
pub async fn send_hook_response(
&self,
request_id: String,
hook_output: serde_json::Value,
) -> Result<(), ExecutorError> {
self.send_json(&ControlResponseMessage::new(ControlResponseType::Success {
request_id,
response: Some(hook_output),
}))
.await
}
/// Send error response to CLI
async fn send_error(&self, request_id: String, error: String) -> Result<(), ExecutorError> {
self.send_json(&ControlResponseMessage::new(ControlResponseType::Error {
request_id,
error: Some(error),
}))
.await
}
/// Send JSON message to stdin
async fn send_json<T: serde::Serialize>(&self, message: &T) -> Result<(), ExecutorError> {
let json = serde_json::to_string(message)?;
let mut stdin = self.stdin.lock().await;
stdin.write_all(json.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
Ok(())
}
pub async fn send_user_message(&self, content: String) -> Result<(), ExecutorError> {
let message = serde_json::json!({
"type": "user",
"message": {
"role": "user",
"content": content
}
});
self.send_json(&message).await
}
pub async fn initialize(&self, hooks: Option<serde_json::Value>) -> Result<(), ExecutorError> {
self.send_json(&SDKControlRequestMessage::new(
SDKControlRequestType::Initialize { hooks },
))
.await
}
pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<(), ExecutorError> {
self.send_json(&SDKControlRequestMessage::new(
SDKControlRequestType::SetPermissionMode { mode },
))
.await
}
}

View File

@@ -0,0 +1,178 @@
//! Type definitions for Claude Code control protocol
//!
//! Similar to: https://github.com/ZhangHanDong/claude-code-api-rs/blob/main/claude-code-sdk-rs/src/types.rs
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// Top-level message types from CLI stdout
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum CLIMessage {
ControlRequest {
request_id: String,
request: ControlRequestType,
},
ControlResponse {
response: ControlResponseType,
},
Result(serde_json::Value),
#[serde(untagged)]
Other(serde_json::Value),
}
/// Control request from SDK to CLI (outgoing)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SDKControlRequestMessage {
#[serde(rename = "type")]
message_type: String, // Always "control_request"
pub request_id: String,
pub request: SDKControlRequestType,
}
impl SDKControlRequestMessage {
pub fn new(request: SDKControlRequestType) -> Self {
use uuid::Uuid;
Self {
message_type: "control_request".to_string(),
request_id: Uuid::new_v4().to_string(),
request,
}
}
}
/// Control response from SDK to CLI (outgoing)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ControlResponseMessage {
#[serde(rename = "type")]
message_type: String, // Always "control_response"
pub response: ControlResponseType,
}
impl ControlResponseMessage {
pub fn new(response: ControlResponseType) -> Self {
Self {
message_type: "control_response".to_string(),
response,
}
}
}
/// Types of control requests
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "subtype", rename_all = "snake_case")]
pub enum ControlRequestType {
CanUseTool {
tool_name: String,
input: Value,
#[serde(skip_serializing_if = "Option::is_none")]
permission_suggestions: Option<Vec<PermissionUpdate>>,
},
HookCallback {
#[serde(rename = "callback_id")]
callback_id: String,
input: Value,
#[serde(skip_serializing_if = "Option::is_none")]
tool_use_id: Option<String>,
},
}
/// Result of permission check
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "behavior", rename_all = "camelCase")]
pub enum PermissionResult {
Allow {
#[serde(rename = "updatedInput")]
updated_input: Value,
#[serde(skip_serializing_if = "Option::is_none", rename = "updatedPermissions")]
updated_permissions: Option<Vec<PermissionUpdate>>,
},
Deny {
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
interrupt: Option<bool>,
},
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum PermissionUpdateType {
SetMode,
AddRules,
RemoveRules,
ClearRules,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum PermissionUpdateDestination {
Session,
UserSettings,
ProjectSettings,
LocalSettings,
}
/// Permission update operation
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PermissionUpdate {
#[serde(rename = "type")]
pub update_type: PermissionUpdateType,
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<PermissionMode>,
pub destination: PermissionUpdateDestination,
}
/// Control response from SDK to CLI
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "subtype", rename_all = "snake_case")]
pub enum ControlResponseType {
Success {
request_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
response: Option<Value>,
},
Error {
request_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "subtype", rename_all = "snake_case")]
pub enum SDKControlRequestType {
SetPermissionMode {
mode: PermissionMode,
},
Initialize {
#[serde(skip_serializing_if = "Option::is_none")]
hooks: Option<Value>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum PermissionMode {
Default,
AcceptEdits,
Plan,
BypassPermissions,
}
impl PermissionMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::Default => "default",
Self::AcceptEdits => "acceptEdits",
Self::Plan => "plan",
Self::BypassPermissions => "bypassPermissions",
}
}
}
impl std::fmt::Display for PermissionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}

View File

@@ -235,12 +235,6 @@ impl AppServerClient {
let mut guard = self.conversation_id.lock().await;
guard.replace(*conversation_id);
}
if let Some(approvals) = self.approvals.as_ref() {
approvals
.register_session(&conversation_id.to_string())
.await
.map_err(|err| ExecutorError::Io(io::Error::other(err.to_string())))?;
}
self.flush_pending_feedback().await;
Ok(())
}

View File

@@ -1,179 +0,0 @@
#!/usr/bin/env python3
import argparse
import json
import sys
import time
import urllib.error
import urllib.request
from typing import Optional
def json_error(reason: Optional[str], feedback_marker: Optional[str] = None) -> None:
"""Emit a deny PreToolUse JSON to stdout and exit(0)."""
# Prefix user feedback with marker for extraction if provided
formatted_reason = reason
if reason and feedback_marker:
formatted_reason = f"{feedback_marker}{reason}"
payload = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": formatted_reason,
}
}
print(json.dumps(payload, ensure_ascii=False))
sys.exit(0)
def json_success() -> None:
payload = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
},
"suppressOutput": True,
}
print(json.dumps(payload, ensure_ascii=False))
sys.exit(0)
def http_post_json(url: str, body: dict) -> dict:
data = json.dumps(body).encode("utf-8")
req = urllib.request.Request(
url, data=data, headers={"Content-Type": "application/json"}, method="POST"
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
return json.loads(resp.read().decode("utf-8") or "{}")
except (
urllib.error.HTTPError,
urllib.error.URLError,
TimeoutError,
json.JSONDecodeError,
) as e:
json_error(
f"Failed to create approval request. Backend may be unavailable. ({e})"
)
raise # unreachable
def http_get_json(url: str) -> dict:
req = urllib.request.Request(url, method="GET")
try:
with urllib.request.urlopen(req, timeout=10) as resp:
return json.loads(resp.read().decode("utf-8") or "{}")
except (
urllib.error.HTTPError,
urllib.error.URLError,
TimeoutError,
json.JSONDecodeError,
) as e:
json_error(f"Lost connection to approval backend: {e}")
raise
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="PreToolUse approval gate. All parameters are passed via CLI."
)
parser.add_argument(
"-t",
"--timeout-seconds",
type=int,
required=True,
help="Maximum time to wait for approval before timing out (seconds).",
)
parser.add_argument(
"-p",
"--poll-interval",
type=int,
required=True,
help="Seconds between polling the backend for status.",
)
parser.add_argument(
"-b",
"--backend-port",
type=int,
required=True,
help="Port of the approval backend running on 127.0.0.1.",
)
parser.add_argument(
"-m",
"--feedback-marker",
type=str,
required=True,
help="Marker prefix for user feedback messages.",
)
args = parser.parse_args()
if args.timeout_seconds <= 0:
parser.error("--timeout-seconds must be a positive integer")
if args.poll_interval <= 0:
parser.error("--poll-interval must be a positive integer")
if args.poll_interval > args.timeout_seconds:
parser.error("--poll-interval cannot be greater than --timeout-seconds")
return args
def main():
args = parse_args()
port = args.backend_port
url = f"http://127.0.0.1:{port}"
create_endpoint = f"{url}/api/approvals/create"
try:
raw_payload = sys.stdin.read()
incoming = json.loads(raw_payload or "{}")
except json.JSONDecodeError:
json_error("Invalid JSON payload on stdin")
tool_name = incoming.get("tool_name")
tool_input = incoming.get("tool_input")
session_id = incoming.get("session_id", "unknown")
create_payload = {
"tool_name": tool_name,
"tool_input": tool_input,
"session_id": session_id,
}
response = http_post_json(create_endpoint, create_payload)
approval_id = response.get("id")
if not approval_id:
json_error("Invalid response from approval backend")
print(
f"Approval request created: {approval_id}. Waiting for user response...",
file=sys.stderr,
)
elapsed = 0
while elapsed < args.timeout_seconds:
result = http_get_json(f"{url}/api/approvals/{approval_id}/status")
status = result.get("status")
if status == "approved":
json_success()
elif status == "denied":
reason = result.get("reason")
json_error(reason, args.feedback_marker)
elif status == "timed_out":
# concat to avoid triggering the watchkill script
json_error(
"Approval request" + f" timed out after {args.timeout_seconds} seconds"
)
elif status == "pending":
time.sleep(args.poll_interval)
elapsed += args.poll_interval
else:
json_error(f"Unknown approval status: {status}")
# concat to avoid triggering the watchkill script
json_error("Approval request"+ f" timed out after {args.timeout_seconds} seconds")
if __name__ == "__main__":
main()

View File

@@ -809,11 +809,13 @@ impl ContainerService for LocalContainerService {
let approvals_service: Arc<dyn ExecutorApprovalService> =
match executor_action.base_executor() {
Some(BaseCodingAgent::Codex) => ExecutorApprovalBridge::new(
self.approvals.clone(),
self.db.clone(),
execution_process.id,
),
Some(BaseCodingAgent::Codex) | Some(BaseCodingAgent::ClaudeCode) => {
ExecutorApprovalBridge::new(
self.approvals.clone(),
self.db.clone(),
execution_process.id,
)
}
_ => Arc::new(NoopExecutorApprovalService {}),
};

View File

@@ -2,60 +2,13 @@ use axum::{
Json, Router,
extract::{Path, State},
http::StatusCode,
routing::{get, post},
routing::post,
};
use db::models::execution_process::ExecutionProcess;
use deployment::Deployment;
use services::services::container::ContainerService;
use utils::approvals::{
ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus, CreateApprovalRequest,
EXIT_PLAN_MODE_TOOL_NAME,
};
use utils::approvals::{ApprovalResponse, ApprovalStatus};
use crate::DeploymentImpl;
pub async fn create_approval(
State(deployment): State<DeploymentImpl>,
Json(request): Json<CreateApprovalRequest>,
) -> Result<Json<ApprovalRequest>, StatusCode> {
let service = deployment.approvals();
match service
.create_from_session(&deployment.db().pool, request)
.await
{
Ok(approval) => {
deployment
.track_if_analytics_allowed(
"approval_created",
serde_json::json!({
"approval_id": approval.id,
"tool_name": &approval.tool_name,
"execution_process_id": approval.execution_process_id.to_string(),
}),
)
.await;
Ok(Json(approval))
}
Err(e) => {
tracing::error!("Failed to create approval: {:?}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
pub async fn get_approval_status(
State(deployment): State<DeploymentImpl>,
Path(id): Path<String>,
) -> Result<Json<ApprovalStatus>, StatusCode> {
let service = deployment.approvals();
match service.status(&id).await {
Some(status) => Ok(Json(status)),
None => Err(StatusCode::NOT_FOUND),
}
}
pub async fn respond_to_approval(
State(deployment): State<DeploymentImpl>,
Path(id): Path<String>,
@@ -77,21 +30,6 @@ pub async fn respond_to_approval(
)
.await;
if matches!(status, ApprovalStatus::Approved)
&& context.tool_name == EXIT_PLAN_MODE_TOOL_NAME
// If exiting plan mode, automatically start a new execution process with different
// permissions
&& let Ok(ctx) = ExecutionProcess::load_context(
&deployment.db().pool,
context.execution_process_id,
)
.await
&& let Err(e) = deployment.container().exit_plan_mode_tool(ctx).await
{
tracing::error!("failed to exit plan mode: {:?}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
Ok(Json(status))
}
Err(e) => {
@@ -101,18 +39,6 @@ pub async fn respond_to_approval(
}
}
pub async fn get_pending_approvals(
State(deployment): State<DeploymentImpl>,
) -> Json<Vec<ApprovalPendingInfo>> {
let service = deployment.approvals();
let approvals = service.pending().await;
Json(approvals)
}
pub fn router() -> Router<DeploymentImpl> {
Router::new()
.route("/approvals/create", post(create_approval))
.route("/approvals/{id}/status", get(get_approval_status))
.route("/approvals/{id}/respond", post(respond_to_approval))
.route("/approvals/pending", get(get_pending_approvals))
Router::new().route("/approvals/{id}/respond", post(respond_to_approval))
}

View File

@@ -2,11 +2,9 @@ pub mod executor_approvals;
use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use db::models::{
execution_process::ExecutionProcess,
executor_session::ExecutorSession,
task::{Task, TaskStatus},
};
use executors::{
@@ -21,10 +19,7 @@ use sqlx::{Error as SqlxError, SqlitePool};
use thiserror::Error;
use tokio::sync::{RwLock, oneshot};
use utils::{
approvals::{
ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus,
CreateApprovalRequest,
},
approvals::{ApprovalRequest, ApprovalResponse, ApprovalStatus},
log_msg::LogMsg,
msg_store::MsgStore,
};
@@ -36,8 +31,6 @@ struct PendingApproval {
entry: NormalizedEntry,
execution_process_id: Uuid,
tool_name: String,
requested_at: DateTime<Utc>,
timeout_at: DateTime<Utc>,
response_tx: oneshot::Sender<ApprovalStatus>,
}
@@ -81,7 +74,7 @@ impl Approvals {
}
}
async fn create_internal(
pub async fn create_with_waiter(
&self,
request: ApprovalRequest,
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
@@ -94,12 +87,7 @@ impl Approvals {
if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await {
// Find the matching tool use entry by name and input
let matching_tool = find_matching_tool_use(
store.clone(),
&request.tool_name,
&request.tool_input,
request.tool_call_id.as_deref(),
);
let matching_tool = find_matching_tool_use(store.clone(), &request.tool_call_id);
if let Some((idx, matching_tool)) = matching_tool {
let approval_entry = matching_tool
@@ -118,8 +106,6 @@ impl Approvals {
entry: matching_tool,
execution_process_id: request.execution_process_id,
tool_name: request.tool_name.clone(),
requested_at: request.created_at,
timeout_at: request.timeout_at,
response_tx: tx,
},
);
@@ -147,41 +133,6 @@ impl Approvals {
Ok((request, waiter))
}
#[tracing::instrument(skip(self, request))]
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
let (request, _) = self.create_internal(request).await?;
Ok(request)
}
pub async fn create_with_waiter(
&self,
request: ApprovalRequest,
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
self.create_internal(request).await
}
pub async fn create_from_session(
&self,
pool: &SqlitePool,
payload: CreateApprovalRequest,
) -> Result<ApprovalRequest, ApprovalError> {
let session_id = payload.session_id.clone();
let execution_process_id =
match ExecutorSession::find_by_session_id(pool, &session_id).await? {
Some(session) => session.execution_process_id,
None => {
tracing::warn!("No executor session found for session_id: {}", session_id);
return Err(ApprovalError::NoExecutorSession(session_id));
}
};
// Move the task to InReview if it's still InProgress
ensure_task_in_review(pool, execution_process_id).await;
let request = ApprovalRequest::from_create(payload, execution_process_id);
self.create(request).await
}
#[tracing::instrument(skip(self, id, req))]
pub async fn respond(
&self,
@@ -238,39 +189,6 @@ impl Approvals {
}
}
pub async fn status(&self, id: &str) -> Option<ApprovalStatus> {
if let Some(f) = self.completed.get(id) {
return Some(f.clone());
}
if let Some(p) = self.pending.get(id) {
if chrono::Utc::now() >= p.timeout_at {
return Some(ApprovalStatus::TimedOut);
}
return Some(ApprovalStatus::Pending);
}
None
}
pub async fn pending(&self) -> Vec<ApprovalPendingInfo> {
self.pending
.iter()
.filter_map(|entry| {
let (id, pending) = entry.pair();
match &pending.entry.entry_type {
NormalizedEntryType::ToolUse { tool_name, .. } => Some(ApprovalPendingInfo {
approval_id: id.clone(),
execution_process_id: pending.execution_process_id,
tool_name: tool_name.clone(),
requested_at: pending.requested_at,
timeout_at: pending.timeout_at,
}),
_ => None,
}
})
.collect()
}
#[tracing::instrument(skip(self, id, timeout_at, waiter))]
fn spawn_timeout_watcher(
&self,
@@ -352,126 +270,37 @@ pub(crate) async fn ensure_task_in_review(pool: &SqlitePool, execution_process_i
}
}
/// Comparison strategy for matching tool use entries
enum ToolComparisonStrategy {
/// Compare by tool_call_id
ToolCallId(String),
/// Compare deserialized ClaudeToolData structures (for known tools)
Deserialized(executors::executors::claude::ClaudeToolData),
/// Compare raw JSON input fields (for Unknown tools like MCP)
RawJson,
}
/// Find a matching tool use entry that hasn't been assigned to an approval yet
/// Matches by tool name and tool input to support parallel tool calls
/// Matches by tool call id from tool metadata
fn find_matching_tool_use(
store: Arc<MsgStore>,
tool_name: &str,
tool_input: &serde_json::Value,
tool_call_id: Option<&str>,
tool_call_id: &str,
) -> Option<(usize, NormalizedEntry)> {
use executors::executors::claude::ClaudeToolData;
let history = store.get_history();
// Determine comparison strategy based on tool type
let strategy = if let Some(call_id) = tool_call_id {
// If tool_call_id is provided, use it for matching
ToolComparisonStrategy::ToolCallId(call_id.to_string())
} else {
match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
"name": tool_name,
"input": tool_input
})) {
Ok(ClaudeToolData::Unknown { .. }) => {
// For Unknown tools (MCP, future tools), use raw JSON comparison
ToolComparisonStrategy::RawJson
}
Ok(data) => {
// For known tools, use deserialized comparison with proper alias handling
ToolComparisonStrategy::Deserialized(data)
}
Err(e) => {
tracing::warn!(
"Failed to deserialize tool_input for tool '{}': {}",
tool_name,
e
);
return None;
}
}
};
// Single loop through history with strategy-based comparison
// Single loop through history
for msg in history.iter().rev() {
if let LogMsg::JsonPatch(patch) = msg
&& let Some((idx, entry)) = extract_normalized_entry_from_patch(patch)
&& let NormalizedEntryType::ToolUse {
tool_name: entry_tool_name,
status,
..
} = &entry.entry_type
&& let NormalizedEntryType::ToolUse { status, .. } = &entry.entry_type
{
// Only match tools that are in Created state
if !matches!(status, ToolStatus::Created) {
continue;
}
// Tool name must match
if entry_tool_name != tool_name {
continue;
}
// Apply comparison strategy
if let Some(metadata) = &entry.metadata {
let is_match = match &strategy {
ToolComparisonStrategy::ToolCallId(call_id) => {
// Match by tool_call_id in metadata
if let Ok(ToolCallMetadata {
tool_call_id: entry_call_id,
..
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
{
entry_call_id == *call_id
} else {
false
}
}
ToolComparisonStrategy::RawJson => {
// Compare raw JSON input for Unknown tools
if let Some(entry_input) = metadata.get("input") {
entry_input == tool_input
} else {
false
}
}
ToolComparisonStrategy::Deserialized(approval_data) => {
// Compare deserialized structures for known tools
if let Ok(entry_tool_data) =
serde_json::from_value::<ClaudeToolData>(metadata.clone())
{
entry_tool_data == *approval_data
} else {
false
}
}
};
if is_match {
let strategy_name = match strategy {
ToolComparisonStrategy::ToolCallId(call_id) => {
format!("tool_call_id '{call_id}'")
}
ToolComparisonStrategy::RawJson => "raw input comparison".to_string(),
ToolComparisonStrategy::Deserialized(_) => {
"deserialized tool data".to_string()
}
};
tracing::debug!(
"Matched tool use entry at index {idx} for tool '{tool_name}' by {strategy_name}"
);
return Some((idx, entry));
}
// Match by tool call id from metadata
if let Some(metadata) = &entry.metadata
&& let Ok(ToolCallMetadata {
tool_call_id: entry_call_id,
..
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
&& entry_call_id == tool_call_id
{
tracing::debug!(
"Matched tool use entry at index {idx} for tool call id '{tool_call_id}'"
);
return Some((idx, entry));
}
}
}
@@ -491,19 +320,9 @@ mod tests {
fn create_tool_use_entry(
tool_name: &str,
file_path: &str,
id: &str,
status: ToolStatus,
) -> NormalizedEntry {
// Create metadata that mimics the actual structure from Claude Code
// which has an "input" field containing the original tool parameters
let metadata = serde_json::json!({
"type": "tool_use",
"id": format!("test-{}", file_path),
"name": tool_name,
"input": {
"file_path": file_path
}
});
NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -514,7 +333,12 @@ mod tests {
status,
},
content: format!("Reading {file_path}"),
metadata: Some(metadata),
metadata: Some(
serde_json::to_value(ToolCallMetadata {
tool_call_id: id.to_string(),
})
.unwrap(),
),
}
}
@@ -523,9 +347,9 @@ mod tests {
let store = Arc::new(MsgStore::new());
// Setup: Simulate 3 parallel Read tool calls with different files
let read_foo = create_tool_use_entry("Read", "foo.rs", ToolStatus::Created);
let read_bar = create_tool_use_entry("Read", "bar.rs", ToolStatus::Created);
let read_baz = create_tool_use_entry("Read", "baz.rs", ToolStatus::Created);
let read_foo = create_tool_use_entry("Read", "foo.rs", "foo-id", ToolStatus::Created);
let read_bar = create_tool_use_entry("Read", "bar.rs", "bar-id", ToolStatus::Created);
let read_baz = create_tool_use_entry("Read", "baz.rs", "baz-id", ToolStatus::Created);
store.push_patch(
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(0, read_foo),
@@ -537,17 +361,12 @@ mod tests {
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(2, read_baz),
);
// Test 1: Each approval request matches its specific tool by input
let foo_input = serde_json::json!({"file_path": "foo.rs"});
let bar_input = serde_json::json!({"file_path": "bar.rs"});
let baz_input = serde_json::json!({"file_path": "baz.rs"});
let (idx_foo, _) = find_matching_tool_use(store.clone(), "Read", &foo_input, None)
.expect("Should match foo.rs");
let (idx_bar, _) = find_matching_tool_use(store.clone(), "Read", &bar_input, None)
.expect("Should match bar.rs");
let (idx_baz, _) = find_matching_tool_use(store.clone(), "Read", &baz_input, None)
.expect("Should match baz.rs");
let (idx_foo, _) =
find_matching_tool_use(store.clone(), "foo-id").expect("Should match foo.rs");
let (idx_bar, _) =
find_matching_tool_use(store.clone(), "bar-id").expect("Should match bar.rs");
let (idx_baz, _) =
find_matching_tool_use(store.clone(), "baz-id").expect("Should match baz.rs");
assert_eq!(idx_foo, 0, "foo.rs should match first entry");
assert_eq!(idx_bar, 1, "bar.rs should match second entry");
@@ -557,6 +376,7 @@ mod tests {
let read_pending = create_tool_use_entry(
"Read",
"pending.rs",
"pending-id",
ToolStatus::PendingApproval {
approval_id: "test-id".to_string(),
requested_at: chrono::Utc::now(),
@@ -567,24 +387,15 @@ mod tests {
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(3, read_pending),
);
let pending_input = serde_json::json!({"file_path": "pending.rs"});
assert!(
find_matching_tool_use(store.clone(), "Read", &pending_input, None).is_none(),
find_matching_tool_use(store.clone(), "pending-id").is_none(),
"Should not match tools in PendingApproval state"
);
// Test 3: Wrong tool name returns None
let write_input = serde_json::json!({"file_path": "foo.rs", "content": "test"});
// Test 3: Wrong tool id returns None
assert!(
find_matching_tool_use(store.clone(), "Write", &write_input, None).is_none(),
"Should not match different tool names"
);
// Test 4: Wrong input parameters returns None
let wrong_input = serde_json::json!({"file_path": "nonexistent.rs"});
assert!(
find_matching_tool_use(store.clone(), "Read", &wrong_input, None).is_none(),
"Should not match with different input parameters"
find_matching_tool_use(store.clone(), "wrong-id").is_none(),
"Should not match different tool ids"
);
}
}

View File

@@ -4,7 +4,6 @@ use async_trait::async_trait;
use db::{self, DBService};
use executors::approvals::{ExecutorApprovalError, ExecutorApprovalService};
use serde_json::Value;
use tokio::sync::RwLock;
use utils::approvals::{ApprovalRequest, ApprovalStatus, CreateApprovalRequest};
use uuid::Uuid;
@@ -14,7 +13,6 @@ pub struct ExecutorApprovalBridge {
approvals: Approvals,
db: DBService,
execution_process_id: Uuid,
session_id: RwLock<Option<String>>,
}
impl ExecutorApprovalBridge {
@@ -23,41 +21,25 @@ impl ExecutorApprovalBridge {
approvals,
db,
execution_process_id,
session_id: RwLock::new(None),
})
}
}
#[async_trait]
impl ExecutorApprovalService for ExecutorApprovalBridge {
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError> {
let mut guard = self.session_id.write().await;
guard.replace(session_id.to_string());
Ok(())
}
async fn request_tool_approval(
&self,
tool_name: &str,
tool_input: Value,
tool_call_id: &str,
) -> Result<ApprovalStatus, ExecutorApprovalError> {
let session_id = {
let guard = self.session_id.read().await;
guard
.clone()
.ok_or(ExecutorApprovalError::SessionNotRegistered)?
};
super::ensure_task_in_review(&self.db.pool, self.execution_process_id).await;
let request = ApprovalRequest::from_create(
CreateApprovalRequest {
tool_name: tool_name.to_string(),
tool_input,
session_id,
tool_call_id: Some(tool_call_id.to_string()),
tool_call_id: tool_call_id.to_string(),
},
self.execution_process_id,
);

View File

@@ -4,15 +4,13 @@ use ts_rs::TS;
use uuid::Uuid;
pub const APPROVAL_TIMEOUT_SECONDS: i64 = 3600; // 1 hour
pub const EXIT_PLAN_MODE_TOOL_NAME: &str = "ExitPlanMode";
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
pub struct ApprovalRequest {
pub id: String,
pub tool_name: String,
pub tool_input: serde_json::Value,
pub session_id: String,
pub tool_call_id: Option<String>,
pub tool_call_id: String,
pub execution_process_id: Uuid,
pub created_at: DateTime<Utc>,
pub timeout_at: DateTime<Utc>,
@@ -25,7 +23,6 @@ impl ApprovalRequest {
id: Uuid::new_v4().to_string(),
tool_name: request.tool_name,
tool_input: request.tool_input,
session_id: request.session_id,
tool_call_id: request.tool_call_id,
execution_process_id,
created_at: now,
@@ -39,8 +36,7 @@ impl ApprovalRequest {
pub struct CreateApprovalRequest {
pub tool_name: String,
pub tool_input: serde_json::Value,
pub session_id: String,
pub tool_call_id: Option<String>,
pub tool_call_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
@@ -62,13 +58,3 @@ pub struct ApprovalResponse {
pub execution_process_id: Uuid,
pub status: ApprovalStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export)]
pub struct ApprovalPendingInfo {
pub approval_id: String,
pub execution_process_id: Uuid,
pub tool_name: String,
pub requested_at: DateTime<Utc>,
pub timeout_at: DateTime<Utc>,
}

View File

@@ -322,7 +322,7 @@ export type PatchType = { "type": "NORMALIZED_ENTRY", "content": NormalizedEntry
export type ApprovalStatus = { "status": "pending" } | { "status": "approved" } | { "status": "denied", reason?: string, } | { "status": "timed_out" };
export type CreateApprovalRequest = { tool_name: string, tool_input: JsonValue, session_id: string, tool_call_id: string | null, };
export type CreateApprovalRequest = { tool_name: string, tool_input: JsonValue, tool_call_id: string, };
export type ApprovalResponse = { execution_process_id: string, status: ApprovalStatus, };