tirea_agentos_server/transport/
runtime_endpoint.rs1use std::future::Future;
12use std::pin::Pin;
13
14use async_trait::async_trait;
15use futures::StreamExt;
16use tirea_agentos::contracts::{AgentEvent, RunRequest};
17use tirea_agentos::runtime::RunStream;
18use tirea_contract::RuntimeInput;
19use tokio::sync::{mpsc, Mutex};
20
21use crate::transport::{BoxStream, Endpoint, TransportError};
22
23const DEFAULT_EVENT_BUFFER: usize = 64;
24
25type RunStartResult = Result<RunStream, TransportError>;
27
28pub type RunStarter =
33 Box<dyn FnOnce(RunRequest) -> Pin<Box<dyn Future<Output = RunStartResult> + Send>> + Send>;
34
35pub struct RuntimeEndpoint {
40 event_tx: Mutex<Option<mpsc::Sender<AgentEvent>>>,
41 event_rx: Mutex<Option<mpsc::Receiver<AgentEvent>>>,
42 run_starter: Mutex<Option<RunStarter>>,
43}
44
45impl RuntimeEndpoint {
46 pub fn new(starter: RunStarter) -> Self {
48 Self::with_buffer(starter, DEFAULT_EVENT_BUFFER)
49 }
50
51 pub fn with_buffer(starter: RunStarter, buffer: usize) -> Self {
53 let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(buffer.max(1));
54 Self {
55 event_tx: Mutex::new(Some(event_tx)),
56 event_rx: Mutex::new(Some(event_rx)),
57 run_starter: Mutex::new(Some(starter)),
58 }
59 }
60
61 pub fn from_run_stream(run: RunStream) -> Self {
65 Self::from_run_stream_with_buffer(run, DEFAULT_EVENT_BUFFER)
66 }
67
68 pub fn from_run_stream_with_buffer(run: RunStream, buffer: usize) -> Self {
70 let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(buffer.max(1));
71
72 Self::spawn_event_pump(event_tx, run);
73
74 Self {
75 event_tx: Mutex::new(None),
76 event_rx: Mutex::new(Some(event_rx)),
77 run_starter: Mutex::new(None),
78 }
79 }
80
81 async fn start_run(&self, request: RunRequest) -> Result<(), TransportError> {
83 let starter = self
84 .run_starter
85 .lock()
86 .await
87 .take()
88 .ok_or_else(|| TransportError::Internal("run already started".into()))?;
89
90 let event_tx = self
91 .event_tx
92 .lock()
93 .await
94 .take()
95 .ok_or_else(|| TransportError::Internal("event pump already started".into()))?;
96
97 let run = starter(request).await?;
98
99 Self::spawn_event_pump(event_tx, run);
100
101 Ok(())
102 }
103
104 fn spawn_event_pump(event_tx: mpsc::Sender<AgentEvent>, run: RunStream) {
105 tokio::spawn(async move {
106 let mut events = run.events;
107 while let Some(e) = events.next().await {
108 if event_tx.send(e).await.is_err() {
109 break;
110 }
111 }
112 });
114 }
115}
116
117#[async_trait]
118impl Endpoint<AgentEvent, RuntimeInput> for RuntimeEndpoint {
119 async fn recv(&self) -> Result<BoxStream<AgentEvent>, TransportError> {
120 let mut guard = self.event_rx.lock().await;
121 let mut rx = guard.take().ok_or(TransportError::Closed)?;
122 let stream = async_stream::stream! {
123 while let Some(item) = rx.recv().await {
124 yield Ok(item);
125 }
126 };
127 Ok(Box::pin(stream))
128 }
129
130 async fn send(&self, item: RuntimeInput) -> Result<(), TransportError> {
131 match item {
132 RuntimeInput::Run(request) => self.start_run(request).await,
133 RuntimeInput::Decision(_) => Err(TransportError::Internal(
134 "decision ingress must be handled by AgentOS ThreadRunHandle".into(),
135 )),
136 RuntimeInput::Cancel => Err(TransportError::Internal(
137 "cancel ingress must be handled by AgentOS ThreadRunHandle".into(),
138 )),
139 }
140 }
141
142 async fn close(&self) -> Result<(), TransportError> {
144 Ok(())
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use std::pin::Pin;
152 use tirea_agentos::contracts::AgentEvent;
153 use tirea_agentos::contracts::ToolCallDecision;
154 use tirea_contract::RunOrigin;
155
156 fn test_run_request() -> RunRequest {
157 RunRequest {
158 agent_id: "test".into(),
159 thread_id: None,
160 run_id: None,
161 parent_run_id: None,
162 parent_thread_id: None,
163 resource_id: None,
164 origin: RunOrigin::default(),
165 state: None,
166 messages: vec![],
167 initial_decisions: vec![],
168 source_mailbox_entry_id: None,
169 }
170 }
171
172 fn fake_run(events: Vec<AgentEvent>) -> RunStream {
173 let (decision_tx, _decision_rx) = mpsc::unbounded_channel();
174 let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(64);
175
176 tokio::spawn(async move {
177 for e in events {
178 let _ = event_tx.send(e).await;
179 }
180 });
181
182 let stream: Pin<Box<dyn futures::Stream<Item = AgentEvent> + Send>> =
183 Box::pin(async_stream::stream! {
184 let mut rx = event_rx;
185 while let Some(item) = rx.recv().await {
186 yield item;
187 }
188 });
189
190 RunStream {
191 thread_id: "t1".to_string(),
192 run_id: "r1".to_string(),
193 decision_tx,
194 events: stream,
195 }
196 }
197
198 fn fake_starter(events: Vec<AgentEvent>) -> RunStarter {
199 let run = fake_run(events);
200 let starter: RunStarter = Box::new(move |_request| Box::pin(async move { Ok(run) }));
201 starter
202 }
203
204 #[tokio::test]
207 async fn from_run_stream_recv_delivers_events() {
208 let run = fake_run(vec![
209 AgentEvent::TextDelta { delta: "a".into() },
210 AgentEvent::TextDelta { delta: "b".into() },
211 ]);
212 let ep = RuntimeEndpoint::from_run_stream(run);
213 let stream = ep.recv().await.unwrap();
214 let items: Vec<AgentEvent> = stream.map(|r| r.unwrap()).collect().await;
215 assert_eq!(items.len(), 2);
216 }
217
218 #[tokio::test]
219 async fn from_run_stream_decision_is_rejected() {
220 let run = fake_run(vec![]);
221 let ep = RuntimeEndpoint::from_run_stream(run);
222 let d = ToolCallDecision::resume("tc1", serde_json::Value::Null, 0);
223 let err = ep.send(RuntimeInput::Decision(d)).await;
224 assert!(err.is_err());
225 }
226
227 #[tokio::test]
228 async fn from_run_stream_close_does_not_cancel() {
229 let run = fake_run(vec![]);
230 let ep = RuntimeEndpoint::from_run_stream(run);
231 ep.close().await.unwrap();
232 }
233
234 #[tokio::test]
237 async fn run_message_starts_execution() {
238 let starter = fake_starter(vec![AgentEvent::TextDelta { delta: "x".into() }]);
239 let ep = RuntimeEndpoint::new(starter);
240 let stream = ep.recv().await.unwrap();
241
242 ep.send(RuntimeInput::Run(test_run_request()))
244 .await
245 .unwrap();
246
247 let items: Vec<AgentEvent> = stream.map(|r| r.unwrap()).collect().await;
248 assert_eq!(items.len(), 1);
249 }
250
251 #[tokio::test]
252 async fn decision_after_run_returns_error() {
253 let starter = fake_starter(vec![]);
254 let ep = RuntimeEndpoint::new(starter);
255
256 ep.send(RuntimeInput::Run(test_run_request()))
257 .await
258 .unwrap();
259
260 let d = ToolCallDecision::resume("tc1", serde_json::Value::Null, 0);
261 let result = ep.send(RuntimeInput::Decision(d)).await;
262 assert!(result.is_err());
263 }
264
265 #[tokio::test]
266 async fn decision_before_run_returns_error() {
267 let starter = fake_starter(vec![]);
268 let ep = RuntimeEndpoint::new(starter);
269 let d = ToolCallDecision::resume("tc1", serde_json::Value::Null, 0);
270 let result = ep.send(RuntimeInput::Decision(d)).await;
271 assert!(result.is_err());
272 }
273
274 #[tokio::test]
275 async fn double_run_returns_error() {
276 let starter = fake_starter(vec![]);
277 let ep = RuntimeEndpoint::new(starter);
278 ep.send(RuntimeInput::Run(test_run_request()))
279 .await
280 .unwrap();
281 let result = ep.send(RuntimeInput::Run(test_run_request())).await;
282 assert!(result.is_err());
283 }
284
285 #[tokio::test]
286 async fn cancel_returns_error() {
287 let run = fake_run(vec![]);
288 let starter: RunStarter = Box::new(move |_request| Box::pin(async move { Ok(run) }));
289 let ep = RuntimeEndpoint::new(starter);
290
291 ep.send(RuntimeInput::Run(test_run_request()))
292 .await
293 .unwrap();
294 let result = ep.send(RuntimeInput::Cancel).await;
295 assert!(result.is_err());
296 }
297
298 #[tokio::test]
299 async fn recv_called_twice_returns_closed() {
300 let starter = fake_starter(vec![]);
301 let ep = RuntimeEndpoint::new(starter);
302 let _first = ep.recv().await.unwrap();
303 assert!(matches!(ep.recv().await, Err(TransportError::Closed)));
304 }
305}