diff --git a/backend/src/executors/amp.rs b/backend/src/executors/amp.rs index 38c6d0ed..92b0ada1 100644 --- a/backend/src/executors/amp.rs +++ b/backend/src/executors/amp.rs @@ -2,6 +2,7 @@ use std::path::Path; use async_trait::async_trait; use command_group::{AsyncCommandGroup, AsyncGroupChild}; +use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ @@ -22,6 +23,175 @@ pub struct AmpFollowupExecutor { pub prompt: String, } +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +#[serde(tag = "type")] +pub enum AmpJson { + #[serde(rename = "messages")] + Messages { + messages: Vec<(usize, AmpMessage)>, + #[serde(rename = "toolResults")] + tool_results: Vec, + }, + #[serde(rename = "initial")] + Initial { + #[serde(rename = "threadID")] + thread_id: Option, + }, + #[serde(rename = "token-usage")] + TokenUsage(serde_json::Value), + #[serde(rename = "state")] + State { state: String }, + #[serde(rename = "shutdown")] + Shutdown, + #[serde(rename = "tool-status")] + ToolStatus(serde_json::Value), +} + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +pub struct AmpMessage { + pub role: String, + pub content: Vec, + pub state: Option, + pub meta: Option, +} + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +pub struct AmpMeta { + #[serde(rename = "sentAt")] + pub sent_at: u64, +} + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +#[serde(tag = "type")] +pub enum AmpContentItem { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "thinking")] + Thinking { thinking: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + #[serde(rename = "toolUseID")] + tool_use_id: String, + run: serde_json::Value, + }, +} + +impl AmpJson { + pub fn should_process(&self) -> bool { + matches!(self, AmpJson::Messages { .. }) + } + + pub fn extract_session_id(&self) -> Option { + match self { + AmpJson::Initial { thread_id } => thread_id.clone(), + _ => None, + } + } + + pub fn has_streaming_content(&self) -> bool { + match self { + AmpJson::Messages { messages, .. } => messages.iter().any(|(_index, message)| { + if let Some(state) = &message.state { + if let Some(state_type) = state.get("type").and_then(|t| t.as_str()) { + state_type == "streaming" + } else { + false + } + } else { + false + } + }), + _ => false, + } + } + + pub fn to_normalized_entries( + &self, + executor: &AmpExecutor, + worktree_path: &str, + ) -> Vec { + match self { + AmpJson::Messages { messages, .. } => { + if self.has_streaming_content() { + return vec![]; + } + + let mut entries = Vec::new(); + for (_index, message) in messages { + let role = &message.role; + for content_item in &message.content { + if let Some(entry) = + content_item.to_normalized_entry(role, message, executor, worktree_path) + { + entries.push(entry); + } + } + } + entries + } + _ => vec![], + } + } +} + +impl AmpContentItem { + pub fn to_normalized_entry( + &self, + role: &str, + message: &AmpMessage, + executor: &AmpExecutor, + worktree_path: &str, + ) -> Option { + use serde_json::Value; + + let timestamp = message.meta.as_ref().map(|meta| meta.sent_at.to_string()); + + match self { + AmpContentItem::Text { text } => { + let entry_type = match role { + "user" => NormalizedEntryType::UserMessage, + "assistant" => NormalizedEntryType::AssistantMessage, + _ => return None, + }; + Some(NormalizedEntry { + timestamp, + entry_type, + content: text.clone(), + metadata: Some(serde_json::to_value(self).unwrap_or(Value::Null)), + }) + } + AmpContentItem::Thinking { thinking } => Some(NormalizedEntry { + timestamp, + entry_type: NormalizedEntryType::Thinking, + content: thinking.clone(), + metadata: Some(serde_json::to_value(self).unwrap_or(Value::Null)), + }), + AmpContentItem::ToolUse { name, input, .. } => { + let action_type = executor.extract_action_type(name, input, worktree_path); + let content = + executor.generate_concise_content(name, input, &action_type, worktree_path); + + Some(NormalizedEntry { + timestamp, + entry_type: NormalizedEntryType::ToolUse { + tool_name: name.clone(), + action_type, + }, + content, + metadata: Some(serde_json::to_value(self).unwrap_or(Value::Null)), + }) + } + AmpContentItem::ToolResult { .. } => None, + } + } +} + #[async_trait] impl Executor for AmpExecutor { async fn spawn( @@ -94,8 +264,6 @@ Task title: {}"#, logs: &str, worktree_path: &str, ) -> Result { - use serde_json::Value; - let mut entries = Vec::new(); let mut session_id = None; @@ -105,9 +273,9 @@ Task title: {}"#, continue; } - // Try to parse as JSON - let json: Value = match serde_json::from_str(trimmed) { - Ok(json) => json, + // Try to parse as AmpMessage + let amp_message: AmpJson = match serde_json::from_str(trimmed) { + Ok(msg) => msg, Err(_) => { // If line isn't valid JSON, add it as raw text entries.push(NormalizedEntry { @@ -120,147 +288,17 @@ Task title: {}"#, } }; - // Extract session ID (threadID in AMP) + // Extract session ID if available if session_id.is_none() { - if let Some(thread_id) = json.get("threadID").and_then(|v| v.as_str()) { - session_id = Some(thread_id.to_string()); + if let Some(id) = amp_message.extract_session_id() { + session_id = Some(id); } } - // Process different message types - let processed = if let Some(msg_type) = json.get("type").and_then(|t| t.as_str()) { - match msg_type { - "messages" => { - if let Some(messages) = json.get("messages").and_then(|m| m.as_array()) { - for message_entry in messages { - if let Some(message_data) = - message_entry.as_array().and_then(|arr| arr.get(1)) - { - if let Some(role) = - message_data.get("role").and_then(|r| r.as_str()) - { - if let Some(content) = - message_data.get("content").and_then(|c| c.as_array()) - { - for content_item in content { - if let Some(content_type) = content_item - .get("type") - .and_then(|t| t.as_str()) - { - match content_type { - "text" => { - if let Some(text) = content_item - .get("text") - .and_then(|t| t.as_str()) - { - let entry_type = match role { - "user" => NormalizedEntryType::UserMessage, - "assistant" => NormalizedEntryType::AssistantMessage, - _ => continue, - }; - entries.push(NormalizedEntry { - timestamp: message_data - .get("meta") - .and_then(|m| { - m.get("sentAt") - }) - .and_then(|s| s.as_u64()) - .map(|ts| ts.to_string()), - entry_type, - content: text.to_string(), - metadata: Some( - content_item.clone(), - ), - }); - } - } - "thinking" => { - if let Some(thinking) = content_item - .get("thinking") - .and_then(|t| t.as_str()) - { - entries.push(NormalizedEntry { - timestamp: None, - entry_type: - NormalizedEntryType::Thinking, - content: thinking.to_string(), - metadata: Some( - content_item.clone(), - ), - }); - } - } - "tool_use" => { - if let Some(tool_name) = content_item - .get("name") - .and_then(|n| n.as_str()) - { - let input = content_item - .get("input") - .unwrap_or(&Value::Null); - let action_type = self - .extract_action_type( - tool_name, - input, - worktree_path, - ); - let content = self - .generate_concise_content( - tool_name, - input, - &action_type, - worktree_path, - ); - - entries.push(NormalizedEntry { - timestamp: None, - entry_type: - NormalizedEntryType::ToolUse { - tool_name: tool_name - .to_string(), - action_type, - }, - content, - metadata: Some( - content_item.clone(), - ), - }); - } - } - _ => {} - } - } - } - } - } - } - } - } - true - } - // Ignore these JSON types - they're not relevant for task execution logs - "initial" | "token-usage" | "state" | "shutdown" => true, - _ => false, - } - } else { - false - }; - - // If JSON didn't match expected patterns, add it as unrecognized JSON - // Skip JSON with type "result" as requested - if !processed { - if let Some(msg_type) = json.get("type").and_then(|t| t.as_str()) { - if msg_type == "result" { - // Skip result entries - continue; - } - } - entries.push(NormalizedEntry { - timestamp: None, - entry_type: NormalizedEntryType::SystemMessage, - content: format!("Unrecognized JSON: {}", trimmed), - metadata: Some(json), - }); + // Process the message if it's a type we care about + if amp_message.should_process() { + let new_entries = amp_message.to_normalized_entries(self, worktree_path); + entries.extend(new_entries); } } @@ -610,3 +648,54 @@ impl Executor for AmpFollowupExecutor { main_executor.normalize_logs(logs, worktree_path) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_filter_streaming_messages() { + // Test logs that simulate the actual normalize_logs behavior + let amp_executor = AmpExecutor; + let logs = r#"{"type":"messages","messages":[[7,{"role":"assistant","content":[{"type":"text","text":"Created all three files: test1.txt, test2.txt, and test3.txt"}],"state":{"type":"streaming"}}]],"toolResults":[]} +{"type":"messages","messages":[[7,{"role":"assistant","content":[{"type":"text","text":"Created all three files: test1.txt, test2.txt, and test3.txt, each with a line of text."}],"state":{"type":"streaming"}}]],"toolResults":[]} +{"type":"messages","messages":[[7,{"role":"assistant","content":[{"type":"text","text":"Created all three files: test1.txt, test2.txt, and test3.txt, each with a line of text."}],"state":{"type":"complete","stopReason":"end_turn"}}]],"toolResults":[]}"#; + + let result = amp_executor.normalize_logs(logs, "/tmp/test"); + assert!(result.is_ok()); + + let conversation = result.unwrap(); + + // Should only have 1 assistant message (the complete one) + let assistant_messages: Vec<_> = conversation + .entries + .iter() + .filter(|e| matches!(e.entry_type, NormalizedEntryType::AssistantMessage)) + .collect(); + + assert_eq!(assistant_messages.len(), 1); + assert_eq!(assistant_messages[0].content, "Created all three files: test1.txt, test2.txt, and test3.txt, each with a line of text."); + } + + #[test] + fn test_filter_preserves_messages_without_state() { + // Test that messages without state metadata are preserved (for compatibility) + let amp_executor = AmpExecutor; + let logs = r#"{"type":"messages","messages":[[1,{"role":"assistant","content":[{"type":"text","text":"Regular message"}]}]],"toolResults":[]}"#; + + let result = amp_executor.normalize_logs(logs, "/tmp/test"); + assert!(result.is_ok()); + + let conversation = result.unwrap(); + + // Should have 1 assistant message + let assistant_messages: Vec<_> = conversation + .entries + .iter() + .filter(|e| matches!(e.entry_type, NormalizedEntryType::AssistantMessage)) + .collect(); + + assert_eq!(assistant_messages.len(), 1); + assert_eq!(assistant_messages[0].content, "Regular message"); + } +}