tirea_agentos/runtime/
run_context.rs

1use crate::contracts::storage::VersionPrecondition;
2use crate::contracts::thread::ThreadChangeSet;
3use async_trait::async_trait;
4use futures::future::pending;
5use thiserror::Error;
6use tokio_util::sync::CancellationToken;
7
8pub type RunCancellationToken = CancellationToken;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum CancelAware<T> {
12    Value(T),
13    Cancelled,
14}
15
16pub fn is_cancelled(token: Option<&RunCancellationToken>) -> bool {
17    token.is_some_and(RunCancellationToken::is_cancelled)
18}
19
20pub async fn cancelled(token: Option<&RunCancellationToken>) {
21    if let Some(token) = token {
22        token.cancelled().await;
23    } else {
24        pending::<()>().await;
25    }
26}
27
28pub async fn await_or_cancel<T, F>(token: Option<&RunCancellationToken>, fut: F) -> CancelAware<T>
29where
30    F: std::future::Future<Output = T>,
31{
32    if let Some(token) = token {
33        tokio::select! {
34            _ = token.cancelled() => CancelAware::Cancelled,
35            value = fut => CancelAware::Value(value),
36        }
37    } else {
38        CancelAware::Value(fut.await)
39    }
40}
41
42/// Error returned by state commit sinks.
43#[derive(Debug, Clone, Error)]
44#[error("{message}")]
45pub struct StateCommitError {
46    pub message: String,
47}
48
49impl StateCommitError {
50    pub fn new(message: impl Into<String>) -> Self {
51        Self {
52            message: message.into(),
53        }
54    }
55}
56
57/// Sink for committed thread deltas.
58#[async_trait]
59pub trait StateCommitter: Send + Sync {
60    /// Commit a single change set for a thread.
61    ///
62    /// Returns the committed storage version after the write succeeds.
63    async fn commit(
64        &self,
65        thread_id: &str,
66        changeset: ThreadChangeSet,
67        precondition: VersionPrecondition,
68    ) -> Result<u64, StateCommitError>;
69}
70
71impl std::fmt::Debug for dyn StateCommitter {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.write_str("<StateCommitter>")
74    }
75}
76
77/// Scope key: caller session id visible to tools.
78pub const TOOL_SCOPE_CALLER_THREAD_ID_KEY: &str = "__agent_tool_caller_thread_id";
79/// Scope key: caller agent id visible to tools.
80pub const TOOL_SCOPE_CALLER_AGENT_ID_KEY: &str = "__agent_tool_caller_agent_id";
81/// Scope key: caller state snapshot visible to tools.
82pub const TOOL_SCOPE_CALLER_STATE_KEY: &str = "__agent_tool_caller_state";
83/// Scope key: caller message snapshot visible to tools.
84pub const TOOL_SCOPE_CALLER_MESSAGES_KEY: &str = "__agent_tool_caller_messages";
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use tokio::time::{timeout, Duration};
90
91    #[tokio::test]
92    async fn await_or_cancel_returns_value_without_token() {
93        let out = await_or_cancel(None, async { 42usize }).await;
94        assert_eq!(out, CancelAware::Value(42));
95    }
96
97    #[tokio::test]
98    async fn await_or_cancel_returns_cancelled_when_token_cancelled() {
99        let token = RunCancellationToken::new();
100        let token_for_task = token.clone();
101        let handle = tokio::spawn(async move {
102            await_or_cancel(Some(&token_for_task), async {
103                tokio::time::sleep(Duration::from_secs(5)).await;
104                7usize
105            })
106            .await
107        });
108
109        token.cancel();
110        let out = timeout(Duration::from_millis(300), handle)
111            .await
112            .expect("await_or_cancel should resolve quickly after cancellation")
113            .expect("task should not panic");
114        assert_eq!(out, CancelAware::Cancelled);
115    }
116
117    #[tokio::test]
118    async fn cancelled_waits_for_token_signal() {
119        let token = RunCancellationToken::new();
120        let token_for_task = token.clone();
121        let handle = tokio::spawn(async move {
122            cancelled(Some(&token_for_task)).await;
123            true
124        });
125
126        token.cancel();
127        let done = timeout(Duration::from_millis(300), handle)
128            .await
129            .expect("cancelled() should return after token cancellation")
130            .expect("task should not panic");
131        assert!(done);
132    }
133}