tirea_agentos/composition/registry/
tool.rs1use super::sorted_registry_ids;
2use super::traits::{ToolRegistry, ToolRegistryError};
3use crate::contracts::runtime::tool_call::Tool;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Default)]
8pub struct InMemoryToolRegistry {
9 tools: HashMap<String, Arc<dyn Tool>>,
10}
11
12impl std::fmt::Debug for InMemoryToolRegistry {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 f.debug_struct("InMemoryToolRegistry")
15 .field("len", &self.tools.len())
16 .finish()
17 }
18}
19
20impl InMemoryToolRegistry {
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 pub fn len(&self) -> usize {
26 self.tools.len()
27 }
28
29 pub fn is_empty(&self) -> bool {
30 self.tools.is_empty()
31 }
32
33 pub fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
34 self.tools.get(id).cloned()
35 }
36
37 pub fn ids(&self) -> impl Iterator<Item = &String> {
38 self.tools.keys()
39 }
40
41 pub fn register(&mut self, tool: Arc<dyn Tool>) -> Result<(), ToolRegistryError> {
42 let id = tool.descriptor().id;
43 if self.tools.contains_key(&id) {
44 return Err(ToolRegistryError::ToolIdConflict(id));
45 }
46 self.tools.insert(id, tool);
47 Ok(())
48 }
49
50 pub fn register_named(
51 &mut self,
52 id: impl Into<String>,
53 tool: Arc<dyn Tool>,
54 ) -> Result<(), ToolRegistryError> {
55 let key = id.into();
56 let descriptor_id = tool.descriptor().id;
57 if key != descriptor_id {
58 return Err(ToolRegistryError::ToolIdMismatch { key, descriptor_id });
59 }
60 if self.tools.contains_key(&key) {
61 return Err(ToolRegistryError::ToolIdConflict(key));
62 }
63 self.tools.insert(key, tool);
64 Ok(())
65 }
66
67 pub fn extend_named(
68 &mut self,
69 tools: HashMap<String, Arc<dyn Tool>>,
70 ) -> Result<(), ToolRegistryError> {
71 for (key, tool) in tools {
72 self.register_named(key, tool)?;
73 }
74 Ok(())
75 }
76
77 pub fn extend_registry(&mut self, other: &dyn ToolRegistry) -> Result<(), ToolRegistryError> {
78 self.extend_named(other.snapshot())
79 }
80
81 pub fn merge_many(
82 regs: impl IntoIterator<Item = InMemoryToolRegistry>,
83 ) -> Result<InMemoryToolRegistry, ToolRegistryError> {
84 let mut out = InMemoryToolRegistry::new();
85 for r in regs {
86 out.extend_named(r.into_map())?;
87 }
88 Ok(out)
89 }
90
91 pub fn into_map(self) -> HashMap<String, Arc<dyn Tool>> {
92 self.tools
93 }
94
95 pub fn to_map(&self) -> HashMap<String, Arc<dyn Tool>> {
96 self.tools.clone()
97 }
98}
99
100impl ToolRegistry for InMemoryToolRegistry {
101 fn len(&self) -> usize {
102 self.len()
103 }
104
105 fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
106 self.get(id)
107 }
108
109 fn ids(&self) -> Vec<String> {
110 sorted_registry_ids(&self.tools)
111 }
112
113 fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
114 self.tools.clone()
115 }
116}
117
118#[derive(Clone, Default)]
119pub struct CompositeToolRegistry {
120 registries: Vec<Arc<dyn ToolRegistry>>,
121 cached_snapshot: Arc<RwLock<HashMap<String, Arc<dyn Tool>>>>,
122}
123
124impl std::fmt::Debug for CompositeToolRegistry {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 let snapshot = match self.cached_snapshot.read() {
127 Ok(guard) => guard,
128 Err(poisoned) => poisoned.into_inner(),
129 };
130 f.debug_struct("CompositeToolRegistry")
131 .field("registries", &self.registries.len())
132 .field("len", &snapshot.len())
133 .finish()
134 }
135}
136
137impl CompositeToolRegistry {
138 pub fn try_new(
139 regs: impl IntoIterator<Item = Arc<dyn ToolRegistry>>,
140 ) -> Result<Self, ToolRegistryError> {
141 let registries: Vec<Arc<dyn ToolRegistry>> = regs.into_iter().collect();
142 let merged = Self::merge_snapshots(®istries)?;
143 Ok(Self {
144 registries,
145 cached_snapshot: Arc::new(RwLock::new(merged)),
146 })
147 }
148
149 fn merge_snapshots(
150 registries: &[Arc<dyn ToolRegistry>],
151 ) -> Result<HashMap<String, Arc<dyn Tool>>, ToolRegistryError> {
152 let mut merged = InMemoryToolRegistry::new();
153 for reg in registries {
154 merged.extend_registry(reg.as_ref())?;
155 }
156 Ok(merged.into_map())
157 }
158
159 fn refresh_snapshot(&self) -> Result<HashMap<String, Arc<dyn Tool>>, ToolRegistryError> {
160 Self::merge_snapshots(&self.registries)
161 }
162
163 fn read_cached_snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
164 match self.cached_snapshot.read() {
165 Ok(guard) => guard.clone(),
166 Err(poisoned) => poisoned.into_inner().clone(),
167 }
168 }
169
170 fn write_cached_snapshot(&self, snapshot: HashMap<String, Arc<dyn Tool>>) {
171 match self.cached_snapshot.write() {
172 Ok(mut guard) => *guard = snapshot,
173 Err(poisoned) => *poisoned.into_inner() = snapshot,
174 };
175 }
176}
177
178impl ToolRegistry for CompositeToolRegistry {
179 fn len(&self) -> usize {
180 self.snapshot().len()
181 }
182
183 fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
184 self.snapshot().get(id).cloned()
185 }
186
187 fn ids(&self) -> Vec<String> {
188 let snapshot = self.snapshot();
189 sorted_registry_ids(&snapshot)
190 }
191
192 fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
193 match self.refresh_snapshot() {
194 Ok(snapshot) => {
195 self.write_cached_snapshot(snapshot.clone());
196 snapshot
197 }
198 Err(_) => self.read_cached_snapshot(),
199 }
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError, ToolResult};
207 use crate::contracts::ToolCallContext;
208 use serde_json::json;
209
210 struct StaticTool {
211 descriptor: ToolDescriptor,
212 }
213
214 impl StaticTool {
215 fn new(id: &str) -> Self {
216 Self {
217 descriptor: ToolDescriptor::new(id, id, "test tool"),
218 }
219 }
220 }
221
222 #[async_trait::async_trait]
223 impl Tool for StaticTool {
224 fn descriptor(&self) -> ToolDescriptor {
225 self.descriptor.clone()
226 }
227
228 async fn execute(
229 &self,
230 _args: serde_json::Value,
231 _ctx: &ToolCallContext<'_>,
232 ) -> Result<ToolResult, ToolError> {
233 Ok(ToolResult::success(
234 self.descriptor.id.clone(),
235 json!({"ok": true}),
236 ))
237 }
238 }
239
240 #[derive(Default)]
241 struct MutableToolRegistry {
242 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
243 }
244
245 impl MutableToolRegistry {
246 fn replace_ids(&self, ids: &[&str]) {
247 let mut map = HashMap::new();
248 for id in ids {
249 map.insert(
250 (*id).to_string(),
251 Arc::new(StaticTool::new(id)) as Arc<dyn Tool>,
252 );
253 }
254 match self.tools.write() {
255 Ok(mut guard) => *guard = map,
256 Err(poisoned) => *poisoned.into_inner() = map,
257 }
258 }
259 }
260
261 impl ToolRegistry for MutableToolRegistry {
262 fn len(&self) -> usize {
263 self.snapshot().len()
264 }
265
266 fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
267 self.snapshot().get(id).cloned()
268 }
269
270 fn ids(&self) -> Vec<String> {
271 let mut ids: Vec<String> = self.snapshot().keys().cloned().collect();
272 ids.sort();
273 ids
274 }
275
276 fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
277 match self.tools.read() {
278 Ok(guard) => guard.clone(),
279 Err(poisoned) => poisoned.into_inner().clone(),
280 }
281 }
282 }
283
284 #[test]
285 fn composite_tool_registry_reads_live_updates_from_source_registries() {
286 let dynamic = Arc::new(MutableToolRegistry::default());
287 dynamic.replace_ids(&["dynamic_a"]);
288
289 let mut static_registry = InMemoryToolRegistry::new();
290 static_registry
291 .register_named("static_tool", Arc::new(StaticTool::new("static_tool")))
292 .expect("register static tool");
293
294 let composite = CompositeToolRegistry::try_new(vec![
295 dynamic.clone() as Arc<dyn ToolRegistry>,
296 Arc::new(static_registry) as Arc<dyn ToolRegistry>,
297 ])
298 .expect("compose registries");
299
300 assert!(composite.ids().contains(&"dynamic_a".to_string()));
301 assert!(composite.ids().contains(&"static_tool".to_string()));
302
303 dynamic.replace_ids(&["dynamic_a", "dynamic_b"]);
304
305 let ids = composite.ids();
306 assert!(ids.contains(&"dynamic_a".to_string()));
307 assert!(ids.contains(&"dynamic_b".to_string()));
308 assert!(ids.contains(&"static_tool".to_string()));
309 }
310
311 #[test]
312 fn composite_tool_registry_keeps_last_good_snapshot_on_runtime_conflict() {
313 let reg_a = Arc::new(MutableToolRegistry::default());
314 reg_a.replace_ids(&["tool_a"]);
315
316 let reg_b = Arc::new(MutableToolRegistry::default());
317 reg_b.replace_ids(&["tool_b"]);
318
319 let composite = CompositeToolRegistry::try_new(vec![
320 reg_a.clone() as Arc<dyn ToolRegistry>,
321 reg_b.clone() as Arc<dyn ToolRegistry>,
322 ])
323 .expect("compose registries");
324
325 let initial_ids = composite.ids();
326 assert_eq!(
327 initial_ids,
328 vec!["tool_a".to_string(), "tool_b".to_string()]
329 );
330
331 reg_b.replace_ids(&["tool_a"]);
333
334 assert_eq!(composite.ids(), initial_ids);
335 assert!(composite.get("tool_b").is_some());
336 }
337}