feat: manual approvals (#748)

* manual user approvals

* refactor implementation

* cleanup

* fix lint errors

* i18n

* remove isLastEntry frontend check

* address fe feedback

* always run claude plan with approvals

* add watchkill script back to plan mode

* update timeout

* tooltip hover

* use response type

* put back watchkill append hack
This commit is contained in:
Gabriel Gordon-Hall
2025-09-22 16:02:42 +01:00
committed by GitHub
parent eaff3dee9e
commit 798bcb80a3
51 changed files with 1808 additions and 198 deletions

View File

@@ -0,0 +1,62 @@
{
"db_name": "SQLite",
"query": "SELECT\n id as \"id!: Uuid\",\n task_attempt_id as \"task_attempt_id!: Uuid\",\n execution_process_id as \"execution_process_id!: Uuid\",\n session_id,\n prompt,\n summary,\n created_at as \"created_at!: DateTime<Utc>\",\n updated_at as \"updated_at!: DateTime<Utc>\"\n FROM executor_sessions\n WHERE session_id = ?",
"describe": {
"columns": [
{
"name": "id!: Uuid",
"ordinal": 0,
"type_info": "Blob"
},
{
"name": "task_attempt_id!: Uuid",
"ordinal": 1,
"type_info": "Blob"
},
{
"name": "execution_process_id!: Uuid",
"ordinal": 2,
"type_info": "Blob"
},
{
"name": "session_id",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "prompt",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "summary",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "created_at!: DateTime<Utc>",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "updated_at!: DateTime<Utc>",
"ordinal": 7,
"type_info": "Text"
}
],
"parameters": {
"Right": 1
},
"nullable": [
true,
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "87b21a85d12a8d5494464574a460b1b78024a9bb43d6246f161a073629e463ff"
}

View File

@@ -104,6 +104,29 @@ impl ExecutorSession {
.await
}
pub async fn find_by_session_id(
pool: &SqlitePool,
session_id: &str,
) -> Result<Option<Self>, sqlx::Error> {
sqlx::query_as!(
ExecutorSession,
r#"SELECT
id as "id!: Uuid",
task_attempt_id as "task_attempt_id!: Uuid",
execution_process_id as "execution_process_id!: Uuid",
session_id,
prompt,
summary,
created_at as "created_at!: DateTime<Utc>",
updated_at as "updated_at!: DateTime<Utc>"
FROM executor_sessions
WHERE session_id = ?"#,
session_id
)
.fetch_optional(pool)
.await
}
/// Create a new executor session
pub async fn create(
pool: &SqlitePool,

View File

@@ -18,6 +18,7 @@ use git2::Error as Git2Error;
use serde_json::Value;
use services::services::{
analytics::AnalyticsService,
approvals::Approvals,
auth::{AuthError, AuthService},
config::{Config, ConfigError},
container::{ContainerError, ContainerService},
@@ -102,6 +103,8 @@ pub trait Deployment: Clone + Send + Sync + 'static {
fn file_search_cache(&self) -> &Arc<FileSearchCache>;
fn approvals(&self) -> &Approvals;
async fn update_sentry_scope(&self) -> Result<(), DeploymentError> {
let user_id = self.user_id();
let config = self.config().read().await;

View File

@@ -4,7 +4,7 @@ version = "0.0.94"
edition = "2024"
[dependencies]
utils = { path = "../utils" }
workspace_utils = { path = "../utils", package = "utils" }
tokio = { workspace = true }
tokio-util = { version = "0.7", features = ["io"] }
bytes = "1.0"
@@ -15,7 +15,7 @@ toml = "0.8"
tracing-subscriber = { workspace = true }
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4", "serde"] }
ts-rs = { workspace = true, features = ["serde-json-impl"]}
ts-rs = { workspace = true }
schemars = { workspace = true }
dirs = "5.0"
xdg = "3.0"
@@ -41,3 +41,4 @@ strum_macros = "0.27.2"
convert_case = "0.6"
sqlx = "0.8.6"
axum = { workspace = true }
shlex = "1.3.0"

View File

@@ -10,6 +10,11 @@
"CLAUDE_CODE": {
"plan": true
}
},
"APPROVALS": {
"CLAUDE_CODE": {
"approvals": true
}
}
},
"AMP": {

View File

@@ -5,7 +5,7 @@ use command_group::{AsyncCommandGroup, AsyncGroupChild};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use ts_rs::TS;
use utils::shell::get_shell_command;
use workspace_utils::shell::get_shell_command;
use crate::{actions::Executable, executors::ExecutorError};

View File

@@ -6,7 +6,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::{io::AsyncWriteExt, process::Command};
use ts_rs::TS;
use utils::{msg_store::MsgStore, shell::get_shell_command};
use workspace_utils::{msg_store::MsgStore, shell::get_shell_command};
use crate::{
command::{CmdOverrides, CommandBuilder, apply_overrides},

View File

@@ -1,17 +1,19 @@
use std::{path::Path, process::Stdio, sync::Arc};
use std::{os::unix::fs::PermissionsExt, path::Path, process::Stdio, sync::Arc};
use async_trait::async_trait;
use command_group::{AsyncCommandGroup, AsyncGroupChild};
use futures::StreamExt;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::{io::AsyncWriteExt, process::Command};
use tokio::{io::AsyncWriteExt, process::Command, sync::OnceCell};
use ts_rs::TS;
use utils::{
use workspace_utils::{
approvals::APPROVAL_TIMEOUT_SECONDS,
diff::{concatenate_diff_hunks, create_unified_diff, create_unified_diff_hunk},
log_msg::LogMsg,
msg_store::MsgStore,
path::make_path_relative,
port_file::read_port_file,
shell::get_shell_command,
};
@@ -19,12 +21,22 @@ use crate::{
command::{CmdOverrides, CommandBuilder, apply_overrides},
executors::{AppendPrompt, ExecutorError, StandardCodingAgentExecutor},
logs::{
ActionType, FileChange, NormalizedEntry, NormalizedEntryType, TodoItem,
ActionType, FileChange, NormalizedEntry, NormalizedEntryType, TodoItem, ToolStatus,
stderr_processor::normalize_stderr_logs,
utils::{EntryIndexProvider, patch::ConversationPatch},
},
};
static BACKEND_PORT: OnceCell<u16> = OnceCell::const_new();
async fn get_backend_port() -> std::io::Result<u16> {
BACKEND_PORT
.get_or_try_init(|| async { read_port_file("vibe-kanban").await })
.await
.copied()
}
const CONFIRM_HOOK_SCRIPT: &str = include_str!("./hooks/confirm.py");
fn base_command(claude_code_router: bool) -> &'static str {
if claude_code_router {
"npx -y @musistudio/claude-code-router code"
@@ -42,6 +54,8 @@ pub struct ClaudeCode {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub plan: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub approvals: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dangerously_skip_permissions: Option<bool>,
@@ -50,7 +64,7 @@ pub struct ClaudeCode {
}
impl ClaudeCode {
fn build_command_builder(&self) -> CommandBuilder {
async fn build_command_builder(&self) -> CommandBuilder {
// If base_command_override is provided and claude_code_router is also set, log a warning
if self.cmd.base_command_override.is_some() && self.claude_code_router.is_some() {
tracing::warn!(
@@ -62,9 +76,35 @@ impl ClaudeCode {
CommandBuilder::new(base_command(self.claude_code_router.unwrap_or(false)))
.params(["-p"]);
if self.plan.unwrap_or(false) {
let plan = self.plan.unwrap_or(false);
let approvals = self.approvals.unwrap_or(false);
if plan && approvals {
tracing::warn!("Both plan and approvals are enabled. Plan will take precedence.");
}
if plan {
builder = builder.extend_params(["--permission-mode=plan"]);
}
if plan || approvals {
match settings_json(plan).await {
// TODO: Avoid quoting
Ok(settings) => match shlex::try_quote(&settings) {
Ok(quoted) => {
builder = builder.extend_params(["--settings", &quoted]);
}
Err(e) => {
tracing::error!("Failed to quote approvals JSON for --settings: {e}");
}
},
Err(e) => {
tracing::error!(
"Failed to generate approvals JSON. Not running approvals: {e}"
);
}
}
}
if self.dangerously_skip_permissions.unwrap_or(false) {
builder = builder.extend_params(["--dangerously-skip-permissions"]);
}
@@ -85,13 +125,16 @@ impl StandardCodingAgentExecutor for ClaudeCode {
prompt: &str,
) -> Result<AsyncGroupChild, ExecutorError> {
let (shell_cmd, shell_arg) = get_shell_command();
let command_builder = self.build_command_builder();
let base_command = command_builder.build_initial();
let claude_command = if self.plan.unwrap_or(false) {
create_watchkill_script(&base_command)
} else {
base_command
};
let command_builder = self.build_command_builder().await;
let mut base_command = command_builder.build_initial();
if self.plan.unwrap_or(false) {
base_command = create_watchkill_script(&base_command);
}
if self.approvals.unwrap_or(false) || self.plan.unwrap_or(false) {
write_python_hook(current_dir).await?
}
let combined_prompt = self.append_prompt.combine_prompt(prompt);
@@ -103,7 +146,7 @@ impl StandardCodingAgentExecutor for ClaudeCode {
.stderr(Stdio::piped())
.current_dir(current_dir)
.arg(shell_arg)
.arg(&claude_command);
.arg(&base_command);
let mut child = command.group_spawn()?;
@@ -123,15 +166,18 @@ impl StandardCodingAgentExecutor for ClaudeCode {
session_id: &str,
) -> Result<AsyncGroupChild, ExecutorError> {
let (shell_cmd, shell_arg) = get_shell_command();
let command_builder = self.build_command_builder();
let command_builder = self.build_command_builder().await;
// Build follow-up command with --resume {session_id}
let base_command =
let mut base_command =
command_builder.build_follow_up(&["--resume".to_string(), session_id.to_string()]);
let claude_command = if self.plan.unwrap_or(false) {
create_watchkill_script(&base_command)
} else {
base_command
};
if self.plan.unwrap_or(false) {
base_command = create_watchkill_script(&base_command);
}
if self.approvals.unwrap_or(false) || self.plan.unwrap_or(false) {
write_python_hook(current_dir).await?
}
let combined_prompt = self.append_prompt.combine_prompt(prompt);
@@ -143,7 +189,7 @@ impl StandardCodingAgentExecutor for ClaudeCode {
.stderr(Stdio::piped())
.current_dir(current_dir)
.arg(shell_arg)
.arg(&claude_command);
.arg(&base_command);
let mut child = command.group_spawn()?;
@@ -177,14 +223,80 @@ impl StandardCodingAgentExecutor for ClaudeCode {
}
}
async fn write_python_hook(current_dir: &Path) -> Result<(), ExecutorError> {
let hooks_dir = current_dir.join(".claude").join("hooks");
tokio::fs::create_dir_all(&hooks_dir).await?;
let hook_path = hooks_dir.join("confirm.py");
if tokio::fs::try_exists(&hook_path).await? {
return Ok(());
}
let mut file = tokio::fs::File::create(&hook_path).await?;
file.write_all(CONFIRM_HOOK_SCRIPT.as_bytes()).await?;
file.flush().await?;
// TODO: Handle Windows permissioning
#[cfg(unix)]
{
let perm = std::fs::Permissions::from_mode(0o755);
tokio::fs::set_permissions(&hook_path, perm).await?;
}
// ignore the confirm.py script
let gitignore_path = hooks_dir.join(".gitignore");
if !tokio::fs::try_exists(&gitignore_path).await? {
let mut gitignore_file = tokio::fs::File::create(&gitignore_path).await?;
gitignore_file
.write_all(b"confirm.py\n.gitignore\n")
.await?;
gitignore_file.flush().await?;
}
Ok(())
}
// Configure settings json
async fn settings_json(plan: bool) -> Result<String, std::io::Error> {
let backend_port = get_backend_port().await?;
let backend_timeout = APPROVAL_TIMEOUT_SECONDS + 5; // add buffer
let matcher = if plan {
"^ExitPlanMode$"
} else {
"^(?!(Glob|Grep|NotebookRead|Read|Task|TodoWrite)$).*"
};
Ok(serde_json::json!({
"hooks": {
"PreToolUse": [
{
"matcher": matcher,
"hooks": [
{
"type": "command",
"command": format!("$CLAUDE_PROJECT_DIR/.claude/hooks/confirm.py --timeout-seconds {backend_timeout} --poll-interval 5 --backend-port {backend_port}"),
"timeout": backend_timeout + 10
}
]
}
]
}
})
.to_string())
}
fn create_watchkill_script(command: &str) -> String {
let claude_plan_stop_indicator = concat!("Exit ", "plan mode?");
// Hack: we concatenate so that Claude doesn't trigger the watchkill when reading this file
// during development, since it contains the stop phrase
let claude_plan_stop_indicator = concat!("Approval ", "request timed out");
let cmd = shlex::try_quote(command).unwrap().to_string();
format!(
r#"#!/usr/bin/env bash
set -euo pipefail
word="{claude_plan_stop_indicator}"
command="{command}"
exit_code=0
while IFS= read -r line; do
@@ -192,7 +304,7 @@ while IFS= read -r line; do
if [[ $line == *"$word"* ]]; then
exit 0
fi
done < <($command <&0 2>&1)
done < <(bash -lc {cmd} <&0 2>&1)
exit_code=${{PIPESTATUS[0]}}
exit "$exit_code"
@@ -325,6 +437,7 @@ impl ClaudeLogProcessor {
entry_type: NormalizedEntryType::ToolUse {
tool_name: tool_name.clone(),
action_type,
status: ToolStatus::Created,
},
content: content_text.clone(),
metadata: Some(
@@ -457,6 +570,12 @@ impl ClaudeLogProcessor {
})
};
let status = if is_error.unwrap_or(false) {
ToolStatus::Failed
} else {
ToolStatus::Success
};
let entry = NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -465,6 +584,7 @@ impl ClaudeLogProcessor {
command: info.content.clone(),
result,
},
status,
},
content: info.content.clone(),
metadata: None,
@@ -520,6 +640,12 @@ impl ClaudeLogProcessor {
tool_name.clone()
};
let status = if is_error.unwrap_or(false) {
ToolStatus::Failed
} else {
ToolStatus::Success
};
let entry = NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -534,6 +660,7 @@ impl ClaudeLogProcessor {
},
),
},
status,
},
content: info.content.clone(),
metadata: None,
@@ -727,6 +854,7 @@ impl ClaudeLogProcessor {
entry_type: NormalizedEntryType::ToolUse {
tool_name: tool_name.to_string(),
action_type,
status: ToolStatus::Created,
},
content,
metadata: Some(
@@ -834,6 +962,7 @@ impl ClaudeLogProcessor {
entry_type: NormalizedEntryType::ToolUse {
tool_name: name.to_string(),
action_type,
status: ToolStatus::Created,
},
content,
metadata: Some(
@@ -1549,11 +1678,12 @@ mod tests {
async fn test_streaming_patch_generation() {
use std::sync::Arc;
use utils::msg_store::MsgStore;
use workspace_utils::msg_store::MsgStore;
let executor = ClaudeCode {
claude_code_router: Some(false),
plan: None,
approvals: None,
model: None,
append_prompt: AppendPrompt::default(),
dangerously_skip_permissions: None,
@@ -1582,7 +1712,7 @@ mod tests {
let history = msg_store.get_history();
let patch_count = history
.iter()
.filter(|msg| matches!(msg, utils::log_msg::LogMsg::JsonPatch(_)))
.filter(|msg| matches!(msg, workspace_utils::log_msg::LogMsg::JsonPatch(_)))
.count();
assert!(
patch_count > 0,

View File

@@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};
use strum_macros::AsRefStr;
use tokio::{io::AsyncWriteExt, process::Command};
use ts_rs::TS;
use utils::{
use workspace_utils::{
diff::{concatenate_diff_hunks, extract_unified_diff_hunks},
msg_store::MsgStore,
path::make_path_relative,
@@ -27,7 +27,7 @@ use crate::{
AppendPrompt, ExecutorError, StandardCodingAgentExecutor, codex::session::SessionHandler,
},
logs::{
ActionType, FileChange, NormalizedEntry, NormalizedEntryType,
ActionType, FileChange, NormalizedEntry, NormalizedEntryType, ToolStatus,
utils::{EntryIndexProvider, patch::ConversationPatch},
},
};
@@ -245,6 +245,7 @@ impl StandardCodingAgentExecutor for Codex {
command: command_str.clone(),
result: None,
},
status: ToolStatus::Created,
},
content: format!("`{command_str}`"),
metadata: None,
@@ -317,6 +318,24 @@ impl StandardCodingAgentExecutor for Codex {
crate::logs::CommandExitStatus::ExitCode { code: *code }
})
};
let status = if let Some(s) = success {
if *s {
ToolStatus::Success
} else {
ToolStatus::Failed
}
} else if let Some(code) = exit_code {
if *code == 0 {
ToolStatus::Success
} else {
ToolStatus::Failed
}
} else {
// Default to failed status
ToolStatus::Failed
};
let entry = NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -328,6 +347,7 @@ impl StandardCodingAgentExecutor for Codex {
output,
}),
},
status,
},
content: prev_content,
metadata: None,
@@ -351,6 +371,7 @@ impl StandardCodingAgentExecutor for Codex {
arguments: invocation.arguments.clone(),
result: None,
},
status: ToolStatus::Created,
},
content: content_str.clone(),
metadata: None,
@@ -386,6 +407,7 @@ impl StandardCodingAgentExecutor for Codex {
value: result.clone(),
}),
},
status: ToolStatus::Success,
},
content: prev_content,
metadata: None,
@@ -710,6 +732,7 @@ impl CodexJson {
path: relative_path.clone(),
changes,
},
status: ToolStatus::Success,
},
content: relative_path,
metadata: None,
@@ -1109,6 +1132,7 @@ invalid json line here
if let NormalizedEntryType::ToolUse {
tool_name,
action_type,
status: _,
} = &entries[0].entry_type
{
assert_eq!(tool_name, "edit");

View File

@@ -2,7 +2,7 @@ use std::{path::PathBuf, sync::Arc};
use futures::StreamExt;
use regex::Regex;
use utils::msg_store::MsgStore;
use workspace_utils::msg_store::MsgStore;
/// Handles session management for Codex
pub struct SessionHandler;

View File

@@ -8,7 +8,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::{io::AsyncWriteExt, process::Command};
use ts_rs::TS;
use utils::{
use workspace_utils::{
diff::{
concatenate_diff_hunks, create_unified_diff, create_unified_diff_hunk,
extract_unified_diff_hunks,
@@ -22,7 +22,7 @@ use crate::{
command::{CmdOverrides, CommandBuilder, apply_overrides},
executors::{AppendPrompt, ExecutorError, StandardCodingAgentExecutor},
logs::{
ActionType, FileChange, NormalizedEntry, NormalizedEntryType, TodoItem,
ActionType, FileChange, NormalizedEntry, NormalizedEntryType, TodoItem, ToolStatus,
plain_text_processor::PlainTextLogProcessor,
utils::{ConversationPatch, EntryIndexProvider},
},
@@ -255,6 +255,7 @@ impl StandardCodingAgentExecutor for Cursor {
entry_type: NormalizedEntryType::ToolUse {
tool_name,
action_type,
status: ToolStatus::Created,
},
content,
metadata: None,
@@ -358,6 +359,7 @@ impl StandardCodingAgentExecutor for Cursor {
}),
};
}
let entry = NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
@@ -374,6 +376,7 @@ impl StandardCodingAgentExecutor for Cursor {
_ => tool_call.get_name().to_string(),
},
action_type: new_action,
status: ToolStatus::Success,
},
content: content_str,
metadata: None,
@@ -1063,7 +1066,7 @@ Tests
mod tests {
use std::sync::Arc;
use utils::msg_store::MsgStore;
use workspace_utils::msg_store::MsgStore;
use super::*;
@@ -1103,7 +1106,7 @@ mod tests {
let history = msg_store.get_history();
let patch_count = history
.iter()
.filter(|m| matches!(m, utils::log_msg::LogMsg::JsonPatch(_)))
.filter(|m| matches!(m, workspace_utils::log_msg::LogMsg::JsonPatch(_)))
.count();
assert!(
patch_count >= 2,

View File

@@ -15,7 +15,7 @@ use tokio::{
process::Command,
};
use ts_rs::TS;
use utils::{msg_store::MsgStore, shell::get_shell_command};
use workspace_utils::{msg_store::MsgStore, shell::get_shell_command};
use crate::{
command::{CmdOverrides, CommandBuilder, apply_overrides},
@@ -381,7 +381,7 @@ You are continuing work on the above task. The execution history shows the previ
fn get_legacy_sessions_base_dir() -> PathBuf {
// Previous location was under the temp-based vibe-kanban dir
utils::path::get_vibe_kanban_temp_dir().join("gemini_sessions")
workspace_utils::path::get_vibe_kanban_temp_dir().join("gemini_sessions")
}
async fn get_session_file_path(current_dir: &Path) -> PathBuf {

View File

@@ -0,0 +1,165 @@
#!/usr/bin/env python3
import argparse
import json
import sys
import time
import urllib.error
import urllib.request
from typing import Optional
def json_error(reason: Optional[str]) -> None:
"""Emit a deny PreToolUse JSON to stdout and exit(0)."""
payload = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": reason,
}
}
print(json.dumps(payload, ensure_ascii=False))
sys.exit(0)
def json_success() -> None:
payload = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
},
"suppressOutput": True,
}
print(json.dumps(payload, ensure_ascii=False))
sys.exit(0)
def http_post_json(url: str, body: dict) -> dict:
data = json.dumps(body).encode("utf-8")
req = urllib.request.Request(
url, data=data, headers={"Content-Type": "application/json"}, method="POST"
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
return json.loads(resp.read().decode("utf-8") or "{}")
except (
urllib.error.HTTPError,
urllib.error.URLError,
TimeoutError,
json.JSONDecodeError,
) as e:
json_error(
f"Failed to create approval request. Backend may be unavailable. ({e})"
)
raise # unreachable
def http_get_json(url: str) -> dict:
req = urllib.request.Request(url, method="GET")
try:
with urllib.request.urlopen(req, timeout=10) as resp:
return json.loads(resp.read().decode("utf-8") or "{}")
except (
urllib.error.HTTPError,
urllib.error.URLError,
TimeoutError,
json.JSONDecodeError,
) as e:
json_error(f"Lost connection to approval backend: {e}")
raise
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="PreToolUse approval gate. All parameters are passed via CLI."
)
parser.add_argument(
"-t",
"--timeout-seconds",
type=int,
required=True,
help="Maximum time to wait for approval before timing out (seconds).",
)
parser.add_argument(
"-p",
"--poll-interval",
type=int,
required=True,
help="Seconds between polling the backend for status.",
)
parser.add_argument(
"-b",
"--backend-port",
type=int,
required=True,
help="Port of the approval backend running on 127.0.0.1.",
)
args = parser.parse_args()
if args.timeout_seconds <= 0:
parser.error("--timeout-seconds must be a positive integer")
if args.poll_interval <= 0:
parser.error("--poll-interval must be a positive integer")
if args.poll_interval > args.timeout_seconds:
parser.error("--poll-interval cannot be greater than --timeout-seconds")
return args
def main():
args = parse_args()
port = args.backend_port
url = f"http://127.0.0.1:{port}"
create_endpoint = f"{url}/api/approvals/create"
try:
raw_payload = sys.stdin.read()
incoming = json.loads(raw_payload or "{}")
except json.JSONDecodeError:
json_error("Invalid JSON payload on stdin")
tool_name = incoming.get("tool_name")
tool_input = incoming.get("tool_input")
session_id = incoming.get("session_id", "unknown")
create_payload = {
"tool_name": tool_name,
"tool_input": tool_input,
"session_id": session_id,
}
response = http_post_json(create_endpoint, create_payload)
approval_id = response.get("id")
if not approval_id:
json_error("Invalid response from approval backend")
print(
f"Approval request created: {approval_id}. Waiting for user response...",
file=sys.stderr,
)
elapsed = 0
while elapsed < args.timeout_seconds:
result = http_get_json(f"{url}/api/approvals/{approval_id}/status")
status = result.get("status")
if status == "approved":
json_success()
elif status == "denied":
reason = result.get("reason")
json_error(reason)
elif status == "timed_out":
json_error(
f"Approval request timed out after {args.timeout_seconds} seconds"
)
elif status == "pending":
time.sleep(args.poll_interval)
elapsed += args.poll_interval
else:
json_error(f"Unknown approval status: {status}")
json_error(f"Approval request timed out after {args.timeout_seconds} seconds")
if __name__ == "__main__":
main()

View File

@@ -10,7 +10,7 @@ use sqlx::Type;
use strum_macros::{Display, EnumDiscriminants, EnumString, VariantNames};
use thiserror::Error;
use ts_rs::TS;
use utils::msg_store::MsgStore;
use workspace_utils::msg_store::MsgStore;
use crate::{
executors::{

View File

@@ -16,7 +16,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::{io::AsyncWriteExt, process::Command};
use ts_rs::TS;
use utils::{msg_store::MsgStore, path::make_path_relative, shell::get_shell_command};
use workspace_utils::{msg_store::MsgStore, path::make_path_relative, shell::get_shell_command};
use crate::{
command::{CmdOverrides, CommandBuilder, apply_overrides},
@@ -25,7 +25,7 @@ use crate::{
opencode::share_bridge::Bridge as ShareBridge,
},
logs::{
ActionType, FileChange, NormalizedEntry, NormalizedEntryType, TodoItem,
ActionType, FileChange, NormalizedEntry, NormalizedEntryType, TodoItem, ToolStatus,
utils::EntryIndexProvider,
},
stdout_dup,
@@ -796,6 +796,7 @@ impl Opencode {
entry_type: NormalizedEntryType::ToolUse {
tool_name: tool.clone(),
action_type: resolved_action_type,
status: ToolStatus::Success,
},
content: content_text,
metadata: None,

View File

@@ -6,7 +6,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::{io::AsyncWriteExt, process::Command};
use ts_rs::TS;
use utils::{msg_store::MsgStore, shell::get_shell_command};
use workspace_utils::{msg_store::MsgStore, shell::get_shell_command};
use crate::{
command::{CmdOverrides, CommandBuilder, apply_overrides},

View File

@@ -1,5 +1,7 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use workspace_utils::approvals::ApprovalStatus;
pub mod plain_text_processor;
pub mod stderr_processor;
@@ -45,6 +47,7 @@ pub struct NormalizedConversation {
pub summary: Option<String>,
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum NormalizedEntryType {
@@ -53,6 +56,7 @@ pub enum NormalizedEntryType {
ToolUse {
tool_name: String,
action_type: ActionType,
status: ToolStatus,
},
SystemMessage,
ErrorMessage,
@@ -69,6 +73,59 @@ pub struct NormalizedEntry {
pub metadata: Option<serde_json::Value>,
}
impl NormalizedEntry {
pub fn with_tool_status(&self, status: ToolStatus) -> Option<Self> {
if let NormalizedEntryType::ToolUse {
tool_name,
action_type,
..
} = &self.entry_type
{
Some(Self {
entry_type: NormalizedEntryType::ToolUse {
tool_name: tool_name.clone(),
action_type: action_type.clone(),
status,
},
..self.clone()
})
} else {
None
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum ToolStatus {
Created,
Success,
Failed,
Denied {
reason: Option<String>,
},
PendingApproval {
approval_id: String,
requested_at: DateTime<Utc>,
timeout_at: DateTime<Utc>,
},
TimedOut,
}
impl ToolStatus {
pub fn from_approval_status(status: &ApprovalStatus) -> Option<Self> {
match status {
ApprovalStatus::Approved => Some(ToolStatus::Created),
ApprovalStatus::Denied { reason } => Some(ToolStatus::Denied {
reason: reason.clone(),
}),
ApprovalStatus::TimedOut => Some(ToolStatus::TimedOut),
ApprovalStatus::Pending => None, // this should not happen
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export)]
pub struct TodoItem {

View File

@@ -362,7 +362,7 @@ impl PlainTextLogProcessor {
#[cfg(test)]
mod tests {
use super::*;
use crate::logs::NormalizedEntryType;
use crate::logs::{NormalizedEntryType, ToolStatus};
#[test]
fn test_plain_buffer_flush() {
@@ -432,6 +432,7 @@ mod tests {
action_type: super::super::ActionType::Other {
description: tool_name.to_string(),
},
status: ToolStatus::Success,
},
content,
metadata: None,

View File

@@ -11,7 +11,7 @@
use std::{sync::Arc, time::Duration};
use futures::StreamExt;
use utils::msg_store::MsgStore;
use workspace_utils::msg_store::MsgStore;
use super::{NormalizedEntry, NormalizedEntryType, plain_text_processor::PlainTextLogProcessor};
use crate::logs::utils::EntryIndexProvider;

View File

@@ -6,7 +6,7 @@ use std::sync::{
};
use json_patch::PatchOperation;
use utils::{log_msg::LogMsg, msg_store::MsgStore};
use workspace_utils::{log_msg::LogMsg, msg_store::MsgStore};
/// Thread-safe provider for monotonically increasing entry indexes
#[derive(Debug, Clone)]

View File

@@ -1,8 +1,8 @@
use json_patch::Patch;
use serde::{Deserialize, Serialize};
use serde_json::{from_value, json};
use serde_json::{from_value, json, to_value};
use ts_rs::TS;
use utils::diff::Diff;
use workspace_utils::diff::Diff;
use crate::logs::NormalizedEntry;
@@ -114,3 +114,20 @@ impl ConversationPatch {
from_value(json!([patch_entry])).unwrap()
}
}
/// Extract the entry index and `NormalizedEntry` from a JsonPatch if it contains one
pub fn extract_normalized_entry_from_patch(patch: &Patch) -> Option<(usize, NormalizedEntry)> {
let value = to_value(patch).ok()?;
let ops = value.as_array()?;
ops.iter().rev().find_map(|op| {
let path = op.get("path")?.as_str()?;
let entry_index = path.strip_prefix("/entries/")?.parse::<usize>().ok()?;
let value = op.get("value")?;
(value.get("type")?.as_str()? == "NORMALIZED_ENTRY")
.then(|| value.get("content"))
.flatten()
.and_then(|c| from_value::<NormalizedEntry>(c.clone()).ok())
.map(|entry| (entry_index, entry))
})
}

View File

@@ -204,7 +204,7 @@ impl ExecutorConfigs {
/// Load executor profiles from file or defaults
pub fn load() -> Self {
let profiles_path = utils::assets::profiles_path();
let profiles_path = workspace_utils::assets::profiles_path();
// Load defaults first
let mut defaults = Self::from_defaults();
@@ -238,7 +238,7 @@ impl ExecutorConfigs {
/// Save user profile overrides to file (only saves what differs from defaults)
pub fn save_overrides(&self) -> Result<(), ProfileError> {
let profiles_path = utils::assets::profiles_path();
let profiles_path = workspace_utils::assets::profiles_path();
let mut defaults = Self::from_defaults();
defaults.canonicalise();
@@ -425,3 +425,10 @@ impl ExecutorConfigs {
Err(ProfileError::NoAvailableExecutorProfile)
}
}
pub fn to_default_variant(id: &ExecutorProfileId) -> ExecutorProfileId {
ExecutorProfileId {
executor: id.executor,
variant: None,
}
}

View File

@@ -33,8 +33,11 @@ use deployment::DeploymentError;
use executors::{
actions::{Executable, ExecutorAction},
logs::{
NormalizedEntry, NormalizedEntryType,
utils::{ConversationPatch, patch::escape_json_pointer_segment},
NormalizedEntryType,
utils::{
ConversationPatch,
patch::{escape_json_pointer_segment, extract_normalized_entry_from_patch},
},
},
};
use futures::{StreamExt, TryStreamExt, stream::select};
@@ -1298,7 +1301,7 @@ impl LocalContainerService {
for msg in history.iter().rev() {
if let LogMsg::JsonPatch(patch) = msg {
// Try to extract a NormalizedEntry from the patch
if let Some(entry) = self.extract_normalized_entry_from_patch(patch)
if let Some((_, entry)) = extract_normalized_entry_from_patch(patch)
&& matches!(entry.entry_type, NormalizedEntryType::AssistantMessage)
{
let content = entry.content.trim();
@@ -1317,32 +1320,6 @@ impl LocalContainerService {
None
}
/// Extract a NormalizedEntry from a JsonPatch if it contains one
fn extract_normalized_entry_from_patch(
&self,
patch: &json_patch::Patch,
) -> Option<NormalizedEntry> {
// Convert the patch to JSON to examine its structure
if let Ok(patch_json) = serde_json::to_value(patch)
&& let Some(operations) = patch_json.as_array()
{
for operation in operations {
if let Some(value) = operation.get("value") {
// Try to extract a NormalizedEntry from the value
if let Some(patch_type) = value.get("type").and_then(|t| t.as_str())
&& patch_type == "NORMALIZED_ENTRY"
&& let Some(content) = value.get("content")
&& let Ok(entry) =
serde_json::from_value::<NormalizedEntry>(content.clone())
{
return Some(entry);
}
}
}
}
None
}
/// Update the executor session summary with the final assistant message
async fn update_executor_session_summary(&self, exec_id: &Uuid) -> Result<(), anyhow::Error> {
// Check if there's an executor session for this execution process

View File

@@ -6,6 +6,7 @@ use deployment::{Deployment, DeploymentError};
use executors::profile::ExecutorConfigs;
use services::services::{
analytics::{AnalyticsConfig, AnalyticsContext, AnalyticsService, generate_user_id},
approvals::Approvals,
auth::AuthService,
config::{Config, load_config_from_file, save_config_to_file},
container::ContainerService,
@@ -40,6 +41,7 @@ pub struct LocalDeployment {
filesystem: FilesystemService,
events: EventService,
file_search_cache: Arc<FileSearchCache>,
approvals: Approvals,
}
#[async_trait]
@@ -103,6 +105,8 @@ impl Deployment for LocalDeployment {
});
}
let approvals = Approvals::new(db.pool.clone(), msg_stores.clone());
// We need to make analytics accessible to the ContainerService
// TODO: Handle this more gracefully
let analytics_ctx = analytics.as_ref().map(|s| AnalyticsContext {
@@ -136,6 +140,7 @@ impl Deployment for LocalDeployment {
filesystem,
events,
file_search_cache,
approvals,
})
}
@@ -193,4 +198,8 @@ impl Deployment for LocalDeployment {
fn file_search_cache(&self) -> &Arc<FileSearchCache> {
&self.file_search_cache
}
fn approvals(&self) -> &Approvals {
&self.approvals
}
}

View File

@@ -25,7 +25,7 @@ tracing-subscriber = { workspace = true }
sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid"] }
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4", "serde"] }
ts-rs = { workspace = true, features = ["serde-json-impl"]}
ts-rs = { workspace = true }
async-trait = "0.1"
command-group = { version = "5.0", features = ["with-tokio"] }
nix = { version = "0.29", features = ["signal", "process"] }

View File

@@ -108,7 +108,11 @@ fn generate_types_content() -> String {
executors::logs::TodoItem::decl(),
executors::logs::ToolResult::decl(),
executors::logs::ToolResultValueType::decl(),
executors::logs::ToolStatus::decl(),
executors::logs::utils::patch::PatchType::decl(),
utils::approvals::ApprovalStatus::decl(),
utils::approvals::CreateApprovalRequest::decl(),
utils::approvals::ApprovalResponse::decl(),
serde_json::Value::decl(),
];

View File

@@ -81,9 +81,7 @@ async fn main() -> Result<(), VibeKanbanError> {
let actual_port = listener.local_addr()?.port(); // get → 53427 (example)
// Write port file for discovery if prod, warn on fail
if !cfg!(debug_assertions)
&& let Err(e) = write_port_file(actual_port).await
{
if let Err(e) = write_port_file(actual_port).await {
tracing::warn!("Failed to write port file: {}", e);
}

View File

@@ -0,0 +1,92 @@
use axum::{
Json, Router,
extract::{Path, State},
http::StatusCode,
routing::{get, post},
};
use db::models::execution_process::ExecutionProcess;
use deployment::Deployment;
use services::services::container::ContainerService;
use utils::approvals::{
ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus, CreateApprovalRequest,
EXIT_PLAN_MODE_TOOL_NAME,
};
use crate::DeploymentImpl;
pub async fn create_approval(
State(deployment): State<DeploymentImpl>,
Json(request): Json<CreateApprovalRequest>,
) -> Result<Json<ApprovalRequest>, StatusCode> {
let service = deployment.approvals();
let approval_request = ApprovalRequest::from_create(request);
match service.create(approval_request).await {
Ok(approval) => Ok(Json(approval)),
Err(e) => {
tracing::error!("Failed to create approval: {:?}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
pub async fn get_approval_status(
State(deployment): State<DeploymentImpl>,
Path(id): Path<String>,
) -> Result<Json<ApprovalStatus>, StatusCode> {
let service = deployment.approvals();
match service.status(&id).await {
Some(status) => Ok(Json(status)),
None => Err(StatusCode::NOT_FOUND),
}
}
pub async fn respond_to_approval(
State(deployment): State<DeploymentImpl>,
Path(id): Path<String>,
Json(request): Json<ApprovalResponse>,
) -> Result<Json<ApprovalStatus>, StatusCode> {
let service = deployment.approvals();
match service.respond(&id, request).await {
Ok((status, context)) => {
if matches!(status, ApprovalStatus::Approved)
&& context.tool_name == EXIT_PLAN_MODE_TOOL_NAME
{
// If exiting plan mode, automatically start a new execution process with different
// permissions
if let Ok(ctx) = ExecutionProcess::load_context(
&deployment.db().pool,
context.execution_process_id,
)
.await
&& let Err(e) = deployment.container().exit_plan_mode_tool(ctx).await
{
tracing::error!("failed to exit plan mode: {:?}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
Ok(Json(status))
}
Err(e) => {
tracing::error!("Failed to respond to approval: {:?}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
pub async fn get_pending_approvals(
State(deployment): State<DeploymentImpl>,
) -> Json<Vec<ApprovalPendingInfo>> {
let service = deployment.approvals();
let approvals = service.pending().await;
Json(approvals)
}
pub fn router() -> Router<DeploymentImpl> {
Router::new()
.route("/approvals/create", post(create_approval))
.route("/approvals/{id}/status", get(get_approval_status))
.route("/approvals/{id}/respond", post(respond_to_approval))
.route("/approvals/pending", get(get_pending_approvals))
}

View File

@@ -5,6 +5,7 @@ use axum::{
use crate::DeploymentImpl;
pub mod approvals;
pub mod auth;
pub mod config;
pub mod containers;
@@ -34,6 +35,7 @@ pub fn router(deployment: DeploymentImpl) -> IntoMakeService<Router> {
.merge(auth::router(&deployment))
.merge(filesystem::router())
.merge(events::router(&deployment))
.merge(approvals::router())
.nest("/images", images::routes())
.with_state(deployment);

View File

@@ -0,0 +1,274 @@
use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use db::models::executor_session::ExecutorSession;
use executors::logs::{
NormalizedEntry, NormalizedEntryType, ToolStatus,
utils::patch::{ConversationPatch, extract_normalized_entry_from_patch},
};
use sqlx::{Error as SqlxError, SqlitePool};
use thiserror::Error;
use tokio::sync::{RwLock, oneshot};
use utils::{
approvals::{ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus},
log_msg::LogMsg,
msg_store::MsgStore,
};
use uuid::Uuid;
#[derive(Debug)]
struct PendingApproval {
entry_index: usize,
entry: NormalizedEntry,
execution_process_id: Uuid,
tool_name: String,
requested_at: DateTime<Utc>,
timeout_at: DateTime<Utc>,
response_tx: oneshot::Sender<ApprovalStatus>,
}
#[derive(Debug)]
pub struct ToolContext {
pub tool_name: String,
pub execution_process_id: Uuid,
}
#[derive(Clone)]
pub struct Approvals {
pending: Arc<DashMap<String, PendingApproval>>,
completed: Arc<DashMap<String, ApprovalStatus>>,
db_pool: SqlitePool,
msg_stores: Arc<RwLock<HashMap<Uuid, Arc<MsgStore>>>>,
}
#[derive(Debug, Error)]
pub enum ApprovalError {
#[error("approval request not found")]
NotFound,
#[error("approval request already completed")]
AlreadyCompleted,
#[error("no executor session found for session_id: {0}")]
NoExecutorSession(String),
#[error("corresponding tool use entry not found for approval request")]
NoToolUseEntry,
#[error(transparent)]
Custom(#[from] anyhow::Error),
#[error(transparent)]
Sqlx(#[from] SqlxError),
}
impl Approvals {
pub fn new(db_pool: SqlitePool, msg_stores: Arc<RwLock<HashMap<Uuid, Arc<MsgStore>>>>) -> Self {
Self {
pending: Arc::new(DashMap::new()),
completed: Arc::new(DashMap::new()),
db_pool,
msg_stores,
}
}
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
let execution_process_id = if let Some(executor_session) =
ExecutorSession::find_by_session_id(&self.db_pool, &request.session_id).await?
{
executor_session.execution_process_id
} else {
tracing::warn!(
"No executor session found for session_id: {}",
request.session_id
);
return Err(ApprovalError::NoExecutorSession(request.session_id.clone()));
};
let (tx, rx) = oneshot::channel();
let req_id = request.id.clone();
if let Some(store) = self.msg_store_by_id(&execution_process_id).await {
let last_tool = get_last_tool_use(store.clone());
if let Some((idx, last_tool)) = last_tool {
let approval_entry = last_tool
.with_tool_status(ToolStatus::PendingApproval {
approval_id: req_id.clone(),
requested_at: request.created_at,
timeout_at: request.timeout_at,
})
.ok_or(ApprovalError::NoToolUseEntry)?;
store.push_patch(ConversationPatch::replace(idx, approval_entry));
self.pending.insert(
req_id.clone(),
PendingApproval {
entry_index: idx,
entry: last_tool,
execution_process_id,
tool_name: request.tool_name.clone(),
requested_at: request.created_at,
timeout_at: request.timeout_at,
response_tx: tx,
},
);
}
} else {
tracing::warn!(
"No msg_store found for execution_process_id: {}",
execution_process_id
);
}
self.spawn_timeout_watcher(req_id.clone(), request.timeout_at, rx);
Ok(request)
}
pub async fn respond(
&self,
id: &str,
req: ApprovalResponse,
) -> Result<(ApprovalStatus, ToolContext), ApprovalError> {
if let Some((_, p)) = self.pending.remove(id) {
self.completed.insert(id.to_string(), req.status.clone());
let _ = p.response_tx.send(req.status.clone());
if let Some(store) = self.msg_store_by_id(&req.execution_process_id).await {
let status = ToolStatus::from_approval_status(&req.status).ok_or(
ApprovalError::Custom(anyhow::anyhow!("Invalid approval status")),
)?;
let updated_entry = p
.entry
.with_tool_status(status)
.ok_or(ApprovalError::NoToolUseEntry)?;
store.push_patch(ConversationPatch::replace(p.entry_index, updated_entry));
} else {
tracing::warn!(
"No msg_store found for execution_process_id: {}",
req.execution_process_id
);
}
let tool_ctx = ToolContext {
tool_name: p.tool_name,
execution_process_id: p.execution_process_id,
};
Ok((req.status, tool_ctx))
} else if self.completed.contains_key(id) {
Err(ApprovalError::AlreadyCompleted)
} else {
Err(ApprovalError::NotFound)
}
}
pub async fn status(&self, id: &str) -> Option<ApprovalStatus> {
if let Some(f) = self.completed.get(id) {
return Some(f.clone());
}
if let Some(p) = self.pending.get(id) {
if chrono::Utc::now() >= p.timeout_at {
return Some(ApprovalStatus::TimedOut);
}
return Some(ApprovalStatus::Pending);
}
None
}
pub async fn pending(&self) -> Vec<ApprovalPendingInfo> {
self.pending
.iter()
.filter_map(|entry| {
let (id, pending) = entry.pair();
match &pending.entry.entry_type {
NormalizedEntryType::ToolUse { tool_name, .. } => Some(ApprovalPendingInfo {
approval_id: id.clone(),
execution_process_id: pending.execution_process_id,
tool_name: tool_name.clone(),
requested_at: pending.requested_at,
timeout_at: pending.timeout_at,
}),
_ => None,
}
})
.collect()
}
fn spawn_timeout_watcher(
&self,
id: String,
timeout_at: chrono::DateTime<chrono::Utc>,
mut rx: oneshot::Receiver<ApprovalStatus>,
) {
let pending = self.pending.clone();
let completed = self.completed.clone();
let msg_stores = self.msg_stores.clone();
let now = chrono::Utc::now();
let to_wait = (timeout_at - now)
.to_std()
.unwrap_or_else(|_| StdDuration::from_secs(0));
let deadline = tokio::time::Instant::now() + to_wait;
tokio::spawn(async move {
let status = tokio::select! {
biased;
r = &mut rx => match r {
Ok(status) => status,
Err(_canceled) => ApprovalStatus::TimedOut,
},
_ = tokio::time::sleep_until(deadline) => ApprovalStatus::TimedOut,
};
let is_timeout = matches!(&status, ApprovalStatus::TimedOut);
completed.insert(id.clone(), status.clone());
let removed = pending.remove(&id);
if is_timeout && let Some((_, pending_approval)) = removed {
let store = {
let map = msg_stores.read().await;
map.get(&pending_approval.execution_process_id).cloned()
};
if let Some(store) = store {
if let Some(updated_entry) = pending_approval
.entry
.with_tool_status(ToolStatus::TimedOut)
{
store.push_patch(ConversationPatch::replace(
pending_approval.entry_index,
updated_entry,
));
} else {
tracing::warn!(
"Timed out approval '{}' but couldn't update tool status (no tool-use entry).",
id
);
}
} else {
tracing::warn!(
"No msg_store found for execution_process_id: {}",
pending_approval.execution_process_id
);
}
}
});
}
async fn msg_store_by_id(&self, execution_process_id: &Uuid) -> Option<Arc<MsgStore>> {
let map = self.msg_stores.read().await;
map.get(execution_process_id).cloned()
}
}
fn get_last_tool_use(store: Arc<MsgStore>) -> Option<(usize, NormalizedEntry)> {
let history = store.get_history();
for msg in history.iter().rev() {
if let LogMsg::JsonPatch(patch) = msg
&& let Some((idx, entry)) = extract_normalized_entry_from_patch(patch)
&& matches!(entry.entry_type, NormalizedEntryType::ToolUse { .. })
{
return Some((idx, entry));
}
}
None
}

View File

@@ -23,11 +23,12 @@ use db::{
use executors::{
actions::{
ExecutorAction, ExecutorActionType,
coding_agent_follow_up::CodingAgentFollowUpRequest,
coding_agent_initial::CodingAgentInitialRequest,
script::{ScriptContext, ScriptRequest, ScriptRequestLanguage},
},
executors::{ExecutorError, StandardCodingAgentExecutor},
profile::{ExecutorConfigs, ExecutorProfileId},
profile::{ExecutorConfigs, ExecutorProfileId, to_default_variant},
};
use futures::{StreamExt, future};
use sqlx::Error as SqlxError;
@@ -689,4 +690,66 @@ pub trait ContainerService {
tracing::debug!("Started next action: {:?}", next_action);
Ok(())
}
async fn exit_plan_mode_tool(&self, ctx: ExecutionContext) -> Result<(), ContainerError> {
let execution_id = ctx.execution_process.id;
if let Err(err) = self.stop_execution(&ctx.execution_process).await {
tracing::error!("Failed to stop execution process {}: {}", execution_id, err);
return Err(err);
}
let _ = ExecutionProcess::update_completion(
&self.db().pool,
execution_id,
ExecutionProcessStatus::Completed,
Some(0),
)
.await;
let action = ctx.execution_process.executor_action()?;
let executor_profile_id = match action.typ() {
ExecutorActionType::CodingAgentInitialRequest(req) => req.executor_profile_id.clone(),
ExecutorActionType::CodingAgentFollowUpRequest(req) => req.executor_profile_id.clone(),
_ => {
return Err(ContainerError::Other(anyhow::anyhow!(
"exit plan mode tool called on non-coding agent action"
)));
}
};
let cleanup_chain = action.next_action().cloned();
let session_id =
ExecutorSession::find_by_execution_process_id(&self.db().pool, execution_id)
.await?
.and_then(|s| s.session_id);
if session_id.is_none() {
tracing::warn!(
"No executor session found for execution process {}",
execution_id
);
return Err(ContainerError::Other(anyhow::anyhow!(
"No executor session found"
)));
}
let default_profile = to_default_variant(&executor_profile_id);
let follow_up = CodingAgentFollowUpRequest {
prompt: String::from("The plan has been approved, please execute it."),
session_id: session_id.unwrap(),
executor_profile_id: default_profile,
};
let action = ExecutorAction::new(
ExecutorActionType::CodingAgentFollowUpRequest(follow_up),
cleanup_chain.map(Box::new),
);
let _ = self
.start_execution(
&ctx.task_attempt,
&action,
&ExecutionProcessRunReason::CodingAgent,
)
.await?;
Ok(())
}
}

View File

@@ -1,4 +1,5 @@
pub mod analytics;
pub mod approvals;
pub mod auth;
pub mod config;
pub mod container;

View File

@@ -0,0 +1,69 @@
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use uuid::Uuid;
pub const APPROVAL_TIMEOUT_SECONDS: i64 = 3600; // 1 hour
pub const EXIT_PLAN_MODE_TOOL_NAME: &str = "ExitPlanMode";
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
pub struct ApprovalRequest {
pub id: String,
pub tool_name: String,
pub tool_input: serde_json::Value,
pub session_id: String,
pub created_at: DateTime<Utc>,
pub timeout_at: DateTime<Utc>,
}
impl ApprovalRequest {
pub fn from_create(request: CreateApprovalRequest) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
tool_name: request.tool_name,
tool_input: request.tool_input,
session_id: request.session_id,
created_at: now,
timeout_at: now + Duration::seconds(APPROVAL_TIMEOUT_SECONDS),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export)]
pub struct CreateApprovalRequest {
pub tool_name: String,
pub tool_input: serde_json::Value,
pub session_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum ApprovalStatus {
Pending,
Approved,
Denied {
#[ts(optional)]
reason: Option<String>,
},
TimedOut,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export)]
pub struct ApprovalResponse {
pub execution_process_id: Uuid,
pub status: ApprovalStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export)]
pub struct ApprovalPendingInfo {
pub approval_id: String,
pub execution_process_id: Uuid,
pub tool_name: String,
pub requested_at: DateTime<Utc>,
pub timeout_at: DateTime<Utc>,
}

View File

@@ -2,6 +2,7 @@ use std::{env, sync::OnceLock};
use directories::ProjectDirs;
pub mod approvals;
pub mod assets;
pub mod browser;
pub mod diff;

View File

@@ -67,6 +67,7 @@ impl MsgStore {
pub fn push_stdout<S: Into<String>>(&self, s: S) {
self.push(LogMsg::Stdout(s.into()));
}
pub fn push_stderr<S: Into<String>>(&self, s: S) {
self.push(LogMsg::Stderr(s.into()));
}
@@ -85,6 +86,7 @@ impl MsgStore {
pub fn get_receiver(&self) -> broadcast::Receiver<LogMsg> {
self.sender.subscribe()
}
pub fn get_history(&self) -> Vec<LogMsg> {
self.inner
.read()

View File

@@ -10,3 +10,25 @@ pub async fn write_port_file(port: u16) -> std::io::Result<PathBuf> {
fs::write(&path, port.to_string()).await?;
Ok(path)
}
pub async fn read_port_file(app_name: &str) -> std::io::Result<u16> {
let base = if cfg!(target_os = "linux") {
match env::var("XDG_RUNTIME_DIR") {
Ok(val) if !val.is_empty() => PathBuf::from(val),
_ => env::temp_dir(),
}
} else {
env::temp_dir()
};
let path = base.join(app_name).join(format!("{app_name}.port"));
tracing::debug!("Reading port from {:?}", path);
let content = fs::read_to_string(&path).await?;
let port: u16 = content
.trim()
.parse()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(port)
}