Claude approval refactor (#1080)
* WIP claude approvals * Use canusetool * Remove old exitplanmode approvals * WIP approvals * types * Remove bloat * Cleanup, exit on finish * Approval messages, cleanup * Cleanup * Fix msg types * Lint fmt * Cleanup * Send deny * add missing timeout to hooks * FIx timeout issue * Cleanup * Error handling, log writer bugs * Remove deprecated approbal endpoints * Remove tool matching strategies in favour of only id based matching * remove register session, parse result at protocol level * Remove circular peer, remove unneeded trait * Types
This commit is contained in:
@@ -26,9 +26,6 @@ impl ExecutorApprovalError {
|
||||
/// Abstraction for executor approval backends.
|
||||
#[async_trait]
|
||||
pub trait ExecutorApprovalService: Send + Sync {
|
||||
/// Registers the session identifier associated with subsequent approval requests.
|
||||
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError>;
|
||||
|
||||
/// Requests approval for a tool invocation and waits for the final decision.
|
||||
async fn request_tool_approval(
|
||||
&self,
|
||||
@@ -43,10 +40,6 @@ pub struct NoopExecutorApprovalService;
|
||||
|
||||
#[async_trait]
|
||||
impl ExecutorApprovalService for NoopExecutorApprovalService {
|
||||
async fn register_session(&self, _session_id: &str) -> Result<(), ExecutorApprovalError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn request_tool_approval(
|
||||
&self,
|
||||
_tool_name: &str,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
// SDK submodules
|
||||
pub mod client;
|
||||
pub mod protocol;
|
||||
pub mod types;
|
||||
|
||||
use std::{collections::HashMap, path::Path, process::Stdio, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -7,42 +10,33 @@ use command_group::AsyncCommandGroup;
|
||||
use futures::StreamExt;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::{io::AsyncWriteExt, process::Command, sync::OnceCell};
|
||||
use tokio::process::Command;
|
||||
use ts_rs::TS;
|
||||
use workspace_utils::{
|
||||
approvals::APPROVAL_TIMEOUT_SECONDS,
|
||||
approvals::ApprovalStatus,
|
||||
diff::{concatenate_diff_hunks, create_unified_diff, create_unified_diff_hunk},
|
||||
log_msg::LogMsg,
|
||||
msg_store::MsgStore,
|
||||
path::make_path_relative,
|
||||
port_file::read_port_file,
|
||||
shell::get_shell_command,
|
||||
};
|
||||
|
||||
use self::{client::ClaudeAgentClient, protocol::ProtocolPeer, types::PermissionMode};
|
||||
use crate::{
|
||||
approvals::ExecutorApprovalService,
|
||||
command::{CmdOverrides, CommandBuilder, apply_overrides},
|
||||
executors::{AppendPrompt, ExecutorError, SpawnedChild, StandardCodingAgentExecutor},
|
||||
executors::{
|
||||
AppendPrompt, ExecutorError, SpawnedChild, StandardCodingAgentExecutor,
|
||||
codex::client::LogWriter,
|
||||
},
|
||||
logs::{
|
||||
ActionType, FileChange, NormalizedEntry, NormalizedEntryType, TodoItem, ToolStatus,
|
||||
stderr_processor::normalize_stderr_logs,
|
||||
utils::{EntryIndexProvider, patch::ConversationPatch},
|
||||
},
|
||||
stdout_dup::create_stdout_pipe_writer,
|
||||
};
|
||||
|
||||
static BACKEND_PORT: OnceCell<u16> = OnceCell::const_new();
|
||||
async fn get_backend_port() -> std::io::Result<u16> {
|
||||
BACKEND_PORT
|
||||
.get_or_try_init(|| async { read_port_file("vibe-kanban").await })
|
||||
.await
|
||||
.copied()
|
||||
}
|
||||
|
||||
const CONFIRM_HOOK_SCRIPT: &str = include_str!("./hooks/confirm.py");
|
||||
|
||||
/// Natural language marker we add in our Python hook to denote user feedback
|
||||
/// This marker is added by our confirm.py hook script and is robust to Claude Code format changes
|
||||
const USER_FEEDBACK_MARKER: &str = "User feedback: ";
|
||||
|
||||
fn base_command(claude_code_router: bool) -> &'static str {
|
||||
if claude_code_router {
|
||||
"npx -y @musistudio/claude-code-router@1.0.58 code"
|
||||
@@ -51,7 +45,10 @@ fn base_command(claude_code_router: bool) -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, JsonSchema)]
|
||||
use derivative::Derivative;
|
||||
|
||||
#[derive(Derivative, Clone, Serialize, Deserialize, TS, JsonSchema)]
|
||||
#[derivative(Debug, PartialEq)]
|
||||
pub struct ClaudeCode {
|
||||
#[serde(default)]
|
||||
pub append_prompt: AppendPrompt,
|
||||
@@ -67,6 +64,11 @@ pub struct ClaudeCode {
|
||||
pub dangerously_skip_permissions: Option<bool>,
|
||||
#[serde(flatten)]
|
||||
pub cmd: CmdOverrides,
|
||||
|
||||
#[serde(skip)]
|
||||
#[ts(skip)]
|
||||
#[derivative(Debug = "ignore", PartialEq = "ignore")]
|
||||
approvals_service: Option<Arc<dyn ExecutorApprovalService>>,
|
||||
}
|
||||
|
||||
impl ClaudeCode {
|
||||
@@ -87,30 +89,14 @@ impl ClaudeCode {
|
||||
if plan && approvals {
|
||||
tracing::warn!("Both plan and approvals are enabled. Plan will take precedence.");
|
||||
}
|
||||
|
||||
if plan {
|
||||
builder = builder.extend_params(["--permission-mode=plan"]);
|
||||
}
|
||||
|
||||
if plan || approvals {
|
||||
match settings_json(plan).await {
|
||||
// TODO: Avoid quoting
|
||||
Ok(settings) => match shlex::try_quote(&settings) {
|
||||
Ok(quoted) => {
|
||||
builder = builder.extend_params(["--settings", "ed]);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to quote approvals JSON for --settings: {e}");
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to generate approvals JSON. Not running approvals: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
// Enable bypass at startup, otherwise we cannot change to it after exiting plan mode
|
||||
builder = builder.extend_params(["--permission-prompt-tool=stdio"]);
|
||||
builder = builder.extend_params([format!(
|
||||
"--permission-mode={}",
|
||||
PermissionMode::BypassPermissions
|
||||
)]);
|
||||
}
|
||||
|
||||
if self.dangerously_skip_permissions.unwrap_or(false) {
|
||||
builder = builder.extend_params(["--dangerously-skip-permissions"]);
|
||||
}
|
||||
@@ -120,49 +106,58 @@ impl ClaudeCode {
|
||||
builder = builder.extend_params([
|
||||
"--verbose",
|
||||
"--output-format=stream-json",
|
||||
"--input-format=stream-json",
|
||||
"--include-partial-messages",
|
||||
]);
|
||||
|
||||
apply_overrides(builder, &self.cmd)
|
||||
}
|
||||
|
||||
pub fn permission_mode(&self) -> PermissionMode {
|
||||
if self.plan.unwrap_or(false) {
|
||||
PermissionMode::Plan
|
||||
} else if self.approvals.unwrap_or(false) {
|
||||
PermissionMode::Default
|
||||
} else {
|
||||
PermissionMode::BypassPermissions
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_hooks(&self) -> Option<serde_json::Value> {
|
||||
if self.plan.unwrap_or(false) {
|
||||
Some(serde_json::json!({
|
||||
"PreToolUse": [
|
||||
{
|
||||
"matcher": "^ExitPlanMode$",
|
||||
"hookCallbackIds": ["tool_approval"],
|
||||
}
|
||||
]
|
||||
}))
|
||||
} else if self.approvals.unwrap_or(false) {
|
||||
Some(serde_json::json!({
|
||||
"PreToolUse": [
|
||||
{
|
||||
"matcher": "^(?!(Glob|Grep|NotebookRead|Read|Task|TodoWrite)$).*",
|
||||
"hookCallbackIds": ["tool_approval"],
|
||||
}
|
||||
]
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StandardCodingAgentExecutor for ClaudeCode {
|
||||
fn use_approvals(&mut self, approvals: Arc<dyn ExecutorApprovalService>) {
|
||||
self.approvals_service = Some(approvals);
|
||||
}
|
||||
|
||||
async fn spawn(&self, current_dir: &Path, prompt: &str) -> Result<SpawnedChild, ExecutorError> {
|
||||
let (shell_cmd, shell_arg) = get_shell_command();
|
||||
let command_builder = self.build_command_builder().await;
|
||||
let mut base_command = command_builder.build_initial();
|
||||
|
||||
if self.plan.unwrap_or(false) {
|
||||
base_command = create_watchkill_script(&base_command);
|
||||
}
|
||||
|
||||
if self.approvals.unwrap_or(false) || self.plan.unwrap_or(false) {
|
||||
write_python_hook(current_dir).await?
|
||||
}
|
||||
|
||||
let combined_prompt = self.append_prompt.combine_prompt(prompt);
|
||||
|
||||
let mut command = Command::new(shell_cmd);
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.current_dir(current_dir)
|
||||
.arg(shell_arg)
|
||||
.arg(&base_command);
|
||||
|
||||
let mut child = command.group_spawn()?;
|
||||
|
||||
// Feed the prompt in, then close the pipe so Claude sees EOF
|
||||
if let Some(mut stdin) = child.inner().stdin.take() {
|
||||
stdin.write_all(combined_prompt.as_bytes()).await?;
|
||||
stdin.shutdown().await?;
|
||||
}
|
||||
|
||||
Ok(child.into())
|
||||
let base_command = command_builder.build_initial();
|
||||
self.spawn_internal(current_dir, prompt, base_command).await
|
||||
}
|
||||
|
||||
async fn spawn_follow_up(
|
||||
@@ -171,44 +166,13 @@ impl StandardCodingAgentExecutor for ClaudeCode {
|
||||
prompt: &str,
|
||||
session_id: &str,
|
||||
) -> Result<SpawnedChild, ExecutorError> {
|
||||
let (shell_cmd, shell_arg) = get_shell_command();
|
||||
let command_builder = self.build_command_builder().await;
|
||||
// Build follow-up command with --resume {session_id}
|
||||
let mut base_command = command_builder.build_follow_up(&[
|
||||
let base_command = command_builder.build_follow_up(&[
|
||||
"--fork-session".to_string(),
|
||||
"--resume".to_string(),
|
||||
session_id.to_string(),
|
||||
]);
|
||||
|
||||
if self.plan.unwrap_or(false) {
|
||||
base_command = create_watchkill_script(&base_command);
|
||||
}
|
||||
|
||||
if self.approvals.unwrap_or(false) || self.plan.unwrap_or(false) {
|
||||
write_python_hook(current_dir).await?
|
||||
}
|
||||
|
||||
let combined_prompt = self.append_prompt.combine_prompt(prompt);
|
||||
|
||||
let mut command = Command::new(shell_cmd);
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.current_dir(current_dir)
|
||||
.arg(shell_arg)
|
||||
.arg(&base_command);
|
||||
|
||||
let mut child = command.group_spawn()?;
|
||||
|
||||
// Feed the followup prompt in, then close the pipe
|
||||
if let Some(mut stdin) = child.inner().stdin.take() {
|
||||
stdin.write_all(combined_prompt.as_bytes()).await?;
|
||||
stdin.shutdown().await?;
|
||||
}
|
||||
|
||||
Ok(child.into())
|
||||
self.spawn_internal(current_dir, prompt, base_command).await
|
||||
}
|
||||
|
||||
fn normalize_logs(&self, msg_store: Arc<MsgStore>, current_dir: &Path) {
|
||||
@@ -232,116 +196,74 @@ impl StandardCodingAgentExecutor for ClaudeCode {
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_python_hook(current_dir: &Path) -> Result<(), ExecutorError> {
|
||||
let hooks_dir = current_dir.join(".claude").join("hooks");
|
||||
tokio::fs::create_dir_all(&hooks_dir).await?;
|
||||
let hook_path = hooks_dir.join("confirm.py");
|
||||
impl ClaudeCode {
|
||||
async fn spawn_internal(
|
||||
&self,
|
||||
current_dir: &Path,
|
||||
prompt: &str,
|
||||
base_command: String,
|
||||
) -> Result<SpawnedChild, ExecutorError> {
|
||||
let (shell_cmd, shell_arg) = get_shell_command();
|
||||
let combined_prompt = self.append_prompt.combine_prompt(prompt);
|
||||
|
||||
let mut file = tokio::fs::File::create(&hook_path).await?;
|
||||
file.write_all(CONFIRM_HOOK_SCRIPT.as_bytes()).await?;
|
||||
file.flush().await?;
|
||||
let mut command = Command::new(shell_cmd);
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.current_dir(current_dir)
|
||||
.arg(shell_arg)
|
||||
.arg(&base_command);
|
||||
|
||||
// TODO: Handle Windows permissioning
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let perm = std::fs::Permissions::from_mode(0o755);
|
||||
tokio::fs::set_permissions(&hook_path, perm).await?;
|
||||
let mut child = command.group_spawn()?;
|
||||
let child_stdout = child.inner().stdout.take().ok_or_else(|| {
|
||||
ExecutorError::Io(std::io::Error::other("Claude Code missing stdout"))
|
||||
})?;
|
||||
let child_stdin =
|
||||
child.inner().stdin.take().ok_or_else(|| {
|
||||
ExecutorError::Io(std::io::Error::other("Claude Code missing stdin"))
|
||||
})?;
|
||||
|
||||
let new_stdout = create_stdout_pipe_writer(&mut child)?;
|
||||
let permission_mode = self.permission_mode();
|
||||
let hooks = self.get_hooks();
|
||||
|
||||
// Spawn task to handle the SDK client with control protocol
|
||||
let prompt_clone = combined_prompt.clone();
|
||||
let approvals_clone = self.approvals_service.clone();
|
||||
tokio::spawn(async move {
|
||||
let log_writer = LogWriter::new(new_stdout);
|
||||
let client = ClaudeAgentClient::new(log_writer.clone(), approvals_clone);
|
||||
let protocol_peer = ProtocolPeer::spawn(child_stdin, child_stdout, client.clone());
|
||||
|
||||
// Initialize control protocol
|
||||
if let Err(e) = protocol_peer.initialize(hooks).await {
|
||||
tracing::error!("Failed to initialize control protocol: {e}");
|
||||
let _ = log_writer
|
||||
.log_raw(&format!("Error: Failed to initialize - {e}"))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(e) = protocol_peer.set_permission_mode(permission_mode).await {
|
||||
tracing::warn!("Failed to set permission mode to {permission_mode}: {e}");
|
||||
}
|
||||
|
||||
// Send user message
|
||||
if let Err(e) = protocol_peer.send_user_message(prompt_clone).await {
|
||||
tracing::error!("Failed to send prompt: {e}");
|
||||
let _ = log_writer
|
||||
.log_raw(&format!("Error: Failed to send prompt - {e}"))
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(SpawnedChild {
|
||||
child,
|
||||
exit_signal: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ignore the confirm.py script
|
||||
let gitignore_path = hooks_dir.join(".gitignore");
|
||||
if !tokio::fs::try_exists(&gitignore_path).await? {
|
||||
let mut gitignore_file = tokio::fs::File::create(&gitignore_path).await?;
|
||||
gitignore_file
|
||||
.write_all(b"confirm.py\n.gitignore\n")
|
||||
.await?;
|
||||
gitignore_file.flush().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Configure settings json
|
||||
async fn settings_json(plan: bool) -> Result<String, std::io::Error> {
|
||||
let backend_port = get_backend_port().await?;
|
||||
let backend_timeout = APPROVAL_TIMEOUT_SECONDS + 5; // add buffer
|
||||
|
||||
let matcher = if plan {
|
||||
"^ExitPlanMode$"
|
||||
} else {
|
||||
"^(?!(Glob|Grep|NotebookRead|Read|Task|TodoWrite)$).*"
|
||||
};
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"hooks": {
|
||||
"PreToolUse": [
|
||||
{
|
||||
"matcher": matcher,
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": format!("$CLAUDE_PROJECT_DIR/.claude/hooks/confirm.py --timeout-seconds {backend_timeout} --poll-interval 5 --backend-port {backend_port} --feedback-marker '{USER_FEEDBACK_MARKER}'"),
|
||||
"timeout": backend_timeout + 10
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
fn create_watchkill_script(command: &str) -> String {
|
||||
// Hack: we concatenate so that Claude doesn't trigger the watchkill when reading this file
|
||||
// during development, since it contains the stop phrase
|
||||
let claude_plan_stop_indicator = concat!("Approval ", "request timed out");
|
||||
let cmd = shlex::try_quote(command).unwrap().to_string();
|
||||
|
||||
format!(
|
||||
r#"#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
word="{claude_plan_stop_indicator}"
|
||||
|
||||
exit_code=0
|
||||
while IFS= read -r line; do
|
||||
printf '%s\n' "$line"
|
||||
if [[ $line == *"$word"* ]]; then
|
||||
exit 0
|
||||
fi
|
||||
done < <(bash -lc {cmd} <&0 2>&1)
|
||||
|
||||
exit_code=${{PIPESTATUS[0]}}
|
||||
exit "$exit_code"
|
||||
"#
|
||||
)
|
||||
}
|
||||
|
||||
/// Extract user denial reason from tool result error messages
|
||||
/// Our confirm.py hook prefixes user feedback with "User feedback: " for easy extraction
|
||||
/// Supports both string content and Claude's array format: [{"type":"text","text":"..."}]
|
||||
fn extract_denial_reason(content: &serde_json::Value) -> Option<String> {
|
||||
// First try to parse as string
|
||||
let content_str = if let Some(s) = content.as_str() {
|
||||
s.to_string()
|
||||
} else if let Ok(items) =
|
||||
serde_json::from_value::<Vec<ClaudeToolResultTextItem>>(content.clone())
|
||||
{
|
||||
// Handle array format: [{"type":"text","text":"..."}]
|
||||
items
|
||||
.into_iter()
|
||||
.map(|item| item.text)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
} else {
|
||||
// Try to serialize the value as a string
|
||||
content.to_string()
|
||||
};
|
||||
|
||||
content_str
|
||||
.split_once(USER_FEEDBACK_MARKER)
|
||||
.map(|(_, rest)| rest.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -491,6 +413,7 @@ impl ClaudeLogProcessor {
|
||||
ClaudeJson::ToolResult { session_id, .. } => session_id.clone(),
|
||||
ClaudeJson::StreamEvent { session_id, .. } => session_id.clone(),
|
||||
ClaudeJson::Result { session_id, .. } => session_id.clone(),
|
||||
ClaudeJson::ApprovalResponse { .. } => None,
|
||||
ClaudeJson::Unknown { .. } => None,
|
||||
}
|
||||
}
|
||||
@@ -580,12 +503,22 @@ impl ClaudeLogProcessor {
|
||||
serde_json::to_value(content_item).unwrap_or(serde_json::Value::Null),
|
||||
),
|
||||
}),
|
||||
ClaudeContentItem::ToolUse { tool_data, .. } => {
|
||||
ClaudeContentItem::ToolUse { tool_data, id } => {
|
||||
let name = tool_data.get_name();
|
||||
let action_type = Self::extract_action_type(tool_data, worktree_path);
|
||||
let content =
|
||||
Self::generate_concise_content(tool_data, &action_type, worktree_path);
|
||||
|
||||
// Create metadata with tool_call_id for approval matching
|
||||
let mut metadata =
|
||||
serde_json::to_value(content_item).unwrap_or(serde_json::Value::Null);
|
||||
if let Some(obj) = metadata.as_object_mut() {
|
||||
obj.insert(
|
||||
"tool_call_id".to_string(),
|
||||
serde_json::Value::String(id.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
Some(NormalizedEntry {
|
||||
timestamp: None,
|
||||
entry_type: NormalizedEntryType::ToolUse {
|
||||
@@ -594,9 +527,7 @@ impl ClaudeLogProcessor {
|
||||
status: ToolStatus::Created,
|
||||
},
|
||||
content,
|
||||
metadata: Some(
|
||||
serde_json::to_value(content_item).unwrap_or(serde_json::Value::Null),
|
||||
),
|
||||
metadata: Some(metadata),
|
||||
})
|
||||
}
|
||||
ClaudeContentItem::ToolResult { .. } => {
|
||||
@@ -837,6 +768,17 @@ impl ClaudeLogProcessor {
|
||||
&action_type,
|
||||
worktree_path,
|
||||
);
|
||||
|
||||
// Create metadata with tool_call_id for approval matching
|
||||
let mut metadata =
|
||||
serde_json::to_value(item).unwrap_or(serde_json::Value::Null);
|
||||
if let Some(obj) = metadata.as_object_mut() {
|
||||
obj.insert(
|
||||
"tool_call_id".to_string(),
|
||||
serde_json::Value::String(id.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
let entry = NormalizedEntry {
|
||||
timestamp: None,
|
||||
entry_type: NormalizedEntryType::ToolUse {
|
||||
@@ -845,9 +787,7 @@ impl ClaudeLogProcessor {
|
||||
status: ToolStatus::Created,
|
||||
},
|
||||
content: content_text.clone(),
|
||||
metadata: Some(
|
||||
serde_json::to_value(item).unwrap_or(serde_json::Value::Null),
|
||||
),
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
let is_new = entry_index.is_none();
|
||||
let id_num = entry_index.unwrap_or_else(|| entry_index_provider.next());
|
||||
@@ -930,7 +870,7 @@ impl ClaudeLogProcessor {
|
||||
{
|
||||
let is_command = matches!(info.tool_data, ClaudeToolData::Bash { .. });
|
||||
|
||||
let display_tool_name = if is_command {
|
||||
let _display_tool_name = if is_command {
|
||||
info.tool_name.clone()
|
||||
} else {
|
||||
let raw_name = info.tool_data.get_name().to_string();
|
||||
@@ -1048,24 +988,8 @@ impl ClaudeLogProcessor {
|
||||
};
|
||||
patches.push(ConversationPatch::replace(info.entry_index, entry));
|
||||
}
|
||||
|
||||
if is_error.unwrap_or(false)
|
||||
&& let Some(denial_reason) = extract_denial_reason(content)
|
||||
{
|
||||
let user_feedback = NormalizedEntry {
|
||||
timestamp: None,
|
||||
entry_type: NormalizedEntryType::UserFeedback {
|
||||
denied_tool: display_tool_name.clone(),
|
||||
},
|
||||
content: denial_reason,
|
||||
metadata: None,
|
||||
};
|
||||
let feedback_index = entry_index_provider.next();
|
||||
patches.push(ConversationPatch::add_normalized_entry(
|
||||
feedback_index,
|
||||
user_feedback,
|
||||
));
|
||||
}
|
||||
// Note: With control protocol, denials are handled via protocol messages
|
||||
// rather than error content parsing
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1166,6 +1090,40 @@ impl ClaudeLogProcessor {
|
||||
patches.push(ConversationPatch::add_normalized_entry(idx, entry));
|
||||
}
|
||||
}
|
||||
ClaudeJson::ApprovalResponse {
|
||||
call_id: _,
|
||||
tool_name,
|
||||
approval_status,
|
||||
} => {
|
||||
// Convert denials and timeouts to visible entries (matching Codex behavior)
|
||||
let entry_opt = match approval_status {
|
||||
ApprovalStatus::Pending => None,
|
||||
ApprovalStatus::Approved => None,
|
||||
ApprovalStatus::Denied { reason } => Some(NormalizedEntry {
|
||||
timestamp: None,
|
||||
entry_type: NormalizedEntryType::UserFeedback {
|
||||
denied_tool: tool_name.clone(),
|
||||
},
|
||||
content: reason
|
||||
.as_ref()
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(|| "User denied this tool use request".to_string()),
|
||||
metadata: None,
|
||||
}),
|
||||
ApprovalStatus::TimedOut => Some(NormalizedEntry {
|
||||
timestamp: None,
|
||||
entry_type: NormalizedEntryType::ErrorMessage,
|
||||
content: format!("Approval timed out for tool {tool_name}"),
|
||||
metadata: None,
|
||||
}),
|
||||
};
|
||||
|
||||
if let Some(entry) = entry_opt {
|
||||
let idx = entry_index_provider.next();
|
||||
patches.push(ConversationPatch::add_normalized_entry(idx, entry));
|
||||
}
|
||||
}
|
||||
ClaudeJson::Unknown { data } => {
|
||||
let entry = NormalizedEntry {
|
||||
timestamp: None,
|
||||
@@ -1434,7 +1392,7 @@ impl StreamingContentState {
|
||||
}
|
||||
|
||||
// Data structures for parsing Claude's JSON output format
|
||||
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ClaudeJson {
|
||||
#[serde(rename = "system")]
|
||||
@@ -1497,6 +1455,12 @@ pub enum ClaudeJson {
|
||||
#[serde(default, alias = "sessionId")]
|
||||
session_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "approval_response")]
|
||||
ApprovalResponse {
|
||||
call_id: String,
|
||||
tool_name: String,
|
||||
approval_status: ApprovalStatus,
|
||||
},
|
||||
// Catch-all for unknown message types
|
||||
#[serde(untagged)]
|
||||
Unknown {
|
||||
@@ -2006,6 +1970,7 @@ mod tests {
|
||||
base_command_override: None,
|
||||
additional_params: None,
|
||||
},
|
||||
approvals_service: None,
|
||||
};
|
||||
let msg_store = Arc::new(MsgStore::new());
|
||||
let current_dir = std::path::PathBuf::from("/tmp/test-worktree");
|
||||
|
||||
206
crates/executors/src/executors/claude/client.rs
Normal file
206
crates/executors/src/executors/claude/client.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
use workspace_utils::approvals::ApprovalStatus;
|
||||
|
||||
use super::types::PermissionMode;
|
||||
use crate::{
|
||||
approvals::{ExecutorApprovalError, ExecutorApprovalService},
|
||||
executors::{
|
||||
ExecutorError,
|
||||
claude::{
|
||||
ClaudeJson,
|
||||
types::{
|
||||
PermissionResult, PermissionUpdate, PermissionUpdateDestination,
|
||||
PermissionUpdateType,
|
||||
},
|
||||
},
|
||||
codex::client::LogWriter,
|
||||
},
|
||||
};
|
||||
|
||||
const EXIT_PLAN_MODE_NAME: &str = "ExitPlanMode";
|
||||
|
||||
/// Claude Agent client with control protocol support
|
||||
pub struct ClaudeAgentClient {
|
||||
log_writer: LogWriter,
|
||||
approvals: Option<Arc<dyn ExecutorApprovalService>>,
|
||||
auto_approve: bool, // true when approvals is None
|
||||
latest_unhandled_tool_use_id: Mutex<Option<String>>,
|
||||
}
|
||||
|
||||
impl ClaudeAgentClient {
|
||||
/// Create a new client with optional approval service
|
||||
pub fn new(
|
||||
log_writer: LogWriter,
|
||||
approvals: Option<Arc<dyn ExecutorApprovalService>>,
|
||||
) -> Arc<Self> {
|
||||
let auto_approve = approvals.is_none();
|
||||
Arc::new(Self {
|
||||
log_writer,
|
||||
approvals,
|
||||
auto_approve,
|
||||
latest_unhandled_tool_use_id: Mutex::new(None),
|
||||
})
|
||||
}
|
||||
async fn set_latest_unhandled_tool_use_id(&self, tool_use_id: String) {
|
||||
if self.latest_unhandled_tool_use_id.lock().await.is_some() {
|
||||
tracing::warn!(
|
||||
"Overwriting unhandled tool_use_id: {} with new tool_use_id: {}",
|
||||
self.latest_unhandled_tool_use_id
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
tool_use_id
|
||||
);
|
||||
}
|
||||
let mut guard = self.latest_unhandled_tool_use_id.lock().await;
|
||||
guard.replace(tool_use_id);
|
||||
}
|
||||
|
||||
async fn handle_approval(
|
||||
&self,
|
||||
tool_use_id: String,
|
||||
tool_name: String,
|
||||
tool_input: serde_json::Value,
|
||||
) -> Result<PermissionResult, ExecutorError> {
|
||||
// Use approval service to request tool approval
|
||||
let approval_service = self
|
||||
.approvals
|
||||
.as_ref()
|
||||
.ok_or(ExecutorApprovalError::ServiceUnavailable)?;
|
||||
let status = approval_service
|
||||
.request_tool_approval(&tool_name, tool_input.clone(), &tool_use_id)
|
||||
.await;
|
||||
match status {
|
||||
Ok(status) => {
|
||||
// Log the approval response so we it appears in the executor logs
|
||||
self.log_writer
|
||||
.log_raw(&serde_json::to_string(&ClaudeJson::ApprovalResponse {
|
||||
call_id: tool_use_id.clone(),
|
||||
tool_name: tool_name.clone(),
|
||||
approval_status: status.clone(),
|
||||
})?)
|
||||
.await?;
|
||||
match status {
|
||||
ApprovalStatus::Approved => {
|
||||
if tool_name == EXIT_PLAN_MODE_NAME {
|
||||
Ok(PermissionResult::Allow {
|
||||
updated_input: tool_input,
|
||||
updated_permissions: Some(vec![PermissionUpdate {
|
||||
update_type: PermissionUpdateType::SetMode,
|
||||
mode: Some(PermissionMode::BypassPermissions),
|
||||
destination: PermissionUpdateDestination::Session,
|
||||
}]),
|
||||
})
|
||||
} else {
|
||||
Ok(PermissionResult::Allow {
|
||||
updated_input: tool_input,
|
||||
updated_permissions: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
ApprovalStatus::Denied { reason } => {
|
||||
let message = reason.unwrap_or("Denied by user".to_string());
|
||||
Ok(PermissionResult::Deny {
|
||||
message,
|
||||
interrupt: Some(false),
|
||||
})
|
||||
}
|
||||
ApprovalStatus::TimedOut => Ok(PermissionResult::Deny {
|
||||
message: "Approval request timed out".to_string(),
|
||||
interrupt: Some(false),
|
||||
}),
|
||||
ApprovalStatus::Pending => Ok(PermissionResult::Deny {
|
||||
message: "Approval still pending (unexpected)".to_string(),
|
||||
interrupt: Some(false),
|
||||
}),
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Tool approval request failed: {e}");
|
||||
Ok(PermissionResult::Deny {
|
||||
message: "Tool approval request failed".to_string(),
|
||||
interrupt: Some(false),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn on_can_use_tool(
|
||||
&self,
|
||||
tool_name: String,
|
||||
input: serde_json::Value,
|
||||
_permission_suggestions: Option<Vec<PermissionUpdate>>,
|
||||
) -> Result<PermissionResult, ExecutorError> {
|
||||
if self.auto_approve {
|
||||
Ok(PermissionResult::Allow {
|
||||
updated_input: input,
|
||||
updated_permissions: None,
|
||||
})
|
||||
} else {
|
||||
let latest_tool_use_id = {
|
||||
let guard = self.latest_unhandled_tool_use_id.lock().await.take();
|
||||
guard.clone()
|
||||
};
|
||||
|
||||
if let Some(latest_tool_use_id) = latest_tool_use_id {
|
||||
self.handle_approval(latest_tool_use_id, tool_name, input)
|
||||
.await
|
||||
} else {
|
||||
// Auto approve tools with no matching tool_use_id.
|
||||
// This rare edge case happens if a tool call triggers no hook callback,
|
||||
// so no tool_use_id is available to match the approval request to.
|
||||
tracing::warn!(
|
||||
"No unhandled tool_use_id available for tool '{}', cannot request approval",
|
||||
tool_name
|
||||
);
|
||||
Ok(PermissionResult::Allow {
|
||||
updated_input: input,
|
||||
updated_permissions: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn on_hook_callback(
|
||||
&self,
|
||||
_callback_id: String,
|
||||
_input: serde_json::Value,
|
||||
tool_use_id: Option<String>,
|
||||
) -> Result<serde_json::Value, ExecutorError> {
|
||||
if self.auto_approve {
|
||||
Ok(serde_json::json!({
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "allow",
|
||||
"permissionDecisionReason": "Auto-approved by SDK"
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
// Hook callbacks is only used to store tool_use_id for later approval request
|
||||
// Both hook callback and can_use_tool are needed.
|
||||
// - Hook callbacks have a constant 60s timeout, so cannot be used for long approvals
|
||||
// - can_use_tool does not provide tool_use_id, so cannot be used alone
|
||||
// Together they allow matching approval requests to tool uses.
|
||||
// This works because `ask` decision in hook callback triggers a can_use_tool request
|
||||
// https://docs.claude.com/en/api/agent-sdk/permissions#permission-flow-diagram
|
||||
if let Some(tool_use_id) = tool_use_id.clone() {
|
||||
self.set_latest_unhandled_tool_use_id(tool_use_id).await;
|
||||
}
|
||||
Ok(serde_json::json!({
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "ask",
|
||||
"permissionDecisionReason": "Forwarding to canusetool service"
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn on_non_control(&self, line: &str) -> Result<(), ExecutorError> {
|
||||
// Forward all non-control messages to stdout
|
||||
self.log_writer.log_raw(line).await
|
||||
}
|
||||
}
|
||||
200
crates/executors/src/executors/claude/protocol.rs
Normal file
200
crates/executors/src/executors/claude/protocol.rs
Normal file
@@ -0,0 +1,200 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||
process::{ChildStdin, ChildStdout},
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use super::types::{
|
||||
CLIMessage, ControlRequestType, ControlResponseMessage, ControlResponseType,
|
||||
SDKControlRequestMessage,
|
||||
};
|
||||
use crate::executors::{
|
||||
ExecutorError,
|
||||
claude::{
|
||||
client::ClaudeAgentClient,
|
||||
types::{PermissionMode, SDKControlRequestType},
|
||||
},
|
||||
};
|
||||
|
||||
/// Handles bidirectional control protocol communication
|
||||
#[derive(Clone)]
|
||||
pub struct ProtocolPeer {
|
||||
stdin: Arc<Mutex<ChildStdin>>,
|
||||
}
|
||||
|
||||
impl ProtocolPeer {
|
||||
pub fn spawn(stdin: ChildStdin, stdout: ChildStdout, client: Arc<ClaudeAgentClient>) -> Self {
|
||||
let peer = Self {
|
||||
stdin: Arc::new(Mutex::new(stdin)),
|
||||
};
|
||||
|
||||
let reader_peer = peer.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = reader_peer.read_loop(stdout, client).await {
|
||||
tracing::error!("Protocol reader loop error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
peer
|
||||
}
|
||||
|
||||
async fn read_loop(
|
||||
&self,
|
||||
stdout: ChildStdout,
|
||||
client: Arc<ClaudeAgentClient>,
|
||||
) -> Result<(), ExecutorError> {
|
||||
let mut reader = BufReader::new(stdout);
|
||||
let mut buffer = String::new();
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
match reader.read_line(&mut buffer).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
let line = buffer.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// Parse message using typed enum
|
||||
match serde_json::from_str::<CLIMessage>(line) {
|
||||
Ok(CLIMessage::ControlRequest {
|
||||
request_id,
|
||||
request,
|
||||
}) => {
|
||||
self.handle_control_request(&client, request_id, request)
|
||||
.await;
|
||||
}
|
||||
Ok(CLIMessage::ControlResponse { .. }) => {}
|
||||
Ok(CLIMessage::Result(_)) => {
|
||||
client.on_non_control(line).await?;
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
client.on_non_control(line).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error reading stdout: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_control_request(
|
||||
&self,
|
||||
client: &Arc<ClaudeAgentClient>,
|
||||
request_id: String,
|
||||
request: ControlRequestType,
|
||||
) {
|
||||
match request {
|
||||
ControlRequestType::CanUseTool {
|
||||
tool_name,
|
||||
input,
|
||||
permission_suggestions,
|
||||
} => {
|
||||
match client
|
||||
.on_can_use_tool(tool_name, input, permission_suggestions)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
if let Err(e) = self
|
||||
.send_hook_response(request_id, serde_json::to_value(result).unwrap())
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to send permission result: {e}");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error in on_can_use_tool: {e}");
|
||||
if let Err(e2) = self.send_error(request_id, e.to_string()).await {
|
||||
tracing::error!("Failed to send error response: {e2}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ControlRequestType::HookCallback {
|
||||
callback_id,
|
||||
input,
|
||||
tool_use_id,
|
||||
} => {
|
||||
match client
|
||||
.on_hook_callback(callback_id, input, tool_use_id)
|
||||
.await
|
||||
{
|
||||
Ok(hook_output) => {
|
||||
if let Err(e) = self.send_hook_response(request_id, hook_output).await {
|
||||
tracing::error!("Failed to send hook callback result: {e}");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error in on_hook_callback: {e}");
|
||||
if let Err(e2) = self.send_error(request_id, e.to_string()).await {
|
||||
tracing::error!("Failed to send error response: {e2}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_hook_response(
|
||||
&self,
|
||||
request_id: String,
|
||||
hook_output: serde_json::Value,
|
||||
) -> Result<(), ExecutorError> {
|
||||
self.send_json(&ControlResponseMessage::new(ControlResponseType::Success {
|
||||
request_id,
|
||||
response: Some(hook_output),
|
||||
}))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send error response to CLI
|
||||
async fn send_error(&self, request_id: String, error: String) -> Result<(), ExecutorError> {
|
||||
self.send_json(&ControlResponseMessage::new(ControlResponseType::Error {
|
||||
request_id,
|
||||
error: Some(error),
|
||||
}))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send JSON message to stdin
|
||||
async fn send_json<T: serde::Serialize>(&self, message: &T) -> Result<(), ExecutorError> {
|
||||
let json = serde_json::to_string(message)?;
|
||||
let mut stdin = self.stdin.lock().await;
|
||||
stdin.write_all(json.as_bytes()).await?;
|
||||
stdin.write_all(b"\n").await?;
|
||||
stdin.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_user_message(&self, content: String) -> Result<(), ExecutorError> {
|
||||
let message = serde_json::json!({
|
||||
"type": "user",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
});
|
||||
self.send_json(&message).await
|
||||
}
|
||||
|
||||
pub async fn initialize(&self, hooks: Option<serde_json::Value>) -> Result<(), ExecutorError> {
|
||||
self.send_json(&SDKControlRequestMessage::new(
|
||||
SDKControlRequestType::Initialize { hooks },
|
||||
))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<(), ExecutorError> {
|
||||
self.send_json(&SDKControlRequestMessage::new(
|
||||
SDKControlRequestType::SetPermissionMode { mode },
|
||||
))
|
||||
.await
|
||||
}
|
||||
}
|
||||
178
crates/executors/src/executors/claude/types.rs
Normal file
178
crates/executors/src/executors/claude/types.rs
Normal file
@@ -0,0 +1,178 @@
|
||||
//! Type definitions for Claude Code control protocol
|
||||
//!
|
||||
//! Similar to: https://github.com/ZhangHanDong/claude-code-api-rs/blob/main/claude-code-sdk-rs/src/types.rs
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Top-level message types from CLI stdout
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum CLIMessage {
|
||||
ControlRequest {
|
||||
request_id: String,
|
||||
request: ControlRequestType,
|
||||
},
|
||||
ControlResponse {
|
||||
response: ControlResponseType,
|
||||
},
|
||||
Result(serde_json::Value),
|
||||
#[serde(untagged)]
|
||||
Other(serde_json::Value),
|
||||
}
|
||||
|
||||
/// Control request from SDK to CLI (outgoing)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SDKControlRequestMessage {
|
||||
#[serde(rename = "type")]
|
||||
message_type: String, // Always "control_request"
|
||||
pub request_id: String,
|
||||
pub request: SDKControlRequestType,
|
||||
}
|
||||
|
||||
impl SDKControlRequestMessage {
|
||||
pub fn new(request: SDKControlRequestType) -> Self {
|
||||
use uuid::Uuid;
|
||||
Self {
|
||||
message_type: "control_request".to_string(),
|
||||
request_id: Uuid::new_v4().to_string(),
|
||||
request,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Control response from SDK to CLI (outgoing)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ControlResponseMessage {
|
||||
#[serde(rename = "type")]
|
||||
message_type: String, // Always "control_response"
|
||||
pub response: ControlResponseType,
|
||||
}
|
||||
|
||||
impl ControlResponseMessage {
|
||||
pub fn new(response: ControlResponseType) -> Self {
|
||||
Self {
|
||||
message_type: "control_response".to_string(),
|
||||
response,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Types of control requests
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "subtype", rename_all = "snake_case")]
|
||||
pub enum ControlRequestType {
|
||||
CanUseTool {
|
||||
tool_name: String,
|
||||
input: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
permission_suggestions: Option<Vec<PermissionUpdate>>,
|
||||
},
|
||||
HookCallback {
|
||||
#[serde(rename = "callback_id")]
|
||||
callback_id: String,
|
||||
input: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_use_id: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Result of permission check
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "behavior", rename_all = "camelCase")]
|
||||
pub enum PermissionResult {
|
||||
Allow {
|
||||
#[serde(rename = "updatedInput")]
|
||||
updated_input: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "updatedPermissions")]
|
||||
updated_permissions: Option<Vec<PermissionUpdate>>,
|
||||
},
|
||||
Deny {
|
||||
message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
interrupt: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum PermissionUpdateType {
|
||||
SetMode,
|
||||
AddRules,
|
||||
RemoveRules,
|
||||
ClearRules,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum PermissionUpdateDestination {
|
||||
Session,
|
||||
UserSettings,
|
||||
ProjectSettings,
|
||||
LocalSettings,
|
||||
}
|
||||
|
||||
/// Permission update operation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PermissionUpdate {
|
||||
#[serde(rename = "type")]
|
||||
pub update_type: PermissionUpdateType,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mode: Option<PermissionMode>,
|
||||
pub destination: PermissionUpdateDestination,
|
||||
}
|
||||
|
||||
/// Control response from SDK to CLI
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "subtype", rename_all = "snake_case")]
|
||||
pub enum ControlResponseType {
|
||||
Success {
|
||||
request_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
response: Option<Value>,
|
||||
},
|
||||
Error {
|
||||
request_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "subtype", rename_all = "snake_case")]
|
||||
pub enum SDKControlRequestType {
|
||||
SetPermissionMode {
|
||||
mode: PermissionMode,
|
||||
},
|
||||
Initialize {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
hooks: Option<Value>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum PermissionMode {
|
||||
Default,
|
||||
AcceptEdits,
|
||||
Plan,
|
||||
BypassPermissions,
|
||||
}
|
||||
|
||||
impl PermissionMode {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Default => "default",
|
||||
Self::AcceptEdits => "acceptEdits",
|
||||
Self::Plan => "plan",
|
||||
Self::BypassPermissions => "bypassPermissions",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PermissionMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
@@ -235,12 +235,6 @@ impl AppServerClient {
|
||||
let mut guard = self.conversation_id.lock().await;
|
||||
guard.replace(*conversation_id);
|
||||
}
|
||||
if let Some(approvals) = self.approvals.as_ref() {
|
||||
approvals
|
||||
.register_session(&conversation_id.to_string())
|
||||
.await
|
||||
.map_err(|err| ExecutorError::Io(io::Error::other(err.to_string())))?;
|
||||
}
|
||||
self.flush_pending_feedback().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def json_error(reason: Optional[str], feedback_marker: Optional[str] = None) -> None:
|
||||
"""Emit a deny PreToolUse JSON to stdout and exit(0)."""
|
||||
# Prefix user feedback with marker for extraction if provided
|
||||
formatted_reason = reason
|
||||
if reason and feedback_marker:
|
||||
formatted_reason = f"{feedback_marker}{reason}"
|
||||
|
||||
payload = {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": formatted_reason,
|
||||
}
|
||||
}
|
||||
print(json.dumps(payload, ensure_ascii=False))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def json_success() -> None:
|
||||
payload = {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "allow",
|
||||
},
|
||||
"suppressOutput": True,
|
||||
}
|
||||
print(json.dumps(payload, ensure_ascii=False))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def http_post_json(url: str, body: dict) -> dict:
|
||||
data = json.dumps(body).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
url, data=data, headers={"Content-Type": "application/json"}, method="POST"
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
return json.loads(resp.read().decode("utf-8") or "{}")
|
||||
except (
|
||||
urllib.error.HTTPError,
|
||||
urllib.error.URLError,
|
||||
TimeoutError,
|
||||
json.JSONDecodeError,
|
||||
) as e:
|
||||
json_error(
|
||||
f"Failed to create approval request. Backend may be unavailable. ({e})"
|
||||
)
|
||||
raise # unreachable
|
||||
|
||||
|
||||
def http_get_json(url: str) -> dict:
|
||||
req = urllib.request.Request(url, method="GET")
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
return json.loads(resp.read().decode("utf-8") or "{}")
|
||||
except (
|
||||
urllib.error.HTTPError,
|
||||
urllib.error.URLError,
|
||||
TimeoutError,
|
||||
json.JSONDecodeError,
|
||||
) as e:
|
||||
json_error(f"Lost connection to approval backend: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PreToolUse approval gate. All parameters are passed via CLI."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--timeout-seconds",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Maximum time to wait for approval before timing out (seconds).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--poll-interval",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Seconds between polling the backend for status.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--backend-port",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Port of the approval backend running on 127.0.0.1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--feedback-marker",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Marker prefix for user feedback messages.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.timeout_seconds <= 0:
|
||||
parser.error("--timeout-seconds must be a positive integer")
|
||||
if args.poll_interval <= 0:
|
||||
parser.error("--poll-interval must be a positive integer")
|
||||
if args.poll_interval > args.timeout_seconds:
|
||||
parser.error("--poll-interval cannot be greater than --timeout-seconds")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
port = args.backend_port
|
||||
|
||||
url = f"http://127.0.0.1:{port}"
|
||||
create_endpoint = f"{url}/api/approvals/create"
|
||||
|
||||
try:
|
||||
raw_payload = sys.stdin.read()
|
||||
incoming = json.loads(raw_payload or "{}")
|
||||
except json.JSONDecodeError:
|
||||
json_error("Invalid JSON payload on stdin")
|
||||
|
||||
tool_name = incoming.get("tool_name")
|
||||
tool_input = incoming.get("tool_input")
|
||||
session_id = incoming.get("session_id", "unknown")
|
||||
|
||||
create_payload = {
|
||||
"tool_name": tool_name,
|
||||
"tool_input": tool_input,
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
response = http_post_json(create_endpoint, create_payload)
|
||||
approval_id = response.get("id")
|
||||
if not approval_id:
|
||||
json_error("Invalid response from approval backend")
|
||||
|
||||
print(
|
||||
f"Approval request created: {approval_id}. Waiting for user response...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
elapsed = 0
|
||||
while elapsed < args.timeout_seconds:
|
||||
result = http_get_json(f"{url}/api/approvals/{approval_id}/status")
|
||||
status = result.get("status")
|
||||
|
||||
if status == "approved":
|
||||
json_success()
|
||||
elif status == "denied":
|
||||
reason = result.get("reason")
|
||||
json_error(reason, args.feedback_marker)
|
||||
elif status == "timed_out":
|
||||
# concat to avoid triggering the watchkill script
|
||||
json_error(
|
||||
"Approval request" + f" timed out after {args.timeout_seconds} seconds"
|
||||
)
|
||||
elif status == "pending":
|
||||
time.sleep(args.poll_interval)
|
||||
elapsed += args.poll_interval
|
||||
else:
|
||||
json_error(f"Unknown approval status: {status}")
|
||||
|
||||
# concat to avoid triggering the watchkill script
|
||||
json_error("Approval request"+ f" timed out after {args.timeout_seconds} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -809,11 +809,13 @@ impl ContainerService for LocalContainerService {
|
||||
|
||||
let approvals_service: Arc<dyn ExecutorApprovalService> =
|
||||
match executor_action.base_executor() {
|
||||
Some(BaseCodingAgent::Codex) => ExecutorApprovalBridge::new(
|
||||
self.approvals.clone(),
|
||||
self.db.clone(),
|
||||
execution_process.id,
|
||||
),
|
||||
Some(BaseCodingAgent::Codex) | Some(BaseCodingAgent::ClaudeCode) => {
|
||||
ExecutorApprovalBridge::new(
|
||||
self.approvals.clone(),
|
||||
self.db.clone(),
|
||||
execution_process.id,
|
||||
)
|
||||
}
|
||||
_ => Arc::new(NoopExecutorApprovalService {}),
|
||||
};
|
||||
|
||||
|
||||
@@ -2,60 +2,13 @@ use axum::{
|
||||
Json, Router,
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
routing::{get, post},
|
||||
routing::post,
|
||||
};
|
||||
use db::models::execution_process::ExecutionProcess;
|
||||
use deployment::Deployment;
|
||||
use services::services::container::ContainerService;
|
||||
use utils::approvals::{
|
||||
ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus, CreateApprovalRequest,
|
||||
EXIT_PLAN_MODE_TOOL_NAME,
|
||||
};
|
||||
use utils::approvals::{ApprovalResponse, ApprovalStatus};
|
||||
|
||||
use crate::DeploymentImpl;
|
||||
|
||||
pub async fn create_approval(
|
||||
State(deployment): State<DeploymentImpl>,
|
||||
Json(request): Json<CreateApprovalRequest>,
|
||||
) -> Result<Json<ApprovalRequest>, StatusCode> {
|
||||
let service = deployment.approvals();
|
||||
|
||||
match service
|
||||
.create_from_session(&deployment.db().pool, request)
|
||||
.await
|
||||
{
|
||||
Ok(approval) => {
|
||||
deployment
|
||||
.track_if_analytics_allowed(
|
||||
"approval_created",
|
||||
serde_json::json!({
|
||||
"approval_id": approval.id,
|
||||
"tool_name": &approval.tool_name,
|
||||
"execution_process_id": approval.execution_process_id.to_string(),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(approval))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to create approval: {:?}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_approval_status(
|
||||
State(deployment): State<DeploymentImpl>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApprovalStatus>, StatusCode> {
|
||||
let service = deployment.approvals();
|
||||
match service.status(&id).await {
|
||||
Some(status) => Ok(Json(status)),
|
||||
None => Err(StatusCode::NOT_FOUND),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn respond_to_approval(
|
||||
State(deployment): State<DeploymentImpl>,
|
||||
Path(id): Path<String>,
|
||||
@@ -77,21 +30,6 @@ pub async fn respond_to_approval(
|
||||
)
|
||||
.await;
|
||||
|
||||
if matches!(status, ApprovalStatus::Approved)
|
||||
&& context.tool_name == EXIT_PLAN_MODE_TOOL_NAME
|
||||
// If exiting plan mode, automatically start a new execution process with different
|
||||
// permissions
|
||||
&& let Ok(ctx) = ExecutionProcess::load_context(
|
||||
&deployment.db().pool,
|
||||
context.execution_process_id,
|
||||
)
|
||||
.await
|
||||
&& let Err(e) = deployment.container().exit_plan_mode_tool(ctx).await
|
||||
{
|
||||
tracing::error!("failed to exit plan mode: {:?}", e);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
Ok(Json(status))
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -101,18 +39,6 @@ pub async fn respond_to_approval(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pending_approvals(
|
||||
State(deployment): State<DeploymentImpl>,
|
||||
) -> Json<Vec<ApprovalPendingInfo>> {
|
||||
let service = deployment.approvals();
|
||||
let approvals = service.pending().await;
|
||||
Json(approvals)
|
||||
}
|
||||
|
||||
pub fn router() -> Router<DeploymentImpl> {
|
||||
Router::new()
|
||||
.route("/approvals/create", post(create_approval))
|
||||
.route("/approvals/{id}/status", get(get_approval_status))
|
||||
.route("/approvals/{id}/respond", post(respond_to_approval))
|
||||
.route("/approvals/pending", get(get_pending_approvals))
|
||||
Router::new().route("/approvals/{id}/respond", post(respond_to_approval))
|
||||
}
|
||||
|
||||
@@ -2,11 +2,9 @@ pub mod executor_approvals;
|
||||
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use db::models::{
|
||||
execution_process::ExecutionProcess,
|
||||
executor_session::ExecutorSession,
|
||||
task::{Task, TaskStatus},
|
||||
};
|
||||
use executors::{
|
||||
@@ -21,10 +19,7 @@ use sqlx::{Error as SqlxError, SqlitePool};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{RwLock, oneshot};
|
||||
use utils::{
|
||||
approvals::{
|
||||
ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus,
|
||||
CreateApprovalRequest,
|
||||
},
|
||||
approvals::{ApprovalRequest, ApprovalResponse, ApprovalStatus},
|
||||
log_msg::LogMsg,
|
||||
msg_store::MsgStore,
|
||||
};
|
||||
@@ -36,8 +31,6 @@ struct PendingApproval {
|
||||
entry: NormalizedEntry,
|
||||
execution_process_id: Uuid,
|
||||
tool_name: String,
|
||||
requested_at: DateTime<Utc>,
|
||||
timeout_at: DateTime<Utc>,
|
||||
response_tx: oneshot::Sender<ApprovalStatus>,
|
||||
}
|
||||
|
||||
@@ -81,7 +74,7 @@ impl Approvals {
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_internal(
|
||||
pub async fn create_with_waiter(
|
||||
&self,
|
||||
request: ApprovalRequest,
|
||||
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
|
||||
@@ -94,12 +87,7 @@ impl Approvals {
|
||||
|
||||
if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await {
|
||||
// Find the matching tool use entry by name and input
|
||||
let matching_tool = find_matching_tool_use(
|
||||
store.clone(),
|
||||
&request.tool_name,
|
||||
&request.tool_input,
|
||||
request.tool_call_id.as_deref(),
|
||||
);
|
||||
let matching_tool = find_matching_tool_use(store.clone(), &request.tool_call_id);
|
||||
|
||||
if let Some((idx, matching_tool)) = matching_tool {
|
||||
let approval_entry = matching_tool
|
||||
@@ -118,8 +106,6 @@ impl Approvals {
|
||||
entry: matching_tool,
|
||||
execution_process_id: request.execution_process_id,
|
||||
tool_name: request.tool_name.clone(),
|
||||
requested_at: request.created_at,
|
||||
timeout_at: request.timeout_at,
|
||||
response_tx: tx,
|
||||
},
|
||||
);
|
||||
@@ -147,41 +133,6 @@ impl Approvals {
|
||||
Ok((request, waiter))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, request))]
|
||||
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
|
||||
let (request, _) = self.create_internal(request).await?;
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
pub async fn create_with_waiter(
|
||||
&self,
|
||||
request: ApprovalRequest,
|
||||
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
|
||||
self.create_internal(request).await
|
||||
}
|
||||
|
||||
pub async fn create_from_session(
|
||||
&self,
|
||||
pool: &SqlitePool,
|
||||
payload: CreateApprovalRequest,
|
||||
) -> Result<ApprovalRequest, ApprovalError> {
|
||||
let session_id = payload.session_id.clone();
|
||||
let execution_process_id =
|
||||
match ExecutorSession::find_by_session_id(pool, &session_id).await? {
|
||||
Some(session) => session.execution_process_id,
|
||||
None => {
|
||||
tracing::warn!("No executor session found for session_id: {}", session_id);
|
||||
return Err(ApprovalError::NoExecutorSession(session_id));
|
||||
}
|
||||
};
|
||||
|
||||
// Move the task to InReview if it's still InProgress
|
||||
ensure_task_in_review(pool, execution_process_id).await;
|
||||
|
||||
let request = ApprovalRequest::from_create(payload, execution_process_id);
|
||||
self.create(request).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, id, req))]
|
||||
pub async fn respond(
|
||||
&self,
|
||||
@@ -238,39 +189,6 @@ impl Approvals {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn status(&self, id: &str) -> Option<ApprovalStatus> {
|
||||
if let Some(f) = self.completed.get(id) {
|
||||
return Some(f.clone());
|
||||
}
|
||||
if let Some(p) = self.pending.get(id) {
|
||||
if chrono::Utc::now() >= p.timeout_at {
|
||||
return Some(ApprovalStatus::TimedOut);
|
||||
}
|
||||
return Some(ApprovalStatus::Pending);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub async fn pending(&self) -> Vec<ApprovalPendingInfo> {
|
||||
self.pending
|
||||
.iter()
|
||||
.filter_map(|entry| {
|
||||
let (id, pending) = entry.pair();
|
||||
|
||||
match &pending.entry.entry_type {
|
||||
NormalizedEntryType::ToolUse { tool_name, .. } => Some(ApprovalPendingInfo {
|
||||
approval_id: id.clone(),
|
||||
execution_process_id: pending.execution_process_id,
|
||||
tool_name: tool_name.clone(),
|
||||
requested_at: pending.requested_at,
|
||||
timeout_at: pending.timeout_at,
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, id, timeout_at, waiter))]
|
||||
fn spawn_timeout_watcher(
|
||||
&self,
|
||||
@@ -352,126 +270,37 @@ pub(crate) async fn ensure_task_in_review(pool: &SqlitePool, execution_process_i
|
||||
}
|
||||
}
|
||||
|
||||
/// Comparison strategy for matching tool use entries
|
||||
enum ToolComparisonStrategy {
|
||||
/// Compare by tool_call_id
|
||||
ToolCallId(String),
|
||||
/// Compare deserialized ClaudeToolData structures (for known tools)
|
||||
Deserialized(executors::executors::claude::ClaudeToolData),
|
||||
/// Compare raw JSON input fields (for Unknown tools like MCP)
|
||||
RawJson,
|
||||
}
|
||||
|
||||
/// Find a matching tool use entry that hasn't been assigned to an approval yet
|
||||
/// Matches by tool name and tool input to support parallel tool calls
|
||||
/// Matches by tool call id from tool metadata
|
||||
fn find_matching_tool_use(
|
||||
store: Arc<MsgStore>,
|
||||
tool_name: &str,
|
||||
tool_input: &serde_json::Value,
|
||||
tool_call_id: Option<&str>,
|
||||
tool_call_id: &str,
|
||||
) -> Option<(usize, NormalizedEntry)> {
|
||||
use executors::executors::claude::ClaudeToolData;
|
||||
|
||||
let history = store.get_history();
|
||||
|
||||
// Determine comparison strategy based on tool type
|
||||
let strategy = if let Some(call_id) = tool_call_id {
|
||||
// If tool_call_id is provided, use it for matching
|
||||
ToolComparisonStrategy::ToolCallId(call_id.to_string())
|
||||
} else {
|
||||
match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
|
||||
"name": tool_name,
|
||||
"input": tool_input
|
||||
})) {
|
||||
Ok(ClaudeToolData::Unknown { .. }) => {
|
||||
// For Unknown tools (MCP, future tools), use raw JSON comparison
|
||||
ToolComparisonStrategy::RawJson
|
||||
}
|
||||
Ok(data) => {
|
||||
// For known tools, use deserialized comparison with proper alias handling
|
||||
ToolComparisonStrategy::Deserialized(data)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to deserialize tool_input for tool '{}': {}",
|
||||
tool_name,
|
||||
e
|
||||
);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Single loop through history with strategy-based comparison
|
||||
// Single loop through history
|
||||
for msg in history.iter().rev() {
|
||||
if let LogMsg::JsonPatch(patch) = msg
|
||||
&& let Some((idx, entry)) = extract_normalized_entry_from_patch(patch)
|
||||
&& let NormalizedEntryType::ToolUse {
|
||||
tool_name: entry_tool_name,
|
||||
status,
|
||||
..
|
||||
} = &entry.entry_type
|
||||
&& let NormalizedEntryType::ToolUse { status, .. } = &entry.entry_type
|
||||
{
|
||||
// Only match tools that are in Created state
|
||||
if !matches!(status, ToolStatus::Created) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Tool name must match
|
||||
if entry_tool_name != tool_name {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply comparison strategy
|
||||
if let Some(metadata) = &entry.metadata {
|
||||
let is_match = match &strategy {
|
||||
ToolComparisonStrategy::ToolCallId(call_id) => {
|
||||
// Match by tool_call_id in metadata
|
||||
if let Ok(ToolCallMetadata {
|
||||
tool_call_id: entry_call_id,
|
||||
..
|
||||
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
|
||||
{
|
||||
entry_call_id == *call_id
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ToolComparisonStrategy::RawJson => {
|
||||
// Compare raw JSON input for Unknown tools
|
||||
if let Some(entry_input) = metadata.get("input") {
|
||||
entry_input == tool_input
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ToolComparisonStrategy::Deserialized(approval_data) => {
|
||||
// Compare deserialized structures for known tools
|
||||
if let Ok(entry_tool_data) =
|
||||
serde_json::from_value::<ClaudeToolData>(metadata.clone())
|
||||
{
|
||||
entry_tool_data == *approval_data
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if is_match {
|
||||
let strategy_name = match strategy {
|
||||
ToolComparisonStrategy::ToolCallId(call_id) => {
|
||||
format!("tool_call_id '{call_id}'")
|
||||
}
|
||||
ToolComparisonStrategy::RawJson => "raw input comparison".to_string(),
|
||||
ToolComparisonStrategy::Deserialized(_) => {
|
||||
"deserialized tool data".to_string()
|
||||
}
|
||||
};
|
||||
tracing::debug!(
|
||||
"Matched tool use entry at index {idx} for tool '{tool_name}' by {strategy_name}"
|
||||
);
|
||||
return Some((idx, entry));
|
||||
}
|
||||
// Match by tool call id from metadata
|
||||
if let Some(metadata) = &entry.metadata
|
||||
&& let Ok(ToolCallMetadata {
|
||||
tool_call_id: entry_call_id,
|
||||
..
|
||||
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
|
||||
&& entry_call_id == tool_call_id
|
||||
{
|
||||
tracing::debug!(
|
||||
"Matched tool use entry at index {idx} for tool call id '{tool_call_id}'"
|
||||
);
|
||||
return Some((idx, entry));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -491,19 +320,9 @@ mod tests {
|
||||
fn create_tool_use_entry(
|
||||
tool_name: &str,
|
||||
file_path: &str,
|
||||
id: &str,
|
||||
status: ToolStatus,
|
||||
) -> NormalizedEntry {
|
||||
// Create metadata that mimics the actual structure from Claude Code
|
||||
// which has an "input" field containing the original tool parameters
|
||||
let metadata = serde_json::json!({
|
||||
"type": "tool_use",
|
||||
"id": format!("test-{}", file_path),
|
||||
"name": tool_name,
|
||||
"input": {
|
||||
"file_path": file_path
|
||||
}
|
||||
});
|
||||
|
||||
NormalizedEntry {
|
||||
timestamp: None,
|
||||
entry_type: NormalizedEntryType::ToolUse {
|
||||
@@ -514,7 +333,12 @@ mod tests {
|
||||
status,
|
||||
},
|
||||
content: format!("Reading {file_path}"),
|
||||
metadata: Some(metadata),
|
||||
metadata: Some(
|
||||
serde_json::to_value(ToolCallMetadata {
|
||||
tool_call_id: id.to_string(),
|
||||
})
|
||||
.unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,9 +347,9 @@ mod tests {
|
||||
let store = Arc::new(MsgStore::new());
|
||||
|
||||
// Setup: Simulate 3 parallel Read tool calls with different files
|
||||
let read_foo = create_tool_use_entry("Read", "foo.rs", ToolStatus::Created);
|
||||
let read_bar = create_tool_use_entry("Read", "bar.rs", ToolStatus::Created);
|
||||
let read_baz = create_tool_use_entry("Read", "baz.rs", ToolStatus::Created);
|
||||
let read_foo = create_tool_use_entry("Read", "foo.rs", "foo-id", ToolStatus::Created);
|
||||
let read_bar = create_tool_use_entry("Read", "bar.rs", "bar-id", ToolStatus::Created);
|
||||
let read_baz = create_tool_use_entry("Read", "baz.rs", "baz-id", ToolStatus::Created);
|
||||
|
||||
store.push_patch(
|
||||
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(0, read_foo),
|
||||
@@ -537,17 +361,12 @@ mod tests {
|
||||
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(2, read_baz),
|
||||
);
|
||||
|
||||
// Test 1: Each approval request matches its specific tool by input
|
||||
let foo_input = serde_json::json!({"file_path": "foo.rs"});
|
||||
let bar_input = serde_json::json!({"file_path": "bar.rs"});
|
||||
let baz_input = serde_json::json!({"file_path": "baz.rs"});
|
||||
|
||||
let (idx_foo, _) = find_matching_tool_use(store.clone(), "Read", &foo_input, None)
|
||||
.expect("Should match foo.rs");
|
||||
let (idx_bar, _) = find_matching_tool_use(store.clone(), "Read", &bar_input, None)
|
||||
.expect("Should match bar.rs");
|
||||
let (idx_baz, _) = find_matching_tool_use(store.clone(), "Read", &baz_input, None)
|
||||
.expect("Should match baz.rs");
|
||||
let (idx_foo, _) =
|
||||
find_matching_tool_use(store.clone(), "foo-id").expect("Should match foo.rs");
|
||||
let (idx_bar, _) =
|
||||
find_matching_tool_use(store.clone(), "bar-id").expect("Should match bar.rs");
|
||||
let (idx_baz, _) =
|
||||
find_matching_tool_use(store.clone(), "baz-id").expect("Should match baz.rs");
|
||||
|
||||
assert_eq!(idx_foo, 0, "foo.rs should match first entry");
|
||||
assert_eq!(idx_bar, 1, "bar.rs should match second entry");
|
||||
@@ -557,6 +376,7 @@ mod tests {
|
||||
let read_pending = create_tool_use_entry(
|
||||
"Read",
|
||||
"pending.rs",
|
||||
"pending-id",
|
||||
ToolStatus::PendingApproval {
|
||||
approval_id: "test-id".to_string(),
|
||||
requested_at: chrono::Utc::now(),
|
||||
@@ -567,24 +387,15 @@ mod tests {
|
||||
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(3, read_pending),
|
||||
);
|
||||
|
||||
let pending_input = serde_json::json!({"file_path": "pending.rs"});
|
||||
assert!(
|
||||
find_matching_tool_use(store.clone(), "Read", &pending_input, None).is_none(),
|
||||
find_matching_tool_use(store.clone(), "pending-id").is_none(),
|
||||
"Should not match tools in PendingApproval state"
|
||||
);
|
||||
|
||||
// Test 3: Wrong tool name returns None
|
||||
let write_input = serde_json::json!({"file_path": "foo.rs", "content": "test"});
|
||||
// Test 3: Wrong tool id returns None
|
||||
assert!(
|
||||
find_matching_tool_use(store.clone(), "Write", &write_input, None).is_none(),
|
||||
"Should not match different tool names"
|
||||
);
|
||||
|
||||
// Test 4: Wrong input parameters returns None
|
||||
let wrong_input = serde_json::json!({"file_path": "nonexistent.rs"});
|
||||
assert!(
|
||||
find_matching_tool_use(store.clone(), "Read", &wrong_input, None).is_none(),
|
||||
"Should not match with different input parameters"
|
||||
find_matching_tool_use(store.clone(), "wrong-id").is_none(),
|
||||
"Should not match different tool ids"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use async_trait::async_trait;
|
||||
use db::{self, DBService};
|
||||
use executors::approvals::{ExecutorApprovalError, ExecutorApprovalService};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::RwLock;
|
||||
use utils::approvals::{ApprovalRequest, ApprovalStatus, CreateApprovalRequest};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -14,7 +13,6 @@ pub struct ExecutorApprovalBridge {
|
||||
approvals: Approvals,
|
||||
db: DBService,
|
||||
execution_process_id: Uuid,
|
||||
session_id: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl ExecutorApprovalBridge {
|
||||
@@ -23,41 +21,25 @@ impl ExecutorApprovalBridge {
|
||||
approvals,
|
||||
db,
|
||||
execution_process_id,
|
||||
session_id: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ExecutorApprovalService for ExecutorApprovalBridge {
|
||||
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError> {
|
||||
let mut guard = self.session_id.write().await;
|
||||
guard.replace(session_id.to_string());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn request_tool_approval(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
tool_input: Value,
|
||||
tool_call_id: &str,
|
||||
) -> Result<ApprovalStatus, ExecutorApprovalError> {
|
||||
let session_id = {
|
||||
let guard = self.session_id.read().await;
|
||||
guard
|
||||
.clone()
|
||||
.ok_or(ExecutorApprovalError::SessionNotRegistered)?
|
||||
};
|
||||
|
||||
super::ensure_task_in_review(&self.db.pool, self.execution_process_id).await;
|
||||
|
||||
let request = ApprovalRequest::from_create(
|
||||
CreateApprovalRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
tool_input,
|
||||
session_id,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
tool_call_id: tool_call_id.to_string(),
|
||||
},
|
||||
self.execution_process_id,
|
||||
);
|
||||
|
||||
@@ -4,15 +4,13 @@ use ts_rs::TS;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub const APPROVAL_TIMEOUT_SECONDS: i64 = 3600; // 1 hour
|
||||
pub const EXIT_PLAN_MODE_TOOL_NAME: &str = "ExitPlanMode";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
|
||||
pub struct ApprovalRequest {
|
||||
pub id: String,
|
||||
pub tool_name: String,
|
||||
pub tool_input: serde_json::Value,
|
||||
pub session_id: String,
|
||||
pub tool_call_id: Option<String>,
|
||||
pub tool_call_id: String,
|
||||
pub execution_process_id: Uuid,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub timeout_at: DateTime<Utc>,
|
||||
@@ -25,7 +23,6 @@ impl ApprovalRequest {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
tool_name: request.tool_name,
|
||||
tool_input: request.tool_input,
|
||||
session_id: request.session_id,
|
||||
tool_call_id: request.tool_call_id,
|
||||
execution_process_id,
|
||||
created_at: now,
|
||||
@@ -39,8 +36,7 @@ impl ApprovalRequest {
|
||||
pub struct CreateApprovalRequest {
|
||||
pub tool_name: String,
|
||||
pub tool_input: serde_json::Value,
|
||||
pub session_id: String,
|
||||
pub tool_call_id: Option<String>,
|
||||
pub tool_call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
|
||||
@@ -62,13 +58,3 @@ pub struct ApprovalResponse {
|
||||
pub execution_process_id: Uuid,
|
||||
pub status: ApprovalStatus,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
|
||||
#[ts(export)]
|
||||
pub struct ApprovalPendingInfo {
|
||||
pub approval_id: String,
|
||||
pub execution_process_id: Uuid,
|
||||
pub tool_name: String,
|
||||
pub requested_at: DateTime<Utc>,
|
||||
pub timeout_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
@@ -322,7 +322,7 @@ export type PatchType = { "type": "NORMALIZED_ENTRY", "content": NormalizedEntry
|
||||
|
||||
export type ApprovalStatus = { "status": "pending" } | { "status": "approved" } | { "status": "denied", reason?: string, } | { "status": "timed_out" };
|
||||
|
||||
export type CreateApprovalRequest = { tool_name: string, tool_input: JsonValue, session_id: string, tool_call_id: string | null, };
|
||||
export type CreateApprovalRequest = { tool_name: string, tool_input: JsonValue, tool_call_id: string, };
|
||||
|
||||
export type ApprovalResponse = { execution_process_id: string, status: ApprovalStatus, };
|
||||
|
||||
|
||||
Reference in New Issue
Block a user