tirea_agentos_server/transport/
http_sse.rs

1use axum::body::Body;
2use axum::http::{header, HeaderMap, HeaderValue};
3use axum::response::{IntoResponse, Response};
4use bytes::Bytes;
5use serde::Serialize;
6use std::convert::Infallible;
7use std::marker::PhantomData;
8use tirea_contract::RuntimeInput;
9use tokio::sync::{broadcast, mpsc, Mutex};
10
11use crate::transport::{BoxStream, Endpoint, TransportError};
12
13pub struct HttpSseServerEndpoint<SendMsg> {
14    ingress_rx: Mutex<Option<mpsc::UnboundedReceiver<RuntimeInput>>>,
15    sse_tx: mpsc::Sender<Bytes>,
16    fanout: Option<broadcast::Sender<Bytes>>,
17    _phantom: PhantomData<fn(SendMsg)>,
18}
19
20impl<SendMsg> HttpSseServerEndpoint<SendMsg> {
21    pub fn new(
22        ingress_rx: mpsc::UnboundedReceiver<RuntimeInput>,
23        sse_tx: mpsc::Sender<Bytes>,
24    ) -> Self {
25        Self {
26            ingress_rx: Mutex::new(Some(ingress_rx)),
27            sse_tx,
28            fanout: None,
29            _phantom: PhantomData,
30        }
31    }
32
33    pub fn with_fanout(
34        ingress_rx: mpsc::UnboundedReceiver<RuntimeInput>,
35        sse_tx: mpsc::Sender<Bytes>,
36        fanout: broadcast::Sender<Bytes>,
37    ) -> Self {
38        Self {
39            ingress_rx: Mutex::new(Some(ingress_rx)),
40            sse_tx,
41            fanout: Some(fanout),
42            _phantom: PhantomData,
43        }
44    }
45}
46
47#[async_trait::async_trait]
48impl<SendMsg> Endpoint<RuntimeInput, SendMsg> for HttpSseServerEndpoint<SendMsg>
49where
50    SendMsg: Serialize + Send + 'static,
51{
52    async fn recv(&self) -> Result<BoxStream<RuntimeInput>, TransportError> {
53        let mut guard = self.ingress_rx.lock().await;
54        let mut rx = guard.take().ok_or(TransportError::Closed)?;
55        let stream = async_stream::stream! {
56            while let Some(item) = rx.recv().await {
57                yield Ok(item);
58            }
59        };
60        Ok(Box::pin(stream))
61    }
62
63    async fn send(&self, item: SendMsg) -> Result<(), TransportError> {
64        let json = serde_json::to_string(&item).map_err(|e| {
65            tracing::warn!(error = %e, "failed to serialize SSE protocol event");
66            TransportError::Io(format!("serialize event failed: {e}"))
67        })?;
68        let chunk = Bytes::from(format!("data: {json}\n\n"));
69        if let Some(f) = &self.fanout {
70            let _ = f.send(chunk.clone());
71        }
72        self.sse_tx
73            .send(chunk)
74            .await
75            .map_err(|_| TransportError::Closed)
76    }
77
78    async fn close(&self) -> Result<(), TransportError> {
79        Ok(())
80    }
81}
82
83pub fn sse_body_stream(
84    mut rx: mpsc::Receiver<Bytes>,
85) -> impl futures::Stream<Item = Result<Bytes, Infallible>> + Send + 'static {
86    async_stream::stream! {
87        while let Some(chunk) = rx.recv().await {
88            yield Ok::<Bytes, Infallible>(chunk);
89        }
90    }
91}
92
93pub fn sse_response<S>(stream: S) -> Response
94where
95    S: futures::Stream<Item = Result<Bytes, Infallible>> + Send + 'static,
96{
97    let mut headers = HeaderMap::new();
98    headers.insert(
99        header::CONTENT_TYPE,
100        HeaderValue::from_static("text/event-stream"),
101    );
102    headers.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
103    headers.insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
104    (headers, Body::from_stream(stream)).into_response()
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use futures::StreamExt;
111    use serde_json::json;
112
113    #[tokio::test]
114    async fn send_serializes_and_frames_as_sse() {
115        let (_ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
116        let (sse_tx, mut sse_rx) = mpsc::channel::<Bytes>(4);
117        let endpoint: HttpSseServerEndpoint<serde_json::Value> =
118            HttpSseServerEndpoint::new(ingress_rx, sse_tx);
119
120        let event = json!({"type": "test"});
121        endpoint.send(event).await.unwrap();
122        let received = sse_rx.recv().await.unwrap();
123        assert_eq!(received, Bytes::from("data: {\"type\":\"test\"}\n\n"));
124    }
125
126    #[tokio::test]
127    async fn send_with_fanout_broadcasts() {
128        let (_ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
129        let (sse_tx, mut sse_rx) = mpsc::channel::<Bytes>(4);
130        let (fanout_tx, mut fanout_rx) = broadcast::channel::<Bytes>(4);
131        let endpoint: HttpSseServerEndpoint<serde_json::Value> =
132            HttpSseServerEndpoint::with_fanout(ingress_rx, sse_tx, fanout_tx);
133
134        let event = json!({"type": "test"});
135        endpoint.send(event).await.unwrap();
136
137        let received = sse_rx.recv().await.unwrap();
138        let expected = Bytes::from("data: {\"type\":\"test\"}\n\n");
139        assert_eq!(received, expected);
140
141        let fanout_received = fanout_rx.recv().await.unwrap();
142        assert_eq!(fanout_received, expected);
143    }
144
145    #[tokio::test]
146    async fn send_returns_error_on_closed_channel() {
147        let (_ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
148        let (sse_tx, sse_rx) = mpsc::channel::<Bytes>(4);
149        let endpoint: HttpSseServerEndpoint<serde_json::Value> =
150            HttpSseServerEndpoint::new(ingress_rx, sse_tx);
151        drop(sse_rx);
152
153        let result = endpoint.send(json!({"type": "test"})).await;
154        assert!(result.is_err());
155    }
156
157    #[tokio::test]
158    async fn sse_body_stream_yields_all_chunks() {
159        let (tx, rx) = mpsc::channel::<Bytes>(4);
160        let stream = sse_body_stream(rx);
161        tokio::pin!(stream);
162
163        tx.send(Bytes::from("a")).await.unwrap();
164        tx.send(Bytes::from("b")).await.unwrap();
165        drop(tx);
166
167        let items: Vec<Bytes> = stream.map(|r| r.unwrap()).collect().await;
168        assert_eq!(items, vec![Bytes::from("a"), Bytes::from("b")]);
169    }
170}