Refactor amp, fix duplicated messages in frontend (#133)

* Refactor amp, fix duplicated messages in frontend

* Fmt
This commit is contained in:
Alex Netsch
2025-07-14 09:31:41 +01:00
committed by GitHub
parent 8a8c7a16f6
commit f6bbece4c1

View File

@@ -2,6 +2,7 @@ use std::path::Path;
use async_trait::async_trait;
use command_group::{AsyncCommandGroup, AsyncGroupChild};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{
@@ -22,6 +23,175 @@ pub struct AmpFollowupExecutor {
pub prompt: String,
}
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "type")]
pub enum AmpJson {
#[serde(rename = "messages")]
Messages {
messages: Vec<(usize, AmpMessage)>,
#[serde(rename = "toolResults")]
tool_results: Vec<serde_json::Value>,
},
#[serde(rename = "initial")]
Initial {
#[serde(rename = "threadID")]
thread_id: Option<String>,
},
#[serde(rename = "token-usage")]
TokenUsage(serde_json::Value),
#[serde(rename = "state")]
State { state: String },
#[serde(rename = "shutdown")]
Shutdown,
#[serde(rename = "tool-status")]
ToolStatus(serde_json::Value),
}
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
pub struct AmpMessage {
pub role: String,
pub content: Vec<AmpContentItem>,
pub state: Option<serde_json::Value>,
pub meta: Option<AmpMeta>,
}
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
pub struct AmpMeta {
#[serde(rename = "sentAt")]
pub sent_at: u64,
}
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "type")]
pub enum AmpContentItem {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "thinking")]
Thinking { thinking: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
#[serde(rename = "toolUseID")]
tool_use_id: String,
run: serde_json::Value,
},
}
impl AmpJson {
pub fn should_process(&self) -> bool {
matches!(self, AmpJson::Messages { .. })
}
pub fn extract_session_id(&self) -> Option<String> {
match self {
AmpJson::Initial { thread_id } => thread_id.clone(),
_ => None,
}
}
pub fn has_streaming_content(&self) -> bool {
match self {
AmpJson::Messages { messages, .. } => messages.iter().any(|(_index, message)| {
if let Some(state) = &message.state {
if let Some(state_type) = state.get("type").and_then(|t| t.as_str()) {
state_type == "streaming"
} else {
false
}
} else {
false
}
}),
_ => false,
}
}
pub fn to_normalized_entries(
&self,
executor: &AmpExecutor,
worktree_path: &str,
) -> Vec<NormalizedEntry> {
match self {
AmpJson::Messages { messages, .. } => {
if self.has_streaming_content() {
return vec![];
}
let mut entries = Vec::new();
for (_index, message) in messages {
let role = &message.role;
for content_item in &message.content {
if let Some(entry) =
content_item.to_normalized_entry(role, message, executor, worktree_path)
{
entries.push(entry);
}
}
}
entries
}
_ => vec![],
}
}
}
impl AmpContentItem {
pub fn to_normalized_entry(
&self,
role: &str,
message: &AmpMessage,
executor: &AmpExecutor,
worktree_path: &str,
) -> Option<NormalizedEntry> {
use serde_json::Value;
let timestamp = message.meta.as_ref().map(|meta| meta.sent_at.to_string());
match self {
AmpContentItem::Text { text } => {
let entry_type = match role {
"user" => NormalizedEntryType::UserMessage,
"assistant" => NormalizedEntryType::AssistantMessage,
_ => return None,
};
Some(NormalizedEntry {
timestamp,
entry_type,
content: text.clone(),
metadata: Some(serde_json::to_value(self).unwrap_or(Value::Null)),
})
}
AmpContentItem::Thinking { thinking } => Some(NormalizedEntry {
timestamp,
entry_type: NormalizedEntryType::Thinking,
content: thinking.clone(),
metadata: Some(serde_json::to_value(self).unwrap_or(Value::Null)),
}),
AmpContentItem::ToolUse { name, input, .. } => {
let action_type = executor.extract_action_type(name, input, worktree_path);
let content =
executor.generate_concise_content(name, input, &action_type, worktree_path);
Some(NormalizedEntry {
timestamp,
entry_type: NormalizedEntryType::ToolUse {
tool_name: name.clone(),
action_type,
},
content,
metadata: Some(serde_json::to_value(self).unwrap_or(Value::Null)),
})
}
AmpContentItem::ToolResult { .. } => None,
}
}
}
#[async_trait]
impl Executor for AmpExecutor {
async fn spawn(
@@ -94,8 +264,6 @@ Task title: {}"#,
logs: &str,
worktree_path: &str,
) -> Result<NormalizedConversation, String> {
use serde_json::Value;
let mut entries = Vec::new();
let mut session_id = None;
@@ -105,9 +273,9 @@ Task title: {}"#,
continue;
}
// Try to parse as JSON
let json: Value = match serde_json::from_str(trimmed) {
Ok(json) => json,
// Try to parse as AmpMessage
let amp_message: AmpJson = match serde_json::from_str(trimmed) {
Ok(msg) => msg,
Err(_) => {
// If line isn't valid JSON, add it as raw text
entries.push(NormalizedEntry {
@@ -120,147 +288,17 @@ Task title: {}"#,
}
};
// Extract session ID (threadID in AMP)
// Extract session ID if available
if session_id.is_none() {
if let Some(thread_id) = json.get("threadID").and_then(|v| v.as_str()) {
session_id = Some(thread_id.to_string());
if let Some(id) = amp_message.extract_session_id() {
session_id = Some(id);
}
}
// Process different message types
let processed = if let Some(msg_type) = json.get("type").and_then(|t| t.as_str()) {
match msg_type {
"messages" => {
if let Some(messages) = json.get("messages").and_then(|m| m.as_array()) {
for message_entry in messages {
if let Some(message_data) =
message_entry.as_array().and_then(|arr| arr.get(1))
{
if let Some(role) =
message_data.get("role").and_then(|r| r.as_str())
{
if let Some(content) =
message_data.get("content").and_then(|c| c.as_array())
{
for content_item in content {
if let Some(content_type) = content_item
.get("type")
.and_then(|t| t.as_str())
{
match content_type {
"text" => {
if let Some(text) = content_item
.get("text")
.and_then(|t| t.as_str())
{
let entry_type = match role {
"user" => NormalizedEntryType::UserMessage,
"assistant" => NormalizedEntryType::AssistantMessage,
_ => continue,
};
entries.push(NormalizedEntry {
timestamp: message_data
.get("meta")
.and_then(|m| {
m.get("sentAt")
})
.and_then(|s| s.as_u64())
.map(|ts| ts.to_string()),
entry_type,
content: text.to_string(),
metadata: Some(
content_item.clone(),
),
});
}
}
"thinking" => {
if let Some(thinking) = content_item
.get("thinking")
.and_then(|t| t.as_str())
{
entries.push(NormalizedEntry {
timestamp: None,
entry_type:
NormalizedEntryType::Thinking,
content: thinking.to_string(),
metadata: Some(
content_item.clone(),
),
});
}
}
"tool_use" => {
if let Some(tool_name) = content_item
.get("name")
.and_then(|n| n.as_str())
{
let input = content_item
.get("input")
.unwrap_or(&Value::Null);
let action_type = self
.extract_action_type(
tool_name,
input,
worktree_path,
);
let content = self
.generate_concise_content(
tool_name,
input,
&action_type,
worktree_path,
);
entries.push(NormalizedEntry {
timestamp: None,
entry_type:
NormalizedEntryType::ToolUse {
tool_name: tool_name
.to_string(),
action_type,
},
content,
metadata: Some(
content_item.clone(),
),
});
}
}
_ => {}
}
}
}
}
}
}
}
}
true
}
// Ignore these JSON types - they're not relevant for task execution logs
"initial" | "token-usage" | "state" | "shutdown" => true,
_ => false,
}
} else {
false
};
// If JSON didn't match expected patterns, add it as unrecognized JSON
// Skip JSON with type "result" as requested
if !processed {
if let Some(msg_type) = json.get("type").and_then(|t| t.as_str()) {
if msg_type == "result" {
// Skip result entries
continue;
}
}
entries.push(NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::SystemMessage,
content: format!("Unrecognized JSON: {}", trimmed),
metadata: Some(json),
});
// Process the message if it's a type we care about
if amp_message.should_process() {
let new_entries = amp_message.to_normalized_entries(self, worktree_path);
entries.extend(new_entries);
}
}
@@ -610,3 +648,54 @@ impl Executor for AmpFollowupExecutor {
main_executor.normalize_logs(logs, worktree_path)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_streaming_messages() {
// Test logs that simulate the actual normalize_logs behavior
let amp_executor = AmpExecutor;
let logs = r#"{"type":"messages","messages":[[7,{"role":"assistant","content":[{"type":"text","text":"Created all three files: test1.txt, test2.txt, and test3.txt"}],"state":{"type":"streaming"}}]],"toolResults":[]}
{"type":"messages","messages":[[7,{"role":"assistant","content":[{"type":"text","text":"Created all three files: test1.txt, test2.txt, and test3.txt, each with a line of text."}],"state":{"type":"streaming"}}]],"toolResults":[]}
{"type":"messages","messages":[[7,{"role":"assistant","content":[{"type":"text","text":"Created all three files: test1.txt, test2.txt, and test3.txt, each with a line of text."}],"state":{"type":"complete","stopReason":"end_turn"}}]],"toolResults":[]}"#;
let result = amp_executor.normalize_logs(logs, "/tmp/test");
assert!(result.is_ok());
let conversation = result.unwrap();
// Should only have 1 assistant message (the complete one)
let assistant_messages: Vec<_> = conversation
.entries
.iter()
.filter(|e| matches!(e.entry_type, NormalizedEntryType::AssistantMessage))
.collect();
assert_eq!(assistant_messages.len(), 1);
assert_eq!(assistant_messages[0].content, "Created all three files: test1.txt, test2.txt, and test3.txt, each with a line of text.");
}
#[test]
fn test_filter_preserves_messages_without_state() {
// Test that messages without state metadata are preserved (for compatibility)
let amp_executor = AmpExecutor;
let logs = r#"{"type":"messages","messages":[[1,{"role":"assistant","content":[{"type":"text","text":"Regular message"}]}]],"toolResults":[]}"#;
let result = amp_executor.normalize_logs(logs, "/tmp/test");
assert!(result.is_ok());
let conversation = result.unwrap();
// Should have 1 assistant message
let assistant_messages: Vec<_> = conversation
.entries
.iter()
.filter(|e| matches!(e.entry_type, NormalizedEntryType::AssistantMessage))
.collect();
assert_eq!(assistant_messages.len(), 1);
assert_eq!(assistant_messages[0].content, "Regular message");
}
}