codex approvals (#993)
* codex approvals * send deny feedback * Normalize user-feedback * use tool call id to match normalized_entry * store approvals in executor * add noop approval for api consistency --------- Co-authored-by: Gabriel Gordon-Hall <ggordonhall@gmail.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
pub mod executor_approvals;
|
||||
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
@@ -7,10 +9,14 @@ use db::models::{
|
||||
executor_session::ExecutorSession,
|
||||
task::{Task, TaskStatus},
|
||||
};
|
||||
use executors::logs::{
|
||||
NormalizedEntry, NormalizedEntryType, ToolStatus,
|
||||
utils::patch::{ConversationPatch, extract_normalized_entry_from_patch},
|
||||
use executors::{
|
||||
approvals::ToolCallMetadata,
|
||||
logs::{
|
||||
NormalizedEntry, NormalizedEntryType, ToolStatus,
|
||||
utils::patch::{ConversationPatch, extract_normalized_entry_from_patch},
|
||||
},
|
||||
};
|
||||
use futures::future::{BoxFuture, FutureExt, Shared};
|
||||
use sqlx::{Error as SqlxError, SqlitePool};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{RwLock, oneshot};
|
||||
@@ -35,6 +41,8 @@ struct PendingApproval {
|
||||
response_tx: oneshot::Sender<ApprovalStatus>,
|
||||
}
|
||||
|
||||
type ApprovalWaiter = Shared<BoxFuture<'static, ApprovalStatus>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolContext {
|
||||
pub tool_name: String,
|
||||
@@ -73,15 +81,25 @@ impl Approvals {
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, request))]
|
||||
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
|
||||
async fn create_internal(
|
||||
&self,
|
||||
request: ApprovalRequest,
|
||||
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let waiter: ApprovalWaiter = rx
|
||||
.map(|result| result.unwrap_or(ApprovalStatus::TimedOut))
|
||||
.boxed()
|
||||
.shared();
|
||||
let req_id = request.id.clone();
|
||||
|
||||
if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await {
|
||||
// Find the matching tool use entry by name and input
|
||||
let matching_tool =
|
||||
find_matching_tool_use(store.clone(), &request.tool_name, &request.tool_input);
|
||||
let matching_tool = find_matching_tool_use(
|
||||
store.clone(),
|
||||
&request.tool_name,
|
||||
&request.tool_input,
|
||||
request.tool_call_id.as_deref(),
|
||||
);
|
||||
|
||||
if let Some((idx, matching_tool)) = matching_tool {
|
||||
let approval_entry = matching_tool
|
||||
@@ -125,10 +143,23 @@ impl Approvals {
|
||||
);
|
||||
}
|
||||
|
||||
self.spawn_timeout_watcher(req_id.clone(), request.timeout_at, rx);
|
||||
self.spawn_timeout_watcher(req_id.clone(), request.timeout_at, waiter.clone());
|
||||
Ok((request, waiter))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, request))]
|
||||
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
|
||||
let (request, _) = self.create_internal(request).await?;
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
pub async fn create_with_waiter(
|
||||
&self,
|
||||
request: ApprovalRequest,
|
||||
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
|
||||
self.create_internal(request).await
|
||||
}
|
||||
|
||||
pub async fn create_from_session(
|
||||
&self,
|
||||
pool: &SqlitePool,
|
||||
@@ -145,15 +176,7 @@ impl Approvals {
|
||||
};
|
||||
|
||||
// Move the task to InReview if it's still InProgress
|
||||
if let Ok(ctx) = ExecutionProcess::load_context(pool, execution_process_id).await
|
||||
&& ctx.task.status == TaskStatus::InProgress
|
||||
&& let Err(e) = Task::update_status(pool, ctx.task.id, TaskStatus::InReview).await
|
||||
{
|
||||
tracing::warn!(
|
||||
"Failed to update task status to InReview for approval request: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
ensure_task_in_review(pool, execution_process_id).await;
|
||||
|
||||
let request = ApprovalRequest::from_create(payload, execution_process_id);
|
||||
self.create(request).await
|
||||
@@ -248,12 +271,12 @@ impl Approvals {
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, id, timeout_at, rx))]
|
||||
#[tracing::instrument(skip(self, id, timeout_at, waiter))]
|
||||
fn spawn_timeout_watcher(
|
||||
&self,
|
||||
id: String,
|
||||
timeout_at: chrono::DateTime<chrono::Utc>,
|
||||
mut rx: oneshot::Receiver<ApprovalStatus>,
|
||||
waiter: ApprovalWaiter,
|
||||
) {
|
||||
let pending = self.pending.clone();
|
||||
let completed = self.completed.clone();
|
||||
@@ -269,19 +292,18 @@ impl Approvals {
|
||||
let status = tokio::select! {
|
||||
biased;
|
||||
|
||||
r = &mut rx => match r {
|
||||
Ok(status) => status,
|
||||
Err(_canceled) => ApprovalStatus::TimedOut,
|
||||
},
|
||||
resolved = waiter.clone() => resolved,
|
||||
_ = tokio::time::sleep_until(deadline) => ApprovalStatus::TimedOut,
|
||||
};
|
||||
|
||||
let is_timeout = matches!(&status, ApprovalStatus::TimedOut);
|
||||
completed.insert(id.clone(), status.clone());
|
||||
|
||||
let removed = pending.remove(&id);
|
||||
if is_timeout && let Some((_, pending_approval)) = pending.remove(&id) {
|
||||
if pending_approval.response_tx.send(status.clone()).is_err() {
|
||||
tracing::debug!("approval '{}' timeout notification receiver dropped", id);
|
||||
}
|
||||
|
||||
if is_timeout && let Some((_, pending_approval)) = removed {
|
||||
let store = {
|
||||
let map = msg_stores.read().await;
|
||||
map.get(&pending_approval.execution_process_id).cloned()
|
||||
@@ -318,8 +340,22 @@ impl Approvals {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn ensure_task_in_review(pool: &SqlitePool, execution_process_id: Uuid) {
|
||||
if let Ok(ctx) = ExecutionProcess::load_context(pool, execution_process_id).await
|
||||
&& ctx.task.status == TaskStatus::InProgress
|
||||
&& let Err(e) = Task::update_status(pool, ctx.task.id, TaskStatus::InReview).await
|
||||
{
|
||||
tracing::warn!(
|
||||
"Failed to update task status to InReview for approval request: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Comparison strategy for matching tool use entries
|
||||
enum ToolComparisonStrategy {
|
||||
/// Compare by tool_call_id
|
||||
ToolCallId(String),
|
||||
/// Compare deserialized ClaudeToolData structures (for known tools)
|
||||
Deserialized(executors::executors::claude::ClaudeToolData),
|
||||
/// Compare raw JSON input fields (for Unknown tools like MCP)
|
||||
@@ -332,31 +368,37 @@ fn find_matching_tool_use(
|
||||
store: Arc<MsgStore>,
|
||||
tool_name: &str,
|
||||
tool_input: &serde_json::Value,
|
||||
tool_call_id: Option<&str>,
|
||||
) -> Option<(usize, NormalizedEntry)> {
|
||||
use executors::executors::claude::ClaudeToolData;
|
||||
|
||||
let history = store.get_history();
|
||||
|
||||
// Determine comparison strategy based on tool type
|
||||
let strategy = match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
|
||||
"name": tool_name,
|
||||
"input": tool_input
|
||||
})) {
|
||||
Ok(ClaudeToolData::Unknown { .. }) => {
|
||||
// For Unknown tools (MCP, future tools), use raw JSON comparison
|
||||
ToolComparisonStrategy::RawJson
|
||||
}
|
||||
Ok(data) => {
|
||||
// For known tools, use deserialized comparison with proper alias handling
|
||||
ToolComparisonStrategy::Deserialized(data)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to deserialize tool_input for tool '{}': {}",
|
||||
tool_name,
|
||||
e
|
||||
);
|
||||
return None;
|
||||
let strategy = if let Some(call_id) = tool_call_id {
|
||||
// If tool_call_id is provided, use it for matching
|
||||
ToolComparisonStrategy::ToolCallId(call_id.to_string())
|
||||
} else {
|
||||
match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
|
||||
"name": tool_name,
|
||||
"input": tool_input
|
||||
})) {
|
||||
Ok(ClaudeToolData::Unknown { .. }) => {
|
||||
// For Unknown tools (MCP, future tools), use raw JSON comparison
|
||||
ToolComparisonStrategy::RawJson
|
||||
}
|
||||
Ok(data) => {
|
||||
// For known tools, use deserialized comparison with proper alias handling
|
||||
ToolComparisonStrategy::Deserialized(data)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to deserialize tool_input for tool '{}': {}",
|
||||
tool_name,
|
||||
e
|
||||
);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -383,6 +425,18 @@ fn find_matching_tool_use(
|
||||
// Apply comparison strategy
|
||||
if let Some(metadata) = &entry.metadata {
|
||||
let is_match = match &strategy {
|
||||
ToolComparisonStrategy::ToolCallId(call_id) => {
|
||||
// Match by tool_call_id in metadata
|
||||
if let Ok(ToolCallMetadata {
|
||||
tool_call_id: entry_call_id,
|
||||
..
|
||||
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
|
||||
{
|
||||
entry_call_id == *call_id
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ToolComparisonStrategy::RawJson => {
|
||||
// Compare raw JSON input for Unknown tools
|
||||
if let Some(entry_input) = metadata.get("input") {
|
||||
@@ -405,8 +459,13 @@ fn find_matching_tool_use(
|
||||
|
||||
if is_match {
|
||||
let strategy_name = match strategy {
|
||||
ToolComparisonStrategy::RawJson => "raw input comparison",
|
||||
ToolComparisonStrategy::Deserialized(_) => "deserialized tool data",
|
||||
ToolComparisonStrategy::ToolCallId(call_id) => {
|
||||
format!("tool_call_id '{call_id}'")
|
||||
}
|
||||
ToolComparisonStrategy::RawJson => "raw input comparison".to_string(),
|
||||
ToolComparisonStrategy::Deserialized(_) => {
|
||||
"deserialized tool data".to_string()
|
||||
}
|
||||
};
|
||||
tracing::debug!(
|
||||
"Matched tool use entry at index {idx} for tool '{tool_name}' by {strategy_name}"
|
||||
@@ -483,12 +542,12 @@ mod tests {
|
||||
let bar_input = serde_json::json!({"file_path": "bar.rs"});
|
||||
let baz_input = serde_json::json!({"file_path": "baz.rs"});
|
||||
|
||||
let (idx_foo, _) =
|
||||
find_matching_tool_use(store.clone(), "Read", &foo_input).expect("Should match foo.rs");
|
||||
let (idx_bar, _) =
|
||||
find_matching_tool_use(store.clone(), "Read", &bar_input).expect("Should match bar.rs");
|
||||
let (idx_baz, _) =
|
||||
find_matching_tool_use(store.clone(), "Read", &baz_input).expect("Should match baz.rs");
|
||||
let (idx_foo, _) = find_matching_tool_use(store.clone(), "Read", &foo_input, None)
|
||||
.expect("Should match foo.rs");
|
||||
let (idx_bar, _) = find_matching_tool_use(store.clone(), "Read", &bar_input, None)
|
||||
.expect("Should match bar.rs");
|
||||
let (idx_baz, _) = find_matching_tool_use(store.clone(), "Read", &baz_input, None)
|
||||
.expect("Should match baz.rs");
|
||||
|
||||
assert_eq!(idx_foo, 0, "foo.rs should match first entry");
|
||||
assert_eq!(idx_bar, 1, "bar.rs should match second entry");
|
||||
@@ -510,21 +569,21 @@ mod tests {
|
||||
|
||||
let pending_input = serde_json::json!({"file_path": "pending.rs"});
|
||||
assert!(
|
||||
find_matching_tool_use(store.clone(), "Read", &pending_input).is_none(),
|
||||
find_matching_tool_use(store.clone(), "Read", &pending_input, None).is_none(),
|
||||
"Should not match tools in PendingApproval state"
|
||||
);
|
||||
|
||||
// Test 3: Wrong tool name returns None
|
||||
let write_input = serde_json::json!({"file_path": "foo.rs", "content": "test"});
|
||||
assert!(
|
||||
find_matching_tool_use(store.clone(), "Write", &write_input).is_none(),
|
||||
find_matching_tool_use(store.clone(), "Write", &write_input, None).is_none(),
|
||||
"Should not match different tool names"
|
||||
);
|
||||
|
||||
// Test 4: Wrong input parameters returns None
|
||||
let wrong_input = serde_json::json!({"file_path": "nonexistent.rs"});
|
||||
assert!(
|
||||
find_matching_tool_use(store.clone(), "Read", &wrong_input).is_none(),
|
||||
find_matching_tool_use(store.clone(), "Read", &wrong_input, None).is_none(),
|
||||
"Should not match with different input parameters"
|
||||
);
|
||||
}
|
||||
|
||||
81
crates/services/src/services/approvals/executor_approvals.rs
Normal file
81
crates/services/src/services/approvals/executor_approvals.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use db::{self, DBService};
|
||||
use executors::approvals::{ExecutorApprovalError, ExecutorApprovalService};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::RwLock;
|
||||
use utils::approvals::{ApprovalRequest, ApprovalStatus, CreateApprovalRequest};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::services::approvals::Approvals;
|
||||
|
||||
pub struct ExecutorApprovalBridge {
|
||||
approvals: Approvals,
|
||||
db: DBService,
|
||||
execution_process_id: Uuid,
|
||||
session_id: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl ExecutorApprovalBridge {
|
||||
pub fn new(approvals: Approvals, db: DBService, execution_process_id: Uuid) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
approvals,
|
||||
db,
|
||||
execution_process_id,
|
||||
session_id: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ExecutorApprovalService for ExecutorApprovalBridge {
|
||||
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError> {
|
||||
let mut guard = self.session_id.write().await;
|
||||
guard.replace(session_id.to_string());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn request_tool_approval(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
tool_input: Value,
|
||||
tool_call_id: &str,
|
||||
) -> Result<ApprovalStatus, ExecutorApprovalError> {
|
||||
let session_id = {
|
||||
let guard = self.session_id.read().await;
|
||||
guard
|
||||
.clone()
|
||||
.ok_or(ExecutorApprovalError::SessionNotRegistered)?
|
||||
};
|
||||
|
||||
super::ensure_task_in_review(&self.db.pool, self.execution_process_id).await;
|
||||
|
||||
let request = ApprovalRequest::from_create(
|
||||
CreateApprovalRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
tool_input,
|
||||
session_id,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
},
|
||||
self.execution_process_id,
|
||||
);
|
||||
|
||||
let (_, waiter) = self
|
||||
.approvals
|
||||
.create_with_waiter(request)
|
||||
.await
|
||||
.map_err(ExecutorApprovalError::request_failed)?;
|
||||
|
||||
let status = waiter.clone().await;
|
||||
|
||||
if matches!(status, ApprovalStatus::Pending) {
|
||||
return Err(ExecutorApprovalError::request_failed(
|
||||
"approval finished in pending state",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(status)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user