tirea_agentos_server/transport/
nats.rs1use serde::Serialize;
2use std::future::Future;
3use std::sync::Arc;
4use tirea_agentos::contracts::AgentEvent;
5use tirea_agentos::runtime::RunStream;
6use tirea_contract::{RuntimeInput, Transcoder};
7
8use crate::transport::NatsProtocolError;
9use crate::transport::{
10 relay_binding, Endpoint, RelayCancellation, RuntimeEndpoint, SessionId, TranscoderEndpoint,
11 TransportBinding, TransportCapabilities, TransportError,
12};
13
14#[derive(Clone, Debug)]
15pub struct NatsTransportConfig {
16 pub outbound_buffer: usize,
17}
18
19impl Default for NatsTransportConfig {
20 fn default() -> Self {
21 Self {
22 outbound_buffer: 64,
23 }
24 }
25}
26
27#[derive(Clone)]
29pub struct NatsTransport {
30 client: async_nats::Client,
31 config: NatsTransportConfig,
32}
33
34impl NatsTransport {
35 pub fn new(client: async_nats::Client, config: NatsTransportConfig) -> Self {
36 Self { client, config }
37 }
38
39 pub fn client(&self) -> &async_nats::Client {
40 &self.client
41 }
42
43 pub async fn serve<H, Fut>(
45 &self,
46 subject: &str,
47 protocol_label: &'static str,
48 handler: H,
49 ) -> Result<(), NatsProtocolError>
50 where
51 H: Fn(NatsTransport, async_nats::Message) -> Fut + Send + Sync + 'static,
52 Fut: Future<Output = Result<(), NatsProtocolError>> + Send + 'static,
53 {
54 use futures::StreamExt;
55 let handler = Arc::new(handler);
56 let mut sub = self.client.subscribe(subject.to_string()).await?;
57 while let Some(msg) = sub.next().await {
58 let transport = self.clone();
59 let handler = handler.clone();
60 tokio::spawn(async move {
61 if let Err(e) = handler(transport, msg).await {
62 tracing::error!(error = %e, "nats {protocol_label} handler failed");
63 }
64 });
65 }
66 Ok(())
67 }
68
69 pub async fn publish_run_stream<E, BuildEncoder>(
70 &self,
71 run: RunStream,
72 reply: async_nats::Subject,
73 build_encoder: BuildEncoder,
74 ) -> Result<(), NatsProtocolError>
75 where
76 E: Transcoder<Input = AgentEvent> + 'static,
77 E::Output: Serialize + Send + 'static,
78 BuildEncoder: FnOnce(&RunStream) -> E,
79 {
80 let session_thread_id = run.thread_id.clone();
81 let encoder = build_encoder(&run);
82 let upstream = Arc::new(NatsReplyServerEndpoint::new(self.client.clone(), reply));
83 let runtime_ep = Arc::new(RuntimeEndpoint::from_run_stream_with_buffer(
84 run,
85 self.config.outbound_buffer,
86 ));
87 let downstream = Arc::new(TranscoderEndpoint::new(runtime_ep, encoder));
88 let binding = TransportBinding {
89 session: SessionId {
90 thread_id: session_thread_id,
91 },
92 caps: TransportCapabilities {
93 upstream_async: false,
94 downstream_streaming: true,
95 single_channel_bidirectional: false,
96 resumable_downstream: false,
97 },
98 upstream,
99 downstream,
100 };
101 relay_binding(binding, RelayCancellation::new())
102 .await
103 .map_err(|e| NatsProtocolError::Run(format!("transport relay failed: {e}")))?;
104
105 Ok(())
106 }
107
108 pub(crate) async fn publish_error_event<ErrEvent: Serialize>(
109 &self,
110 reply: async_nats::Subject,
111 event: ErrEvent,
112 ) -> Result<(), NatsProtocolError> {
113 let payload = serde_json::to_vec(&event)
114 .map_err(|e| NatsProtocolError::Run(format!("serialize error event failed: {e}")))?
115 .into();
116 if let Err(publish_err) = self.client.publish(reply, payload).await {
117 return Err(NatsProtocolError::Run(format!(
118 "publish error event failed: {publish_err}"
119 )));
120 }
121 Ok(())
122 }
123}
124
125pub(crate) struct NatsReplyServerEndpoint {
126 client: async_nats::Client,
127 reply: async_nats::Subject,
128}
129
130impl NatsReplyServerEndpoint {
131 pub(crate) fn new(client: async_nats::Client, reply: async_nats::Subject) -> Self {
132 Self { client, reply }
133 }
134}
135
136#[async_trait::async_trait]
137impl<Evt> Endpoint<RuntimeInput, Evt> for NatsReplyServerEndpoint
138where
139 Evt: Serialize + Send + 'static,
140{
141 async fn recv(&self) -> Result<crate::transport::BoxStream<RuntimeInput>, TransportError> {
142 let stream = futures::stream::empty::<Result<RuntimeInput, TransportError>>();
143 Ok(Box::pin(stream))
144 }
145
146 async fn send(&self, item: Evt) -> Result<(), TransportError> {
147 let payload = serde_json::to_vec(&item).map_err(|e| {
148 tracing::warn!(error = %e, "failed to serialize NATS protocol event");
149 TransportError::Io(format!("serialize event failed: {e}"))
150 })?;
151 self.client
152 .publish(self.reply.clone(), payload.into())
153 .await
154 .map_err(|e| TransportError::Io(e.to_string()))
155 }
156
157 async fn close(&self) -> Result<(), TransportError> {
158 Ok(())
159 }
160}