diff --git a/crates/executors/Cargo.toml b/crates/executors/Cargo.toml index 94e06be9..52059169 100644 --- a/crates/executors/Cargo.toml +++ b/crates/executors/Cargo.toml @@ -47,3 +47,4 @@ codex-protocol = { git = "https://github.com/openai/codex.git", package = "codex codex-app-server-protocol = { git = "https://github.com/openai/codex.git", package = "codex-app-server-protocol", rev = "488ec061bf4d36916b8f477c700ea4fde4162a7a" } codex-mcp-types = { git = "https://github.com/openai/codex.git", package = "mcp-types", rev = "488ec061bf4d36916b8f477c700ea4fde4162a7a" } sha2 = "0.10" +derivative = "2.2.0" diff --git a/crates/executors/default_profiles.json b/crates/executors/default_profiles.json index 71c7db14..da350164 100644 --- a/crates/executors/default_profiles.json +++ b/crates/executors/default_profiles.json @@ -49,6 +49,12 @@ "sandbox": "danger-full-access", "model_reasoning_effort": "high" } + }, + "APPROVALS": { + "CODEX": { + "sandbox": "workspace-write", + "ask_for_approval": "unless-trusted" + } } }, "OPENCODE": { diff --git a/crates/executors/src/actions/coding_agent_follow_up.rs b/crates/executors/src/actions/coding_agent_follow_up.rs index 2cba1053..c76c21d1 100644 --- a/crates/executors/src/actions/coding_agent_follow_up.rs +++ b/crates/executors/src/actions/coding_agent_follow_up.rs @@ -1,4 +1,4 @@ -use std::path::Path; +use std::{path::Path, sync::Arc}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -6,7 +6,8 @@ use ts_rs::TS; use crate::{ actions::Executable, - executors::{ExecutorError, SpawnedChild, StandardCodingAgentExecutor}, + approvals::ExecutorApprovalService, + executors::{BaseCodingAgent, ExecutorError, SpawnedChild, StandardCodingAgentExecutor}, profile::{ExecutorConfigs, ExecutorProfileId}, }; @@ -25,18 +26,28 @@ impl CodingAgentFollowUpRequest { pub fn get_executor_profile_id(&self) -> ExecutorProfileId { self.executor_profile_id.clone() } + + pub fn base_executor(&self) -> BaseCodingAgent { + self.executor_profile_id.executor + } } #[async_trait] impl Executable for CodingAgentFollowUpRequest { - async fn spawn(&self, current_dir: &Path) -> Result { + async fn spawn( + &self, + current_dir: &Path, + approvals: Arc, + ) -> Result { let executor_profile_id = self.get_executor_profile_id(); - let agent = ExecutorConfigs::get_cached() + let mut agent = ExecutorConfigs::get_cached() .get_coding_agent(&executor_profile_id) .ok_or(ExecutorError::UnknownExecutorType( executor_profile_id.to_string(), ))?; + agent.use_approvals(approvals.clone()); + agent .spawn_follow_up(current_dir, &self.prompt, &self.session_id) .await diff --git a/crates/executors/src/actions/coding_agent_initial.rs b/crates/executors/src/actions/coding_agent_initial.rs index 48dde87a..c313171d 100644 --- a/crates/executors/src/actions/coding_agent_initial.rs +++ b/crates/executors/src/actions/coding_agent_initial.rs @@ -1,4 +1,4 @@ -use std::path::Path; +use std::{path::Path, sync::Arc}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -6,7 +6,8 @@ use ts_rs::TS; use crate::{ actions::Executable, - executors::{ExecutorError, SpawnedChild, StandardCodingAgentExecutor}, + approvals::ExecutorApprovalService, + executors::{BaseCodingAgent, ExecutorError, SpawnedChild, StandardCodingAgentExecutor}, profile::{ExecutorConfigs, ExecutorProfileId}, }; @@ -19,16 +20,28 @@ pub struct CodingAgentInitialRequest { pub executor_profile_id: ExecutorProfileId, } +impl CodingAgentInitialRequest { + pub fn base_executor(&self) -> BaseCodingAgent { + self.executor_profile_id.executor + } +} + #[async_trait] impl Executable for CodingAgentInitialRequest { - async fn spawn(&self, current_dir: &Path) -> Result { + async fn spawn( + &self, + current_dir: &Path, + approvals: Arc, + ) -> Result { let executor_profile_id = self.executor_profile_id.clone(); - let agent = ExecutorConfigs::get_cached() + let mut agent = ExecutorConfigs::get_cached() .get_coding_agent(&executor_profile_id) .ok_or(ExecutorError::UnknownExecutorType( executor_profile_id.to_string(), ))?; + agent.use_approvals(approvals.clone()); + agent.spawn(current_dir, &self.prompt).await } } diff --git a/crates/executors/src/actions/mod.rs b/crates/executors/src/actions/mod.rs index 9a5a23aa..f2d8a1b3 100644 --- a/crates/executors/src/actions/mod.rs +++ b/crates/executors/src/actions/mod.rs @@ -1,4 +1,4 @@ -use std::path::Path; +use std::{path::Path, sync::Arc}; use async_trait::async_trait; use enum_dispatch::enum_dispatch; @@ -10,7 +10,8 @@ use crate::{ coding_agent_follow_up::CodingAgentFollowUpRequest, coding_agent_initial::CodingAgentInitialRequest, script::ScriptRequest, }, - executors::{ExecutorError, SpawnedChild}, + approvals::ExecutorApprovalService, + executors::{BaseCodingAgent, ExecutorError, SpawnedChild}, }; pub mod coding_agent_follow_up; pub mod coding_agent_initial; @@ -43,17 +44,35 @@ impl ExecutorAction { pub fn next_action(&self) -> Option<&ExecutorAction> { self.next_action.as_deref() } + + pub fn base_executor(&self) -> Option { + match self.typ() { + ExecutorActionType::CodingAgentInitialRequest(request) => Some(request.base_executor()), + ExecutorActionType::CodingAgentFollowUpRequest(request) => { + Some(request.base_executor()) + } + ExecutorActionType::ScriptRequest(_) => None, + } + } } #[async_trait] #[enum_dispatch(ExecutorActionType)] pub trait Executable { - async fn spawn(&self, current_dir: &Path) -> Result; + async fn spawn( + &self, + current_dir: &Path, + approvals: Arc, + ) -> Result; } #[async_trait] impl Executable for ExecutorAction { - async fn spawn(&self, current_dir: &Path) -> Result { - self.typ.spawn(current_dir).await + async fn spawn( + &self, + current_dir: &Path, + approvals: Arc, + ) -> Result { + self.typ.spawn(current_dir, approvals).await } } diff --git a/crates/executors/src/actions/script.rs b/crates/executors/src/actions/script.rs index fddc71da..fc4632ee 100644 --- a/crates/executors/src/actions/script.rs +++ b/crates/executors/src/actions/script.rs @@ -1,4 +1,4 @@ -use std::path::Path; +use std::{path::Path, sync::Arc}; use async_trait::async_trait; use command_group::AsyncCommandGroup; @@ -9,6 +9,7 @@ use workspace_utils::shell::get_shell_command; use crate::{ actions::Executable, + approvals::ExecutorApprovalService, executors::{ExecutorError, SpawnedChild}, }; @@ -33,7 +34,11 @@ pub struct ScriptRequest { #[async_trait] impl Executable for ScriptRequest { - async fn spawn(&self, current_dir: &Path) -> Result { + async fn spawn( + &self, + current_dir: &Path, + _approvals: Arc, + ) -> Result { let (shell_cmd, shell_arg) = get_shell_command(); let mut command = Command::new(shell_cmd); command diff --git a/crates/executors/src/approvals.rs b/crates/executors/src/approvals.rs new file mode 100644 index 00000000..5f12fa64 --- /dev/null +++ b/crates/executors/src/approvals.rs @@ -0,0 +1,63 @@ +use std::fmt; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use thiserror::Error; +use workspace_utils::approvals::ApprovalStatus; + +/// Errors emitted by executor approval services. +#[derive(Debug, Error)] +pub enum ExecutorApprovalError { + #[error("executor approval session not registered")] + SessionNotRegistered, + #[error("executor approval request failed: {0}")] + RequestFailed(String), + #[error("executor approval service unavailable")] + ServiceUnavailable, +} + +impl ExecutorApprovalError { + pub fn request_failed(err: E) -> Self { + Self::RequestFailed(err.to_string()) + } +} + +/// Abstraction for executor approval backends. +#[async_trait] +pub trait ExecutorApprovalService: Send + Sync { + /// Registers the session identifier associated with subsequent approval requests. + async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError>; + + /// Requests approval for a tool invocation and waits for the final decision. + async fn request_tool_approval( + &self, + tool_name: &str, + tool_input: Value, + tool_call_id: &str, + ) -> Result; +} + +#[derive(Debug, Default)] +pub struct NoopExecutorApprovalService; + +#[async_trait] +impl ExecutorApprovalService for NoopExecutorApprovalService { + async fn register_session(&self, _session_id: &str) -> Result<(), ExecutorApprovalError> { + Ok(()) + } + + async fn request_tool_approval( + &self, + _tool_name: &str, + _tool_input: Value, + _tool_call_id: &str, + ) -> Result { + Ok(ApprovalStatus::Approved) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ToolCallMetadata { + pub tool_call_id: String, +} diff --git a/crates/executors/src/executors/codex.rs b/crates/executors/src/executors/codex.rs index 8eb6e3e3..d1f9f075 100644 --- a/crates/executors/src/executors/codex.rs +++ b/crates/executors/src/executors/codex.rs @@ -2,7 +2,6 @@ pub mod client; pub mod jsonrpc; pub mod normalize_logs; pub mod session; - use std::{ collections::HashMap, path::{Path, PathBuf}, @@ -11,8 +10,11 @@ use std::{ use async_trait::async_trait; use codex_app_server_protocol::NewConversationParams; -use codex_protocol::config_types::SandboxMode as CodexSandboxMode; +use codex_protocol::{ + config_types::SandboxMode as CodexSandboxMode, protocol::AskForApproval as CodexAskForApproval, +}; use command_group::AsyncCommandGroup; +use derivative::Derivative; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -28,6 +30,7 @@ use self::{ session::SessionHandler, }; use crate::{ + approvals::ExecutorApprovalService, command::{CmdOverrides, CommandBuilder, apply_overrides}, executors::{ AppendPrompt, ExecutorError, SpawnedChild, StandardCodingAgentExecutor, @@ -47,6 +50,25 @@ pub enum SandboxMode { DangerFullAccess, } +/// Determines when the user is consulted to approve Codex actions. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, JsonSchema, AsRefStr)] +#[serde(rename_all = "kebab-case")] +#[strum(serialize_all = "kebab-case")] +pub enum AskForApproval { + /// Read-only commands are auto-approved. Everything else will ask the user to approve. + UnlessTrusted, + + /// All commands run in a restricted sandbox initially. + /// If the command fails, the user is asked to approve execution without the sandbox. + OnFailure, + + /// The model decides when to ask the user for approval. + OnRequest, + + /// Never ask the user to approve commands. Commands that fail in the restricted sandbox will not be retried. + Never, +} + /// Reasoning effort for the underlying model #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, JsonSchema, AsRefStr)] #[serde(rename_all = "kebab-case")] @@ -77,13 +99,16 @@ pub enum ReasoningSummaryFormat { Experimental, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, JsonSchema)] +#[derive(Derivative, Clone, Serialize, Deserialize, TS, JsonSchema)] +#[derivative(Debug, PartialEq)] pub struct Codex { #[serde(default)] pub append_prompt: AppendPrompt, #[serde(default, skip_serializing_if = "Option::is_none")] pub sandbox: Option, #[serde(default, skip_serializing_if = "Option::is_none")] + pub ask_for_approval: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub oss: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub model: Option, @@ -103,10 +128,19 @@ pub struct Codex { pub include_apply_patch_tool: Option, #[serde(flatten)] pub cmd: CmdOverrides, + + #[serde(skip)] + #[ts(skip)] + #[derivative(Debug = "ignore", PartialEq = "ignore")] + approvals: Option>, } #[async_trait] impl StandardCodingAgentExecutor for Codex { + fn use_approvals(&mut self, approvals: Arc) { + self.approvals = Some(approvals); + } + async fn spawn(&self, current_dir: &Path, prompt: &str) -> Result { let command = self.build_command_builder().build_initial(); self.spawn(current_dir, prompt, command, None).await @@ -151,11 +185,18 @@ impl Codex { Some(SandboxMode::DangerFullAccess) => Some(CodexSandboxMode::DangerFullAccess), }; + let approval_policy = self.ask_for_approval.as_ref().map(|policy| match policy { + AskForApproval::UnlessTrusted => CodexAskForApproval::UnlessTrusted, + AskForApproval::OnFailure => CodexAskForApproval::OnFailure, + AskForApproval::OnRequest => CodexAskForApproval::OnRequest, + AskForApproval::Never => CodexAskForApproval::Never, + }); + NewConversationParams { model: self.model.clone(), profile: self.profile.clone(), cwd: Some(cwd.to_string_lossy().to_string()), - approval_policy: None, + approval_policy, sandbox, config: self.build_config_overrides(), base_instructions: self.base_instructions.clone(), @@ -234,6 +275,11 @@ impl Codex { let params = self.build_new_conversation_params(current_dir); let resume_session = resume_session.map(|s| s.to_string()); + let auto_approve = matches!( + (&self.sandbox, &self.ask_for_approval), + (Some(SandboxMode::DangerFullAccess), None) + ); + let approvals = self.approvals.clone(); tokio::spawn(async move { let exit_signal_tx = ExitSignalSender::new(exit_signal_tx); let log_writer = LogWriter::new(new_stdout); @@ -245,6 +291,8 @@ impl Codex { child_stdin, log_writer.clone(), exit_signal_tx.clone(), + approvals, + auto_approve, ) .await { @@ -268,6 +316,7 @@ impl Codex { }) } + #[allow(clippy::too_many_arguments)] async fn launch_codex_app_server( conversation_params: NewConversationParams, resume_session: Option, @@ -276,8 +325,10 @@ impl Codex { child_stdin: tokio::process::ChildStdin, log_writer: LogWriter, exit_signal_tx: ExitSignalSender, + approvals: Option>, + auto_approve: bool, ) -> Result<(), ExecutorError> { - let client = AppServerClient::new(log_writer); + let client = AppServerClient::new(log_writer, approvals, auto_approve); let rpc_peer = JsonRpcPeer::spawn(child_stdin, child_stdout, client.clone(), exit_signal_tx); client.connect(rpc_peer); @@ -286,11 +337,11 @@ impl Codex { None => { let params = conversation_params; let response = client.new_conversation(params).await?; + let conversation_id = response.conversation_id; + client.register_session(&conversation_id).await?; + client.add_conversation_listener(conversation_id).await?; client - .add_conversation_listener(response.conversation_id) - .await?; - client - .send_user_message(response.conversation_id, combined_prompt) + .send_user_message(conversation_id, combined_prompt) .await?; } Some(session_id) => { @@ -306,11 +357,11 @@ impl Codex { rollout_path.display(), response ); + let conversation_id = response.conversation_id; + client.register_session(&conversation_id).await?; + client.add_conversation_listener(conversation_id).await?; client - .add_conversation_listener(response.conversation_id) - .await?; - client - .send_user_message(response.conversation_id, combined_prompt) + .send_user_message(conversation_id, combined_prompt) .await?; } } diff --git a/crates/executors/src/executors/codex/client.rs b/crates/executors/src/executors/codex/client.rs index 037c2610..c28cd2cb 100644 --- a/crates/executors/src/executors/codex/client.rs +++ b/crates/executors/src/executors/codex/client.rs @@ -1,5 +1,6 @@ use std::{ borrow::Cow, + collections::VecDeque, io, sync::{Arc, OnceLock}, }; @@ -13,26 +14,43 @@ use codex_app_server_protocol::{ ResumeConversationParams, ResumeConversationResponse, SendUserMessageParams, SendUserMessageResponse, ServerNotification, ServerRequest, }; +use codex_protocol::{ConversationId, protocol::ReviewDecision}; use serde::{Serialize, de::DeserializeOwned}; -use serde_json::Value; +use serde_json::{self, Value}; use tokio::{ io::{AsyncWrite, AsyncWriteExt, BufWriter}, sync::Mutex, }; +use workspace_utils::approvals::ApprovalStatus; use super::jsonrpc::{JsonRpcCallbacks, JsonRpcPeer}; -use crate::executors::ExecutorError; +use crate::{ + approvals::{ExecutorApprovalError, ExecutorApprovalService}, + executors::{ExecutorError, codex::normalize_logs::Approval}, +}; pub struct AppServerClient { rpc: OnceLock, log_writer: LogWriter, + approvals: Option>, + conversation_id: Mutex>, + pending_feedback: Mutex>, + auto_approve: bool, } impl AppServerClient { - pub fn new(log_writer: LogWriter) -> Arc { + pub fn new( + log_writer: LogWriter, + approvals: Option>, + auto_approve: bool, + ) -> Arc { Arc::new(Self { rpc: OnceLock::new(), log_writer, + approvals, + auto_approve, + conversation_id: Mutex::new(None), + pending_feedback: Mutex::new(VecDeque::new()), }) } @@ -113,6 +131,120 @@ impl AppServerClient { self.send_request(request, "sendUserMessage").await } + async fn handle_server_request( + &self, + peer: &JsonRpcPeer, + request: ServerRequest, + ) -> Result<(), ExecutorError> { + match request { + ServerRequest::ApplyPatchApproval { request_id, params } => { + let input = serde_json::to_value(¶ms) + .map_err(|err| ExecutorError::Io(io::Error::other(err.to_string())))?; + let status = match self + .request_tool_approval("edit", input, ¶ms.call_id) + .await + { + Ok(status) => status, + Err(err) => { + tracing::error!("failed to request patch approval: {err}"); + ApprovalStatus::Denied { + reason: Some("approval service error".to_string()), + } + } + }; + self.log_writer + .log_raw( + &Approval::approval_response( + params.call_id, + "codex.apply_patch".to_string(), + status.clone(), + ) + .raw(), + ) + .await?; + let (decision, feedback) = self.review_decision(&status).await?; + let response = ApplyPatchApprovalResponse { decision }; + send_server_response(peer, request_id, response).await?; + if let Some(message) = feedback { + tracing::debug!("queueing patch denial feedback: {message}"); + self.enqueue_feedback(message).await; + } + Ok(()) + } + ServerRequest::ExecCommandApproval { request_id, params } => { + let input = serde_json::to_value(¶ms) + .map_err(|err| ExecutorError::Io(io::Error::other(err.to_string())))?; + let status = match self + .request_tool_approval("bash", input, ¶ms.call_id) + .await + { + Ok(status) => status, + Err(err) => { + tracing::error!("failed to request command approval: {err}"); + ApprovalStatus::Denied { + reason: Some("approval service error".to_string()), + } + } + }; + self.log_writer + .log_raw( + &Approval::approval_response( + params.call_id, + "codex.exec_command".to_string(), + status.clone(), + ) + .raw(), + ) + .await?; + + let (decision, feedback) = self.review_decision(&status).await?; + let response = ExecCommandApprovalResponse { decision }; + send_server_response(peer, request_id, response).await?; + if let Some(message) = feedback { + tracing::debug!("queueing exec denial feedback: {message}"); + self.enqueue_feedback(message).await; + } + Ok(()) + } + } + } + + async fn request_tool_approval( + &self, + tool_name: &str, + tool_input: Value, + tool_call_id: &str, + ) -> Result { + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + if self.auto_approve { + return Ok(ApprovalStatus::Approved); + } + Ok(self + .approvals + .as_ref() + .ok_or(ExecutorApprovalError::ServiceUnavailable)? + .request_tool_approval(tool_name, tool_input, tool_call_id) + .await?) + } + + pub async fn register_session( + &self, + conversation_id: &ConversationId, + ) -> Result<(), ExecutorError> { + { + let mut guard = self.conversation_id.lock().await; + guard.replace(*conversation_id); + } + if let Some(approvals) = self.approvals.as_ref() { + approvals + .register_session(&conversation_id.to_string()) + .await + .map_err(|err| ExecutorError::Io(io::Error::other(err.to_string())))?; + } + self.flush_pending_feedback().await; + Ok(()) + } + async fn send_message(&self, message: &M) -> Result<(), ExecutorError> where M: Serialize + Sync, @@ -131,6 +263,94 @@ impl AppServerClient { fn next_request_id(&self) -> RequestId { self.rpc().next_request_id() } + + async fn review_decision( + &self, + status: &ApprovalStatus, + ) -> Result<(ReviewDecision, Option), ExecutorError> { + if self.auto_approve { + return Ok((ReviewDecision::ApprovedForSession, None)); + } + + let outcome = match status { + ApprovalStatus::Approved => (ReviewDecision::Approved, None), + ApprovalStatus::Denied { reason } => { + let feedback = reason + .as_ref() + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()); + if feedback.is_some() { + (ReviewDecision::Abort, feedback) + } else { + (ReviewDecision::Denied, None) + } + } + ApprovalStatus::TimedOut => (ReviewDecision::Denied, None), + ApprovalStatus::Pending => (ReviewDecision::Denied, None), + }; + Ok(outcome) + } + + async fn enqueue_feedback(&self, message: String) { + if message.trim().is_empty() { + return; + } + let mut guard = self.pending_feedback.lock().await; + guard.push_back(message); + } + + async fn flush_pending_feedback(&self) { + let messages: Vec = { + let mut guard = self.pending_feedback.lock().await; + guard.drain(..).collect() + }; + + if messages.is_empty() { + return; + } + + let Some(conversation_id) = *self.conversation_id.lock().await else { + tracing::warn!( + "pending Codex feedback but conversation id unavailable; dropping {} messages", + messages.len() + ); + return; + }; + + for message in messages { + let trimmed = message.trim(); + if trimmed.is_empty() { + continue; + } + self.spawn_feedback_message(conversation_id, trimmed.to_string()); + } + } + + fn spawn_feedback_message(&self, conversation_id: ConversationId, feedback: String) { + let peer = self.rpc().clone(); + let request = ClientRequest::SendUserMessage { + request_id: peer.next_request_id(), + params: SendUserMessageParams { + conversation_id, + items: vec![InputItem::Text { + text: format!("User feedback: {feedback}"), + }], + }, + }; + tokio::spawn(async move { + if let Err(err) = peer + .request::( + request_id(&request), + &request, + "sendUserMessage", + ) + .await + { + tracing::error!("failed to send feedback follow-up message: {err}"); + } + }); + } } #[async_trait] @@ -143,7 +363,7 @@ impl JsonRpcCallbacks for AppServerClient { ) -> Result<(), ExecutorError> { self.log_writer.log_raw(raw).await?; match ServerRequest::try_from(request.clone()) { - Ok(server_request) => handle_server_request(peer, server_request).await, + Ok(server_request) => self.handle_server_request(peer, server_request).await, Err(err) => { tracing::debug!("Unhandled server request `{}`: {err}", request.method); let response = JSONRPCResponse { @@ -200,6 +420,12 @@ impl JsonRpcCallbacks for AppServerClient { return Ok(false); } + if method.ends_with("turn_aborted") { + tracing::debug!("codex turn aborted; flushing feedback queue"); + self.flush_pending_feedback().await; + return Ok(false); + } + let has_finished = method .strip_prefix("codex/event/") .is_some_and(|suffix| suffix == "task_complete"); @@ -213,27 +439,6 @@ impl JsonRpcCallbacks for AppServerClient { } } -// Aprovals -async fn handle_server_request( - peer: &JsonRpcPeer, - request: ServerRequest, -) -> Result<(), ExecutorError> { - match request { - ServerRequest::ApplyPatchApproval { request_id, .. } => { - let response = ApplyPatchApprovalResponse { - decision: codex_protocol::protocol::ReviewDecision::ApprovedForSession, - }; - send_server_response(peer, request_id, response).await - } - ServerRequest::ExecCommandApproval { request_id, .. } => { - let response = ExecCommandApprovalResponse { - decision: codex_protocol::protocol::ReviewDecision::ApprovedForSession, - }; - send_server_response(peer, request_id, response).await - } - } -} - async fn send_server_response( peer: &JsonRpcPeer, request_id: RequestId, diff --git a/crates/executors/src/executors/codex/normalize_logs.rs b/crates/executors/src/executors/codex/normalize_logs.rs index f6ba8a91..f56574d4 100644 --- a/crates/executors/src/executors/codex/normalize_logs.rs +++ b/crates/executors/src/executors/codex/normalize_logs.rs @@ -13,11 +13,12 @@ use codex_protocol::{ plan_tool::{StepStatus, UpdatePlanArgs}, protocol::{ AgentMessageDeltaEvent, AgentMessageEvent, AgentReasoningDeltaEvent, AgentReasoningEvent, - AgentReasoningSectionBreakEvent, BackgroundEventEvent, ErrorEvent, EventMsg, - ExecCommandBeginEvent, ExecCommandEndEvent, ExecCommandOutputDeltaEvent, ExecOutputStream, - FileChange as CodexProtoFileChange, McpInvocation, McpToolCallBeginEvent, - McpToolCallEndEvent, PatchApplyBeginEvent, PatchApplyEndEvent, StreamErrorEvent, - TokenUsageInfo, ViewImageToolCallEvent, WebSearchBeginEvent, WebSearchEndEvent, + AgentReasoningSectionBreakEvent, ApplyPatchApprovalRequestEvent, BackgroundEventEvent, + ErrorEvent, EventMsg, ExecApprovalRequestEvent, ExecCommandBeginEvent, ExecCommandEndEvent, + ExecCommandOutputDeltaEvent, ExecOutputStream, FileChange as CodexProtoFileChange, + McpInvocation, McpToolCallBeginEvent, McpToolCallEndEvent, PatchApplyBeginEvent, + PatchApplyEndEvent, StreamErrorEvent, TokenUsageInfo, ViewImageToolCallEvent, + WebSearchBeginEvent, WebSearchEndEvent, }, }; use futures::StreamExt; @@ -26,12 +27,14 @@ use regex::Regex; use serde::{Deserialize, Serialize}; use serde_json::Value; use workspace_utils::{ + approvals::ApprovalStatus, diff::{concatenate_diff_hunks, extract_unified_diff_hunks}, msg_store::MsgStore, path::make_path_relative, }; use crate::{ + approvals::ToolCallMetadata, executors::codex::session::SessionHandler, logs::{ ActionType, CommandExitStatus, CommandRunResult, FileChange, NormalizedEntry, @@ -45,6 +48,10 @@ trait ToNormalizedEntry { fn to_normalized_entry(&self) -> NormalizedEntry; } +trait ToNormalizedEntryOpt { + fn to_normalized_entry_opt(&self) -> Option; +} + #[derive(Debug, Deserialize)] struct CodexNotificationParams { #[serde(rename = "msg")] @@ -66,10 +73,14 @@ struct CommandState { formatted_output: Option, status: ToolStatus, exit_code: Option, + awaiting_approval: bool, + call_id: String, } impl ToNormalizedEntry for CommandState { fn to_normalized_entry(&self) -> NormalizedEntry { + let content = format!("`{}`", self.command); + NormalizedEntry { timestamp: None, entry_type: NormalizedEntryType::ToolUse { @@ -89,8 +100,11 @@ impl ToNormalizedEntry for CommandState { }, status: self.status.clone(), }, - content: format!("`{}`", self.command), - metadata: None, + content, + metadata: serde_json::to_value(ToolCallMetadata { + tool_call_id: self.call_id.clone(), + }) + .ok(), } } } @@ -165,10 +179,14 @@ struct PatchEntry { path: String, changes: Vec, status: ToolStatus, + awaiting_approval: bool, + call_id: String, } impl ToNormalizedEntry for PatchEntry { fn to_normalized_entry(&self) -> NormalizedEntry { + let content = self.path.clone(); + NormalizedEntry { timestamp: None, entry_type: NormalizedEntryType::ToolUse { @@ -179,8 +197,11 @@ impl ToNormalizedEntry for PatchEntry { }, status: self.status.clone(), }, - content: self.path.clone(), - metadata: None, + content, + metadata: serde_json::to_value(ToolCallMetadata { + tool_call_id: self.call_id.clone(), + }) + .ok(), } } } @@ -385,6 +406,13 @@ pub fn normalize_logs(msg_store: Arc, worktree_path: &Path) { continue; } + if let Ok(approval) = serde_json::from_str::(&line) { + if let Some(entry) = approval.to_normalized_entry_opt() { + add_normalized_entry(&msg_store, &entry_index, entry); + } + continue; + } + if let Ok(response) = serde_json::from_str::(&line) { handle_jsonrpc_response(response, &msg_store, &entry_index); continue; @@ -466,6 +494,80 @@ pub fn normalize_logs(msg_store: Arc, worktree_path: &Path) { state.assistant = None; state.thinking = None; } + EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { + call_id, + command, + cwd: _, + reason, + }) => { + state.assistant = None; + state.thinking = None; + + let command_text = if command.is_empty() { + reason + .filter(|r| !r.is_empty()) + .unwrap_or_else(|| "command execution".to_string()) + } else { + command.join(" ") + }; + + let command_state = state.commands.entry(call_id.clone()).or_default(); + + if command_state.command.is_empty() { + command_state.command = command_text; + } + command_state.awaiting_approval = true; + if let Some(index) = command_state.index { + replace_normalized_entry( + &msg_store, + index, + command_state.to_normalized_entry(), + ); + } else { + let index = add_normalized_entry( + &msg_store, + &entry_index, + command_state.to_normalized_entry(), + ); + command_state.index = Some(index); + } + } + EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { + call_id, + changes, + reason: _, + grant_root: _, + }) => { + state.assistant = None; + state.thinking = None; + + let normalized = normalize_file_changes(&worktree_path_str, &changes); + let patch_state = state.patches.entry(call_id.clone()).or_default(); + + for entry in patch_state.entries.drain(..) { + if let Some(index) = entry.index { + msg_store.push_patch(ConversationPatch::remove(index)); + } + } + + for (path, file_changes) in normalized { + let mut entry = PatchEntry { + index: None, + path, + changes: file_changes, + status: ToolStatus::Created, + awaiting_approval: true, + call_id: call_id.clone(), + }; + let index = add_normalized_entry( + &msg_store, + &entry_index, + entry.to_normalized_entry(), + ); + entry.index = Some(index); + patch_state.entries.push(entry); + } + } EventMsg::ExecCommandBegin(ExecCommandBeginEvent { call_id, command, .. }) => { @@ -485,6 +587,8 @@ pub fn normalize_logs(msg_store: Arc, worktree_path: &Path) { formatted_output: None, status: ToolStatus::Created, exit_code: None, + awaiting_approval: false, + call_id: call_id.clone(), }, ); let command_state = state.commands.get_mut(&call_id).unwrap(); @@ -532,6 +636,7 @@ pub fn normalize_logs(msg_store: Arc, worktree_path: &Path) { if let Some(mut command_state) = state.commands.remove(&call_id) { command_state.formatted_output = Some(formatted_output); command_state.exit_code = Some(exit_code); + command_state.awaiting_approval = false; command_state.status = if exit_code == 0 { ToolStatus::Success } else { @@ -664,23 +769,68 @@ pub fn normalize_logs(msg_store: Arc, worktree_path: &Path) { state.assistant = None; state.thinking = None; let normalized = normalize_file_changes(&worktree_path_str, &changes); - let mut patch_state = PatchState::default(); - for (path, file_changes) in normalized { - patch_state.entries.push(PatchEntry { - index: None, - path, - changes: file_changes, - status: ToolStatus::Created, - }); - let patch_entry = patch_state.entries.last_mut().unwrap(); - let index = add_normalized_entry( - &msg_store, - &entry_index, - patch_entry.to_normalized_entry(), - ); - patch_entry.index = Some(index); + if let Some(patch_state) = state.patches.get_mut(&call_id) { + let mut iter = normalized.into_iter(); + for entry in &mut patch_state.entries { + if let Some((path, file_changes)) = iter.next() { + entry.path = path; + entry.changes = file_changes; + } + entry.status = ToolStatus::Created; + entry.awaiting_approval = false; + if let Some(index) = entry.index { + replace_normalized_entry( + &msg_store, + index, + entry.to_normalized_entry(), + ); + } else { + let index = add_normalized_entry( + &msg_store, + &entry_index, + entry.to_normalized_entry(), + ); + entry.index = Some(index); + } + } + for (path, file_changes) in iter { + let mut entry = PatchEntry { + index: None, + path, + changes: file_changes, + status: ToolStatus::Created, + awaiting_approval: false, + call_id: call_id.clone(), + }; + let index = add_normalized_entry( + &msg_store, + &entry_index, + entry.to_normalized_entry(), + ); + entry.index = Some(index); + patch_state.entries.push(entry); + } + } else { + let mut patch_state = PatchState::default(); + for (path, file_changes) in normalized { + patch_state.entries.push(PatchEntry { + index: None, + path, + changes: file_changes, + status: ToolStatus::Created, + awaiting_approval: false, + call_id: call_id.clone(), + }); + let patch_entry = patch_state.entries.last_mut().unwrap(); + let index = add_normalized_entry( + &msg_store, + &entry_index, + patch_entry.to_normalized_entry(), + ); + patch_entry.index = Some(index); + } + state.patches.insert(call_id, patch_state); } - state.patches.insert(call_id, patch_state); } EventMsg::PatchApplyEnd(PatchApplyEndEvent { call_id, @@ -826,9 +976,7 @@ pub fn normalize_logs(msg_store: Arc, worktree_path: &Path) { | EventMsg::ConversationPath(..) | EventMsg::EnteredReviewMode(..) | EventMsg::ExitedReviewMode(..) - | EventMsg::TaskComplete(..) - | EventMsg::ExecApprovalRequest(..) - | EventMsg::ApplyPatchApprovalRequest(..) => {} + | EventMsg::TaskComplete(..) => {} } } }); @@ -937,3 +1085,73 @@ impl ToNormalizedEntry for Error { } } } + +#[derive(Serialize, Deserialize, Debug)] +pub enum Approval { + ApprovalResponse { + call_id: String, + tool_name: String, + approval_status: ApprovalStatus, + }, +} + +impl Approval { + pub fn approval_response( + call_id: String, + tool_name: String, + approval_status: ApprovalStatus, + ) -> Self { + Self::ApprovalResponse { + call_id, + tool_name, + approval_status, + } + } + + pub fn raw(&self) -> String { + serde_json::to_string(self).unwrap_or_default() + } + + pub fn display_tool_name(&self) -> String { + let Self::ApprovalResponse { tool_name, .. } = self; + match tool_name.as_str() { + "codex.exec_command" => "Exec Command".to_string(), + "codex.apply_patch" => "Edit".to_string(), + other => other.to_string(), + } + } +} + +impl ToNormalizedEntryOpt for Approval { + fn to_normalized_entry_opt(&self) -> Option { + let Self::ApprovalResponse { + call_id: _, + tool_name: _, + approval_status, + } = self; + let tool_name = self.display_tool_name(); + + match approval_status { + ApprovalStatus::Pending => None, + ApprovalStatus::Approved => None, + ApprovalStatus::Denied { reason } => Some(NormalizedEntry { + timestamp: None, + entry_type: NormalizedEntryType::UserFeedback { + denied_tool: tool_name.clone(), + }, + content: reason + .clone() + .unwrap_or_else(|| "User denied this tool use request".to_string()) + .trim() + .to_string(), + metadata: None, + }), + ApprovalStatus::TimedOut => Some(NormalizedEntry { + timestamp: None, + entry_type: NormalizedEntryType::ErrorMessage, + content: format!("Approval timed out for tool {tool_name}"), + metadata: None, + }), + } + } +} diff --git a/crates/executors/src/executors/mod.rs b/crates/executors/src/executors/mod.rs index 6a612ab0..7426f257 100644 --- a/crates/executors/src/executors/mod.rs +++ b/crates/executors/src/executors/mod.rs @@ -13,6 +13,7 @@ use ts_rs::TS; use workspace_utils::msg_store::MsgStore; use crate::{ + approvals::ExecutorApprovalService, executors::{ amp::Amp, claude::ClaudeCode, codex::Codex, copilot::Copilot, cursor::Cursor, gemini::Gemini, opencode::Opencode, qwen::QwenCode, @@ -52,6 +53,8 @@ pub enum ExecutorError { TomlSerialize(#[from] toml::ser::Error), #[error(transparent)] TomlDeserialize(#[from] toml::de::Error), + #[error(transparent)] + ExecutorApprovalError(#[from] crate::approvals::ExecutorApprovalError), } #[enum_dispatch] @@ -138,6 +141,8 @@ impl CodingAgent { #[async_trait] #[enum_dispatch(CodingAgent)] pub trait StandardCodingAgentExecutor { + fn use_approvals(&mut self, _approvals: Arc) {} + async fn spawn(&self, current_dir: &Path, prompt: &str) -> Result; async fn spawn_follow_up( &self, diff --git a/crates/executors/src/lib.rs b/crates/executors/src/lib.rs index 61e04c21..13d77407 100644 --- a/crates/executors/src/lib.rs +++ b/crates/executors/src/lib.rs @@ -1,4 +1,5 @@ pub mod actions; +pub mod approvals; pub mod command; pub mod executors; pub mod logs; diff --git a/crates/executors/src/logs/utils/patch.rs b/crates/executors/src/logs/utils/patch.rs index 91685263..746cc2c2 100644 --- a/crates/executors/src/logs/utils/patch.rs +++ b/crates/executors/src/logs/utils/patch.rs @@ -113,6 +113,14 @@ impl ConversationPatch { from_value(json!([patch_entry])).unwrap() } + + pub fn remove(entry_index: usize) -> Patch { + from_value(json!([{ + "op": PatchOperation::Remove, + "path": format!("/entries/{entry_index}"), + }])) + .unwrap() + } } /// Extract the entry index and `NormalizedEntry` from a JsonPatch if it contains one diff --git a/crates/local-deployment/src/container.rs b/crates/local-deployment/src/container.rs index 354d472f..a9925a6f 100644 --- a/crates/local-deployment/src/container.rs +++ b/crates/local-deployment/src/container.rs @@ -27,6 +27,8 @@ use db::{ use deployment::DeploymentError; use executors::{ actions::{Executable, ExecutorAction}, + approvals::{ExecutorApprovalService, NoopExecutorApprovalService}, + executors::BaseCodingAgent, logs::{ NormalizedEntryType, utils::{ @@ -39,6 +41,7 @@ use futures::{FutureExt, StreamExt, TryStreamExt, stream::select}; use serde_json::json; use services::services::{ analytics::AnalyticsContext, + approvals::{Approvals, executor_approvals::ExecutorApprovalBridge}, config::Config, container::{ContainerError, ContainerRef, ContainerService}, diff_stream::{self, DiffStreamHandle}, @@ -67,6 +70,7 @@ pub struct LocalContainerService { git: GitService, image_service: ImageService, analytics: Option, + approvals: Approvals, } impl LocalContainerService { @@ -77,6 +81,7 @@ impl LocalContainerService { git: GitService, image_service: ImageService, analytics: Option, + approvals: Approvals, ) -> Self { let child_store = Arc::new(RwLock::new(HashMap::new())); @@ -88,6 +93,7 @@ impl LocalContainerService { git, image_service, analytics, + approvals, } } @@ -801,8 +807,20 @@ impl ContainerService for LocalContainerService { )))?; let current_dir = PathBuf::from(container_ref); + let approvals_service: Arc = + match executor_action.base_executor() { + Some(BaseCodingAgent::Codex) => ExecutorApprovalBridge::new( + self.approvals.clone(), + self.db.clone(), + execution_process.id, + ), + _ => Arc::new(NoopExecutorApprovalService {}), + }; + // Create the child and stream, add to execution tracker - let mut spawned = executor_action.spawn(¤t_dir).await?; + let mut spawned = executor_action + .spawn(¤t_dir, approvals_service) + .await?; self.track_child_msgs_in_store(execution_process.id, &mut spawned.child) .await; diff --git a/crates/local-deployment/src/lib.rs b/crates/local-deployment/src/lib.rs index b4b8bc80..12aa4b3a 100644 --- a/crates/local-deployment/src/lib.rs +++ b/crates/local-deployment/src/lib.rs @@ -118,6 +118,7 @@ impl Deployment for LocalDeployment { git.clone(), image.clone(), analytics_ctx, + approvals.clone(), ); container.spawn_worktree_cleanup().await; diff --git a/crates/server/src/bin/generate_types.rs b/crates/server/src/bin/generate_types.rs index a94ec25e..98cd5c1b 100644 --- a/crates/server/src/bin/generate_types.rs +++ b/crates/server/src/bin/generate_types.rs @@ -78,6 +78,7 @@ fn generate_types_content() -> String { executors::executors::amp::Amp::decl(), executors::executors::codex::Codex::decl(), executors::executors::codex::SandboxMode::decl(), + executors::executors::codex::AskForApproval::decl(), executors::executors::codex::ReasoningEffort::decl(), executors::executors::codex::ReasoningSummary::decl(), executors::executors::codex::ReasoningSummaryFormat::decl(), diff --git a/crates/services/src/services/approvals.rs b/crates/services/src/services/approvals.rs index 68898ab8..35bec4c2 100644 --- a/crates/services/src/services/approvals.rs +++ b/crates/services/src/services/approvals.rs @@ -1,3 +1,5 @@ +pub mod executor_approvals; + use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration}; use chrono::{DateTime, Utc}; @@ -7,10 +9,14 @@ use db::models::{ executor_session::ExecutorSession, task::{Task, TaskStatus}, }; -use executors::logs::{ - NormalizedEntry, NormalizedEntryType, ToolStatus, - utils::patch::{ConversationPatch, extract_normalized_entry_from_patch}, +use executors::{ + approvals::ToolCallMetadata, + logs::{ + NormalizedEntry, NormalizedEntryType, ToolStatus, + utils::patch::{ConversationPatch, extract_normalized_entry_from_patch}, + }, }; +use futures::future::{BoxFuture, FutureExt, Shared}; use sqlx::{Error as SqlxError, SqlitePool}; use thiserror::Error; use tokio::sync::{RwLock, oneshot}; @@ -35,6 +41,8 @@ struct PendingApproval { response_tx: oneshot::Sender, } +type ApprovalWaiter = Shared>; + #[derive(Debug)] pub struct ToolContext { pub tool_name: String, @@ -73,15 +81,25 @@ impl Approvals { } } - #[tracing::instrument(skip(self, request))] - pub async fn create(&self, request: ApprovalRequest) -> Result { + async fn create_internal( + &self, + request: ApprovalRequest, + ) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> { let (tx, rx) = oneshot::channel(); + let waiter: ApprovalWaiter = rx + .map(|result| result.unwrap_or(ApprovalStatus::TimedOut)) + .boxed() + .shared(); let req_id = request.id.clone(); if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await { // Find the matching tool use entry by name and input - let matching_tool = - find_matching_tool_use(store.clone(), &request.tool_name, &request.tool_input); + let matching_tool = find_matching_tool_use( + store.clone(), + &request.tool_name, + &request.tool_input, + request.tool_call_id.as_deref(), + ); if let Some((idx, matching_tool)) = matching_tool { let approval_entry = matching_tool @@ -125,10 +143,23 @@ impl Approvals { ); } - self.spawn_timeout_watcher(req_id.clone(), request.timeout_at, rx); + self.spawn_timeout_watcher(req_id.clone(), request.timeout_at, waiter.clone()); + Ok((request, waiter)) + } + + #[tracing::instrument(skip(self, request))] + pub async fn create(&self, request: ApprovalRequest) -> Result { + let (request, _) = self.create_internal(request).await?; Ok(request) } + pub async fn create_with_waiter( + &self, + request: ApprovalRequest, + ) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> { + self.create_internal(request).await + } + pub async fn create_from_session( &self, pool: &SqlitePool, @@ -145,15 +176,7 @@ impl Approvals { }; // Move the task to InReview if it's still InProgress - if let Ok(ctx) = ExecutionProcess::load_context(pool, execution_process_id).await - && ctx.task.status == TaskStatus::InProgress - && let Err(e) = Task::update_status(pool, ctx.task.id, TaskStatus::InReview).await - { - tracing::warn!( - "Failed to update task status to InReview for approval request: {}", - e - ); - } + ensure_task_in_review(pool, execution_process_id).await; let request = ApprovalRequest::from_create(payload, execution_process_id); self.create(request).await @@ -248,12 +271,12 @@ impl Approvals { .collect() } - #[tracing::instrument(skip(self, id, timeout_at, rx))] + #[tracing::instrument(skip(self, id, timeout_at, waiter))] fn spawn_timeout_watcher( &self, id: String, timeout_at: chrono::DateTime, - mut rx: oneshot::Receiver, + waiter: ApprovalWaiter, ) { let pending = self.pending.clone(); let completed = self.completed.clone(); @@ -269,19 +292,18 @@ impl Approvals { let status = tokio::select! { biased; - r = &mut rx => match r { - Ok(status) => status, - Err(_canceled) => ApprovalStatus::TimedOut, - }, + resolved = waiter.clone() => resolved, _ = tokio::time::sleep_until(deadline) => ApprovalStatus::TimedOut, }; let is_timeout = matches!(&status, ApprovalStatus::TimedOut); completed.insert(id.clone(), status.clone()); - let removed = pending.remove(&id); + if is_timeout && let Some((_, pending_approval)) = pending.remove(&id) { + if pending_approval.response_tx.send(status.clone()).is_err() { + tracing::debug!("approval '{}' timeout notification receiver dropped", id); + } - if is_timeout && let Some((_, pending_approval)) = removed { let store = { let map = msg_stores.read().await; map.get(&pending_approval.execution_process_id).cloned() @@ -318,8 +340,22 @@ impl Approvals { } } +pub(crate) async fn ensure_task_in_review(pool: &SqlitePool, execution_process_id: Uuid) { + if let Ok(ctx) = ExecutionProcess::load_context(pool, execution_process_id).await + && ctx.task.status == TaskStatus::InProgress + && let Err(e) = Task::update_status(pool, ctx.task.id, TaskStatus::InReview).await + { + tracing::warn!( + "Failed to update task status to InReview for approval request: {}", + e + ); + } +} + /// Comparison strategy for matching tool use entries enum ToolComparisonStrategy { + /// Compare by tool_call_id + ToolCallId(String), /// Compare deserialized ClaudeToolData structures (for known tools) Deserialized(executors::executors::claude::ClaudeToolData), /// Compare raw JSON input fields (for Unknown tools like MCP) @@ -332,31 +368,37 @@ fn find_matching_tool_use( store: Arc, tool_name: &str, tool_input: &serde_json::Value, + tool_call_id: Option<&str>, ) -> Option<(usize, NormalizedEntry)> { use executors::executors::claude::ClaudeToolData; let history = store.get_history(); // Determine comparison strategy based on tool type - let strategy = match serde_json::from_value::(serde_json::json!({ - "name": tool_name, - "input": tool_input - })) { - Ok(ClaudeToolData::Unknown { .. }) => { - // For Unknown tools (MCP, future tools), use raw JSON comparison - ToolComparisonStrategy::RawJson - } - Ok(data) => { - // For known tools, use deserialized comparison with proper alias handling - ToolComparisonStrategy::Deserialized(data) - } - Err(e) => { - tracing::warn!( - "Failed to deserialize tool_input for tool '{}': {}", - tool_name, - e - ); - return None; + let strategy = if let Some(call_id) = tool_call_id { + // If tool_call_id is provided, use it for matching + ToolComparisonStrategy::ToolCallId(call_id.to_string()) + } else { + match serde_json::from_value::(serde_json::json!({ + "name": tool_name, + "input": tool_input + })) { + Ok(ClaudeToolData::Unknown { .. }) => { + // For Unknown tools (MCP, future tools), use raw JSON comparison + ToolComparisonStrategy::RawJson + } + Ok(data) => { + // For known tools, use deserialized comparison with proper alias handling + ToolComparisonStrategy::Deserialized(data) + } + Err(e) => { + tracing::warn!( + "Failed to deserialize tool_input for tool '{}': {}", + tool_name, + e + ); + return None; + } } }; @@ -383,6 +425,18 @@ fn find_matching_tool_use( // Apply comparison strategy if let Some(metadata) = &entry.metadata { let is_match = match &strategy { + ToolComparisonStrategy::ToolCallId(call_id) => { + // Match by tool_call_id in metadata + if let Ok(ToolCallMetadata { + tool_call_id: entry_call_id, + .. + }) = serde_json::from_value::(metadata.clone()) + { + entry_call_id == *call_id + } else { + false + } + } ToolComparisonStrategy::RawJson => { // Compare raw JSON input for Unknown tools if let Some(entry_input) = metadata.get("input") { @@ -405,8 +459,13 @@ fn find_matching_tool_use( if is_match { let strategy_name = match strategy { - ToolComparisonStrategy::RawJson => "raw input comparison", - ToolComparisonStrategy::Deserialized(_) => "deserialized tool data", + ToolComparisonStrategy::ToolCallId(call_id) => { + format!("tool_call_id '{call_id}'") + } + ToolComparisonStrategy::RawJson => "raw input comparison".to_string(), + ToolComparisonStrategy::Deserialized(_) => { + "deserialized tool data".to_string() + } }; tracing::debug!( "Matched tool use entry at index {idx} for tool '{tool_name}' by {strategy_name}" @@ -483,12 +542,12 @@ mod tests { let bar_input = serde_json::json!({"file_path": "bar.rs"}); let baz_input = serde_json::json!({"file_path": "baz.rs"}); - let (idx_foo, _) = - find_matching_tool_use(store.clone(), "Read", &foo_input).expect("Should match foo.rs"); - let (idx_bar, _) = - find_matching_tool_use(store.clone(), "Read", &bar_input).expect("Should match bar.rs"); - let (idx_baz, _) = - find_matching_tool_use(store.clone(), "Read", &baz_input).expect("Should match baz.rs"); + let (idx_foo, _) = find_matching_tool_use(store.clone(), "Read", &foo_input, None) + .expect("Should match foo.rs"); + let (idx_bar, _) = find_matching_tool_use(store.clone(), "Read", &bar_input, None) + .expect("Should match bar.rs"); + let (idx_baz, _) = find_matching_tool_use(store.clone(), "Read", &baz_input, None) + .expect("Should match baz.rs"); assert_eq!(idx_foo, 0, "foo.rs should match first entry"); assert_eq!(idx_bar, 1, "bar.rs should match second entry"); @@ -510,21 +569,21 @@ mod tests { let pending_input = serde_json::json!({"file_path": "pending.rs"}); assert!( - find_matching_tool_use(store.clone(), "Read", &pending_input).is_none(), + find_matching_tool_use(store.clone(), "Read", &pending_input, None).is_none(), "Should not match tools in PendingApproval state" ); // Test 3: Wrong tool name returns None let write_input = serde_json::json!({"file_path": "foo.rs", "content": "test"}); assert!( - find_matching_tool_use(store.clone(), "Write", &write_input).is_none(), + find_matching_tool_use(store.clone(), "Write", &write_input, None).is_none(), "Should not match different tool names" ); // Test 4: Wrong input parameters returns None let wrong_input = serde_json::json!({"file_path": "nonexistent.rs"}); assert!( - find_matching_tool_use(store.clone(), "Read", &wrong_input).is_none(), + find_matching_tool_use(store.clone(), "Read", &wrong_input, None).is_none(), "Should not match with different input parameters" ); } diff --git a/crates/services/src/services/approvals/executor_approvals.rs b/crates/services/src/services/approvals/executor_approvals.rs new file mode 100644 index 00000000..6dba8eaf --- /dev/null +++ b/crates/services/src/services/approvals/executor_approvals.rs @@ -0,0 +1,81 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use db::{self, DBService}; +use executors::approvals::{ExecutorApprovalError, ExecutorApprovalService}; +use serde_json::Value; +use tokio::sync::RwLock; +use utils::approvals::{ApprovalRequest, ApprovalStatus, CreateApprovalRequest}; +use uuid::Uuid; + +use crate::services::approvals::Approvals; + +pub struct ExecutorApprovalBridge { + approvals: Approvals, + db: DBService, + execution_process_id: Uuid, + session_id: RwLock>, +} + +impl ExecutorApprovalBridge { + pub fn new(approvals: Approvals, db: DBService, execution_process_id: Uuid) -> Arc { + Arc::new(Self { + approvals, + db, + execution_process_id, + session_id: RwLock::new(None), + }) + } +} + +#[async_trait] +impl ExecutorApprovalService for ExecutorApprovalBridge { + async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError> { + let mut guard = self.session_id.write().await; + guard.replace(session_id.to_string()); + + Ok(()) + } + + async fn request_tool_approval( + &self, + tool_name: &str, + tool_input: Value, + tool_call_id: &str, + ) -> Result { + let session_id = { + let guard = self.session_id.read().await; + guard + .clone() + .ok_or(ExecutorApprovalError::SessionNotRegistered)? + }; + + super::ensure_task_in_review(&self.db.pool, self.execution_process_id).await; + + let request = ApprovalRequest::from_create( + CreateApprovalRequest { + tool_name: tool_name.to_string(), + tool_input, + session_id, + tool_call_id: Some(tool_call_id.to_string()), + }, + self.execution_process_id, + ); + + let (_, waiter) = self + .approvals + .create_with_waiter(request) + .await + .map_err(ExecutorApprovalError::request_failed)?; + + let status = waiter.clone().await; + + if matches!(status, ApprovalStatus::Pending) { + return Err(ExecutorApprovalError::request_failed( + "approval finished in pending state", + )); + } + + Ok(status) + } +} diff --git a/crates/utils/src/approvals.rs b/crates/utils/src/approvals.rs index efe3c87e..c331163f 100644 --- a/crates/utils/src/approvals.rs +++ b/crates/utils/src/approvals.rs @@ -12,6 +12,7 @@ pub struct ApprovalRequest { pub tool_name: String, pub tool_input: serde_json::Value, pub session_id: String, + pub tool_call_id: Option, pub execution_process_id: Uuid, pub created_at: DateTime, pub timeout_at: DateTime, @@ -25,6 +26,7 @@ impl ApprovalRequest { tool_name: request.tool_name, tool_input: request.tool_input, session_id: request.session_id, + tool_call_id: request.tool_call_id, execution_process_id, created_at: now, timeout_at: now + Duration::seconds(APPROVAL_TIMEOUT_SECONDS), @@ -38,6 +40,7 @@ pub struct CreateApprovalRequest { pub tool_name: String, pub tool_input: serde_json::Value, pub session_id: String, + pub tool_call_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize, TS)] diff --git a/shared/schemas/codex.json b/shared/schemas/codex.json index 7ebb3790..69cbbc46 100644 --- a/shared/schemas/codex.json +++ b/shared/schemas/codex.json @@ -25,6 +25,38 @@ null ] }, + "ask_for_approval": { + "anyOf": [ + { + "description": "Determines when the user is consulted to approve Codex actions.", + "oneOf": [ + { + "description": "Read-only commands are auto-approved. Everything else will ask the user to approve.", + "type": "string", + "const": "unless-trusted" + }, + { + "description": "All commands run in a restricted sandbox initially.\nIf the command fails, the user is asked to approve execution without the sandbox.", + "type": "string", + "const": "on-failure" + }, + { + "description": "The model decides when to ask the user for approval.", + "type": "string", + "const": "on-request" + }, + { + "description": "Never ask the user to approve commands. Commands that fail in the restricted sandbox will not be retried.", + "type": "string", + "const": "never" + } + ] + }, + { + "type": "null" + } + ] + }, "oss": { "type": [ "boolean", diff --git a/shared/types.ts b/shared/types.ts index a01b1026..54487537 100644 --- a/shared/types.ts +++ b/shared/types.ts @@ -164,10 +164,12 @@ export type GeminiModel = "default" | "flash"; export type Amp = { append_prompt: AppendPrompt, dangerously_allow_all?: boolean | null, base_command_override?: string | null, additional_params?: Array | null, }; -export type Codex = { append_prompt: AppendPrompt, sandbox?: SandboxMode | null, oss?: boolean | null, model?: string | null, model_reasoning_effort?: ReasoningEffort | null, model_reasoning_summary?: ReasoningSummary | null, model_reasoning_summary_format?: ReasoningSummaryFormat | null, profile?: string | null, base_instructions?: string | null, include_plan_tool?: boolean | null, include_apply_patch_tool?: boolean | null, base_command_override?: string | null, additional_params?: Array | null, }; +export type Codex = { append_prompt: AppendPrompt, sandbox?: SandboxMode | null, ask_for_approval?: AskForApproval | null, oss?: boolean | null, model?: string | null, model_reasoning_effort?: ReasoningEffort | null, model_reasoning_summary?: ReasoningSummary | null, model_reasoning_summary_format?: ReasoningSummaryFormat | null, profile?: string | null, base_instructions?: string | null, include_plan_tool?: boolean | null, include_apply_patch_tool?: boolean | null, base_command_override?: string | null, additional_params?: Array | null, }; export type SandboxMode = "auto" | "read-only" | "workspace-write" | "danger-full-access"; +export type AskForApproval = "unless-trusted" | "on-failure" | "on-request" | "never"; + export type ReasoningEffort = "low" | "medium" | "high"; export type ReasoningSummary = "auto" | "concise" | "detailed" | "none"; @@ -318,7 +320,7 @@ export type PatchType = { "type": "NORMALIZED_ENTRY", "content": NormalizedEntry export type ApprovalStatus = { "status": "pending" } | { "status": "approved" } | { "status": "denied", reason?: string, } | { "status": "timed_out" }; -export type CreateApprovalRequest = { tool_name: string, tool_input: JsonValue, session_id: string, }; +export type CreateApprovalRequest = { tool_name: string, tool_input: JsonValue, session_id: string, tool_call_id: string | null, }; export type ApprovalResponse = { execution_process_id: string, status: ApprovalStatus, };