fix: handle whole assistant and thinking messages for Codex (#1036)

* fix: handle whole assistant and thinking messages for Codex

* fmt
This commit is contained in:
Solomon
2025-10-17 15:47:08 +01:00
committed by GitHub
parent dfa8694d0d
commit e4a4c004da

View File

@@ -12,12 +12,12 @@ use codex_protocol::{
config_types::ReasoningEffort,
plan_tool::{StepStatus, UpdatePlanArgs},
protocol::{
AgentMessageDeltaEvent, AgentReasoningDeltaEvent, AgentReasoningSectionBreakEvent,
BackgroundEventEvent, ErrorEvent, EventMsg, ExecCommandBeginEvent, ExecCommandEndEvent,
ExecCommandOutputDeltaEvent, ExecOutputStream, FileChange as CodexProtoFileChange,
McpInvocation, McpToolCallBeginEvent, McpToolCallEndEvent, PatchApplyBeginEvent,
PatchApplyEndEvent, StreamErrorEvent, TokenUsageInfo, ViewImageToolCallEvent,
WebSearchBeginEvent, WebSearchEndEvent,
AgentMessageDeltaEvent, AgentMessageEvent, AgentReasoningDeltaEvent, AgentReasoningEvent,
AgentReasoningSectionBreakEvent, BackgroundEventEvent, ErrorEvent, EventMsg,
ExecCommandBeginEvent, ExecCommandEndEvent, ExecCommandOutputDeltaEvent, ExecOutputStream,
FileChange as CodexProtoFileChange, McpInvocation, McpToolCallBeginEvent,
McpToolCallEndEvent, PatchApplyBeginEvent, PatchApplyEndEvent, StreamErrorEvent,
TokenUsageInfo, ViewImageToolCallEvent, WebSearchBeginEvent, WebSearchEndEvent,
},
};
use futures::StreamExt;
@@ -219,6 +219,7 @@ impl LogState {
&mut self,
content: String,
type_: StreamingTextKind,
mode: UpdateMode,
) -> (NormalizedEntry, usize, bool) {
let index_provider = &self.entry_index;
let entry = match type_ {
@@ -232,7 +233,10 @@ impl LogState {
(&entry.as_ref().unwrap().content, index)
} else {
let streaming_state = entry.as_mut().unwrap();
streaming_state.content.push_str(&content);
match mode {
UpdateMode::Append => streaming_state.content.push_str(&content),
UpdateMode::Set => streaming_state.content = content,
}
(&streaming_state.content, streaming_state.index)
};
let normalized_entry = NormalizedEntry {
@@ -247,13 +251,42 @@ impl LogState {
(normalized_entry, index, is_new)
}
fn assistant_message_update(&mut self, content: String) -> (NormalizedEntry, usize, bool) {
self.streaming_text_update(content, StreamingTextKind::Assistant)
fn streaming_text_append(
&mut self,
content: String,
type_: StreamingTextKind,
) -> (NormalizedEntry, usize, bool) {
self.streaming_text_update(content, type_, UpdateMode::Append)
}
fn thinking_update(&mut self, content: String) -> (NormalizedEntry, usize, bool) {
self.streaming_text_update(content, StreamingTextKind::Thinking)
fn streaming_text_set(
&mut self,
content: String,
type_: StreamingTextKind,
) -> (NormalizedEntry, usize, bool) {
self.streaming_text_update(content, type_, UpdateMode::Set)
}
fn assistant_message_append(&mut self, content: String) -> (NormalizedEntry, usize, bool) {
self.streaming_text_append(content, StreamingTextKind::Assistant)
}
fn thinking_append(&mut self, content: String) -> (NormalizedEntry, usize, bool) {
self.streaming_text_append(content, StreamingTextKind::Thinking)
}
fn assistant_message(&mut self, content: String) -> (NormalizedEntry, usize, bool) {
self.streaming_text_set(content, StreamingTextKind::Assistant)
}
fn thinking(&mut self, content: String) -> (NormalizedEntry, usize, bool) {
self.streaming_text_set(content, StreamingTextKind::Thinking)
}
}
enum UpdateMode {
Append,
Set,
}
fn upsert_normalized_entry(
@@ -400,18 +433,35 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
match event {
EventMsg::SessionConfigured(payload) => {
msg_store.push_session_id(payload.session_id.to_string());
handle_model_params(payload.model, payload.reasoning_effort, &msg_store, &entry_index);
handle_model_params(
payload.model,
payload.reasoning_effort,
&msg_store,
&entry_index,
);
}
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
state.thinking = None;
let (entry, index, is_new) = state.assistant_message_update(delta);
let (entry, index, is_new) = state.assistant_message_append(delta);
upsert_normalized_entry(&msg_store, index, entry, is_new);
}
EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => {
state.assistant = None;
let (entry, index, is_new) = state.thinking_update(delta);
let (entry, index, is_new) = state.thinking_append(delta);
upsert_normalized_entry(&msg_store, index, entry, is_new);
}
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
state.thinking = None;
let (entry, index, is_new) = state.assistant_message(message);
upsert_normalized_entry(&msg_store, index, entry, is_new);
state.assistant = None;
}
EventMsg::AgentReasoning(AgentReasoningEvent { text }) => {
state.assistant = None;
let (entry, index, is_new) = state.thinking(text);
upsert_normalized_entry(&msg_store, index, entry, is_new);
state.thinking = None;
}
EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}) => {
state.assistant = None;
state.thinking = None;
@@ -438,7 +488,11 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
},
);
let command_state = state.commands.get_mut(&call_id).unwrap();
let index = add_normalized_entry(&msg_store, &entry_index, command_state.to_normalized_entry());
let index = add_normalized_entry(
&msg_store,
&entry_index,
command_state.to_normalized_entry(),
);
command_state.index = Some(index)
}
EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
@@ -459,7 +513,11 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
tracing::error!("missing entry index for existing command state");
continue;
};
replace_normalized_entry(&msg_store, index, command_state.to_normalized_entry());
replace_normalized_entry(
&msg_store,
index,
command_state.to_normalized_entry(),
);
}
}
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
@@ -483,24 +541,36 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
tracing::error!("missing entry index for existing command state");
continue;
};
replace_normalized_entry(&msg_store, index, command_state.to_normalized_entry());
replace_normalized_entry(
&msg_store,
index,
command_state.to_normalized_entry(),
);
}
}
EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => {
add_normalized_entry(&msg_store, &entry_index, NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::SystemMessage,
content: format!("Background event: {message}"),
metadata: None,
});
add_normalized_entry(
&msg_store,
&entry_index,
NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::SystemMessage,
content: format!("Background event: {message}"),
metadata: None,
},
);
}
EventMsg::StreamError(StreamErrorEvent { message }) => {
add_normalized_entry(&msg_store, &entry_index, NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ErrorMessage,
content: format!("Stream error: {message}"),
metadata: None,
});
add_normalized_entry(
&msg_store,
&entry_index,
NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ErrorMessage,
content: format!("Stream error: {message}"),
metadata: None,
},
);
}
EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
call_id,
@@ -518,7 +588,11 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
},
);
let mcp_tool_state = state.mcp_tools.get_mut(&call_id).unwrap();
let index = add_normalized_entry(&msg_store, &entry_index, mcp_tool_state.to_normalized_entry());
let index = add_normalized_entry(
&msg_store,
&entry_index,
mcp_tool_state.to_normalized_entry(),
);
mcp_tool_state.index = Some(index);
}
EventMsg::McpToolCallEnd(McpToolCallEndEvent {
@@ -527,25 +601,42 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
if let Some(mut mcp_tool_state) = state.mcp_tools.remove(&call_id) {
match result {
Ok(value) => {
mcp_tool_state.status =
if value.is_error.unwrap_or(false) {
mcp_tool_state.status = if value.is_error.unwrap_or(false) {
ToolStatus::Failed
} else {
ToolStatus::Success
};
if value.content.iter().all(|block| matches!(block, ContentBlock::TextContent(_))) {
if value
.content
.iter()
.all(|block| matches!(block, ContentBlock::TextContent(_)))
{
mcp_tool_state.result = Some(ToolResult {
r#type: ToolResultValueType::Markdown,
value: Value::String(value.content.iter().map(|block| {
if let ContentBlock::TextContent(content) = block {
content.text.clone()
} else {
unreachable!()
}
}).collect::<Vec<String>>().join("\n"))
value: Value::String(
value
.content
.iter()
.map(|block| {
if let ContentBlock::TextContent(content) =
block
{
content.text.clone()
} else {
unreachable!()
}
})
.collect::<Vec<String>>()
.join("\n"),
),
});
} else {
mcp_tool_state.result = Some(ToolResult { r#type: ToolResultValueType::Json, value: value.structured_content.unwrap_or_else(|| serde_json::to_value(value.content).unwrap_or_default()) });
mcp_tool_state.result = Some(ToolResult {
r#type: ToolResultValueType::Json,
value: value.structured_content.unwrap_or_else(|| {
serde_json::to_value(value.content).unwrap_or_default()
}),
});
}
}
Err(err) => {
@@ -560,7 +651,11 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
tracing::error!("missing entry index for existing mcp tool state");
continue;
};
replace_normalized_entry(&msg_store, index, mcp_tool_state.to_normalized_entry());
replace_normalized_entry(
&msg_store,
index,
mcp_tool_state.to_normalized_entry(),
);
}
}
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
@@ -578,7 +673,11 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
status: ToolStatus::Created,
});
let patch_entry = patch_state.entries.last_mut().unwrap();
let index = add_normalized_entry(&msg_store, &entry_index, patch_entry.to_normalized_entry());
let index = add_normalized_entry(
&msg_store,
&entry_index,
patch_entry.to_normalized_entry(),
);
patch_entry.index = Some(index);
}
state.patches.insert(call_id, patch_state);
@@ -602,7 +701,11 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
tracing::error!("missing entry index for existing patch entry");
continue;
};
replace_normalized_entry(&msg_store, index, entry.to_normalized_entry());
replace_normalized_entry(
&msg_store,
index,
entry.to_normalized_entry(),
);
}
}
}
@@ -643,7 +746,9 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
timestamp: None,
entry_type: NormalizedEntryType::ToolUse {
tool_name: "view_image".to_string(),
action_type: ActionType::FileRead { path: relative_path.clone() },
action_type: ActionType::FileRead {
path: relative_path.clone(),
},
status: ToolStatus::Success,
},
content: format!("`{relative_path}`"),
@@ -692,21 +797,23 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
);
}
EventMsg::Error(ErrorEvent { message }) => {
add_normalized_entry(&msg_store, &entry_index, NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ErrorMessage,
content: message,
metadata: None,
});
add_normalized_entry(
&msg_store,
&entry_index,
NormalizedEntry {
timestamp: None,
entry_type: NormalizedEntryType::ErrorMessage,
content: message,
metadata: None,
},
);
}
EventMsg::TokenCount(payload) => {
if let Some(info) = payload.info {
state.token_usage_info = Some(info);
}
}
EventMsg::AgentReasoning(..) // content duplicated with delta events
| EventMsg::AgentMessage(..) // ditto
| EventMsg::AgentReasoningRawContent(..)
EventMsg::AgentReasoningRawContent(..)
| EventMsg::AgentReasoningRawContentDelta(..)
| EventMsg::TaskStarted(..)
| EventMsg::UserMessage(..)
@@ -720,9 +827,8 @@ pub fn normalize_logs(msg_store: Arc<MsgStore>, worktree_path: &Path) {
| EventMsg::EnteredReviewMode(..)
| EventMsg::ExitedReviewMode(..)
| EventMsg::TaskComplete(..)
|EventMsg::ExecApprovalRequest(..)
|EventMsg::ApplyPatchApprovalRequest(..)
=> {}
| EventMsg::ExecApprovalRequest(..)
| EventMsg::ApplyPatchApprovalRequest(..) => {}
}
}
});