1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::{HashMap, HashSet};
4use tirea_contract::io::decision_translation::suspension_response_to_decision;
5use tirea_contract::io::ResumeDecisionAction;
6use tirea_contract::runtime::suspended_calls_from_state;
7use tirea_contract::{gen_message_id, RunOrigin, RunRequest, Visibility};
8use tirea_contract::{SuspensionResponse, ToolCallDecision};
9use tracing::warn;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15 Developer,
16 System,
17 #[default]
18 Assistant,
19 User,
20 Tool,
21 Activity,
22 Reasoning,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct Message {
28 pub role: Role,
30 pub content: String,
32 #[serde(skip_serializing_if = "Option::is_none")]
34 pub id: Option<String>,
35 #[serde(rename = "toolCallId", skip_serializing_if = "Option::is_none")]
37 pub tool_call_id: Option<String>,
38}
39
40impl Message {
41 pub fn user(content: impl Into<String>) -> Self {
43 Self {
44 role: Role::User,
45 content: content.into(),
46 id: None,
47 tool_call_id: None,
48 }
49 }
50
51 pub fn assistant(content: impl Into<String>) -> Self {
53 Self {
54 role: Role::Assistant,
55 content: content.into(),
56 id: None,
57 tool_call_id: None,
58 }
59 }
60
61 pub fn system(content: impl Into<String>) -> Self {
63 Self {
64 role: Role::System,
65 content: content.into(),
66 id: None,
67 tool_call_id: None,
68 }
69 }
70
71 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
73 Self {
74 role: Role::Tool,
75 content: content.into(),
76 id: None,
77 tool_call_id: Some(tool_call_id.into()),
78 }
79 }
80
81 pub fn activity(content: impl Into<String>) -> Self {
83 Self {
84 role: Role::Activity,
85 content: content.into(),
86 id: None,
87 tool_call_id: None,
88 }
89 }
90
91 pub fn reasoning(content: impl Into<String>) -> Self {
93 Self {
94 role: Role::Reasoning,
95 content: content.into(),
96 id: None,
97 tool_call_id: None,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
104pub struct Context {
105 pub description: String,
107 pub value: Value,
109}
110
111#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
113#[serde(rename_all = "lowercase")]
114pub enum ToolExecutionLocation {
115 Backend,
117 #[default]
119 Frontend,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
124pub struct Tool {
125 pub name: String,
127 pub description: String,
129 #[serde(skip_serializing_if = "Option::is_none")]
131 pub parameters: Option<Value>,
132 #[serde(default, skip_serializing_if = "is_default_frontend")]
134 pub execute: ToolExecutionLocation,
135}
136
137fn is_default_frontend(loc: &ToolExecutionLocation) -> bool {
138 *loc == ToolExecutionLocation::Frontend
139}
140
141impl Tool {
142 pub fn backend(name: impl Into<String>, description: impl Into<String>) -> Self {
144 Self {
145 name: name.into(),
146 description: description.into(),
147 parameters: None,
148 execute: ToolExecutionLocation::Backend,
149 }
150 }
151
152 pub fn frontend(name: impl Into<String>, description: impl Into<String>) -> Self {
154 Self {
155 name: name.into(),
156 description: description.into(),
157 parameters: None,
158 execute: ToolExecutionLocation::Frontend,
159 }
160 }
161
162 pub fn with_parameters(mut self, parameters: Value) -> Self {
164 self.parameters = Some(parameters);
165 self
166 }
167
168 pub fn is_frontend(&self) -> bool {
170 self.execute == ToolExecutionLocation::Frontend
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct RunAgentInput {
177 #[serde(rename = "threadId")]
179 pub thread_id: String,
180 #[serde(rename = "runId")]
182 pub run_id: String,
183 pub messages: Vec<Message>,
185 #[serde(default)]
187 pub tools: Vec<Tool>,
188 #[serde(default)]
190 pub context: Vec<Context>,
191 #[serde(skip_serializing_if = "Option::is_none")]
193 pub state: Option<Value>,
194 #[serde(rename = "parentRunId", skip_serializing_if = "Option::is_none")]
196 pub parent_run_id: Option<String>,
197 #[serde(
199 rename = "parentThreadId",
200 alias = "parent_thread_id",
201 skip_serializing_if = "Option::is_none"
202 )]
203 pub parent_thread_id: Option<String>,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 pub model: Option<String>,
207 #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
209 pub system_prompt: Option<String>,
210 #[serde(skip_serializing_if = "Option::is_none")]
212 pub config: Option<Value>,
213 #[serde(
215 rename = "forwardedProps",
216 alias = "forwarded_props",
217 skip_serializing_if = "Option::is_none"
218 )]
219 pub forwarded_props: Option<Value>,
220}
221
222impl RunAgentInput {
223 pub fn new(thread_id: impl Into<String>, run_id: impl Into<String>) -> Self {
225 Self {
226 thread_id: thread_id.into(),
227 run_id: run_id.into(),
228 messages: Vec::new(),
229 tools: Vec::new(),
230 context: Vec::new(),
231 state: None,
232 parent_run_id: None,
233 parent_thread_id: None,
234 model: None,
235 system_prompt: None,
236 config: None,
237 forwarded_props: None,
238 }
239 }
240
241 pub fn with_message(mut self, message: Message) -> Self {
243 self.messages.push(message);
244 self
245 }
246
247 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
249 self.messages.extend(messages);
250 self
251 }
252
253 pub fn with_state(mut self, state: Value) -> Self {
255 self.state = Some(state);
256 self
257 }
258
259 pub fn with_parent_thread_id(mut self, parent_thread_id: impl Into<String>) -> Self {
261 self.parent_thread_id = Some(parent_thread_id.into());
262 self
263 }
264
265 pub fn with_model(mut self, model: impl Into<String>) -> Self {
267 self.model = Some(model.into());
268 self
269 }
270
271 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
273 self.system_prompt = Some(prompt.into());
274 self
275 }
276
277 pub fn with_forwarded_props(mut self, forwarded_props: Value) -> Self {
279 self.forwarded_props = Some(forwarded_props);
280 self
281 }
282
283 pub fn validate(&self) -> Result<(), RequestError> {
285 if self.thread_id.is_empty() {
286 return Err(RequestError::invalid_field("threadId cannot be empty"));
287 }
288 if self.run_id.is_empty() {
289 return Err(RequestError::invalid_field("runId cannot be empty"));
290 }
291 Ok(())
292 }
293
294 pub fn frontend_tools(&self) -> Vec<&Tool> {
296 self.tools.iter().filter(|t| t.is_frontend()).collect()
297 }
298
299 pub fn has_any_interaction_responses(&self) -> bool {
301 !self.interaction_responses().is_empty()
302 }
303
304 pub fn has_any_suspension_decisions(&self) -> bool {
306 !self.suspension_decisions().is_empty()
307 }
308
309 pub fn has_user_input(&self) -> bool {
311 self.messages
312 .iter()
313 .any(|message| message.role == Role::User && !message.content.trim().is_empty())
314 }
315
316 pub fn into_runtime_run_request(self, agent_id: String) -> RunRequest {
324 let initial_decisions = self.suspension_decisions();
325 RunRequest {
326 agent_id,
327 thread_id: Some(self.thread_id),
328 run_id: Some(self.run_id),
329 parent_run_id: self.parent_run_id,
330 parent_thread_id: self.parent_thread_id,
331 resource_id: None,
332 origin: RunOrigin::AgUi,
333 state: self.state,
334 messages: convert_agui_messages(&self.messages),
335 initial_decisions,
336 source_mailbox_entry_id: None,
337 }
338 }
339
340 pub fn interaction_responses(&self) -> Vec<SuspensionResponse> {
342 let expected_ids = self.suspended_call_response_ids();
343 let mut latest_by_id: HashMap<String, (usize, Value)> = HashMap::new();
344
345 self.messages
346 .iter()
347 .enumerate()
348 .filter(|(_, m)| m.role == Role::Tool)
349 .filter_map(|(idx, m)| {
350 m.tool_call_id.as_ref().and_then(|id| {
351 if !expected_ids.is_empty() && !expected_ids.contains(id) {
352 return None;
353 }
354 let result = parse_interaction_result_value(&m.content);
355 Some((idx, id.clone(), result))
356 })
357 })
358 .for_each(|(idx, id, result)| {
359 latest_by_id.insert(id, (idx, result));
361 });
362
363 let mut responses: Vec<(usize, SuspensionResponse)> = latest_by_id
364 .into_iter()
365 .map(|(id, (idx, result))| (idx, SuspensionResponse::new(id, result)))
366 .collect();
367 responses.sort_by_key(|(idx, _)| *idx);
368 responses
369 .into_iter()
370 .map(|(_, response)| response)
371 .collect()
372 }
373
374 pub fn suspension_decisions(&self) -> Vec<ToolCallDecision> {
376 self.interaction_responses()
377 .into_iter()
378 .map(suspension_response_to_decision)
379 .collect()
380 }
381
382 pub fn approved_target_ids(&self) -> Vec<String> {
384 self.suspension_decisions()
385 .into_iter()
386 .filter(|d| matches!(d.resume.action, ResumeDecisionAction::Resume))
387 .map(|d| d.target_id)
388 .collect()
389 }
390
391 pub fn denied_target_ids(&self) -> Vec<String> {
393 self.suspension_decisions()
394 .into_iter()
395 .filter(|d| matches!(d.resume.action, ResumeDecisionAction::Cancel))
396 .map(|d| d.target_id)
397 .collect()
398 }
399
400 fn suspended_call_response_ids(&self) -> HashSet<String> {
401 let mut ids = HashSet::new();
402 let Some(state) = self.state.as_ref() else {
403 return ids;
404 };
405
406 let calls = suspended_calls_from_state(state);
407 for call in calls.values() {
408 ids.insert(call.ticket.pending.id.clone());
409 ids.insert(call.call_id.clone());
410 ids.insert(call.ticket.suspension.id.clone());
411 }
412
413 ids
414 }
415}
416
417fn parse_interaction_result_value(content: &str) -> Value {
418 serde_json::from_str(content).unwrap_or_else(|_| Value::String(content.to_string()))
419}
420
421pub fn core_message_from_ag_ui(msg: &Message) -> tirea_contract::Message {
423 let role = match msg.role {
424 Role::System => tirea_contract::Role::System,
425 Role::Developer => tirea_contract::Role::System,
426 Role::User => tirea_contract::Role::User,
427 Role::Assistant => tirea_contract::Role::Assistant,
428 Role::Tool => tirea_contract::Role::Tool,
429 Role::Activity => tirea_contract::Role::Assistant,
430 Role::Reasoning => tirea_contract::Role::Assistant,
431 };
432
433 tirea_contract::Message {
434 id: Some(msg.id.clone().unwrap_or_else(gen_message_id)),
435 role,
436 content: msg.content.clone(),
437 tool_calls: None,
438 tool_call_id: msg.tool_call_id.clone(),
439 visibility: Visibility::default(),
440 metadata: None,
441 }
442}
443
444pub fn convert_agui_messages(messages: &[Message]) -> Vec<tirea_contract::Message> {
446 messages
447 .iter()
448 .filter(|m| {
449 m.role != Role::Assistant && m.role != Role::Activity && m.role != Role::Reasoning
450 })
451 .map(core_message_from_ag_ui)
452 .collect()
453}
454
455#[derive(Debug, Clone, Serialize, Deserialize)]
457pub struct RequestError {
458 pub code: String,
460 pub message: String,
462}
463
464impl RequestError {
465 pub fn invalid_field(message: impl Into<String>) -> Self {
467 Self {
468 code: "INVALID_FIELD".into(),
469 message: message.into(),
470 }
471 }
472
473 pub fn validation(message: impl Into<String>) -> Self {
475 Self {
476 code: "VALIDATION_ERROR".into(),
477 message: message.into(),
478 }
479 }
480
481 pub fn internal(message: impl Into<String>) -> Self {
483 Self {
484 code: "INTERNAL_ERROR".into(),
485 message: message.into(),
486 }
487 }
488}
489
490impl std::fmt::Display for RequestError {
491 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492 write!(f, "[{}] {}", self.code, self.message)
493 }
494}
495
496impl std::error::Error for RequestError {}
497
498impl From<String> for RequestError {
499 fn from(message: String) -> Self {
500 Self::validation(message)
501 }
502}
503
504pub fn build_context_addendum(request: &RunAgentInput) -> Option<String> {
506 if request.context.is_empty() {
507 return None;
508 }
509 let mut parts = Vec::new();
510 for entry in &request.context {
511 let value_str = match &entry.value {
512 Value::String(s) => s.clone(),
513 other => match serde_json::to_string(other) {
514 Ok(value) => value,
515 Err(err) => {
516 warn!(
517 error = %err,
518 description = %entry.description,
519 "failed to stringify AG-UI context value"
520 );
521 "<unserializable-context-value>".to_string()
522 }
523 },
524 };
525 parts.push(format!("[{}]: {}", entry.description, value_str));
526 }
527 Some(format!(
528 "\n\nThe following context is available from the frontend:\n{}",
529 parts.join("\n")
530 ))
531}