codex approvals (#993)

* codex approvals

* send deny feedback

* Normalize user-feedback

* use tool call id to match normalized_entry

* store approvals in executor

* add noop approval for api consistency

---------

Co-authored-by: Gabriel Gordon-Hall <ggordonhall@gmail.com>
This commit is contained in:
Solomon
2025-10-20 18:02:58 +01:00
committed by GitHub
parent ee68b2fc43
commit 62834ea581
21 changed files with 942 additions and 139 deletions

View File

@@ -47,3 +47,4 @@ codex-protocol = { git = "https://github.com/openai/codex.git", package = "codex
codex-app-server-protocol = { git = "https://github.com/openai/codex.git", package = "codex-app-server-protocol", rev = "488ec061bf4d36916b8f477c700ea4fde4162a7a" }
codex-mcp-types = { git = "https://github.com/openai/codex.git", package = "mcp-types", rev = "488ec061bf4d36916b8f477c700ea4fde4162a7a" }
sha2 = "0.10"
derivative = "2.2.0"

View File

@@ -49,6 +49,12 @@
"sandbox": "danger-full-access",
"model_reasoning_effort": "high"
}
},
"APPROVALS": {
"CODEX": {
"sandbox": "workspace-write",
"ask_for_approval": "unless-trusted"
}
}
},
"OPENCODE": {

View File

@@ -1,4 +1,4 @@
use std::path::Path;
use std::{path::Path, sync::Arc};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
@@ -6,7 +6,8 @@ use ts_rs::TS;
use crate::{
actions::Executable,
executors::{ExecutorError, SpawnedChild, StandardCodingAgentExecutor},
approvals::ExecutorApprovalService,
executors::{BaseCodingAgent, ExecutorError, SpawnedChild, StandardCodingAgentExecutor},
profile::{ExecutorConfigs, ExecutorProfileId},
};
@@ -25,18 +26,28 @@ impl CodingAgentFollowUpRequest {
pub fn get_executor_profile_id(&self) -> ExecutorProfileId {
self.executor_profile_id.clone()
}
pub fn base_executor(&self) -> BaseCodingAgent {
self.executor_profile_id.executor
}
}
#[async_trait]
impl Executable for CodingAgentFollowUpRequest {
async fn spawn(&self, current_dir: &Path) -> Result<SpawnedChild, ExecutorError> {
async fn spawn(
&self,
current_dir: &Path,
approvals: Arc<dyn ExecutorApprovalService>,
) -> Result<SpawnedChild, ExecutorError> {
let executor_profile_id = self.get_executor_profile_id();
let agent = ExecutorConfigs::get_cached()
let mut agent = ExecutorConfigs::get_cached()
.get_coding_agent(&executor_profile_id)
.ok_or(ExecutorError::UnknownExecutorType(
executor_profile_id.to_string(),
))?;
agent.use_approvals(approvals.clone());
agent
.spawn_follow_up(current_dir, &self.prompt, &self.session_id)
.await

View File

@@ -1,4 +1,4 @@
use std::path::Path;
use std::{path::Path, sync::Arc};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
@@ -6,7 +6,8 @@ use ts_rs::TS;
use crate::{
actions::Executable,
executors::{ExecutorError, SpawnedChild, StandardCodingAgentExecutor},
approvals::ExecutorApprovalService,
executors::{BaseCodingAgent, ExecutorError, SpawnedChild, StandardCodingAgentExecutor},
profile::{ExecutorConfigs, ExecutorProfileId},
};
@@ -19,16 +20,28 @@ pub struct CodingAgentInitialRequest {
pub executor_profile_id: ExecutorProfileId,
}
impl CodingAgentInitialRequest {
pub fn base_executor(&self) -> BaseCodingAgent {
self.executor_profile_id.executor
}
}
#[async_trait]
impl Executable for CodingAgentInitialRequest {
async fn spawn(&self, current_dir: &Path) -> Result<SpawnedChild, ExecutorError> {
async fn spawn(
&self,
current_dir: &Path,
approvals: Arc<dyn ExecutorApprovalService>,
) -> Result<SpawnedChild, ExecutorError> {
let executor_profile_id = self.executor_profile_id.clone();
let agent = ExecutorConfigs::get_cached()
let mut agent = ExecutorConfigs::get_cached()
.get_coding_agent(&executor_profile_id)
.ok_or(ExecutorError::UnknownExecutorType(
executor_profile_id.to_string(),
))?;
agent.use_approvals(approvals.clone());
agent.spawn(current_dir, &self.prompt).await
}
}

View File

@@ -1,4 +1,4 @@
use std::path::Path;
use std::{path::Path, sync::Arc};
use async_trait::async_trait;
use enum_dispatch::enum_dispatch;
@@ -10,7 +10,8 @@ use crate::{
coding_agent_follow_up::CodingAgentFollowUpRequest,
coding_agent_initial::CodingAgentInitialRequest, script::ScriptRequest,
},
executors::{ExecutorError, SpawnedChild},
approvals::ExecutorApprovalService,
executors::{BaseCodingAgent, ExecutorError, SpawnedChild},
};
pub mod coding_agent_follow_up;
pub mod coding_agent_initial;
@@ -43,17 +44,35 @@ impl ExecutorAction {
pub fn next_action(&self) -> Option<&ExecutorAction> {
self.next_action.as_deref()
}
pub fn base_executor(&self) -> Option<BaseCodingAgent> {
match self.typ() {
ExecutorActionType::CodingAgentInitialRequest(request) => Some(request.base_executor()),
ExecutorActionType::CodingAgentFollowUpRequest(request) => {
Some(request.base_executor())
}
ExecutorActionType::ScriptRequest(_) => None,
}
}
}
#[async_trait]
#[enum_dispatch(ExecutorActionType)]
pub trait Executable {
async fn spawn(&self, current_dir: &Path) -> Result<SpawnedChild, ExecutorError>;
async fn spawn(
&self,
current_dir: &Path,
approvals: Arc<dyn ExecutorApprovalService>,
) -> Result<SpawnedChild, ExecutorError>;
}
#[async_trait]
impl Executable for ExecutorAction {
async fn spawn(&self, current_dir: &Path) -> Result<SpawnedChild, ExecutorError> {
self.typ.spawn(current_dir).await
async fn spawn(
&self,
current_dir: &Path,
approvals: Arc<dyn ExecutorApprovalService>,
) -> Result<SpawnedChild, ExecutorError> {
self.typ.spawn(current_dir, approvals).await
}
}

View File

@@ -1,4 +1,4 @@
use std::path::Path;
use std::{path::Path, sync::Arc};
use async_trait::async_trait;
use command_group::AsyncCommandGroup;
@@ -9,6 +9,7 @@ use workspace_utils::shell::get_shell_command;
use crate::{
actions::Executable,
approvals::ExecutorApprovalService,
executors::{ExecutorError, SpawnedChild},
};
@@ -33,7 +34,11 @@ pub struct ScriptRequest {
#[async_trait]
impl Executable for ScriptRequest {
async fn spawn(&self, current_dir: &Path) -> Result<SpawnedChild, ExecutorError> {
async fn spawn(
&self,
current_dir: &Path,
_approvals: Arc<dyn ExecutorApprovalService>,
) -> Result<SpawnedChild, ExecutorError> {
let (shell_cmd, shell_arg) = get_shell_command();
let mut command = Command::new(shell_cmd);
command

View File

@@ -0,0 +1,63 @@
use std::fmt;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use workspace_utils::approvals::ApprovalStatus;
/// Errors emitted by executor approval services.
#[derive(Debug, Error)]
pub enum ExecutorApprovalError {
#[error("executor approval session not registered")]
SessionNotRegistered,
#[error("executor approval request failed: {0}")]
RequestFailed(String),
#[error("executor approval service unavailable")]
ServiceUnavailable,
}
impl ExecutorApprovalError {
pub fn request_failed<E: fmt::Display>(err: E) -> Self {
Self::RequestFailed(err.to_string())
}
}
/// Abstraction for executor approval backends.
#[async_trait]
pub trait ExecutorApprovalService: Send + Sync {
/// Registers the session identifier associated with subsequent approval requests.
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError>;
/// Requests approval for a tool invocation and waits for the final decision.
async fn request_tool_approval(
&self,
tool_name: &str,
tool_input: Value,
tool_call_id: &str,
) -> Result<ApprovalStatus, ExecutorApprovalError>;
}
#[derive(Debug, Default)]
pub struct NoopExecutorApprovalService;
#[async_trait]
impl ExecutorApprovalService for NoopExecutorApprovalService {
async fn register_session(&self, _session_id: &str) -> Result<(), ExecutorApprovalError> {
Ok(())
}
async fn request_tool_approval(
&self,
_tool_name: &str,
_tool_input: Value,
_tool_call_id: &str,
) -> Result<ApprovalStatus, ExecutorApprovalError> {
Ok(ApprovalStatus::Approved)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCallMetadata {
pub tool_call_id: String,
}

View File

@@ -2,7 +2,6 @@ pub mod client;
pub mod jsonrpc;
pub mod normalize_logs;
pub mod session;
use std::{
collections::HashMap,
path::{Path, PathBuf},
@@ -11,8 +10,11 @@ use std::{
use async_trait::async_trait;
use codex_app_server_protocol::NewConversationParams;
use codex_protocol::config_types::SandboxMode as CodexSandboxMode;
use codex_protocol::{
config_types::SandboxMode as CodexSandboxMode, protocol::AskForApproval as CodexAskForApproval,
};
use command_group::AsyncCommandGroup;
use derivative::Derivative;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -28,6 +30,7 @@ use self::{
session::SessionHandler,
};
use crate::{
approvals::ExecutorApprovalService,
command::{CmdOverrides, CommandBuilder, apply_overrides},
executors::{
AppendPrompt, ExecutorError, SpawnedChild, StandardCodingAgentExecutor,
@@ -47,6 +50,25 @@ pub enum SandboxMode {
DangerFullAccess,
}
/// Determines when the user is consulted to approve Codex actions.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, JsonSchema, AsRefStr)]
#[serde(rename_all = "kebab-case")]
#[strum(serialize_all = "kebab-case")]
pub enum AskForApproval {
/// Read-only commands are auto-approved. Everything else will ask the user to approve.
UnlessTrusted,
/// All commands run in a restricted sandbox initially.
/// If the command fails, the user is asked to approve execution without the sandbox.
OnFailure,
/// The model decides when to ask the user for approval.
OnRequest,
/// Never ask the user to approve commands. Commands that fail in the restricted sandbox will not be retried.
Never,
}
/// Reasoning effort for the underlying model
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, JsonSchema, AsRefStr)]
#[serde(rename_all = "kebab-case")]
@@ -77,13 +99,16 @@ pub enum ReasoningSummaryFormat {
Experimental,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS, JsonSchema)]
#[derive(Derivative, Clone, Serialize, Deserialize, TS, JsonSchema)]
#[derivative(Debug, PartialEq)]
pub struct Codex {
#[serde(default)]
pub append_prompt: AppendPrompt,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sandbox: Option<SandboxMode>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ask_for_approval: Option<AskForApproval>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub oss: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
@@ -103,10 +128,19 @@ pub struct Codex {
pub include_apply_patch_tool: Option<bool>,
#[serde(flatten)]
pub cmd: CmdOverrides,
#[serde(skip)]
#[ts(skip)]
#[derivative(Debug = "ignore", PartialEq = "ignore")]
approvals: Option<Arc<dyn ExecutorApprovalService>>,
}
#[async_trait]
impl StandardCodingAgentExecutor for Codex {
fn use_approvals(&mut self, approvals: Arc<dyn ExecutorApprovalService>) {
self.approvals = Some(approvals);
}
async fn spawn(&self, current_dir: &Path, prompt: &str) -> Result<SpawnedChild, ExecutorError> {
let command = self.build_command_builder().build_initial();
self.spawn(current_dir, prompt, command, None).await
@@ -151,11 +185,18 @@ impl Codex {
Some(SandboxMode::DangerFullAccess) => Some(CodexSandboxMode::DangerFullAccess),
};
let approval_policy = self.ask_for_approval.as_ref().map(|policy| match policy {
AskForApproval::UnlessTrusted => CodexAskForApproval::UnlessTrusted,
AskForApproval::OnFailure => CodexAskForApproval::OnFailure,
AskForApproval::OnRequest => CodexAskForApproval::OnRequest,
AskForApproval::Never => CodexAskForApproval::Never,
});
NewConversationParams {
model: self.model.clone(),
profile: self.profile.clone(),
cwd: Some(cwd.to_string_lossy().to_string()),
approval_policy: None,
approval_policy,
sandbox,
config: self.build_config_overrides(),
base_instructions: self.base_instructions.clone(),
@@ -234,6 +275,11 @@ impl Codex {
let params = self.build_new_conversation_params(current_dir);
let resume_session = resume_session.map(|s| s.to_string());
let auto_approve = matches!(
(&self.sandbox, &self.ask_for_approval),
(Some(SandboxMode::DangerFullAccess), None)
);
let approvals = self.approvals.clone();
tokio::spawn(async move {
let exit_signal_tx = ExitSignalSender::new(exit_signal_tx);
let log_writer = LogWriter::new(new_stdout);
@@ -245,6 +291,8 @@ impl Codex {
child_stdin,
log_writer.clone(),
exit_signal_tx.clone(),
approvals,
auto_approve,
)
.await
{
@@ -268,6 +316,7 @@ impl Codex {
})
}
#[allow(clippy::too_many_arguments)]
async fn launch_codex_app_server(
conversation_params: NewConversationParams,
resume_session: Option<String>,
@@ -276,8 +325,10 @@ impl Codex {
child_stdin: tokio::process::ChildStdin,
log_writer: LogWriter,
exit_signal_tx: ExitSignalSender,
approvals: Option<Arc<dyn ExecutorApprovalService>>,
auto_approve: bool,
) -> Result<(), ExecutorError> {
let client = AppServerClient::new(log_writer);
let client = AppServerClient::new(log_writer, approvals, auto_approve);
let rpc_peer =
JsonRpcPeer::spawn(child_stdin, child_stdout, client.clone(), exit_signal_tx);
client.connect(rpc_peer);
@@ -286,11 +337,11 @@ impl Codex {
None => {
let params = conversation_params;
let response = client.new_conversation(params).await?;
let conversation_id = response.conversation_id;
client.register_session(&conversation_id).await?;
client.add_conversation_listener(conversation_id).await?;
client
.add_conversation_listener(response.conversation_id)
.await?;
client
.send_user_message(response.conversation_id, combined_prompt)
.send_user_message(conversation_id, combined_prompt)
.await?;
}
Some(session_id) => {
@@ -306,11 +357,11 @@ impl Codex {
rollout_path.display(),
response
);
let conversation_id = response.conversation_id;
client.register_session(&conversation_id).await?;
client.add_conversation_listener(conversation_id).await?;
client
.add_conversation_listener(response.conversation_id)
.await?;
client
.send_user_message(response.conversation_id, combined_prompt)
.send_user_message(conversation_id, combined_prompt)
.await?;
}
}

View File

@@ -1,5 +1,6 @@
use std::{
borrow::Cow,
collections::VecDeque,
io,
sync::{Arc, OnceLock},
};
@@ -13,26 +14,43 @@ use codex_app_server_protocol::{
ResumeConversationParams, ResumeConversationResponse, SendUserMessageParams,
SendUserMessageResponse, ServerNotification, ServerRequest,
};
use codex_protocol::{ConversationId, protocol::ReviewDecision};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use serde_json::{self, Value};
use tokio::{
io::{AsyncWrite, AsyncWriteExt, BufWriter},
sync::Mutex,
};
use workspace_utils::approvals::ApprovalStatus;
use super::jsonrpc::{JsonRpcCallbacks, JsonRpcPeer};
use crate::executors::ExecutorError;
use crate::{
approvals::{ExecutorApprovalError, ExecutorApprovalService},
executors::{ExecutorError, codex::normalize_logs::Approval},
};
pub struct AppServerClient {
rpc: OnceLock<JsonRpcPeer>,
log_writer: LogWriter,
approvals: Option<Arc<dyn ExecutorApprovalService>>,
conversation_id: Mutex<Option<ConversationId>>,
pending_feedback: Mutex<VecDeque<String>>,
auto_approve: bool,
}
impl AppServerClient {
pub fn new(log_writer: LogWriter) -> Arc<Self> {
pub fn new(
log_writer: LogWriter,
approvals: Option<Arc<dyn ExecutorApprovalService>>,
auto_approve: bool,
) -> Arc<Self> {
Arc::new(Self {
rpc: OnceLock::new(),
log_writer,
approvals,
auto_approve,
conversation_id: Mutex::new(None),
pending_feedback: Mutex::new(VecDeque::new()),
})
}
@@ -113,6 +131,120 @@ impl AppServerClient {
self.send_request(request, "sendUserMessage").await
}
async fn handle_server_request(
&self,
peer: &JsonRpcPeer,
request: ServerRequest,
) -> Result<(), ExecutorError> {
match request {
ServerRequest::ApplyPatchApproval { request_id, params } => {
let input = serde_json::to_value(&params)
.map_err(|err| ExecutorError::Io(io::Error::other(err.to_string())))?;
let status = match self
.request_tool_approval("edit", input, &params.call_id)
.await
{
Ok(status) => status,
Err(err) => {
tracing::error!("failed to request patch approval: {err}");
ApprovalStatus::Denied {
reason: Some("approval service error".to_string()),
}
}
};
self.log_writer
.log_raw(
&Approval::approval_response(
params.call_id,
"codex.apply_patch".to_string(),
status.clone(),
)
.raw(),
)
.await?;
let (decision, feedback) = self.review_decision(&status).await?;
let response = ApplyPatchApprovalResponse { decision };
send_server_response(peer, request_id, response).await?;
if let Some(message) = feedback {
tracing::debug!("queueing patch denial feedback: {message}");
self.enqueue_feedback(message).await;
}
Ok(())
}
ServerRequest::ExecCommandApproval { request_id, params } => {
let input = serde_json::to_value(&params)
.map_err(|err| ExecutorError::Io(io::Error::other(err.to_string())))?;
let status = match self
.request_tool_approval("bash", input, &params.call_id)
.await
{
Ok(status) => status,
Err(err) => {
tracing::error!("failed to request command approval: {err}");
ApprovalStatus::Denied {
reason: Some("approval service error".to_string()),
}
}
};
self.log_writer
.log_raw(
&Approval::approval_response(
params.call_id,
"codex.exec_command".to_string(),
status.clone(),
)
.raw(),
)
.await?;
let (decision, feedback) = self.review_decision(&status).await?;
let response = ExecCommandApprovalResponse { decision };
send_server_response(peer, request_id, response).await?;
if let Some(message) = feedback {
tracing::debug!("queueing exec denial feedback: {message}");
self.enqueue_feedback(message).await;
}
Ok(())
}
}
}
async fn request_tool_approval(
&self,
tool_name: &str,
tool_input: Value,
tool_call_id: &str,
) -> Result<ApprovalStatus, ExecutorError> {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
if self.auto_approve {
return Ok(ApprovalStatus::Approved);
}
Ok(self
.approvals
.as_ref()
.ok_or(ExecutorApprovalError::ServiceUnavailable)?
.request_tool_approval(tool_name, tool_input, tool_call_id)
.await?)
}
pub async fn register_session(
&self,
conversation_id: &ConversationId,
) -> Result<(), ExecutorError> {
{
let mut guard = self.conversation_id.lock().await;
guard.replace(*conversation_id);
}
if let Some(approvals) = self.approvals.as_ref() {
approvals
.register_session(&conversation_id.to_string())
.await
.map_err(|err| ExecutorError::Io(io::Error::other(err.to_string())))?;
}
self.flush_pending_feedback().await;
Ok(())
}
async fn send_message<M>(&self, message: &M) -> Result<(), ExecutorError>
where
M: Serialize + Sync,
@@ -131,6 +263,94 @@ impl AppServerClient {
fn next_request_id(&self) -> RequestId {
self.rpc().next_request_id()
}
async fn review_decision(
&self,
status: &ApprovalStatus,
) -> Result<(ReviewDecision, Option<String>), ExecutorError> {
if self.auto_approve {
return Ok((ReviewDecision::ApprovedForSession, None));
}
let outcome = match status {
ApprovalStatus::Approved => (ReviewDecision::Approved, None),
ApprovalStatus::Denied { reason } => {
let feedback = reason
.as_ref()
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string());
if feedback.is_some() {
(ReviewDecision::Abort, feedback)
} else {
(ReviewDecision::Denied, None)
}
}
ApprovalStatus::TimedOut => (ReviewDecision::Denied, None),
ApprovalStatus::Pending => (ReviewDecision::Denied, None),
};
Ok(outcome)
}
async fn enqueue_feedback(&self, message: String) {
if message.trim().is_empty() {
return;
}
let mut guard = self.pending_feedback.lock().await;
guard.push_back(message);
}
async fn flush_pending_feedback(&self) {
let messages: Vec<String> = {
let mut guard = self.pending_feedback.lock().await;
guard.drain(..).collect()
};
if messages.is_empty() {
return;
}
let Some(conversation_id) = *self.conversation_id.lock().await else {
tracing::warn!(
"pending Codex feedback but conversation id unavailable; dropping {} messages",
messages.len()
);
return;
};
for message in messages {
let trimmed = message.trim();
if trimmed.is_empty() {
continue;
}
self.spawn_feedback_message(conversation_id, trimmed.to_string());
}
}
fn spawn_feedback_message(&self, conversation_id: ConversationId, feedback: String) {
let peer = self.rpc().clone();
let request = ClientRequest::SendUserMessage {
request_id: peer.next_request_id(),
params: SendUserMessageParams {
conversation_id,
items: vec![InputItem::Text {
text: format!("User feedback: {feedback}"),
}],
},
};
tokio::spawn(async move {
if let Err(err) = peer
.request::<SendUserMessageResponse, _>(
request_id(&request),
&request,
"sendUserMessage",
)
.await
{
tracing::error!("failed to send feedback follow-up message: {err}");
}
});
}
}
#[async_trait]
@@ -143,7 +363,7 @@ impl JsonRpcCallbacks for AppServerClient {
) -> Result<(), ExecutorError> {
self.log_writer.log_raw(raw).await?;
match ServerRequest::try_from(request.clone()) {
Ok(server_request) => handle_server_request(peer, server_request).await,
Ok(server_request) => self.handle_server_request(peer, server_request).await,
Err(err) => {
tracing::debug!("Unhandled server request `{}`: {err}", request.method);
let response = JSONRPCResponse {
@@ -200,6 +420,12 @@ impl JsonRpcCallbacks for AppServerClient {
return Ok(false);
}
if method.ends_with("turn_aborted") {
tracing::debug!("codex turn aborted; flushing feedback queue");
self.flush_pending_feedback().await;
return Ok(false);
}
let has_finished = method
.strip_prefix("codex/event/")
.is_some_and(|suffix| suffix == "task_complete");
@@ -213,27 +439,6 @@ impl JsonRpcCallbacks for AppServerClient {
}
}
// Aprovals
async fn handle_server_request(
peer: &JsonRpcPeer,
request: ServerRequest,
) -> Result<(), ExecutorError> {
match request {
ServerRequest::ApplyPatchApproval { request_id, .. } => {
let response = ApplyPatchApprovalResponse {
decision: codex_protocol::protocol::ReviewDecision::ApprovedForSession,
};
send_server_response(peer, request_id, response).await
}
ServerRequest::ExecCommandApproval { request_id, .. } => {
let response = ExecCommandApprovalResponse {
decision: codex_protocol::protocol::ReviewDecision::ApprovedForSession,
};
send_server_response(peer, request_id, response).await
}
}
}
async fn send_server_response<T>(
peer: &JsonRpcPeer,
request_id: RequestId,

View File

@@ -13,11 +13,12 @@ use codex_protocol::{
plan_tool::{StepStatus, UpdatePlanArgs},
protocol::{
AgentMessageDeltaEvent, AgentMessageEvent, AgentReasoningDeltaEvent, AgentReasoningEvent,
AgentReasoningSectionBreakEvent, BackgroundEventEvent, ErrorEvent, EventMsg,
ExecCommandBeginEvent, ExecCommandEndEvent, ExecCommandOutputDeltaEvent, ExecOutputStream,
FileChange as CodexProtoFileChange, McpInvocation, McpToolCallBeginEvent,
McpToolCallEndEvent, PatchApplyBeginEvent, PatchApplyEndEvent, StreamErrorEvent,
TokenUsageInfo, ViewImageToolCallEvent, WebSearchBeginEvent, WebSearchEndEvent,
AgentReasoningSectionBreakEvent, ApplyPatchApprovalRequestEvent, BackgroundEventEvent,
ErrorEvent, EventMsg, ExecApprovalRequestEvent, ExecCommandBeginEvent, ExecCommandEndEvent,
ExecCommandOutputDeltaEvent, ExecOutputStream, FileChange as CodexProtoFileChange,
McpInvocation, McpToolCallBeginEvent, McpToolCallEndEvent, PatchApplyBeginEvent,
PatchApplyEndEvent, StreamErrorEvent, TokenUsageInfo, ViewImageToolCallEvent,
WebSearchBeginEvent, WebSearchEndEvent,
},
};
use futures::StreamExt;
@@ -26,12 +27,14 @@ use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use workspace_utils::{
approvals::ApprovalStatus,
diff::{concatenate_diff_hunks, extract_unified_diff_hunks},
msg_store::MsgStore,
path::make_path_relative,
};
use crate::{
approvals::ToolCallMetadata,
executors::codex::session::SessionHandler,
logs::{
ActionType, CommandExitStatus, CommandRunResult, FileChange, NormalizedEntry,
@@ -45,6 +48,10 @@ trait ToNormalizedEntry {
fn to_normalized_entry(&self) -> NormalizedEntry;
}
trait ToNormalizedEntryOpt {
fn to_normalized_entry_opt(&self) -> Option<NormalizedEntry>;
}
#[derive(Debug, Deserialize)]
struct CodexNotificationParams {
#[serde(rename = "msg")]
@@ -66,10 +73,14 @@ struct CommandState {
formatted_output: Option<String>,
status: ToolStatus,
exit_code: Option<i32>,
awaiting_approval: bool,
call_id: String,
}
impl ToNormalizedEntry for CommandState {
fn to_normalized_entry(&self) -> NormalizedEntry {
let content = format!("`{}`", self.command);
NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -89,8 +100,11 @@ impl ToNormalizedEntry for CommandState {
},
status: self.status.clone(),
},
content: format!("`{}`", self.command),
metadata: None,
content,
metadata: serde_json::to_value(ToolCallMetadata {
tool_call_id: self.call_id.clone(),
})
.ok(),
}
}
}
@@ -165,10 +179,14 @@ struct PatchEntry {
path: String,
changes: Vec<FileChange>,
status: ToolStatus,
awaiting_approval: bool,
call_id: String,
}
impl ToNormalizedEntry for PatchEntry {
fn to_normalized_entry(&self) -> NormalizedEntry {
let content = self.path.clone();
NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -179,8 +197,11 @@ impl ToNormalizedEntry for PatchEntry {
},
status: self.status.clone(),
},
content: self.path.clone(),
metadata: None,
content,
metadata: serde_json::to_value(ToolCallMetadata {
tool_call_id: self.call_id.clone(),
})
.ok(),
}
}
}
@@ -385,6 +406,13 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
continue;
}
if let Ok(approval) = serde_json::from_str::<Approval>(&line) {
if let Some(entry) = approval.to_normalized_entry_opt() {
add_normalized_entry(&msg_store, &entry_index, entry);
}
continue;
}
if let Ok(response) = serde_json::from_str::<JSONRPCResponse>(&line) {
handle_jsonrpc_response(response, &msg_store, &entry_index);
continue;
@@ -466,6 +494,80 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
state.assistant = None;
state.thinking = None;
}
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
call_id,
command,
cwd: _,
reason,
}) => {
state.assistant = None;
state.thinking = None;
let command_text = if command.is_empty() {
reason
.filter(|r| !r.is_empty())
.unwrap_or_else(|| "command execution".to_string())
} else {
command.join(" ")
};
let command_state = state.commands.entry(call_id.clone()).or_default();
if command_state.command.is_empty() {
command_state.command = command_text;
}
command_state.awaiting_approval = true;
if let Some(index) = command_state.index {
replace_normalized_entry(
&msg_store,
index,
command_state.to_normalized_entry(),
);
} else {
let index = add_normalized_entry(
&msg_store,
&entry_index,
command_state.to_normalized_entry(),
);
command_state.index = Some(index);
}
}
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
call_id,
changes,
reason: _,
grant_root: _,
}) => {
state.assistant = None;
state.thinking = None;
let normalized = normalize_file_changes(&worktree_path_str, &changes);
let patch_state = state.patches.entry(call_id.clone()).or_default();
for entry in patch_state.entries.drain(..) {
if let Some(index) = entry.index {
msg_store.push_patch(ConversationPatch::remove(index));
}
}
for (path, file_changes) in normalized {
let mut entry = PatchEntry {
index: None,
path,
changes: file_changes,
status: ToolStatus::Created,
awaiting_approval: true,
call_id: call_id.clone(),
};
let index = add_normalized_entry(
&msg_store,
&entry_index,
entry.to_normalized_entry(),
);
entry.index = Some(index);
patch_state.entries.push(entry);
}
}
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
call_id, command, ..
}) => {
@@ -485,6 +587,8 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
formatted_output: None,
status: ToolStatus::Created,
exit_code: None,
awaiting_approval: false,
call_id: call_id.clone(),
},
);
let command_state = state.commands.get_mut(&call_id).unwrap();
@@ -532,6 +636,7 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
if let Some(mut command_state) = state.commands.remove(&call_id) {
command_state.formatted_output = Some(formatted_output);
command_state.exit_code = Some(exit_code);
command_state.awaiting_approval = false;
command_state.status = if exit_code == 0 {
ToolStatus::Success
} else {
@@ -664,23 +769,68 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
state.assistant = None;
state.thinking = None;
let normalized = normalize_file_changes(&worktree_path_str, &changes);
let mut patch_state = PatchState::default();
for (path, file_changes) in normalized {
patch_state.entries.push(PatchEntry {
index: None,
path,
changes: file_changes,
status: ToolStatus::Created,
});
let patch_entry = patch_state.entries.last_mut().unwrap();
let index = add_normalized_entry(
&msg_store,
&entry_index,
patch_entry.to_normalized_entry(),
);
patch_entry.index = Some(index);
if let Some(patch_state) = state.patches.get_mut(&call_id) {
let mut iter = normalized.into_iter();
for entry in &mut patch_state.entries {
if let Some((path, file_changes)) = iter.next() {
entry.path = path;
entry.changes = file_changes;
}
entry.status = ToolStatus::Created;
entry.awaiting_approval = false;
if let Some(index) = entry.index {
replace_normalized_entry(
&msg_store,
index,
entry.to_normalized_entry(),
);
} else {
let index = add_normalized_entry(
&msg_store,
&entry_index,
entry.to_normalized_entry(),
);
entry.index = Some(index);
}
}
for (path, file_changes) in iter {
let mut entry = PatchEntry {
index: None,
path,
changes: file_changes,
status: ToolStatus::Created,
awaiting_approval: false,
call_id: call_id.clone(),
};
let index = add_normalized_entry(
&msg_store,
&entry_index,
entry.to_normalized_entry(),
);
entry.index = Some(index);
patch_state.entries.push(entry);
}
} else {
let mut patch_state = PatchState::default();
for (path, file_changes) in normalized {
patch_state.entries.push(PatchEntry {
index: None,
path,
changes: file_changes,
status: ToolStatus::Created,
awaiting_approval: false,
call_id: call_id.clone(),
});
let patch_entry = patch_state.entries.last_mut().unwrap();
let index = add_normalized_entry(
&msg_store,
&entry_index,
patch_entry.to_normalized_entry(),
);
patch_entry.index = Some(index);
}
state.patches.insert(call_id, patch_state);
}
state.patches.insert(call_id, patch_state);
}
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
call_id,
@@ -826,9 +976,7 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
| EventMsg::ConversationPath(..)
| EventMsg::EnteredReviewMode(..)
| EventMsg::ExitedReviewMode(..)
| EventMsg::TaskComplete(..)
| EventMsg::ExecApprovalRequest(..)
| EventMsg::ApplyPatchApprovalRequest(..) => {}
| EventMsg::TaskComplete(..) => {}
}
}
});
@@ -937,3 +1085,73 @@ impl ToNormalizedEntry for Error {
}
}
}
#[derive(Serialize, Deserialize, Debug)]
pub enum Approval {
ApprovalResponse {
call_id: String,
tool_name: String,
approval_status: ApprovalStatus,
},
}
impl Approval {
pub fn approval_response(
call_id: String,
tool_name: String,
approval_status: ApprovalStatus,
) -> Self {
Self::ApprovalResponse {
call_id,
tool_name,
approval_status,
}
}
pub fn raw(&self) -> String {
serde_json::to_string(self).unwrap_or_default()
}
pub fn display_tool_name(&self) -> String {
let Self::ApprovalResponse { tool_name, .. } = self;
match tool_name.as_str() {
"codex.exec_command" => "Exec Command".to_string(),
"codex.apply_patch" => "Edit".to_string(),
other => other.to_string(),
}
}
}
impl ToNormalizedEntryOpt for Approval {
fn to_normalized_entry_opt(&self) -> Option<NormalizedEntry> {
let Self::ApprovalResponse {
call_id: _,
tool_name: _,
approval_status,
} = self;
let tool_name = self.display_tool_name();
match approval_status {
ApprovalStatus::Pending => None,
ApprovalStatus::Approved => None,
ApprovalStatus::Denied { reason } => Some(NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::UserFeedback {
denied_tool: tool_name.clone(),
},
content: reason
.clone()
.unwrap_or_else(|| "User denied this tool use request".to_string())
.trim()
.to_string(),
metadata: None,
}),
ApprovalStatus::TimedOut => Some(NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ErrorMessage,
content: format!("Approval timed out for tool {tool_name}"),
metadata: None,
}),
}
}
}

View File

@@ -13,6 +13,7 @@ use ts_rs::TS;
use workspace_utils::msg_store::MsgStore;
use crate::{
approvals::ExecutorApprovalService,
executors::{
amp::Amp, claude::ClaudeCode, codex::Codex, copilot::Copilot, cursor::Cursor,
gemini::Gemini, opencode::Opencode, qwen::QwenCode,
@@ -52,6 +53,8 @@ pub enum ExecutorError {
TomlSerialize(#[from] toml::ser::Error),
#[error(transparent)]
TomlDeserialize(#[from] toml::de::Error),
#[error(transparent)]
ExecutorApprovalError(#[from] crate::approvals::ExecutorApprovalError),
}
#[enum_dispatch]
@@ -138,6 +141,8 @@ impl CodingAgent {
#[async_trait]
#[enum_dispatch(CodingAgent)]
pub trait StandardCodingAgentExecutor {
fn use_approvals(&mut self, _approvals: Arc<dyn ExecutorApprovalService>) {}
async fn spawn(&self, current_dir: &Path, prompt: &str) -> Result<SpawnedChild, ExecutorError>;
async fn spawn_follow_up(
&self,

View File

@@ -1,4 +1,5 @@
pub mod actions;
pub mod approvals;
pub mod command;
pub mod executors;
pub mod logs;

View File

@@ -113,6 +113,14 @@ impl ConversationPatch {
from_value(json!([patch_entry])).unwrap()
}
pub fn remove(entry_index: usize) -> Patch {
from_value(json!([{
"op": PatchOperation::Remove,
"path": format!("/entries/{entry_index}"),
}]))
.unwrap()
}
}
/// Extract the entry index and `NormalizedEntry` from a JsonPatch if it contains one

View File

@@ -27,6 +27,8 @@ use db::{
use deployment::DeploymentError;
use executors::{
actions::{Executable, ExecutorAction},
approvals::{ExecutorApprovalService, NoopExecutorApprovalService},
executors::BaseCodingAgent,
logs::{
NormalizedEntryType,
utils::{
@@ -39,6 +41,7 @@ use futures::{FutureExt, StreamExt, TryStreamExt, stream::select};
use serde_json::json;
use services::services::{
analytics::AnalyticsContext,
approvals::{Approvals, executor_approvals::ExecutorApprovalBridge},
config::Config,
container::{ContainerError, ContainerRef, ContainerService},
diff_stream::{self, DiffStreamHandle},
@@ -67,6 +70,7 @@ pub struct LocalContainerService {
git: GitService,
image_service: ImageService,
analytics: Option<AnalyticsContext>,
approvals: Approvals,
}
impl LocalContainerService {
@@ -77,6 +81,7 @@ impl LocalContainerService {
git: GitService,
image_service: ImageService,
analytics: Option<AnalyticsContext>,
approvals: Approvals,
) -> Self {
let child_store = Arc::new(RwLock::new(HashMap::new()));
@@ -88,6 +93,7 @@ impl LocalContainerService {
git,
image_service,
analytics,
approvals,
}
}
@@ -801,8 +807,20 @@ impl ContainerService for LocalContainerService {
)))?;
let current_dir = PathBuf::from(container_ref);
let approvals_service: Arc<dyn ExecutorApprovalService> =
match executor_action.base_executor() {
Some(BaseCodingAgent::Codex) => ExecutorApprovalBridge::new(
self.approvals.clone(),
self.db.clone(),
execution_process.id,
),
_ => Arc::new(NoopExecutorApprovalService {}),
};
// Create the child and stream, add to execution tracker
let mut spawned = executor_action.spawn(&current_dir).await?;
let mut spawned = executor_action
.spawn(&current_dir, approvals_service)
.await?;
self.track_child_msgs_in_store(execution_process.id, &mut spawned.child)
.await;

View File

@@ -118,6 +118,7 @@ impl Deployment for LocalDeployment {
git.clone(),
image.clone(),
analytics_ctx,
approvals.clone(),
);
container.spawn_worktree_cleanup().await;

View File

@@ -78,6 +78,7 @@ fn generate_types_content() -> String {
executors::executors::amp::Amp::decl(),
executors::executors::codex::Codex::decl(),
executors::executors::codex::SandboxMode::decl(),
executors::executors::codex::AskForApproval::decl(),
executors::executors::codex::ReasoningEffort::decl(),
executors::executors::codex::ReasoningSummary::decl(),
executors::executors::codex::ReasoningSummaryFormat::decl(),

View File

@@ -1,3 +1,5 @@
pub mod executor_approvals;
use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
use chrono::{DateTime, Utc};
@@ -7,10 +9,14 @@ use db::models::{
executor_session::ExecutorSession,
task::{Task, TaskStatus},
};
use executors::logs::{
NormalizedEntry, NormalizedEntryType, ToolStatus,
utils::patch::{ConversationPatch, extract_normalized_entry_from_patch},
use executors::{
approvals::ToolCallMetadata,
logs::{
NormalizedEntry, NormalizedEntryType, ToolStatus,
utils::patch::{ConversationPatch, extract_normalized_entry_from_patch},
},
};
use futures::future::{BoxFuture, FutureExt, Shared};
use sqlx::{Error as SqlxError, SqlitePool};
use thiserror::Error;
use tokio::sync::{RwLock, oneshot};
@@ -35,6 +41,8 @@ struct PendingApproval {
response_tx: oneshot::Sender<ApprovalStatus>,
}
type ApprovalWaiter = Shared<BoxFuture<'static, ApprovalStatus>>;
#[derive(Debug)]
pub struct ToolContext {
pub tool_name: String,
@@ -73,15 +81,25 @@ impl Approvals {
}
}
#[tracing::instrument(skip(self, request))]
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
async fn create_internal(
&self,
request: ApprovalRequest,
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
let (tx, rx) = oneshot::channel();
let waiter: ApprovalWaiter = rx
.map(|result| result.unwrap_or(ApprovalStatus::TimedOut))
.boxed()
.shared();
let req_id = request.id.clone();
if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await {
// Find the matching tool use entry by name and input
let matching_tool =
find_matching_tool_use(store.clone(), &request.tool_name, &request.tool_input);
let matching_tool = find_matching_tool_use(
store.clone(),
&request.tool_name,
&request.tool_input,
request.tool_call_id.as_deref(),
);
if let Some((idx, matching_tool)) = matching_tool {
let approval_entry = matching_tool
@@ -125,10 +143,23 @@ impl Approvals {
);
}
self.spawn_timeout_watcher(req_id.clone(), request.timeout_at, rx);
self.spawn_timeout_watcher(req_id.clone(), request.timeout_at, waiter.clone());
Ok((request, waiter))
}
#[tracing::instrument(skip(self, request))]
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
let (request, _) = self.create_internal(request).await?;
Ok(request)
}
pub async fn create_with_waiter(
&self,
request: ApprovalRequest,
) -> Result<(ApprovalRequest, ApprovalWaiter), ApprovalError> {
self.create_internal(request).await
}
pub async fn create_from_session(
&self,
pool: &SqlitePool,
@@ -145,15 +176,7 @@ impl Approvals {
};
// Move the task to InReview if it's still InProgress
if let Ok(ctx) = ExecutionProcess::load_context(pool, execution_process_id).await
&& ctx.task.status == TaskStatus::InProgress
&& let Err(e) = Task::update_status(pool, ctx.task.id, TaskStatus::InReview).await
{
tracing::warn!(
"Failed to update task status to InReview for approval request: {}",
e
);
}
ensure_task_in_review(pool, execution_process_id).await;
let request = ApprovalRequest::from_create(payload, execution_process_id);
self.create(request).await
@@ -248,12 +271,12 @@ impl Approvals {
.collect()
}
#[tracing::instrument(skip(self, id, timeout_at, rx))]
#[tracing::instrument(skip(self, id, timeout_at, waiter))]
fn spawn_timeout_watcher(
&self,
id: String,
timeout_at: chrono::DateTime<chrono::Utc>,
mut rx: oneshot::Receiver<ApprovalStatus>,
waiter: ApprovalWaiter,
) {
let pending = self.pending.clone();
let completed = self.completed.clone();
@@ -269,19 +292,18 @@ impl Approvals {
let status = tokio::select! {
biased;
r = &mut rx => match r {
Ok(status) => status,
Err(_canceled) => ApprovalStatus::TimedOut,
},
resolved = waiter.clone() => resolved,
_ = tokio::time::sleep_until(deadline) => ApprovalStatus::TimedOut,
};
let is_timeout = matches!(&status, ApprovalStatus::TimedOut);
completed.insert(id.clone(), status.clone());
let removed = pending.remove(&id);
if is_timeout && let Some((_, pending_approval)) = pending.remove(&id) {
if pending_approval.response_tx.send(status.clone()).is_err() {
tracing::debug!("approval '{}' timeout notification receiver dropped", id);
}
if is_timeout && let Some((_, pending_approval)) = removed {
let store = {
let map = msg_stores.read().await;
map.get(&pending_approval.execution_process_id).cloned()
@@ -318,8 +340,22 @@ impl Approvals {
}
}
pub(crate) async fn ensure_task_in_review(pool: &SqlitePool, execution_process_id: Uuid) {
if let Ok(ctx) = ExecutionProcess::load_context(pool, execution_process_id).await
&& ctx.task.status == TaskStatus::InProgress
&& let Err(e) = Task::update_status(pool, ctx.task.id, TaskStatus::InReview).await
{
tracing::warn!(
"Failed to update task status to InReview for approval request: {}",
e
);
}
}
/// Comparison strategy for matching tool use entries
enum ToolComparisonStrategy {
/// Compare by tool_call_id
ToolCallId(String),
/// Compare deserialized ClaudeToolData structures (for known tools)
Deserialized(executors::executors::claude::ClaudeToolData),
/// Compare raw JSON input fields (for Unknown tools like MCP)
@@ -332,31 +368,37 @@ fn find_matching_tool_use(
store: Arc<MsgStore>,
tool_name: &str,
tool_input: &serde_json::Value,
tool_call_id: Option<&str>,
) -> Option<(usize, NormalizedEntry)> {
use executors::executors::claude::ClaudeToolData;
let history = store.get_history();
// Determine comparison strategy based on tool type
let strategy = match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
"name": tool_name,
"input": tool_input
})) {
Ok(ClaudeToolData::Unknown { .. }) => {
// For Unknown tools (MCP, future tools), use raw JSON comparison
ToolComparisonStrategy::RawJson
}
Ok(data) => {
// For known tools, use deserialized comparison with proper alias handling
ToolComparisonStrategy::Deserialized(data)
}
Err(e) => {
tracing::warn!(
"Failed to deserialize tool_input for tool '{}': {}",
tool_name,
e
);
return None;
let strategy = if let Some(call_id) = tool_call_id {
// If tool_call_id is provided, use it for matching
ToolComparisonStrategy::ToolCallId(call_id.to_string())
} else {
match serde_json::from_value::<ClaudeToolData>(serde_json::json!({
"name": tool_name,
"input": tool_input
})) {
Ok(ClaudeToolData::Unknown { .. }) => {
// For Unknown tools (MCP, future tools), use raw JSON comparison
ToolComparisonStrategy::RawJson
}
Ok(data) => {
// For known tools, use deserialized comparison with proper alias handling
ToolComparisonStrategy::Deserialized(data)
}
Err(e) => {
tracing::warn!(
"Failed to deserialize tool_input for tool '{}': {}",
tool_name,
e
);
return None;
}
}
};
@@ -383,6 +425,18 @@ fn find_matching_tool_use(
// Apply comparison strategy
if let Some(metadata) = &entry.metadata {
let is_match = match &strategy {
ToolComparisonStrategy::ToolCallId(call_id) => {
// Match by tool_call_id in metadata
if let Ok(ToolCallMetadata {
tool_call_id: entry_call_id,
..
}) = serde_json::from_value::<ToolCallMetadata>(metadata.clone())
{
entry_call_id == *call_id
} else {
false
}
}
ToolComparisonStrategy::RawJson => {
// Compare raw JSON input for Unknown tools
if let Some(entry_input) = metadata.get("input") {
@@ -405,8 +459,13 @@ fn find_matching_tool_use(
if is_match {
let strategy_name = match strategy {
ToolComparisonStrategy::RawJson => "raw input comparison",
ToolComparisonStrategy::Deserialized(_) => "deserialized tool data",
ToolComparisonStrategy::ToolCallId(call_id) => {
format!("tool_call_id '{call_id}'")
}
ToolComparisonStrategy::RawJson => "raw input comparison".to_string(),
ToolComparisonStrategy::Deserialized(_) => {
"deserialized tool data".to_string()
}
};
tracing::debug!(
"Matched tool use entry at index {idx} for tool '{tool_name}' by {strategy_name}"
@@ -483,12 +542,12 @@ mod tests {
let bar_input = serde_json::json!({"file_path": "bar.rs"});
let baz_input = serde_json::json!({"file_path": "baz.rs"});
let (idx_foo, _) =
find_matching_tool_use(store.clone(), "Read", &foo_input).expect("Should match foo.rs");
let (idx_bar, _) =
find_matching_tool_use(store.clone(), "Read", &bar_input).expect("Should match bar.rs");
let (idx_baz, _) =
find_matching_tool_use(store.clone(), "Read", &baz_input).expect("Should match baz.rs");
let (idx_foo, _) = find_matching_tool_use(store.clone(), "Read", &foo_input, None)
.expect("Should match foo.rs");
let (idx_bar, _) = find_matching_tool_use(store.clone(), "Read", &bar_input, None)
.expect("Should match bar.rs");
let (idx_baz, _) = find_matching_tool_use(store.clone(), "Read", &baz_input, None)
.expect("Should match baz.rs");
assert_eq!(idx_foo, 0, "foo.rs should match first entry");
assert_eq!(idx_bar, 1, "bar.rs should match second entry");
@@ -510,21 +569,21 @@ mod tests {
let pending_input = serde_json::json!({"file_path": "pending.rs"});
assert!(
find_matching_tool_use(store.clone(), "Read", &pending_input).is_none(),
find_matching_tool_use(store.clone(), "Read", &pending_input, None).is_none(),
"Should not match tools in PendingApproval state"
);
// Test 3: Wrong tool name returns None
let write_input = serde_json::json!({"file_path": "foo.rs", "content": "test"});
assert!(
find_matching_tool_use(store.clone(), "Write", &write_input).is_none(),
find_matching_tool_use(store.clone(), "Write", &write_input, None).is_none(),
"Should not match different tool names"
);
// Test 4: Wrong input parameters returns None
let wrong_input = serde_json::json!({"file_path": "nonexistent.rs"});
assert!(
find_matching_tool_use(store.clone(), "Read", &wrong_input).is_none(),
find_matching_tool_use(store.clone(), "Read", &wrong_input, None).is_none(),
"Should not match with different input parameters"
);
}

View File

@@ -0,0 +1,81 @@
use std::sync::Arc;
use async_trait::async_trait;
use db::{self, DBService};
use executors::approvals::{ExecutorApprovalError, ExecutorApprovalService};
use serde_json::Value;
use tokio::sync::RwLock;
use utils::approvals::{ApprovalRequest, ApprovalStatus, CreateApprovalRequest};
use uuid::Uuid;
use crate::services::approvals::Approvals;
pub struct ExecutorApprovalBridge {
approvals: Approvals,
db: DBService,
execution_process_id: Uuid,
session_id: RwLock<Option<String>>,
}
impl ExecutorApprovalBridge {
pub fn new(approvals: Approvals, db: DBService, execution_process_id: Uuid) -> Arc<Self> {
Arc::new(Self {
approvals,
db,
execution_process_id,
session_id: RwLock::new(None),
})
}
}
#[async_trait]
impl ExecutorApprovalService for ExecutorApprovalBridge {
async fn register_session(&self, session_id: &str) -> Result<(), ExecutorApprovalError> {
let mut guard = self.session_id.write().await;
guard.replace(session_id.to_string());
Ok(())
}
async fn request_tool_approval(
&self,
tool_name: &str,
tool_input: Value,
tool_call_id: &str,
) -> Result<ApprovalStatus, ExecutorApprovalError> {
let session_id = {
let guard = self.session_id.read().await;
guard
.clone()
.ok_or(ExecutorApprovalError::SessionNotRegistered)?
};
super::ensure_task_in_review(&self.db.pool, self.execution_process_id).await;
let request = ApprovalRequest::from_create(
CreateApprovalRequest {
tool_name: tool_name.to_string(),
tool_input,
session_id,
tool_call_id: Some(tool_call_id.to_string()),
},
self.execution_process_id,
);
let (_, waiter) = self
.approvals
.create_with_waiter(request)
.await
.map_err(ExecutorApprovalError::request_failed)?;
let status = waiter.clone().await;
if matches!(status, ApprovalStatus::Pending) {
return Err(ExecutorApprovalError::request_failed(
"approval finished in pending state",
));
}
Ok(status)
}
}

View File

@@ -12,6 +12,7 @@ pub struct ApprovalRequest {
pub tool_name: String,
pub tool_input: serde_json::Value,
pub session_id: String,
pub tool_call_id: Option<String>,
pub execution_process_id: Uuid,
pub created_at: DateTime<Utc>,
pub timeout_at: DateTime<Utc>,
@@ -25,6 +26,7 @@ impl ApprovalRequest {
tool_name: request.tool_name,
tool_input: request.tool_input,
session_id: request.session_id,
tool_call_id: request.tool_call_id,
execution_process_id,
created_at: now,
timeout_at: now + Duration::seconds(APPROVAL_TIMEOUT_SECONDS),
@@ -38,6 +40,7 @@ pub struct CreateApprovalRequest {
pub tool_name: String,
pub tool_input: serde_json::Value,
pub session_id: String,
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]