tirea_agentos_server/protocol/ai_sdk_v6/
nats.rs

1use serde::Deserialize;
2use std::sync::Arc;
3use tirea_agentos::runtime::AgentOs;
4use tirea_protocol_ai_sdk_v6::{AiSdkEncoder, AiSdkV6RunRequest, UIStreamEvent};
5
6use super::runtime::apply_ai_sdk_extensions;
7use crate::service::{EnqueueOptions, MailboxService};
8use crate::transport::nats::NatsTransport;
9use crate::transport::NatsProtocolError;
10
11/// Serve AI SDK v6 protocol over NATS.
12pub async fn serve(
13    transport: NatsTransport,
14    os: Arc<AgentOs>,
15    mailbox_service: Arc<MailboxService>,
16    subject: String,
17) -> Result<(), NatsProtocolError> {
18    transport
19        .serve(&subject, "aisdk", move |transport, msg| {
20            let os = os.clone();
21            let mailbox_service = mailbox_service.clone();
22            async move { handle_message(transport, os, mailbox_service, msg).await }
23        })
24        .await
25}
26
27async fn handle_message(
28    transport: NatsTransport,
29    os: Arc<AgentOs>,
30    mailbox_service: Arc<MailboxService>,
31    msg: async_nats::Message,
32) -> Result<(), NatsProtocolError> {
33    #[derive(Debug, Deserialize)]
34    struct Req {
35        #[serde(rename = "agentId")]
36        agent_id: String,
37        #[serde(rename = "sessionId")]
38        thread_id: String,
39        input: String,
40        #[serde(rename = "runId")]
41        run_id: Option<String>,
42        #[serde(rename = "replySubject")]
43        reply_subject: Option<String>,
44    }
45
46    let req: Req = serde_json::from_slice(&msg.payload)
47        .map_err(|e| NatsProtocolError::BadRequest(e.to_string()))?;
48    if req.input.trim().is_empty() {
49        return Err(NatsProtocolError::BadRequest(
50            "input cannot be empty".to_string(),
51        ));
52    }
53
54    let reply = msg.reply.or(req.reply_subject.map(Into::into));
55    let Some(reply) = reply else {
56        return Err(NatsProtocolError::BadRequest(
57            "missing reply subject".to_string(),
58        ));
59    };
60
61    let mut resolved = match os.resolve(&req.agent_id) {
62        Ok(r) => r,
63        Err(err) => {
64            return transport
65                .publish_error_event(reply, UIStreamEvent::error(err.to_string()))
66                .await;
67        }
68    };
69
70    let req_for_runtime = AiSdkV6RunRequest::from_thread_input(req.thread_id, req.input);
71    apply_ai_sdk_extensions(&mut resolved, &req_for_runtime);
72    let agent_id = req.agent_id.clone();
73    let mut run_request = req_for_runtime.into_runtime_run_request(agent_id.clone());
74    run_request.run_id = req.run_id;
75
76    let run = match mailbox_service
77        .submit_streaming(&agent_id, run_request, EnqueueOptions::default())
78        .await
79    {
80        Ok(run) => run,
81        Err(err) => {
82            return transport
83                .publish_error_event(reply, UIStreamEvent::error(err.to_string()))
84                .await;
85        }
86    };
87
88    transport
89        .publish_run_stream(run, reply, move |_run| AiSdkEncoder::new())
90        .await
91}