tirea_agentos_server/transport/
nats.rs

1use serde::Serialize;
2use std::future::Future;
3use std::sync::Arc;
4use tirea_agentos::contracts::AgentEvent;
5use tirea_agentos::runtime::RunStream;
6use tirea_contract::{RuntimeInput, Transcoder};
7
8use crate::transport::NatsProtocolError;
9use crate::transport::{
10    relay_binding, Endpoint, RelayCancellation, RuntimeEndpoint, SessionId, TranscoderEndpoint,
11    TransportBinding, TransportCapabilities, TransportError,
12};
13
14#[derive(Clone, Debug)]
15pub struct NatsTransportConfig {
16    pub outbound_buffer: usize,
17}
18
19impl Default for NatsTransportConfig {
20    fn default() -> Self {
21        Self {
22            outbound_buffer: 64,
23        }
24    }
25}
26
27/// Owns a NATS connection and transport configuration.
28#[derive(Clone)]
29pub struct NatsTransport {
30    client: async_nats::Client,
31    config: NatsTransportConfig,
32}
33
34impl NatsTransport {
35    pub fn new(client: async_nats::Client, config: NatsTransportConfig) -> Self {
36        Self { client, config }
37    }
38
39    pub fn client(&self) -> &async_nats::Client {
40        &self.client
41    }
42
43    /// Subscribe to a NATS subject and dispatch each message to a handler.
44    pub async fn serve<H, Fut>(
45        &self,
46        subject: &str,
47        protocol_label: &'static str,
48        handler: H,
49    ) -> Result<(), NatsProtocolError>
50    where
51        H: Fn(NatsTransport, async_nats::Message) -> Fut + Send + Sync + 'static,
52        Fut: Future<Output = Result<(), NatsProtocolError>> + Send + 'static,
53    {
54        use futures::StreamExt;
55        let handler = Arc::new(handler);
56        let mut sub = self.client.subscribe(subject.to_string()).await?;
57        while let Some(msg) = sub.next().await {
58            let transport = self.clone();
59            let handler = handler.clone();
60            tokio::spawn(async move {
61                if let Err(e) = handler(transport, msg).await {
62                    tracing::error!(error = %e, "nats {protocol_label} handler failed");
63                }
64            });
65        }
66        Ok(())
67    }
68
69    pub async fn publish_run_stream<E, BuildEncoder>(
70        &self,
71        run: RunStream,
72        reply: async_nats::Subject,
73        build_encoder: BuildEncoder,
74    ) -> Result<(), NatsProtocolError>
75    where
76        E: Transcoder<Input = AgentEvent> + 'static,
77        E::Output: Serialize + Send + 'static,
78        BuildEncoder: FnOnce(&RunStream) -> E,
79    {
80        let session_thread_id = run.thread_id.clone();
81        let encoder = build_encoder(&run);
82        let upstream = Arc::new(NatsReplyServerEndpoint::new(self.client.clone(), reply));
83        let runtime_ep = Arc::new(RuntimeEndpoint::from_run_stream_with_buffer(
84            run,
85            self.config.outbound_buffer,
86        ));
87        let downstream = Arc::new(TranscoderEndpoint::new(runtime_ep, encoder));
88        let binding = TransportBinding {
89            session: SessionId {
90                thread_id: session_thread_id,
91            },
92            caps: TransportCapabilities {
93                upstream_async: false,
94                downstream_streaming: true,
95                single_channel_bidirectional: false,
96                resumable_downstream: false,
97            },
98            upstream,
99            downstream,
100        };
101        relay_binding(binding, RelayCancellation::new())
102            .await
103            .map_err(|e| NatsProtocolError::Run(format!("transport relay failed: {e}")))?;
104
105        Ok(())
106    }
107
108    pub(crate) async fn publish_error_event<ErrEvent: Serialize>(
109        &self,
110        reply: async_nats::Subject,
111        event: ErrEvent,
112    ) -> Result<(), NatsProtocolError> {
113        let payload = serde_json::to_vec(&event)
114            .map_err(|e| NatsProtocolError::Run(format!("serialize error event failed: {e}")))?
115            .into();
116        if let Err(publish_err) = self.client.publish(reply, payload).await {
117            return Err(NatsProtocolError::Run(format!(
118                "publish error event failed: {publish_err}"
119            )));
120        }
121        Ok(())
122    }
123}
124
125pub(crate) struct NatsReplyServerEndpoint {
126    client: async_nats::Client,
127    reply: async_nats::Subject,
128}
129
130impl NatsReplyServerEndpoint {
131    pub(crate) fn new(client: async_nats::Client, reply: async_nats::Subject) -> Self {
132        Self { client, reply }
133    }
134}
135
136#[async_trait::async_trait]
137impl<Evt> Endpoint<RuntimeInput, Evt> for NatsReplyServerEndpoint
138where
139    Evt: Serialize + Send + 'static,
140{
141    async fn recv(&self) -> Result<crate::transport::BoxStream<RuntimeInput>, TransportError> {
142        let stream = futures::stream::empty::<Result<RuntimeInput, TransportError>>();
143        Ok(Box::pin(stream))
144    }
145
146    async fn send(&self, item: Evt) -> Result<(), TransportError> {
147        let payload = serde_json::to_vec(&item).map_err(|e| {
148            tracing::warn!(error = %e, "failed to serialize NATS protocol event");
149            TransportError::Io(format!("serialize event failed: {e}"))
150        })?;
151        self.client
152            .publish(self.reply.clone(), payload.into())
153            .await
154            .map_err(|e| TransportError::Io(e.to_string()))
155    }
156
157    async fn close(&self) -> Result<(), TransportError> {
158        Ok(())
159    }
160}