- remove AbortController in PendingApprovalEntry (#902)

- fix find execution process check
- force stop process with non-killed status
This commit is contained in:
Gabriel Gordon-Hall
2025-10-01 16:51:23 +01:00
committed by GitHub
parent c78f48ae02
commit 0ace01b55f
13 changed files with 101 additions and 181 deletions

View File

@@ -11,7 +11,10 @@ use sqlx::{Error as SqlxError, SqlitePool};
use thiserror::Error;
use tokio::sync::{RwLock, oneshot};
use utils::{
approvals::{ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus},
approvals::{
ApprovalPendingInfo, ApprovalRequest, ApprovalResponse, ApprovalStatus,
CreateApprovalRequest,
},
log_msg::LogMsg,
msg_store::MsgStore,
};
@@ -38,7 +41,6 @@ pub struct ToolContext {
pub struct Approvals {
pending: Arc<DashMap<String, PendingApproval>>,
completed: Arc<DashMap<String, ApprovalStatus>>,
db_pool: SqlitePool,
msg_stores: Arc<RwLock<HashMap<Uuid, Arc<MsgStore>>>>,
}
@@ -59,32 +61,20 @@ pub enum ApprovalError {
}
impl Approvals {
pub fn new(db_pool: SqlitePool, msg_stores: Arc<RwLock<HashMap<Uuid, Arc<MsgStore>>>>) -> Self {
pub fn new(msg_stores: Arc<RwLock<HashMap<Uuid, Arc<MsgStore>>>>) -> Self {
Self {
pending: Arc::new(DashMap::new()),
completed: Arc::new(DashMap::new()),
db_pool,
msg_stores,
}
}
#[tracing::instrument(skip(self, request))]
pub async fn create(&self, request: ApprovalRequest) -> Result<ApprovalRequest, ApprovalError> {
let execution_process_id = if let Some(executor_session) =
ExecutorSession::find_by_session_id(&self.db_pool, &request.session_id).await?
{
executor_session.execution_process_id
} else {
tracing::warn!(
"No executor session found for session_id: {}",
request.session_id
);
return Err(ApprovalError::NoExecutorSession(request.session_id.clone()));
};
let (tx, rx) = oneshot::channel();
let req_id = request.id.clone();
if let Some(store) = self.msg_store_by_id(&execution_process_id).await {
if let Some(store) = self.msg_store_by_id(&request.execution_process_id).await {
let last_tool = get_last_tool_use(store.clone());
if let Some((idx, last_tool)) = last_tool {
let approval_entry = last_tool
@@ -101,7 +91,7 @@ impl Approvals {
PendingApproval {
entry_index: idx,
entry: last_tool,
execution_process_id,
execution_process_id: request.execution_process_id,
tool_name: request.tool_name.clone(),
requested_at: request.created_at,
timeout_at: request.timeout_at,
@@ -112,7 +102,7 @@ impl Approvals {
} else {
tracing::warn!(
"No msg_store found for execution_process_id: {}",
execution_process_id
request.execution_process_id
);
}
@@ -120,6 +110,26 @@ impl Approvals {
Ok(request)
}
pub async fn create_from_session(
&self,
pool: &SqlitePool,
payload: CreateApprovalRequest,
) -> Result<ApprovalRequest, ApprovalError> {
let session_id = payload.session_id.clone();
let execution_process_id =
match ExecutorSession::find_by_session_id(pool, &session_id).await? {
Some(session) => session.execution_process_id,
None => {
tracing::warn!("No executor session found for session_id: {}", session_id);
return Err(ApprovalError::NoExecutorSession(session_id));
}
};
let request = ApprovalRequest::from_create(payload, execution_process_id);
self.create(request).await
}
#[tracing::instrument(skip(self, id, req))]
pub async fn respond(
&self,
id: &str,
@@ -129,7 +139,7 @@ impl Approvals {
self.completed.insert(id.to_string(), req.status.clone());
let _ = p.response_tx.send(req.status.clone());
if let Some(store) = self.msg_store_by_id(&req.execution_process_id).await {
if let Some(store) = self.msg_store_by_id(&p.execution_process_id).await {
let status = ToolStatus::from_approval_status(&req.status).ok_or(
ApprovalError::Custom(anyhow::anyhow!("Invalid approval status")),
)?;
@@ -142,7 +152,7 @@ impl Approvals {
} else {
tracing::warn!(
"No msg_store found for execution_process_id: {}",
req.execution_process_id
p.execution_process_id
);
}
@@ -191,6 +201,7 @@ impl Approvals {
.collect()
}
#[tracing::instrument(skip(self, id, timeout_at, rx))]
fn spawn_timeout_watcher(
&self,
id: String,

View File

@@ -168,14 +168,16 @@ pub trait ContainerService {
{
for process in processes {
if process.status == ExecutionProcessStatus::Running {
self.stop_execution(&process).await.unwrap_or_else(|e| {
tracing::debug!(
"Failed to stop execution process {} for task attempt {}: {}",
process.id,
task_attempt.id,
e
);
});
self.stop_execution(&process, ExecutionProcessStatus::Killed)
.await
.unwrap_or_else(|e| {
tracing::debug!(
"Failed to stop execution process {} for task attempt {}: {}",
process.id,
task_attempt.id,
e
);
});
}
}
}
@@ -199,6 +201,7 @@ pub trait ContainerService {
async fn stop_execution(
&self,
execution_process: &ExecutionProcess,
status: ExecutionProcessStatus,
) -> Result<(), ContainerError>;
async fn try_commit_changes(&self, ctx: &ExecutionContext) -> Result<bool, ContainerError>;
@@ -708,17 +711,13 @@ pub trait ContainerService {
async fn exit_plan_mode_tool(&self, ctx: ExecutionContext) -> Result<(), ContainerError> {
let execution_id = ctx.execution_process.id;
if let Err(err) = self.stop_execution(&ctx.execution_process).await {
if let Err(err) = self
.stop_execution(&ctx.execution_process, ExecutionProcessStatus::Completed)
.await
{
tracing::error!("Failed to stop execution process {}: {}", execution_id, err);
return Err(err);
}
let _ = ExecutionProcess::update_completion(
&self.db().pool,
execution_id,
ExecutionProcessStatus::Completed,
Some(0),
)
.await;
let action = ctx.execution_process.executor_action()?;
let executor_profile_id = match action.typ() {
@@ -736,6 +735,7 @@ pub trait ContainerService {
ExecutorSession::find_by_execution_process_id(&self.db().pool, execution_id)
.await?
.and_then(|s| s.session_id);
if session_id.is_none() {
tracing::warn!(
"No executor session found for execution process {}",