tirea_contract/runtime/state/
serialized_state_action.rs

1use super::spec::{AnyStateAction, StateScope};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use tirea_state::StateSpec;
6
7/// Serialized state action, sufficient to reconstruct an [`AnyStateAction`].
8///
9/// Captured at the point where a tool completes execution, before the batch
10/// commit. On crash recovery, these entries are deserialized back into
11/// `AnyStateAction` via [`StateActionDeserializerRegistry`] and re-reduced against
12/// the base state.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SerializedStateAction {
15    /// `std::any::type_name::<S>()` — used as the registry lookup key.
16    pub state_type_name: String,
17    /// `S::PATH` — the canonical JSON path for this state type.
18    pub base_path: String,
19    /// Whether this action targets thread-, run-, or tool-call-level state.
20    pub scope: StateScope,
21    /// When set, overrides the scope context call_id for path resolution.
22    pub call_id_override: Option<String>,
23    /// The serialized `S::Action` value.
24    pub payload: Value,
25}
26
27/// Errors from action deserialization operations.
28#[derive(Debug, thiserror::Error)]
29pub enum StateActionDecodeError {
30    #[error("unknown state type: {0}")]
31    UnknownStateType(String),
32    #[error("action deserialization failed for {state_type}: {source}")]
33    DeserializationFailed {
34        state_type: String,
35        source: serde_json::Error,
36    },
37}
38
39// ---------------------------------------------------------------------------
40// AnyStateAction → SerializedStateAction
41// ---------------------------------------------------------------------------
42
43impl AnyStateAction {
44    /// Convert this action into a serialized form for persistence.
45    pub fn to_serialized_state_action(&self) -> SerializedStateAction {
46        SerializedStateAction {
47            state_type_name: self.state_type_name().to_owned(),
48            base_path: self.base_path().to_owned(),
49            scope: self.scope(),
50            call_id_override: self.call_id_override().map(str::to_owned),
51            payload: self.serialized_payload().clone(),
52        }
53    }
54}
55
56// ---------------------------------------------------------------------------
57// StateActionDeserializerRegistry
58// ---------------------------------------------------------------------------
59
60type ActionFactory = Box<
61    dyn Fn(&SerializedStateAction) -> Result<AnyStateAction, StateActionDecodeError> + Send + Sync,
62>;
63
64/// Registry that maps `state_type_name` → factory closure for reconstructing
65/// `AnyStateAction` from a [`SerializedStateAction`].
66///
67/// Built once at agent construction (alongside `StateScopeRegistry` and
68/// `LatticeRegistry`) by calling `register::<S>()` for every `StateSpec` type.
69pub struct StateActionDeserializerRegistry {
70    factories: HashMap<String, ActionFactory>,
71}
72
73impl StateActionDeserializerRegistry {
74    pub fn new() -> Self {
75        Self {
76            factories: HashMap::new(),
77        }
78    }
79
80    /// Register a `StateSpec` type so its actions can be deserialized.
81    pub fn register<S: StateSpec>(&mut self) {
82        let type_name = std::any::type_name::<S>().to_owned();
83        self.factories.insert(
84            type_name,
85            Box::new(|entry: &SerializedStateAction| {
86                let action: S::Action =
87                    serde_json::from_value(entry.payload.clone()).map_err(|e| {
88                        StateActionDecodeError::DeserializationFailed {
89                            state_type: entry.state_type_name.clone(),
90                            source: e,
91                        }
92                    })?;
93                match entry.scope {
94                    StateScope::Thread | StateScope::Run => {
95                        Ok(AnyStateAction::new_at::<S>(entry.base_path.clone(), action))
96                    }
97                    StateScope::ToolCall => {
98                        let call_id = entry.call_id_override.as_deref().unwrap_or("");
99                        Ok(AnyStateAction::new_for_call_at::<S>(
100                            entry.base_path.clone(),
101                            action,
102                            call_id.to_owned(),
103                        ))
104                    }
105                }
106            }),
107        );
108    }
109
110    /// Deserialize a [`SerializedStateAction`] back into an [`AnyStateAction`].
111    pub fn deserialize(
112        &self,
113        entry: &SerializedStateAction,
114    ) -> Result<AnyStateAction, StateActionDecodeError> {
115        let factory = self.factories.get(&entry.state_type_name).ok_or_else(|| {
116            StateActionDecodeError::UnknownStateType(entry.state_type_name.clone())
117        })?;
118        factory(entry)
119    }
120}
121
122impl Default for StateActionDeserializerRegistry {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128impl std::fmt::Debug for StateActionDeserializerRegistry {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        f.debug_struct("StateActionDeserializerRegistry")
131            .field(
132                "registered_types",
133                &self.factories.keys().collect::<Vec<_>>(),
134            )
135            .finish()
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::super::scope_context::ScopeContext;
142    use super::super::spec::reduce_state_actions;
143    use super::*;
144    use serde::{Deserialize, Serialize};
145    use serde_json::json;
146    use tirea_state::{apply_patch, DocCell, PatchSink, Path, State, TireaResult};
147
148    // -- Test state type --
149
150    #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
151    struct TestCounter {
152        value: i64,
153    }
154
155    struct TestCounterRef;
156
157    impl State for TestCounter {
158        type Ref<'a> = TestCounterRef;
159        const PATH: &'static str = "test_counter";
160
161        fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
162            TestCounterRef
163        }
164
165        fn from_value(value: &Value) -> TireaResult<Self> {
166            if value.is_null() {
167                return Ok(Self::default());
168            }
169            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
170        }
171
172        fn to_value(&self) -> TireaResult<Value> {
173            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
174        }
175    }
176
177    #[derive(Debug, Serialize, Deserialize)]
178    enum TestCounterAction {
179        Increment(i64),
180        Reset,
181    }
182
183    impl StateSpec for TestCounter {
184        type Action = TestCounterAction;
185
186        fn reduce(&mut self, action: TestCounterAction) {
187            match action {
188                TestCounterAction::Increment(n) => self.value += n,
189                TestCounterAction::Reset => self.value = 0,
190            }
191        }
192    }
193
194    // -- ToolCall-scoped test state type --
195
196    #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
197    struct ToolCallTestCounter {
198        value: i64,
199    }
200
201    struct ToolCallTestCounterRef;
202
203    impl State for ToolCallTestCounter {
204        type Ref<'a> = ToolCallTestCounterRef;
205        const PATH: &'static str = "tc_counter";
206
207        fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
208            ToolCallTestCounterRef
209        }
210
211        fn from_value(value: &Value) -> TireaResult<Self> {
212            if value.is_null() {
213                return Ok(Self::default());
214            }
215            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
216        }
217
218        fn to_value(&self) -> TireaResult<Value> {
219            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
220        }
221    }
222
223    impl StateSpec for ToolCallTestCounter {
224        type Action = TestCounterAction;
225        const SCOPE: StateScope = StateScope::ToolCall;
226
227        fn reduce(&mut self, action: TestCounterAction) {
228            match action {
229                TestCounterAction::Increment(n) => self.value += n,
230                TestCounterAction::Reset => self.value = 0,
231            }
232        }
233    }
234
235    #[test]
236    fn to_serialized_state_action_roundtrip() {
237        let original = AnyStateAction::new::<TestCounter>(TestCounterAction::Increment(42));
238        let serialized = original.to_serialized_state_action();
239
240        assert!(serialized.state_type_name.contains("TestCounter"));
241        assert_eq!(serialized.base_path, "test_counter");
242        assert_eq!(serialized.scope, StateScope::Thread);
243        assert!(serialized.call_id_override.is_none());
244        assert_eq!(serialized.payload, json!({"Increment": 42}));
245    }
246
247    #[test]
248    fn registry_deserialize_and_reduce_roundtrip() {
249        let mut registry = StateActionDeserializerRegistry::new();
250        registry.register::<TestCounter>();
251
252        // Create original action, serialize, then deserialize through registry
253        let original = AnyStateAction::new::<TestCounter>(TestCounterAction::Increment(7));
254        let serialized = original.to_serialized_state_action();
255
256        let reconstructed = registry.deserialize(&serialized).unwrap();
257
258        // Reduce both against the same base state
259        let base = json!({});
260        let original_patches =
261            reduce_state_actions(vec![original], &base, "test", &ScopeContext::run()).unwrap();
262        let reconstructed_patches =
263            reduce_state_actions(vec![reconstructed], &base, "test", &ScopeContext::run()).unwrap();
264
265        // Both should produce identical results
266        let result_a = apply_patch(&base, original_patches[0].patch()).unwrap();
267        let result_b = apply_patch(&base, reconstructed_patches[0].patch()).unwrap();
268        assert_eq!(result_a, result_b);
269        assert_eq!(result_a["test_counter"]["value"], 7);
270    }
271
272    #[test]
273    fn registry_unknown_type_returns_error() {
274        let registry = StateActionDeserializerRegistry::new();
275        let entry = SerializedStateAction {
276            state_type_name: "unknown::Type".into(),
277            base_path: "x".into(),
278            scope: StateScope::Run,
279            call_id_override: None,
280            payload: json!(null),
281        };
282        let err = registry.deserialize(&entry).unwrap_err();
283        assert!(matches!(err, StateActionDecodeError::UnknownStateType(_)));
284    }
285
286    #[test]
287    fn registry_bad_payload_returns_deserialization_error() {
288        let mut registry = StateActionDeserializerRegistry::new();
289        registry.register::<TestCounter>();
290
291        let entry = SerializedStateAction {
292            state_type_name: std::any::type_name::<TestCounter>().into(),
293            base_path: "test_counter".into(),
294            scope: StateScope::Run,
295            call_id_override: None,
296            payload: json!({"BadVariant": 99}),
297        };
298        let err = registry.deserialize(&entry).unwrap_err();
299        assert!(matches!(
300            err,
301            StateActionDecodeError::DeserializationFailed { .. }
302        ));
303    }
304
305    #[test]
306    fn tool_call_scoped_roundtrip() {
307        let mut registry = StateActionDeserializerRegistry::new();
308        registry.register::<ToolCallTestCounter>();
309
310        let original = AnyStateAction::new_for_call::<ToolCallTestCounter>(
311            TestCounterAction::Increment(3),
312            "call_99",
313        );
314        let serialized = original.to_serialized_state_action();
315        assert_eq!(serialized.scope, StateScope::ToolCall);
316        assert_eq!(serialized.call_id_override, Some("call_99".into()));
317
318        let reconstructed = registry.deserialize(&serialized).unwrap();
319
320        let base = json!({});
321        let patches =
322            reduce_state_actions(vec![reconstructed], &base, "test", &ScopeContext::run()).unwrap();
323        let result = apply_patch(&base, patches[0].patch()).unwrap();
324        assert_eq!(
325            result["__tool_call_scope"]["call_99"]["tc_counter"]["value"],
326            3
327        );
328    }
329}