Claude approval refactor (#1080)
* WIP claude approvals * Use canusetool * Remove old exitplanmode approvals * WIP approvals * types * Remove bloat * Cleanup, exit on finish * Approval messages, cleanup * Cleanup * Fix msg types * Lint fmt * Cleanup * Send deny * add missing timeout to hooks * FIx timeout issue * Cleanup * Error handling, log writer bugs * Remove deprecated approbal endpoints * Remove tool matching strategies in favour of only id based matching * remove register session, parse result at protocol level * Remove circular peer, remove unneeded trait * Types
This commit is contained in:
@@ -2,11 +2,9 @@ pub mod executor_approvals;
|
||||
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use db::models::{
|
||||
execution_process::ExecutionProcess,
|
||||
executor_session::ExecutorSession,
|
||||
task::{Task, TaskStatus},
|
||||
};
|
||||
use executors::{
|
||||
@@ -21,10 +19,7 @@ use sqlx::{Error as SqlxError, SqlitePool};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{RwLock, oneshot};
|
||||
use utils::{
|
||||
approvals::{
|
||||
ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus,
|
||||
CreateApprovalRequest,
|
||||
},
|
||||
approvals::{ApprovalRequest, ApprovalResponse, ApprovalStatus},
|
||||
log_msg::LogMsg,
|
||||
msg_store::MsgStore,
|
||||
};
|
||||
@@ -36,8 +31,6 @@ struct PendingApproval {
|
||||
entry: NormalizedEntry,
|
||||
execution_process_id: Uuid,
|
||||
tool_name: String,
|
||||
requested_at: DateTime<Utc>,
|
||||
timeout_at: DateTime<Utc>,
|
||||
response_tx: oneshot::Sender<ApprovalStatus>,
|
||||
}
|
||||
|
||||
@@ -81,7 +74,7 @@ impl Approvals {
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_internal(
|
||||
pub async fn create_with_waiter(
|
||||
&self,
|
||||
request: ApprovalRequest,
|
||||
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
|
||||
@@ -94,12 +87,7 @@ impl Approvals {
|
||||
|
||||
if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await {
|
||||
// Find the matching tool use entry by name and input
|
||||
let matching_tool = find_matching_tool_use(
|
||||
store.clone(),
|
||||
&request.tool_name,
|
||||
&request.tool_input,
|
||||
request.tool_call_id.as_deref(),
|
||||
);
|
||||
let matching_tool = find_matching_tool_use(store.clone(), &request.tool_call_id);
|
||||
|
||||
if let Some((idx, matching_tool)) = matching_tool {
|
||||
let approval_entry = matching_tool
|
||||
@@ -118,8 +106,6 @@ impl Approvals {
|
||||
entry: matching_tool,
|
||||
execution_process_id: request.execution_process_id,
|
||||
tool_name: request.tool_name.clone(),
|
||||
requested_at: request.created_at,
|
||||
timeout_at: request.timeout_at,
|
||||
response_tx: tx,
|
||||
},
|
||||
);
|
||||
@@ -147,41 +133,6 @@ impl Approvals {
|
||||
Ok((request, waiter))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, request))]
|
||||
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
|
||||
let (request, _) = self.create_internal(request).await?;
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
pub async fn create_with_waiter(
|
||||
&self,
|
||||
request: ApprovalRequest,
|
||||
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
|
||||
self.create_internal(request).await
|
||||
}
|
||||
|
||||
pub async fn create_from_session(
|
||||
&self,
|
||||
pool: &SqlitePool,
|
||||
payload: CreateApprovalRequest,
|
||||
) -> Result<ApprovalRequest, ApprovalError> {
|
||||
let session_id = payload.session_id.clone();
|
||||
let execution_process_id =
|
||||
match ExecutorSession::find_by_session_id(pool, &session_id).await? {
|
||||
Some(session) => session.execution_process_id,
|
||||
None => {
|
||||
tracing::warn!("No executor session found for session_id: {}", session_id);
|
||||
return Err(ApprovalError::NoExecutorSession(session_id));
|
||||
}
|
||||
};
|
||||
|
||||
// Move the task to InReview if it's still InProgress
|
||||
ensure_task_in_review(pool, execution_process_id).await;
|
||||
|
||||
let request = ApprovalRequest::from_create(payload, execution_process_id);
|
||||
self.create(request).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, id, req))]
|
||||
pub async fn respond(
|
||||
&self,
|
||||
@@ -238,39 +189,6 @@ impl Approvals {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn status(&self, id: &str) -> Option<ApprovalStatus> {
|
||||
if let Some(f) = self.completed.get(id) {
|
||||
return Some(f.clone());
|
||||
}
|
||||
if let Some(p) = self.pending.get(id) {
|
||||
if chrono::Utc::now() >= p.timeout_at {
|
||||
return Some(ApprovalStatus::TimedOut);
|
||||
}
|
||||
return Some(ApprovalStatus::Pending);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub async fn pending(&self) -> Vec<ApprovalPendingInfo> {
|
||||
self.pending
|
||||
.iter()
|
||||
.filter_map(|entry| {
|
||||
let (id, pending) = entry.pair();
|
||||
|
||||
match &pending.entry.entry_type {
|
||||
NormalizedEntryType::ToolUse { tool_name, .. } => Some(ApprovalPendingInfo {
|
||||
approval_id: id.clone(),
|
||||
execution_process_id: pending.execution_process_id,
|
||||
tool_name: tool_name.clone(),
|
||||
requested_at: pending.requested_at,
|
||||
timeout_at: pending.timeout_at,
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, id, timeout_at, waiter))]
|
||||
fn spawn_timeout_watcher(
|
||||
&self,
|
||||
@@ -352,126 +270,37 @@ pub(crate) async fn ensure_task_in_review(pool: &SqlitePool, execution_process_i
|
||||
}
|
||||
}
|
||||
|
||||
/// Comparison strategy for matching tool use entries
|
||||
enum ToolComparisonStrategy {
|
||||
/// Compare by tool_call_id
|
||||
ToolCallId(String),
|
||||
/// Compare deserialized ClaudeToolData structures (for known tools)
|
||||
Deserialized(executors::executors::claude::ClaudeToolData),
|
||||
/// Compare raw JSON input fields (for Unknown tools like MCP)
|
||||
RawJson,
|
||||
}
|
||||
|
||||
/// Find a matching tool use entry that hasn't been assigned to an approval yet
|
||||
/// Matches by tool name and tool input to support parallel tool calls
|
||||
/// Matches by tool call id from tool metadata
|
||||
fn find_matching_tool_use(
|
||||
store: Arc<MsgStore>,
|
||||
tool_name: &str,
|
||||
tool_input: &serde_json::Value,
|
||||
tool_call_id: Option<&str>,
|
||||
tool_call_id: &str,
|
||||
) -> Option<(usize, NormalizedEntry)> {
|
||||
use executors::executors::claude::ClaudeToolData;
|
||||
|
||||
let history = store.get_history();
|
||||
|
||||
// Determine comparison strategy based on tool type
|
||||
let strategy = if let Some(call_id) = tool_call_id {
|
||||
// If tool_call_id is provided, use it for matching
|
||||
ToolComparisonStrategy::ToolCallId(call_id.to_string())
|
||||
} else {
|
||||
match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
|
||||
"name": tool_name,
|
||||
"input": tool_input
|
||||
})) {
|
||||
Ok(ClaudeToolData::Unknown { .. }) => {
|
||||
// For Unknown tools (MCP, future tools), use raw JSON comparison
|
||||
ToolComparisonStrategy::RawJson
|
||||
}
|
||||
Ok(data) => {
|
||||
// For known tools, use deserialized comparison with proper alias handling
|
||||
ToolComparisonStrategy::Deserialized(data)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to deserialize tool_input for tool '{}': {}",
|
||||
tool_name,
|
||||
e
|
||||
);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Single loop through history with strategy-based comparison
|
||||
// Single loop through history
|
||||
for msg in history.iter().rev() {
|
||||
if let LogMsg::JsonPatch(patch) = msg
|
||||
&& let Some((idx, entry)) = extract_normalized_entry_from_patch(patch)
|
||||
&& let NormalizedEntryType::ToolUse {
|
||||
tool_name: entry_tool_name,
|
||||
status,
|
||||
..
|
||||
} = &entry.entry_type
|
||||
&& let NormalizedEntryType::ToolUse { status, .. } = &entry.entry_type
|
||||
{
|
||||
// Only match tools that are in Created state
|
||||
if !matches!(status, ToolStatus::Created) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Tool name must match
|
||||
if entry_tool_name != tool_name {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply comparison strategy
|
||||
if let Some(metadata) = &entry.metadata {
|
||||
let is_match = match &strategy {
|
||||
ToolComparisonStrategy::ToolCallId(call_id) => {
|
||||
// Match by tool_call_id in metadata
|
||||
if let Ok(ToolCallMetadata {
|
||||
tool_call_id: entry_call_id,
|
||||
..
|
||||
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
|
||||
{
|
||||
entry_call_id == *call_id
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ToolComparisonStrategy::RawJson => {
|
||||
// Compare raw JSON input for Unknown tools
|
||||
if let Some(entry_input) = metadata.get("input") {
|
||||
entry_input == tool_input
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ToolComparisonStrategy::Deserialized(approval_data) => {
|
||||
// Compare deserialized structures for known tools
|
||||
if let Ok(entry_tool_data) =
|
||||
serde_json::from_value::<ClaudeToolData>(metadata.clone())
|
||||
{
|
||||
entry_tool_data == *approval_data
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if is_match {
|
||||
let strategy_name = match strategy {
|
||||
ToolComparisonStrategy::ToolCallId(call_id) => {
|
||||
format!("tool_call_id '{call_id}'")
|
||||
}
|
||||
ToolComparisonStrategy::RawJson => "raw input comparison".to_string(),
|
||||
ToolComparisonStrategy::Deserialized(_) => {
|
||||
"deserialized tool data".to_string()
|
||||
}
|
||||
};
|
||||
tracing::debug!(
|
||||
"Matched tool use entry at index {idx} for tool '{tool_name}' by {strategy_name}"
|
||||
);
|
||||
return Some((idx, entry));
|
||||
}
|
||||
// Match by tool call id from metadata
|
||||
if let Some(metadata) = &entry.metadata
|
||||
&& let Ok(ToolCallMetadata {
|
||||
tool_call_id: entry_call_id,
|
||||
..
|
||||
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
|
||||
&& entry_call_id == tool_call_id
|
||||
{
|
||||
tracing::debug!(
|
||||
"Matched tool use entry at index {idx} for tool call id '{tool_call_id}'"
|
||||
);
|
||||
return Some((idx, entry));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -491,19 +320,9 @@ mod tests {
|
||||
fn create_tool_use_entry(
|
||||
tool_name: &str,
|
||||
file_path: &str,
|
||||
id: &str,
|
||||
status: ToolStatus,
|
||||
) -> NormalizedEntry {
|
||||
// Create metadata that mimics the actual structure from Claude Code
|
||||
// which has an "input" field containing the original tool parameters
|
||||
let metadata = serde_json::json!({
|
||||
"type": "tool_use",
|
||||
"id": format!("test-{}", file_path),
|
||||
"name": tool_name,
|
||||
"input": {
|
||||
"file_path": file_path
|
||||
}
|
||||
});
|
||||
|
||||
NormalizedEntry {
|
||||
timestamp: None,
|
||||
entry_type: NormalizedEntryType::ToolUse {
|
||||
@@ -514,7 +333,12 @@ mod tests {
|
||||
status,
|
||||
},
|
||||
content: format!("Reading {file_path}"),
|
||||
metadata: Some(metadata),
|
||||
metadata: Some(
|
||||
serde_json::to_value(ToolCallMetadata {
|
||||
tool_call_id: id.to_string(),
|
||||
})
|
||||
.unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,9 +347,9 @@ mod tests {
|
||||
let store = Arc::new(MsgStore::new());
|
||||
|
||||
// Setup: Simulate 3 parallel Read tool calls with different files
|
||||
let read_foo = create_tool_use_entry("Read", "foo.rs", ToolStatus::Created);
|
||||
let read_bar = create_tool_use_entry("Read", "bar.rs", ToolStatus::Created);
|
||||
let read_baz = create_tool_use_entry("Read", "baz.rs", ToolStatus::Created);
|
||||
let read_foo = create_tool_use_entry("Read", "foo.rs", "foo-id", ToolStatus::Created);
|
||||
let read_bar = create_tool_use_entry("Read", "bar.rs", "bar-id", ToolStatus::Created);
|
||||
let read_baz = create_tool_use_entry("Read", "baz.rs", "baz-id", ToolStatus::Created);
|
||||
|
||||
store.push_patch(
|
||||
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(0, read_foo),
|
||||
@@ -537,17 +361,12 @@ mod tests {
|
||||
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(2, read_baz),
|
||||
);
|
||||
|
||||
// Test 1: Each approval request matches its specific tool by input
|
||||
let foo_input = serde_json::json!({"file_path": "foo.rs"});
|
||||
let bar_input = serde_json::json!({"file_path": "bar.rs"});
|
||||
let baz_input = serde_json::json!({"file_path": "baz.rs"});
|
||||
|
||||
let (idx_foo, _) = find_matching_tool_use(store.clone(), "Read", &foo_input, None)
|
||||
.expect("Should match foo.rs");
|
||||
let (idx_bar, _) = find_matching_tool_use(store.clone(), "Read", &bar_input, None)
|
||||
.expect("Should match bar.rs");
|
||||
let (idx_baz, _) = find_matching_tool_use(store.clone(), "Read", &baz_input, None)
|
||||
.expect("Should match baz.rs");
|
||||
let (idx_foo, _) =
|
||||
find_matching_tool_use(store.clone(), "foo-id").expect("Should match foo.rs");
|
||||
let (idx_bar, _) =
|
||||
find_matching_tool_use(store.clone(), "bar-id").expect("Should match bar.rs");
|
||||
let (idx_baz, _) =
|
||||
find_matching_tool_use(store.clone(), "baz-id").expect("Should match baz.rs");
|
||||
|
||||
assert_eq!(idx_foo, 0, "foo.rs should match first entry");
|
||||
assert_eq!(idx_bar, 1, "bar.rs should match second entry");
|
||||
@@ -557,6 +376,7 @@ mod tests {
|
||||
let read_pending = create_tool_use_entry(
|
||||
"Read",
|
||||
"pending.rs",
|
||||
"pending-id",
|
||||
ToolStatus::PendingApproval {
|
||||
approval_id: "test-id".to_string(),
|
||||
requested_at: chrono::Utc::now(),
|
||||
@@ -567,24 +387,15 @@ mod tests {
|
||||
executors::logs::utils::patch::ConversationPatch::add_normalized_entry(3, read_pending),
|
||||
);
|
||||
|
||||
let pending_input = serde_json::json!({"file_path": "pending.rs"});
|
||||
assert!(
|
||||
find_matching_tool_use(store.clone(), "Read", &pending_input, None).is_none(),
|
||||
find_matching_tool_use(store.clone(), "pending-id").is_none(),
|
||||
"Should not match tools in PendingApproval state"
|
||||
);
|
||||
|
||||
// Test 3: Wrong tool name returns None
|
||||
let write_input = serde_json::json!({"file_path": "foo.rs", "content": "test"});
|
||||
// Test 3: Wrong tool id returns None
|
||||
assert!(
|
||||
find_matching_tool_use(store.clone(), "Write", &write_input, None).is_none(),
|
||||
"Should not match different tool names"
|
||||
);
|
||||
|
||||
// Test 4: Wrong input parameters returns None
|
||||
let wrong_input = serde_json::json!({"file_path": "nonexistent.rs"});
|
||||
assert!(
|
||||
find_matching_tool_use(store.clone(), "Read", &wrong_input, None).is_none(),
|
||||
"Should not match with different input parameters"
|
||||
find_matching_tool_use(store.clone(), "wrong-id").is_none(),
|
||||
"Should not match different tool ids"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use async_trait::async_trait;
|
||||
use db::{self, DBService};
|
||||
use executors::approvals::{ExecutorApprovalError, ExecutorApprovalService};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::RwLock;
|
||||
use utils::approvals::{ApprovalRequest, ApprovalStatus, CreateApprovalRequest};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -14,7 +13,6 @@ pub struct ExecutorApprovalBridge {
|
||||
approvals: Approvals,
|
||||
db: DBService,
|
||||
execution_process_id: Uuid,
|
||||
session_id: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl ExecutorApprovalBridge {
|
||||
@@ -23,41 +21,25 @@ impl ExecutorApprovalBridge {
|
||||
approvals,
|
||||
db,
|
||||
execution_process_id,
|
||||
session_id: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ExecutorApprovalService for ExecutorApprovalBridge {
|
||||
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError> {
|
||||
let mut guard = self.session_id.write().await;
|
||||
guard.replace(session_id.to_string());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn request_tool_approval(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
tool_input: Value,
|
||||
tool_call_id: &str,
|
||||
) -> Result<ApprovalStatus, ExecutorApprovalError> {
|
||||
let session_id = {
|
||||
let guard = self.session_id.read().await;
|
||||
guard
|
||||
.clone()
|
||||
.ok_or(ExecutorApprovalError::SessionNotRegistered)?
|
||||
};
|
||||
|
||||
super::ensure_task_in_review(&self.db.pool, self.execution_process_id).await;
|
||||
|
||||
let request = ApprovalRequest::from_create(
|
||||
CreateApprovalRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
tool_input,
|
||||
session_id,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
tool_call_id: tool_call_id.to_string(),
|
||||
},
|
||||
self.execution_process_id,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user