tirea_contract/runtime/state/
scope_registry.rs1use super::spec::{AnyStateAction, StateScope};
2use std::any::TypeId;
3use std::collections::HashMap;
4use tirea_state::StateSpec;
5
6#[derive(Debug, Clone, Default)]
17pub struct StateScopeRegistry {
18 by_type_id: HashMap<TypeId, (&'static str, StateScope, &'static str)>,
19}
20
21impl StateScopeRegistry {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn register<S: StateSpec>(&mut self, scope: StateScope) {
28 self.by_type_id.insert(
29 TypeId::of::<S>(),
30 (std::any::type_name::<S>(), scope, S::PATH),
31 );
32 }
33
34 pub fn scope_for_type_id(&self, type_id: TypeId) -> Option<StateScope> {
36 self.by_type_id.get(&type_id).map(|(_, scope, _)| *scope)
37 }
38
39 pub fn run_scoped_paths(&self) -> Vec<&'static str> {
44 self.by_type_id
45 .values()
46 .filter(|(_, scope, _)| *scope == StateScope::Run)
47 .map(|(_, _, path)| *path)
48 .collect()
49 }
50
51 pub fn resolve(&self, action: &AnyStateAction) -> StateScope {
56 if let Some(scope) = self.scope_for_type_id(action.state_type_id()) {
57 return scope;
58 }
59 action.scope()
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use serde::{Deserialize, Serialize};
67 use serde_json::Value;
68 use tirea_state::{DocCell, PatchSink, Path, State, TireaResult};
69
70 #[derive(Debug, Clone, Serialize, Deserialize, Default)]
71 struct RunScoped {
72 value: i64,
73 }
74
75 struct RunScopedRef;
76
77 impl State for RunScoped {
78 type Ref<'a> = RunScopedRef;
79 const PATH: &'static str = "run_scoped";
80
81 fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
82 RunScopedRef
83 }
84 fn from_value(value: &Value) -> TireaResult<Self> {
85 if value.is_null() {
86 return Ok(Self::default());
87 }
88 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
89 }
90 fn to_value(&self) -> TireaResult<Value> {
91 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
92 }
93 }
94
95 impl StateSpec for RunScoped {
96 type Action = ();
97 fn reduce(&mut self, _: ()) {}
98 }
99
100 #[derive(Debug, Clone, Serialize, Deserialize, Default)]
101 struct ToolScoped {
102 value: i64,
103 }
104
105 struct ToolScopedRef;
106
107 impl State for ToolScoped {
108 type Ref<'a> = ToolScopedRef;
109 const PATH: &'static str = "tool_scoped";
110
111 fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
112 ToolScopedRef
113 }
114 fn from_value(value: &Value) -> TireaResult<Self> {
115 if value.is_null() {
116 return Ok(Self::default());
117 }
118 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
119 }
120 fn to_value(&self) -> TireaResult<Value> {
121 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
122 }
123 }
124
125 impl StateSpec for ToolScoped {
126 type Action = ();
127 const SCOPE: StateScope = StateScope::ToolCall;
128 fn reduce(&mut self, _: ()) {}
129 }
130
131 #[test]
132 fn register_and_lookup() {
133 let mut reg = StateScopeRegistry::new();
134 reg.register::<RunScoped>(StateScope::Run);
135 reg.register::<ToolScoped>(StateScope::ToolCall);
136
137 assert_eq!(
138 reg.scope_for_type_id(TypeId::of::<RunScoped>()),
139 Some(StateScope::Run)
140 );
141 assert_eq!(
142 reg.scope_for_type_id(TypeId::of::<ToolScoped>()),
143 Some(StateScope::ToolCall)
144 );
145 }
146
147 #[test]
148 fn unregistered_type_returns_none() {
149 let reg = StateScopeRegistry::new();
150 assert_eq!(reg.scope_for_type_id(TypeId::of::<RunScoped>()), None);
151 }
152
153 #[test]
154 fn resolve_falls_back_to_action_scope() {
155 let reg = StateScopeRegistry::new();
156 let action = AnyStateAction::new::<RunScoped>(());
157 assert_eq!(reg.resolve(&action), StateScope::Thread);
158 }
159
160 #[test]
161 fn resolve_uses_registered_scope() {
162 let mut reg = StateScopeRegistry::new();
163 reg.register::<ToolScoped>(StateScope::ToolCall);
164 assert_eq!(
168 reg.scope_for_type_id(TypeId::of::<ToolScoped>()),
169 Some(StateScope::ToolCall)
170 );
171 }
172
173 #[test]
174 fn run_scoped_paths_returns_run_types() {
175 let mut reg = StateScopeRegistry::new();
176 reg.register::<RunScoped>(StateScope::Run);
177 reg.register::<ToolScoped>(StateScope::ToolCall);
178
179 let paths = reg.run_scoped_paths();
180 assert_eq!(paths.len(), 1);
181 assert_eq!(paths[0], "run_scoped");
182 }
183
184 #[test]
185 fn run_scoped_paths_empty_when_none_registered() {
186 let mut reg = StateScopeRegistry::new();
187 reg.register::<ToolScoped>(StateScope::ToolCall);
188 assert!(reg.run_scoped_paths().is_empty());
189 }
190}