feat: implement automatic PR discovery and attachment for task attempts (#842)
* feat: implement automatic PR discovery and attachment for task attempts - Add GitHub API methods to list PRs for a branch (open and all states) - Create /pr/attach endpoint to discover and attach existing PRs - Automatically mark tasks as done when attached PR is merged - Update Merge model to support PR status on creation - Handle both open and closed/merged PRs during attachment This improves on #837 by using GitHub API to automatically discover PRs rather than requiring manual input of PR details. * fix: address PR review feedback - Fix compilation issue by using find_latest_by_task_attempt_id - Properly handle Merge enum (Direct vs Pr variants) - Remove redundant list_prs_for_branch method - Simplify PR discovery to use only list_all_prs_for_branch - Only check for existing PR merges, not direct merges * fix: resolve compilation issues - Fix SQLx cache issue by restoring exact original create_pr method - Fix API response type for GitHub token error - Fix ProjectError variant name to ProjectNotFound - Add update_status call after PR creation for non-open PRs * fix: address PR review feedback - Fix compilation issue by using find_latest_by_task_attempt_id - Properly handle Merge enum (Direct vs Pr variants) - Remove redundant list_prs_for_branch method - Simplify PR discovery to use only list_all_prs_for_branch - Only check for existing PR merges, not direct merges - Update code to match current TaskAttempt struct (branch: String, target_branch: String) * Clippy, fmt, cleanup --------- Co-authored-by: Alex Netsch <alex@bloop.ai>
This commit is contained in:
committed by
GitHub
parent
0e9d10732a
commit
2b277d3ddf
@@ -1285,6 +1285,100 @@ pub async fn stop_task_attempt_execution(
|
||||
Ok(ResponseJson(ApiResponse::success(())))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, TS)]
|
||||
pub struct AttachPrResponse {
|
||||
pub pr_attached: bool,
|
||||
pub pr_url: Option<String>,
|
||||
pub pr_number: Option<i64>,
|
||||
pub pr_status: Option<MergeStatus>,
|
||||
}
|
||||
|
||||
pub async fn attach_existing_pr(
|
||||
Extension(task_attempt): Extension<TaskAttempt>,
|
||||
State(deployment): State<DeploymentImpl>,
|
||||
) -> Result<ResponseJson<ApiResponse<AttachPrResponse>>, ApiError> {
|
||||
let pool = &deployment.db().pool;
|
||||
|
||||
// Check if PR already attached
|
||||
if let Some(Merge::Pr(pr_merge)) =
|
||||
Merge::find_latest_by_task_attempt_id(pool, task_attempt.id).await?
|
||||
{
|
||||
return Ok(ResponseJson(ApiResponse::success(AttachPrResponse {
|
||||
pr_attached: true,
|
||||
pr_url: Some(pr_merge.pr_info.url.clone()),
|
||||
pr_number: Some(pr_merge.pr_info.number),
|
||||
pr_status: Some(pr_merge.pr_info.status.clone()),
|
||||
})));
|
||||
}
|
||||
|
||||
// Get GitHub token
|
||||
let github_config = deployment.config().read().await.github.clone();
|
||||
let Some(github_token) = github_config.token() else {
|
||||
return Err(ApiError::GitHubService(GitHubServiceError::TokenInvalid));
|
||||
};
|
||||
|
||||
// Get project and repo info
|
||||
let Some(task) = task_attempt.parent_task(pool).await? else {
|
||||
return Err(ApiError::TaskAttempt(TaskAttemptError::TaskNotFound));
|
||||
};
|
||||
let Some(project) = Project::find_by_id(pool, task.project_id).await? else {
|
||||
return Err(ApiError::Project(ProjectError::ProjectNotFound));
|
||||
};
|
||||
|
||||
let github_service = GitHubService::new(&github_token)?;
|
||||
let repo_info = deployment
|
||||
.git()
|
||||
.get_github_repo_info(&project.git_repo_path)?;
|
||||
|
||||
// List all PRs for branch (open, closed, and merged)
|
||||
let prs = github_service
|
||||
.list_all_prs_for_branch(&repo_info, &task_attempt.branch)
|
||||
.await?;
|
||||
|
||||
// Take the first PR (prefer open, but also accept merged/closed)
|
||||
if let Some(pr_info) = prs.into_iter().next() {
|
||||
// Save PR info to database
|
||||
let merge = Merge::create_pr(
|
||||
pool,
|
||||
task_attempt.id,
|
||||
&task_attempt.target_branch,
|
||||
pr_info.number,
|
||||
&pr_info.url,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Update status if not open
|
||||
if !matches!(pr_info.status, MergeStatus::Open) {
|
||||
Merge::update_status(
|
||||
pool,
|
||||
merge.id,
|
||||
pr_info.status.clone(),
|
||||
pr_info.merge_commit_sha.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// If PR is merged, mark task as done
|
||||
if matches!(pr_info.status, MergeStatus::Merged) {
|
||||
Task::update_status(pool, task.id, TaskStatus::Done).await?;
|
||||
}
|
||||
|
||||
Ok(ResponseJson(ApiResponse::success(AttachPrResponse {
|
||||
pr_attached: true,
|
||||
pr_url: Some(pr_info.url),
|
||||
pr_number: Some(pr_info.number),
|
||||
pr_status: Some(pr_info.status),
|
||||
})))
|
||||
} else {
|
||||
Ok(ResponseJson(ApiResponse::success(AttachPrResponse {
|
||||
pr_attached: false,
|
||||
pr_url: None,
|
||||
pr_number: None,
|
||||
pr_status: None,
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
|
||||
let task_attempt_id_router = Router::new()
|
||||
.route("/", get(get_task_attempt))
|
||||
@@ -1307,6 +1401,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
|
||||
.route("/rebase", post(rebase_task_attempt))
|
||||
.route("/conflicts/abort", post(abort_conflicts_task_attempt))
|
||||
.route("/pr", post(create_github_pr))
|
||||
.route("/pr/attach", post(attach_existing_pr))
|
||||
.route("/open-editor", post(open_task_attempt_in_editor))
|
||||
.route("/delete-file", post(delete_task_attempt_file))
|
||||
.route("/children", get(get_task_attempt_children))
|
||||
|
||||
@@ -323,6 +323,60 @@ impl GitHubService {
|
||||
}
|
||||
}
|
||||
|
||||
/// List all pull requests for a branch (including closed/merged)
|
||||
pub async fn list_all_prs_for_branch(
|
||||
&self,
|
||||
repo_info: &GitHubRepoInfo,
|
||||
branch_name: &str,
|
||||
) -> Result<Vec<PullRequestInfo>, GitHubServiceError> {
|
||||
(|| async {
|
||||
self.list_all_prs_for_branch_internal(repo_info, branch_name)
|
||||
.await
|
||||
})
|
||||
.retry(
|
||||
&ExponentialBuilder::default()
|
||||
.with_min_delay(Duration::from_secs(1))
|
||||
.with_max_delay(Duration::from_secs(30))
|
||||
.with_max_times(3)
|
||||
.with_jitter(),
|
||||
)
|
||||
.when(|e| e.should_retry())
|
||||
.notify(|err: &GitHubServiceError, dur: Duration| {
|
||||
tracing::warn!(
|
||||
"GitHub API call failed, retrying after {:.2}s: {}",
|
||||
dur.as_secs_f64(),
|
||||
err
|
||||
);
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn list_all_prs_for_branch_internal(
|
||||
&self,
|
||||
repo_info: &GitHubRepoInfo,
|
||||
branch_name: &str,
|
||||
) -> Result<Vec<PullRequestInfo>, GitHubServiceError> {
|
||||
let prs = self
|
||||
.client
|
||||
.pulls(&repo_info.owner, &repo_info.repo_name)
|
||||
.list()
|
||||
.state(octocrab::params::State::All)
|
||||
.head(format!("{}:{}", repo_info.owner, branch_name))
|
||||
.per_page(100)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| match GitHubServiceError::from(err) {
|
||||
GitHubServiceError::Client(source) => GitHubServiceError::PullRequest(format!(
|
||||
"Failed to list all PRs for branch '{branch_name}': {source}",
|
||||
)),
|
||||
other => other,
|
||||
})?;
|
||||
|
||||
let pr_infos = prs.items.into_iter().map(Self::map_pull_request).collect();
|
||||
|
||||
Ok(pr_infos)
|
||||
}
|
||||
|
||||
/// List repositories for the authenticated user with pagination
|
||||
#[cfg(feature = "cloud")]
|
||||
pub async fn list_repositories(
|
||||
|
||||
Reference in New Issue
Block a user