Refactor app state (#85)

* Init

* Refactor project endpoints

* Remaining endpoints

* Fmt
This commit is contained in:
Louis Knight-Webb
2025-07-07 14:11:16 +01:00
committed by GitHub
parent 0e40c09b0d
commit 53a3bab0c0
7 changed files with 243 additions and 199 deletions

View File

@@ -227,6 +227,10 @@ impl AppState {
config.sound_file.clone()
}
pub fn get_config(&self) -> &Arc<tokio::sync::RwLock<crate::models::config::Config>> {
&self.config
}
pub async fn track_analytics_event(
&self,
event_name: &str,

View File

@@ -2,7 +2,6 @@ use std::{str::FromStr, sync::Arc};
use axum::{
body::Body,
extract::Extension,
http::{header, HeaderValue, StatusCode},
response::{IntoResponse, Json as ResponseJson, Response},
routing::{get, post},
@@ -198,9 +197,7 @@ fn main() -> anyhow::Result<()> {
.merge(filesystem::filesystem_router())
.merge(config::config_router())
.route("/sounds/:filename", get(serve_sound_file)),
)
.layer(Extension(pool.clone()))
.layer(Extension(config_arc));
);
let app = Router::new()
.merge(public_routes)
@@ -208,8 +205,7 @@ fn main() -> anyhow::Result<()> {
// Static file serving routes
.route("/", get(index_handler))
.route("/*path", get(static_handler))
.layer(Extension(pool))
.layer(Extension(app_state))
.with_state(app_state)
.layer(CorsLayer::permissive())
.layer(NewSentryLayer::new_from_top());

View File

@@ -1,17 +1,18 @@
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use axum::{
extract::{Extension, Query},
extract::{Query, State},
response::Json as ResponseJson,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::{fs, sync::RwLock};
use tokio::fs;
use ts_rs::TS;
use crate::{
app_state::AppState,
executor::ExecutorConfig,
models::{
config::{Config, EditorConstants, SoundConstants},
@@ -20,7 +21,7 @@ use crate::{
utils,
};
pub fn config_router() -> Router {
pub fn config_router() -> Router<AppState> {
Router::new()
.route("/config", get(get_config))
.route("/config", post(update_config))
@@ -29,10 +30,8 @@ pub fn config_router() -> Router {
.route("/mcp-servers", post(update_mcp_servers))
}
async fn get_config(
Extension(config): Extension<Arc<RwLock<Config>>>,
) -> ResponseJson<ApiResponse<Config>> {
let config = config.read().await;
async fn get_config(State(app_state): State<AppState>) -> ResponseJson<ApiResponse<Config>> {
let config = app_state.get_config().read().await;
ResponseJson(ApiResponse {
success: true,
data: Some(config.clone()),
@@ -41,15 +40,14 @@ async fn get_config(
}
async fn update_config(
Extension(config_arc): Extension<Arc<RwLock<Config>>>,
Extension(app_state): Extension<crate::app_state::AppState>,
State(app_state): State<AppState>,
Json(new_config): Json<Config>,
) -> ResponseJson<ApiResponse<Config>> {
let config_path = utils::config_path();
match new_config.save(&config_path) {
Ok(_) => {
let mut config = config_arc.write().await;
let mut config = app_state.get_config().write().await;
*config = new_config.clone();
drop(config);
@@ -119,11 +117,11 @@ fn resolve_executor_config(
}
async fn get_mcp_servers(
Extension(config): Extension<Arc<RwLock<Config>>>,
State(app_state): State<AppState>,
Query(query): Query<McpServerQuery>,
) -> ResponseJson<ApiResponse<Value>> {
let saved_config = {
let config = config.read().await;
let config = app_state.get_config().read().await;
config.executor.clone()
};
@@ -171,12 +169,12 @@ async fn get_mcp_servers(
}
async fn update_mcp_servers(
Extension(config): Extension<Arc<RwLock<Config>>>,
State(app_state): State<AppState>,
Query(query): Query<McpServerQuery>,
Json(new_servers): Json<HashMap<String, Value>>,
) -> ResponseJson<ApiResponse<String>> {
let saved_config = {
let config = config.read().await;
let config = app_state.get_config().read().await;
config.executor.clone()
};

View File

@@ -9,7 +9,7 @@ use axum::{
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::models::ApiResponse;
use crate::{app_state::AppState, models::ApiResponse};
#[derive(Debug, Serialize, TS)]
#[ts(export)]
@@ -199,7 +199,7 @@ pub async fn create_git_repo(
}
}
pub fn filesystem_router() -> Router {
pub fn filesystem_router() -> Router<AppState> {
Router::new()
.route("/filesystem/list", get(list_directory))
.route("/filesystem/validate-git", get(validate_git_path))

View File

@@ -1,28 +1,29 @@
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use axum::{
extract::{Extension, Path, Query},
extract::{Path, Query, State},
http::StatusCode,
response::Json as ResponseJson,
routing::get,
Json, Router,
};
use sqlx::SqlitePool;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::models::{
project::{
CreateBranch, CreateProject, GitBranch, Project, ProjectWithBranch, SearchMatchType,
SearchResult, UpdateProject,
use crate::{
app_state::AppState,
models::{
project::{
CreateBranch, CreateProject, GitBranch, Project, ProjectWithBranch, SearchMatchType,
SearchResult, UpdateProject,
},
ApiResponse,
},
ApiResponse,
};
pub async fn get_projects(
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<Vec<Project>>>, StatusCode> {
match Project::find_all(&pool).await {
match Project::find_all(&app_state.db_pool).await {
Ok(projects) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(projects),
@@ -37,9 +38,9 @@ pub async fn get_projects(
pub async fn get_project(
Path(id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<Project>>, StatusCode> {
match Project::find_by_id(&pool, id).await {
match Project::find_by_id(&app_state.db_pool, id).await {
Ok(Some(project)) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(project),
@@ -55,9 +56,9 @@ pub async fn get_project(
pub async fn get_project_with_branch(
Path(id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<ProjectWithBranch>>, StatusCode> {
match Project::find_by_id(&pool, id).await {
match Project::find_by_id(&app_state.db_pool, id).await {
Ok(Some(project)) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(project.with_branch_info()),
@@ -73,9 +74,9 @@ pub async fn get_project_with_branch(
pub async fn get_project_branches(
Path(id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<Vec<GitBranch>>>, StatusCode> {
match Project::find_by_id(&pool, id).await {
match Project::find_by_id(&app_state.db_pool, id).await {
Ok(Some(project)) => match project.get_all_branches() {
Ok(branches) => Ok(ResponseJson(ApiResponse {
success: true,
@@ -97,7 +98,7 @@ pub async fn get_project_branches(
pub async fn create_project_branch(
Path(id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
Json(payload): Json<CreateBranch>,
) -> Result<ResponseJson<ApiResponse<GitBranch>>, StatusCode> {
// Validate branch name
@@ -118,7 +119,7 @@ pub async fn create_project_branch(
}));
}
match Project::find_by_id(&pool, id).await {
match Project::find_by_id(&app_state.db_pool, id).await {
Ok(Some(project)) => {
match project.create_branch(&payload.name, payload.base_branch.as_deref()) {
Ok(branch) => Ok(ResponseJson(ApiResponse {
@@ -150,8 +151,7 @@ pub async fn create_project_branch(
}
pub async fn create_project(
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<crate::app_state::AppState>,
State(app_state): State<AppState>,
Json(payload): Json<CreateProject>,
) -> Result<ResponseJson<ApiResponse<Project>>, StatusCode> {
let id = Uuid::new_v4();
@@ -159,7 +159,7 @@ pub async fn create_project(
tracing::debug!("Creating project '{}'", payload.name);
// Check if git repo path is already used by another project
match Project::find_by_git_repo_path(&pool, &payload.git_repo_path).await {
match Project::find_by_git_repo_path(&app_state.db_pool, &payload.git_repo_path).await {
Ok(Some(_)) => {
return Ok(ResponseJson(ApiResponse {
success: false,
@@ -249,7 +249,7 @@ pub async fn create_project(
}
}
match Project::create(&pool, &payload, id).await {
match Project::create(&app_state.db_pool, &payload, id).await {
Ok(project) => {
// Track project creation event
app_state
@@ -279,11 +279,11 @@ pub async fn create_project(
pub async fn update_project(
Path(id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
Json(payload): Json<UpdateProject>,
) -> Result<ResponseJson<ApiResponse<Project>>, StatusCode> {
// Check if project exists first
let existing_project = match Project::find_by_id(&pool, id).await {
let existing_project = match Project::find_by_id(&app_state.db_pool, id).await {
Ok(Some(project)) => project,
Ok(None) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
@@ -295,7 +295,13 @@ pub async fn update_project(
// If git_repo_path is being changed, check if the new path is already used by another project
if let Some(new_git_repo_path) = &payload.git_repo_path {
if new_git_repo_path != &existing_project.git_repo_path {
match Project::find_by_git_repo_path_excluding_id(&pool, new_git_repo_path, id).await {
match Project::find_by_git_repo_path_excluding_id(
&app_state.db_pool,
new_git_repo_path,
id,
)
.await
{
Ok(Some(_)) => {
return Ok(ResponseJson(ApiResponse {
success: false,
@@ -329,7 +335,16 @@ pub async fn update_project(
let name = name.unwrap_or(existing_project.name);
let git_repo_path = git_repo_path.unwrap_or(existing_project.git_repo_path);
match Project::update(&pool, id, name, git_repo_path, setup_script, dev_script).await {
match Project::update(
&app_state.db_pool,
id,
name,
git_repo_path,
setup_script,
dev_script,
)
.await
{
Ok(project) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(project),
@@ -344,9 +359,9 @@ pub async fn update_project(
pub async fn delete_project(
Path(id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
match Project::delete(&pool, id).await {
match Project::delete(&app_state.db_pool, id).await {
Ok(rows_affected) => {
if rows_affected == 0 {
Err(StatusCode::NOT_FOUND)
@@ -372,12 +387,11 @@ pub struct OpenEditorRequest {
pub async fn open_project_in_editor(
Path(id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
Extension(config): Extension<Arc<RwLock<crate::models::config::Config>>>,
State(app_state): State<AppState>,
Json(payload): Json<Option<OpenEditorRequest>>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
// Get the project
let project = match Project::find_by_id(&pool, id).await {
let project = match Project::find_by_id(&app_state.db_pool, id).await {
Ok(Some(project)) => project,
Ok(None) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
@@ -388,7 +402,7 @@ pub async fn open_project_in_editor(
// Get editor command from config or override
let editor_command = {
let config_guard = config.read().await;
let config_guard = app_state.get_config().read().await;
if let Some(ref request) = payload {
if let Some(ref editor_type) = request.editor_type {
// Create a temporary editor config with the override
@@ -451,7 +465,7 @@ pub async fn open_project_in_editor(
pub async fn search_project_files(
Path(id): Path<Uuid>,
Query(params): Query<HashMap<String, String>>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<Vec<SearchResult>>>, StatusCode> {
let query = match params.get("q") {
Some(q) if !q.trim().is_empty() => q.trim(),
@@ -465,7 +479,7 @@ pub async fn search_project_files(
};
// Check if project exists
let project = match Project::find_by_id(&pool, id).await {
let project = match Project::find_by_id(&app_state.db_pool, id).await {
Ok(Some(project)) => project,
Ok(None) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
@@ -587,7 +601,7 @@ async fn search_files_in_repo(
Ok(results)
}
pub fn projects_router() -> Router {
pub fn projects_router() -> Router<AppState> {
use axum::routing::post;
Router::new()

View File

@@ -1,15 +1,11 @@
use std::sync::Arc;
use axum::{
extract::{Extension, Path, Query},
extract::{Path, Query, State},
http::StatusCode,
response::Json as ResponseJson,
routing::get,
Json, Router,
};
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::{
@@ -45,10 +41,10 @@ pub struct CreateGitHubPRRequest {
pub async fn get_task_attempts(
Path((project_id, task_id)): Path<(Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<Vec<TaskAttempt>>>, StatusCode> {
// Verify task exists in project first
match Task::exists(&pool, task_id, project_id).await {
match Task::exists(&app_state.db_pool, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task existence: {}", e);
@@ -57,7 +53,7 @@ pub async fn get_task_attempts(
Ok(true) => {}
}
match TaskAttempt::find_by_task_id(&pool, task_id).await {
match TaskAttempt::find_by_task_id(&app_state.db_pool, task_id).await {
Ok(attempts) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(attempts),
@@ -72,10 +68,10 @@ pub async fn get_task_attempts(
pub async fn get_task_attempt_activities(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<Vec<TaskAttemptActivityWithPrompt>>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -85,7 +81,9 @@ pub async fn get_task_attempt_activities(
}
// Get activities with prompts for the task attempt
match TaskAttemptActivity::find_with_prompts_by_task_attempt_id(&pool, attempt_id).await {
match TaskAttemptActivity::find_with_prompts_by_task_attempt_id(&app_state.db_pool, attempt_id)
.await
{
Ok(activities) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(activities),
@@ -104,12 +102,11 @@ pub async fn get_task_attempt_activities(
pub async fn create_task_attempt(
Path((project_id, task_id)): Path<(Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<crate::app_state::AppState>,
State(app_state): State<AppState>,
Json(payload): Json<CreateTaskAttempt>,
) -> Result<ResponseJson<ApiResponse<TaskAttempt>>, StatusCode> {
// Verify task exists in project first
match Task::exists(&pool, task_id, project_id).await {
match Task::exists(&app_state.db_pool, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task existence: {}", e);
@@ -120,7 +117,7 @@ pub async fn create_task_attempt(
let executor_string = payload.executor.as_ref().map(|exec| exec.to_string());
match TaskAttempt::create(&pool, &payload, task_id).await {
match TaskAttempt::create(&app_state.db_pool, &payload, task_id).await {
Ok(attempt) => {
app_state
.track_analytics_event(
@@ -134,12 +131,11 @@ pub async fn create_task_attempt(
.await;
// Start execution asynchronously (don't block the response)
let pool_clone = pool.clone();
let app_state_clone = app_state.clone();
let attempt_id = attempt.id;
tokio::spawn(async move {
if let Err(e) = TaskAttempt::start_execution(
&pool_clone,
&app_state_clone.db_pool,
&app_state_clone,
attempt_id,
task_id,
@@ -170,11 +166,11 @@ pub async fn create_task_attempt(
pub async fn create_task_attempt_activity(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
Json(payload): Json<CreateTaskAttemptActivity>,
) -> Result<ResponseJson<ApiResponse<TaskAttemptActivity>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -191,7 +187,7 @@ pub async fn create_task_attempt_activity(
}
// Verify the execution process exists and belongs to this task attempt
match ExecutionProcess::find_by_id(&pool, payload.execution_process_id).await {
match ExecutionProcess::find_by_id(&app_state.db_pool, payload.execution_process_id).await {
Ok(Some(process)) => {
if process.task_attempt_id != attempt_id {
return Err(StatusCode::BAD_REQUEST);
@@ -210,7 +206,7 @@ pub async fn create_task_attempt_activity(
.clone()
.unwrap_or(TaskAttemptStatus::SetupRunning);
match TaskAttemptActivity::create(&pool, &payload, id, status).await {
match TaskAttemptActivity::create(&app_state.db_pool, &payload, id, status).await {
Ok(activity) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(activity),
@@ -225,9 +221,9 @@ pub async fn create_task_attempt_activity(
pub async fn get_task_attempt_diff(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<WorktreeDiff>>, StatusCode> {
match TaskAttempt::get_diff(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::get_diff(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(diff) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(diff),
@@ -243,11 +239,10 @@ pub async fn get_task_attempt_diff(
#[axum::debug_handler]
pub async fn merge_task_attempt(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<Arc<AppState>>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -256,11 +251,11 @@ pub async fn merge_task_attempt(
Ok(true) => {}
}
match TaskAttempt::merge_changes(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::merge_changes(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(_) => {
// Update task status to Done
if let Err(e) = Task::update_status(
&pool,
&app_state.db_pool,
task_id,
project_id,
crate::models::task::TaskStatus::Done,
@@ -298,12 +293,11 @@ pub async fn merge_task_attempt(
pub async fn create_github_pr(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<Arc<AppState>>,
State(app_state): State<AppState>,
Json(request): Json<CreateGitHubPRRequest>,
) -> Result<ResponseJson<ApiResponse<String>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -341,7 +335,7 @@ pub async fn create_github_pr(
.unwrap_or_else(|| "main".to_string());
match TaskAttempt::create_github_pr(
&pool,
&app_state.db_pool,
CreatePrParams {
attempt_id,
task_id,
@@ -394,12 +388,11 @@ pub struct OpenEditorRequest {
pub async fn open_task_attempt_in_editor(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
Extension(config): Extension<Arc<RwLock<crate::models::config::Config>>>,
State(app_state): State<AppState>,
Json(payload): Json<Option<OpenEditorRequest>>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -409,7 +402,7 @@ pub async fn open_task_attempt_in_editor(
}
// Get the task attempt to access the worktree path
let attempt = match TaskAttempt::find_by_id(&pool, attempt_id).await {
let attempt = match TaskAttempt::find_by_id(&app_state.db_pool, attempt_id).await {
Ok(Some(attempt)) => attempt,
Ok(None) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
@@ -420,7 +413,7 @@ pub async fn open_task_attempt_in_editor(
// Get editor command from config or override
let editor_command = {
let config_guard = config.read().await;
let config_guard = app_state.get_config().read().await;
if let Some(ref request) = payload {
if let Some(ref editor_type) = request.editor_type {
// Create a temporary editor config with the override
@@ -482,9 +475,10 @@ pub async fn open_task_attempt_in_editor(
pub async fn get_task_attempt_branch_status(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<BranchStatus>>, StatusCode> {
match TaskAttempt::get_branch_status(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::get_branch_status(&app_state.db_pool, attempt_id, task_id, project_id).await
{
Ok(status) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(status),
@@ -504,11 +498,11 @@ pub async fn get_task_attempt_branch_status(
#[axum::debug_handler]
pub async fn rebase_task_attempt(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
request_body: Option<Json<RebaseTaskAttemptRequest>>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -520,7 +514,14 @@ pub async fn rebase_task_attempt(
// Extract new base branch from request body if provided
let new_base_branch = request_body.and_then(|body| body.new_base_branch.clone());
match TaskAttempt::rebase_attempt(&pool, attempt_id, task_id, project_id, new_base_branch).await
match TaskAttempt::rebase_attempt(
&app_state.db_pool,
attempt_id,
task_id,
project_id,
new_base_branch,
)
.await
{
Ok(_new_base_commit) => Ok(ResponseJson(ApiResponse {
success: true,
@@ -540,10 +541,10 @@ pub async fn rebase_task_attempt(
pub async fn get_task_attempt_execution_processes(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<Vec<ExecutionProcessSummary>>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -552,7 +553,8 @@ pub async fn get_task_attempt_execution_processes(
Ok(true) => {}
}
match ExecutionProcess::find_summaries_by_task_attempt_id(&pool, attempt_id).await {
match ExecutionProcess::find_summaries_by_task_attempt_id(&app_state.db_pool, attempt_id).await
{
Ok(processes) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(processes),
@@ -571,14 +573,14 @@ pub async fn get_task_attempt_execution_processes(
pub async fn get_execution_process(
Path((project_id, process_id)): Path<(Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<ExecutionProcess>>, StatusCode> {
match ExecutionProcess::find_by_id(&pool, process_id).await {
match ExecutionProcess::find_by_id(&app_state.db_pool, process_id).await {
Ok(Some(process)) => {
// Verify the process belongs to a task attempt in the correct project
match TaskAttempt::find_by_id(&pool, process.task_attempt_id).await {
match TaskAttempt::find_by_id(&app_state.db_pool, process.task_attempt_id).await {
Ok(Some(attempt)) => {
match Task::find_by_id(&pool, attempt.task_id).await {
match Task::find_by_id(&app_state.db_pool, attempt.task_id).await {
Ok(Some(task)) if task.project_id == project_id => {
Ok(ResponseJson(ApiResponse {
success: true,
@@ -612,11 +614,10 @@ pub async fn get_execution_process(
#[axum::debug_handler]
pub async fn stop_all_execution_processes(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<crate::app_state::AppState>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -626,17 +627,18 @@ pub async fn stop_all_execution_processes(
}
// Get all execution processes for the task attempt
let processes = match ExecutionProcess::find_by_task_attempt_id(&pool, attempt_id).await {
Ok(processes) => processes,
Err(e) => {
tracing::error!(
"Failed to fetch execution processes for attempt {}: {}",
attempt_id,
e
);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
let processes =
match ExecutionProcess::find_by_task_attempt_id(&app_state.db_pool, attempt_id).await {
Ok(processes) => processes,
Err(e) => {
tracing::error!(
"Failed to fetch execution processes for attempt {}: {}",
attempt_id,
e
);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
let mut stopped_count = 0;
let mut errors = Vec::new();
@@ -649,7 +651,7 @@ pub async fn stop_all_execution_processes(
// Update the execution process status in the database
if let Err(e) = ExecutionProcess::update_completion(
&pool,
&app_state.db_pool,
process.id,
crate::models::execution_process::ExecutionProcessStatus::Killed,
None,
@@ -675,7 +677,7 @@ pub async fn stop_all_execution_processes(
};
if let Err(e) = TaskAttemptActivity::create(
&pool,
&app_state.db_pool,
&create_activity,
activity_id,
TaskAttemptStatus::ExecutorFailed,
@@ -734,11 +736,10 @@ pub async fn stop_all_execution_processes(
#[axum::debug_handler]
pub async fn stop_execution_process(
Path((project_id, task_id, attempt_id, process_id)): Path<(Uuid, Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<crate::app_state::AppState>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -748,7 +749,7 @@ pub async fn stop_execution_process(
}
// Verify execution process exists and belongs to the task attempt
let process = match ExecutionProcess::find_by_id(&pool, process_id).await {
let process = match ExecutionProcess::find_by_id(&app_state.db_pool, process_id).await {
Ok(Some(process)) if process.task_attempt_id == attempt_id => process,
Ok(Some(_)) => return Err(StatusCode::NOT_FOUND), // Process exists but wrong attempt
Ok(None) => return Err(StatusCode::NOT_FOUND),
@@ -777,7 +778,7 @@ pub async fn stop_execution_process(
// Update the execution process status in the database
if let Err(e) = ExecutionProcess::update_completion(
&pool,
&app_state.db_pool,
process_id,
crate::models::execution_process::ExecutionProcessStatus::Killed,
None,
@@ -804,7 +805,7 @@ pub async fn stop_execution_process(
};
if let Err(e) = TaskAttemptActivity::create(
&pool,
&app_state.db_pool,
&create_activity,
activity_id,
TaskAttemptStatus::ExecutorFailed,
@@ -835,10 +836,10 @@ pub struct DeleteFileQuery {
pub async fn delete_task_attempt_file(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Query(query): Query<DeleteFileQuery>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -847,7 +848,15 @@ pub async fn delete_task_attempt_file(
Ok(true) => {}
}
match TaskAttempt::delete_file(&pool, attempt_id, task_id, project_id, &query.file_path).await {
match TaskAttempt::delete_file(
&app_state.db_pool,
attempt_id,
task_id,
project_id,
&query.file_path,
)
.await
{
Ok(_commit_id) => Ok(ResponseJson(ApiResponse {
success: true,
data: None,
@@ -871,12 +880,11 @@ pub async fn delete_task_attempt_file(
pub async fn create_followup_attempt(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<crate::app_state::AppState>,
State(app_state): State<AppState>,
Json(payload): Json<CreateFollowUpAttempt>,
) -> Result<ResponseJson<ApiResponse<String>>, StatusCode> {
// Verify task attempt exists
if !TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id)
if !TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id)
.await
.map_err(|e| {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -888,7 +896,7 @@ pub async fn create_followup_attempt(
// Start follow-up execution synchronously to catch errors
match TaskAttempt::start_followup_execution(
&pool,
&app_state.db_pool,
&app_state,
attempt_id,
task_id,
@@ -915,11 +923,10 @@ pub async fn create_followup_attempt(
pub async fn start_dev_server(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<crate::app_state::AppState>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -930,7 +937,9 @@ pub async fn start_dev_server(
// Stop any existing dev servers for this project
let existing_dev_servers =
match ExecutionProcess::find_running_dev_servers_by_project(&pool, project_id).await {
match ExecutionProcess::find_running_dev_servers_by_project(&app_state.db_pool, project_id)
.await
{
Ok(servers) => servers,
Err(e) => {
tracing::error!(
@@ -955,7 +964,7 @@ pub async fn start_dev_server(
} else {
// Update the execution process status in the database
if let Err(e) = ExecutionProcess::update_completion(
&pool,
&app_state.db_pool,
dev_server.id,
crate::models::execution_process::ExecutionProcessStatus::Killed,
None,
@@ -972,7 +981,15 @@ pub async fn start_dev_server(
}
// Start dev server execution
match TaskAttempt::start_dev_server(&pool, &app_state, attempt_id, task_id, project_id).await {
match TaskAttempt::start_dev_server(
&app_state.db_pool,
&app_state,
attempt_id,
task_id,
project_id,
)
.await
{
Ok(_) => Ok(ResponseJson(ApiResponse {
success: true,
data: None,
@@ -995,10 +1012,10 @@ pub async fn start_dev_server(
pub async fn get_task_attempt_execution_state(
Path((project_id, task_id, attempt_id)): Path<(Uuid, Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<TaskAttemptState>>, StatusCode> {
// Verify task attempt exists and belongs to the correct task
match TaskAttempt::exists_for_task(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::exists_for_task(&app_state.db_pool, attempt_id, task_id, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task attempt existence: {}", e);
@@ -1008,7 +1025,9 @@ pub async fn get_task_attempt_execution_state(
}
// Get the execution state
match TaskAttempt::get_execution_state(&pool, attempt_id, task_id, project_id).await {
match TaskAttempt::get_execution_state(&app_state.db_pool, attempt_id, task_id, project_id)
.await
{
Ok(state) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(state),
@@ -1027,10 +1046,10 @@ pub async fn get_task_attempt_execution_state(
pub async fn get_execution_process_normalized_logs(
Path((project_id, process_id)): Path<(Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<NormalizedConversation>>, StatusCode> {
// Get the execution process and verify it belongs to the correct project
let process = match ExecutionProcess::find_by_id(&pool, process_id).await {
let process = match ExecutionProcess::find_by_id(&app_state.db_pool, process_id).await {
Ok(Some(process)) => process,
Ok(None) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
@@ -1040,7 +1059,7 @@ pub async fn get_execution_process_normalized_logs(
};
// Verify the process belongs to a task attempt in the correct project
let attempt = match TaskAttempt::find_by_id(&pool, process.task_attempt_id).await {
let attempt = match TaskAttempt::find_by_id(&app_state.db_pool, process.task_attempt_id).await {
Ok(Some(attempt)) => attempt,
Ok(None) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
@@ -1049,7 +1068,7 @@ pub async fn get_execution_process_normalized_logs(
}
};
let _task = match Task::find_by_id(&pool, attempt.task_id).await {
let _task = match Task::find_by_id(&app_state.db_pool, attempt.task_id).await {
Ok(Some(task)) if task.project_id == project_id => task,
Ok(Some(_)) => return Err(StatusCode::NOT_FOUND), // Wrong project
Ok(None) => return Err(StatusCode::NOT_FOUND),
@@ -1066,18 +1085,22 @@ pub async fn get_execution_process_normalized_logs(
// If the process is still running, return empty logs instead of an error
if process.status == ExecutionProcessStatus::Running {
// Get executor session data for this execution process
let executor_session =
match ExecutorSession::find_by_execution_process_id(&pool, process_id).await {
Ok(session) => session,
Err(e) => {
tracing::error!(
"Failed to fetch executor session for process {}: {}",
process_id,
e
);
None
}
};
let executor_session = match ExecutorSession::find_by_execution_process_id(
&app_state.db_pool,
process_id,
)
.await
{
Ok(session) => session,
Err(e) => {
tracing::error!(
"Failed to fetch executor session for process {}: {}",
process_id,
e
);
None
}
};
return Ok(ResponseJson(ApiResponse {
success: true,
@@ -1124,7 +1147,7 @@ pub async fn get_execution_process_normalized_logs(
// Get executor session data for this execution process
let executor_session =
match ExecutorSession::find_by_execution_process_id(&pool, process_id).await {
match ExecutorSession::find_by_execution_process_id(&app_state.db_pool, process_id).await {
Ok(session) => session,
Err(e) => {
tracing::error!(
@@ -1162,7 +1185,7 @@ pub async fn get_execution_process_normalized_logs(
}
}
pub fn task_attempts_router() -> Router {
pub fn task_attempts_router() -> Router<AppState> {
use axum::routing::post;
Router::new()

View File

@@ -1,25 +1,27 @@
use axum::{
extract::{Extension, Path},
extract::{Path, State},
http::StatusCode,
response::Json as ResponseJson,
routing::get,
Json, Router,
};
use sqlx::SqlitePool;
use uuid::Uuid;
use crate::models::{
project::Project,
task::{CreateTask, CreateTaskAndStart, Task, TaskWithAttemptStatus, UpdateTask},
task_attempt::{CreateTaskAttempt, TaskAttempt},
ApiResponse,
use crate::{
app_state::AppState,
models::{
project::Project,
task::{CreateTask, CreateTaskAndStart, Task, TaskWithAttemptStatus, UpdateTask},
task_attempt::{CreateTaskAttempt, TaskAttempt},
ApiResponse,
},
};
pub async fn get_project_tasks(
Path(project_id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<Vec<TaskWithAttemptStatus>>>, StatusCode> {
match Task::find_by_project_id_with_attempt_status(&pool, project_id).await {
match Task::find_by_project_id_with_attempt_status(&app_state.db_pool, project_id).await {
Ok(tasks) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(tasks),
@@ -34,9 +36,9 @@ pub async fn get_project_tasks(
pub async fn get_task(
Path((project_id, task_id)): Path<(Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<Task>>, StatusCode> {
match Task::find_by_id_and_project_id(&pool, task_id, project_id).await {
match Task::find_by_id_and_project_id(&app_state.db_pool, task_id, project_id).await {
Ok(Some(task)) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(task),
@@ -57,8 +59,7 @@ pub async fn get_task(
pub async fn create_task(
Path(project_id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<crate::app_state::AppState>,
State(app_state): State<AppState>,
Json(mut payload): Json<CreateTask>,
) -> Result<ResponseJson<ApiResponse<Task>>, StatusCode> {
let id = Uuid::new_v4();
@@ -67,7 +68,7 @@ pub async fn create_task(
payload.project_id = project_id;
// Verify project exists first
match Project::exists(&pool, project_id).await {
match Project::exists(&app_state.db_pool, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check project existence: {}", e);
@@ -82,7 +83,7 @@ pub async fn create_task(
project_id
);
match Task::create(&pool, &payload, id).await {
match Task::create(&app_state.db_pool, &payload, id).await {
Ok(task) => {
// Track task creation event
app_state
@@ -111,8 +112,7 @@ pub async fn create_task(
pub async fn create_task_and_start(
Path(project_id): Path<Uuid>,
Extension(pool): Extension<SqlitePool>,
Extension(app_state): Extension<crate::app_state::AppState>,
State(app_state): State<AppState>,
Json(mut payload): Json<CreateTaskAndStart>,
) -> Result<ResponseJson<ApiResponse<Task>>, StatusCode> {
let task_id = Uuid::new_v4();
@@ -121,7 +121,7 @@ pub async fn create_task_and_start(
payload.project_id = project_id;
// Verify project exists first
match Project::exists(&pool, project_id).await {
match Project::exists(&app_state.db_pool, project_id).await {
Ok(false) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check project existence: {}", e);
@@ -142,7 +142,7 @@ pub async fn create_task_and_start(
title: payload.title.clone(),
description: payload.description.clone(),
};
let task = match Task::create(&pool, &create_task_payload, task_id).await {
let task = match Task::create(&app_state.db_pool, &create_task_payload, task_id).await {
Ok(task) => task,
Err(e) => {
tracing::error!("Failed to create task: {}", e);
@@ -157,7 +157,7 @@ pub async fn create_task_and_start(
base_branch: None, // Not supported in task creation endpoint, only in task attempts
};
match TaskAttempt::create(&pool, &attempt_payload, task_id).await {
match TaskAttempt::create(&app_state.db_pool, &attempt_payload, task_id).await {
Ok(attempt) => {
app_state
.track_analytics_event(
@@ -182,12 +182,11 @@ pub async fn create_task_and_start(
.await;
// Start execution asynchronously (don't block the response)
let pool_clone = pool.clone();
let app_state_clone = app_state.clone();
let attempt_id = attempt.id;
tokio::spawn(async move {
if let Err(e) = TaskAttempt::start_execution(
&pool_clone,
&app_state_clone.db_pool,
&app_state_clone,
attempt_id,
task_id,
@@ -218,25 +217,35 @@ pub async fn create_task_and_start(
pub async fn update_task(
Path((project_id, task_id)): Path<(Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
Json(payload): Json<UpdateTask>,
) -> Result<ResponseJson<ApiResponse<Task>>, StatusCode> {
// Check if task exists in the specified project
let existing_task = match Task::find_by_id_and_project_id(&pool, task_id, project_id).await {
Ok(Some(task)) => task,
Ok(None) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task existence: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
let existing_task =
match Task::find_by_id_and_project_id(&app_state.db_pool, task_id, project_id).await {
Ok(Some(task)) => task,
Ok(None) => return Err(StatusCode::NOT_FOUND),
Err(e) => {
tracing::error!("Failed to check task existence: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
// Use existing values if not provided in update
let title = payload.title.unwrap_or(existing_task.title);
let description = payload.description.or(existing_task.description);
let status = payload.status.unwrap_or(existing_task.status);
match Task::update(&pool, task_id, project_id, title, description, status).await {
match Task::update(
&app_state.db_pool,
task_id,
project_id,
title,
description,
status,
)
.await
{
Ok(task) => Ok(ResponseJson(ApiResponse {
success: true,
data: Some(task),
@@ -251,9 +260,9 @@ pub async fn update_task(
pub async fn delete_task(
Path((project_id, task_id)): Path<(Uuid, Uuid)>,
Extension(pool): Extension<SqlitePool>,
State(app_state): State<AppState>,
) -> Result<ResponseJson<ApiResponse<()>>, StatusCode> {
match Task::delete(&pool, task_id, project_id).await {
match Task::delete(&app_state.db_pool, task_id, project_id).await {
Ok(rows_affected) => {
if rows_affected == 0 {
Err(StatusCode::NOT_FOUND)
@@ -272,7 +281,7 @@ pub async fn delete_task(
}
}
pub fn tasks_router() -> Router {
pub fn tasks_router() -> Router<AppState> {
use axum::routing::post;
Router::new()