tirea_contract/runtime/state/
spec.rs

1use super::scope_context::ScopeContext;
2use serde_json::Value;
3use std::any::TypeId;
4use std::fmt;
5use tirea_state::{
6    apply_patch_with_registry, get_at_path, parse_path, LatticeRegistry, Patch, Path, TireaResult,
7    TrackedPatch,
8};
9
10// Re-export from tirea-state so downstream code still works.
11pub use tirea_state::{StateScope, StateSpec};
12
13type ReduceFn = Box<dyn FnOnce(&Value, &str) -> TireaResult<Patch> + Send>;
14
15/// Type-erased state action that can be applied to a JSON document.
16///
17/// Created via [`AnyStateAction::new`] / [`new_at`](Self::new_at) /
18/// [`new_for_call`](Self::new_for_call) from a concrete `StateSpec` type and
19/// reducer action.
20pub struct AnyStateAction {
21    state_type_id: TypeId,
22    state_type_name: &'static str,
23    scope: StateScope,
24    base_path: String,
25    /// When set, overrides the `ScopeContext` call_id for path resolution.
26    /// Used by recovery/framework-internal scenarios that must target a
27    /// specific call_id without a live `ScopeContext`.
28    call_id_override: Option<String>,
29    reduce_fn: ReduceFn,
30    /// Type-erased lattice registration function captured from `S::register_lattice`.
31    /// Enables `reduce_state_actions` to build a local registry for
32    /// CRDT-aware rolling snapshot application.
33    register_lattice: fn(&mut LatticeRegistry),
34    /// Serialized action payload captured before the action is moved into
35    /// `reduce_fn`. Enables pending-write persistence for crash recovery
36    /// without requiring access to the concrete action type.
37    serialized_payload: Value,
38}
39
40impl AnyStateAction {
41    /// Create a type-erased action for non-ToolCall-scoped state `S`.
42    ///
43    /// The scope is read from `S::SCOPE` (Thread or Run). For ToolCall-scoped
44    /// state, use [`new_for_call`](Self::new_for_call) instead.
45    ///
46    /// # Panics
47    ///
48    /// Panics if `S::PATH` is empty or `S::SCOPE` is `ToolCall`.
49    pub fn new<S: StateSpec>(action: S::Action) -> Self {
50        assert!(
51            S::SCOPE != StateScope::ToolCall,
52            "ToolCall-scoped state must use new_for_call(); got new() for {}",
53            std::any::type_name::<S>(),
54        );
55        Self::build::<S>(action, S::SCOPE, S::PATH.to_owned(), None)
56    }
57
58    /// Create a type-erased action targeting an explicit thread/run base path.
59    ///
60    /// This is the preferred way to use typed reducers with dynamically chosen
61    /// state paths while still avoiding raw patch actions.
62    ///
63    /// # Panics
64    ///
65    /// Panics if `S::SCOPE` is `ToolCall`.
66    pub fn new_at<S: StateSpec>(path: impl Into<String>, action: S::Action) -> Self {
67        assert!(
68            S::SCOPE != StateScope::ToolCall,
69            "ToolCall-scoped state must use new_for_call() / new_for_call_at(); got new_at() for {}",
70            std::any::type_name::<S>(),
71        );
72        Self::build::<S>(action, S::SCOPE, path.into(), None)
73    }
74
75    /// Create a type-erased action targeting a specific tool call scope.
76    ///
77    /// The `call_id` determines which `__tool_call_scope.<id>` namespace the
78    /// action is routed to.
79    ///
80    /// # Panics
81    ///
82    /// Panics if `S::PATH` is empty or `S::SCOPE` is not `ToolCall`.
83    pub fn new_for_call<S: StateSpec>(action: S::Action, call_id: impl Into<String>) -> Self {
84        assert!(
85            S::SCOPE == StateScope::ToolCall,
86            "new_for_call() requires ToolCall-scoped state; {} has scope {:?}",
87            std::any::type_name::<S>(),
88            S::SCOPE,
89        );
90        Self::build::<S>(
91            action,
92            StateScope::ToolCall,
93            S::PATH.to_owned(),
94            Some(call_id.into()),
95        )
96    }
97
98    /// Create a type-erased tool-call-scoped action targeting an explicit path.
99    ///
100    /// # Panics
101    ///
102    /// Panics if `S::SCOPE` is not `ToolCall`.
103    pub fn new_for_call_at<S: StateSpec>(
104        path: impl Into<String>,
105        action: S::Action,
106        call_id: impl Into<String>,
107    ) -> Self {
108        assert!(
109            S::SCOPE == StateScope::ToolCall,
110            "new_for_call_at() requires ToolCall-scoped state; {} has scope {:?}",
111            std::any::type_name::<S>(),
112            S::SCOPE,
113        );
114        Self::build::<S>(
115            action,
116            StateScope::ToolCall,
117            path.into(),
118            Some(call_id.into()),
119        )
120    }
121
122    fn build<S: StateSpec>(
123        action: S::Action,
124        scope: StateScope,
125        base_path: String,
126        call_id_override: Option<String>,
127    ) -> Self {
128        let serialized_payload =
129            serde_json::to_value(&action).expect("StateSpec::Action must be serializable");
130
131        Self {
132            state_type_id: TypeId::of::<S>(),
133            state_type_name: std::any::type_name::<S>(),
134            scope,
135            base_path,
136            call_id_override,
137            reduce_fn: Self::make_reduce_fn::<S>(action),
138            register_lattice: S::register_lattice,
139            serialized_payload,
140        }
141    }
142
143    fn make_reduce_fn<S: StateSpec>(action: S::Action) -> ReduceFn {
144        Box::new(move |doc: &Value, actual_path: &str| {
145            let path = parse_path(actual_path);
146            let sub_doc = get_at_path(doc, &path).cloned().unwrap_or(Value::Null);
147            // Track whether the state is being created for the first time.
148            // When true, we must emit a whole-state Op::set rather than a
149            // per-field diff, because diff_ops would skip fields that match
150            // the default (e.g., status=Running when default is Running).
151            let is_creation = sub_doc.is_null() || sub_doc == Value::Object(Default::default());
152
153            // When the path doesn't exist (Null) and from_value fails,
154            // fall back to an empty object. This handles derive(State) structs
155            // whose #[serde(default)] fields can deserialize from `{}` but not
156            // from `null` (serde_json rejects null for struct types).
157            let mut state = S::from_value(&sub_doc).or_else(|first_err| {
158                if sub_doc.is_null() {
159                    S::from_value(&Value::Object(Default::default())).map_err(|_| first_err)
160                } else {
161                    Err(first_err)
162                }
163            })?;
164
165            if is_creation && S::lattice_keys().is_empty() {
166                // First-time creation of non-CRDT state: emit whole-state
167                // Op::set so all fields (including those matching defaults)
168                // are materialised in the document.
169                state.reduce(action);
170                let new_value = state.to_value()?;
171                let base_path = path_from_str(actual_path);
172                return Ok(Patch::with_ops(vec![tirea_state::Op::set(
173                    base_path, new_value,
174                )]));
175            }
176
177            // For CRDT types (or updates to existing state): use diff_ops
178            // so lattice fields correctly emit Op::LatticeMerge.
179            let old = state.clone();
180            state.reduce(action);
181
182            let base_path = path_from_str(actual_path);
183            let ops = S::diff_ops(&old, &state, &base_path)?;
184            if ops.is_empty() {
185                return Ok(Patch::default());
186            }
187            Ok(Patch::with_ops(ops))
188        })
189    }
190
191    /// The [`TypeId`] of the state type this action targets.
192    pub fn state_type_id(&self) -> TypeId {
193        self.state_type_id
194    }
195
196    /// Human-readable name of the state type (for diagnostics).
197    pub fn state_type_name(&self) -> &str {
198        self.state_type_name
199    }
200
201    /// Scope of the targeted state.
202    pub fn scope(&self) -> StateScope {
203        self.scope
204    }
205
206    /// Canonical base JSON path for the targeted state.
207    pub fn base_path(&self) -> &str {
208        &self.base_path
209    }
210
211    /// Optional tool-call scope override captured for recovery/internal flows.
212    pub fn call_id_override(&self) -> Option<&str> {
213        self.call_id_override.as_deref()
214    }
215
216    /// The serialized action payload captured before the action is moved into
217    /// the reduce closure.
218    pub fn serialized_payload(&self) -> &Value {
219        &self.serialized_payload
220    }
221}
222
223/// Reduce a batch of state actions into tracked patches with rolling snapshot semantics.
224///
225/// Typed actions are reduced against a snapshot that is updated after each
226/// action, so sequential actions in one batch compose deterministically.
227///
228/// `scope_ctx` controls how `ToolCall`-scoped actions are routed to per-call
229/// namespaces. For Thread/Run phases (anything outside a tool-call), pass
230/// `ScopeContext::run()`.
231pub fn reduce_state_actions(
232    actions: Vec<AnyStateAction>,
233    base_snapshot: &Value,
234    default_source: &str,
235    scope_ctx: &ScopeContext,
236) -> TireaResult<Vec<TrackedPatch>> {
237    // Build a local lattice registry from Typed actions so that the rolling
238    // snapshot correctly handles Op::LatticeMerge ops (rather than falling back
239    // to Op::Set semantics).
240    let mut local_registry = LatticeRegistry::new();
241    for action in &actions {
242        (action.register_lattice)(&mut local_registry);
243    }
244
245    let mut rolling_snapshot = base_snapshot.clone();
246    let mut tracked_patches = Vec::new();
247
248    for action in actions {
249        // Resolve actual storage path: call_id_override takes priority,
250        // then fall back to the ambient scope_ctx.
251        let actual_path = if let Some(ref cid) = action.call_id_override {
252            let override_ctx = ScopeContext::for_call(cid.as_str());
253            override_ctx.resolve_path(action.scope, action.base_path.as_str())
254        } else {
255            scope_ctx.resolve_path(action.scope, action.base_path.as_str())
256        };
257        let patch = (action.reduce_fn)(&rolling_snapshot, &actual_path)?;
258        if patch.is_empty() {
259            continue;
260        }
261        rolling_snapshot = apply_patch_with_registry(&rolling_snapshot, &patch, &local_registry)?;
262        tracked_patches.push(TrackedPatch::new(patch).with_source(default_source));
263    }
264
265    Ok(tracked_patches)
266}
267
268impl fmt::Debug for AnyStateAction {
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        f.debug_struct("AnyStateAction")
271            .field("state", &self.state_type_name)
272            .field("type_id", &self.state_type_id)
273            .field("scope", &self.scope)
274            .field("payload", &self.serialized_payload)
275            .finish()
276    }
277}
278
279/// Convert a dot-separated path string to a `Path` for use in `Op::set`.
280fn path_from_str(s: &str) -> Path {
281    let mut path = Path::root();
282    for seg in s.split('.') {
283        if !seg.is_empty() {
284            path = path.key(seg);
285        }
286    }
287    path
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use serde::{Deserialize, Serialize};
294    use serde_json::json;
295    use tirea_state::{
296        apply_patch, conflicts_with_registry, DocCell, GCounter, LatticeRegistry, Op, PatchSink,
297        Path as TPath, State,
298    };
299
300    // -- Manual State + StateSpec impl for testing --
301
302    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
303    struct Counter {
304        value: i64,
305    }
306
307    struct CounterRef;
308
309    impl State for Counter {
310        type Ref<'a> = CounterRef;
311        const PATH: &'static str = "counters.main";
312
313        fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
314            CounterRef
315        }
316
317        fn from_value(value: &Value) -> TireaResult<Self> {
318            if value.is_null() {
319                return Ok(Self::default());
320            }
321            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
322        }
323
324        fn to_value(&self) -> TireaResult<Value> {
325            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
326        }
327    }
328
329    #[derive(Debug, Serialize, Deserialize)]
330    enum CounterAction {
331        Increment(i64),
332        Reset,
333    }
334
335    impl StateSpec for Counter {
336        type Action = CounterAction;
337
338        fn reduce(&mut self, action: CounterAction) {
339            match action {
340                CounterAction::Increment(n) => self.value += n,
341                CounterAction::Reset => self.value = 0,
342            }
343        }
344    }
345
346    // -- No-path state for panic test --
347
348    #[derive(Debug, Clone, Serialize, Deserialize)]
349    struct Unbound {
350        x: i64,
351    }
352
353    struct UnboundRef;
354
355    impl State for Unbound {
356        type Ref<'a> = UnboundRef;
357        // PATH defaults to "" (no bound path)
358
359        fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
360            UnboundRef
361        }
362
363        fn from_value(value: &Value) -> TireaResult<Self> {
364            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
365        }
366
367        fn to_value(&self) -> TireaResult<Value> {
368            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
369        }
370    }
371
372    impl StateSpec for Unbound {
373        type Action = ();
374        fn reduce(&mut self, _: ()) {}
375    }
376
377    #[derive(Debug, Clone, Serialize, Deserialize, Default)]
378    struct ToolScopedCounter {
379        value: i64,
380    }
381
382    struct ToolScopedCounterRef;
383
384    impl State for ToolScopedCounter {
385        type Ref<'a> = ToolScopedCounterRef;
386        const PATH: &'static str = "tool_counter";
387
388        fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
389            ToolScopedCounterRef
390        }
391
392        fn from_value(value: &Value) -> TireaResult<Self> {
393            if value.is_null() {
394                return Ok(Self::default());
395            }
396            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
397        }
398
399        fn to_value(&self) -> TireaResult<Value> {
400            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
401        }
402    }
403
404    impl StateSpec for ToolScopedCounter {
405        type Action = CounterAction;
406        const SCOPE: StateScope = StateScope::ToolCall;
407
408        fn reduce(&mut self, action: Self::Action) {
409            match action {
410                CounterAction::Increment(n) => self.value += n,
411                CounterAction::Reset => self.value = 0,
412            }
413        }
414    }
415
416    // -- Tests --
417
418    #[test]
419    fn any_state_action_increment() {
420        let doc = json!({"counters": {"main": {"value": 5}}});
421        let action = AnyStateAction::new::<Counter>(CounterAction::Increment(3));
422        let patch = reduce_state_actions(vec![action], &doc, "test", &ScopeContext::run()).unwrap();
423        let result = apply_patch(&doc, patch[0].patch()).unwrap();
424        assert_eq!(result["counters"]["main"]["value"], 8);
425    }
426
427    #[test]
428    fn any_state_action_reset() {
429        let doc = json!({"counters": {"main": {"value": 42}}});
430        let action = AnyStateAction::new::<Counter>(CounterAction::Reset);
431        let patch = reduce_state_actions(vec![action], &doc, "test", &ScopeContext::run()).unwrap();
432        let result = apply_patch(&doc, patch[0].patch()).unwrap();
433        assert_eq!(result["counters"]["main"]["value"], 0);
434    }
435
436    #[test]
437    fn any_state_action_missing_path_defaults() {
438        let doc = json!({});
439        let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
440        let patch = reduce_state_actions(vec![action], &doc, "test", &ScopeContext::run()).unwrap();
441        let result = apply_patch(&doc, patch[0].patch()).unwrap();
442        assert_eq!(result["counters"]["main"]["value"], 1);
443    }
444
445    #[test]
446    fn any_state_action_label() {
447        let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
448        assert!(action.state_type_name().contains("Counter"));
449    }
450
451    #[test]
452    fn any_state_action_debug() {
453        let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
454        let debug = format!("{action:?}");
455        assert!(debug.contains("AnyStateAction"));
456        assert!(debug.contains("Counter"));
457    }
458
459    #[test]
460    fn any_state_action_state_type_id() {
461        let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
462        assert_eq!(action.state_type_id(), TypeId::of::<Counter>());
463    }
464
465    #[test]
466    fn any_state_action_scope_defaults_to_thread() {
467        let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
468        assert_eq!(action.scope(), StateScope::Thread);
469    }
470
471    #[test]
472    fn any_state_action_scope_tool_call() {
473        let action = AnyStateAction::new_for_call::<ToolScopedCounter>(
474            CounterAction::Increment(1),
475            "call_1",
476        );
477        assert_eq!(action.scope(), StateScope::ToolCall);
478    }
479
480    #[test]
481    fn reduce_state_actions_uses_rolling_snapshot() {
482        let base = json!({"counters": {"main": {"value": 1}}});
483        let actions = vec![
484            AnyStateAction::new::<Counter>(CounterAction::Increment(1)),
485            AnyStateAction::new::<Counter>(CounterAction::Increment(1)),
486        ];
487        let tracked = reduce_state_actions(actions, &base, "agent", &ScopeContext::run()).unwrap();
488        assert_eq!(tracked.len(), 2);
489
490        let mut state = base.clone();
491        for patch in tracked {
492            state = apply_patch(&state, patch.patch()).unwrap();
493        }
494        assert_eq!(state["counters"]["main"]["value"], 3);
495    }
496
497    #[test]
498    fn any_state_action_allows_root_path() {
499        let tracked = reduce_state_actions(
500            vec![AnyStateAction::new_at::<Counter>(
501                "",
502                CounterAction::Increment(1),
503            )],
504            &Value::Null,
505            "test",
506            &ScopeContext::run(),
507        )
508        .expect("root path action should reduce");
509        assert_eq!(tracked.len(), 1);
510        let result = apply_patch(&Value::Null, tracked[0].patch()).expect("patch should apply");
511        assert_eq!(result["value"], 1);
512    }
513
514    #[test]
515    fn reduce_tool_call_scoped_action_routes_to_call_namespace() {
516        let base = json!({});
517        let actions = vec![AnyStateAction::new_for_call::<ToolScopedCounter>(
518            CounterAction::Increment(5),
519            "call_42",
520        )];
521        let tracked = reduce_state_actions(actions, &base, "test", &ScopeContext::run()).unwrap();
522        assert_eq!(tracked.len(), 1);
523
524        let result = apply_patch(&base, tracked[0].patch()).unwrap();
525        assert_eq!(
526            result["__tool_call_scope"]["call_42"]["tool_counter"]["value"],
527            5
528        );
529    }
530
531    #[test]
532    fn reduce_run_scoped_action_ignores_call_context() {
533        let base = json!({});
534        let scope_ctx = ScopeContext::for_call("call_42");
535        let actions = vec![AnyStateAction::new::<Counter>(CounterAction::Increment(7))];
536        let tracked = reduce_state_actions(actions, &base, "test", &scope_ctx).unwrap();
537
538        let result = apply_patch(&base, tracked[0].patch()).unwrap();
539        assert_eq!(result["counters"]["main"]["value"], 7);
540        assert!(result.get("__tool_call_scope").is_none());
541    }
542
543    #[test]
544    fn new_for_call_overrides_scope_ctx() {
545        let base = json!({});
546        let scope_ctx = ScopeContext::for_call("ambient_call");
547        let actions = vec![AnyStateAction::new_for_call::<ToolScopedCounter>(
548            CounterAction::Increment(3),
549            "override_call",
550        )];
551        let tracked = reduce_state_actions(actions, &base, "test", &scope_ctx).unwrap();
552
553        let result = apply_patch(&base, tracked[0].patch()).unwrap();
554        // Should use the override call_id, not the ambient one
555        assert_eq!(
556            result["__tool_call_scope"]["override_call"]["tool_counter"]["value"],
557            3
558        );
559        assert!(result["__tool_call_scope"].get("ambient_call").is_none());
560    }
561
562    #[test]
563    #[should_panic(expected = "requires ToolCall-scoped state")]
564    fn new_for_call_panics_on_non_tool_call_scope() {
565        let _ = AnyStateAction::new_for_call::<Unbound>((), "call_1");
566    }
567
568    // -- CRDT (lattice) field test types --
569
570    #[derive(Debug, Clone, Serialize, Deserialize, Default)]
571    struct TokenStats {
572        #[serde(default)]
573        total_input: GCounter,
574        #[serde(default)]
575        total_output: GCounter,
576        #[serde(default)]
577        label: String,
578    }
579
580    struct TokenStatsRef;
581
582    impl State for TokenStats {
583        type Ref<'a> = TokenStatsRef;
584        const PATH: &'static str = "token_stats";
585
586        fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
587            TokenStatsRef
588        }
589
590        fn from_value(value: &Value) -> TireaResult<Self> {
591            if value.is_null() {
592                return Ok(Self::default());
593            }
594            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
595        }
596
597        fn to_value(&self) -> TireaResult<Value> {
598            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
599        }
600
601        fn lattice_keys() -> &'static [&'static str] {
602            &["total_input", "total_output"]
603        }
604
605        fn register_lattice(registry: &mut LatticeRegistry) {
606            registry.register::<GCounter>(parse_path("token_stats.total_input"));
607            registry.register::<GCounter>(parse_path("token_stats.total_output"));
608        }
609    }
610
611    #[derive(Serialize, Deserialize)]
612    #[allow(dead_code)]
613    enum TokenStatsAction {
614        AddInput(u64),
615        AddOutput(u64),
616    }
617
618    impl StateSpec for TokenStats {
619        type Action = TokenStatsAction;
620
621        fn reduce(&mut self, action: TokenStatsAction) {
622            match action {
623                TokenStatsAction::AddInput(n) => self.total_input.increment("_", n),
624                TokenStatsAction::AddOutput(n) => self.total_output.increment("_", n),
625            }
626        }
627    }
628
629    #[test]
630    fn reducer_emits_op_set_for_crdt_fields_causing_false_conflict() {
631        // Two plugins independently record tokens → parallel patches
632        let base = json!({});
633
634        let patches_a = reduce_state_actions(
635            vec![AnyStateAction::new::<TokenStats>(
636                TokenStatsAction::AddInput(100),
637            )],
638            &base,
639            "plugin_a",
640            &ScopeContext::run(),
641        )
642        .unwrap();
643        let patches_b = reduce_state_actions(
644            vec![AnyStateAction::new::<TokenStats>(
645                TokenStatsAction::AddInput(200),
646            )],
647            &base,
648            "plugin_b",
649            &ScopeContext::run(),
650        )
651        .unwrap();
652
653        // Register GCounter at field paths
654        let mut registry = LatticeRegistry::new();
655        registry.register::<GCounter>(parse_path("token_stats.total_input"));
656        registry.register::<GCounter>(parse_path("token_stats.total_output"));
657
658        let conflicts =
659            conflicts_with_registry(patches_a[0].patch(), patches_b[0].patch(), &registry);
660
661        // After fix: CRDT fields should use Op::LatticeMerge → no conflict
662        assert!(
663            conflicts.is_empty(),
664            "CRDT fields should not conflict; reducer should emit Op::LatticeMerge for lattice fields"
665        );
666    }
667
668    #[test]
669    fn reducer_emits_lattice_merge_ops_for_crdt_fields() {
670        let base = json!({});
671        let patches = reduce_state_actions(
672            vec![AnyStateAction::new::<TokenStats>(
673                TokenStatsAction::AddInput(100),
674            )],
675            &base,
676            "test",
677            &ScopeContext::run(),
678        )
679        .unwrap();
680
681        let ops = patches[0].patch().ops();
682        // Should have per-field ops, not a single whole-state Op::set
683        let has_lattice_merge = ops.iter().any(|op| matches!(op, Op::LatticeMerge { .. }));
684        assert!(
685            has_lattice_merge,
686            "reducer should emit Op::LatticeMerge for CRDT fields, got: {ops:?}"
687        );
688    }
689
690    #[test]
691    fn reducer_mixed_fields_emits_correct_op_types() {
692        // Custom action that modifies both CRDT and non-CRDT fields
693        #[derive(Debug, Clone, Serialize, Deserialize, Default)]
694        struct MixedState {
695            #[serde(default)]
696            counter: GCounter,
697            #[serde(default)]
698            name: String,
699        }
700
701        struct MixedStateRef;
702
703        impl State for MixedState {
704            type Ref<'a> = MixedStateRef;
705            const PATH: &'static str = "mixed";
706
707            fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
708                MixedStateRef
709            }
710
711            fn from_value(value: &Value) -> TireaResult<Self> {
712                if value.is_null() {
713                    return Ok(Self::default());
714                }
715                serde_json::from_value(value.clone())
716                    .map_err(tirea_state::TireaError::Serialization)
717            }
718
719            fn to_value(&self) -> TireaResult<Value> {
720                serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
721            }
722
723            fn lattice_keys() -> &'static [&'static str] {
724                &["counter"]
725            }
726        }
727
728        #[derive(Serialize, Deserialize)]
729        enum MixedAction {
730            IncrementAndRename(u64, String),
731        }
732
733        impl StateSpec for MixedState {
734            type Action = MixedAction;
735
736            fn reduce(&mut self, action: MixedAction) {
737                match action {
738                    MixedAction::IncrementAndRename(n, name) => {
739                        self.counter.increment("_", n);
740                        self.name = name;
741                    }
742                }
743            }
744        }
745
746        let base = json!({});
747        let patches = reduce_state_actions(
748            vec![AnyStateAction::new::<MixedState>(
749                MixedAction::IncrementAndRename(5, "new".to_string()),
750            )],
751            &base,
752            "test",
753            &ScopeContext::run(),
754        )
755        .unwrap();
756
757        let ops = patches[0].patch().ops();
758        let lattice_ops: Vec<_> = ops
759            .iter()
760            .filter(|op| matches!(op, Op::LatticeMerge { .. }))
761            .collect();
762        let set_ops: Vec<_> = ops
763            .iter()
764            .filter(|op| matches!(op, Op::Set { .. }))
765            .collect();
766
767        assert!(
768            !lattice_ops.is_empty(),
769            "should have LatticeMerge for CRDT field 'counter'"
770        );
771        assert!(
772            !set_ops.is_empty(),
773            "should have Op::set for non-CRDT field 'name'"
774        );
775    }
776
777    #[test]
778    fn diff_ops_skips_unchanged_fields() {
779        // Only modify one CRDT field; the other fields should not appear in ops.
780        let base = json!({"token_stats": {"total_input": {}, "total_output": {}, "label": ""}});
781        let patches = reduce_state_actions(
782            vec![AnyStateAction::new::<TokenStats>(
783                TokenStatsAction::AddInput(42),
784            )],
785            &base,
786            "test",
787            &ScopeContext::run(),
788        )
789        .unwrap();
790
791        let ops = patches[0].patch().ops();
792        // Only total_input changed → exactly one op
793        assert_eq!(
794            ops.len(),
795            1,
796            "should only emit op for the changed field, got: {ops:?}"
797        );
798        assert!(
799            matches!(&ops[0], Op::LatticeMerge { .. }),
800            "changed CRDT field should use LatticeMerge"
801        );
802    }
803
804    #[test]
805    fn diff_ops_empty_when_no_changes() {
806        // A reset on an already-zero counter produces no state change.
807        let base = json!({"counters": {"main": {"value": 0}}});
808        let patches = reduce_state_actions(
809            vec![AnyStateAction::new::<Counter>(CounterAction::Reset)],
810            &base,
811            "test",
812            &ScopeContext::run(),
813        )
814        .unwrap();
815
816        // No change → no patches emitted
817        assert!(
818            patches.is_empty(),
819            "no-op reduce should produce no patches, got: {patches:?}"
820        );
821    }
822
823    #[test]
824    fn serialized_payload_is_captured() {
825        let action = AnyStateAction::new::<Counter>(CounterAction::Increment(42));
826        let payload = action.serialized_payload();
827        assert_eq!(*payload, json!({"Increment": 42}));
828    }
829
830    #[test]
831    fn serialized_payload_captured_for_call_scoped() {
832        let action =
833            AnyStateAction::new_for_call::<ToolScopedCounter>(CounterAction::Reset, "call_1");
834        let payload = action.serialized_payload();
835        assert_eq!(*payload, json!("Reset"));
836    }
837}