From 53a3bab0c063b9c7ee4f8b406b542cb54fba42b1 Mon Sep 17 00:00:00 2001 From: Louis Knight-Webb Date: Mon, 7 Jul 2025 14:11:16 +0100 Subject: [PATCH] Refactor app state (#85) * Init * Refactor project endpoints * Remaining endpoints * Fmt --- backend/src/app_state.rs | 4 + backend/src/main.rs | 8 +- backend/src/routes/config.rs | 28 ++-- backend/src/routes/filesystem.rs | 4 +- backend/src/routes/projects.rs | 86 +++++----- backend/src/routes/task_attempts.rs | 233 +++++++++++++++------------- backend/src/routes/tasks.rs | 79 +++++----- 7 files changed, 243 insertions(+), 199 deletions(-) diff --git a/backend/src/app_state.rs b/backend/src/app_state.rs index 50ea5145..1de684b1 100644 --- a/backend/src/app_state.rs +++ b/backend/src/app_state.rs @@ -227,6 +227,10 @@ impl AppState { config.sound_file.clone() } + pub fn get_config(&self) -> &Arc> { + &self.config + } + pub async fn track_analytics_event( &self, event_name: &str, diff --git a/backend/src/main.rs b/backend/src/main.rs index dbcbd636..e7b802fd 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -2,7 +2,6 @@ use std::{str::FromStr, sync::Arc}; use axum::{ body::Body, - extract::Extension, http::{header, HeaderValue, StatusCode}, response::{IntoResponse, Json as ResponseJson, Response}, routing::{get, post}, @@ -198,9 +197,7 @@ fn main() -> anyhow::Result<()> { .merge(filesystem::filesystem_router()) .merge(config::config_router()) .route("/sounds/:filename", get(serve_sound_file)), - ) - .layer(Extension(pool.clone())) - .layer(Extension(config_arc)); + ); let app = Router::new() .merge(public_routes) @@ -208,8 +205,7 @@ fn main() -> anyhow::Result<()> { // Static file serving routes .route("/", get(index_handler)) .route("/*path", get(static_handler)) - .layer(Extension(pool)) - .layer(Extension(app_state)) + .with_state(app_state) .layer(CorsLayer::permissive()) .layer(NewSentryLayer::new_from_top()); diff --git a/backend/src/routes/config.rs b/backend/src/routes/config.rs index a744cb7e..ca5e2c25 100644 --- a/backend/src/routes/config.rs +++ b/backend/src/routes/config.rs @@ -1,17 +1,18 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use axum::{ - extract::{Extension, Query}, + extract::{Query, State}, response::Json as ResponseJson, routing::{get, post}, Json, Router, }; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tokio::{fs, sync::RwLock}; +use tokio::fs; use ts_rs::TS; use crate::{ + app_state::AppState, executor::ExecutorConfig, models::{ config::{Config, EditorConstants, SoundConstants}, @@ -20,7 +21,7 @@ use crate::{ utils, }; -pub fn config_router() -> Router { +pub fn config_router() -> Router { Router::new() .route("/config", get(get_config)) .route("/config", post(update_config)) @@ -29,10 +30,8 @@ pub fn config_router() -> Router { .route("/mcp-servers", post(update_mcp_servers)) } -async fn get_config( - Extension(config): Extension>>, -) -> ResponseJson> { - let config = config.read().await; +async fn get_config(State(app_state): State) -> ResponseJson> { + let config = app_state.get_config().read().await; ResponseJson(ApiResponse { success: true, data: Some(config.clone()), @@ -41,15 +40,14 @@ async fn get_config( } async fn update_config( - Extension(config_arc): Extension>>, - Extension(app_state): Extension, + State(app_state): State, Json(new_config): Json, ) -> ResponseJson> { let config_path = utils::config_path(); match new_config.save(&config_path) { Ok(_) => { - let mut config = config_arc.write().await; + let mut config = app_state.get_config().write().await; *config = new_config.clone(); drop(config); @@ -119,11 +117,11 @@ fn resolve_executor_config( } async fn get_mcp_servers( - Extension(config): Extension>>, + State(app_state): State, Query(query): Query, ) -> ResponseJson> { let saved_config = { - let config = config.read().await; + let config = app_state.get_config().read().await; config.executor.clone() }; @@ -171,12 +169,12 @@ async fn get_mcp_servers( } async fn update_mcp_servers( - Extension(config): Extension>>, + State(app_state): State, Query(query): Query, Json(new_servers): Json>, ) -> ResponseJson> { let saved_config = { - let config = config.read().await; + let config = app_state.get_config().read().await; config.executor.clone() }; diff --git a/backend/src/routes/filesystem.rs b/backend/src/routes/filesystem.rs index 81f59c78..107e1482 100644 --- a/backend/src/routes/filesystem.rs +++ b/backend/src/routes/filesystem.rs @@ -9,7 +9,7 @@ use axum::{ use serde::{Deserialize, Serialize}; use ts_rs::TS; -use crate::models::ApiResponse; +use crate::{app_state::AppState, models::ApiResponse}; #[derive(Debug, Serialize, TS)] #[ts(export)] @@ -199,7 +199,7 @@ pub async fn create_git_repo( } } -pub fn filesystem_router() -> Router { +pub fn filesystem_router() -> Router { Router::new() .route("/filesystem/list", get(list_directory)) .route("/filesystem/validate-git", get(validate_git_path)) diff --git a/backend/src/routes/projects.rs b/backend/src/routes/projects.rs index 7a37ca0c..f94c6b10 100644 --- a/backend/src/routes/projects.rs +++ b/backend/src/routes/projects.rs @@ -1,28 +1,29 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use axum::{ - extract::{Extension, Path, Query}, + extract::{Path, Query, State}, http::StatusCode, response::Json as ResponseJson, routing::get, Json, Router, }; -use sqlx::SqlitePool; -use tokio::sync::RwLock; use uuid::Uuid; -use crate::models::{ - project::{ - CreateBranch, CreateProject, GitBranch, Project, ProjectWithBranch, SearchMatchType, - SearchResult, UpdateProject, +use crate::{ + app_state::AppState, + models::{ + project::{ + CreateBranch, CreateProject, GitBranch, Project, ProjectWithBranch, SearchMatchType, + SearchResult, UpdateProject, + }, + ApiResponse, }, - ApiResponse, }; pub async fn get_projects( - Extension(pool): Extension, + State(app_state): State, ) -> Result>>, StatusCode> { - match Project::find_all(&pool).await { + match Project::find_all(&app_state.db_pool).await { Ok(projects) => Ok(ResponseJson(ApiResponse { success: true, data: Some(projects), @@ -37,9 +38,9 @@ pub async fn get_projects( pub async fn get_project( Path(id): Path, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { - match Project::find_by_id(&pool, id).await { + match Project::find_by_id(&app_state.db_pool, id).await { Ok(Some(project)) => Ok(ResponseJson(ApiResponse { success: true, data: Some(project), @@ -55,9 +56,9 @@ pub async fn get_project( pub async fn get_project_with_branch( Path(id): Path, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { - match Project::find_by_id(&pool, id).await { + match Project::find_by_id(&app_state.db_pool, id).await { Ok(Some(project)) => Ok(ResponseJson(ApiResponse { success: true, data: Some(project.with_branch_info()), @@ -73,9 +74,9 @@ pub async fn get_project_with_branch( pub async fn get_project_branches( Path(id): Path, - Extension(pool): Extension, + State(app_state): State, ) -> Result>>, StatusCode> { - match Project::find_by_id(&pool, id).await { + match Project::find_by_id(&app_state.db_pool, id).await { Ok(Some(project)) => match project.get_all_branches() { Ok(branches) => Ok(ResponseJson(ApiResponse { success: true, @@ -97,7 +98,7 @@ pub async fn get_project_branches( pub async fn create_project_branch( Path(id): Path, - Extension(pool): Extension, + State(app_state): State, Json(payload): Json, ) -> Result>, StatusCode> { // Validate branch name @@ -118,7 +119,7 @@ pub async fn create_project_branch( })); } - match Project::find_by_id(&pool, id).await { + match Project::find_by_id(&app_state.db_pool, id).await { Ok(Some(project)) => { match project.create_branch(&payload.name, payload.base_branch.as_deref()) { Ok(branch) => Ok(ResponseJson(ApiResponse { @@ -150,8 +151,7 @@ pub async fn create_project_branch( } pub async fn create_project( - Extension(pool): Extension, - Extension(app_state): Extension, + State(app_state): State, Json(payload): Json, ) -> Result>, StatusCode> { let id = Uuid::new_v4(); @@ -159,7 +159,7 @@ pub async fn create_project( tracing::debug!("Creating project '{}'", payload.name); // Check if git repo path is already used by another project - match Project::find_by_git_repo_path(&pool, &payload.git_repo_path).await { + match Project::find_by_git_repo_path(&app_state.db_pool, &payload.git_repo_path).await { Ok(Some(_)) => { return Ok(ResponseJson(ApiResponse { success: false, @@ -249,7 +249,7 @@ pub async fn create_project( } } - match Project::create(&pool, &payload, id).await { + match Project::create(&app_state.db_pool, &payload, id).await { Ok(project) => { // Track project creation event app_state @@ -279,11 +279,11 @@ pub async fn create_project( pub async fn update_project( Path(id): Path, - Extension(pool): Extension, + State(app_state): State, Json(payload): Json, ) -> Result>, StatusCode> { // Check if project exists first - let existing_project = match Project::find_by_id(&pool, id).await { + let existing_project = match Project::find_by_id(&app_state.db_pool, id).await { Ok(Some(project)) => project, Ok(None) => return Err(StatusCode::NOT_FOUND), Err(e) => { @@ -295,7 +295,13 @@ pub async fn update_project( // If git_repo_path is being changed, check if the new path is already used by another project if let Some(new_git_repo_path) = &payload.git_repo_path { if new_git_repo_path != &existing_project.git_repo_path { - match Project::find_by_git_repo_path_excluding_id(&pool, new_git_repo_path, id).await { + match Project::find_by_git_repo_path_excluding_id( + &app_state.db_pool, + new_git_repo_path, + id, + ) + .await + { Ok(Some(_)) => { return Ok(ResponseJson(ApiResponse { success: false, @@ -329,7 +335,16 @@ pub async fn update_project( let name = name.unwrap_or(existing_project.name); let git_repo_path = git_repo_path.unwrap_or(existing_project.git_repo_path); - match Project::update(&pool, id, name, git_repo_path, setup_script, dev_script).await { + match Project::update( + &app_state.db_pool, + id, + name, + git_repo_path, + setup_script, + dev_script, + ) + .await + { Ok(project) => Ok(ResponseJson(ApiResponse { success: true, data: Some(project), @@ -344,9 +359,9 @@ pub async fn update_project( pub async fn delete_project( Path(id): Path, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { - match Project::delete(&pool, id).await { + match Project::delete(&app_state.db_pool, id).await { Ok(rows_affected) => { if rows_affected == 0 { Err(StatusCode::NOT_FOUND) @@ -372,12 +387,11 @@ pub struct OpenEditorRequest { pub async fn open_project_in_editor( Path(id): Path, - Extension(pool): Extension, - Extension(config): Extension>>, + State(app_state): State, Json(payload): Json>, ) -> Result>, StatusCode> { // Get the project - let project = match Project::find_by_id(&pool, id).await { + let project = match Project::find_by_id(&app_state.db_pool, id).await { Ok(Some(project)) => project, Ok(None) => return Err(StatusCode::NOT_FOUND), Err(e) => { @@ -388,7 +402,7 @@ pub async fn open_project_in_editor( // Get editor command from config or override let editor_command = { - let config_guard = config.read().await; + let config_guard = app_state.get_config().read().await; if let Some(ref request) = payload { if let Some(ref editor_type) = request.editor_type { // Create a temporary editor config with the override @@ -451,7 +465,7 @@ pub async fn open_project_in_editor( pub async fn search_project_files( Path(id): Path, Query(params): Query>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>>, StatusCode> { let query = match params.get("q") { Some(q) if !q.trim().is_empty() => q.trim(), @@ -465,7 +479,7 @@ pub async fn search_project_files( }; // Check if project exists - let project = match Project::find_by_id(&pool, id).await { + let project = match Project::find_by_id(&app_state.db_pool, id).await { Ok(Some(project)) => project, Ok(None) => return Err(StatusCode::NOT_FOUND), Err(e) => { @@ -587,7 +601,7 @@ async fn search_files_in_repo( Ok(results) } -pub fn projects_router() -> Router { +pub fn projects_router() -> Router { use axum::routing::post; Router::new() diff --git a/backend/src/routes/task_attempts.rs b/backend/src/routes/task_attempts.rs index 84c9e2cc..d45f815c 100644 --- a/backend/src/routes/task_attempts.rs +++ b/backend/src/routes/task_attempts.rs @@ -1,15 +1,11 @@ -use std::sync::Arc; - use axum::{ - extract::{Extension, Path, Query}, + extract::{Path, Query, State}, http::StatusCode, response::Json as ResponseJson, routing::get, Json, Router, }; use serde::{Deserialize, Serialize}; -use sqlx::SqlitePool; -use tokio::sync::RwLock; use uuid::Uuid; use crate::{ @@ -45,10 +41,10 @@ pub struct CreateGitHubPRRequest { pub async fn get_task_attempts( Path((project_id, task_id)): Path<(Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>>, StatusCode> { // Verify task exists in project first - match Task::exists(&pool, task_id, project_id).await { + match Task::exists(&app_state.db_pool, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task existence: {}", e); @@ -57,7 +53,7 @@ pub async fn get_task_attempts( Ok(true) => {} } - match TaskAttempt::find_by_task_id(&pool, task_id).await { + match TaskAttempt::find_by_task_id(&app_state.db_pool, task_id).await { Ok(attempts) => Ok(ResponseJson(ApiResponse { success: true, data: Some(attempts), @@ -72,10 +68,10 @@ pub async fn get_task_attempts( pub async fn get_task_attempt_activities( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -85,7 +81,9 @@ pub async fn get_task_attempt_activities( } // Get activities with prompts for the task attempt - match TaskAttemptActivity::find_with_prompts_by_task_attempt_id(&pool, attempt_id).await { + match TaskAttemptActivity::find_with_prompts_by_task_attempt_id(&app_state.db_pool, attempt_id) + .await + { Ok(activities) => Ok(ResponseJson(ApiResponse { success: true, data: Some(activities), @@ -104,12 +102,11 @@ pub async fn get_task_attempt_activities( pub async fn create_task_attempt( Path((project_id, task_id)): Path<(Uuid, Uuid)>, - Extension(pool): Extension, - Extension(app_state): Extension, + State(app_state): State, Json(payload): Json, ) -> Result>, StatusCode> { // Verify task exists in project first - match Task::exists(&pool, task_id, project_id).await { + match Task::exists(&app_state.db_pool, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task existence: {}", e); @@ -120,7 +117,7 @@ pub async fn create_task_attempt( let executor_string = payload.executor.as_ref().map(|exec| exec.to_string()); - match TaskAttempt::create(&pool, &payload, task_id).await { + match TaskAttempt::create(&app_state.db_pool, &payload, task_id).await { Ok(attempt) => { app_state .track_analytics_event( @@ -134,12 +131,11 @@ pub async fn create_task_attempt( .await; // Start execution asynchronously (don't block the response) - let pool_clone = pool.clone(); let app_state_clone = app_state.clone(); let attempt_id = attempt.id; tokio::spawn(async move { if let Err(e) = TaskAttempt::start_execution( - &pool_clone, + &app_state_clone.db_pool, &app_state_clone, attempt_id, task_id, @@ -170,11 +166,11 @@ pub async fn create_task_attempt( pub async fn create_task_attempt_activity( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, Json(payload): Json, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -191,7 +187,7 @@ pub async fn create_task_attempt_activity( } // Verify the execution process exists and belongs to this task attempt - match ExecutionProcess::find_by_id(&pool, payload.execution_process_id).await { + match ExecutionProcess::find_by_id(&app_state.db_pool, payload.execution_process_id).await { Ok(Some(process)) => { if process.task_attempt_id != attempt_id { return Err(StatusCode::BAD_REQUEST); @@ -210,7 +206,7 @@ pub async fn create_task_attempt_activity( .clone() .unwrap_or(TaskAttemptStatus::SetupRunning); - match TaskAttemptActivity::create(&pool, &payload, id, status).await { + match TaskAttemptActivity::create(&app_state.db_pool, &payload, id, status).await { Ok(activity) => Ok(ResponseJson(ApiResponse { success: true, data: Some(activity), @@ -225,9 +221,9 @@ pub async fn create_task_attempt_activity( pub async fn get_task_attempt_diff( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { - match TaskAttempt::get_diff(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::get_diff(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(diff) => Ok(ResponseJson(ApiResponse { success: true, data: Some(diff), @@ -243,11 +239,10 @@ pub async fn get_task_attempt_diff( #[axum::debug_handler] pub async fn merge_task_attempt( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, - Extension(app_state): Extension>, + State(app_state): State, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -256,11 +251,11 @@ pub async fn merge_task_attempt( Ok(true) => {} } - match TaskAttempt::merge_changes(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::merge_changes(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(_) => { // Update task status to Done if let Err(e) = Task::update_status( - &pool, + &app_state.db_pool, task_id, project_id, crate::models::task::TaskStatus::Done, @@ -298,12 +293,11 @@ pub async fn merge_task_attempt( pub async fn create_github_pr( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, - Extension(app_state): Extension>, + State(app_state): State, Json(request): Json, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -341,7 +335,7 @@ pub async fn create_github_pr( .unwrap_or_else(|| "main".to_string()); match TaskAttempt::create_github_pr( - &pool, + &app_state.db_pool, CreatePrParams { attempt_id, task_id, @@ -394,12 +388,11 @@ pub struct OpenEditorRequest { pub async fn open_task_attempt_in_editor( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, - Extension(config): Extension>>, + State(app_state): State, Json(payload): Json>, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -409,7 +402,7 @@ pub async fn open_task_attempt_in_editor( } // Get the task attempt to access the worktree path - let attempt = match TaskAttempt::find_by_id(&pool, attempt_id).await { + let attempt = match TaskAttempt::find_by_id(&app_state.db_pool, attempt_id).await { Ok(Some(attempt)) => attempt, Ok(None) => return Err(StatusCode::NOT_FOUND), Err(e) => { @@ -420,7 +413,7 @@ pub async fn open_task_attempt_in_editor( // Get editor command from config or override let editor_command = { - let config_guard = config.read().await; + let config_guard = app_state.get_config().read().await; if let Some(ref request) = payload { if let Some(ref editor_type) = request.editor_type { // Create a temporary editor config with the override @@ -482,9 +475,10 @@ pub async fn open_task_attempt_in_editor( pub async fn get_task_attempt_branch_status( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { - match TaskAttempt::get_branch_status(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::get_branch_status(&app_state.db_pool, attempt_id, task_id, project_id).await + { Ok(status) => Ok(ResponseJson(ApiResponse { success: true, data: Some(status), @@ -504,11 +498,11 @@ pub async fn get_task_attempt_branch_status( #[axum::debug_handler] pub async fn rebase_task_attempt( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, request_body: Option>, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -520,7 +514,14 @@ pub async fn rebase_task_attempt( // Extract new base branch from request body if provided let new_base_branch = request_body.and_then(|body| body.new_base_branch.clone()); - match TaskAttempt::rebase_attempt(&pool, attempt_id, task_id, project_id, new_base_branch).await + match TaskAttempt::rebase_attempt( + &app_state.db_pool, + attempt_id, + task_id, + project_id, + new_base_branch, + ) + .await { Ok(_new_base_commit) => Ok(ResponseJson(ApiResponse { success: true, @@ -540,10 +541,10 @@ pub async fn rebase_task_attempt( pub async fn get_task_attempt_execution_processes( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -552,7 +553,8 @@ pub async fn get_task_attempt_execution_processes( Ok(true) => {} } - match ExecutionProcess::find_summaries_by_task_attempt_id(&pool, attempt_id).await { + match ExecutionProcess::find_summaries_by_task_attempt_id(&app_state.db_pool, attempt_id).await + { Ok(processes) => Ok(ResponseJson(ApiResponse { success: true, data: Some(processes), @@ -571,14 +573,14 @@ pub async fn get_task_attempt_execution_processes( pub async fn get_execution_process( Path((project_id, process_id)): Path<(Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { - match ExecutionProcess::find_by_id(&pool, process_id).await { + match ExecutionProcess::find_by_id(&app_state.db_pool, process_id).await { Ok(Some(process)) => { // Verify the process belongs to a task attempt in the correct project - match TaskAttempt::find_by_id(&pool, process.task_attempt_id).await { + match TaskAttempt::find_by_id(&app_state.db_pool, process.task_attempt_id).await { Ok(Some(attempt)) => { - match Task::find_by_id(&pool, attempt.task_id).await { + match Task::find_by_id(&app_state.db_pool, attempt.task_id).await { Ok(Some(task)) if task.project_id == project_id => { Ok(ResponseJson(ApiResponse { success: true, @@ -612,11 +614,10 @@ pub async fn get_execution_process( #[axum::debug_handler] pub async fn stop_all_execution_processes( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, - Extension(app_state): Extension, + State(app_state): State, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -626,17 +627,18 @@ pub async fn stop_all_execution_processes( } // Get all execution processes for the task attempt - let processes = match ExecutionProcess::find_by_task_attempt_id(&pool, attempt_id).await { - Ok(processes) => processes, - Err(e) => { - tracing::error!( - "Failed to fetch execution processes for attempt {}: {}", - attempt_id, - e - ); - return Err(StatusCode::INTERNAL_SERVER_ERROR); - } - }; + let processes = + match ExecutionProcess::find_by_task_attempt_id(&app_state.db_pool, attempt_id).await { + Ok(processes) => processes, + Err(e) => { + tracing::error!( + "Failed to fetch execution processes for attempt {}: {}", + attempt_id, + e + ); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + }; let mut stopped_count = 0; let mut errors = Vec::new(); @@ -649,7 +651,7 @@ pub async fn stop_all_execution_processes( // Update the execution process status in the database if let Err(e) = ExecutionProcess::update_completion( - &pool, + &app_state.db_pool, process.id, crate::models::execution_process::ExecutionProcessStatus::Killed, None, @@ -675,7 +677,7 @@ pub async fn stop_all_execution_processes( }; if let Err(e) = TaskAttemptActivity::create( - &pool, + &app_state.db_pool, &create_activity, activity_id, TaskAttemptStatus::ExecutorFailed, @@ -734,11 +736,10 @@ pub async fn stop_all_execution_processes( #[axum::debug_handler] pub async fn stop_execution_process( Path((project_id, task_id, attempt_id, process_id)): Path<(Uuid, Uuid, Uuid, Uuid)>, - Extension(pool): Extension, - Extension(app_state): Extension, + State(app_state): State, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -748,7 +749,7 @@ pub async fn stop_execution_process( } // Verify execution process exists and belongs to the task attempt - let process = match ExecutionProcess::find_by_id(&pool, process_id).await { + let process = match ExecutionProcess::find_by_id(&app_state.db_pool, process_id).await { Ok(Some(process)) if process.task_attempt_id == attempt_id => process, Ok(Some(_)) => return Err(StatusCode::NOT_FOUND), // Process exists but wrong attempt Ok(None) => return Err(StatusCode::NOT_FOUND), @@ -777,7 +778,7 @@ pub async fn stop_execution_process( // Update the execution process status in the database if let Err(e) = ExecutionProcess::update_completion( - &pool, + &app_state.db_pool, process_id, crate::models::execution_process::ExecutionProcessStatus::Killed, None, @@ -804,7 +805,7 @@ pub async fn stop_execution_process( }; if let Err(e) = TaskAttemptActivity::create( - &pool, + &app_state.db_pool, &create_activity, activity_id, TaskAttemptStatus::ExecutorFailed, @@ -835,10 +836,10 @@ pub struct DeleteFileQuery { pub async fn delete_task_attempt_file( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, Query(query): Query, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -847,7 +848,15 @@ pub async fn delete_task_attempt_file( Ok(true) => {} } - match TaskAttempt::delete_file(&pool, attempt_id, task_id, project_id, &query.file_path).await { + match TaskAttempt::delete_file( + &app_state.db_pool, + attempt_id, + task_id, + project_id, + &query.file_path, + ) + .await + { Ok(_commit_id) => Ok(ResponseJson(ApiResponse { success: true, data: None, @@ -871,12 +880,11 @@ pub async fn delete_task_attempt_file( pub async fn create_followup_attempt( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, - Extension(app_state): Extension, + State(app_state): State, Json(payload): Json, ) -> Result>, StatusCode> { // Verify task attempt exists - if !TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id) + if !TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id) .await .map_err(|e| { tracing::error!("Failed to check task attempt existence: {}", e); @@ -888,7 +896,7 @@ pub async fn create_followup_attempt( // Start follow-up execution synchronously to catch errors match TaskAttempt::start_followup_execution( - &pool, + &app_state.db_pool, &app_state, attempt_id, task_id, @@ -915,11 +923,10 @@ pub async fn create_followup_attempt( pub async fn start_dev_server( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, - Extension(app_state): Extension, + State(app_state): State, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -930,7 +937,9 @@ pub async fn start_dev_server( // Stop any existing dev servers for this project let existing_dev_servers = - match ExecutionProcess::find_running_dev_servers_by_project(&pool, project_id).await { + match ExecutionProcess::find_running_dev_servers_by_project(&app_state.db_pool, project_id) + .await + { Ok(servers) => servers, Err(e) => { tracing::error!( @@ -955,7 +964,7 @@ pub async fn start_dev_server( } else { // Update the execution process status in the database if let Err(e) = ExecutionProcess::update_completion( - &pool, + &app_state.db_pool, dev_server.id, crate::models::execution_process::ExecutionProcessStatus::Killed, None, @@ -972,7 +981,15 @@ pub async fn start_dev_server( } // Start dev server execution - match TaskAttempt::start_dev_server(&pool, &app_state, attempt_id, task_id, project_id).await { + match TaskAttempt::start_dev_server( + &app_state.db_pool, + &app_state, + attempt_id, + task_id, + project_id, + ) + .await + { Ok(_) => Ok(ResponseJson(ApiResponse { success: true, data: None, @@ -995,10 +1012,10 @@ pub async fn start_dev_server( pub async fn get_task_attempt_execution_state( Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { // Verify task attempt exists and belongs to the correct task - match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check task attempt existence: {}", e); @@ -1008,7 +1025,9 @@ pub async fn get_task_attempt_execution_state( } // Get the execution state - match TaskAttempt::get_execution_state(&pool, attempt_id, task_id, project_id).await { + match TaskAttempt::get_execution_state(&app_state.db_pool, attempt_id, task_id, project_id) + .await + { Ok(state) => Ok(ResponseJson(ApiResponse { success: true, data: Some(state), @@ -1027,10 +1046,10 @@ pub async fn get_task_attempt_execution_state( pub async fn get_execution_process_normalized_logs( Path((project_id, process_id)): Path<(Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { // Get the execution process and verify it belongs to the correct project - let process = match ExecutionProcess::find_by_id(&pool, process_id).await { + let process = match ExecutionProcess::find_by_id(&app_state.db_pool, process_id).await { Ok(Some(process)) => process, Ok(None) => return Err(StatusCode::NOT_FOUND), Err(e) => { @@ -1040,7 +1059,7 @@ pub async fn get_execution_process_normalized_logs( }; // Verify the process belongs to a task attempt in the correct project - let attempt = match TaskAttempt::find_by_id(&pool, process.task_attempt_id).await { + let attempt = match TaskAttempt::find_by_id(&app_state.db_pool, process.task_attempt_id).await { Ok(Some(attempt)) => attempt, Ok(None) => return Err(StatusCode::NOT_FOUND), Err(e) => { @@ -1049,7 +1068,7 @@ pub async fn get_execution_process_normalized_logs( } }; - let _task = match Task::find_by_id(&pool, attempt.task_id).await { + let _task = match Task::find_by_id(&app_state.db_pool, attempt.task_id).await { Ok(Some(task)) if task.project_id == project_id => task, Ok(Some(_)) => return Err(StatusCode::NOT_FOUND), // Wrong project Ok(None) => return Err(StatusCode::NOT_FOUND), @@ -1066,18 +1085,22 @@ pub async fn get_execution_process_normalized_logs( // If the process is still running, return empty logs instead of an error if process.status == ExecutionProcessStatus::Running { // Get executor session data for this execution process - let executor_session = - match ExecutorSession::find_by_execution_process_id(&pool, process_id).await { - Ok(session) => session, - Err(e) => { - tracing::error!( - "Failed to fetch executor session for process {}: {}", - process_id, - e - ); - None - } - }; + let executor_session = match ExecutorSession::find_by_execution_process_id( + &app_state.db_pool, + process_id, + ) + .await + { + Ok(session) => session, + Err(e) => { + tracing::error!( + "Failed to fetch executor session for process {}: {}", + process_id, + e + ); + None + } + }; return Ok(ResponseJson(ApiResponse { success: true, @@ -1124,7 +1147,7 @@ pub async fn get_execution_process_normalized_logs( // Get executor session data for this execution process let executor_session = - match ExecutorSession::find_by_execution_process_id(&pool, process_id).await { + match ExecutorSession::find_by_execution_process_id(&app_state.db_pool, process_id).await { Ok(session) => session, Err(e) => { tracing::error!( @@ -1162,7 +1185,7 @@ pub async fn get_execution_process_normalized_logs( } } -pub fn task_attempts_router() -> Router { +pub fn task_attempts_router() -> Router { use axum::routing::post; Router::new() diff --git a/backend/src/routes/tasks.rs b/backend/src/routes/tasks.rs index d5e3e872..5fe4784a 100644 --- a/backend/src/routes/tasks.rs +++ b/backend/src/routes/tasks.rs @@ -1,25 +1,27 @@ use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, http::StatusCode, response::Json as ResponseJson, routing::get, Json, Router, }; -use sqlx::SqlitePool; use uuid::Uuid; -use crate::models::{ - project::Project, - task::{CreateTask, CreateTaskAndStart, Task, TaskWithAttemptStatus, UpdateTask}, - task_attempt::{CreateTaskAttempt, TaskAttempt}, - ApiResponse, +use crate::{ + app_state::AppState, + models::{ + project::Project, + task::{CreateTask, CreateTaskAndStart, Task, TaskWithAttemptStatus, UpdateTask}, + task_attempt::{CreateTaskAttempt, TaskAttempt}, + ApiResponse, + }, }; pub async fn get_project_tasks( Path(project_id): Path, - Extension(pool): Extension, + State(app_state): State, ) -> Result>>, StatusCode> { - match Task::find_by_project_id_with_attempt_status(&pool, project_id).await { + match Task::find_by_project_id_with_attempt_status(&app_state.db_pool, project_id).await { Ok(tasks) => Ok(ResponseJson(ApiResponse { success: true, data: Some(tasks), @@ -34,9 +36,9 @@ pub async fn get_project_tasks( pub async fn get_task( Path((project_id, task_id)): Path<(Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { - match Task::find_by_id_and_project_id(&pool, task_id, project_id).await { + match Task::find_by_id_and_project_id(&app_state.db_pool, task_id, project_id).await { Ok(Some(task)) => Ok(ResponseJson(ApiResponse { success: true, data: Some(task), @@ -57,8 +59,7 @@ pub async fn get_task( pub async fn create_task( Path(project_id): Path, - Extension(pool): Extension, - Extension(app_state): Extension, + State(app_state): State, Json(mut payload): Json, ) -> Result>, StatusCode> { let id = Uuid::new_v4(); @@ -67,7 +68,7 @@ pub async fn create_task( payload.project_id = project_id; // Verify project exists first - match Project::exists(&pool, project_id).await { + match Project::exists(&app_state.db_pool, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check project existence: {}", e); @@ -82,7 +83,7 @@ pub async fn create_task( project_id ); - match Task::create(&pool, &payload, id).await { + match Task::create(&app_state.db_pool, &payload, id).await { Ok(task) => { // Track task creation event app_state @@ -111,8 +112,7 @@ pub async fn create_task( pub async fn create_task_and_start( Path(project_id): Path, - Extension(pool): Extension, - Extension(app_state): Extension, + State(app_state): State, Json(mut payload): Json, ) -> Result>, StatusCode> { let task_id = Uuid::new_v4(); @@ -121,7 +121,7 @@ pub async fn create_task_and_start( payload.project_id = project_id; // Verify project exists first - match Project::exists(&pool, project_id).await { + match Project::exists(&app_state.db_pool, project_id).await { Ok(false) => return Err(StatusCode::NOT_FOUND), Err(e) => { tracing::error!("Failed to check project existence: {}", e); @@ -142,7 +142,7 @@ pub async fn create_task_and_start( title: payload.title.clone(), description: payload.description.clone(), }; - let task = match Task::create(&pool, &create_task_payload, task_id).await { + let task = match Task::create(&app_state.db_pool, &create_task_payload, task_id).await { Ok(task) => task, Err(e) => { tracing::error!("Failed to create task: {}", e); @@ -157,7 +157,7 @@ pub async fn create_task_and_start( base_branch: None, // Not supported in task creation endpoint, only in task attempts }; - match TaskAttempt::create(&pool, &attempt_payload, task_id).await { + match TaskAttempt::create(&app_state.db_pool, &attempt_payload, task_id).await { Ok(attempt) => { app_state .track_analytics_event( @@ -182,12 +182,11 @@ pub async fn create_task_and_start( .await; // Start execution asynchronously (don't block the response) - let pool_clone = pool.clone(); let app_state_clone = app_state.clone(); let attempt_id = attempt.id; tokio::spawn(async move { if let Err(e) = TaskAttempt::start_execution( - &pool_clone, + &app_state_clone.db_pool, &app_state_clone, attempt_id, task_id, @@ -218,25 +217,35 @@ pub async fn create_task_and_start( pub async fn update_task( Path((project_id, task_id)): Path<(Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, Json(payload): Json, ) -> Result>, StatusCode> { // Check if task exists in the specified project - let existing_task = match Task::find_by_id_and_project_id(&pool, task_id, project_id).await { - Ok(Some(task)) => task, - Ok(None) => return Err(StatusCode::NOT_FOUND), - Err(e) => { - tracing::error!("Failed to check task existence: {}", e); - return Err(StatusCode::INTERNAL_SERVER_ERROR); - } - }; + let existing_task = + match Task::find_by_id_and_project_id(&app_state.db_pool, task_id, project_id).await { + Ok(Some(task)) => task, + Ok(None) => return Err(StatusCode::NOT_FOUND), + Err(e) => { + tracing::error!("Failed to check task existence: {}", e); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + }; // Use existing values if not provided in update let title = payload.title.unwrap_or(existing_task.title); let description = payload.description.or(existing_task.description); let status = payload.status.unwrap_or(existing_task.status); - match Task::update(&pool, task_id, project_id, title, description, status).await { + match Task::update( + &app_state.db_pool, + task_id, + project_id, + title, + description, + status, + ) + .await + { Ok(task) => Ok(ResponseJson(ApiResponse { success: true, data: Some(task), @@ -251,9 +260,9 @@ pub async fn update_task( pub async fn delete_task( Path((project_id, task_id)): Path<(Uuid, Uuid)>, - Extension(pool): Extension, + State(app_state): State, ) -> Result>, StatusCode> { - match Task::delete(&pool, task_id, project_id).await { + match Task::delete(&app_state.db_pool, task_id, project_id).await { Ok(rows_affected) => { if rows_affected == 0 { Err(StatusCode::NOT_FOUND) @@ -272,7 +281,7 @@ pub async fn delete_task( } } -pub fn tasks_router() -> Router { +pub fn tasks_router() -> Router { use axum::routing::post; Router::new()