tirea_agentos_server/protocol/ag_ui/
nats.rs

1use serde::Deserialize;
2use std::sync::Arc;
3use tirea_agentos::runtime::AgentOs;
4use tirea_protocol_ag_ui::{AgUiProtocolEncoder, Event, RunAgentInput};
5
6use super::runtime::apply_agui_extensions;
7use crate::service::{EnqueueOptions, MailboxService};
8use crate::transport::nats::NatsTransport;
9use crate::transport::NatsProtocolError;
10
11/// Serve AG-UI protocol over NATS.
12pub async fn serve(
13    transport: NatsTransport,
14    os: Arc<AgentOs>,
15    mailbox_service: Arc<MailboxService>,
16    subject: String,
17) -> Result<(), NatsProtocolError> {
18    transport
19        .serve(&subject, "agui", move |transport, msg| {
20            let os = os.clone();
21            let mailbox_service = mailbox_service.clone();
22            async move { handle_message(transport, os, mailbox_service, msg).await }
23        })
24        .await
25}
26
27async fn handle_message(
28    transport: NatsTransport,
29    os: Arc<AgentOs>,
30    mailbox_service: Arc<MailboxService>,
31    msg: async_nats::Message,
32) -> Result<(), NatsProtocolError> {
33    #[derive(Debug, Deserialize)]
34    struct Req {
35        #[serde(rename = "agentId")]
36        agent_id: String,
37        request: RunAgentInput,
38        #[serde(rename = "replySubject")]
39        reply_subject: Option<String>,
40    }
41
42    let req: Req = serde_json::from_slice(&msg.payload)
43        .map_err(|e| NatsProtocolError::BadRequest(e.to_string()))?;
44    req.request
45        .validate()
46        .map_err(|e| NatsProtocolError::BadRequest(e.to_string()))?;
47
48    let reply = msg.reply.or(req.reply_subject.map(Into::into));
49    let Some(reply) = reply else {
50        return Err(NatsProtocolError::BadRequest(
51            "missing reply subject".to_string(),
52        ));
53    };
54
55    let resolved = match os.resolve(&req.agent_id) {
56        Ok(r) => r,
57        Err(err) => {
58            return transport
59                .publish_error_event(reply, Event::run_error(err.to_string(), None))
60                .await;
61        }
62    };
63
64    let mut resolved = resolved;
65    apply_agui_extensions(&mut resolved, &req.request);
66    let agent_id = req.agent_id.clone();
67    let frontend_run_id = req.request.run_id.clone();
68    let mut run_request = req.request.into_runtime_run_request(agent_id.clone());
69    run_request.run_id = None;
70
71    let run = match mailbox_service
72        .submit_streaming(&agent_id, run_request, EnqueueOptions::default())
73        .await
74    {
75        Ok(run) => run,
76        Err(err) => {
77            return transport
78                .publish_error_event(reply, Event::run_error(err.to_string(), None))
79                .await;
80        }
81    };
82
83    transport
84        .publish_run_stream(run, reply, move |_run| {
85            AgUiProtocolEncoder::new_with_frontend_run_id(frontend_run_id)
86        })
87        .await
88}