tirea_agentos_server/transport/
runtime_endpoint.rs

1//! RuntimeEndpoint: message-driven Endpoint<AgentEvent, RuntimeInput>.
2//!
3//! The endpoint lifecycle is fully driven by [`RuntimeInput`] messages:
4//!
5//! 1. `Run(request)` — starts execution via the injected run factory.
6//! 2. `Decision(d)` / `Cancel` — control messages managed by AgentOS
7//!    [`ThreadRunHandle`], not by this endpoint.
8//!
9//! `close()` is transport-only and does **not** cancel the run.
10
11use std::future::Future;
12use std::pin::Pin;
13
14use async_trait::async_trait;
15use futures::StreamExt;
16use tirea_agentos::contracts::{AgentEvent, RunRequest};
17use tirea_agentos::runtime::RunStream;
18use tirea_contract::RuntimeInput;
19use tokio::sync::{mpsc, Mutex};
20
21use crate::transport::{BoxStream, Endpoint, TransportError};
22
23const DEFAULT_EVENT_BUFFER: usize = 64;
24
25/// Result produced by a run starter.
26type RunStartResult = Result<RunStream, TransportError>;
27
28/// Async factory that prepares and executes a run from a [`RunRequest`].
29///
30/// Created by protocol handlers; captures `AgentOs`, resolved agent config,
31/// and any protocol-specific state needed for run preparation.
32pub type RunStarter =
33    Box<dyn FnOnce(RunRequest) -> Pin<Box<dyn Future<Output = RunStartResult> + Send>> + Send>;
34
35/// Message-driven runtime endpoint.
36///
37/// Implements `Endpoint<AgentEvent, RuntimeInput>`. The run is started
38/// lazily when the first `RuntimeInput::Run` message arrives.
39pub struct RuntimeEndpoint {
40    event_tx: Mutex<Option<mpsc::Sender<AgentEvent>>>,
41    event_rx: Mutex<Option<mpsc::Receiver<AgentEvent>>>,
42    run_starter: Mutex<Option<RunStarter>>,
43}
44
45impl RuntimeEndpoint {
46    /// Create with a run factory that will be invoked on the first `Run` message.
47    pub fn new(starter: RunStarter) -> Self {
48        Self::with_buffer(starter, DEFAULT_EVENT_BUFFER)
49    }
50
51    /// Create with a run factory and explicit event buffer size.
52    pub fn with_buffer(starter: RunStarter, buffer: usize) -> Self {
53        let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(buffer.max(1));
54        Self {
55            event_tx: Mutex::new(Some(event_tx)),
56            event_rx: Mutex::new(Some(event_rx)),
57            run_starter: Mutex::new(Some(starter)),
58        }
59    }
60
61    /// Attach an already-started run (bypasses the `Run` message).
62    ///
63    /// Useful for tests or contexts where the run was prepared externally.
64    pub fn from_run_stream(run: RunStream) -> Self {
65        Self::from_run_stream_with_buffer(run, DEFAULT_EVENT_BUFFER)
66    }
67
68    /// Attach an already-started run with explicit buffer size.
69    pub fn from_run_stream_with_buffer(run: RunStream, buffer: usize) -> Self {
70        let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(buffer.max(1));
71
72        Self::spawn_event_pump(event_tx, run);
73
74        Self {
75            event_tx: Mutex::new(None),
76            event_rx: Mutex::new(Some(event_rx)),
77            run_starter: Mutex::new(None),
78        }
79    }
80
81    /// Start the run from a `RunRequest` using the stored factory.
82    async fn start_run(&self, request: RunRequest) -> Result<(), TransportError> {
83        let starter = self
84            .run_starter
85            .lock()
86            .await
87            .take()
88            .ok_or_else(|| TransportError::Internal("run already started".into()))?;
89
90        let event_tx = self
91            .event_tx
92            .lock()
93            .await
94            .take()
95            .ok_or_else(|| TransportError::Internal("event pump already started".into()))?;
96
97        let run = starter(request).await?;
98
99        Self::spawn_event_pump(event_tx, run);
100
101        Ok(())
102    }
103
104    fn spawn_event_pump(event_tx: mpsc::Sender<AgentEvent>, run: RunStream) {
105        tokio::spawn(async move {
106            let mut events = run.events;
107            while let Some(e) = events.next().await {
108                if event_tx.send(e).await.is_err() {
109                    break;
110                }
111            }
112            // event_tx is dropped here, closing the channel
113        });
114    }
115}
116
117#[async_trait]
118impl Endpoint<AgentEvent, RuntimeInput> for RuntimeEndpoint {
119    async fn recv(&self) -> Result<BoxStream<AgentEvent>, TransportError> {
120        let mut guard = self.event_rx.lock().await;
121        let mut rx = guard.take().ok_or(TransportError::Closed)?;
122        let stream = async_stream::stream! {
123            while let Some(item) = rx.recv().await {
124                yield Ok(item);
125            }
126        };
127        Ok(Box::pin(stream))
128    }
129
130    async fn send(&self, item: RuntimeInput) -> Result<(), TransportError> {
131        match item {
132            RuntimeInput::Run(request) => self.start_run(request).await,
133            RuntimeInput::Decision(_) => Err(TransportError::Internal(
134                "decision ingress must be handled by AgentOS ThreadRunHandle".into(),
135            )),
136            RuntimeInput::Cancel => Err(TransportError::Internal(
137                "cancel ingress must be handled by AgentOS ThreadRunHandle".into(),
138            )),
139        }
140    }
141
142    /// Transport-level close. Does **not** cancel the run.
143    async fn close(&self) -> Result<(), TransportError> {
144        Ok(())
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use std::pin::Pin;
152    use tirea_agentos::contracts::AgentEvent;
153    use tirea_agentos::contracts::ToolCallDecision;
154    use tirea_contract::RunOrigin;
155
156    fn test_run_request() -> RunRequest {
157        RunRequest {
158            agent_id: "test".into(),
159            thread_id: None,
160            run_id: None,
161            parent_run_id: None,
162            parent_thread_id: None,
163            resource_id: None,
164            origin: RunOrigin::default(),
165            state: None,
166            messages: vec![],
167            initial_decisions: vec![],
168            source_mailbox_entry_id: None,
169        }
170    }
171
172    fn fake_run(events: Vec<AgentEvent>) -> RunStream {
173        let (decision_tx, _decision_rx) = mpsc::unbounded_channel();
174        let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(64);
175
176        tokio::spawn(async move {
177            for e in events {
178                let _ = event_tx.send(e).await;
179            }
180        });
181
182        let stream: Pin<Box<dyn futures::Stream<Item = AgentEvent> + Send>> =
183            Box::pin(async_stream::stream! {
184                let mut rx = event_rx;
185                while let Some(item) = rx.recv().await {
186                    yield item;
187                }
188            });
189
190        RunStream {
191            thread_id: "t1".to_string(),
192            run_id: "r1".to_string(),
193            decision_tx,
194            events: stream,
195        }
196    }
197
198    fn fake_starter(events: Vec<AgentEvent>) -> RunStarter {
199        let run = fake_run(events);
200        let starter: RunStarter = Box::new(move |_request| Box::pin(async move { Ok(run) }));
201        starter
202    }
203
204    // ── from_run_stream tests ───────────────────────────────────────
205
206    #[tokio::test]
207    async fn from_run_stream_recv_delivers_events() {
208        let run = fake_run(vec![
209            AgentEvent::TextDelta { delta: "a".into() },
210            AgentEvent::TextDelta { delta: "b".into() },
211        ]);
212        let ep = RuntimeEndpoint::from_run_stream(run);
213        let stream = ep.recv().await.unwrap();
214        let items: Vec<AgentEvent> = stream.map(|r| r.unwrap()).collect().await;
215        assert_eq!(items.len(), 2);
216    }
217
218    #[tokio::test]
219    async fn from_run_stream_decision_is_rejected() {
220        let run = fake_run(vec![]);
221        let ep = RuntimeEndpoint::from_run_stream(run);
222        let d = ToolCallDecision::resume("tc1", serde_json::Value::Null, 0);
223        let err = ep.send(RuntimeInput::Decision(d)).await;
224        assert!(err.is_err());
225    }
226
227    #[tokio::test]
228    async fn from_run_stream_close_does_not_cancel() {
229        let run = fake_run(vec![]);
230        let ep = RuntimeEndpoint::from_run_stream(run);
231        ep.close().await.unwrap();
232    }
233
234    // ── run starter tests ───────────────────────────────────────────
235
236    #[tokio::test]
237    async fn run_message_starts_execution() {
238        let starter = fake_starter(vec![AgentEvent::TextDelta { delta: "x".into() }]);
239        let ep = RuntimeEndpoint::new(starter);
240        let stream = ep.recv().await.unwrap();
241
242        // Send Run to trigger the factory
243        ep.send(RuntimeInput::Run(test_run_request()))
244            .await
245            .unwrap();
246
247        let items: Vec<AgentEvent> = stream.map(|r| r.unwrap()).collect().await;
248        assert_eq!(items.len(), 1);
249    }
250
251    #[tokio::test]
252    async fn decision_after_run_returns_error() {
253        let starter = fake_starter(vec![]);
254        let ep = RuntimeEndpoint::new(starter);
255
256        ep.send(RuntimeInput::Run(test_run_request()))
257            .await
258            .unwrap();
259
260        let d = ToolCallDecision::resume("tc1", serde_json::Value::Null, 0);
261        let result = ep.send(RuntimeInput::Decision(d)).await;
262        assert!(result.is_err());
263    }
264
265    #[tokio::test]
266    async fn decision_before_run_returns_error() {
267        let starter = fake_starter(vec![]);
268        let ep = RuntimeEndpoint::new(starter);
269        let d = ToolCallDecision::resume("tc1", serde_json::Value::Null, 0);
270        let result = ep.send(RuntimeInput::Decision(d)).await;
271        assert!(result.is_err());
272    }
273
274    #[tokio::test]
275    async fn double_run_returns_error() {
276        let starter = fake_starter(vec![]);
277        let ep = RuntimeEndpoint::new(starter);
278        ep.send(RuntimeInput::Run(test_run_request()))
279            .await
280            .unwrap();
281        let result = ep.send(RuntimeInput::Run(test_run_request())).await;
282        assert!(result.is_err());
283    }
284
285    #[tokio::test]
286    async fn cancel_returns_error() {
287        let run = fake_run(vec![]);
288        let starter: RunStarter = Box::new(move |_request| Box::pin(async move { Ok(run) }));
289        let ep = RuntimeEndpoint::new(starter);
290
291        ep.send(RuntimeInput::Run(test_run_request()))
292            .await
293            .unwrap();
294        let result = ep.send(RuntimeInput::Cancel).await;
295        assert!(result.is_err());
296    }
297
298    #[tokio::test]
299    async fn recv_called_twice_returns_closed() {
300        let starter = fake_starter(vec![]);
301        let ep = RuntimeEndpoint::new(starter);
302        let _first = ep.recv().await.unwrap();
303        assert!(matches!(ep.recv().await, Err(TransportError::Closed)));
304    }
305}