tirea_agentos/composition/registry/
stop_policy.rs

1use super::sorted_registry_ids;
2use crate::runtime::StopPolicy;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6#[derive(Debug, thiserror::Error)]
7pub enum StopPolicyRegistryError {
8    #[error("stop policy id already registered: {0}")]
9    StopPolicyIdConflict(String),
10
11    #[error("stop policy id mismatch: key={key} policy.name()={policy_name}")]
12    StopPolicyIdMismatch { key: String, policy_name: String },
13}
14
15pub trait StopPolicyRegistry: Send + Sync {
16    fn len(&self) -> usize;
17
18    fn is_empty(&self) -> bool {
19        self.len() == 0
20    }
21
22    fn get(&self, id: &str) -> Option<Arc<dyn StopPolicy>>;
23
24    fn ids(&self) -> Vec<String>;
25
26    fn snapshot(&self) -> HashMap<String, Arc<dyn StopPolicy>>;
27}
28
29#[derive(Clone, Default)]
30pub struct InMemoryStopPolicyRegistry {
31    policies: HashMap<String, Arc<dyn StopPolicy>>,
32}
33
34impl std::fmt::Debug for InMemoryStopPolicyRegistry {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("InMemoryStopPolicyRegistry")
37            .field("len", &self.policies.len())
38            .finish()
39    }
40}
41
42impl InMemoryStopPolicyRegistry {
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    pub fn register_named(
48        &mut self,
49        id: impl Into<String>,
50        policy: Arc<dyn StopPolicy>,
51    ) -> Result<(), StopPolicyRegistryError> {
52        let key = id.into();
53        if self.policies.contains_key(&key) {
54            return Err(StopPolicyRegistryError::StopPolicyIdConflict(key));
55        }
56        self.policies.insert(key, policy);
57        Ok(())
58    }
59
60    pub fn extend_named(
61        &mut self,
62        policies: HashMap<String, Arc<dyn StopPolicy>>,
63    ) -> Result<(), StopPolicyRegistryError> {
64        for (key, policy) in policies {
65            self.register_named(key, policy)?;
66        }
67        Ok(())
68    }
69
70    pub fn extend_registry(
71        &mut self,
72        other: &dyn StopPolicyRegistry,
73    ) -> Result<(), StopPolicyRegistryError> {
74        self.extend_named(other.snapshot())
75    }
76}
77
78impl StopPolicyRegistry for InMemoryStopPolicyRegistry {
79    fn len(&self) -> usize {
80        self.policies.len()
81    }
82
83    fn get(&self, id: &str) -> Option<Arc<dyn StopPolicy>> {
84        self.policies.get(id).cloned()
85    }
86
87    fn ids(&self) -> Vec<String> {
88        sorted_registry_ids(&self.policies)
89    }
90
91    fn snapshot(&self) -> HashMap<String, Arc<dyn StopPolicy>> {
92        self.policies.clone()
93    }
94}
95
96#[derive(Clone, Default)]
97pub struct CompositeStopPolicyRegistry {
98    merged: InMemoryStopPolicyRegistry,
99}
100
101impl std::fmt::Debug for CompositeStopPolicyRegistry {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("CompositeStopPolicyRegistry")
104            .field("len", &self.merged.len())
105            .finish()
106    }
107}
108
109impl CompositeStopPolicyRegistry {
110    pub fn try_new(
111        regs: impl IntoIterator<Item = Arc<dyn StopPolicyRegistry>>,
112    ) -> Result<Self, StopPolicyRegistryError> {
113        let mut merged = InMemoryStopPolicyRegistry::new();
114        for r in regs {
115            merged.extend_registry(r.as_ref())?;
116        }
117        Ok(Self { merged })
118    }
119}
120
121impl StopPolicyRegistry for CompositeStopPolicyRegistry {
122    fn len(&self) -> usize {
123        self.merged.len()
124    }
125
126    fn get(&self, id: &str) -> Option<Arc<dyn StopPolicy>> {
127        self.merged.get(id)
128    }
129
130    fn ids(&self) -> Vec<String> {
131        self.merged.ids()
132    }
133
134    fn snapshot(&self) -> HashMap<String, Arc<dyn StopPolicy>> {
135        self.merged.snapshot()
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::contracts::StoppedReason;
143    use crate::runtime::StopPolicyInput;
144
145    #[derive(Debug)]
146    struct MockStopPolicy {
147        name: String,
148    }
149
150    impl MockStopPolicy {
151        fn new(name: &str) -> Self {
152            Self {
153                name: name.to_string(),
154            }
155        }
156    }
157
158    impl StopPolicy for MockStopPolicy {
159        fn id(&self) -> &str {
160            &self.name
161        }
162
163        fn evaluate(&self, _input: &StopPolicyInput<'_>) -> Option<StoppedReason> {
164            None
165        }
166    }
167
168    #[test]
169    fn in_memory_register_and_get() {
170        let mut reg = InMemoryStopPolicyRegistry::new();
171        reg.register_named("max_rounds", Arc::new(MockStopPolicy::new("max_rounds")))
172            .unwrap();
173        assert_eq!(reg.len(), 1);
174        assert!(reg.get("max_rounds").is_some());
175        assert!(reg.get("other").is_none());
176    }
177
178    #[test]
179    fn in_memory_rejects_duplicate() {
180        let mut reg = InMemoryStopPolicyRegistry::new();
181        reg.register_named("p1", Arc::new(MockStopPolicy::new("p1")))
182            .unwrap();
183        let err = reg
184            .register_named("p1", Arc::new(MockStopPolicy::new("p1")))
185            .unwrap_err();
186        assert!(matches!(
187            err,
188            StopPolicyRegistryError::StopPolicyIdConflict(ref id) if id == "p1"
189        ));
190    }
191
192    #[test]
193    fn composite_merges_registries() {
194        let mut r1 = InMemoryStopPolicyRegistry::new();
195        r1.register_named("p1", Arc::new(MockStopPolicy::new("p1")))
196            .unwrap();
197        let mut r2 = InMemoryStopPolicyRegistry::new();
198        r2.register_named("p2", Arc::new(MockStopPolicy::new("p2")))
199            .unwrap();
200
201        let composite = CompositeStopPolicyRegistry::try_new(vec![
202            Arc::new(r1) as Arc<dyn StopPolicyRegistry>,
203            Arc::new(r2) as Arc<dyn StopPolicyRegistry>,
204        ])
205        .unwrap();
206
207        assert_eq!(composite.len(), 2);
208        assert!(composite.get("p1").is_some());
209        assert!(composite.get("p2").is_some());
210    }
211
212    #[test]
213    fn composite_rejects_cross_registry_duplicate() {
214        let mut r1 = InMemoryStopPolicyRegistry::new();
215        r1.register_named("dup", Arc::new(MockStopPolicy::new("dup")))
216            .unwrap();
217        let mut r2 = InMemoryStopPolicyRegistry::new();
218        r2.register_named("dup", Arc::new(MockStopPolicy::new("dup")))
219            .unwrap();
220
221        let err = CompositeStopPolicyRegistry::try_new(vec![
222            Arc::new(r1) as Arc<dyn StopPolicyRegistry>,
223            Arc::new(r2) as Arc<dyn StopPolicyRegistry>,
224        ])
225        .unwrap_err();
226        assert!(matches!(
227            err,
228            StopPolicyRegistryError::StopPolicyIdConflict(ref id) if id == "dup"
229        ));
230    }
231}