Claude approval refactor (#1080)

* WIP claude approvals

* Use canusetool

* Remove old exitplanmode approvals

* WIP approvals

* types

* Remove bloat

* Cleanup, exit on finish

* Approval messages, cleanup

* Cleanup

* Fix msg types

* Lint fmt

* Cleanup

* Send deny

* add missing timeout to hooks

* FIx timeout issue

* Cleanup

* Error handling, log writer bugs

* Remove deprecated approbal endpoints

* Remove tool matching strategies in favour of only id based matching

* remove register session, parse result at protocol level

* Remove circular peer, remove unneeded trait

* Types
This commit is contained in:
Alex Netsch
2025-10-28 15:36:47 +00:00
committed by GitHub
parent a70a7bfbad
commit e06dd1f6dc
13 changed files with 845 additions and 781 deletions

View File

@@ -2,11 +2,9 @@ pub mod executor_approvals;
use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use db::models::{
execution_process::ExecutionProcess,
executor_session::ExecutorSession,
task::{Task, TaskStatus},
};
use executors::{
@@ -21,10 +19,7 @@ use sqlx::{Error as SqlxError, SqlitePool};
use thiserror::Error;
use tokio::sync::{RwLock, oneshot};
use utils::{
approvals::{
ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus,
CreateApprovalRequest,
},
approvals::{ApprovalRequest, ApprovalResponse, ApprovalStatus},
log_msg::LogMsg,
msg_store::MsgStore,
};
@@ -36,8 +31,6 @@ struct PendingApproval {
entry: NormalizedEntry,
execution_process_id: Uuid,
tool_name: String,
requested_at: DateTime<Utc>,
timeout_at: DateTime<Utc>,
response_tx: oneshot::Sender<ApprovalStatus>,
}
@@ -81,7 +74,7 @@ impl Approvals {
}
}
async fn create_internal(
pub async fn create_with_waiter(
&self,
request: ApprovalRequest,
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
@@ -94,12 +87,7 @@ impl Approvals {
if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await {
// Find the matching tool use entry by name and input
let matching_tool = find_matching_tool_use(
store.clone(),
&request.tool_name,
&request.tool_input,
request.tool_call_id.as_deref(),
);
let matching_tool = find_matching_tool_use(store.clone(), &request.tool_call_id);
if let Some((idx, matching_tool)) = matching_tool {
let approval_entry = matching_tool
@@ -118,8 +106,6 @@ impl Approvals {
entry: matching_tool,
execution_process_id: request.execution_process_id,
tool_name: request.tool_name.clone(),
requested_at: request.created_at,
timeout_at: request.timeout_at,
response_tx: tx,
},
);
@@ -147,41 +133,6 @@ impl Approvals {
Ok((request, waiter))
}
#[tracing::instrument(skip(self, request))]
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
let (request, _) = self.create_internal(request).await?;
Ok(request)
}
pub async fn create_with_waiter(
&self,
request: ApprovalRequest,
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
self.create_internal(request).await
}
pub async fn create_from_session(
&self,
pool: &SqlitePool,
payload: CreateApprovalRequest,
) -> Result<ApprovalRequest, ApprovalError> {
let session_id = payload.session_id.clone();
let execution_process_id =
match ExecutorSession::find_by_session_id(pool, &session_id).await? {
Some(session) => session.execution_process_id,
None => {
tracing::warn!("No executor session found for session_id: {}", session_id);
return Err(ApprovalError::NoExecutorSession(session_id));
}
};
// Move the task to InReview if it's still InProgress
ensure_task_in_review(pool, execution_process_id).await;
let request = ApprovalRequest::from_create(payload, execution_process_id);
self.create(request).await
}
#[tracing::instrument(skip(self, id, req))]
pub async fn respond(
&self,
@@ -238,39 +189,6 @@ impl Approvals {
}
}
pub async fn status(&self, id: &str) -> Option<ApprovalStatus> {
if let Some(f) = self.completed.get(id) {
return Some(f.clone());
}
if let Some(p) = self.pending.get(id) {
if chrono::Utc::now() >= p.timeout_at {
return Some(ApprovalStatus::TimedOut);
}
return Some(ApprovalStatus::Pending);
}
None
}
pub async fn pending(&self) -> Vec<ApprovalPendingInfo> {
self.pending
.iter()
.filter_map(|entry| {
let (id, pending) = entry.pair();
match &pending.entry.entry_type {
NormalizedEntryType::ToolUse { tool_name, .. } => Some(ApprovalPendingInfo {
approval_id: id.clone(),
execution_process_id: pending.execution_process_id,
tool_name: tool_name.clone(),
requested_at: pending.requested_at,
timeout_at: pending.timeout_at,
}),
_ => None,
}
})
.collect()
}
#[tracing::instrument(skip(self, id, timeout_at, waiter))]
fn spawn_timeout_watcher(
&self,
@@ -352,126 +270,37 @@ pub(crate) async fn ensure_task_in_review(pool: &SqlitePool, execution_process_i
}
}
/// Comparison strategy for matching tool use entries
enum ToolComparisonStrategy {
/// Compare by tool_call_id
ToolCallId(String),
/// Compare deserialized ClaudeToolData structures (for known tools)
Deserialized(executors::executors::claude::ClaudeToolData),
/// Compare raw JSON input fields (for Unknown tools like MCP)
RawJson,
}
/// Find a matching tool use entry that hasn't been assigned to an approval yet
/// Matches by tool name and tool input to support parallel tool calls
/// Matches by tool call id from tool metadata
fn find_matching_tool_use(
store: Arc<MsgStore>,
tool_name: &str,
tool_input: &serde_json::Value,
tool_call_id: Option<&str>,
tool_call_id: &str,
) -> Option<(usize, NormalizedEntry)> {
use executors::executors::claude::ClaudeToolData;
let history = store.get_history();
// Determine comparison strategy based on tool type
let strategy = if let Some(call_id) = tool_call_id {
// If tool_call_id is provided, use it for matching
ToolComparisonStrategy::ToolCallId(call_id.to_string())
} else {
match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
"name": tool_name,
"input": tool_input
})) {
Ok(ClaudeToolData::Unknown { .. }) => {
// For Unknown tools (MCP, future tools), use raw JSON comparison
ToolComparisonStrategy::RawJson
}
Ok(data) => {
// For known tools, use deserialized comparison with proper alias handling
ToolComparisonStrategy::Deserialized(data)
}
Err(e) => {
tracing::warn!(
"Failed to deserialize tool_input for tool '{}': {}",
tool_name,
e
);
return None;
}
}
};
// Single loop through history with strategy-based comparison
// Single loop through history
for msg in history.iter().rev() {
if let LogMsg::JsonPatch(patch) = msg
&& let Some((idx, entry)) = extract_normalized_entry_from_patch(patch)
&& let NormalizedEntryType::ToolUse {
tool_name: entry_tool_name,
status,
..
} = &entry.entry_type
&& let NormalizedEntryType::ToolUse { status, .. } = &entry.entry_type
{
// Only match tools that are in Created state
if !matches!(status, ToolStatus::Created) {
continue;
}
// Tool name must match
if entry_tool_name != tool_name {
continue;
}
// Apply comparison strategy
if let Some(metadata) = &entry.metadata {
let is_match = match &strategy {
ToolComparisonStrategy::ToolCallId(call_id) => {
// Match by tool_call_id in metadata
if let Ok(ToolCallMetadata {
tool_call_id: entry_call_id,
..
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
{
entry_call_id == *call_id
} else {
false
}
}
ToolComparisonStrategy::RawJson => {
// Compare raw JSON input for Unknown tools
if let Some(entry_input) = metadata.get("input") {
entry_input == tool_input
} else {
false
}
}
ToolComparisonStrategy::Deserialized(approval_data) => {
// Compare deserialized structures for known tools
if let Ok(entry_tool_data) =
serde_json::from_value::<ClaudeToolData>(metadata.clone())
{
entry_tool_data == *approval_data
} else {
false
}
}
};
if is_match {
let strategy_name = match strategy {
ToolComparisonStrategy::ToolCallId(call_id) => {
format!("tool_call_id '{call_id}'")
}
ToolComparisonStrategy::RawJson => "raw input comparison".to_string(),
ToolComparisonStrategy::Deserialized(_) => {
"deserialized tool data".to_string()
}
};
tracing::debug!(
"Matched tool use entry at index {idx} for tool '{tool_name}' by {strategy_name}"
);
return Some((idx, entry));
}
// Match by tool call id from metadata
if let Some(metadata) = &entry.metadata
&& let Ok(ToolCallMetadata {
tool_call_id: entry_call_id,
..
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
&& entry_call_id == tool_call_id
{
tracing::debug!(
"Matched tool use entry at index {idx} for tool call id '{tool_call_id}'"
);
return Some((idx, entry));
}
}
}
@@ -491,19 +320,9 @@ mod tests {
fn create_tool_use_entry(
tool_name: &str,
file_path: &str,
id: &str,
status: ToolStatus,
) -> NormalizedEntry {
// Create metadata that mimics the actual structure from Claude Code
// which has an "input" field containing the original tool parameters
let metadata = serde_json::json!({
"type": "tool_use",
"id": format!("test-{}", file_path),
"name": tool_name,
"input": {
"file_path": file_path
}
});
NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -514,7 +333,12 @@ mod tests {
status,
},
content: format!("Reading {file_path}"),
metadata: Some(metadata),
metadata: Some(
serde_json::to_value(ToolCallMetadata {
tool_call_id: id.to_string(),
})
.unwrap(),
),
}
}
@@ -523,9 +347,9 @@ mod tests {
let store = Arc::new(MsgStore::new());
// Setup: Simulate 3 parallel Read tool calls with different files
let read_foo = create_tool_use_entry("Read", "foo.rs", ToolStatus::Created);
let read_bar = create_tool_use_entry("Read", "bar.rs", ToolStatus::Created);
let read_baz = create_tool_use_entry("Read", "baz.rs", ToolStatus::Created);
let read_foo = create_tool_use_entry("Read", "foo.rs", "foo-id", ToolStatus::Created);
let read_bar = create_tool_use_entry("Read", "bar.rs", "bar-id", ToolStatus::Created);
let read_baz = create_tool_use_entry("Read", "baz.rs", "baz-id", ToolStatus::Created);
store.push_patch(
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(0, read_foo),
@@ -537,17 +361,12 @@ mod tests {
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(2, read_baz),
);
// Test 1: Each approval request matches its specific tool by input
let foo_input = serde_json::json!({"file_path": "foo.rs"});
let bar_input = serde_json::json!({"file_path": "bar.rs"});
let baz_input = serde_json::json!({"file_path": "baz.rs"});
let (idx_foo, _) = find_matching_tool_use(store.clone(), "Read", &foo_input, None)
.expect("Should match foo.rs");
let (idx_bar, _) = find_matching_tool_use(store.clone(), "Read", &bar_input, None)
.expect("Should match bar.rs");
let (idx_baz, _) = find_matching_tool_use(store.clone(), "Read", &baz_input, None)
.expect("Should match baz.rs");
let (idx_foo, _) =
find_matching_tool_use(store.clone(), "foo-id").expect("Should match foo.rs");
let (idx_bar, _) =
find_matching_tool_use(store.clone(), "bar-id").expect("Should match bar.rs");
let (idx_baz, _) =
find_matching_tool_use(store.clone(), "baz-id").expect("Should match baz.rs");
assert_eq!(idx_foo, 0, "foo.rs should match first entry");
assert_eq!(idx_bar, 1, "bar.rs should match second entry");
@@ -557,6 +376,7 @@ mod tests {
let read_pending = create_tool_use_entry(
"Read",
"pending.rs",
"pending-id",
ToolStatus::PendingApproval {
approval_id: "test-id".to_string(),
requested_at: chrono::Utc::now(),
@@ -567,24 +387,15 @@ mod tests {
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(3, read_pending),
);
let pending_input = serde_json::json!({"file_path": "pending.rs"});
assert!(
find_matching_tool_use(store.clone(), "Read", &pending_input, None).is_none(),
find_matching_tool_use(store.clone(), "pending-id").is_none(),
"Should not match tools in PendingApproval state"
);
// Test 3: Wrong tool name returns None
let write_input = serde_json::json!({"file_path": "foo.rs", "content": "test"});
// Test 3: Wrong tool id returns None
assert!(
find_matching_tool_use(store.clone(), "Write", &write_input, None).is_none(),
"Should not match different tool names"
);
// Test 4: Wrong input parameters returns None
let wrong_input = serde_json::json!({"file_path": "nonexistent.rs"});
assert!(
find_matching_tool_use(store.clone(), "Read", &wrong_input, None).is_none(),
"Should not match with different input parameters"
find_matching_tool_use(store.clone(), "wrong-id").is_none(),
"Should not match different tool ids"
);
}
}

View File

@@ -4,7 +4,6 @@ use async_trait::async_trait;
use db::{self, DBService};
use executors::approvals::{ExecutorApprovalError, ExecutorApprovalService};
use serde_json::Value;
use tokio::sync::RwLock;
use utils::approvals::{ApprovalRequest, ApprovalStatus, CreateApprovalRequest};
use uuid::Uuid;
@@ -14,7 +13,6 @@ pub struct ExecutorApprovalBridge {
approvals: Approvals,
db: DBService,
execution_process_id: Uuid,
session_id: RwLock<Option<String>>,
}
impl ExecutorApprovalBridge {
@@ -23,41 +21,25 @@ impl ExecutorApprovalBridge {
approvals,
db,
execution_process_id,
session_id: RwLock::new(None),
})
}
}
#[async_trait]
impl ExecutorApprovalService for ExecutorApprovalBridge {
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError> {
let mut guard = self.session_id.write().await;
guard.replace(session_id.to_string());
Ok(())
}
async fn request_tool_approval(
&self,
tool_name: &str,
tool_input: Value,
tool_call_id: &str,
) -> Result<ApprovalStatus, ExecutorApprovalError> {
let session_id = {
let guard = self.session_id.read().await;
guard
.clone()
.ok_or(ExecutorApprovalError::SessionNotRegistered)?
};
super::ensure_task_in_review(&self.db.pool, self.execution_process_id).await;
let request = ApprovalRequest::from_create(
CreateApprovalRequest {
tool_name: tool_name.to_string(),
tool_input,
session_id,
tool_call_id: Some(tool_call_id.to_string()),
tool_call_id: tool_call_id.to_string(),
},
self.execution_process_id,
);