tirea_agentos_server/transport/
mod.rs

1pub mod http_run;
2pub mod http_sse;
3pub mod nats;
4mod nats_error;
5pub mod runtime_endpoint;
6pub mod transcoder;
7
8pub use nats_error::NatsProtocolError;
9pub use runtime_endpoint::{RunStarter, RuntimeEndpoint};
10pub use transcoder::TranscoderEndpoint;
11
12use async_trait::async_trait;
13use futures::Stream;
14use futures::StreamExt;
15use std::pin::Pin;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::Arc;
18use tokio::sync::{mpsc, Mutex};
19
20/// Common boxed stream for transport endpoints.
21pub type BoxStream<T> = Pin<Box<dyn Stream<Item = Result<T, TransportError>> + Send>>;
22
23/// Session key for one chat transport binding.
24#[derive(Clone, Debug, PartialEq, Eq, Hash)]
25pub struct SessionId {
26    pub thread_id: String,
27}
28
29/// Transport-level capability declaration used for composition checks.
30#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
31pub struct TransportCapabilities {
32    /// Upstream (caller→runtime) supports asynchronous delivery.
33    pub upstream_async: bool,
34    /// Downstream (runtime→caller) delivers events as a stream.
35    pub downstream_streaming: bool,
36    /// Both directions share a single bidirectional channel.
37    pub single_channel_bidirectional: bool,
38    /// Downstream stream can be resumed after reconnection.
39    pub resumable_downstream: bool,
40}
41
42#[derive(Debug, thiserror::Error)]
43pub enum TransportError {
44    #[error("session not found: {0}")]
45    SessionNotFound(String),
46    #[error("closed")]
47    Closed,
48    #[error("io: {0}")]
49    Io(String),
50    #[error("internal: {0}")]
51    Internal(String),
52}
53
54/// Lightweight cancellation token for relay loops.
55#[derive(Clone, Default, Debug)]
56pub struct RelayCancellation {
57    cancelled: Arc<AtomicBool>,
58}
59
60impl RelayCancellation {
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    pub fn cancel(&self) {
66        self.cancelled.store(true, Ordering::Relaxed);
67    }
68
69    pub fn is_cancelled(&self) -> bool {
70        self.cancelled.load(Ordering::Relaxed)
71    }
72}
73
74/// Generic endpoint view.
75///
76/// A caller only needs recv/send from one side;
77/// direction is encoded at type-level by `RecvMsg` and `SendMsg`.
78#[async_trait]
79pub trait Endpoint<RecvMsg, SendMsg>: Send + Sync
80where
81    RecvMsg: Send + 'static,
82    SendMsg: Send + 'static,
83{
84    async fn recv(&self) -> Result<BoxStream<RecvMsg>, TransportError>;
85    async fn send(&self, item: SendMsg) -> Result<(), TransportError>;
86    async fn close(&self) -> Result<(), TransportError>;
87}
88
89/// Generic downstream endpoint backed by a bounded `mpsc::Receiver` for recv
90/// and an unbounded `mpsc::UnboundedSender` for send.
91///
92/// Test-only: used in integration tests to simulate a runtime endpoint.
93#[doc(hidden)]
94pub struct ChannelDownstreamEndpoint<RecvMsg, SendMsg>
95where
96    RecvMsg: Send + 'static,
97    SendMsg: Send + 'static,
98{
99    recv_rx: Mutex<Option<mpsc::Receiver<RecvMsg>>>,
100    send_tx: mpsc::UnboundedSender<SendMsg>,
101}
102
103impl<RecvMsg, SendMsg> ChannelDownstreamEndpoint<RecvMsg, SendMsg>
104where
105    RecvMsg: Send + 'static,
106    SendMsg: Send + 'static,
107{
108    pub fn new(recv_rx: mpsc::Receiver<RecvMsg>, send_tx: mpsc::UnboundedSender<SendMsg>) -> Self {
109        Self {
110            recv_rx: Mutex::new(Some(recv_rx)),
111            send_tx,
112        }
113    }
114}
115
116#[async_trait]
117impl<RecvMsg, SendMsg> Endpoint<RecvMsg, SendMsg> for ChannelDownstreamEndpoint<RecvMsg, SendMsg>
118where
119    RecvMsg: Send + 'static,
120    SendMsg: Send + 'static,
121{
122    async fn recv(&self) -> Result<BoxStream<RecvMsg>, TransportError> {
123        let mut guard = self.recv_rx.lock().await;
124        let mut rx = guard.take().ok_or(TransportError::Closed)?;
125        let stream = async_stream::stream! {
126            while let Some(item) = rx.recv().await {
127                yield Ok(item);
128            }
129        };
130        Ok(Box::pin(stream))
131    }
132
133    async fn send(&self, item: SendMsg) -> Result<(), TransportError> {
134        self.send_tx.send(item).map_err(|_| TransportError::Closed)
135    }
136
137    async fn close(&self) -> Result<(), TransportError> {
138        Ok(())
139    }
140}
141
142/// A matched pair of endpoints representing both sides of a transport channel.
143///
144/// - `server`: the runtime/handler side — receives `Ingress`, sends `Egress`
145/// - `client`: the caller/consumer side — receives `Egress`, sends `Ingress`
146pub struct EndpointPair<Ingress, Egress>
147where
148    Ingress: Send + 'static,
149    Egress: Send + 'static,
150{
151    pub server: Arc<dyn Endpoint<Ingress, Egress>>,
152    pub client: Arc<dyn Endpoint<Egress, Ingress>>,
153}
154
155/// Create an in-memory `EndpointPair` backed by bounded channels.
156pub fn channel_pair<A, B>(buffer: usize) -> EndpointPair<A, B>
157where
158    A: Send + 'static,
159    B: Send + 'static,
160{
161    let buffer = buffer.max(1);
162    let (a_tx, a_rx) = mpsc::channel::<A>(buffer);
163    let (b_tx, b_rx) = mpsc::channel::<B>(buffer);
164
165    let server = Arc::new(BoundedChannelEndpoint::new(a_rx, b_tx));
166    let client = Arc::new(BoundedChannelEndpoint::new(b_rx, a_tx));
167
168    EndpointPair { server, client }
169}
170
171/// Channel endpoint backed by bounded `mpsc` channels. Used internally by [`channel_pair`].
172struct BoundedChannelEndpoint<RecvMsg, SendMsg>
173where
174    RecvMsg: Send + 'static,
175    SendMsg: Send + 'static,
176{
177    recv_rx: Mutex<Option<mpsc::Receiver<RecvMsg>>>,
178    send_tx: mpsc::Sender<SendMsg>,
179}
180
181impl<RecvMsg, SendMsg> BoundedChannelEndpoint<RecvMsg, SendMsg>
182where
183    RecvMsg: Send + 'static,
184    SendMsg: Send + 'static,
185{
186    fn new(recv_rx: mpsc::Receiver<RecvMsg>, send_tx: mpsc::Sender<SendMsg>) -> Self {
187        Self {
188            recv_rx: Mutex::new(Some(recv_rx)),
189            send_tx,
190        }
191    }
192}
193
194#[async_trait]
195impl<RecvMsg, SendMsg> Endpoint<RecvMsg, SendMsg> for BoundedChannelEndpoint<RecvMsg, SendMsg>
196where
197    RecvMsg: Send + 'static,
198    SendMsg: Send + 'static,
199{
200    async fn recv(&self) -> Result<BoxStream<RecvMsg>, TransportError> {
201        let mut guard = self.recv_rx.lock().await;
202        let mut rx = guard.take().ok_or(TransportError::Closed)?;
203        let stream = async_stream::stream! {
204            while let Some(item) = rx.recv().await {
205                yield Ok(item);
206            }
207        };
208        Ok(Box::pin(stream))
209    }
210
211    async fn send(&self, item: SendMsg) -> Result<(), TransportError> {
212        self.send_tx
213            .send(item)
214            .await
215            .map_err(|_| TransportError::Closed)
216    }
217
218    async fn close(&self) -> Result<(), TransportError> {
219        Ok(())
220    }
221}
222
223/// Bound transport session with both sides.
224///
225/// - `upstream`: caller-facing side: recv `UpMsg`, send `DownMsg`
226/// - `downstream`: runtime/next-hop side: recv `DownMsg`, send `UpMsg`
227pub struct TransportBinding<UpMsg, DownMsg>
228where
229    UpMsg: Send + 'static,
230    DownMsg: Send + 'static,
231{
232    pub session: SessionId,
233    pub caps: TransportCapabilities,
234    pub upstream: Arc<dyn Endpoint<UpMsg, DownMsg>>,
235    pub downstream: Arc<dyn Endpoint<DownMsg, UpMsg>>,
236}
237
238/// Relay one bound session bidirectionally:
239/// - upstream.recv -> downstream.send
240/// - downstream.recv -> upstream.send
241pub async fn relay_binding<UpMsg, DownMsg>(
242    binding: TransportBinding<UpMsg, DownMsg>,
243    cancel: RelayCancellation,
244) -> Result<(), TransportError>
245where
246    UpMsg: Send + 'static,
247    DownMsg: Send + 'static,
248{
249    let upstream = binding.upstream.clone();
250    let downstream = binding.downstream.clone();
251
252    let ingress = {
253        let cancel = cancel.clone();
254        let upstream = upstream.clone();
255        let downstream = downstream.clone();
256        tokio::spawn(async move {
257            let mut stream = upstream.recv().await?;
258            while let Some(item) = stream.next().await {
259                if cancel.is_cancelled() {
260                    break;
261                }
262                downstream.send(item?).await?;
263            }
264            Ok::<(), TransportError>(())
265        })
266    };
267
268    let egress = {
269        let cancel = cancel.clone();
270        let upstream = upstream.clone();
271        let downstream = downstream.clone();
272        tokio::spawn(async move {
273            let mut stream = downstream.recv().await?;
274            while let Some(item) = stream.next().await {
275                if cancel.is_cancelled() {
276                    break;
277                }
278                upstream.send(item?).await?;
279            }
280            Ok::<(), TransportError>(())
281        })
282    };
283
284    fn normalize_relay_result(result: Result<(), TransportError>) -> Result<(), TransportError> {
285        match result {
286            Ok(()) | Err(TransportError::Closed) => Ok(()),
287            Err(other) => Err(other),
288        }
289    }
290
291    let egress_res = egress
292        .await
293        .map_err(|e| TransportError::Internal(e.to_string()))?;
294    cancel.cancel();
295
296    let ingress_res = if ingress.is_finished() {
297        Some(
298            ingress
299                .await
300                .map_err(|e| TransportError::Internal(e.to_string()))?,
301        )
302    } else {
303        ingress.abort();
304        None
305    };
306
307    if let Some(result) = ingress_res {
308        normalize_relay_result(result)?
309    }
310
311    normalize_relay_result(egress_res)
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use std::sync::Arc;
318    use tokio::sync::mpsc;
319
320    #[derive(Debug)]
321    struct ChannelEndpoint<Recv, SendMsg>
322    where
323        Recv: std::marker::Send + 'static,
324        SendMsg: std::marker::Send + 'static,
325    {
326        recv_rx: tokio::sync::Mutex<Option<mpsc::UnboundedReceiver<Recv>>>,
327        send_tx: mpsc::UnboundedSender<SendMsg>,
328    }
329
330    impl<Recv, SendMsg> ChannelEndpoint<Recv, SendMsg>
331    where
332        Recv: std::marker::Send + 'static,
333        SendMsg: std::marker::Send + 'static,
334    {
335        fn new(
336            recv_rx: mpsc::UnboundedReceiver<Recv>,
337            send_tx: mpsc::UnboundedSender<SendMsg>,
338        ) -> Self {
339            Self {
340                recv_rx: tokio::sync::Mutex::new(Some(recv_rx)),
341                send_tx,
342            }
343        }
344    }
345
346    #[derive(Debug)]
347    struct FailingSendEndpoint<Recv>
348    where
349        Recv: std::marker::Send + 'static,
350    {
351        recv_rx: tokio::sync::Mutex<Option<mpsc::UnboundedReceiver<Recv>>>,
352        error: &'static str,
353    }
354
355    impl<Recv> FailingSendEndpoint<Recv>
356    where
357        Recv: std::marker::Send + 'static,
358    {
359        fn new(recv_rx: mpsc::UnboundedReceiver<Recv>, error: &'static str) -> Self {
360            Self {
361                recv_rx: tokio::sync::Mutex::new(Some(recv_rx)),
362                error,
363            }
364        }
365    }
366
367    #[async_trait]
368    impl<Recv> Endpoint<Recv, u32> for FailingSendEndpoint<Recv>
369    where
370        Recv: std::marker::Send + 'static,
371    {
372        async fn recv(&self) -> Result<BoxStream<Recv>, TransportError> {
373            let mut guard = self.recv_rx.lock().await;
374            let rx = guard.take().ok_or(TransportError::Closed)?;
375            let stream = async_stream::stream! {
376                let mut rx = rx;
377                while let Some(item) = rx.recv().await {
378                    yield Ok(item);
379                }
380            };
381            Ok(Box::pin(stream))
382        }
383
384        async fn send(&self, _item: u32) -> Result<(), TransportError> {
385            Err(TransportError::Io(self.error.to_string()))
386        }
387
388        async fn close(&self) -> Result<(), TransportError> {
389            Ok(())
390        }
391    }
392
393    #[async_trait]
394    impl<Recv, SendMsg> Endpoint<Recv, SendMsg> for ChannelEndpoint<Recv, SendMsg>
395    where
396        Recv: std::marker::Send + 'static,
397        SendMsg: std::marker::Send + 'static,
398    {
399        async fn recv(&self) -> Result<BoxStream<Recv>, TransportError> {
400            let mut guard = self.recv_rx.lock().await;
401            let rx = guard.take().ok_or(TransportError::Closed)?;
402            let stream = async_stream::stream! {
403                let mut rx = rx;
404                while let Some(item) = rx.recv().await {
405                    yield Ok(item);
406                }
407            };
408            Ok(Box::pin(stream))
409        }
410
411        async fn send(&self, item: SendMsg) -> Result<(), TransportError> {
412            self.send_tx.send(item).map_err(|_| TransportError::Closed)
413        }
414
415        async fn close(&self) -> Result<(), TransportError> {
416            Ok(())
417        }
418    }
419
420    #[tokio::test]
421    async fn relay_binding_moves_messages_both_directions() {
422        let (up_in_tx, up_in_rx) = mpsc::unbounded_channel::<u32>();
423        let (up_send_tx, mut up_send_rx) = mpsc::unbounded_channel::<String>();
424
425        let (down_in_tx, down_in_rx) = mpsc::unbounded_channel::<String>();
426        let (down_send_tx, mut down_send_rx) = mpsc::unbounded_channel::<u32>();
427
428        let upstream = Arc::new(ChannelEndpoint::new(up_in_rx, up_send_tx));
429        let downstream = Arc::new(ChannelEndpoint::new(down_in_rx, down_send_tx));
430
431        let binding = TransportBinding {
432            session: SessionId {
433                thread_id: "thread-1".to_string(),
434            },
435            caps: TransportCapabilities {
436                upstream_async: true,
437                downstream_streaming: true,
438                single_channel_bidirectional: false,
439                resumable_downstream: true,
440            },
441            upstream,
442            downstream,
443        };
444
445        let cancel = RelayCancellation::new();
446        let relay_task = tokio::spawn(relay_binding(binding, cancel.clone()));
447
448        up_in_tx.send(7).unwrap();
449        down_in_tx.send("evt".to_string()).unwrap();
450
451        let up_out = up_send_rx
452            .recv()
453            .await
454            .expect("upstream should receive event");
455        let down_out = down_send_rx
456            .recv()
457            .await
458            .expect("downstream should receive ingress");
459
460        assert_eq!(up_out, "evt");
461        assert_eq!(down_out, 7);
462
463        cancel.cancel();
464        drop(up_in_tx);
465        drop(down_in_tx);
466
467        let result = relay_task.await.expect("relay task should join");
468        assert!(result.is_ok());
469    }
470
471    #[tokio::test]
472    async fn channel_downstream_endpoint_bridges_recv_and_send() {
473        let (recv_tx, recv_rx) = mpsc::channel::<u32>(4);
474        let (send_tx, mut send_rx) = mpsc::unbounded_channel::<String>();
475        let endpoint = ChannelDownstreamEndpoint::new(recv_rx, send_tx);
476
477        recv_tx.send(7).await.expect("seed recv channel");
478        drop(recv_tx);
479
480        let mut stream = endpoint.recv().await.expect("recv stream");
481        let first = stream
482            .next()
483            .await
484            .expect("stream item")
485            .expect("stream ok item");
486        assert_eq!(first, 7);
487
488        endpoint
489            .send("ok".to_string())
490            .await
491            .expect("send should work");
492        let sent = send_rx.recv().await.expect("sent item");
493        assert_eq!(sent, "ok");
494    }
495
496    // ── ChannelDownstreamEndpoint ──────────────────────────────────
497
498    #[tokio::test]
499    async fn channel_downstream_recv_called_twice_returns_closed() {
500        let (_tx, rx) = mpsc::channel::<u32>(4);
501        let (send_tx, _send_rx) = mpsc::unbounded_channel::<String>();
502        let ep = ChannelDownstreamEndpoint::new(rx, send_tx);
503
504        let _first = ep.recv().await.unwrap();
505        let second = ep.recv().await;
506        assert!(matches!(second, Err(TransportError::Closed)));
507    }
508
509    #[tokio::test]
510    async fn channel_downstream_send_after_receiver_dropped_returns_closed() {
511        let (_tx, rx) = mpsc::channel::<u32>(4);
512        let (send_tx, send_rx) = mpsc::unbounded_channel::<String>();
513        let ep = ChannelDownstreamEndpoint::new(rx, send_tx);
514
515        drop(send_rx);
516        let result = ep.send("msg".to_string()).await;
517        assert!(matches!(result, Err(TransportError::Closed)));
518    }
519
520    #[tokio::test]
521    async fn channel_downstream_recv_delivers_all_items_in_order() {
522        let (tx, rx) = mpsc::channel::<u32>(8);
523        let (send_tx, _send_rx) = mpsc::unbounded_channel::<String>();
524        let ep = ChannelDownstreamEndpoint::new(rx, send_tx);
525
526        for i in 0..5 {
527            tx.send(i).await.unwrap();
528        }
529        drop(tx);
530
531        let stream = ep.recv().await.unwrap();
532        let items: Vec<u32> = stream.map(|r| r.unwrap()).collect().await;
533        assert_eq!(items, vec![0, 1, 2, 3, 4]);
534    }
535
536    // ── channel_pair / BoundedChannelEndpoint ───────────────────
537
538    #[tokio::test]
539    async fn channel_pair_bidirectional() {
540        let pair = channel_pair::<u32, String>(4);
541
542        // server sends String, client receives String
543        pair.server.send("hello".to_string()).await.unwrap();
544        let mut client_stream = pair.client.recv().await.unwrap();
545        let received = client_stream.next().await.unwrap().unwrap();
546        assert_eq!(received, "hello");
547
548        // client sends u32, server receives u32
549        pair.client.send(42).await.unwrap();
550        let mut server_stream = pair.server.recv().await.unwrap();
551        let received = server_stream.next().await.unwrap().unwrap();
552        assert_eq!(received, 42);
553    }
554
555    #[tokio::test]
556    async fn channel_pair_close_propagates() {
557        let pair = channel_pair::<u32, String>(4);
558
559        pair.server.send("a".to_string()).await.unwrap();
560        drop(pair.server);
561
562        let mut stream = pair.client.recv().await.unwrap();
563        let first = stream.next().await.unwrap().unwrap();
564        assert_eq!(first, "a");
565
566        // After server side dropped, stream ends
567        assert!(stream.next().await.is_none());
568    }
569
570    #[tokio::test]
571    async fn channel_pair_recv_called_twice_returns_closed() {
572        let pair = channel_pair::<u32, String>(4);
573
574        let _first = pair.server.recv().await.unwrap();
575        let second = pair.server.recv().await;
576        assert!(matches!(second, Err(TransportError::Closed)));
577    }
578
579    #[tokio::test]
580    async fn channel_pair_send_after_peer_dropped_returns_closed() {
581        let pair = channel_pair::<u32, String>(4);
582
583        drop(pair.client);
584        // server sends String → client's recv channel, but client is dropped
585        let result = pair.server.send("orphan".to_string()).await;
586        assert!(matches!(result, Err(TransportError::Closed)));
587    }
588
589    #[tokio::test]
590    async fn channel_pair_multiple_items_preserve_order() {
591        let pair = channel_pair::<u32, String>(8);
592
593        for i in 0..5 {
594            pair.client.send(i).await.unwrap();
595        }
596        drop(pair.client);
597
598        let stream = pair.server.recv().await.unwrap();
599        let items: Vec<u32> = stream.map(|r| r.unwrap()).collect().await;
600        assert_eq!(items, vec![0, 1, 2, 3, 4]);
601    }
602
603    #[tokio::test]
604    async fn channel_pair_concurrent_bidirectional() {
605        let pair = channel_pair::<u32, String>(8);
606
607        // Start consumers that read exactly 3 items each.
608        // (Cannot use .collect() because each endpoint's send_tx keeps
609        // the peer's recv channel open — a known trait of paired endpoints.)
610        let consumer_server = tokio::spawn({
611            let server = pair.server.clone();
612            async move {
613                let mut stream = server.recv().await.unwrap();
614                let mut items = Vec::new();
615                for _ in 0..3 {
616                    items.push(stream.next().await.unwrap().unwrap());
617                }
618                items
619            }
620        });
621        let consumer_client = tokio::spawn({
622            let client = pair.client.clone();
623            async move {
624                let mut stream = client.recv().await.unwrap();
625                let mut items = Vec::new();
626                for _ in 0..3 {
627                    items.push(stream.next().await.unwrap().unwrap());
628                }
629                items
630            }
631        });
632
633        // Concurrently send in both directions
634        for i in 0u32..3 {
635            pair.client.send(i).await.unwrap();
636        }
637        for s in ["a", "b", "c"] {
638            pair.server.send(s.to_string()).await.unwrap();
639        }
640
641        let server_items = consumer_server.await.unwrap();
642        assert_eq!(server_items, vec![0, 1, 2]);
643
644        let client_items = consumer_client.await.unwrap();
645        assert_eq!(client_items, vec!["a", "b", "c"]);
646    }
647
648    // ── relay_binding edge cases ────────────────────────────────
649
650    #[tokio::test]
651    async fn relay_completes_when_downstream_closes() {
652        let (up_in_tx, up_in_rx) = mpsc::unbounded_channel::<u32>();
653        let (up_send_tx, _up_send_rx) = mpsc::unbounded_channel::<String>();
654
655        let (_down_in_tx, down_in_rx) = mpsc::unbounded_channel::<String>();
656        let (down_send_tx, mut down_send_rx) = mpsc::unbounded_channel::<u32>();
657
658        let upstream = Arc::new(ChannelEndpoint::new(up_in_rx, up_send_tx));
659        let downstream = Arc::new(ChannelEndpoint::new(down_in_rx, down_send_tx));
660
661        let binding = TransportBinding {
662            session: SessionId {
663                thread_id: "t".to_string(),
664            },
665            caps: TransportCapabilities::default(),
666            upstream,
667            downstream,
668        };
669
670        let cancel = RelayCancellation::new();
671        let relay = tokio::spawn(relay_binding(binding, cancel));
672
673        // Send one message upstream → downstream, then close upstream ingress
674        up_in_tx.send(42).unwrap();
675        drop(up_in_tx);
676        // Also close downstream egress source so relay's egress loop ends
677        drop(_down_in_tx);
678
679        let received = down_send_rx.recv().await.unwrap();
680        assert_eq!(received, 42);
681
682        let result = relay.await.unwrap();
683        assert!(result.is_ok(), "relay should normalize Closed to Ok");
684    }
685
686    #[tokio::test]
687    async fn relay_with_cancel_before_messages() {
688        let (_up_tx, up_rx) = mpsc::unbounded_channel::<u32>();
689        let (up_send_tx, _up_send_rx) = mpsc::unbounded_channel::<String>();
690        let (down_tx, down_rx) = mpsc::unbounded_channel::<String>();
691        let (down_send_tx, _down_send_rx) = mpsc::unbounded_channel::<u32>();
692
693        let upstream = Arc::new(ChannelEndpoint::new(up_rx, up_send_tx));
694        let downstream = Arc::new(ChannelEndpoint::new(down_rx, down_send_tx));
695
696        let binding = TransportBinding {
697            session: SessionId {
698                thread_id: "t".to_string(),
699            },
700            caps: TransportCapabilities::default(),
701            upstream,
702            downstream,
703        };
704
705        let cancel = RelayCancellation::new();
706        cancel.cancel();
707        // Close sources so streams end
708        drop(_up_tx);
709        drop(down_tx);
710
711        let result = relay_binding(binding, cancel).await;
712        assert!(result.is_ok());
713    }
714
715    #[tokio::test]
716    async fn relay_multiple_messages_in_sequence() {
717        let (up_tx, up_rx) = mpsc::unbounded_channel::<u32>();
718        let (up_send_tx, mut up_send_rx) = mpsc::unbounded_channel::<String>();
719        let (down_tx, down_rx) = mpsc::unbounded_channel::<String>();
720        let (down_send_tx, mut down_send_rx) = mpsc::unbounded_channel::<u32>();
721
722        let upstream = Arc::new(ChannelEndpoint::new(up_rx, up_send_tx));
723        let downstream = Arc::new(ChannelEndpoint::new(down_rx, down_send_tx));
724
725        let binding = TransportBinding {
726            session: SessionId {
727                thread_id: "seq".to_string(),
728            },
729            caps: TransportCapabilities::default(),
730            upstream,
731            downstream,
732        };
733
734        let cancel = RelayCancellation::new();
735        let relay = tokio::spawn(relay_binding(binding, cancel));
736
737        // ingress: upstream → downstream
738        for i in 0..3 {
739            up_tx.send(i).unwrap();
740        }
741        for expected in 0..3 {
742            assert_eq!(down_send_rx.recv().await.unwrap(), expected);
743        }
744
745        // egress: downstream → upstream
746        for s in ["x", "y", "z"] {
747            down_tx.send(s.to_string()).unwrap();
748        }
749        for expected in ["x", "y", "z"] {
750            assert_eq!(up_send_rx.recv().await.unwrap(), expected);
751        }
752
753        drop(up_tx);
754        drop(down_tx);
755        assert!(relay.await.unwrap().is_ok());
756    }
757
758    #[tokio::test]
759    async fn relay_binding_propagates_ingress_error_even_if_egress_closes_cleanly() {
760        let (up_tx, up_rx) = mpsc::unbounded_channel::<u32>();
761        let (up_send_tx, _up_send_rx) = mpsc::unbounded_channel::<String>();
762        let (down_tx, down_rx) = mpsc::unbounded_channel::<String>();
763
764        let upstream = Arc::new(ChannelEndpoint::new(up_rx, up_send_tx));
765        let downstream = Arc::new(FailingSendEndpoint::<String>::new(
766            down_rx,
767            "starter failed before streaming",
768        ));
769
770        let binding = TransportBinding {
771            session: SessionId {
772                thread_id: "ingress-error".to_string(),
773            },
774            caps: TransportCapabilities::default(),
775            upstream,
776            downstream,
777        };
778
779        up_tx.send(7).unwrap();
780        drop(up_tx);
781        drop(down_tx);
782
783        let result = relay_binding(binding, RelayCancellation::new()).await;
784        assert!(matches!(
785            result,
786            Err(TransportError::Io(message)) if message == "starter failed before streaming"
787        ));
788    }
789}