tirea_contract/runtime/state/
scope_registry.rs

1use super::spec::{AnyStateAction, StateScope};
2use std::any::TypeId;
3use std::collections::HashMap;
4use tirea_state::StateSpec;
5
6/// Registry mapping `StateSpec` types to their declared [`StateScope`] and path.
7///
8/// Built once at agent construction by calling
9/// [`AgentBehavior::register_state_scopes`] on each behavior. The loop then
10/// uses [`resolve`] to determine the effective scope of any [`AnyStateAction`],
11/// overriding the action-carried default when a registered type provides a
12/// canonical scope.
13///
14/// Also exposes [`run_scoped_paths`] for enumerating all Run-scoped state
15/// paths, enabling framework-driven cleanup at the start of each run.
16#[derive(Debug, Clone, Default)]
17pub struct StateScopeRegistry {
18    by_type_id: HashMap<TypeId, (&'static str, StateScope, &'static str)>,
19}
20
21impl StateScopeRegistry {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    /// Register a [`StateSpec`] type with an explicit [`StateScope`].
27    pub fn register<S: StateSpec>(&mut self, scope: StateScope) {
28        self.by_type_id.insert(
29            TypeId::of::<S>(),
30            (std::any::type_name::<S>(), scope, S::PATH),
31        );
32    }
33
34    /// Look up the scope of a registered type.
35    pub fn scope_for_type_id(&self, type_id: TypeId) -> Option<StateScope> {
36        self.by_type_id.get(&type_id).map(|(_, scope, _)| *scope)
37    }
38
39    /// Return the canonical paths of all registered Run-scoped state types.
40    ///
41    /// Used by `prepare_run` to emit delete patches for stale run-scoped
42    /// state before starting a new run.
43    pub fn run_scoped_paths(&self) -> Vec<&'static str> {
44        self.by_type_id
45            .values()
46            .filter(|(_, scope, _)| *scope == StateScope::Run)
47            .map(|(_, _, path)| *path)
48            .collect()
49    }
50
51    /// Resolve the scope of an [`AnyStateAction`].
52    ///
53    /// If the action targets a registered type, returns the registered scope.
54    /// Otherwise falls back to [`AnyStateAction::scope`].
55    pub fn resolve(&self, action: &AnyStateAction) -> StateScope {
56        if let Some(scope) = self.scope_for_type_id(action.state_type_id()) {
57            return scope;
58        }
59        action.scope()
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use serde::{Deserialize, Serialize};
67    use serde_json::Value;
68    use tirea_state::{DocCell, PatchSink, Path, State, TireaResult};
69
70    #[derive(Debug, Clone, Serialize, Deserialize, Default)]
71    struct RunScoped {
72        value: i64,
73    }
74
75    struct RunScopedRef;
76
77    impl State for RunScoped {
78        type Ref<'a> = RunScopedRef;
79        const PATH: &'static str = "run_scoped";
80
81        fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
82            RunScopedRef
83        }
84        fn from_value(value: &Value) -> TireaResult<Self> {
85            if value.is_null() {
86                return Ok(Self::default());
87            }
88            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
89        }
90        fn to_value(&self) -> TireaResult<Value> {
91            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
92        }
93    }
94
95    impl StateSpec for RunScoped {
96        type Action = ();
97        fn reduce(&mut self, _: ()) {}
98    }
99
100    #[derive(Debug, Clone, Serialize, Deserialize, Default)]
101    struct ToolScoped {
102        value: i64,
103    }
104
105    struct ToolScopedRef;
106
107    impl State for ToolScoped {
108        type Ref<'a> = ToolScopedRef;
109        const PATH: &'static str = "tool_scoped";
110
111        fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
112            ToolScopedRef
113        }
114        fn from_value(value: &Value) -> TireaResult<Self> {
115            if value.is_null() {
116                return Ok(Self::default());
117            }
118            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
119        }
120        fn to_value(&self) -> TireaResult<Value> {
121            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
122        }
123    }
124
125    impl StateSpec for ToolScoped {
126        type Action = ();
127        const SCOPE: StateScope = StateScope::ToolCall;
128        fn reduce(&mut self, _: ()) {}
129    }
130
131    #[test]
132    fn register_and_lookup() {
133        let mut reg = StateScopeRegistry::new();
134        reg.register::<RunScoped>(StateScope::Run);
135        reg.register::<ToolScoped>(StateScope::ToolCall);
136
137        assert_eq!(
138            reg.scope_for_type_id(TypeId::of::<RunScoped>()),
139            Some(StateScope::Run)
140        );
141        assert_eq!(
142            reg.scope_for_type_id(TypeId::of::<ToolScoped>()),
143            Some(StateScope::ToolCall)
144        );
145    }
146
147    #[test]
148    fn unregistered_type_returns_none() {
149        let reg = StateScopeRegistry::new();
150        assert_eq!(reg.scope_for_type_id(TypeId::of::<RunScoped>()), None);
151    }
152
153    #[test]
154    fn resolve_falls_back_to_action_scope() {
155        let reg = StateScopeRegistry::new();
156        let action = AnyStateAction::new::<RunScoped>(());
157        assert_eq!(reg.resolve(&action), StateScope::Thread);
158    }
159
160    #[test]
161    fn resolve_uses_registered_scope() {
162        let mut reg = StateScopeRegistry::new();
163        reg.register::<ToolScoped>(StateScope::ToolCall);
164        // ToolScoped has SCOPE=ToolCall, but new() asserts not ToolCall.
165        // Use a raw Typed variant to test resolution without assertion.
166        // Instead, register and look up directly.
167        assert_eq!(
168            reg.scope_for_type_id(TypeId::of::<ToolScoped>()),
169            Some(StateScope::ToolCall)
170        );
171    }
172
173    #[test]
174    fn run_scoped_paths_returns_run_types() {
175        let mut reg = StateScopeRegistry::new();
176        reg.register::<RunScoped>(StateScope::Run);
177        reg.register::<ToolScoped>(StateScope::ToolCall);
178
179        let paths = reg.run_scoped_paths();
180        assert_eq!(paths.len(), 1);
181        assert_eq!(paths[0], "run_scoped");
182    }
183
184    #[test]
185    fn run_scoped_paths_empty_when_none_registered() {
186        let mut reg = StateScopeRegistry::new();
187        reg.register::<ToolScoped>(StateScope::ToolCall);
188        assert!(reg.run_scoped_paths().is_empty());
189    }
190}