feat: implement automatic PR discovery and attachment for task attempts (#842)

* feat: implement automatic PR discovery and attachment for task attempts

- Add GitHub API methods to list PRs for a branch (open and all states)
- Create /pr/attach endpoint to discover and attach existing PRs
- Automatically mark tasks as done when attached PR is merged
- Update Merge model to support PR status on creation
- Handle both open and closed/merged PRs during attachment

This improves on #837 by using GitHub API to automatically discover PRs
rather than requiring manual input of PR details.

* fix: address PR review feedback

- Fix compilation issue by using find_latest_by_task_attempt_id
- Properly handle Merge enum (Direct vs Pr variants)
- Remove redundant list_prs_for_branch method
- Simplify PR discovery to use only list_all_prs_for_branch
- Only check for existing PR merges, not direct merges

* fix: resolve compilation issues

- Fix SQLx cache issue by restoring exact original create_pr method
- Fix API response type for GitHub token error
- Fix ProjectError variant name to ProjectNotFound
- Add update_status call after PR creation for non-open PRs

* fix: address PR review feedback

- Fix compilation issue by using find_latest_by_task_attempt_id
- Properly handle Merge enum (Direct vs Pr variants)
- Remove redundant list_prs_for_branch method
- Simplify PR discovery to use only list_all_prs_for_branch
- Only check for existing PR merges, not direct merges
- Update code to match current TaskAttempt struct (branch: String, target_branch: String)

* Clippy, fmt, cleanup

---------

Co-authored-by: Alex Netsch <alex@bloop.ai>
This commit is contained in:
Jacek Tomaszewski
2025-10-01 18:31:50 +02:00
committed by GitHub
parent 0e9d10732a
commit 2b277d3ddf
3 changed files with 150 additions and 1 deletions

View File

@@ -1285,6 +1285,100 @@ pub async fn stop_task_attempt_execution(
Ok(ResponseJson(ApiResponse::success(())))
}
#[derive(Debug, Serialize, TS)]
pub struct AttachPrResponse {
pub pr_attached: bool,
pub pr_url: Option<String>,
pub pr_number: Option<i64>,
pub pr_status: Option<MergeStatus>,
}
pub async fn attach_existing_pr(
Extension(task_attempt): Extension<TaskAttempt>,
State(deployment): State<DeploymentImpl>,
) -> Result<ResponseJson<ApiResponse<AttachPrResponse>>, ApiError> {
let pool = &deployment.db().pool;
// Check if PR already attached
if let Some(Merge::Pr(pr_merge)) =
Merge::find_latest_by_task_attempt_id(pool, task_attempt.id).await?
{
return Ok(ResponseJson(ApiResponse::success(AttachPrResponse {
pr_attached: true,
pr_url: Some(pr_merge.pr_info.url.clone()),
pr_number: Some(pr_merge.pr_info.number),
pr_status: Some(pr_merge.pr_info.status.clone()),
})));
}
// Get GitHub token
let github_config = deployment.config().read().await.github.clone();
let Some(github_token) = github_config.token() else {
return Err(ApiError::GitHubService(GitHubServiceError::TokenInvalid));
};
// Get project and repo info
let Some(task) = task_attempt.parent_task(pool).await? else {
return Err(ApiError::TaskAttempt(TaskAttemptError::TaskNotFound));
};
let Some(project) = Project::find_by_id(pool, task.project_id).await? else {
return Err(ApiError::Project(ProjectError::ProjectNotFound));
};
let github_service = GitHubService::new(&github_token)?;
let repo_info = deployment
.git()
.get_github_repo_info(&project.git_repo_path)?;
// List all PRs for branch (open, closed, and merged)
let prs = github_service
.list_all_prs_for_branch(&repo_info, &task_attempt.branch)
.await?;
// Take the first PR (prefer open, but also accept merged/closed)
if let Some(pr_info) = prs.into_iter().next() {
// Save PR info to database
let merge = Merge::create_pr(
pool,
task_attempt.id,
&task_attempt.target_branch,
pr_info.number,
&pr_info.url,
)
.await?;
// Update status if not open
if !matches!(pr_info.status, MergeStatus::Open) {
Merge::update_status(
pool,
merge.id,
pr_info.status.clone(),
pr_info.merge_commit_sha.clone(),
)
.await?;
}
// If PR is merged, mark task as done
if matches!(pr_info.status, MergeStatus::Merged) {
Task::update_status(pool, task.id, TaskStatus::Done).await?;
}
Ok(ResponseJson(ApiResponse::success(AttachPrResponse {
pr_attached: true,
pr_url: Some(pr_info.url),
pr_number: Some(pr_info.number),
pr_status: Some(pr_info.status),
})))
} else {
Ok(ResponseJson(ApiResponse::success(AttachPrResponse {
pr_attached: false,
pr_url: None,
pr_number: None,
pr_status: None,
})))
}
}
pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
let task_attempt_id_router = Router::new()
.route("/", get(get_task_attempt))
@@ -1307,6 +1401,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
.route("/rebase", post(rebase_task_attempt))
.route("/conflicts/abort", post(abort_conflicts_task_attempt))
.route("/pr", post(create_github_pr))
.route("/pr/attach", post(attach_existing_pr))
.route("/open-editor", post(open_task_attempt_in_editor))
.route("/delete-file", post(delete_task_attempt_file))
.route("/children", get(get_task_attempt_children))

View File

@@ -323,6 +323,60 @@ impl GitHubService {
}
}
/// List all pull requests for a branch (including closed/merged)
pub async fn list_all_prs_for_branch(
&self,
repo_info: &GitHubRepoInfo,
branch_name: &str,
) -> Result<Vec<PullRequestInfo>, GitHubServiceError> {
(|| async {
self.list_all_prs_for_branch_internal(repo_info, branch_name)
.await
})
.retry(
&ExponentialBuilder::default()
.with_min_delay(Duration::from_secs(1))
.with_max_delay(Duration::from_secs(30))
.with_max_times(3)
.with_jitter(),
)
.when(|e| e.should_retry())
.notify(|err: &GitHubServiceError, dur: Duration| {
tracing::warn!(
"GitHub API call failed, retrying after {:.2}s: {}",
dur.as_secs_f64(),
err
);
})
.await
}
async fn list_all_prs_for_branch_internal(
&self,
repo_info: &GitHubRepoInfo,
branch_name: &str,
) -> Result<Vec<PullRequestInfo>, GitHubServiceError> {
let prs = self
.client
.pulls(&repo_info.owner, &repo_info.repo_name)
.list()
.state(octocrab::params::State::All)
.head(format!("{}:{}", repo_info.owner, branch_name))
.per_page(100)
.send()
.await
.map_err(|err| match GitHubServiceError::from(err) {
GitHubServiceError::Client(source) => GitHubServiceError::PullRequest(format!(
"Failed to list all PRs for branch '{branch_name}': {source}",
)),
other => other,
})?;
let pr_infos = prs.items.into_iter().map(Self::map_pull_request).collect();
Ok(pr_infos)
}
/// List repositories for the authenticated user with pagination
#[cfg(feature = "cloud")]
pub async fn list_repositories(

View File

@@ -42,4 +42,4 @@
"@dnd-kit/utilities": "^3.2.2",
"@ebay/nice-modal-react": "^1.2.13"
}
}
}