tirea_contract/runtime/state/
serialized_state_action.rs1use super::spec::{AnyStateAction, StateScope};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use tirea_state::StateSpec;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SerializedStateAction {
15 pub state_type_name: String,
17 pub base_path: String,
19 pub scope: StateScope,
21 pub call_id_override: Option<String>,
23 pub payload: Value,
25}
26
27#[derive(Debug, thiserror::Error)]
29pub enum StateActionDecodeError {
30 #[error("unknown state type: {0}")]
31 UnknownStateType(String),
32 #[error("action deserialization failed for {state_type}: {source}")]
33 DeserializationFailed {
34 state_type: String,
35 source: serde_json::Error,
36 },
37}
38
39impl AnyStateAction {
44 pub fn to_serialized_state_action(&self) -> SerializedStateAction {
46 SerializedStateAction {
47 state_type_name: self.state_type_name().to_owned(),
48 base_path: self.base_path().to_owned(),
49 scope: self.scope(),
50 call_id_override: self.call_id_override().map(str::to_owned),
51 payload: self.serialized_payload().clone(),
52 }
53 }
54}
55
56type ActionFactory = Box<
61 dyn Fn(&SerializedStateAction) -> Result<AnyStateAction, StateActionDecodeError> + Send + Sync,
62>;
63
64pub struct StateActionDeserializerRegistry {
70 factories: HashMap<String, ActionFactory>,
71}
72
73impl StateActionDeserializerRegistry {
74 pub fn new() -> Self {
75 Self {
76 factories: HashMap::new(),
77 }
78 }
79
80 pub fn register<S: StateSpec>(&mut self) {
82 let type_name = std::any::type_name::<S>().to_owned();
83 self.factories.insert(
84 type_name,
85 Box::new(|entry: &SerializedStateAction| {
86 let action: S::Action =
87 serde_json::from_value(entry.payload.clone()).map_err(|e| {
88 StateActionDecodeError::DeserializationFailed {
89 state_type: entry.state_type_name.clone(),
90 source: e,
91 }
92 })?;
93 match entry.scope {
94 StateScope::Thread | StateScope::Run => {
95 Ok(AnyStateAction::new_at::<S>(entry.base_path.clone(), action))
96 }
97 StateScope::ToolCall => {
98 let call_id = entry.call_id_override.as_deref().unwrap_or("");
99 Ok(AnyStateAction::new_for_call_at::<S>(
100 entry.base_path.clone(),
101 action,
102 call_id.to_owned(),
103 ))
104 }
105 }
106 }),
107 );
108 }
109
110 pub fn deserialize(
112 &self,
113 entry: &SerializedStateAction,
114 ) -> Result<AnyStateAction, StateActionDecodeError> {
115 let factory = self.factories.get(&entry.state_type_name).ok_or_else(|| {
116 StateActionDecodeError::UnknownStateType(entry.state_type_name.clone())
117 })?;
118 factory(entry)
119 }
120}
121
122impl Default for StateActionDeserializerRegistry {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl std::fmt::Debug for StateActionDeserializerRegistry {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("StateActionDeserializerRegistry")
131 .field(
132 "registered_types",
133 &self.factories.keys().collect::<Vec<_>>(),
134 )
135 .finish()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::super::scope_context::ScopeContext;
142 use super::super::spec::reduce_state_actions;
143 use super::*;
144 use serde::{Deserialize, Serialize};
145 use serde_json::json;
146 use tirea_state::{apply_patch, DocCell, PatchSink, Path, State, TireaResult};
147
148 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
151 struct TestCounter {
152 value: i64,
153 }
154
155 struct TestCounterRef;
156
157 impl State for TestCounter {
158 type Ref<'a> = TestCounterRef;
159 const PATH: &'static str = "test_counter";
160
161 fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
162 TestCounterRef
163 }
164
165 fn from_value(value: &Value) -> TireaResult<Self> {
166 if value.is_null() {
167 return Ok(Self::default());
168 }
169 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
170 }
171
172 fn to_value(&self) -> TireaResult<Value> {
173 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
174 }
175 }
176
177 #[derive(Debug, Serialize, Deserialize)]
178 enum TestCounterAction {
179 Increment(i64),
180 Reset,
181 }
182
183 impl StateSpec for TestCounter {
184 type Action = TestCounterAction;
185
186 fn reduce(&mut self, action: TestCounterAction) {
187 match action {
188 TestCounterAction::Increment(n) => self.value += n,
189 TestCounterAction::Reset => self.value = 0,
190 }
191 }
192 }
193
194 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
197 struct ToolCallTestCounter {
198 value: i64,
199 }
200
201 struct ToolCallTestCounterRef;
202
203 impl State for ToolCallTestCounter {
204 type Ref<'a> = ToolCallTestCounterRef;
205 const PATH: &'static str = "tc_counter";
206
207 fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
208 ToolCallTestCounterRef
209 }
210
211 fn from_value(value: &Value) -> TireaResult<Self> {
212 if value.is_null() {
213 return Ok(Self::default());
214 }
215 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
216 }
217
218 fn to_value(&self) -> TireaResult<Value> {
219 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
220 }
221 }
222
223 impl StateSpec for ToolCallTestCounter {
224 type Action = TestCounterAction;
225 const SCOPE: StateScope = StateScope::ToolCall;
226
227 fn reduce(&mut self, action: TestCounterAction) {
228 match action {
229 TestCounterAction::Increment(n) => self.value += n,
230 TestCounterAction::Reset => self.value = 0,
231 }
232 }
233 }
234
235 #[test]
236 fn to_serialized_state_action_roundtrip() {
237 let original = AnyStateAction::new::<TestCounter>(TestCounterAction::Increment(42));
238 let serialized = original.to_serialized_state_action();
239
240 assert!(serialized.state_type_name.contains("TestCounter"));
241 assert_eq!(serialized.base_path, "test_counter");
242 assert_eq!(serialized.scope, StateScope::Thread);
243 assert!(serialized.call_id_override.is_none());
244 assert_eq!(serialized.payload, json!({"Increment": 42}));
245 }
246
247 #[test]
248 fn registry_deserialize_and_reduce_roundtrip() {
249 let mut registry = StateActionDeserializerRegistry::new();
250 registry.register::<TestCounter>();
251
252 let original = AnyStateAction::new::<TestCounter>(TestCounterAction::Increment(7));
254 let serialized = original.to_serialized_state_action();
255
256 let reconstructed = registry.deserialize(&serialized).unwrap();
257
258 let base = json!({});
260 let original_patches =
261 reduce_state_actions(vec![original], &base, "test", &ScopeContext::run()).unwrap();
262 let reconstructed_patches =
263 reduce_state_actions(vec![reconstructed], &base, "test", &ScopeContext::run()).unwrap();
264
265 let result_a = apply_patch(&base, original_patches[0].patch()).unwrap();
267 let result_b = apply_patch(&base, reconstructed_patches[0].patch()).unwrap();
268 assert_eq!(result_a, result_b);
269 assert_eq!(result_a["test_counter"]["value"], 7);
270 }
271
272 #[test]
273 fn registry_unknown_type_returns_error() {
274 let registry = StateActionDeserializerRegistry::new();
275 let entry = SerializedStateAction {
276 state_type_name: "unknown::Type".into(),
277 base_path: "x".into(),
278 scope: StateScope::Run,
279 call_id_override: None,
280 payload: json!(null),
281 };
282 let err = registry.deserialize(&entry).unwrap_err();
283 assert!(matches!(err, StateActionDecodeError::UnknownStateType(_)));
284 }
285
286 #[test]
287 fn registry_bad_payload_returns_deserialization_error() {
288 let mut registry = StateActionDeserializerRegistry::new();
289 registry.register::<TestCounter>();
290
291 let entry = SerializedStateAction {
292 state_type_name: std::any::type_name::<TestCounter>().into(),
293 base_path: "test_counter".into(),
294 scope: StateScope::Run,
295 call_id_override: None,
296 payload: json!({"BadVariant": 99}),
297 };
298 let err = registry.deserialize(&entry).unwrap_err();
299 assert!(matches!(
300 err,
301 StateActionDecodeError::DeserializationFailed { .. }
302 ));
303 }
304
305 #[test]
306 fn tool_call_scoped_roundtrip() {
307 let mut registry = StateActionDeserializerRegistry::new();
308 registry.register::<ToolCallTestCounter>();
309
310 let original = AnyStateAction::new_for_call::<ToolCallTestCounter>(
311 TestCounterAction::Increment(3),
312 "call_99",
313 );
314 let serialized = original.to_serialized_state_action();
315 assert_eq!(serialized.scope, StateScope::ToolCall);
316 assert_eq!(serialized.call_id_override, Some("call_99".into()));
317
318 let reconstructed = registry.deserialize(&serialized).unwrap();
319
320 let base = json!({});
321 let patches =
322 reduce_state_actions(vec![reconstructed], &base, "test", &ScopeContext::run()).unwrap();
323 let result = apply_patch(&base, patches[0].patch()).unwrap();
324 assert_eq!(
325 result["__tool_call_scope"]["call_99"]["tc_counter"]["value"],
326 3
327 );
328 }
329}