tirea_agentos/composition/registry/
provider.rs

1use super::sorted_registry_ids;
2use super::traits::{ProviderRegistry, ProviderRegistryError};
3use genai::Client;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7#[derive(Clone, Default)]
8pub struct InMemoryProviderRegistry {
9    providers: HashMap<String, Client>,
10}
11
12impl std::fmt::Debug for InMemoryProviderRegistry {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        f.debug_struct("InMemoryProviderRegistry")
15            .field("len", &self.providers.len())
16            .finish()
17    }
18}
19
20impl InMemoryProviderRegistry {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    pub fn register(
26        &mut self,
27        provider_id: impl Into<String>,
28        client: Client,
29    ) -> Result<(), ProviderRegistryError> {
30        let provider_id = provider_id.into();
31        if provider_id.trim().is_empty() {
32            return Err(ProviderRegistryError::EmptyProviderId);
33        }
34        if self.providers.contains_key(&provider_id) {
35            return Err(ProviderRegistryError::ProviderIdConflict(provider_id));
36        }
37        self.providers.insert(provider_id, client);
38        Ok(())
39    }
40
41    pub fn extend(
42        &mut self,
43        providers: HashMap<String, Client>,
44    ) -> Result<(), ProviderRegistryError> {
45        for (id, client) in providers {
46            self.register(id, client)?;
47        }
48        Ok(())
49    }
50}
51
52impl ProviderRegistry for InMemoryProviderRegistry {
53    fn len(&self) -> usize {
54        self.providers.len()
55    }
56
57    fn get(&self, id: &str) -> Option<Client> {
58        self.providers.get(id).cloned()
59    }
60
61    fn ids(&self) -> Vec<String> {
62        sorted_registry_ids(&self.providers)
63    }
64
65    fn snapshot(&self) -> HashMap<String, Client> {
66        self.providers.clone()
67    }
68}
69
70#[derive(Clone, Default)]
71pub struct CompositeProviderRegistry {
72    merged: InMemoryProviderRegistry,
73}
74
75impl std::fmt::Debug for CompositeProviderRegistry {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("CompositeProviderRegistry")
78            .field("len", &self.merged.len())
79            .finish()
80    }
81}
82
83impl CompositeProviderRegistry {
84    pub fn try_new(
85        regs: impl IntoIterator<Item = Arc<dyn ProviderRegistry>>,
86    ) -> Result<Self, ProviderRegistryError> {
87        let mut merged = InMemoryProviderRegistry::new();
88        for r in regs {
89            merged.extend(r.snapshot())?;
90        }
91        Ok(Self { merged })
92    }
93}
94
95impl ProviderRegistry for CompositeProviderRegistry {
96    fn len(&self) -> usize {
97        self.merged.len()
98    }
99
100    fn get(&self, id: &str) -> Option<Client> {
101        self.merged.get(id)
102    }
103
104    fn ids(&self) -> Vec<String> {
105        sorted_registry_ids(&self.merged.providers)
106    }
107
108    fn snapshot(&self) -> HashMap<String, Client> {
109        self.merged.snapshot()
110    }
111}