tirea_agentos/runtime/loop_runner/
state_commit.rs

1use super::{AgentLoopError, RunIdentity, StateCommitError, StateCommitter};
2use crate::contracts::storage::{RunOrigin, VersionPrecondition};
3use crate::contracts::thread::CheckpointReason;
4use crate::contracts::{RunContext, RunMeta, TerminationReason, ThreadChangeSet};
5use async_trait::async_trait;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9#[derive(Clone)]
10pub struct ChannelStateCommitter {
11    tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>,
12    version: Arc<AtomicU64>,
13}
14
15impl ChannelStateCommitter {
16    pub fn new(tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>) -> Self {
17        Self {
18            tx,
19            version: Arc::new(AtomicU64::new(0)),
20        }
21    }
22}
23
24#[async_trait]
25impl StateCommitter for ChannelStateCommitter {
26    async fn commit(
27        &self,
28        _thread_id: &str,
29        changeset: ThreadChangeSet,
30        _precondition: VersionPrecondition,
31    ) -> Result<u64, StateCommitError> {
32        let next_version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
33        self.tx
34            .send(changeset)
35            .map_err(|e| StateCommitError::new(format!("channel state commit failed: {e}")))?;
36        Ok(next_version)
37    }
38}
39
40pub(super) async fn commit_pending_delta(
41    run_ctx: &mut RunContext,
42    reason: CheckpointReason,
43    force: bool,
44    state_committer: Option<&Arc<dyn StateCommitter>>,
45    run_identity: &RunIdentity,
46    termination: Option<&TerminationReason>,
47) -> Result<(), AgentLoopError> {
48    let Some(committer) = state_committer else {
49        return Ok(());
50    };
51
52    let delta = run_ctx.take_delta();
53    if !force && delta.is_empty() {
54        return Ok(());
55    }
56
57    // On RunFinished, write a full state snapshot to bound the action/patch
58    // replay window to a single run.
59    let snapshot = if reason == CheckpointReason::RunFinished {
60        match run_ctx.snapshot() {
61            Ok(state) => Some(state),
62            Err(e) => {
63                tracing::warn!(error = %e, "failed to compute RunFinished snapshot; continuing without snapshot");
64                None
65            }
66        }
67    } else {
68        None
69    };
70
71    let mut changeset = ThreadChangeSet::from_parts(
72        run_identity.run_id.clone(),
73        run_identity.parent_run_id.clone(),
74        reason,
75        delta.messages,
76        delta.patches,
77        delta.state_actions,
78        snapshot,
79    );
80
81    // Loop always emits run-finished RunMeta. Whether this metadata is used to
82    // materialize/maintain durable run mappings is decided by the outer
83    // orchestration layer's StateCommitter policy.
84    if let Some(termination) = termination {
85        let agent_id = run_identity.agent_id.clone();
86        let origin: RunOrigin = run_identity.origin;
87        let parent_thread_id = None; // Already set on the initial changeset.
88        let (status, termination_code, termination_detail) = map_termination(termination);
89        changeset.run_meta = Some(RunMeta {
90            agent_id,
91            origin,
92            status,
93            parent_thread_id,
94            termination_code,
95            termination_detail,
96            source_mailbox_entry_id: None,
97        });
98    }
99
100    let precondition = VersionPrecondition::Exact(run_ctx.version());
101    let committed_version = committer
102        .commit(run_ctx.thread_id(), changeset, precondition)
103        .await
104        .map_err(|e| AgentLoopError::StateError(format!("state commit failed: {e}")))?;
105    run_ctx.set_version(committed_version, Some(super::current_unix_millis()));
106    Ok(())
107}
108
109fn map_termination(
110    termination: &TerminationReason,
111) -> (
112    crate::contracts::storage::RunStatus,
113    Option<String>,
114    Option<String>,
115) {
116    let (status, _) = termination.to_run_status();
117    match termination {
118        TerminationReason::NaturalEnd => (status, Some("natural".to_string()), None),
119        TerminationReason::BehaviorRequested => {
120            (status, Some("behavior_requested".to_string()), None)
121        }
122        TerminationReason::Suspended => (status, Some("input_required".to_string()), None),
123        TerminationReason::Cancelled => (status, Some("cancelled".to_string()), None),
124        TerminationReason::Error(message) => {
125            (status, Some("error".to_string()), Some(message.clone()))
126        }
127        TerminationReason::Stopped(stopped) => (
128            status,
129            Some(stopped.code.trim().to_ascii_lowercase()),
130            stopped.detail.clone(),
131        ),
132    }
133}
134
135pub(super) struct PendingDeltaCommitContext<'a> {
136    run_identity: &'a RunIdentity,
137    state_committer: Option<&'a Arc<dyn StateCommitter>>,
138}
139
140impl<'a> PendingDeltaCommitContext<'a> {
141    pub(super) fn new(
142        run_identity: &'a RunIdentity,
143        state_committer: Option<&'a Arc<dyn StateCommitter>>,
144    ) -> Self {
145        Self {
146            run_identity,
147            state_committer,
148        }
149    }
150
151    pub(super) async fn commit(
152        &self,
153        run_ctx: &mut RunContext,
154        reason: CheckpointReason,
155        force: bool,
156    ) -> Result<(), AgentLoopError> {
157        commit_pending_delta(
158            run_ctx,
159            reason,
160            force,
161            self.state_committer,
162            self.run_identity,
163            None,
164        )
165        .await
166    }
167
168    pub(super) async fn commit_run_finished(
169        &self,
170        run_ctx: &mut RunContext,
171        termination: &TerminationReason,
172    ) -> Result<(), AgentLoopError> {
173        commit_pending_delta(
174            run_ctx,
175            CheckpointReason::RunFinished,
176            true,
177            self.state_committer,
178            self.run_identity,
179            Some(termination),
180        )
181        .await
182    }
183}