diff --git a/crates/server/src/routes/task_attempts.rs b/crates/server/src/routes/task_attempts.rs index c656c20d..53d70e29 100644 --- a/crates/server/src/routes/task_attempts.rs +++ b/crates/server/src/routes/task_attempts.rs @@ -1285,6 +1285,100 @@ pub async fn stop_task_attempt_execution( Ok(ResponseJson(ApiResponse::success(()))) } +#[derive(Debug, Serialize, TS)] +pub struct AttachPrResponse { + pub pr_attached: bool, + pub pr_url: Option, + pub pr_number: Option, + pub pr_status: Option, +} + +pub async fn attach_existing_pr( + Extension(task_attempt): Extension, + State(deployment): State, +) -> Result>, ApiError> { + let pool = &deployment.db().pool; + + // Check if PR already attached + if let Some(Merge::Pr(pr_merge)) = + Merge::find_latest_by_task_attempt_id(pool, task_attempt.id).await? + { + return Ok(ResponseJson(ApiResponse::success(AttachPrResponse { + pr_attached: true, + pr_url: Some(pr_merge.pr_info.url.clone()), + pr_number: Some(pr_merge.pr_info.number), + pr_status: Some(pr_merge.pr_info.status.clone()), + }))); + } + + // Get GitHub token + let github_config = deployment.config().read().await.github.clone(); + let Some(github_token) = github_config.token() else { + return Err(ApiError::GitHubService(GitHubServiceError::TokenInvalid)); + }; + + // Get project and repo info + let Some(task) = task_attempt.parent_task(pool).await? else { + return Err(ApiError::TaskAttempt(TaskAttemptError::TaskNotFound)); + }; + let Some(project) = Project::find_by_id(pool, task.project_id).await? else { + return Err(ApiError::Project(ProjectError::ProjectNotFound)); + }; + + let github_service = GitHubService::new(&github_token)?; + let repo_info = deployment + .git() + .get_github_repo_info(&project.git_repo_path)?; + + // List all PRs for branch (open, closed, and merged) + let prs = github_service + .list_all_prs_for_branch(&repo_info, &task_attempt.branch) + .await?; + + // Take the first PR (prefer open, but also accept merged/closed) + if let Some(pr_info) = prs.into_iter().next() { + // Save PR info to database + let merge = Merge::create_pr( + pool, + task_attempt.id, + &task_attempt.target_branch, + pr_info.number, + &pr_info.url, + ) + .await?; + + // Update status if not open + if !matches!(pr_info.status, MergeStatus::Open) { + Merge::update_status( + pool, + merge.id, + pr_info.status.clone(), + pr_info.merge_commit_sha.clone(), + ) + .await?; + } + + // If PR is merged, mark task as done + if matches!(pr_info.status, MergeStatus::Merged) { + Task::update_status(pool, task.id, TaskStatus::Done).await?; + } + + Ok(ResponseJson(ApiResponse::success(AttachPrResponse { + pr_attached: true, + pr_url: Some(pr_info.url), + pr_number: Some(pr_info.number), + pr_status: Some(pr_info.status), + }))) + } else { + Ok(ResponseJson(ApiResponse::success(AttachPrResponse { + pr_attached: false, + pr_url: None, + pr_number: None, + pr_status: None, + }))) + } +} + pub fn router(deployment: &DeploymentImpl) -> Router { let task_attempt_id_router = Router::new() .route("/", get(get_task_attempt)) @@ -1307,6 +1401,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router { .route("/rebase", post(rebase_task_attempt)) .route("/conflicts/abort", post(abort_conflicts_task_attempt)) .route("/pr", post(create_github_pr)) + .route("/pr/attach", post(attach_existing_pr)) .route("/open-editor", post(open_task_attempt_in_editor)) .route("/delete-file", post(delete_task_attempt_file)) .route("/children", get(get_task_attempt_children)) diff --git a/crates/services/src/services/github_service.rs b/crates/services/src/services/github_service.rs index b3c4a7ea..7e18174a 100644 --- a/crates/services/src/services/github_service.rs +++ b/crates/services/src/services/github_service.rs @@ -323,6 +323,60 @@ impl GitHubService { } } + /// List all pull requests for a branch (including closed/merged) + pub async fn list_all_prs_for_branch( + &self, + repo_info: &GitHubRepoInfo, + branch_name: &str, + ) -> Result, GitHubServiceError> { + (|| async { + self.list_all_prs_for_branch_internal(repo_info, branch_name) + .await + }) + .retry( + &ExponentialBuilder::default() + .with_min_delay(Duration::from_secs(1)) + .with_max_delay(Duration::from_secs(30)) + .with_max_times(3) + .with_jitter(), + ) + .when(|e| e.should_retry()) + .notify(|err: &GitHubServiceError, dur: Duration| { + tracing::warn!( + "GitHub API call failed, retrying after {:.2}s: {}", + dur.as_secs_f64(), + err + ); + }) + .await + } + + async fn list_all_prs_for_branch_internal( + &self, + repo_info: &GitHubRepoInfo, + branch_name: &str, + ) -> Result, GitHubServiceError> { + let prs = self + .client + .pulls(&repo_info.owner, &repo_info.repo_name) + .list() + .state(octocrab::params::State::All) + .head(format!("{}:{}", repo_info.owner, branch_name)) + .per_page(100) + .send() + .await + .map_err(|err| match GitHubServiceError::from(err) { + GitHubServiceError::Client(source) => GitHubServiceError::PullRequest(format!( + "Failed to list all PRs for branch '{branch_name}': {source}", + )), + other => other, + })?; + + let pr_infos = prs.items.into_iter().map(Self::map_pull_request).collect(); + + Ok(pr_infos) + } + /// List repositories for the authenticated user with pagination #[cfg(feature = "cloud")] pub async fn list_repositories( diff --git a/package.json b/package.json index 2c08495f..c6af5753 100644 --- a/package.json +++ b/package.json @@ -42,4 +42,4 @@ "@dnd-kit/utilities": "^3.2.2", "@ebay/nice-modal-react": "^1.2.13" } -} +} \ No newline at end of file