1use super::AgentOs;
2use crate::contracts::ToolCallDecision;
3use crate::runtime::loop_runner::RunCancellationToken;
4use bytes::Bytes;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
8
9#[derive(Clone)]
11pub struct ThreadRunHandle {
12 thread_id: String,
13 run_id: String,
14 agent_id: String,
15 cancellation_token: RunCancellationToken,
16 decision_tx: Arc<RwLock<Option<mpsc::UnboundedSender<ToolCallDecision>>>>,
17 pending_decisions: Arc<Mutex<Vec<ToolCallDecision>>>,
18 stream_fanout: Arc<RwLock<Option<broadcast::Sender<Bytes>>>>,
19}
20
21impl ThreadRunHandle {
22 pub fn thread_id(&self) -> &str {
23 &self.thread_id
24 }
25
26 pub fn run_id(&self) -> &str {
27 &self.run_id
28 }
29
30 pub fn cancellation_token(&self) -> RunCancellationToken {
31 self.cancellation_token.clone()
32 }
33
34 pub fn can_own(&self, agent_id: &str) -> bool {
35 self.agent_id == agent_id
36 }
37
38 pub async fn bind_stream_fanout(&self, fanout: broadcast::Sender<Bytes>) {
39 *self.stream_fanout.write().await = Some(fanout);
40 }
41
42 pub async fn subscribe_stream_fanout(&self) -> Option<broadcast::Receiver<Bytes>> {
43 let fanout = self.stream_fanout.read().await;
44 fanout.as_ref().map(|sender| sender.subscribe())
45 }
46
47 pub async fn bind_decision_tx(&self, tx: mpsc::UnboundedSender<ToolCallDecision>) -> bool {
48 let pending = {
49 let mut queued = self.pending_decisions.lock().await;
50 std::mem::take(&mut *queued)
51 };
52 for decision in &pending {
53 if tx.send(decision.clone()).is_err() {
54 return false;
55 }
56 }
57 *self.decision_tx.write().await = Some(tx);
58 true
59 }
60
61 pub async fn send_decisions(&self, decisions: &[ToolCallDecision]) -> bool {
62 let decision_tx = {
63 let guard = self.decision_tx.read().await;
64 guard.clone()
65 };
66 let Some(decision_tx) = decision_tx else {
67 let mut queued = self.pending_decisions.lock().await;
68 queued.extend_from_slice(decisions);
69 return true;
70 };
71 for decision in decisions {
72 if decision_tx.send(decision.clone()).is_err() {
73 return false;
74 }
75 }
76 true
77 }
78
79 pub fn cancel(&self) -> bool {
80 self.cancellation_token.cancel();
81 true
82 }
83}
84
85#[derive(Debug, Clone)]
86pub struct ForwardedDecision {
87 pub thread_id: String,
88}
89
90#[derive(Default)]
91pub(crate) struct ActiveThreadRunRegistry {
92 handles: RwLock<HashMap<String, ThreadRunHandle>>,
93 run_index: RwLock<HashMap<String, String>>,
94}
95
96impl ActiveThreadRunRegistry {
97 async fn register(
98 &self,
99 run_id: String,
100 agent_id: &str,
101 thread_id: &str,
102 token: RunCancellationToken,
103 ) {
104 let mut run_index = self.run_index.write().await;
105 let mut handles = self.handles.write().await;
106 if let Some(old) = handles.get(thread_id) {
107 run_index.remove(&old.run_id);
108 }
109 run_index.insert(run_id.clone(), thread_id.to_string());
110 handles.insert(
111 thread_id.to_string(),
112 ThreadRunHandle {
113 thread_id: thread_id.to_string(),
114 run_id,
115 agent_id: agent_id.to_string(),
116 cancellation_token: token,
117 decision_tx: Arc::new(RwLock::new(None)),
118 pending_decisions: Arc::new(Mutex::new(Vec::new())),
119 stream_fanout: Arc::new(RwLock::new(None)),
120 },
121 );
122 }
123
124 async fn handle_by_thread(&self, thread_id: &str) -> Option<ThreadRunHandle> {
125 let handles = self.handles.read().await;
126 let handle = handles.get(thread_id)?;
127 Some(handle.clone())
128 }
129
130 async fn handle_by_run_id(&self, run_id: &str) -> Option<ThreadRunHandle> {
131 let run_index = self.run_index.read().await;
132 let thread_id = run_index.get(run_id)?;
133 let handles = self.handles.read().await;
134 let handle = handles.get(thread_id)?;
135 Some(handle.clone())
136 }
137
138 pub(super) async fn remove_by_run_id(&self, run_id: &str) {
139 if let Some(thread_id) = self.run_index.write().await.remove(run_id) {
140 self.handles.write().await.remove(&thread_id);
141 }
142 }
143}
144
145impl AgentOs {
146 pub(crate) async fn register_thread_run_handle(
147 &self,
148 run_id: String,
149 agent_id: &str,
150 thread_id: &str,
151 token: RunCancellationToken,
152 ) {
153 self.active_runs
154 .register(run_id, agent_id, thread_id, token)
155 .await;
156 }
157
158 pub(crate) async fn bind_thread_run_decision_tx(
159 &self,
160 run_id: &str,
161 tx: mpsc::UnboundedSender<ToolCallDecision>,
162 ) -> bool {
163 let Some(handle) = self.active_runs.handle_by_run_id(run_id).await else {
164 return false;
165 };
166 handle.bind_decision_tx(tx).await
167 }
168
169 pub(crate) async fn remove_thread_run_handle(&self, run_id: &str) {
170 self.active_runs.remove_by_run_id(run_id).await;
171 }
172
173 pub async fn bind_thread_run_stream_fanout(
174 &self,
175 run_id: &str,
176 fanout: broadcast::Sender<Bytes>,
177 ) -> bool {
178 let Some(handle) = self.active_runs.handle_by_run_id(run_id).await else {
179 return false;
180 };
181 handle.bind_stream_fanout(fanout).await;
182 true
183 }
184
185 pub async fn subscribe_thread_run_stream(
186 &self,
187 run_id: &str,
188 ) -> Option<broadcast::Receiver<Bytes>> {
189 let handle = self.active_runs.handle_by_run_id(run_id).await?;
190 handle.subscribe_stream_fanout().await
191 }
192
193 pub(crate) async fn active_thread_run_by_run_id(
194 &self,
195 run_id: &str,
196 ) -> Option<ThreadRunHandle> {
197 self.active_runs.handle_by_run_id(run_id).await
198 }
199
200 pub async fn active_run_id_for_thread(
201 &self,
202 agent_id: &str,
203 thread_id: &str,
204 ) -> Option<String> {
205 let handle = self.active_runs.handle_by_thread(thread_id).await?;
206 if !handle.can_own(agent_id) {
207 return None;
208 }
209 Some(handle.run_id().to_string())
210 }
211
212 pub async fn forward_decisions_by_thread(
213 &self,
214 agent_id: &str,
215 thread_id: &str,
216 decisions: &[ToolCallDecision],
217 ) -> Option<ForwardedDecision> {
218 let handle = self.active_runs.handle_by_thread(thread_id).await?;
219 if !handle.can_own(agent_id) {
220 return None;
221 }
222 if handle.send_decisions(decisions).await {
223 Some(ForwardedDecision {
224 thread_id: handle.thread_id().to_string(),
225 })
226 } else {
227 self.active_runs.remove_by_run_id(handle.run_id()).await;
228 None
229 }
230 }
231
232 pub async fn forward_decisions_by_run_id(
233 &self,
234 run_id: &str,
235 decisions: &[ToolCallDecision],
236 ) -> Option<ForwardedDecision> {
237 let handle = self.active_runs.handle_by_run_id(run_id).await?;
238 if handle.send_decisions(decisions).await {
239 Some(ForwardedDecision {
240 thread_id: handle.thread_id().to_string(),
241 })
242 } else {
243 self.active_runs.remove_by_run_id(run_id).await;
244 None
245 }
246 }
247
248 pub async fn cancel_active_run_by_id(&self, run_id: &str) -> bool {
249 let Some(handle) = self.active_runs.handle_by_run_id(run_id).await else {
250 return false;
251 };
252 if !handle.cancel() {
253 self.active_runs.remove_by_run_id(run_id).await;
254 }
255 true
256 }
257
258 pub async fn cancel_active_run_by_thread(&self, thread_id: &str) -> Option<String> {
259 let handle = self.active_runs.handle_by_thread(thread_id).await?;
260 let run_id = handle.run_id().to_string();
261 if !handle.cancel() {
262 self.active_runs.remove_by_run_id(&run_id).await;
263 }
264 Some(run_id)
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 fn new_registry() -> ActiveThreadRunRegistry {
273 ActiveThreadRunRegistry::default()
274 }
275
276 fn new_channel() -> (
277 mpsc::UnboundedSender<ToolCallDecision>,
278 mpsc::UnboundedReceiver<ToolCallDecision>,
279 ) {
280 mpsc::unbounded_channel()
281 }
282
283 #[tokio::test]
284 async fn registry_register_and_lookup_by_thread() {
285 let reg = new_registry();
286 let (_decision_tx, _decision_rx) = new_channel();
287 let token = RunCancellationToken::new();
288 reg.register("run-1".into(), "agent-a", "thread-1", token)
289 .await;
290
291 let handle = reg
292 .handle_by_thread("thread-1")
293 .await
294 .expect("handle should exist");
295 assert_eq!(handle.run_id(), "run-1");
296 assert!(handle.can_own("agent-a"));
297 }
298
299 #[tokio::test]
300 async fn registry_lookup_by_run_id() {
301 let reg = new_registry();
302 let (_decision_tx, _decision_rx) = new_channel();
303 let token = RunCancellationToken::new();
304 reg.register("run-1".into(), "agent-a", "thread-1", token.clone())
305 .await;
306
307 let retrieved = reg
308 .handle_by_run_id("run-1")
309 .await
310 .expect("handle should exist");
311 let retrieved_token = retrieved.cancellation_token();
312 retrieved_token.cancel();
313 assert!(token.is_cancelled());
314 }
315
316 #[tokio::test]
317 async fn registry_thread_key_enforces_single_run_per_thread() {
318 let reg = new_registry();
319 let (_decision_tx_a, _decision_rx_a) = new_channel();
320 let (_decision_tx_b, _decision_rx_b) = new_channel();
321 let token_a = RunCancellationToken::new();
322 let token_b = RunCancellationToken::new();
323 reg.register("run-a".into(), "agent-a", "shared-thread", token_a)
324 .await;
325 reg.register("run-b".into(), "agent-b", "shared-thread", token_b)
326 .await;
327
328 let handle = reg
329 .handle_by_thread("shared-thread")
330 .await
331 .expect("handle should exist");
332 assert_eq!(handle.run_id(), "run-b");
333 assert!(handle.can_own("agent-b"));
334 assert!(reg.handle_by_run_id("run-a").await.is_none());
335 assert!(reg.handle_by_run_id("run-b").await.is_some());
336 }
337
338 #[tokio::test]
339 async fn registry_remove_cleans_both_indexes() {
340 let reg = new_registry();
341 let (_decision_tx, _decision_rx) = new_channel();
342 let token = RunCancellationToken::new();
343 reg.register("run-1".into(), "agent-a", "thread-1", token)
344 .await;
345
346 reg.remove_by_run_id("run-1").await;
347 assert!(reg.handle_by_thread("thread-1").await.is_none());
348 assert!(reg.handle_by_run_id("run-1").await.is_none());
349 }
350
351 #[tokio::test]
352 async fn stream_fanout_bind_and_subscribe() {
353 let reg = new_registry();
354 let token = RunCancellationToken::new();
355 reg.register("run-1".into(), "agent-a", "thread-1", token)
356 .await;
357 let handle = reg
358 .handle_by_run_id("run-1")
359 .await
360 .expect("handle should exist");
361 let (fanout, _rx) = broadcast::channel::<Bytes>(8);
362 handle.bind_stream_fanout(fanout.clone()).await;
363 let mut sub = handle
364 .subscribe_stream_fanout()
365 .await
366 .expect("subscription should exist");
367 fanout
368 .send(Bytes::from_static(b"chunk"))
369 .expect("send should work");
370 let got = sub.recv().await.expect("subscriber should receive chunk");
371 assert_eq!(got, Bytes::from_static(b"chunk"));
372 }
373}