1use super::scope_context::ScopeContext;
2use serde_json::Value;
3use std::any::TypeId;
4use std::fmt;
5use tirea_state::{
6 apply_patch_with_registry, get_at_path, parse_path, LatticeRegistry, Patch, Path, TireaResult,
7 TrackedPatch,
8};
9
10pub use tirea_state::{StateScope, StateSpec};
12
13type ReduceFn = Box<dyn FnOnce(&Value, &str) -> TireaResult<Patch> + Send>;
14
15pub struct AnyStateAction {
21 state_type_id: TypeId,
22 state_type_name: &'static str,
23 scope: StateScope,
24 base_path: String,
25 call_id_override: Option<String>,
29 reduce_fn: ReduceFn,
30 register_lattice: fn(&mut LatticeRegistry),
34 serialized_payload: Value,
38}
39
40impl AnyStateAction {
41 pub fn new<S: StateSpec>(action: S::Action) -> Self {
50 assert!(
51 S::SCOPE != StateScope::ToolCall,
52 "ToolCall-scoped state must use new_for_call(); got new() for {}",
53 std::any::type_name::<S>(),
54 );
55 Self::build::<S>(action, S::SCOPE, S::PATH.to_owned(), None)
56 }
57
58 pub fn new_at<S: StateSpec>(path: impl Into<String>, action: S::Action) -> Self {
67 assert!(
68 S::SCOPE != StateScope::ToolCall,
69 "ToolCall-scoped state must use new_for_call() / new_for_call_at(); got new_at() for {}",
70 std::any::type_name::<S>(),
71 );
72 Self::build::<S>(action, S::SCOPE, path.into(), None)
73 }
74
75 pub fn new_for_call<S: StateSpec>(action: S::Action, call_id: impl Into<String>) -> Self {
84 assert!(
85 S::SCOPE == StateScope::ToolCall,
86 "new_for_call() requires ToolCall-scoped state; {} has scope {:?}",
87 std::any::type_name::<S>(),
88 S::SCOPE,
89 );
90 Self::build::<S>(
91 action,
92 StateScope::ToolCall,
93 S::PATH.to_owned(),
94 Some(call_id.into()),
95 )
96 }
97
98 pub fn new_for_call_at<S: StateSpec>(
104 path: impl Into<String>,
105 action: S::Action,
106 call_id: impl Into<String>,
107 ) -> Self {
108 assert!(
109 S::SCOPE == StateScope::ToolCall,
110 "new_for_call_at() requires ToolCall-scoped state; {} has scope {:?}",
111 std::any::type_name::<S>(),
112 S::SCOPE,
113 );
114 Self::build::<S>(
115 action,
116 StateScope::ToolCall,
117 path.into(),
118 Some(call_id.into()),
119 )
120 }
121
122 fn build<S: StateSpec>(
123 action: S::Action,
124 scope: StateScope,
125 base_path: String,
126 call_id_override: Option<String>,
127 ) -> Self {
128 let serialized_payload =
129 serde_json::to_value(&action).expect("StateSpec::Action must be serializable");
130
131 Self {
132 state_type_id: TypeId::of::<S>(),
133 state_type_name: std::any::type_name::<S>(),
134 scope,
135 base_path,
136 call_id_override,
137 reduce_fn: Self::make_reduce_fn::<S>(action),
138 register_lattice: S::register_lattice,
139 serialized_payload,
140 }
141 }
142
143 fn make_reduce_fn<S: StateSpec>(action: S::Action) -> ReduceFn {
144 Box::new(move |doc: &Value, actual_path: &str| {
145 let path = parse_path(actual_path);
146 let sub_doc = get_at_path(doc, &path).cloned().unwrap_or(Value::Null);
147 let is_creation = sub_doc.is_null() || sub_doc == Value::Object(Default::default());
152
153 let mut state = S::from_value(&sub_doc).or_else(|first_err| {
158 if sub_doc.is_null() {
159 S::from_value(&Value::Object(Default::default())).map_err(|_| first_err)
160 } else {
161 Err(first_err)
162 }
163 })?;
164
165 if is_creation && S::lattice_keys().is_empty() {
166 state.reduce(action);
170 let new_value = state.to_value()?;
171 let base_path = path_from_str(actual_path);
172 return Ok(Patch::with_ops(vec![tirea_state::Op::set(
173 base_path, new_value,
174 )]));
175 }
176
177 let old = state.clone();
180 state.reduce(action);
181
182 let base_path = path_from_str(actual_path);
183 let ops = S::diff_ops(&old, &state, &base_path)?;
184 if ops.is_empty() {
185 return Ok(Patch::default());
186 }
187 Ok(Patch::with_ops(ops))
188 })
189 }
190
191 pub fn state_type_id(&self) -> TypeId {
193 self.state_type_id
194 }
195
196 pub fn state_type_name(&self) -> &str {
198 self.state_type_name
199 }
200
201 pub fn scope(&self) -> StateScope {
203 self.scope
204 }
205
206 pub fn base_path(&self) -> &str {
208 &self.base_path
209 }
210
211 pub fn call_id_override(&self) -> Option<&str> {
213 self.call_id_override.as_deref()
214 }
215
216 pub fn serialized_payload(&self) -> &Value {
219 &self.serialized_payload
220 }
221}
222
223pub fn reduce_state_actions(
232 actions: Vec<AnyStateAction>,
233 base_snapshot: &Value,
234 default_source: &str,
235 scope_ctx: &ScopeContext,
236) -> TireaResult<Vec<TrackedPatch>> {
237 let mut local_registry = LatticeRegistry::new();
241 for action in &actions {
242 (action.register_lattice)(&mut local_registry);
243 }
244
245 let mut rolling_snapshot = base_snapshot.clone();
246 let mut tracked_patches = Vec::new();
247
248 for action in actions {
249 let actual_path = if let Some(ref cid) = action.call_id_override {
252 let override_ctx = ScopeContext::for_call(cid.as_str());
253 override_ctx.resolve_path(action.scope, action.base_path.as_str())
254 } else {
255 scope_ctx.resolve_path(action.scope, action.base_path.as_str())
256 };
257 let patch = (action.reduce_fn)(&rolling_snapshot, &actual_path)?;
258 if patch.is_empty() {
259 continue;
260 }
261 rolling_snapshot = apply_patch_with_registry(&rolling_snapshot, &patch, &local_registry)?;
262 tracked_patches.push(TrackedPatch::new(patch).with_source(default_source));
263 }
264
265 Ok(tracked_patches)
266}
267
268impl fmt::Debug for AnyStateAction {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 f.debug_struct("AnyStateAction")
271 .field("state", &self.state_type_name)
272 .field("type_id", &self.state_type_id)
273 .field("scope", &self.scope)
274 .field("payload", &self.serialized_payload)
275 .finish()
276 }
277}
278
279fn path_from_str(s: &str) -> Path {
281 let mut path = Path::root();
282 for seg in s.split('.') {
283 if !seg.is_empty() {
284 path = path.key(seg);
285 }
286 }
287 path
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use serde::{Deserialize, Serialize};
294 use serde_json::json;
295 use tirea_state::{
296 apply_patch, conflicts_with_registry, DocCell, GCounter, LatticeRegistry, Op, PatchSink,
297 Path as TPath, State,
298 };
299
300 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
303 struct Counter {
304 value: i64,
305 }
306
307 struct CounterRef;
308
309 impl State for Counter {
310 type Ref<'a> = CounterRef;
311 const PATH: &'static str = "counters.main";
312
313 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
314 CounterRef
315 }
316
317 fn from_value(value: &Value) -> TireaResult<Self> {
318 if value.is_null() {
319 return Ok(Self::default());
320 }
321 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
322 }
323
324 fn to_value(&self) -> TireaResult<Value> {
325 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
326 }
327 }
328
329 #[derive(Debug, Serialize, Deserialize)]
330 enum CounterAction {
331 Increment(i64),
332 Reset,
333 }
334
335 impl StateSpec for Counter {
336 type Action = CounterAction;
337
338 fn reduce(&mut self, action: CounterAction) {
339 match action {
340 CounterAction::Increment(n) => self.value += n,
341 CounterAction::Reset => self.value = 0,
342 }
343 }
344 }
345
346 #[derive(Debug, Clone, Serialize, Deserialize)]
349 struct Unbound {
350 x: i64,
351 }
352
353 struct UnboundRef;
354
355 impl State for Unbound {
356 type Ref<'a> = UnboundRef;
357 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
360 UnboundRef
361 }
362
363 fn from_value(value: &Value) -> TireaResult<Self> {
364 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
365 }
366
367 fn to_value(&self) -> TireaResult<Value> {
368 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
369 }
370 }
371
372 impl StateSpec for Unbound {
373 type Action = ();
374 fn reduce(&mut self, _: ()) {}
375 }
376
377 #[derive(Debug, Clone, Serialize, Deserialize, Default)]
378 struct ToolScopedCounter {
379 value: i64,
380 }
381
382 struct ToolScopedCounterRef;
383
384 impl State for ToolScopedCounter {
385 type Ref<'a> = ToolScopedCounterRef;
386 const PATH: &'static str = "tool_counter";
387
388 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
389 ToolScopedCounterRef
390 }
391
392 fn from_value(value: &Value) -> TireaResult<Self> {
393 if value.is_null() {
394 return Ok(Self::default());
395 }
396 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
397 }
398
399 fn to_value(&self) -> TireaResult<Value> {
400 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
401 }
402 }
403
404 impl StateSpec for ToolScopedCounter {
405 type Action = CounterAction;
406 const SCOPE: StateScope = StateScope::ToolCall;
407
408 fn reduce(&mut self, action: Self::Action) {
409 match action {
410 CounterAction::Increment(n) => self.value += n,
411 CounterAction::Reset => self.value = 0,
412 }
413 }
414 }
415
416 #[test]
419 fn any_state_action_increment() {
420 let doc = json!({"counters": {"main": {"value": 5}}});
421 let action = AnyStateAction::new::<Counter>(CounterAction::Increment(3));
422 let patch = reduce_state_actions(vec![action], &doc, "test", &ScopeContext::run()).unwrap();
423 let result = apply_patch(&doc, patch[0].patch()).unwrap();
424 assert_eq!(result["counters"]["main"]["value"], 8);
425 }
426
427 #[test]
428 fn any_state_action_reset() {
429 let doc = json!({"counters": {"main": {"value": 42}}});
430 let action = AnyStateAction::new::<Counter>(CounterAction::Reset);
431 let patch = reduce_state_actions(vec![action], &doc, "test", &ScopeContext::run()).unwrap();
432 let result = apply_patch(&doc, patch[0].patch()).unwrap();
433 assert_eq!(result["counters"]["main"]["value"], 0);
434 }
435
436 #[test]
437 fn any_state_action_missing_path_defaults() {
438 let doc = json!({});
439 let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
440 let patch = reduce_state_actions(vec![action], &doc, "test", &ScopeContext::run()).unwrap();
441 let result = apply_patch(&doc, patch[0].patch()).unwrap();
442 assert_eq!(result["counters"]["main"]["value"], 1);
443 }
444
445 #[test]
446 fn any_state_action_label() {
447 let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
448 assert!(action.state_type_name().contains("Counter"));
449 }
450
451 #[test]
452 fn any_state_action_debug() {
453 let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
454 let debug = format!("{action:?}");
455 assert!(debug.contains("AnyStateAction"));
456 assert!(debug.contains("Counter"));
457 }
458
459 #[test]
460 fn any_state_action_state_type_id() {
461 let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
462 assert_eq!(action.state_type_id(), TypeId::of::<Counter>());
463 }
464
465 #[test]
466 fn any_state_action_scope_defaults_to_thread() {
467 let action = AnyStateAction::new::<Counter>(CounterAction::Increment(1));
468 assert_eq!(action.scope(), StateScope::Thread);
469 }
470
471 #[test]
472 fn any_state_action_scope_tool_call() {
473 let action = AnyStateAction::new_for_call::<ToolScopedCounter>(
474 CounterAction::Increment(1),
475 "call_1",
476 );
477 assert_eq!(action.scope(), StateScope::ToolCall);
478 }
479
480 #[test]
481 fn reduce_state_actions_uses_rolling_snapshot() {
482 let base = json!({"counters": {"main": {"value": 1}}});
483 let actions = vec![
484 AnyStateAction::new::<Counter>(CounterAction::Increment(1)),
485 AnyStateAction::new::<Counter>(CounterAction::Increment(1)),
486 ];
487 let tracked = reduce_state_actions(actions, &base, "agent", &ScopeContext::run()).unwrap();
488 assert_eq!(tracked.len(), 2);
489
490 let mut state = base.clone();
491 for patch in tracked {
492 state = apply_patch(&state, patch.patch()).unwrap();
493 }
494 assert_eq!(state["counters"]["main"]["value"], 3);
495 }
496
497 #[test]
498 fn any_state_action_allows_root_path() {
499 let tracked = reduce_state_actions(
500 vec![AnyStateAction::new_at::<Counter>(
501 "",
502 CounterAction::Increment(1),
503 )],
504 &Value::Null,
505 "test",
506 &ScopeContext::run(),
507 )
508 .expect("root path action should reduce");
509 assert_eq!(tracked.len(), 1);
510 let result = apply_patch(&Value::Null, tracked[0].patch()).expect("patch should apply");
511 assert_eq!(result["value"], 1);
512 }
513
514 #[test]
515 fn reduce_tool_call_scoped_action_routes_to_call_namespace() {
516 let base = json!({});
517 let actions = vec![AnyStateAction::new_for_call::<ToolScopedCounter>(
518 CounterAction::Increment(5),
519 "call_42",
520 )];
521 let tracked = reduce_state_actions(actions, &base, "test", &ScopeContext::run()).unwrap();
522 assert_eq!(tracked.len(), 1);
523
524 let result = apply_patch(&base, tracked[0].patch()).unwrap();
525 assert_eq!(
526 result["__tool_call_scope"]["call_42"]["tool_counter"]["value"],
527 5
528 );
529 }
530
531 #[test]
532 fn reduce_run_scoped_action_ignores_call_context() {
533 let base = json!({});
534 let scope_ctx = ScopeContext::for_call("call_42");
535 let actions = vec![AnyStateAction::new::<Counter>(CounterAction::Increment(7))];
536 let tracked = reduce_state_actions(actions, &base, "test", &scope_ctx).unwrap();
537
538 let result = apply_patch(&base, tracked[0].patch()).unwrap();
539 assert_eq!(result["counters"]["main"]["value"], 7);
540 assert!(result.get("__tool_call_scope").is_none());
541 }
542
543 #[test]
544 fn new_for_call_overrides_scope_ctx() {
545 let base = json!({});
546 let scope_ctx = ScopeContext::for_call("ambient_call");
547 let actions = vec![AnyStateAction::new_for_call::<ToolScopedCounter>(
548 CounterAction::Increment(3),
549 "override_call",
550 )];
551 let tracked = reduce_state_actions(actions, &base, "test", &scope_ctx).unwrap();
552
553 let result = apply_patch(&base, tracked[0].patch()).unwrap();
554 assert_eq!(
556 result["__tool_call_scope"]["override_call"]["tool_counter"]["value"],
557 3
558 );
559 assert!(result["__tool_call_scope"].get("ambient_call").is_none());
560 }
561
562 #[test]
563 #[should_panic(expected = "requires ToolCall-scoped state")]
564 fn new_for_call_panics_on_non_tool_call_scope() {
565 let _ = AnyStateAction::new_for_call::<Unbound>((), "call_1");
566 }
567
568 #[derive(Debug, Clone, Serialize, Deserialize, Default)]
571 struct TokenStats {
572 #[serde(default)]
573 total_input: GCounter,
574 #[serde(default)]
575 total_output: GCounter,
576 #[serde(default)]
577 label: String,
578 }
579
580 struct TokenStatsRef;
581
582 impl State for TokenStats {
583 type Ref<'a> = TokenStatsRef;
584 const PATH: &'static str = "token_stats";
585
586 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
587 TokenStatsRef
588 }
589
590 fn from_value(value: &Value) -> TireaResult<Self> {
591 if value.is_null() {
592 return Ok(Self::default());
593 }
594 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
595 }
596
597 fn to_value(&self) -> TireaResult<Value> {
598 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
599 }
600
601 fn lattice_keys() -> &'static [&'static str] {
602 &["total_input", "total_output"]
603 }
604
605 fn register_lattice(registry: &mut LatticeRegistry) {
606 registry.register::<GCounter>(parse_path("token_stats.total_input"));
607 registry.register::<GCounter>(parse_path("token_stats.total_output"));
608 }
609 }
610
611 #[derive(Serialize, Deserialize)]
612 #[allow(dead_code)]
613 enum TokenStatsAction {
614 AddInput(u64),
615 AddOutput(u64),
616 }
617
618 impl StateSpec for TokenStats {
619 type Action = TokenStatsAction;
620
621 fn reduce(&mut self, action: TokenStatsAction) {
622 match action {
623 TokenStatsAction::AddInput(n) => self.total_input.increment("_", n),
624 TokenStatsAction::AddOutput(n) => self.total_output.increment("_", n),
625 }
626 }
627 }
628
629 #[test]
630 fn reducer_emits_op_set_for_crdt_fields_causing_false_conflict() {
631 let base = json!({});
633
634 let patches_a = reduce_state_actions(
635 vec![AnyStateAction::new::<TokenStats>(
636 TokenStatsAction::AddInput(100),
637 )],
638 &base,
639 "plugin_a",
640 &ScopeContext::run(),
641 )
642 .unwrap();
643 let patches_b = reduce_state_actions(
644 vec![AnyStateAction::new::<TokenStats>(
645 TokenStatsAction::AddInput(200),
646 )],
647 &base,
648 "plugin_b",
649 &ScopeContext::run(),
650 )
651 .unwrap();
652
653 let mut registry = LatticeRegistry::new();
655 registry.register::<GCounter>(parse_path("token_stats.total_input"));
656 registry.register::<GCounter>(parse_path("token_stats.total_output"));
657
658 let conflicts =
659 conflicts_with_registry(patches_a[0].patch(), patches_b[0].patch(), ®istry);
660
661 assert!(
663 conflicts.is_empty(),
664 "CRDT fields should not conflict; reducer should emit Op::LatticeMerge for lattice fields"
665 );
666 }
667
668 #[test]
669 fn reducer_emits_lattice_merge_ops_for_crdt_fields() {
670 let base = json!({});
671 let patches = reduce_state_actions(
672 vec![AnyStateAction::new::<TokenStats>(
673 TokenStatsAction::AddInput(100),
674 )],
675 &base,
676 "test",
677 &ScopeContext::run(),
678 )
679 .unwrap();
680
681 let ops = patches[0].patch().ops();
682 let has_lattice_merge = ops.iter().any(|op| matches!(op, Op::LatticeMerge { .. }));
684 assert!(
685 has_lattice_merge,
686 "reducer should emit Op::LatticeMerge for CRDT fields, got: {ops:?}"
687 );
688 }
689
690 #[test]
691 fn reducer_mixed_fields_emits_correct_op_types() {
692 #[derive(Debug, Clone, Serialize, Deserialize, Default)]
694 struct MixedState {
695 #[serde(default)]
696 counter: GCounter,
697 #[serde(default)]
698 name: String,
699 }
700
701 struct MixedStateRef;
702
703 impl State for MixedState {
704 type Ref<'a> = MixedStateRef;
705 const PATH: &'static str = "mixed";
706
707 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
708 MixedStateRef
709 }
710
711 fn from_value(value: &Value) -> TireaResult<Self> {
712 if value.is_null() {
713 return Ok(Self::default());
714 }
715 serde_json::from_value(value.clone())
716 .map_err(tirea_state::TireaError::Serialization)
717 }
718
719 fn to_value(&self) -> TireaResult<Value> {
720 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
721 }
722
723 fn lattice_keys() -> &'static [&'static str] {
724 &["counter"]
725 }
726 }
727
728 #[derive(Serialize, Deserialize)]
729 enum MixedAction {
730 IncrementAndRename(u64, String),
731 }
732
733 impl StateSpec for MixedState {
734 type Action = MixedAction;
735
736 fn reduce(&mut self, action: MixedAction) {
737 match action {
738 MixedAction::IncrementAndRename(n, name) => {
739 self.counter.increment("_", n);
740 self.name = name;
741 }
742 }
743 }
744 }
745
746 let base = json!({});
747 let patches = reduce_state_actions(
748 vec![AnyStateAction::new::<MixedState>(
749 MixedAction::IncrementAndRename(5, "new".to_string()),
750 )],
751 &base,
752 "test",
753 &ScopeContext::run(),
754 )
755 .unwrap();
756
757 let ops = patches[0].patch().ops();
758 let lattice_ops: Vec<_> = ops
759 .iter()
760 .filter(|op| matches!(op, Op::LatticeMerge { .. }))
761 .collect();
762 let set_ops: Vec<_> = ops
763 .iter()
764 .filter(|op| matches!(op, Op::Set { .. }))
765 .collect();
766
767 assert!(
768 !lattice_ops.is_empty(),
769 "should have LatticeMerge for CRDT field 'counter'"
770 );
771 assert!(
772 !set_ops.is_empty(),
773 "should have Op::set for non-CRDT field 'name'"
774 );
775 }
776
777 #[test]
778 fn diff_ops_skips_unchanged_fields() {
779 let base = json!({"token_stats": {"total_input": {}, "total_output": {}, "label": ""}});
781 let patches = reduce_state_actions(
782 vec![AnyStateAction::new::<TokenStats>(
783 TokenStatsAction::AddInput(42),
784 )],
785 &base,
786 "test",
787 &ScopeContext::run(),
788 )
789 .unwrap();
790
791 let ops = patches[0].patch().ops();
792 assert_eq!(
794 ops.len(),
795 1,
796 "should only emit op for the changed field, got: {ops:?}"
797 );
798 assert!(
799 matches!(&ops[0], Op::LatticeMerge { .. }),
800 "changed CRDT field should use LatticeMerge"
801 );
802 }
803
804 #[test]
805 fn diff_ops_empty_when_no_changes() {
806 let base = json!({"counters": {"main": {"value": 0}}});
808 let patches = reduce_state_actions(
809 vec![AnyStateAction::new::<Counter>(CounterAction::Reset)],
810 &base,
811 "test",
812 &ScopeContext::run(),
813 )
814 .unwrap();
815
816 assert!(
818 patches.is_empty(),
819 "no-op reduce should produce no patches, got: {patches:?}"
820 );
821 }
822
823 #[test]
824 fn serialized_payload_is_captured() {
825 let action = AnyStateAction::new::<Counter>(CounterAction::Increment(42));
826 let payload = action.serialized_payload();
827 assert_eq!(*payload, json!({"Increment": 42}));
828 }
829
830 #[test]
831 fn serialized_payload_captured_for_call_scoped() {
832 let action =
833 AnyStateAction::new_for_call::<ToolScopedCounter>(CounterAction::Reset, "call_1");
834 let payload = action.serialized_payload();
835 assert_eq!(*payload, json!("Reset"));
836 }
837}