1use bytes::Bytes;
2use serde::Serialize;
3use std::future::Future;
4use std::sync::Arc;
5use tirea_agentos::contracts::AgentEvent;
6use tirea_contract::{RuntimeInput, Transcoder};
7use tokio::sync::{broadcast, mpsc};
8use tracing::warn;
9
10use super::http_sse::HttpSseServerEndpoint;
11use super::runtime_endpoint::{RunStarter, RuntimeEndpoint};
12use super::{
13 relay_binding, RelayCancellation, SessionId, TranscoderEndpoint, TransportBinding,
14 TransportCapabilities,
15};
16
17pub struct HttpSseRelayConfig<F, ErrFmt> {
18 pub thread_id: String,
19 pub fanout: Option<broadcast::Sender<Bytes>>,
20 pub resumable_downstream: bool,
21 pub protocol_label: &'static str,
22 pub on_relay_done: F,
23 pub error_formatter: ErrFmt,
24}
25
26pub fn wire_http_sse_relay<E, F, Fut, ErrFmt>(
34 run_starter: RunStarter,
35 encoder: E,
36 ingress_rx: mpsc::UnboundedReceiver<RuntimeInput>,
37 config: HttpSseRelayConfig<F, ErrFmt>,
38) -> mpsc::Receiver<Bytes>
39where
40 E: Transcoder<Input = AgentEvent> + 'static,
41 E::Output: Serialize + Send + 'static,
42 F: FnOnce(mpsc::Sender<Bytes>) -> Fut + Send + 'static,
43 Fut: Future<Output = ()> + Send,
44 ErrFmt: Fn(String) -> Bytes + Send + 'static,
45{
46 let HttpSseRelayConfig {
47 thread_id,
48 fanout,
49 resumable_downstream,
50 protocol_label,
51 on_relay_done,
52 error_formatter,
53 } = config;
54 let (sse_tx, sse_rx) = mpsc::channel::<Bytes>(64);
55
56 let upstream: Arc<HttpSseServerEndpoint<E::Output>> = match fanout {
57 Some(f) => Arc::new(HttpSseServerEndpoint::with_fanout(
58 ingress_rx,
59 sse_tx.clone(),
60 f,
61 )),
62 None => Arc::new(HttpSseServerEndpoint::new(ingress_rx, sse_tx.clone())),
63 };
64
65 let runtime_ep = Arc::new(RuntimeEndpoint::new(run_starter));
66 let downstream = Arc::new(TranscoderEndpoint::new(runtime_ep, encoder));
67
68 let binding = TransportBinding {
69 session: SessionId { thread_id },
70 caps: TransportCapabilities {
71 upstream_async: true,
72 downstream_streaming: true,
73 single_channel_bidirectional: false,
74 resumable_downstream,
75 },
76 upstream,
77 downstream,
78 };
79 let relay_cancel = RelayCancellation::new();
80 tokio::spawn(async move {
81 if let Err(err) = relay_binding(binding, relay_cancel.clone()).await {
82 warn!(error = %err, "{protocol_label} transport relay failed");
83 let _ = sse_tx.send(error_formatter(err.to_string())).await;
84 }
85 on_relay_done(sse_tx).await;
86 });
87
88 sse_rx
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use std::pin::Pin;
95 use std::sync::atomic::{AtomicBool, Ordering};
96 use tirea_agentos::contracts::{AgentEvent, RunRequest, ToolCallDecision};
97 use tirea_agentos::runtime::RunStream;
98 use tirea_contract::RunOrigin;
99 use tirea_contract::Transcoder;
100
101 struct TestEncoder;
103
104 impl Transcoder for TestEncoder {
105 type Input = AgentEvent;
106 type Output = String;
107
108 fn prologue(&mut self) -> Vec<String> {
109 vec!["[start]".to_string()]
110 }
111
112 fn transcode(&mut self, item: &AgentEvent) -> Vec<String> {
113 match item {
114 AgentEvent::TextDelta { delta } => vec![format!("text:{delta}")],
115 _ => vec!["other".to_string()],
116 }
117 }
118
119 fn epilogue(&mut self) -> Vec<String> {
120 vec!["[end]".to_string()]
121 }
122 }
123
124 fn fake_run_stream(
125 events: Vec<AgentEvent>,
126 ) -> (RunStream, mpsc::UnboundedReceiver<ToolCallDecision>) {
127 let (decision_tx, decision_rx) = mpsc::unbounded_channel();
128 let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(64);
129
130 tokio::spawn(async move {
131 for e in events {
132 let _ = event_tx.send(e).await;
133 }
134 });
135
136 let stream: Pin<Box<dyn futures::Stream<Item = AgentEvent> + Send>> =
137 Box::pin(async_stream::stream! {
138 let mut rx = event_rx;
139 while let Some(item) = rx.recv().await {
140 yield item;
141 }
142 });
143
144 let run = RunStream {
145 thread_id: "thread-1".to_string(),
146 run_id: "run-1".to_string(),
147 decision_tx,
148 events: stream,
149 };
150
151 (run, decision_rx)
152 }
153
154 fn fake_starter(
155 events: Vec<AgentEvent>,
156 ) -> (RunStarter, mpsc::UnboundedReceiver<ToolCallDecision>) {
157 let (run, decision_rx) = fake_run_stream(events);
158 let starter: RunStarter = Box::new(move |_request| Box::pin(async move { Ok(run) }));
159 (starter, decision_rx)
160 }
161
162 fn test_run_request() -> RunRequest {
163 RunRequest {
164 agent_id: "test".into(),
165 thread_id: None,
166 run_id: None,
167 parent_run_id: None,
168 parent_thread_id: None,
169 resource_id: None,
170 origin: RunOrigin::default(),
171 state: None,
172 messages: vec![],
173 initial_decisions: vec![],
174 source_mailbox_entry_id: None,
175 }
176 }
177
178 fn collect_sse_strings(chunks: Vec<Bytes>) -> Vec<String> {
179 chunks
180 .into_iter()
181 .filter_map(|b| {
182 let s = String::from_utf8(b.to_vec()).ok()?;
183 let trimmed = s.trim();
184 let payload = trimmed.strip_prefix("data: ")?;
185 serde_json::from_str::<String>(payload).ok()
186 })
187 .collect()
188 }
189
190 #[tokio::test]
191 async fn events_flow_through_to_sse_bytes() {
192 let events = vec![
193 AgentEvent::TextDelta {
194 delta: "hello".to_string(),
195 },
196 AgentEvent::TextDelta {
197 delta: "world".to_string(),
198 },
199 ];
200 let (starter, _decision_rx) = fake_starter(events);
201 let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
202
203 ingress_tx
204 .send(RuntimeInput::Run(test_run_request()))
205 .unwrap();
206 drop(ingress_tx);
207
208 let mut sse_rx = wire_http_sse_relay(
209 starter,
210 TestEncoder,
211 ingress_rx,
212 HttpSseRelayConfig {
213 thread_id: "thread-1".to_string(),
214 fanout: None,
215 resumable_downstream: false,
216 protocol_label: "test",
217 on_relay_done: |_sse_tx| async {},
218 error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
219 },
220 );
221
222 let mut chunks = Vec::new();
223 while let Some(chunk) = sse_rx.recv().await {
224 chunks.push(chunk);
225 }
226
227 let events = collect_sse_strings(chunks);
228 assert_eq!(events[0], "[start]");
229 assert_eq!(events[1], "text:hello");
230 assert_eq!(events[2], "text:world");
231 assert_eq!(events[3], "[end]");
232 assert_eq!(events.len(), 4);
233 }
234
235 #[tokio::test]
236 async fn on_relay_done_callback_is_invoked() {
237 let called = Arc::new(AtomicBool::new(false));
238 let called_clone = called.clone();
239
240 let (starter, _decision_rx) = fake_starter(vec![]);
241 let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
242 ingress_tx
243 .send(RuntimeInput::Run(test_run_request()))
244 .unwrap();
245 drop(ingress_tx);
246
247 let mut sse_rx = wire_http_sse_relay(
248 starter,
249 TestEncoder,
250 ingress_rx,
251 HttpSseRelayConfig {
252 thread_id: "thread-1".to_string(),
253 fanout: None,
254 resumable_downstream: false,
255 protocol_label: "test",
256 on_relay_done: move |_sse_tx| async move {
257 called_clone.store(true, Ordering::SeqCst);
258 },
259 error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
260 },
261 );
262
263 while sse_rx.recv().await.is_some() {}
264
265 assert!(
266 called.load(Ordering::SeqCst),
267 "on_relay_done should be called"
268 );
269 }
270
271 #[tokio::test]
272 async fn trailer_via_callback_sse_tx() {
273 let (starter, _decision_rx) = fake_starter(vec![AgentEvent::TextDelta {
274 delta: "x".to_string(),
275 }]);
276 let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
277 ingress_tx
278 .send(RuntimeInput::Run(test_run_request()))
279 .unwrap();
280 drop(ingress_tx);
281
282 let mut sse_rx = wire_http_sse_relay(
283 starter,
284 TestEncoder,
285 ingress_rx,
286 HttpSseRelayConfig {
287 thread_id: "thread-1".to_string(),
288 fanout: None,
289 resumable_downstream: false,
290 protocol_label: "test",
291 on_relay_done: |sse_tx: mpsc::Sender<Bytes>| async move {
292 let _ = sse_tx.send(Bytes::from("data: [DONE]\n\n")).await;
293 },
294 error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
295 },
296 );
297
298 let mut chunks = Vec::new();
299 while let Some(chunk) = sse_rx.recv().await {
300 chunks.push(chunk);
301 }
302
303 let last = String::from_utf8(chunks.last().unwrap().to_vec()).unwrap();
304 assert_eq!(last.trim(), "data: [DONE]");
305 }
306
307 #[tokio::test]
308 async fn fanout_receives_all_sse_events() {
309 let events = vec![AgentEvent::TextDelta {
310 delta: "hi".to_string(),
311 }];
312 let (starter, _decision_rx) = fake_starter(events);
313 let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
314 let (fanout_tx, mut fanout_rx) = broadcast::channel::<Bytes>(64);
315
316 ingress_tx
317 .send(RuntimeInput::Run(test_run_request()))
318 .unwrap();
319 drop(ingress_tx);
320
321 let mut sse_rx = wire_http_sse_relay(
322 starter,
323 TestEncoder,
324 ingress_rx,
325 HttpSseRelayConfig {
326 thread_id: "thread-1".to_string(),
327 fanout: Some(fanout_tx),
328 resumable_downstream: true,
329 protocol_label: "test",
330 on_relay_done: |_sse_tx| async {},
331 error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
332 },
333 );
334
335 let mut sse_chunks = Vec::new();
336 while let Some(chunk) = sse_rx.recv().await {
337 sse_chunks.push(chunk);
338 }
339
340 let mut fanout_chunks = Vec::new();
341 while let Ok(chunk) = fanout_rx.try_recv() {
342 fanout_chunks.push(chunk);
343 }
344
345 let sse_events = collect_sse_strings(sse_chunks);
346 let fanout_events = collect_sse_strings(fanout_chunks);
347
348 assert!(sse_events.contains(&"text:hi".to_string()));
349 assert!(fanout_events.contains(&"text:hi".to_string()));
350 }
351
352 #[tokio::test]
353 async fn decision_ingress_does_not_break_sse_stream() {
354 let (starter, _decision_rx) = fake_starter(vec![AgentEvent::TextDelta {
355 delta: "a".to_string(),
356 }]);
357 let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
358
359 ingress_tx
360 .send(RuntimeInput::Run(test_run_request()))
361 .unwrap();
362
363 let mut sse_rx = wire_http_sse_relay(
364 starter,
365 TestEncoder,
366 ingress_rx,
367 HttpSseRelayConfig {
368 thread_id: "thread-1".to_string(),
369 fanout: None,
370 resumable_downstream: false,
371 protocol_label: "test",
372 on_relay_done: |_sse_tx| async {},
373 error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
374 },
375 );
376
377 let decision = ToolCallDecision::resume("d1", serde_json::json!({"approved": true}), 1);
378 ingress_tx.send(RuntimeInput::Decision(decision)).unwrap();
379 drop(ingress_tx);
380
381 let mut chunks = Vec::new();
382 while let Some(chunk) = sse_rx.recv().await {
383 chunks.push(chunk);
384 }
385 let text = String::from_utf8(
386 chunks
387 .into_iter()
388 .flat_map(|chunk| chunk.to_vec())
389 .collect::<Vec<_>>(),
390 )
391 .unwrap_or_default();
392 assert!(
393 text.contains("text:a"),
394 "stream should still emit run output: {text}"
395 );
396 }
397}