tirea_agentos/composition/registry/
tool.rs

1use super::sorted_registry_ids;
2use super::traits::{ToolRegistry, ToolRegistryError};
3use crate::contracts::runtime::tool_call::Tool;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Default)]
8pub struct InMemoryToolRegistry {
9    tools: HashMap<String, Arc<dyn Tool>>,
10}
11
12impl std::fmt::Debug for InMemoryToolRegistry {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        f.debug_struct("InMemoryToolRegistry")
15            .field("len", &self.tools.len())
16            .finish()
17    }
18}
19
20impl InMemoryToolRegistry {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    pub fn len(&self) -> usize {
26        self.tools.len()
27    }
28
29    pub fn is_empty(&self) -> bool {
30        self.tools.is_empty()
31    }
32
33    pub fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
34        self.tools.get(id).cloned()
35    }
36
37    pub fn ids(&self) -> impl Iterator<Item = &String> {
38        self.tools.keys()
39    }
40
41    pub fn register(&mut self, tool: Arc<dyn Tool>) -> Result<(), ToolRegistryError> {
42        let id = tool.descriptor().id;
43        if self.tools.contains_key(&id) {
44            return Err(ToolRegistryError::ToolIdConflict(id));
45        }
46        self.tools.insert(id, tool);
47        Ok(())
48    }
49
50    pub fn register_named(
51        &mut self,
52        id: impl Into<String>,
53        tool: Arc<dyn Tool>,
54    ) -> Result<(), ToolRegistryError> {
55        let key = id.into();
56        let descriptor_id = tool.descriptor().id;
57        if key != descriptor_id {
58            return Err(ToolRegistryError::ToolIdMismatch { key, descriptor_id });
59        }
60        if self.tools.contains_key(&key) {
61            return Err(ToolRegistryError::ToolIdConflict(key));
62        }
63        self.tools.insert(key, tool);
64        Ok(())
65    }
66
67    pub fn extend_named(
68        &mut self,
69        tools: HashMap<String, Arc<dyn Tool>>,
70    ) -> Result<(), ToolRegistryError> {
71        for (key, tool) in tools {
72            self.register_named(key, tool)?;
73        }
74        Ok(())
75    }
76
77    pub fn extend_registry(&mut self, other: &dyn ToolRegistry) -> Result<(), ToolRegistryError> {
78        self.extend_named(other.snapshot())
79    }
80
81    pub fn merge_many(
82        regs: impl IntoIterator<Item = InMemoryToolRegistry>,
83    ) -> Result<InMemoryToolRegistry, ToolRegistryError> {
84        let mut out = InMemoryToolRegistry::new();
85        for r in regs {
86            out.extend_named(r.into_map())?;
87        }
88        Ok(out)
89    }
90
91    pub fn into_map(self) -> HashMap<String, Arc<dyn Tool>> {
92        self.tools
93    }
94
95    pub fn to_map(&self) -> HashMap<String, Arc<dyn Tool>> {
96        self.tools.clone()
97    }
98}
99
100impl ToolRegistry for InMemoryToolRegistry {
101    fn len(&self) -> usize {
102        self.len()
103    }
104
105    fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
106        self.get(id)
107    }
108
109    fn ids(&self) -> Vec<String> {
110        sorted_registry_ids(&self.tools)
111    }
112
113    fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
114        self.tools.clone()
115    }
116}
117
118#[derive(Clone, Default)]
119pub struct CompositeToolRegistry {
120    registries: Vec<Arc<dyn ToolRegistry>>,
121    cached_snapshot: Arc<RwLock<HashMap<String, Arc<dyn Tool>>>>,
122}
123
124impl std::fmt::Debug for CompositeToolRegistry {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        let snapshot = match self.cached_snapshot.read() {
127            Ok(guard) => guard,
128            Err(poisoned) => poisoned.into_inner(),
129        };
130        f.debug_struct("CompositeToolRegistry")
131            .field("registries", &self.registries.len())
132            .field("len", &snapshot.len())
133            .finish()
134    }
135}
136
137impl CompositeToolRegistry {
138    pub fn try_new(
139        regs: impl IntoIterator<Item = Arc<dyn ToolRegistry>>,
140    ) -> Result<Self, ToolRegistryError> {
141        let registries: Vec<Arc<dyn ToolRegistry>> = regs.into_iter().collect();
142        let merged = Self::merge_snapshots(&registries)?;
143        Ok(Self {
144            registries,
145            cached_snapshot: Arc::new(RwLock::new(merged)),
146        })
147    }
148
149    fn merge_snapshots(
150        registries: &[Arc<dyn ToolRegistry>],
151    ) -> Result<HashMap<String, Arc<dyn Tool>>, ToolRegistryError> {
152        let mut merged = InMemoryToolRegistry::new();
153        for reg in registries {
154            merged.extend_registry(reg.as_ref())?;
155        }
156        Ok(merged.into_map())
157    }
158
159    fn refresh_snapshot(&self) -> Result<HashMap<String, Arc<dyn Tool>>, ToolRegistryError> {
160        Self::merge_snapshots(&self.registries)
161    }
162
163    fn read_cached_snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
164        match self.cached_snapshot.read() {
165            Ok(guard) => guard.clone(),
166            Err(poisoned) => poisoned.into_inner().clone(),
167        }
168    }
169
170    fn write_cached_snapshot(&self, snapshot: HashMap<String, Arc<dyn Tool>>) {
171        match self.cached_snapshot.write() {
172            Ok(mut guard) => *guard = snapshot,
173            Err(poisoned) => *poisoned.into_inner() = snapshot,
174        };
175    }
176}
177
178impl ToolRegistry for CompositeToolRegistry {
179    fn len(&self) -> usize {
180        self.snapshot().len()
181    }
182
183    fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
184        self.snapshot().get(id).cloned()
185    }
186
187    fn ids(&self) -> Vec<String> {
188        let snapshot = self.snapshot();
189        sorted_registry_ids(&snapshot)
190    }
191
192    fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
193        match self.refresh_snapshot() {
194            Ok(snapshot) => {
195                self.write_cached_snapshot(snapshot.clone());
196                snapshot
197            }
198            Err(_) => self.read_cached_snapshot(),
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError, ToolResult};
207    use crate::contracts::ToolCallContext;
208    use serde_json::json;
209
210    struct StaticTool {
211        descriptor: ToolDescriptor,
212    }
213
214    impl StaticTool {
215        fn new(id: &str) -> Self {
216            Self {
217                descriptor: ToolDescriptor::new(id, id, "test tool"),
218            }
219        }
220    }
221
222    #[async_trait::async_trait]
223    impl Tool for StaticTool {
224        fn descriptor(&self) -> ToolDescriptor {
225            self.descriptor.clone()
226        }
227
228        async fn execute(
229            &self,
230            _args: serde_json::Value,
231            _ctx: &ToolCallContext<'_>,
232        ) -> Result<ToolResult, ToolError> {
233            Ok(ToolResult::success(
234                self.descriptor.id.clone(),
235                json!({"ok": true}),
236            ))
237        }
238    }
239
240    #[derive(Default)]
241    struct MutableToolRegistry {
242        tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
243    }
244
245    impl MutableToolRegistry {
246        fn replace_ids(&self, ids: &[&str]) {
247            let mut map = HashMap::new();
248            for id in ids {
249                map.insert(
250                    (*id).to_string(),
251                    Arc::new(StaticTool::new(id)) as Arc<dyn Tool>,
252                );
253            }
254            match self.tools.write() {
255                Ok(mut guard) => *guard = map,
256                Err(poisoned) => *poisoned.into_inner() = map,
257            }
258        }
259    }
260
261    impl ToolRegistry for MutableToolRegistry {
262        fn len(&self) -> usize {
263            self.snapshot().len()
264        }
265
266        fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
267            self.snapshot().get(id).cloned()
268        }
269
270        fn ids(&self) -> Vec<String> {
271            let mut ids: Vec<String> = self.snapshot().keys().cloned().collect();
272            ids.sort();
273            ids
274        }
275
276        fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
277            match self.tools.read() {
278                Ok(guard) => guard.clone(),
279                Err(poisoned) => poisoned.into_inner().clone(),
280            }
281        }
282    }
283
284    #[test]
285    fn composite_tool_registry_reads_live_updates_from_source_registries() {
286        let dynamic = Arc::new(MutableToolRegistry::default());
287        dynamic.replace_ids(&["dynamic_a"]);
288
289        let mut static_registry = InMemoryToolRegistry::new();
290        static_registry
291            .register_named("static_tool", Arc::new(StaticTool::new("static_tool")))
292            .expect("register static tool");
293
294        let composite = CompositeToolRegistry::try_new(vec![
295            dynamic.clone() as Arc<dyn ToolRegistry>,
296            Arc::new(static_registry) as Arc<dyn ToolRegistry>,
297        ])
298        .expect("compose registries");
299
300        assert!(composite.ids().contains(&"dynamic_a".to_string()));
301        assert!(composite.ids().contains(&"static_tool".to_string()));
302
303        dynamic.replace_ids(&["dynamic_a", "dynamic_b"]);
304
305        let ids = composite.ids();
306        assert!(ids.contains(&"dynamic_a".to_string()));
307        assert!(ids.contains(&"dynamic_b".to_string()));
308        assert!(ids.contains(&"static_tool".to_string()));
309    }
310
311    #[test]
312    fn composite_tool_registry_keeps_last_good_snapshot_on_runtime_conflict() {
313        let reg_a = Arc::new(MutableToolRegistry::default());
314        reg_a.replace_ids(&["tool_a"]);
315
316        let reg_b = Arc::new(MutableToolRegistry::default());
317        reg_b.replace_ids(&["tool_b"]);
318
319        let composite = CompositeToolRegistry::try_new(vec![
320            reg_a.clone() as Arc<dyn ToolRegistry>,
321            reg_b.clone() as Arc<dyn ToolRegistry>,
322        ])
323        .expect("compose registries");
324
325        let initial_ids = composite.ids();
326        assert_eq!(
327            initial_ids,
328            vec!["tool_a".to_string(), "tool_b".to_string()]
329        );
330
331        // Introduce a conflict at runtime. Composite should fall back to last good snapshot.
332        reg_b.replace_ids(&["tool_a"]);
333
334        assert_eq!(composite.ids(), initial_ids);
335        assert!(composite.get("tool_b").is_some());
336    }
337}