codex approvals (#993)

* codex approvals

* send deny feedback

* Normalize user-feedback

* use tool call id to match normalized_entry

* store approvals in executor

* add noop approval for api consistency

---------

Co-authored-by: Gabriel Gordon-Hall <ggordonhall@gmail.com>
This commit is contained in:
Solomon
2025-10-20 18:02:58 +01:00
committed by GitHub
parent ee68b2fc43
commit 62834ea581
21 changed files with 942 additions and 139 deletions

View File

@@ -1,3 +1,5 @@
pub mod executor_approvals;
use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
use chrono::{DateTime, Utc};
@@ -7,10 +9,14 @@ use db::models::{
executor_session::ExecutorSession,
task::{Task, TaskStatus},
};
use executors::logs::{
NormalizedEntry, NormalizedEntryType, ToolStatus,
utils::patch::{ConversationPatch, extract_normalized_entry_from_patch},
use executors::{
approvals::ToolCallMetadata,
logs::{
NormalizedEntry, NormalizedEntryType, ToolStatus,
utils::patch::{ConversationPatch, extract_normalized_entry_from_patch},
},
};
use futures::future::{BoxFuture, FutureExt, Shared};
use sqlx::{Error as SqlxError, SqlitePool};
use thiserror::Error;
use tokio::sync::{RwLock, oneshot};
@@ -35,6 +41,8 @@ struct PendingApproval {
response_tx: oneshot::Sender<ApprovalStatus>,
}
type ApprovalWaiter = Shared<BoxFuture<'static, ApprovalStatus>>;
#[derive(Debug)]
pub struct ToolContext {
pub tool_name: String,
@@ -73,15 +81,25 @@ impl Approvals {
}
}
#[tracing::instrument(skip(self, request))]
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
async fn create_internal(
&self,
request: ApprovalRequest,
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
let (tx, rx) = oneshot::channel();
let waiter: ApprovalWaiter = rx
.map(|result| result.unwrap_or(ApprovalStatus::TimedOut))
.boxed()
.shared();
let req_id = request.id.clone();
if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await {
// Find the matching tool use entry by name and input
let matching_tool =
find_matching_tool_use(store.clone(), &request.tool_name, &request.tool_input);
let matching_tool = find_matching_tool_use(
store.clone(),
&request.tool_name,
&request.tool_input,
request.tool_call_id.as_deref(),
);
if let Some((idx, matching_tool)) = matching_tool {
let approval_entry = matching_tool
@@ -125,10 +143,23 @@ impl Approvals {
);
}
self.spawn_timeout_watcher(req_id.clone(), request.timeout_at, rx);
self.spawn_timeout_watcher(req_id.clone(), request.timeout_at, waiter.clone());
Ok((request, waiter))
}
#[tracing::instrument(skip(self, request))]
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
let (request, _) = self.create_internal(request).await?;
Ok(request)
}
pub async fn create_with_waiter(
&self,
request: ApprovalRequest,
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
self.create_internal(request).await
}
pub async fn create_from_session(
&self,
pool: &SqlitePool,
@@ -145,15 +176,7 @@ impl Approvals {
};
// Move the task to InReview if it's still InProgress
if let Ok(ctx) = ExecutionProcess::load_context(pool, execution_process_id).await
&& ctx.task.status == TaskStatus::InProgress
&& let Err(e) = Task::update_status(pool, ctx.task.id, TaskStatus::InReview).await
{
tracing::warn!(
"Failed to update task status to InReview for approval request: {}",
e
);
}
ensure_task_in_review(pool, execution_process_id).await;
let request = ApprovalRequest::from_create(payload, execution_process_id);
self.create(request).await
@@ -248,12 +271,12 @@ impl Approvals {
.collect()
}
#[tracing::instrument(skip(self, id, timeout_at, rx))]
#[tracing::instrument(skip(self, id, timeout_at, waiter))]
fn spawn_timeout_watcher(
&self,
id: String,
timeout_at: chrono::DateTime<chrono::Utc>,
mut rx: oneshot::Receiver<ApprovalStatus>,
waiter: ApprovalWaiter,
) {
let pending = self.pending.clone();
let completed = self.completed.clone();
@@ -269,19 +292,18 @@ impl Approvals {
let status = tokio::select! {
biased;
r = &mut rx => match r {
Ok(status) => status,
Err(_canceled) => ApprovalStatus::TimedOut,
},
resolved = waiter.clone() => resolved,
_ = tokio::time::sleep_until(deadline) => ApprovalStatus::TimedOut,
};
let is_timeout = matches!(&status, ApprovalStatus::TimedOut);
completed.insert(id.clone(), status.clone());
let removed = pending.remove(&id);
if is_timeout && let Some((_, pending_approval)) = pending.remove(&id) {
if pending_approval.response_tx.send(status.clone()).is_err() {
tracing::debug!("approval '{}' timeout notification receiver dropped", id);
}
if is_timeout && let Some((_, pending_approval)) = removed {
let store = {
let map = msg_stores.read().await;
map.get(&pending_approval.execution_process_id).cloned()
@@ -318,8 +340,22 @@ impl Approvals {
}
}
pub(crate) async fn ensure_task_in_review(pool: &SqlitePool, execution_process_id: Uuid) {
if let Ok(ctx) = ExecutionProcess::load_context(pool, execution_process_id).await
&& ctx.task.status == TaskStatus::InProgress
&& let Err(e) = Task::update_status(pool, ctx.task.id, TaskStatus::InReview).await
{
tracing::warn!(
"Failed to update task status to InReview for approval request: {}",
e
);
}
}
/// Comparison strategy for matching tool use entries
enum ToolComparisonStrategy {
/// Compare by tool_call_id
ToolCallId(String),
/// Compare deserialized ClaudeToolData structures (for known tools)
Deserialized(executors::executors::claude::ClaudeToolData),
/// Compare raw JSON input fields (for Unknown tools like MCP)
@@ -332,31 +368,37 @@ fn find_matching_tool_use(
store: Arc<MsgStore>,
tool_name: &str,
tool_input: &serde_json::Value,
tool_call_id: Option<&str>,
) -> Option<(usize, NormalizedEntry)> {
use executors::executors::claude::ClaudeToolData;
let history = store.get_history();
// Determine comparison strategy based on tool type
let strategy = match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
"name": tool_name,
"input": tool_input
})) {
Ok(ClaudeToolData::Unknown { .. }) => {
// For Unknown tools (MCP, future tools), use raw JSON comparison
ToolComparisonStrategy::RawJson
}
Ok(data) => {
// For known tools, use deserialized comparison with proper alias handling
ToolComparisonStrategy::Deserialized(data)
}
Err(e) => {
tracing::warn!(
"Failed to deserialize tool_input for tool '{}': {}",
tool_name,
e
);
return None;
let strategy = if let Some(call_id) = tool_call_id {
// If tool_call_id is provided, use it for matching
ToolComparisonStrategy::ToolCallId(call_id.to_string())
} else {
match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
"name": tool_name,
"input": tool_input
})) {
Ok(ClaudeToolData::Unknown { .. }) => {
// For Unknown tools (MCP, future tools), use raw JSON comparison
ToolComparisonStrategy::RawJson
}
Ok(data) => {
// For known tools, use deserialized comparison with proper alias handling
ToolComparisonStrategy::Deserialized(data)
}
Err(e) => {
tracing::warn!(
"Failed to deserialize tool_input for tool '{}': {}",
tool_name,
e
);
return None;
}
}
};
@@ -383,6 +425,18 @@ fn find_matching_tool_use(
// Apply comparison strategy
if let Some(metadata) = &entry.metadata {
let is_match = match &strategy {
ToolComparisonStrategy::ToolCallId(call_id) => {
// Match by tool_call_id in metadata
if let Ok(ToolCallMetadata {
tool_call_id: entry_call_id,
..
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
{
entry_call_id == *call_id
} else {
false
}
}
ToolComparisonStrategy::RawJson => {
// Compare raw JSON input for Unknown tools
if let Some(entry_input) = metadata.get("input") {
@@ -405,8 +459,13 @@ fn find_matching_tool_use(
if is_match {
let strategy_name = match strategy {
ToolComparisonStrategy::RawJson => "raw input comparison",
ToolComparisonStrategy::Deserialized(_) => "deserialized tool data",
ToolComparisonStrategy::ToolCallId(call_id) => {
format!("tool_call_id '{call_id}'")
}
ToolComparisonStrategy::RawJson => "raw input comparison".to_string(),
ToolComparisonStrategy::Deserialized(_) => {
"deserialized tool data".to_string()
}
};
tracing::debug!(
"Matched tool use entry at index {idx} for tool '{tool_name}' by {strategy_name}"
@@ -483,12 +542,12 @@ mod tests {
let bar_input = serde_json::json!({"file_path": "bar.rs"});
let baz_input = serde_json::json!({"file_path": "baz.rs"});
let (idx_foo, _) =
find_matching_tool_use(store.clone(), "Read", &foo_input).expect("Should match foo.rs");
let (idx_bar, _) =
find_matching_tool_use(store.clone(), "Read", &bar_input).expect("Should match bar.rs");
let (idx_baz, _) =
find_matching_tool_use(store.clone(), "Read", &baz_input).expect("Should match baz.rs");
let (idx_foo, _) = find_matching_tool_use(store.clone(), "Read", &foo_input, None)
.expect("Should match foo.rs");
let (idx_bar, _) = find_matching_tool_use(store.clone(), "Read", &bar_input, None)
.expect("Should match bar.rs");
let (idx_baz, _) = find_matching_tool_use(store.clone(), "Read", &baz_input, None)
.expect("Should match baz.rs");
assert_eq!(idx_foo, 0, "foo.rs should match first entry");
assert_eq!(idx_bar, 1, "bar.rs should match second entry");
@@ -510,21 +569,21 @@ mod tests {
let pending_input = serde_json::json!({"file_path": "pending.rs"});
assert!(
find_matching_tool_use(store.clone(), "Read", &pending_input).is_none(),
find_matching_tool_use(store.clone(), "Read", &pending_input, None).is_none(),
"Should not match tools in PendingApproval state"
);
// Test 3: Wrong tool name returns None
let write_input = serde_json::json!({"file_path": "foo.rs", "content": "test"});
assert!(
find_matching_tool_use(store.clone(), "Write", &write_input).is_none(),
find_matching_tool_use(store.clone(), "Write", &write_input, None).is_none(),
"Should not match different tool names"
);
// Test 4: Wrong input parameters returns None
let wrong_input = serde_json::json!({"file_path": "nonexistent.rs"});
assert!(
find_matching_tool_use(store.clone(), "Read", &wrong_input).is_none(),
find_matching_tool_use(store.clone(), "Read", &wrong_input, None).is_none(),
"Should not match with different input parameters"
);
}

View File

@@ -0,0 +1,81 @@
use std::sync::Arc;
use async_trait::async_trait;
use db::{self, DBService};
use executors::approvals::{ExecutorApprovalError, ExecutorApprovalService};
use serde_json::Value;
use tokio::sync::RwLock;
use utils::approvals::{ApprovalRequest, ApprovalStatus, CreateApprovalRequest};
use uuid::Uuid;
use crate::services::approvals::Approvals;
pub struct ExecutorApprovalBridge {
approvals: Approvals,
db: DBService,
execution_process_id: Uuid,
session_id: RwLock<Option<String>>,
}
impl ExecutorApprovalBridge {
pub fn new(approvals: Approvals, db: DBService, execution_process_id: Uuid) -> Arc<Self> {
Arc::new(Self {
approvals,
db,
execution_process_id,
session_id: RwLock::new(None),
})
}
}
#[async_trait]
impl ExecutorApprovalService for ExecutorApprovalBridge {
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError> {
let mut guard = self.session_id.write().await;
guard.replace(session_id.to_string());
Ok(())
}
async fn request_tool_approval(
&self,
tool_name: &str,
tool_input: Value,
tool_call_id: &str,
) -> Result<ApprovalStatus, ExecutorApprovalError> {
let session_id = {
let guard = self.session_id.read().await;
guard
.clone()
.ok_or(ExecutorApprovalError::SessionNotRegistered)?
};
super::ensure_task_in_review(&self.db.pool, self.execution_process_id).await;
let request = ApprovalRequest::from_create(
CreateApprovalRequest {
tool_name: tool_name.to_string(),
tool_input,
session_id,
tool_call_id: Some(tool_call_id.to_string()),
},
self.execution_process_id,
);
let (_, waiter) = self
.approvals
.create_with_waiter(request)
.await
.map_err(ExecutorApprovalError::request_failed)?;
let status = waiter.clone().await;
if matches!(status, ApprovalStatus::Pending) {
return Err(ExecutorApprovalError::request_failed(
"approval finished in pending state",
));
}
Ok(status)
}
}