tirea_protocol_ag_ui/
types.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::{HashMap, HashSet};
4use tirea_contract::io::decision_translation::suspension_response_to_decision;
5use tirea_contract::io::ResumeDecisionAction;
6use tirea_contract::runtime::suspended_calls_from_state;
7use tirea_contract::{gen_message_id, RunOrigin, RunRequest, Visibility};
8use tirea_contract::{SuspensionResponse, ToolCallDecision};
9use tracing::warn;
10
11/// Role for AG-UI input/output messages.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15    Developer,
16    System,
17    #[default]
18    Assistant,
19    User,
20    Tool,
21    Activity,
22    Reasoning,
23}
24
25/// AG-UI message in a conversation.
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct Message {
28    /// Message role (user, assistant, system, tool, developer, activity, reasoning).
29    pub role: Role,
30    /// Message content.
31    pub content: String,
32    /// Optional message ID.
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub id: Option<String>,
35    /// Optional tool call ID (for tool messages).
36    #[serde(rename = "toolCallId", skip_serializing_if = "Option::is_none")]
37    pub tool_call_id: Option<String>,
38}
39
40impl Message {
41    /// Create a user message.
42    pub fn user(content: impl Into<String>) -> Self {
43        Self {
44            role: Role::User,
45            content: content.into(),
46            id: None,
47            tool_call_id: None,
48        }
49    }
50
51    /// Create an assistant message.
52    pub fn assistant(content: impl Into<String>) -> Self {
53        Self {
54            role: Role::Assistant,
55            content: content.into(),
56            id: None,
57            tool_call_id: None,
58        }
59    }
60
61    /// Create a system message.
62    pub fn system(content: impl Into<String>) -> Self {
63        Self {
64            role: Role::System,
65            content: content.into(),
66            id: None,
67            tool_call_id: None,
68        }
69    }
70
71    /// Create a tool result message.
72    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
73        Self {
74            role: Role::Tool,
75            content: content.into(),
76            id: None,
77            tool_call_id: Some(tool_call_id.into()),
78        }
79    }
80
81    /// Create an activity message.
82    pub fn activity(content: impl Into<String>) -> Self {
83        Self {
84            role: Role::Activity,
85            content: content.into(),
86            id: None,
87            tool_call_id: None,
88        }
89    }
90
91    /// Create a reasoning message.
92    pub fn reasoning(content: impl Into<String>) -> Self {
93        Self {
94            role: Role::Reasoning,
95            content: content.into(),
96            id: None,
97            tool_call_id: None,
98        }
99    }
100}
101
102/// AG-UI context entry from frontend readable values.
103#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
104pub struct Context {
105    /// Human-readable description of the context.
106    pub description: String,
107    /// The context value.
108    pub value: Value,
109}
110
111/// Tool execution location.
112#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
113#[serde(rename_all = "lowercase")]
114pub enum ToolExecutionLocation {
115    /// Tool executes on the backend (server-side).
116    Backend,
117    /// Tool executes on the frontend (client-side).
118    #[default]
119    Frontend,
120}
121
122/// AG-UI tool definition.
123#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
124pub struct Tool {
125    /// Tool name.
126    pub name: String,
127    /// Tool description.
128    pub description: String,
129    /// JSON Schema for tool parameters.
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub parameters: Option<Value>,
132    /// Where the tool executes (frontend or backend).
133    #[serde(default, skip_serializing_if = "is_default_frontend")]
134    pub execute: ToolExecutionLocation,
135}
136
137fn is_default_frontend(loc: &ToolExecutionLocation) -> bool {
138    *loc == ToolExecutionLocation::Frontend
139}
140
141impl Tool {
142    /// Create a new backend tool definition.
143    pub fn backend(name: impl Into<String>, description: impl Into<String>) -> Self {
144        Self {
145            name: name.into(),
146            description: description.into(),
147            parameters: None,
148            execute: ToolExecutionLocation::Backend,
149        }
150    }
151
152    /// Create a new frontend tool definition.
153    pub fn frontend(name: impl Into<String>, description: impl Into<String>) -> Self {
154        Self {
155            name: name.into(),
156            description: description.into(),
157            parameters: None,
158            execute: ToolExecutionLocation::Frontend,
159        }
160    }
161
162    /// Set the JSON Schema parameters.
163    pub fn with_parameters(mut self, parameters: Value) -> Self {
164        self.parameters = Some(parameters);
165        self
166    }
167
168    /// Check if this is a frontend tool.
169    pub fn is_frontend(&self) -> bool {
170        self.execute == ToolExecutionLocation::Frontend
171    }
172}
173
174/// Request to run an AG-UI agent.
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct RunAgentInput {
177    /// Thread identifier.
178    #[serde(rename = "threadId")]
179    pub thread_id: String,
180    /// Run identifier.
181    #[serde(rename = "runId")]
182    pub run_id: String,
183    /// Conversation messages.
184    pub messages: Vec<Message>,
185    /// Available tools.
186    #[serde(default)]
187    pub tools: Vec<Tool>,
188    /// Frontend readable context entries.
189    #[serde(default)]
190    pub context: Vec<Context>,
191    /// Initial state.
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub state: Option<Value>,
194    /// Parent run ID (for sub-runs).
195    #[serde(rename = "parentRunId", skip_serializing_if = "Option::is_none")]
196    pub parent_run_id: Option<String>,
197    /// Parent thread ID (for delegated/sub-agent lineage).
198    #[serde(
199        rename = "parentThreadId",
200        alias = "parent_thread_id",
201        skip_serializing_if = "Option::is_none"
202    )]
203    pub parent_thread_id: Option<String>,
204    /// Model to use.
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub model: Option<String>,
207    /// System prompt.
208    #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
209    pub system_prompt: Option<String>,
210    /// Additional configuration.
211    #[serde(skip_serializing_if = "Option::is_none")]
212    pub config: Option<Value>,
213    /// Additional forwarded properties from AG-UI client runtimes.
214    #[serde(
215        rename = "forwardedProps",
216        alias = "forwarded_props",
217        skip_serializing_if = "Option::is_none"
218    )]
219    pub forwarded_props: Option<Value>,
220}
221
222impl RunAgentInput {
223    /// Create a new request with minimal required fields.
224    pub fn new(thread_id: impl Into<String>, run_id: impl Into<String>) -> Self {
225        Self {
226            thread_id: thread_id.into(),
227            run_id: run_id.into(),
228            messages: Vec::new(),
229            tools: Vec::new(),
230            context: Vec::new(),
231            state: None,
232            parent_run_id: None,
233            parent_thread_id: None,
234            model: None,
235            system_prompt: None,
236            config: None,
237            forwarded_props: None,
238        }
239    }
240
241    /// Add a message.
242    pub fn with_message(mut self, message: Message) -> Self {
243        self.messages.push(message);
244        self
245    }
246
247    /// Add messages.
248    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
249        self.messages.extend(messages);
250        self
251    }
252
253    /// Set initial state.
254    pub fn with_state(mut self, state: Value) -> Self {
255        self.state = Some(state);
256        self
257    }
258
259    /// Set parent thread ID.
260    pub fn with_parent_thread_id(mut self, parent_thread_id: impl Into<String>) -> Self {
261        self.parent_thread_id = Some(parent_thread_id.into());
262        self
263    }
264
265    /// Set model.
266    pub fn with_model(mut self, model: impl Into<String>) -> Self {
267        self.model = Some(model.into());
268        self
269    }
270
271    /// Set system prompt.
272    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
273        self.system_prompt = Some(prompt.into());
274        self
275    }
276
277    /// Set forwarded props.
278    pub fn with_forwarded_props(mut self, forwarded_props: Value) -> Self {
279        self.forwarded_props = Some(forwarded_props);
280        self
281    }
282
283    /// Validate the request.
284    pub fn validate(&self) -> Result<(), RequestError> {
285        if self.thread_id.is_empty() {
286            return Err(RequestError::invalid_field("threadId cannot be empty"));
287        }
288        if self.run_id.is_empty() {
289            return Err(RequestError::invalid_field("runId cannot be empty"));
290        }
291        Ok(())
292    }
293
294    /// Get frontend tools from the request.
295    pub fn frontend_tools(&self) -> Vec<&Tool> {
296        self.tools.iter().filter(|t| t.is_frontend()).collect()
297    }
298
299    /// Check if any interaction responses exist in this request.
300    pub fn has_any_interaction_responses(&self) -> bool {
301        !self.interaction_responses().is_empty()
302    }
303
304    /// Check if any suspension decisions exist in this request.
305    pub fn has_any_suspension_decisions(&self) -> bool {
306        !self.suspension_decisions().is_empty()
307    }
308
309    /// Check if this request contains non-empty user input.
310    pub fn has_user_input(&self) -> bool {
311        self.messages
312            .iter()
313            .any(|message| message.role == Role::User && !message.content.trim().is_empty())
314    }
315
316    /// Convert this AG-UI request to the internal runtime request.
317    ///
318    /// Mapping rules:
319    /// - `thread_id`, `run_id`, `parent_run_id`, `state` are forwarded directly.
320    /// - `messages` are converted via `convert_agui_messages` (assistant/activity/reasoning
321    ///   inbound messages are intentionally skipped at runtime input boundary).
322    /// - `resource_id` is not provided by AG-UI and remains `None`.
323    pub fn into_runtime_run_request(self, agent_id: String) -> RunRequest {
324        let initial_decisions = self.suspension_decisions();
325        RunRequest {
326            agent_id,
327            thread_id: Some(self.thread_id),
328            run_id: Some(self.run_id),
329            parent_run_id: self.parent_run_id,
330            parent_thread_id: self.parent_thread_id,
331            resource_id: None,
332            origin: RunOrigin::AgUi,
333            state: self.state,
334            messages: convert_agui_messages(&self.messages),
335            initial_decisions,
336            source_mailbox_entry_id: None,
337        }
338    }
339
340    /// Extract all interaction responses from tool messages.
341    pub fn interaction_responses(&self) -> Vec<SuspensionResponse> {
342        let expected_ids = self.suspended_call_response_ids();
343        let mut latest_by_id: HashMap<String, (usize, Value)> = HashMap::new();
344
345        self.messages
346            .iter()
347            .enumerate()
348            .filter(|(_, m)| m.role == Role::Tool)
349            .filter_map(|(idx, m)| {
350                m.tool_call_id.as_ref().and_then(|id| {
351                    if !expected_ids.is_empty() && !expected_ids.contains(id) {
352                        return None;
353                    }
354                    let result = parse_interaction_result_value(&m.content);
355                    Some((idx, id.clone(), result))
356                })
357            })
358            .for_each(|(idx, id, result)| {
359                // Last write wins for duplicate IDs.
360                latest_by_id.insert(id, (idx, result));
361            });
362
363        let mut responses: Vec<(usize, SuspensionResponse)> = latest_by_id
364            .into_iter()
365            .map(|(id, (idx, result))| (idx, SuspensionResponse::new(id, result)))
366            .collect();
367        responses.sort_by_key(|(idx, _)| *idx);
368        responses
369            .into_iter()
370            .map(|(_, response)| response)
371            .collect()
372    }
373
374    /// Extract all suspension decisions from tool messages.
375    pub fn suspension_decisions(&self) -> Vec<ToolCallDecision> {
376        self.interaction_responses()
377            .into_iter()
378            .map(suspension_response_to_decision)
379            .collect()
380    }
381
382    /// Get all approved interaction IDs.
383    pub fn approved_target_ids(&self) -> Vec<String> {
384        self.suspension_decisions()
385            .into_iter()
386            .filter(|d| matches!(d.resume.action, ResumeDecisionAction::Resume))
387            .map(|d| d.target_id)
388            .collect()
389    }
390
391    /// Get all denied interaction IDs.
392    pub fn denied_target_ids(&self) -> Vec<String> {
393        self.suspension_decisions()
394            .into_iter()
395            .filter(|d| matches!(d.resume.action, ResumeDecisionAction::Cancel))
396            .map(|d| d.target_id)
397            .collect()
398    }
399
400    fn suspended_call_response_ids(&self) -> HashSet<String> {
401        let mut ids = HashSet::new();
402        let Some(state) = self.state.as_ref() else {
403            return ids;
404        };
405
406        let calls = suspended_calls_from_state(state);
407        for call in calls.values() {
408            ids.insert(call.ticket.pending.id.clone());
409            ids.insert(call.call_id.clone());
410            ids.insert(call.ticket.suspension.id.clone());
411        }
412
413        ids
414    }
415}
416
417fn parse_interaction_result_value(content: &str) -> Value {
418    serde_json::from_str(content).unwrap_or_else(|_| Value::String(content.to_string()))
419}
420
421/// Convert AG-UI message to internal message.
422pub fn core_message_from_ag_ui(msg: &Message) -> tirea_contract::Message {
423    let role = match msg.role {
424        Role::System => tirea_contract::Role::System,
425        Role::Developer => tirea_contract::Role::System,
426        Role::User => tirea_contract::Role::User,
427        Role::Assistant => tirea_contract::Role::Assistant,
428        Role::Tool => tirea_contract::Role::Tool,
429        Role::Activity => tirea_contract::Role::Assistant,
430        Role::Reasoning => tirea_contract::Role::Assistant,
431    };
432
433    tirea_contract::Message {
434        id: Some(msg.id.clone().unwrap_or_else(gen_message_id)),
435        role,
436        content: msg.content.clone(),
437        tool_calls: None,
438        tool_call_id: msg.tool_call_id.clone(),
439        visibility: Visibility::default(),
440        metadata: None,
441    }
442}
443
444/// Convert AG-UI messages to internal messages.
445pub fn convert_agui_messages(messages: &[Message]) -> Vec<tirea_contract::Message> {
446    messages
447        .iter()
448        .filter(|m| {
449            m.role != Role::Assistant && m.role != Role::Activity && m.role != Role::Reasoning
450        })
451        .map(core_message_from_ag_ui)
452        .collect()
453}
454
455/// Error type for request processing.
456#[derive(Debug, Clone, Serialize, Deserialize)]
457pub struct RequestError {
458    /// Error code.
459    pub code: String,
460    /// Error message.
461    pub message: String,
462}
463
464impl RequestError {
465    /// Create an invalid field error.
466    pub fn invalid_field(message: impl Into<String>) -> Self {
467        Self {
468            code: "INVALID_FIELD".into(),
469            message: message.into(),
470        }
471    }
472
473    /// Create a validation error.
474    pub fn validation(message: impl Into<String>) -> Self {
475        Self {
476            code: "VALIDATION_ERROR".into(),
477            message: message.into(),
478        }
479    }
480
481    /// Create an internal error.
482    pub fn internal(message: impl Into<String>) -> Self {
483        Self {
484            code: "INTERNAL_ERROR".into(),
485            message: message.into(),
486        }
487    }
488}
489
490impl std::fmt::Display for RequestError {
491    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492        write!(f, "[{}] {}", self.code, self.message)
493    }
494}
495
496impl std::error::Error for RequestError {}
497
498impl From<String> for RequestError {
499    fn from(message: String) -> Self {
500        Self::validation(message)
501    }
502}
503
504/// Build a context string from AG-UI context entries to append to the system prompt.
505pub fn build_context_addendum(request: &RunAgentInput) -> Option<String> {
506    if request.context.is_empty() {
507        return None;
508    }
509    let mut parts = Vec::new();
510    for entry in &request.context {
511        let value_str = match &entry.value {
512            Value::String(s) => s.clone(),
513            other => match serde_json::to_string(other) {
514                Ok(value) => value,
515                Err(err) => {
516                    warn!(
517                        error = %err,
518                        description = %entry.description,
519                        "failed to stringify AG-UI context value"
520                    );
521                    "<unserializable-context-value>".to_string()
522                }
523            },
524        };
525        parts.push(format!("[{}]: {}", entry.description, value_str));
526    }
527    Some(format!(
528        "\n\nThe following context is available from the frontend:\n{}",
529        parts.join("\n")
530    ))
531}