tirea_agentos_server/transport/
http_sse.rs1use axum::body::Body;
2use axum::http::{header, HeaderMap, HeaderValue};
3use axum::response::{IntoResponse, Response};
4use bytes::Bytes;
5use serde::Serialize;
6use std::convert::Infallible;
7use std::marker::PhantomData;
8use tirea_contract::RuntimeInput;
9use tokio::sync::{broadcast, mpsc, Mutex};
10
11use crate::transport::{BoxStream, Endpoint, TransportError};
12
13pub struct HttpSseServerEndpoint<SendMsg> {
14 ingress_rx: Mutex<Option<mpsc::UnboundedReceiver<RuntimeInput>>>,
15 sse_tx: mpsc::Sender<Bytes>,
16 fanout: Option<broadcast::Sender<Bytes>>,
17 _phantom: PhantomData<fn(SendMsg)>,
18}
19
20impl<SendMsg> HttpSseServerEndpoint<SendMsg> {
21 pub fn new(
22 ingress_rx: mpsc::UnboundedReceiver<RuntimeInput>,
23 sse_tx: mpsc::Sender<Bytes>,
24 ) -> Self {
25 Self {
26 ingress_rx: Mutex::new(Some(ingress_rx)),
27 sse_tx,
28 fanout: None,
29 _phantom: PhantomData,
30 }
31 }
32
33 pub fn with_fanout(
34 ingress_rx: mpsc::UnboundedReceiver<RuntimeInput>,
35 sse_tx: mpsc::Sender<Bytes>,
36 fanout: broadcast::Sender<Bytes>,
37 ) -> Self {
38 Self {
39 ingress_rx: Mutex::new(Some(ingress_rx)),
40 sse_tx,
41 fanout: Some(fanout),
42 _phantom: PhantomData,
43 }
44 }
45}
46
47#[async_trait::async_trait]
48impl<SendMsg> Endpoint<RuntimeInput, SendMsg> for HttpSseServerEndpoint<SendMsg>
49where
50 SendMsg: Serialize + Send + 'static,
51{
52 async fn recv(&self) -> Result<BoxStream<RuntimeInput>, TransportError> {
53 let mut guard = self.ingress_rx.lock().await;
54 let mut rx = guard.take().ok_or(TransportError::Closed)?;
55 let stream = async_stream::stream! {
56 while let Some(item) = rx.recv().await {
57 yield Ok(item);
58 }
59 };
60 Ok(Box::pin(stream))
61 }
62
63 async fn send(&self, item: SendMsg) -> Result<(), TransportError> {
64 let json = serde_json::to_string(&item).map_err(|e| {
65 tracing::warn!(error = %e, "failed to serialize SSE protocol event");
66 TransportError::Io(format!("serialize event failed: {e}"))
67 })?;
68 let chunk = Bytes::from(format!("data: {json}\n\n"));
69 if let Some(f) = &self.fanout {
70 let _ = f.send(chunk.clone());
71 }
72 self.sse_tx
73 .send(chunk)
74 .await
75 .map_err(|_| TransportError::Closed)
76 }
77
78 async fn close(&self) -> Result<(), TransportError> {
79 Ok(())
80 }
81}
82
83pub fn sse_body_stream(
84 mut rx: mpsc::Receiver<Bytes>,
85) -> impl futures::Stream<Item = Result<Bytes, Infallible>> + Send + 'static {
86 async_stream::stream! {
87 while let Some(chunk) = rx.recv().await {
88 yield Ok::<Bytes, Infallible>(chunk);
89 }
90 }
91}
92
93pub fn sse_response<S>(stream: S) -> Response
94where
95 S: futures::Stream<Item = Result<Bytes, Infallible>> + Send + 'static,
96{
97 let mut headers = HeaderMap::new();
98 headers.insert(
99 header::CONTENT_TYPE,
100 HeaderValue::from_static("text/event-stream"),
101 );
102 headers.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
103 headers.insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
104 (headers, Body::from_stream(stream)).into_response()
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use futures::StreamExt;
111 use serde_json::json;
112
113 #[tokio::test]
114 async fn send_serializes_and_frames_as_sse() {
115 let (_ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
116 let (sse_tx, mut sse_rx) = mpsc::channel::<Bytes>(4);
117 let endpoint: HttpSseServerEndpoint<serde_json::Value> =
118 HttpSseServerEndpoint::new(ingress_rx, sse_tx);
119
120 let event = json!({"type": "test"});
121 endpoint.send(event).await.unwrap();
122 let received = sse_rx.recv().await.unwrap();
123 assert_eq!(received, Bytes::from("data: {\"type\":\"test\"}\n\n"));
124 }
125
126 #[tokio::test]
127 async fn send_with_fanout_broadcasts() {
128 let (_ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
129 let (sse_tx, mut sse_rx) = mpsc::channel::<Bytes>(4);
130 let (fanout_tx, mut fanout_rx) = broadcast::channel::<Bytes>(4);
131 let endpoint: HttpSseServerEndpoint<serde_json::Value> =
132 HttpSseServerEndpoint::with_fanout(ingress_rx, sse_tx, fanout_tx);
133
134 let event = json!({"type": "test"});
135 endpoint.send(event).await.unwrap();
136
137 let received = sse_rx.recv().await.unwrap();
138 let expected = Bytes::from("data: {\"type\":\"test\"}\n\n");
139 assert_eq!(received, expected);
140
141 let fanout_received = fanout_rx.recv().await.unwrap();
142 assert_eq!(fanout_received, expected);
143 }
144
145 #[tokio::test]
146 async fn send_returns_error_on_closed_channel() {
147 let (_ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
148 let (sse_tx, sse_rx) = mpsc::channel::<Bytes>(4);
149 let endpoint: HttpSseServerEndpoint<serde_json::Value> =
150 HttpSseServerEndpoint::new(ingress_rx, sse_tx);
151 drop(sse_rx);
152
153 let result = endpoint.send(json!({"type": "test"})).await;
154 assert!(result.is_err());
155 }
156
157 #[tokio::test]
158 async fn sse_body_stream_yields_all_chunks() {
159 let (tx, rx) = mpsc::channel::<Bytes>(4);
160 let stream = sse_body_stream(rx);
161 tokio::pin!(stream);
162
163 tx.send(Bytes::from("a")).await.unwrap();
164 tx.send(Bytes::from("b")).await.unwrap();
165 drop(tx);
166
167 let items: Vec<Bytes> = stream.map(|r| r.unwrap()).collect().await;
168 assert_eq!(items, vec![Bytes::from("a"), Bytes::from("b")]);
169 }
170}