tirea_protocol_ag_ui/
context.rs

1use crate::events::{Event, ReasoningEncryptedValueSubtype};
2use serde_json::Value;
3use std::collections::{HashMap, HashSet};
4use tirea_contract::{AgentEvent, TerminationReason};
5use tracing::warn;
6
7// AG-UI Context
8// ============================================================================
9
10/// Context for AG-UI event conversion.
11///
12/// Maintains state needed for converting internal AgentEvents to AG-UI events.
13/// State is initialized from the first `RunStart` event in the stream.
14#[derive(Debug, Clone)]
15pub struct AgUiEventContext {
16    /// AG-UI-facing run id for event payloads; if set, replaces internal run_id
17    /// in outward-facing lifecycle events.
18    frontend_run_id: Option<String>,
19    /// Current message identifier.
20    pub message_id: String,
21    /// Step counter for generating step names.
22    step_counter: u32,
23    /// Whether text message stream has started.
24    pub(super) text_started: bool,
25    /// Whether reasoning stream has started.
26    reasoning_started: bool,
27    /// Whether text has ever been ended (used to detect restarts).
28    text_ever_ended: bool,
29    /// Current step name.
30    current_step: Option<String>,
31    /// Whether a terminal event (RunFinish/Error) has been emitted.
32    /// After this, all subsequent events are suppressed.
33    stopped: bool,
34    /// Last emitted state snapshot, used to compute RFC 6902 deltas.
35    last_state: Option<Value>,
36    /// Tool call IDs that have received at least one ToolCallDelta (args chunk).
37    /// Used to avoid double-emitting TOOL_CALL_ARGS: when ToolCallReady arrives
38    /// for a tool call that never received deltas (e.g. frontend tool invocations),
39    /// we emit the full arguments as a single TOOL_CALL_ARGS event.
40    tool_ids_with_deltas: HashSet<String>,
41}
42
43impl Default for AgUiEventContext {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl AgUiEventContext {
50    /// Create a new AG-UI context.
51    ///
52    /// The context is fully initialized when the first `RunStart` event
53    /// arrives, which generates an independent `message_id` (UUID v7).
54    pub fn new() -> Self {
55        Self {
56            frontend_run_id: None,
57            message_id: String::new(),
58            step_counter: 0,
59            text_started: false,
60            reasoning_started: false,
61            text_ever_ended: false,
62            current_step: None,
63            stopped: false,
64            last_state: None,
65            tool_ids_with_deltas: HashSet::new(),
66        }
67    }
68
69    /// Set a frontend run id used in outward-facing AG-UI lifecycle events.
70    pub fn with_frontend_run_id(mut self, run_id: impl Into<String>) -> Self {
71        self.frontend_run_id = Some(run_id.into());
72        self
73    }
74
75    fn outward_run_id(&self, internal_run_id: &str) -> String {
76        self.frontend_run_id
77            .as_ref()
78            .cloned()
79            .unwrap_or_else(|| internal_run_id.to_string())
80    }
81
82    /// Generate the next step name.
83    pub fn next_step_name(&mut self) -> String {
84        self.step_counter += 1;
85        let name = format!("step_{}", self.step_counter);
86        self.current_step = Some(name.clone());
87        name
88    }
89
90    /// Get the current step name.
91    pub fn current_step_name(&self) -> String {
92        self.current_step
93            .clone()
94            .unwrap_or_else(|| format!("step_{}", self.step_counter))
95    }
96
97    /// Mark text stream as started.
98    ///
99    /// If text was previously ended, a new `message_id` is generated so that
100    /// each TEXT_MESSAGE_START / TEXT_MESSAGE_END cycle uses a unique ID.
101    /// This prevents CopilotKit / AG-UI runtimes from confusing reopened
102    /// message IDs with already-ended ones (which causes "text-end for
103    /// missing text part" errors on the frontend).
104    pub fn start_text(&mut self) -> bool {
105        let was_started = self.text_started;
106        self.text_started = true;
107        if !was_started {
108            // Generate a fresh message_id when restarting text after a prior end.
109            if self.text_ever_ended {
110                self.new_message_id();
111            }
112            true
113        } else {
114            false
115        }
116    }
117
118    /// Mark text stream as ended and return whether it was active.
119    pub fn end_text(&mut self) -> bool {
120        let was_started = self.text_started;
121        self.text_started = false;
122        if was_started {
123            self.text_ever_ended = true;
124        }
125        was_started
126    }
127
128    /// Whether a text stream is currently open.
129    pub fn is_text_open(&self) -> bool {
130        self.text_started
131    }
132
133    /// Mark reasoning stream as started.
134    pub fn start_reasoning(&mut self) -> bool {
135        let was_started = self.reasoning_started;
136        self.reasoning_started = true;
137        !was_started
138    }
139
140    /// Mark reasoning stream as ended and return whether it was active.
141    pub fn end_reasoning(&mut self) -> bool {
142        let was_started = self.reasoning_started;
143        self.reasoning_started = false;
144        was_started
145    }
146
147    /// Reset text lifecycle state for a new step with a pre-generated message ID.
148    ///
149    /// This ensures that the streaming message ID matches the stored `Message.id`
150    /// for the assistant message produced by this step.
151    pub fn reset_for_step(&mut self, message_id: String) {
152        self.message_id = message_id;
153        self.text_started = false;
154        self.reasoning_started = false;
155        self.text_ever_ended = false;
156    }
157
158    /// Generate a new message ID.
159    pub fn new_message_id(&mut self) -> String {
160        self.message_id = tirea_contract::gen_message_id();
161        self.message_id.clone()
162    }
163
164    fn close_reasoning_stream(&mut self) -> Vec<Event> {
165        if self.end_reasoning() {
166            vec![
167                Event::reasoning_message_end(&self.message_id),
168                Event::reasoning_end(&self.message_id),
169            ]
170        } else {
171            Vec::new()
172        }
173    }
174
175    fn close_open_streams(&mut self) -> Vec<Event> {
176        let mut events = self.close_reasoning_stream();
177        if self.end_text() {
178            events.push(Event::text_message_end(&self.message_id));
179        }
180        events
181    }
182
183    /// Convert an AgentEvent to AG-UI protocol compatible events.
184    ///
185    /// Handles full stream lifecycle: text start/end pairs, step counters,
186    /// terminal event suppression (after Error), and Pending event filtering.
187    pub fn on_agent_event(&mut self, ev: &AgentEvent) -> Vec<Event> {
188        // After a terminal event (RunFinish/Error), suppress everything.
189        if self.stopped {
190            return Vec::new();
191        }
192
193        // Lifecycle bookkeeping before conversion.
194        match ev {
195            AgentEvent::RunFinish { .. } | AgentEvent::Error { .. } => {
196                self.stopped = true;
197            }
198            AgentEvent::ToolCallResumed { .. } => {
199                return vec![];
200            }
201            _ => {}
202        }
203
204        match ev {
205            AgentEvent::RunStart {
206                thread_id,
207                run_id,
208                parent_run_id,
209            } => {
210                let outward_run_id = self.outward_run_id(run_id);
211                self.message_id = tirea_contract::gen_message_id();
212                vec![Event::run_started(
213                    thread_id,
214                    outward_run_id,
215                    parent_run_id.clone(),
216                )]
217            }
218            AgentEvent::RunFinish {
219                thread_id,
220                run_id,
221                result,
222                termination,
223            } => {
224                let mut events = self.close_open_streams();
225                let outward_run_id = self.outward_run_id(run_id);
226                match termination {
227                    TerminationReason::Cancelled => {
228                        events.push(Event::run_error(
229                            "Run cancelled",
230                            Some("CANCELLED".to_string()),
231                        ));
232                    }
233                    TerminationReason::Error(ref msg) => {
234                        events.push(Event::run_error(msg, Some("ERROR".to_string())));
235                    }
236                    _ => {
237                        events.push(Event::run_finished(
238                            thread_id,
239                            outward_run_id,
240                            result.clone(),
241                        ));
242                    }
243                }
244                events
245            }
246
247            AgentEvent::TextDelta { delta } => {
248                let mut events = vec![];
249                if self.start_text() {
250                    events.push(Event::text_message_start(&self.message_id));
251                }
252                events.push(Event::text_message_content(&self.message_id, delta));
253                events
254            }
255            AgentEvent::ReasoningDelta { delta } => {
256                let mut events = vec![];
257                if self.start_reasoning() {
258                    events.push(Event::reasoning_start(&self.message_id));
259                    events.push(Event::reasoning_message_start(&self.message_id));
260                }
261                events.push(Event::reasoning_message_content(&self.message_id, delta));
262                events
263            }
264            AgentEvent::ReasoningEncryptedValue { encrypted_value } => {
265                vec![Event::reasoning_encrypted_value(
266                    ReasoningEncryptedValueSubtype::Message,
267                    self.message_id.clone(),
268                    encrypted_value.clone(),
269                )]
270            }
271
272            AgentEvent::ToolCallStart { id, name } => {
273                let mut events = self.close_open_streams();
274                events.push(Event::tool_call_start(
275                    id,
276                    name,
277                    Some(self.message_id.clone()),
278                ));
279                events
280            }
281            AgentEvent::ToolCallDelta { id, args_delta } => {
282                self.tool_ids_with_deltas.insert(id.clone());
283                vec![Event::tool_call_args(id, args_delta)]
284            }
285            AgentEvent::ToolCallReady { id, arguments, .. } => {
286                let mut events = Vec::new();
287                // For frontend tool invocations (and any tool call that skipped
288                // the streaming ToolCallDelta path), emit the full arguments as
289                // a single TOOL_CALL_ARGS event before TOOL_CALL_END.
290                if !self.tool_ids_with_deltas.contains(id.as_str()) {
291                    let args_str = arguments.to_string();
292                    if args_str != "{}" && args_str != "null" {
293                        events.push(Event::tool_call_args(id.clone(), args_str));
294                    }
295                }
296                events.push(Event::tool_call_end(id));
297                events
298            }
299            AgentEvent::ToolCallDone {
300                id,
301                result,
302                message_id,
303                ..
304            } => {
305                let content = match serde_json::to_string(&result.to_json()) {
306                    Ok(content) => content,
307                    Err(err) => {
308                        warn!(error = %err, tool_call_id = %id, "failed to serialize tool result for AG-UI");
309                        r#"{"error":"failed to serialize tool result"}"#.to_string()
310                    }
311                };
312                let msg_id = if message_id.is_empty() {
313                    format!("result_{id}")
314                } else {
315                    message_id.clone()
316                };
317                vec![Event::tool_call_result(msg_id, id, content)]
318            }
319
320            AgentEvent::StepStart { message_id } => {
321                if !message_id.is_empty() {
322                    self.reset_for_step(message_id.clone());
323                }
324                vec![Event::step_started(self.next_step_name())]
325            }
326            AgentEvent::StepEnd => {
327                let mut events = self.close_reasoning_stream();
328                events.push(Event::step_finished(self.current_step_name()));
329                events
330            }
331
332            AgentEvent::StateSnapshot { snapshot } => {
333                let mut events = Vec::new();
334                // Emit RFC 6902 delta if we have a previous state to diff against.
335                if let Some(ref old) = self.last_state {
336                    let patch = json_patch::diff(old, snapshot);
337                    if !patch.0.is_empty() {
338                        let delta = patch
339                            .0
340                            .iter()
341                            .map(|op| serde_json::to_value(op).expect("RFC 6902 op serializes"))
342                            .collect();
343                        events.push(Event::state_delta(delta));
344                    }
345                }
346                self.last_state = Some(snapshot.clone());
347                events.push(Event::state_snapshot(snapshot.clone()));
348                events
349            }
350            AgentEvent::StateDelta { delta } => {
351                vec![Event::state_delta(delta.clone())]
352            }
353            AgentEvent::MessagesSnapshot { messages } => {
354                vec![Event::messages_snapshot(messages.clone())]
355            }
356
357            AgentEvent::ActivitySnapshot {
358                message_id,
359                activity_type,
360                content,
361                replace,
362            } => {
363                vec![Event::activity_snapshot(
364                    message_id.clone(),
365                    activity_type.clone(),
366                    value_to_map(content),
367                    *replace,
368                )]
369            }
370            AgentEvent::ActivityDelta {
371                message_id,
372                activity_type,
373                patch,
374            } => {
375                vec![Event::activity_delta(
376                    message_id.clone(),
377                    activity_type.clone(),
378                    patch.clone(),
379                )]
380            }
381
382            AgentEvent::Error { message, code } => {
383                let mut events = self.close_reasoning_stream();
384                events.push(Event::run_error(message, code.clone()));
385                events
386            }
387            AgentEvent::InferenceComplete {
388                model,
389                usage,
390                duration_ms,
391            } => {
392                let mut content = serde_json::Map::new();
393                content.insert(
394                    "model".to_string(),
395                    serde_json::Value::String(model.clone()),
396                );
397                content.insert(
398                    "duration_ms".to_string(),
399                    serde_json::Value::Number((*duration_ms).into()),
400                );
401                if let Some(u) = usage {
402                    if let Ok(v) = serde_json::to_value(u) {
403                        content.insert("usage".to_string(), v);
404                    }
405                }
406                vec![Event::activity_snapshot(
407                    self.message_id.clone(),
408                    "inference_complete".to_string(),
409                    content.into_iter().collect(),
410                    Some(false),
411                )]
412            }
413            AgentEvent::ToolCallResumed { .. } => unreachable!(),
414        }
415    }
416}
417
418pub(super) fn value_to_map(value: &Value) -> HashMap<String, Value> {
419    match value.as_object() {
420        Some(map) => map
421            .iter()
422            .map(|(key, value)| (key.clone(), value.clone()))
423            .collect(),
424        None => {
425            let mut map = HashMap::new();
426            map.insert("value".to_string(), value.clone());
427            map
428        }
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use serde_json::json;
436    use tirea_contract::TokenUsage;
437
438    /// Create a context pre-initialized via a RunStart event.
439    fn make_ctx() -> AgUiEventContext {
440        let mut ctx = AgUiEventContext::new();
441        ctx.on_agent_event(&AgentEvent::RunStart {
442            thread_id: "t1".into(),
443            run_id: "run_12345678".into(),
444            parent_run_id: None,
445        });
446        ctx
447    }
448
449    #[test]
450    fn inference_complete_emits_activity_snapshot() {
451        let mut ctx = make_ctx();
452        let ev = AgentEvent::InferenceComplete {
453            model: "gpt-4o".into(),
454            usage: Some(TokenUsage {
455                prompt_tokens: Some(100),
456                completion_tokens: Some(50),
457                ..Default::default()
458            }),
459            duration_ms: 1234,
460        };
461        let events = ctx.on_agent_event(&ev);
462        assert_eq!(events.len(), 1);
463        let json = serde_json::to_value(&events[0]).unwrap();
464        assert_eq!(json["type"], "ACTIVITY_SNAPSHOT");
465        let content = &json["content"];
466        assert_eq!(content["model"], "gpt-4o");
467        assert_eq!(content["duration_ms"], 1234);
468        assert!(content["usage"].is_object());
469    }
470
471    #[test]
472    fn inference_complete_without_usage() {
473        let mut ctx = make_ctx();
474        let ev = AgentEvent::InferenceComplete {
475            model: "gpt-4o-mini".into(),
476            usage: None,
477            duration_ms: 500,
478        };
479        let events = ctx.on_agent_event(&ev);
480        assert_eq!(events.len(), 1);
481        let json = serde_json::to_value(&events[0]).unwrap();
482        assert_eq!(json["type"], "ACTIVITY_SNAPSHOT");
483        let content = &json["content"];
484        assert_eq!(content["model"], "gpt-4o-mini");
485        assert!(content.get("usage").is_none());
486    }
487
488    #[test]
489    fn reasoning_delta_emits_reasoning_events() {
490        let mut ctx = make_ctx();
491        let events = ctx.on_agent_event(&AgentEvent::ReasoningDelta {
492            delta: "step-by-step".into(),
493        });
494        assert_eq!(events.len(), 3);
495
496        let values: Vec<serde_json::Value> = events
497            .iter()
498            .map(|e| serde_json::to_value(e).unwrap())
499            .collect();
500        assert_eq!(values[0]["type"], "REASONING_START");
501        assert_eq!(values[1]["type"], "REASONING_MESSAGE_START");
502        assert_eq!(values[2]["type"], "REASONING_MESSAGE_CONTENT");
503        assert_eq!(values[2]["delta"], "step-by-step");
504    }
505
506    #[test]
507    fn reasoning_encrypted_value_maps_to_message_entity() {
508        let mut ctx = make_ctx();
509        let events = ctx.on_agent_event(&AgentEvent::ReasoningEncryptedValue {
510            encrypted_value: "opaque-token".into(),
511        });
512        assert_eq!(events.len(), 1);
513        let value = serde_json::to_value(&events[0]).unwrap();
514        assert_eq!(value["type"], "REASONING_ENCRYPTED_VALUE");
515        assert_eq!(value["subtype"], "message");
516        assert_eq!(value["encryptedValue"], "opaque-token");
517    }
518
519    #[test]
520    fn tool_call_progress_activity_snapshot_maps_to_agui_example() {
521        let mut ctx = make_ctx();
522        let events = ctx.on_agent_event(&AgentEvent::ActivitySnapshot {
523            message_id: "tool_call:call_1".into(),
524            activity_type: "tool-call-progress".into(),
525            content: json!({
526                "type": "tool-call-progress",
527                "schema": "tool-call-progress.v1",
528                "node_id": "tool_call:call_1",
529                "parent_call_id": "call_parent_1",
530                "parent_node_id": "tool_call:call_parent_1",
531                "status": "running",
532                "progress": 0.35,
533                "message": "calling MCP"
534            }),
535            replace: Some(true),
536        });
537        assert_eq!(events.len(), 1);
538        let value = serde_json::to_value(&events[0]).expect("serialize ag-ui event");
539        assert_eq!(value["type"], "ACTIVITY_SNAPSHOT");
540        assert_eq!(value["activityType"], "tool-call-progress");
541        assert_eq!(value["content"]["schema"], "tool-call-progress.v1");
542        assert_eq!(value["content"]["parent_call_id"], "call_parent_1");
543        assert_eq!(value["content"]["progress"], 0.35);
544    }
545}
546
547// ============================================================================