tirea_agentos_server/transport/
http_run.rs

1use bytes::Bytes;
2use serde::Serialize;
3use std::future::Future;
4use std::sync::Arc;
5use tirea_agentos::contracts::AgentEvent;
6use tirea_contract::{RuntimeInput, Transcoder};
7use tokio::sync::{broadcast, mpsc};
8use tracing::warn;
9
10use super::http_sse::HttpSseServerEndpoint;
11use super::runtime_endpoint::{RunStarter, RuntimeEndpoint};
12use super::{
13    relay_binding, RelayCancellation, SessionId, TranscoderEndpoint, TransportBinding,
14    TransportCapabilities,
15};
16
17pub struct HttpSseRelayConfig<F, ErrFmt> {
18    pub thread_id: String,
19    pub fanout: Option<broadcast::Sender<Bytes>>,
20    pub resumable_downstream: bool,
21    pub protocol_label: &'static str,
22    pub on_relay_done: F,
23    pub error_formatter: ErrFmt,
24}
25
26/// Wire an HTTP SSE relay for a single agent run.
27///
28/// Sets up the full pipeline: runtime endpoint -> protocol transcoder ->
29/// transport binding -> SSE server endpoint, then spawns the relay and
30/// event pump tasks.
31///
32/// Returns the SSE byte receiver to feed into an HTTP response body.
33pub fn wire_http_sse_relay<E, F, Fut, ErrFmt>(
34    run_starter: RunStarter,
35    encoder: E,
36    ingress_rx: mpsc::UnboundedReceiver<RuntimeInput>,
37    config: HttpSseRelayConfig<F, ErrFmt>,
38) -> mpsc::Receiver<Bytes>
39where
40    E: Transcoder<Input = AgentEvent> + 'static,
41    E::Output: Serialize + Send + 'static,
42    F: FnOnce(mpsc::Sender<Bytes>) -> Fut + Send + 'static,
43    Fut: Future<Output = ()> + Send,
44    ErrFmt: Fn(String) -> Bytes + Send + 'static,
45{
46    let HttpSseRelayConfig {
47        thread_id,
48        fanout,
49        resumable_downstream,
50        protocol_label,
51        on_relay_done,
52        error_formatter,
53    } = config;
54    let (sse_tx, sse_rx) = mpsc::channel::<Bytes>(64);
55
56    let upstream: Arc<HttpSseServerEndpoint<E::Output>> = match fanout {
57        Some(f) => Arc::new(HttpSseServerEndpoint::with_fanout(
58            ingress_rx,
59            sse_tx.clone(),
60            f,
61        )),
62        None => Arc::new(HttpSseServerEndpoint::new(ingress_rx, sse_tx.clone())),
63    };
64
65    let runtime_ep = Arc::new(RuntimeEndpoint::new(run_starter));
66    let downstream = Arc::new(TranscoderEndpoint::new(runtime_ep, encoder));
67
68    let binding = TransportBinding {
69        session: SessionId { thread_id },
70        caps: TransportCapabilities {
71            upstream_async: true,
72            downstream_streaming: true,
73            single_channel_bidirectional: false,
74            resumable_downstream,
75        },
76        upstream,
77        downstream,
78    };
79    let relay_cancel = RelayCancellation::new();
80    tokio::spawn(async move {
81        if let Err(err) = relay_binding(binding, relay_cancel.clone()).await {
82            warn!(error = %err, "{protocol_label} transport relay failed");
83            let _ = sse_tx.send(error_formatter(err.to_string())).await;
84        }
85        on_relay_done(sse_tx).await;
86    });
87
88    sse_rx
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use std::pin::Pin;
95    use std::sync::atomic::{AtomicBool, Ordering};
96    use tirea_agentos::contracts::{AgentEvent, RunRequest, ToolCallDecision};
97    use tirea_agentos::runtime::RunStream;
98    use tirea_contract::RunOrigin;
99    use tirea_contract::Transcoder;
100
101    /// Minimal encoder: maps each AgentEvent to a JSON string event.
102    struct TestEncoder;
103
104    impl Transcoder for TestEncoder {
105        type Input = AgentEvent;
106        type Output = String;
107
108        fn prologue(&mut self) -> Vec<String> {
109            vec!["[start]".to_string()]
110        }
111
112        fn transcode(&mut self, item: &AgentEvent) -> Vec<String> {
113            match item {
114                AgentEvent::TextDelta { delta } => vec![format!("text:{delta}")],
115                _ => vec!["other".to_string()],
116            }
117        }
118
119        fn epilogue(&mut self) -> Vec<String> {
120            vec!["[end]".to_string()]
121        }
122    }
123
124    fn fake_run_stream(
125        events: Vec<AgentEvent>,
126    ) -> (RunStream, mpsc::UnboundedReceiver<ToolCallDecision>) {
127        let (decision_tx, decision_rx) = mpsc::unbounded_channel();
128        let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(64);
129
130        tokio::spawn(async move {
131            for e in events {
132                let _ = event_tx.send(e).await;
133            }
134        });
135
136        let stream: Pin<Box<dyn futures::Stream<Item = AgentEvent> + Send>> =
137            Box::pin(async_stream::stream! {
138                let mut rx = event_rx;
139                while let Some(item) = rx.recv().await {
140                    yield item;
141                }
142            });
143
144        let run = RunStream {
145            thread_id: "thread-1".to_string(),
146            run_id: "run-1".to_string(),
147            decision_tx,
148            events: stream,
149        };
150
151        (run, decision_rx)
152    }
153
154    fn fake_starter(
155        events: Vec<AgentEvent>,
156    ) -> (RunStarter, mpsc::UnboundedReceiver<ToolCallDecision>) {
157        let (run, decision_rx) = fake_run_stream(events);
158        let starter: RunStarter = Box::new(move |_request| Box::pin(async move { Ok(run) }));
159        (starter, decision_rx)
160    }
161
162    fn test_run_request() -> RunRequest {
163        RunRequest {
164            agent_id: "test".into(),
165            thread_id: None,
166            run_id: None,
167            parent_run_id: None,
168            parent_thread_id: None,
169            resource_id: None,
170            origin: RunOrigin::default(),
171            state: None,
172            messages: vec![],
173            initial_decisions: vec![],
174            source_mailbox_entry_id: None,
175        }
176    }
177
178    fn collect_sse_strings(chunks: Vec<Bytes>) -> Vec<String> {
179        chunks
180            .into_iter()
181            .filter_map(|b| {
182                let s = String::from_utf8(b.to_vec()).ok()?;
183                let trimmed = s.trim();
184                let payload = trimmed.strip_prefix("data: ")?;
185                serde_json::from_str::<String>(payload).ok()
186            })
187            .collect()
188    }
189
190    #[tokio::test]
191    async fn events_flow_through_to_sse_bytes() {
192        let events = vec![
193            AgentEvent::TextDelta {
194                delta: "hello".to_string(),
195            },
196            AgentEvent::TextDelta {
197                delta: "world".to_string(),
198            },
199        ];
200        let (starter, _decision_rx) = fake_starter(events);
201        let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
202
203        ingress_tx
204            .send(RuntimeInput::Run(test_run_request()))
205            .unwrap();
206        drop(ingress_tx);
207
208        let mut sse_rx = wire_http_sse_relay(
209            starter,
210            TestEncoder,
211            ingress_rx,
212            HttpSseRelayConfig {
213                thread_id: "thread-1".to_string(),
214                fanout: None,
215                resumable_downstream: false,
216                protocol_label: "test",
217                on_relay_done: |_sse_tx| async {},
218                error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
219            },
220        );
221
222        let mut chunks = Vec::new();
223        while let Some(chunk) = sse_rx.recv().await {
224            chunks.push(chunk);
225        }
226
227        let events = collect_sse_strings(chunks);
228        assert_eq!(events[0], "[start]");
229        assert_eq!(events[1], "text:hello");
230        assert_eq!(events[2], "text:world");
231        assert_eq!(events[3], "[end]");
232        assert_eq!(events.len(), 4);
233    }
234
235    #[tokio::test]
236    async fn on_relay_done_callback_is_invoked() {
237        let called = Arc::new(AtomicBool::new(false));
238        let called_clone = called.clone();
239
240        let (starter, _decision_rx) = fake_starter(vec![]);
241        let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
242        ingress_tx
243            .send(RuntimeInput::Run(test_run_request()))
244            .unwrap();
245        drop(ingress_tx);
246
247        let mut sse_rx = wire_http_sse_relay(
248            starter,
249            TestEncoder,
250            ingress_rx,
251            HttpSseRelayConfig {
252                thread_id: "thread-1".to_string(),
253                fanout: None,
254                resumable_downstream: false,
255                protocol_label: "test",
256                on_relay_done: move |_sse_tx| async move {
257                    called_clone.store(true, Ordering::SeqCst);
258                },
259                error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
260            },
261        );
262
263        while sse_rx.recv().await.is_some() {}
264
265        assert!(
266            called.load(Ordering::SeqCst),
267            "on_relay_done should be called"
268        );
269    }
270
271    #[tokio::test]
272    async fn trailer_via_callback_sse_tx() {
273        let (starter, _decision_rx) = fake_starter(vec![AgentEvent::TextDelta {
274            delta: "x".to_string(),
275        }]);
276        let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
277        ingress_tx
278            .send(RuntimeInput::Run(test_run_request()))
279            .unwrap();
280        drop(ingress_tx);
281
282        let mut sse_rx = wire_http_sse_relay(
283            starter,
284            TestEncoder,
285            ingress_rx,
286            HttpSseRelayConfig {
287                thread_id: "thread-1".to_string(),
288                fanout: None,
289                resumable_downstream: false,
290                protocol_label: "test",
291                on_relay_done: |sse_tx: mpsc::Sender<Bytes>| async move {
292                    let _ = sse_tx.send(Bytes::from("data: [DONE]\n\n")).await;
293                },
294                error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
295            },
296        );
297
298        let mut chunks = Vec::new();
299        while let Some(chunk) = sse_rx.recv().await {
300            chunks.push(chunk);
301        }
302
303        let last = String::from_utf8(chunks.last().unwrap().to_vec()).unwrap();
304        assert_eq!(last.trim(), "data: [DONE]");
305    }
306
307    #[tokio::test]
308    async fn fanout_receives_all_sse_events() {
309        let events = vec![AgentEvent::TextDelta {
310            delta: "hi".to_string(),
311        }];
312        let (starter, _decision_rx) = fake_starter(events);
313        let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
314        let (fanout_tx, mut fanout_rx) = broadcast::channel::<Bytes>(64);
315
316        ingress_tx
317            .send(RuntimeInput::Run(test_run_request()))
318            .unwrap();
319        drop(ingress_tx);
320
321        let mut sse_rx = wire_http_sse_relay(
322            starter,
323            TestEncoder,
324            ingress_rx,
325            HttpSseRelayConfig {
326                thread_id: "thread-1".to_string(),
327                fanout: Some(fanout_tx),
328                resumable_downstream: true,
329                protocol_label: "test",
330                on_relay_done: |_sse_tx| async {},
331                error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
332            },
333        );
334
335        let mut sse_chunks = Vec::new();
336        while let Some(chunk) = sse_rx.recv().await {
337            sse_chunks.push(chunk);
338        }
339
340        let mut fanout_chunks = Vec::new();
341        while let Ok(chunk) = fanout_rx.try_recv() {
342            fanout_chunks.push(chunk);
343        }
344
345        let sse_events = collect_sse_strings(sse_chunks);
346        let fanout_events = collect_sse_strings(fanout_chunks);
347
348        assert!(sse_events.contains(&"text:hi".to_string()));
349        assert!(fanout_events.contains(&"text:hi".to_string()));
350    }
351
352    #[tokio::test]
353    async fn decision_ingress_does_not_break_sse_stream() {
354        let (starter, _decision_rx) = fake_starter(vec![AgentEvent::TextDelta {
355            delta: "a".to_string(),
356        }]);
357        let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
358
359        ingress_tx
360            .send(RuntimeInput::Run(test_run_request()))
361            .unwrap();
362
363        let mut sse_rx = wire_http_sse_relay(
364            starter,
365            TestEncoder,
366            ingress_rx,
367            HttpSseRelayConfig {
368                thread_id: "thread-1".to_string(),
369                fanout: None,
370                resumable_downstream: false,
371                protocol_label: "test",
372                on_relay_done: |_sse_tx| async {},
373                error_formatter: |msg| Bytes::from(format!("data: {{\"error\":\"{msg}\"}}\n\n")),
374            },
375        );
376
377        let decision = ToolCallDecision::resume("d1", serde_json::json!({"approved": true}), 1);
378        ingress_tx.send(RuntimeInput::Decision(decision)).unwrap();
379        drop(ingress_tx);
380
381        let mut chunks = Vec::new();
382        while let Some(chunk) = sse_rx.recv().await {
383            chunks.push(chunk);
384        }
385        let text = String::from_utf8(
386            chunks
387                .into_iter()
388                .flat_map(|chunk| chunk.to_vec())
389                .collect::<Vec<_>>(),
390        )
391        .unwrap_or_default();
392        assert!(
393            text.contains("text:a"),
394            "stream should still emit run output: {text}"
395        );
396    }
397}