diff --git a/Cargo.lock b/Cargo.lock index 393938e1..5edffce2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4043,6 +4043,7 @@ dependencies = [ "serde", "serde_json", "services", + "shlex", "sqlx", "strip-ansi-escapes", "tempfile", @@ -5243,6 +5244,7 @@ dependencies = [ "bytes", "chrono", "directories", + "dirs 5.0.1", "futures", "futures-util", "git2", @@ -5256,6 +5258,7 @@ dependencies = [ "serde", "serde_json", "shellexpand", + "shlex", "similar", "tokio", "tokio-stream", diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index 74e9793b..4d0f2574 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -15,6 +15,7 @@ utils = { path = "../utils" } db = { path = "../db" } services = { path = "../services" } tokio = { workspace = true } +shlex = "1.3.0" tokio-util = { version = "0.7", features = ["io"] } axum = { workspace = true } serde = { workspace = true } diff --git a/crates/server/src/routes/task_attempts/cursor_setup.rs b/crates/server/src/routes/task_attempts/cursor_setup.rs index 25f0e9a1..426c576f 100644 --- a/crates/server/src/routes/task_attempts/cursor_setup.rs +++ b/crates/server/src/routes/task_attempts/cursor_setup.rs @@ -13,6 +13,8 @@ use executors::{ executors::cursor::CursorAgent, }; use services::services::container::ContainerService; +use shlex::try_quote; +use utils::shell::UnixShell; use crate::{error::ApiError, routes::task_attempts::ensure_worktree_path}; @@ -55,8 +57,9 @@ async fn get_setup_helper_action() -> Result { #[cfg(unix)] { let base_command = CursorAgent::base_command(); - // First action: Install - let install_script = format!( + + // Install script with PATH setup + let mut install_script = format!( r#"#!/bin/bash set -e if ! command -v {base_command} &> /dev/null; then @@ -65,18 +68,33 @@ if ! command -v {base_command} &> /dev/null; then echo "Installation complete!" else echo "Cursor CLI already installed" -fi -"# +fi"# ); + let shell = UnixShell::current_shell(); + if let Some(config_file) = shell.config_file() + && let Ok(config_file_str) = try_quote(config_file.to_string_lossy().as_ref()) + { + install_script.push_str(&format!( + r#" + echo "Setting up PATH..." + echo 'export PATH="$HOME/.local/bin:$PATH"' >> {config_file_str} + "# + )); + } let install_request = ScriptRequest { script: install_script, language: ScriptRequestLanguage::Bash, context: ScriptContext::SetupScript, }; - // Second action (chained): Login - let login_script = format!("{base_command} login"); + let login_script = format!( + r#"#!/bin/bash +set -e +export PATH="$HOME/.local/bin:$PATH" +{base_command} login +"# + ); let login_request = ScriptRequest { script: login_script, language: ScriptRequestLanguage::Bash, diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 6f87b40f..a42a4dda 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -6,6 +6,7 @@ edition = "2024" [dependencies] tokio-util = { version = "0.7", features = ["io", "codec"] } bytes = "1.0" +shlex = "1.3.0" axum = { workspace = true, features = ["ws"] } serde = { workspace = true } serde_json = { workspace = true } @@ -32,6 +33,7 @@ shellexpand = "3.1.1" which = "8.0.0" similar = "2" git2 = "0.18" +dirs = "5.0" [target.'cfg(windows)'.dependencies] winreg = "0.55" diff --git a/crates/utils/src/shell.rs b/crates/utils/src/shell.rs index b8dfd249..7f1c91c9 100644 --- a/crates/utils/src/shell.rs +++ b/crates/utils/src/shell.rs @@ -5,6 +5,7 @@ use std::{ env::{join_paths, split_paths}, ffi::{OsStr, OsString}, path::{Path, PathBuf}, + process::Stdio, }; use crate::tokio::block_on; @@ -18,21 +19,7 @@ pub fn get_shell_command() -> (String, &'static str) { if cfg!(windows) { ("cmd".into(), "/C") } else { - // Prefer SHELL env var if set and valid - if let Ok(shell) = std::env::var("SHELL") { - let path = Path::new(&shell); - if path.is_absolute() && path.is_file() { - return (shell, "-c"); - } - } - // Prefer zsh or bash if available, fallback to sh - if std::path::Path::new("/bin/zsh").exists() { - ("zsh".into(), "-c") - } else if std::path::Path::new("/bin/bash").exists() { - ("bash".into(), "-c") - } else { - ("sh".into(), "-c") - } + UnixShell::current_shell().get_shell_command() } } @@ -114,20 +101,106 @@ async fn which(executable: &str) -> Option { .and_then(|result| result.ok()) } +#[derive(Debug, Clone, PartialEq)] +pub enum UnixShell { + Zsh, + Bash, + Sh, + Other(String), +} + +impl UnixShell { + pub fn path(&self) -> PathBuf { + match self { + UnixShell::Zsh => PathBuf::from("/bin/zsh"), + UnixShell::Bash => PathBuf::from("/bin/bash"), + UnixShell::Sh => PathBuf::from("/bin/sh"), + UnixShell::Other(path) => PathBuf::from(path), + } + } + pub fn login(&self) -> bool { + match self { + UnixShell::Zsh => true, + UnixShell::Bash => true, + UnixShell::Sh => false, + UnixShell::Other(_) => false, + } + } + pub fn config_file(&self) -> Option { + let home = dirs::home_dir()?; + let config_file = match self { + UnixShell::Zsh => Some(home.join(".zshrc")), + UnixShell::Bash => Some(home.join(".bashrc")), + UnixShell::Sh => None, + UnixShell::Other(_) => None, + }; + if let Some(config_file) = config_file + && config_file.is_file() + { + Some(config_file) + } else { + None + } + } + + pub fn source_command(&self) -> Option { + if let Some(source_file) = self.config_file() + && let Ok(escaped_source_file) = + shlex::try_quote(source_file.to_string_lossy().as_ref()) + { + Some(format!("source {escaped_source_file}")) + } else { + None + } + } + pub fn current_shell() -> UnixShell { + if let Ok(shell) = std::env::var("SHELL") + && let Some(shell) = UnixShell::from_path(Path::new(&shell)) + { + return shell; + } + UnixShell::Sh + } + pub fn from_path(path: &Path) -> Option { + if path.is_absolute() && path.is_file() { + if path.file_name() == Some(OsStr::new("zsh")) { + Some(UnixShell::Zsh) + } else if path.file_name() == Some(OsStr::new("bash")) { + Some(UnixShell::Bash) + } else if path.file_name() == Some(OsStr::new("sh")) { + Some(UnixShell::Sh) + } else { + Some(UnixShell::Other(path.to_string_lossy().into_owned())) + } + } else { + None + } + } + pub fn get_shell_command(&self) -> (String, &'static str) { + (self.path().to_string_lossy().into_owned(), "-c") + } +} + #[cfg(not(windows))] async fn get_fresh_path() -> Option { use std::time::Duration; use tokio::process::Command; - async fn run(shell: &Path, login: bool) -> Option { - let mut cmd = Command::new(shell); - if login { + async fn run(shell: &UnixShell) -> Option { + let mut cmd = Command::new(shell.path()); + if shell.login() { cmd.arg("-l"); } - cmd.arg("-c") - .arg("printf '%s' \"$PATH\"") - .env("TERM", "dumb") + if let Some(source_command) = shell.source_command() { + cmd.arg("-c") + .arg(format!("{source_command}; printf '%s' \"$PATH\"")); + } else { + cmd.arg("-c").arg("printf '%s' \"$PATH\""); + } + cmd.env("TERM", "dumb") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) .kill_on_drop(true); const PATH_REFRESH_COMMAND_TIMEOUT: Duration = Duration::from_secs(5); @@ -142,7 +215,7 @@ async fn get_fresh_path() -> Option { Ok(Ok(output)) => output, Ok(Err(err)) => { tracing::debug!( - shell = %shell.display(), + shell = %shell.path().display(), ?err, "Failed to retrieve PATH from login shell" ); @@ -150,7 +223,7 @@ async fn get_fresh_path() -> Option { } Err(_) => { tracing::warn!( - shell = %shell.display(), + shell = %shell.path().display(), timeout_secs = PATH_REFRESH_COMMAND_TIMEOUT.as_secs(), "Timed out retrieving PATH from login shell" ); @@ -167,33 +240,15 @@ async fn get_fresh_path() -> Option { let mut paths = Vec::new(); - let shells = vec![ - (PathBuf::from("/bin/zsh"), true), - (PathBuf::from("/bin/bash"), true), - (PathBuf::from("/bin/sh"), false), - ]; - - let mut current_shell_name = None; - if let Ok(shell) = std::env::var("SHELL") { - let path = Path::new(&shell); - if path.is_absolute() && path.is_file() { - current_shell_name = path.file_name().and_then(OsStr::to_str).map(String::from); - if let Some(path) = run(path, true).await { - paths.push(path); - } - } + let current_shell = UnixShell::current_shell(); + if let Some(path) = run(¤t_shell).await { + paths.push(path); } - for (shell_path, login) in shells { - if !shell_path.exists() { - continue; - } - let shell_name = shell_path - .file_name() - .and_then(OsStr::to_str) - .map(String::from); - if current_shell_name != shell_name - && let Some(path) = run(&shell_path, login).await + let shells = vec![UnixShell::Zsh, UnixShell::Bash, UnixShell::Sh]; + for shell in shells { + if !(shell == current_shell) + && let Some(path) = run(&shell).await { paths.push(path); }