tirea_agentos/runtime/loop_runner/
state_commit.rs1use super::{AgentLoopError, RunIdentity, StateCommitError, StateCommitter};
2use crate::contracts::storage::{RunOrigin, VersionPrecondition};
3use crate::contracts::thread::CheckpointReason;
4use crate::contracts::{RunContext, RunMeta, TerminationReason, ThreadChangeSet};
5use async_trait::async_trait;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9#[derive(Clone)]
10pub struct ChannelStateCommitter {
11 tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>,
12 version: Arc<AtomicU64>,
13}
14
15impl ChannelStateCommitter {
16 pub fn new(tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>) -> Self {
17 Self {
18 tx,
19 version: Arc::new(AtomicU64::new(0)),
20 }
21 }
22}
23
24#[async_trait]
25impl StateCommitter for ChannelStateCommitter {
26 async fn commit(
27 &self,
28 _thread_id: &str,
29 changeset: ThreadChangeSet,
30 _precondition: VersionPrecondition,
31 ) -> Result<u64, StateCommitError> {
32 let next_version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
33 self.tx
34 .send(changeset)
35 .map_err(|e| StateCommitError::new(format!("channel state commit failed: {e}")))?;
36 Ok(next_version)
37 }
38}
39
40pub(super) async fn commit_pending_delta(
41 run_ctx: &mut RunContext,
42 reason: CheckpointReason,
43 force: bool,
44 state_committer: Option<&Arc<dyn StateCommitter>>,
45 run_identity: &RunIdentity,
46 termination: Option<&TerminationReason>,
47) -> Result<(), AgentLoopError> {
48 let Some(committer) = state_committer else {
49 return Ok(());
50 };
51
52 let delta = run_ctx.take_delta();
53 if !force && delta.is_empty() {
54 return Ok(());
55 }
56
57 let snapshot = if reason == CheckpointReason::RunFinished {
60 match run_ctx.snapshot() {
61 Ok(state) => Some(state),
62 Err(e) => {
63 tracing::warn!(error = %e, "failed to compute RunFinished snapshot; continuing without snapshot");
64 None
65 }
66 }
67 } else {
68 None
69 };
70
71 let mut changeset = ThreadChangeSet::from_parts(
72 run_identity.run_id.clone(),
73 run_identity.parent_run_id.clone(),
74 reason,
75 delta.messages,
76 delta.patches,
77 delta.state_actions,
78 snapshot,
79 );
80
81 if let Some(termination) = termination {
85 let agent_id = run_identity.agent_id.clone();
86 let origin: RunOrigin = run_identity.origin;
87 let parent_thread_id = None; let (status, termination_code, termination_detail) = map_termination(termination);
89 changeset.run_meta = Some(RunMeta {
90 agent_id,
91 origin,
92 status,
93 parent_thread_id,
94 termination_code,
95 termination_detail,
96 source_mailbox_entry_id: None,
97 });
98 }
99
100 let precondition = VersionPrecondition::Exact(run_ctx.version());
101 let committed_version = committer
102 .commit(run_ctx.thread_id(), changeset, precondition)
103 .await
104 .map_err(|e| AgentLoopError::StateError(format!("state commit failed: {e}")))?;
105 run_ctx.set_version(committed_version, Some(super::current_unix_millis()));
106 Ok(())
107}
108
109fn map_termination(
110 termination: &TerminationReason,
111) -> (
112 crate::contracts::storage::RunStatus,
113 Option<String>,
114 Option<String>,
115) {
116 let (status, _) = termination.to_run_status();
117 match termination {
118 TerminationReason::NaturalEnd => (status, Some("natural".to_string()), None),
119 TerminationReason::BehaviorRequested => {
120 (status, Some("behavior_requested".to_string()), None)
121 }
122 TerminationReason::Suspended => (status, Some("input_required".to_string()), None),
123 TerminationReason::Cancelled => (status, Some("cancelled".to_string()), None),
124 TerminationReason::Error(message) => {
125 (status, Some("error".to_string()), Some(message.clone()))
126 }
127 TerminationReason::Stopped(stopped) => (
128 status,
129 Some(stopped.code.trim().to_ascii_lowercase()),
130 stopped.detail.clone(),
131 ),
132 }
133}
134
135pub(super) struct PendingDeltaCommitContext<'a> {
136 run_identity: &'a RunIdentity,
137 state_committer: Option<&'a Arc<dyn StateCommitter>>,
138}
139
140impl<'a> PendingDeltaCommitContext<'a> {
141 pub(super) fn new(
142 run_identity: &'a RunIdentity,
143 state_committer: Option<&'a Arc<dyn StateCommitter>>,
144 ) -> Self {
145 Self {
146 run_identity,
147 state_committer,
148 }
149 }
150
151 pub(super) async fn commit(
152 &self,
153 run_ctx: &mut RunContext,
154 reason: CheckpointReason,
155 force: bool,
156 ) -> Result<(), AgentLoopError> {
157 commit_pending_delta(
158 run_ctx,
159 reason,
160 force,
161 self.state_committer,
162 self.run_identity,
163 None,
164 )
165 .await
166 }
167
168 pub(super) async fn commit_run_finished(
169 &self,
170 run_ctx: &mut RunContext,
171 termination: &TerminationReason,
172 ) -> Result<(), AgentLoopError> {
173 commit_pending_delta(
174 run_ctx,
175 CheckpointReason::RunFinished,
176 true,
177 self.state_committer,
178 self.run_identity,
179 Some(termination),
180 )
181 .await
182 }
183}