tirea_agentos/composition/registry/
provider.rs1use super::sorted_registry_ids;
2use super::traits::{ProviderRegistry, ProviderRegistryError};
3use genai::Client;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7#[derive(Clone, Default)]
8pub struct InMemoryProviderRegistry {
9 providers: HashMap<String, Client>,
10}
11
12impl std::fmt::Debug for InMemoryProviderRegistry {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 f.debug_struct("InMemoryProviderRegistry")
15 .field("len", &self.providers.len())
16 .finish()
17 }
18}
19
20impl InMemoryProviderRegistry {
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 pub fn register(
26 &mut self,
27 provider_id: impl Into<String>,
28 client: Client,
29 ) -> Result<(), ProviderRegistryError> {
30 let provider_id = provider_id.into();
31 if provider_id.trim().is_empty() {
32 return Err(ProviderRegistryError::EmptyProviderId);
33 }
34 if self.providers.contains_key(&provider_id) {
35 return Err(ProviderRegistryError::ProviderIdConflict(provider_id));
36 }
37 self.providers.insert(provider_id, client);
38 Ok(())
39 }
40
41 pub fn extend(
42 &mut self,
43 providers: HashMap<String, Client>,
44 ) -> Result<(), ProviderRegistryError> {
45 for (id, client) in providers {
46 self.register(id, client)?;
47 }
48 Ok(())
49 }
50}
51
52impl ProviderRegistry for InMemoryProviderRegistry {
53 fn len(&self) -> usize {
54 self.providers.len()
55 }
56
57 fn get(&self, id: &str) -> Option<Client> {
58 self.providers.get(id).cloned()
59 }
60
61 fn ids(&self) -> Vec<String> {
62 sorted_registry_ids(&self.providers)
63 }
64
65 fn snapshot(&self) -> HashMap<String, Client> {
66 self.providers.clone()
67 }
68}
69
70#[derive(Clone, Default)]
71pub struct CompositeProviderRegistry {
72 merged: InMemoryProviderRegistry,
73}
74
75impl std::fmt::Debug for CompositeProviderRegistry {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("CompositeProviderRegistry")
78 .field("len", &self.merged.len())
79 .finish()
80 }
81}
82
83impl CompositeProviderRegistry {
84 pub fn try_new(
85 regs: impl IntoIterator<Item = Arc<dyn ProviderRegistry>>,
86 ) -> Result<Self, ProviderRegistryError> {
87 let mut merged = InMemoryProviderRegistry::new();
88 for r in regs {
89 merged.extend(r.snapshot())?;
90 }
91 Ok(Self { merged })
92 }
93}
94
95impl ProviderRegistry for CompositeProviderRegistry {
96 fn len(&self) -> usize {
97 self.merged.len()
98 }
99
100 fn get(&self, id: &str) -> Option<Client> {
101 self.merged.get(id)
102 }
103
104 fn ids(&self) -> Vec<String> {
105 sorted_registry_ids(&self.merged.providers)
106 }
107
108 fn snapshot(&self) -> HashMap<String, Client> {
109 self.merged.snapshot()
110 }
111}