tirea_agentos/runtime/
thread_run.rs

1use super::AgentOs;
2use crate::contracts::ToolCallDecision;
3use crate::runtime::loop_runner::RunCancellationToken;
4use bytes::Bytes;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
8
9/// Unified per-active-run handle managed by AgentOS.
10#[derive(Clone)]
11pub struct ThreadRunHandle {
12    thread_id: String,
13    run_id: String,
14    agent_id: String,
15    cancellation_token: RunCancellationToken,
16    decision_tx: Arc<RwLock<Option<mpsc::UnboundedSender<ToolCallDecision>>>>,
17    pending_decisions: Arc<Mutex<Vec<ToolCallDecision>>>,
18    stream_fanout: Arc<RwLock<Option<broadcast::Sender<Bytes>>>>,
19}
20
21impl ThreadRunHandle {
22    pub fn thread_id(&self) -> &str {
23        &self.thread_id
24    }
25
26    pub fn run_id(&self) -> &str {
27        &self.run_id
28    }
29
30    pub fn cancellation_token(&self) -> RunCancellationToken {
31        self.cancellation_token.clone()
32    }
33
34    pub fn can_own(&self, agent_id: &str) -> bool {
35        self.agent_id == agent_id
36    }
37
38    pub async fn bind_stream_fanout(&self, fanout: broadcast::Sender<Bytes>) {
39        *self.stream_fanout.write().await = Some(fanout);
40    }
41
42    pub async fn subscribe_stream_fanout(&self) -> Option<broadcast::Receiver<Bytes>> {
43        let fanout = self.stream_fanout.read().await;
44        fanout.as_ref().map(|sender| sender.subscribe())
45    }
46
47    pub async fn bind_decision_tx(&self, tx: mpsc::UnboundedSender<ToolCallDecision>) -> bool {
48        let pending = {
49            let mut queued = self.pending_decisions.lock().await;
50            std::mem::take(&mut *queued)
51        };
52        for decision in &pending {
53            if tx.send(decision.clone()).is_err() {
54                return false;
55            }
56        }
57        *self.decision_tx.write().await = Some(tx);
58        true
59    }
60
61    pub async fn send_decisions(&self, decisions: &[ToolCallDecision]) -> bool {
62        let decision_tx = {
63            let guard = self.decision_tx.read().await;
64            guard.clone()
65        };
66        let Some(decision_tx) = decision_tx else {
67            let mut queued = self.pending_decisions.lock().await;
68            queued.extend_from_slice(decisions);
69            return true;
70        };
71        for decision in decisions {
72            if decision_tx.send(decision.clone()).is_err() {
73                return false;
74            }
75        }
76        true
77    }
78
79    pub fn cancel(&self) -> bool {
80        self.cancellation_token.cancel();
81        true
82    }
83}
84
85#[derive(Debug, Clone)]
86pub struct ForwardedDecision {
87    pub thread_id: String,
88}
89
90#[derive(Default)]
91pub(crate) struct ActiveThreadRunRegistry {
92    handles: RwLock<HashMap<String, ThreadRunHandle>>,
93    run_index: RwLock<HashMap<String, String>>,
94}
95
96impl ActiveThreadRunRegistry {
97    async fn register(
98        &self,
99        run_id: String,
100        agent_id: &str,
101        thread_id: &str,
102        token: RunCancellationToken,
103    ) {
104        let mut run_index = self.run_index.write().await;
105        let mut handles = self.handles.write().await;
106        if let Some(old) = handles.get(thread_id) {
107            run_index.remove(&old.run_id);
108        }
109        run_index.insert(run_id.clone(), thread_id.to_string());
110        handles.insert(
111            thread_id.to_string(),
112            ThreadRunHandle {
113                thread_id: thread_id.to_string(),
114                run_id,
115                agent_id: agent_id.to_string(),
116                cancellation_token: token,
117                decision_tx: Arc::new(RwLock::new(None)),
118                pending_decisions: Arc::new(Mutex::new(Vec::new())),
119                stream_fanout: Arc::new(RwLock::new(None)),
120            },
121        );
122    }
123
124    async fn handle_by_thread(&self, thread_id: &str) -> Option<ThreadRunHandle> {
125        let handles = self.handles.read().await;
126        let handle = handles.get(thread_id)?;
127        Some(handle.clone())
128    }
129
130    async fn handle_by_run_id(&self, run_id: &str) -> Option<ThreadRunHandle> {
131        let run_index = self.run_index.read().await;
132        let thread_id = run_index.get(run_id)?;
133        let handles = self.handles.read().await;
134        let handle = handles.get(thread_id)?;
135        Some(handle.clone())
136    }
137
138    pub(super) async fn remove_by_run_id(&self, run_id: &str) {
139        if let Some(thread_id) = self.run_index.write().await.remove(run_id) {
140            self.handles.write().await.remove(&thread_id);
141        }
142    }
143}
144
145impl AgentOs {
146    pub(crate) async fn register_thread_run_handle(
147        &self,
148        run_id: String,
149        agent_id: &str,
150        thread_id: &str,
151        token: RunCancellationToken,
152    ) {
153        self.active_runs
154            .register(run_id, agent_id, thread_id, token)
155            .await;
156    }
157
158    pub(crate) async fn bind_thread_run_decision_tx(
159        &self,
160        run_id: &str,
161        tx: mpsc::UnboundedSender<ToolCallDecision>,
162    ) -> bool {
163        let Some(handle) = self.active_runs.handle_by_run_id(run_id).await else {
164            return false;
165        };
166        handle.bind_decision_tx(tx).await
167    }
168
169    pub(crate) async fn remove_thread_run_handle(&self, run_id: &str) {
170        self.active_runs.remove_by_run_id(run_id).await;
171    }
172
173    pub async fn bind_thread_run_stream_fanout(
174        &self,
175        run_id: &str,
176        fanout: broadcast::Sender<Bytes>,
177    ) -> bool {
178        let Some(handle) = self.active_runs.handle_by_run_id(run_id).await else {
179            return false;
180        };
181        handle.bind_stream_fanout(fanout).await;
182        true
183    }
184
185    pub async fn subscribe_thread_run_stream(
186        &self,
187        run_id: &str,
188    ) -> Option<broadcast::Receiver<Bytes>> {
189        let handle = self.active_runs.handle_by_run_id(run_id).await?;
190        handle.subscribe_stream_fanout().await
191    }
192
193    pub(crate) async fn active_thread_run_by_run_id(
194        &self,
195        run_id: &str,
196    ) -> Option<ThreadRunHandle> {
197        self.active_runs.handle_by_run_id(run_id).await
198    }
199
200    pub async fn active_run_id_for_thread(
201        &self,
202        agent_id: &str,
203        thread_id: &str,
204    ) -> Option<String> {
205        let handle = self.active_runs.handle_by_thread(thread_id).await?;
206        if !handle.can_own(agent_id) {
207            return None;
208        }
209        Some(handle.run_id().to_string())
210    }
211
212    pub async fn forward_decisions_by_thread(
213        &self,
214        agent_id: &str,
215        thread_id: &str,
216        decisions: &[ToolCallDecision],
217    ) -> Option<ForwardedDecision> {
218        let handle = self.active_runs.handle_by_thread(thread_id).await?;
219        if !handle.can_own(agent_id) {
220            return None;
221        }
222        if handle.send_decisions(decisions).await {
223            Some(ForwardedDecision {
224                thread_id: handle.thread_id().to_string(),
225            })
226        } else {
227            self.active_runs.remove_by_run_id(handle.run_id()).await;
228            None
229        }
230    }
231
232    pub async fn forward_decisions_by_run_id(
233        &self,
234        run_id: &str,
235        decisions: &[ToolCallDecision],
236    ) -> Option<ForwardedDecision> {
237        let handle = self.active_runs.handle_by_run_id(run_id).await?;
238        if handle.send_decisions(decisions).await {
239            Some(ForwardedDecision {
240                thread_id: handle.thread_id().to_string(),
241            })
242        } else {
243            self.active_runs.remove_by_run_id(run_id).await;
244            None
245        }
246    }
247
248    pub async fn cancel_active_run_by_id(&self, run_id: &str) -> bool {
249        let Some(handle) = self.active_runs.handle_by_run_id(run_id).await else {
250            return false;
251        };
252        if !handle.cancel() {
253            self.active_runs.remove_by_run_id(run_id).await;
254        }
255        true
256    }
257
258    pub async fn cancel_active_run_by_thread(&self, thread_id: &str) -> Option<String> {
259        let handle = self.active_runs.handle_by_thread(thread_id).await?;
260        let run_id = handle.run_id().to_string();
261        if !handle.cancel() {
262            self.active_runs.remove_by_run_id(&run_id).await;
263        }
264        Some(run_id)
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    fn new_registry() -> ActiveThreadRunRegistry {
273        ActiveThreadRunRegistry::default()
274    }
275
276    fn new_channel() -> (
277        mpsc::UnboundedSender<ToolCallDecision>,
278        mpsc::UnboundedReceiver<ToolCallDecision>,
279    ) {
280        mpsc::unbounded_channel()
281    }
282
283    #[tokio::test]
284    async fn registry_register_and_lookup_by_thread() {
285        let reg = new_registry();
286        let (_decision_tx, _decision_rx) = new_channel();
287        let token = RunCancellationToken::new();
288        reg.register("run-1".into(), "agent-a", "thread-1", token)
289            .await;
290
291        let handle = reg
292            .handle_by_thread("thread-1")
293            .await
294            .expect("handle should exist");
295        assert_eq!(handle.run_id(), "run-1");
296        assert!(handle.can_own("agent-a"));
297    }
298
299    #[tokio::test]
300    async fn registry_lookup_by_run_id() {
301        let reg = new_registry();
302        let (_decision_tx, _decision_rx) = new_channel();
303        let token = RunCancellationToken::new();
304        reg.register("run-1".into(), "agent-a", "thread-1", token.clone())
305            .await;
306
307        let retrieved = reg
308            .handle_by_run_id("run-1")
309            .await
310            .expect("handle should exist");
311        let retrieved_token = retrieved.cancellation_token();
312        retrieved_token.cancel();
313        assert!(token.is_cancelled());
314    }
315
316    #[tokio::test]
317    async fn registry_thread_key_enforces_single_run_per_thread() {
318        let reg = new_registry();
319        let (_decision_tx_a, _decision_rx_a) = new_channel();
320        let (_decision_tx_b, _decision_rx_b) = new_channel();
321        let token_a = RunCancellationToken::new();
322        let token_b = RunCancellationToken::new();
323        reg.register("run-a".into(), "agent-a", "shared-thread", token_a)
324            .await;
325        reg.register("run-b".into(), "agent-b", "shared-thread", token_b)
326            .await;
327
328        let handle = reg
329            .handle_by_thread("shared-thread")
330            .await
331            .expect("handle should exist");
332        assert_eq!(handle.run_id(), "run-b");
333        assert!(handle.can_own("agent-b"));
334        assert!(reg.handle_by_run_id("run-a").await.is_none());
335        assert!(reg.handle_by_run_id("run-b").await.is_some());
336    }
337
338    #[tokio::test]
339    async fn registry_remove_cleans_both_indexes() {
340        let reg = new_registry();
341        let (_decision_tx, _decision_rx) = new_channel();
342        let token = RunCancellationToken::new();
343        reg.register("run-1".into(), "agent-a", "thread-1", token)
344            .await;
345
346        reg.remove_by_run_id("run-1").await;
347        assert!(reg.handle_by_thread("thread-1").await.is_none());
348        assert!(reg.handle_by_run_id("run-1").await.is_none());
349    }
350
351    #[tokio::test]
352    async fn stream_fanout_bind_and_subscribe() {
353        let reg = new_registry();
354        let token = RunCancellationToken::new();
355        reg.register("run-1".into(), "agent-a", "thread-1", token)
356            .await;
357        let handle = reg
358            .handle_by_run_id("run-1")
359            .await
360            .expect("handle should exist");
361        let (fanout, _rx) = broadcast::channel::<Bytes>(8);
362        handle.bind_stream_fanout(fanout.clone()).await;
363        let mut sub = handle
364            .subscribe_stream_fanout()
365            .await
366            .expect("subscription should exist");
367        fanout
368            .send(Bytes::from_static(b"chunk"))
369            .expect("send should work");
370        let got = sub.recv().await.expect("subscriber should receive chunk");
371        assert_eq!(got, Bytes::from_static(b"chunk"));
372    }
373}