tirea_agentos/composition/registry/
model.rs

1use super::sorted_registry_ids;
2use super::traits::{ModelDefinition, ModelRegistry, ModelRegistryError};
3use std::collections::HashMap;
4use std::sync::Arc;
5
6#[derive(Debug, Clone, Default)]
7pub struct InMemoryModelRegistry {
8    models: HashMap<String, ModelDefinition>,
9}
10
11impl InMemoryModelRegistry {
12    pub fn new() -> Self {
13        Self::default()
14    }
15
16    pub fn len(&self) -> usize {
17        self.models.len()
18    }
19
20    pub fn is_empty(&self) -> bool {
21        self.models.is_empty()
22    }
23
24    pub fn get(&self, id: &str) -> Option<ModelDefinition> {
25        self.models.get(id).cloned()
26    }
27
28    pub fn ids(&self) -> impl Iterator<Item = &String> {
29        self.models.keys()
30    }
31
32    pub fn register(
33        &mut self,
34        model_id: impl Into<String>,
35        mut def: ModelDefinition,
36    ) -> Result<(), ModelRegistryError> {
37        let model_id = model_id.into();
38        if self.models.contains_key(&model_id) {
39            return Err(ModelRegistryError::ModelIdConflict(model_id));
40        }
41        def.provider = def.provider.trim().to_string();
42        def.model = def.model.trim().to_string();
43        if def.provider.is_empty() {
44            return Err(ModelRegistryError::EmptyProviderId);
45        }
46        if def.model.is_empty() {
47            return Err(ModelRegistryError::EmptyModelName);
48        }
49        self.models.insert(model_id, def);
50        Ok(())
51    }
52
53    pub fn extend(
54        &mut self,
55        defs: HashMap<String, ModelDefinition>,
56    ) -> Result<(), ModelRegistryError> {
57        for (id, def) in defs {
58            self.register(id, def)?;
59        }
60        Ok(())
61    }
62
63    pub fn extend_registry(&mut self, other: &dyn ModelRegistry) -> Result<(), ModelRegistryError> {
64        self.extend(other.snapshot())
65    }
66}
67
68impl ModelRegistry for InMemoryModelRegistry {
69    fn len(&self) -> usize {
70        self.len()
71    }
72
73    fn get(&self, id: &str) -> Option<ModelDefinition> {
74        self.get(id)
75    }
76
77    fn ids(&self) -> Vec<String> {
78        sorted_registry_ids(&self.models)
79    }
80
81    fn snapshot(&self) -> HashMap<String, ModelDefinition> {
82        self.models.clone()
83    }
84}
85
86#[derive(Clone, Default)]
87pub struct CompositeModelRegistry {
88    merged: InMemoryModelRegistry,
89}
90
91impl std::fmt::Debug for CompositeModelRegistry {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("CompositeModelRegistry")
94            .field("len", &self.merged.len())
95            .finish()
96    }
97}
98
99impl CompositeModelRegistry {
100    pub fn try_new(
101        regs: impl IntoIterator<Item = Arc<dyn ModelRegistry>>,
102    ) -> Result<Self, ModelRegistryError> {
103        let mut merged = InMemoryModelRegistry::new();
104        for r in regs {
105            merged.extend_registry(r.as_ref())?;
106        }
107        Ok(Self { merged })
108    }
109}
110
111impl ModelRegistry for CompositeModelRegistry {
112    fn len(&self) -> usize {
113        self.merged.len()
114    }
115
116    fn get(&self, id: &str) -> Option<ModelDefinition> {
117        self.merged.get(id)
118    }
119
120    fn ids(&self) -> Vec<String> {
121        sorted_registry_ids(&self.merged.models)
122    }
123
124    fn snapshot(&self) -> HashMap<String, ModelDefinition> {
125        self.merged.models.clone()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn model_registry_trims_provider_and_model_names() {
135        let mut registry = InMemoryModelRegistry::new();
136        registry
137            .register("m1", ModelDefinition::new(" openai ", " gemini-2.5-flash "))
138            .expect("register model");
139
140        let model = registry.get("m1").expect("stored model");
141        assert_eq!(model.provider, "openai");
142        assert_eq!(model.model, "gemini-2.5-flash");
143    }
144
145    #[test]
146    fn model_registry_rejects_whitespace_only_provider_or_model() {
147        let mut registry = InMemoryModelRegistry::new();
148        assert!(matches!(
149            registry.register("m1", ModelDefinition::new("   ", "gpt-4o-mini")),
150            Err(ModelRegistryError::EmptyProviderId)
151        ));
152        assert!(matches!(
153            registry.register("m2", ModelDefinition::new("openai", "   ")),
154            Err(ModelRegistryError::EmptyModelName)
155        ));
156    }
157}