tirea_agentos/composition/registry/
stop_policy.rs1use super::sorted_registry_ids;
2use crate::runtime::StopPolicy;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6#[derive(Debug, thiserror::Error)]
7pub enum StopPolicyRegistryError {
8 #[error("stop policy id already registered: {0}")]
9 StopPolicyIdConflict(String),
10
11 #[error("stop policy id mismatch: key={key} policy.name()={policy_name}")]
12 StopPolicyIdMismatch { key: String, policy_name: String },
13}
14
15pub trait StopPolicyRegistry: Send + Sync {
16 fn len(&self) -> usize;
17
18 fn is_empty(&self) -> bool {
19 self.len() == 0
20 }
21
22 fn get(&self, id: &str) -> Option<Arc<dyn StopPolicy>>;
23
24 fn ids(&self) -> Vec<String>;
25
26 fn snapshot(&self) -> HashMap<String, Arc<dyn StopPolicy>>;
27}
28
29#[derive(Clone, Default)]
30pub struct InMemoryStopPolicyRegistry {
31 policies: HashMap<String, Arc<dyn StopPolicy>>,
32}
33
34impl std::fmt::Debug for InMemoryStopPolicyRegistry {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("InMemoryStopPolicyRegistry")
37 .field("len", &self.policies.len())
38 .finish()
39 }
40}
41
42impl InMemoryStopPolicyRegistry {
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn register_named(
48 &mut self,
49 id: impl Into<String>,
50 policy: Arc<dyn StopPolicy>,
51 ) -> Result<(), StopPolicyRegistryError> {
52 let key = id.into();
53 if self.policies.contains_key(&key) {
54 return Err(StopPolicyRegistryError::StopPolicyIdConflict(key));
55 }
56 self.policies.insert(key, policy);
57 Ok(())
58 }
59
60 pub fn extend_named(
61 &mut self,
62 policies: HashMap<String, Arc<dyn StopPolicy>>,
63 ) -> Result<(), StopPolicyRegistryError> {
64 for (key, policy) in policies {
65 self.register_named(key, policy)?;
66 }
67 Ok(())
68 }
69
70 pub fn extend_registry(
71 &mut self,
72 other: &dyn StopPolicyRegistry,
73 ) -> Result<(), StopPolicyRegistryError> {
74 self.extend_named(other.snapshot())
75 }
76}
77
78impl StopPolicyRegistry for InMemoryStopPolicyRegistry {
79 fn len(&self) -> usize {
80 self.policies.len()
81 }
82
83 fn get(&self, id: &str) -> Option<Arc<dyn StopPolicy>> {
84 self.policies.get(id).cloned()
85 }
86
87 fn ids(&self) -> Vec<String> {
88 sorted_registry_ids(&self.policies)
89 }
90
91 fn snapshot(&self) -> HashMap<String, Arc<dyn StopPolicy>> {
92 self.policies.clone()
93 }
94}
95
96#[derive(Clone, Default)]
97pub struct CompositeStopPolicyRegistry {
98 merged: InMemoryStopPolicyRegistry,
99}
100
101impl std::fmt::Debug for CompositeStopPolicyRegistry {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("CompositeStopPolicyRegistry")
104 .field("len", &self.merged.len())
105 .finish()
106 }
107}
108
109impl CompositeStopPolicyRegistry {
110 pub fn try_new(
111 regs: impl IntoIterator<Item = Arc<dyn StopPolicyRegistry>>,
112 ) -> Result<Self, StopPolicyRegistryError> {
113 let mut merged = InMemoryStopPolicyRegistry::new();
114 for r in regs {
115 merged.extend_registry(r.as_ref())?;
116 }
117 Ok(Self { merged })
118 }
119}
120
121impl StopPolicyRegistry for CompositeStopPolicyRegistry {
122 fn len(&self) -> usize {
123 self.merged.len()
124 }
125
126 fn get(&self, id: &str) -> Option<Arc<dyn StopPolicy>> {
127 self.merged.get(id)
128 }
129
130 fn ids(&self) -> Vec<String> {
131 self.merged.ids()
132 }
133
134 fn snapshot(&self) -> HashMap<String, Arc<dyn StopPolicy>> {
135 self.merged.snapshot()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::contracts::StoppedReason;
143 use crate::runtime::StopPolicyInput;
144
145 #[derive(Debug)]
146 struct MockStopPolicy {
147 name: String,
148 }
149
150 impl MockStopPolicy {
151 fn new(name: &str) -> Self {
152 Self {
153 name: name.to_string(),
154 }
155 }
156 }
157
158 impl StopPolicy for MockStopPolicy {
159 fn id(&self) -> &str {
160 &self.name
161 }
162
163 fn evaluate(&self, _input: &StopPolicyInput<'_>) -> Option<StoppedReason> {
164 None
165 }
166 }
167
168 #[test]
169 fn in_memory_register_and_get() {
170 let mut reg = InMemoryStopPolicyRegistry::new();
171 reg.register_named("max_rounds", Arc::new(MockStopPolicy::new("max_rounds")))
172 .unwrap();
173 assert_eq!(reg.len(), 1);
174 assert!(reg.get("max_rounds").is_some());
175 assert!(reg.get("other").is_none());
176 }
177
178 #[test]
179 fn in_memory_rejects_duplicate() {
180 let mut reg = InMemoryStopPolicyRegistry::new();
181 reg.register_named("p1", Arc::new(MockStopPolicy::new("p1")))
182 .unwrap();
183 let err = reg
184 .register_named("p1", Arc::new(MockStopPolicy::new("p1")))
185 .unwrap_err();
186 assert!(matches!(
187 err,
188 StopPolicyRegistryError::StopPolicyIdConflict(ref id) if id == "p1"
189 ));
190 }
191
192 #[test]
193 fn composite_merges_registries() {
194 let mut r1 = InMemoryStopPolicyRegistry::new();
195 r1.register_named("p1", Arc::new(MockStopPolicy::new("p1")))
196 .unwrap();
197 let mut r2 = InMemoryStopPolicyRegistry::new();
198 r2.register_named("p2", Arc::new(MockStopPolicy::new("p2")))
199 .unwrap();
200
201 let composite = CompositeStopPolicyRegistry::try_new(vec![
202 Arc::new(r1) as Arc<dyn StopPolicyRegistry>,
203 Arc::new(r2) as Arc<dyn StopPolicyRegistry>,
204 ])
205 .unwrap();
206
207 assert_eq!(composite.len(), 2);
208 assert!(composite.get("p1").is_some());
209 assert!(composite.get("p2").is_some());
210 }
211
212 #[test]
213 fn composite_rejects_cross_registry_duplicate() {
214 let mut r1 = InMemoryStopPolicyRegistry::new();
215 r1.register_named("dup", Arc::new(MockStopPolicy::new("dup")))
216 .unwrap();
217 let mut r2 = InMemoryStopPolicyRegistry::new();
218 r2.register_named("dup", Arc::new(MockStopPolicy::new("dup")))
219 .unwrap();
220
221 let err = CompositeStopPolicyRegistry::try_new(vec![
222 Arc::new(r1) as Arc<dyn StopPolicyRegistry>,
223 Arc::new(r2) as Arc<dyn StopPolicyRegistry>,
224 ])
225 .unwrap_err();
226 assert!(matches!(
227 err,
228 StopPolicyRegistryError::StopPolicyIdConflict(ref id) if id == "dup"
229 ));
230 }
231}