tirea_agentos/composition/registry/
model.rs1use super::sorted_registry_ids;
2use super::traits::{ModelDefinition, ModelRegistry, ModelRegistryError};
3use std::collections::HashMap;
4use std::sync::Arc;
5
6#[derive(Debug, Clone, Default)]
7pub struct InMemoryModelRegistry {
8 models: HashMap<String, ModelDefinition>,
9}
10
11impl InMemoryModelRegistry {
12 pub fn new() -> Self {
13 Self::default()
14 }
15
16 pub fn len(&self) -> usize {
17 self.models.len()
18 }
19
20 pub fn is_empty(&self) -> bool {
21 self.models.is_empty()
22 }
23
24 pub fn get(&self, id: &str) -> Option<ModelDefinition> {
25 self.models.get(id).cloned()
26 }
27
28 pub fn ids(&self) -> impl Iterator<Item = &String> {
29 self.models.keys()
30 }
31
32 pub fn register(
33 &mut self,
34 model_id: impl Into<String>,
35 mut def: ModelDefinition,
36 ) -> Result<(), ModelRegistryError> {
37 let model_id = model_id.into();
38 if self.models.contains_key(&model_id) {
39 return Err(ModelRegistryError::ModelIdConflict(model_id));
40 }
41 def.provider = def.provider.trim().to_string();
42 def.model = def.model.trim().to_string();
43 if def.provider.is_empty() {
44 return Err(ModelRegistryError::EmptyProviderId);
45 }
46 if def.model.is_empty() {
47 return Err(ModelRegistryError::EmptyModelName);
48 }
49 self.models.insert(model_id, def);
50 Ok(())
51 }
52
53 pub fn extend(
54 &mut self,
55 defs: HashMap<String, ModelDefinition>,
56 ) -> Result<(), ModelRegistryError> {
57 for (id, def) in defs {
58 self.register(id, def)?;
59 }
60 Ok(())
61 }
62
63 pub fn extend_registry(&mut self, other: &dyn ModelRegistry) -> Result<(), ModelRegistryError> {
64 self.extend(other.snapshot())
65 }
66}
67
68impl ModelRegistry for InMemoryModelRegistry {
69 fn len(&self) -> usize {
70 self.len()
71 }
72
73 fn get(&self, id: &str) -> Option<ModelDefinition> {
74 self.get(id)
75 }
76
77 fn ids(&self) -> Vec<String> {
78 sorted_registry_ids(&self.models)
79 }
80
81 fn snapshot(&self) -> HashMap<String, ModelDefinition> {
82 self.models.clone()
83 }
84}
85
86#[derive(Clone, Default)]
87pub struct CompositeModelRegistry {
88 merged: InMemoryModelRegistry,
89}
90
91impl std::fmt::Debug for CompositeModelRegistry {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("CompositeModelRegistry")
94 .field("len", &self.merged.len())
95 .finish()
96 }
97}
98
99impl CompositeModelRegistry {
100 pub fn try_new(
101 regs: impl IntoIterator<Item = Arc<dyn ModelRegistry>>,
102 ) -> Result<Self, ModelRegistryError> {
103 let mut merged = InMemoryModelRegistry::new();
104 for r in regs {
105 merged.extend_registry(r.as_ref())?;
106 }
107 Ok(Self { merged })
108 }
109}
110
111impl ModelRegistry for CompositeModelRegistry {
112 fn len(&self) -> usize {
113 self.merged.len()
114 }
115
116 fn get(&self, id: &str) -> Option<ModelDefinition> {
117 self.merged.get(id)
118 }
119
120 fn ids(&self) -> Vec<String> {
121 sorted_registry_ids(&self.merged.models)
122 }
123
124 fn snapshot(&self) -> HashMap<String, ModelDefinition> {
125 self.merged.models.clone()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn model_registry_trims_provider_and_model_names() {
135 let mut registry = InMemoryModelRegistry::new();
136 registry
137 .register("m1", ModelDefinition::new(" openai ", " gemini-2.5-flash "))
138 .expect("register model");
139
140 let model = registry.get("m1").expect("stored model");
141 assert_eq!(model.provider, "openai");
142 assert_eq!(model.model, "gemini-2.5-flash");
143 }
144
145 #[test]
146 fn model_registry_rejects_whitespace_only_provider_or_model() {
147 let mut registry = InMemoryModelRegistry::new();
148 assert!(matches!(
149 registry.register("m1", ModelDefinition::new(" ", "gpt-4o-mini")),
150 Err(ModelRegistryError::EmptyProviderId)
151 ));
152 assert!(matches!(
153 registry.register("m2", ModelDefinition::new("openai", " ")),
154 Err(ModelRegistryError::EmptyModelName)
155 ));
156 }
157}