tirea_agentos/runtime/
run_context.rs1use crate::contracts::storage::VersionPrecondition;
2use crate::contracts::thread::ThreadChangeSet;
3use async_trait::async_trait;
4use futures::future::pending;
5use thiserror::Error;
6use tokio_util::sync::CancellationToken;
7
8pub type RunCancellationToken = CancellationToken;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum CancelAware<T> {
12 Value(T),
13 Cancelled,
14}
15
16pub fn is_cancelled(token: Option<&RunCancellationToken>) -> bool {
17 token.is_some_and(RunCancellationToken::is_cancelled)
18}
19
20pub async fn cancelled(token: Option<&RunCancellationToken>) {
21 if let Some(token) = token {
22 token.cancelled().await;
23 } else {
24 pending::<()>().await;
25 }
26}
27
28pub async fn await_or_cancel<T, F>(token: Option<&RunCancellationToken>, fut: F) -> CancelAware<T>
29where
30 F: std::future::Future<Output = T>,
31{
32 if let Some(token) = token {
33 tokio::select! {
34 _ = token.cancelled() => CancelAware::Cancelled,
35 value = fut => CancelAware::Value(value),
36 }
37 } else {
38 CancelAware::Value(fut.await)
39 }
40}
41
42#[derive(Debug, Clone, Error)]
44#[error("{message}")]
45pub struct StateCommitError {
46 pub message: String,
47}
48
49impl StateCommitError {
50 pub fn new(message: impl Into<String>) -> Self {
51 Self {
52 message: message.into(),
53 }
54 }
55}
56
57#[async_trait]
59pub trait StateCommitter: Send + Sync {
60 async fn commit(
64 &self,
65 thread_id: &str,
66 changeset: ThreadChangeSet,
67 precondition: VersionPrecondition,
68 ) -> Result<u64, StateCommitError>;
69}
70
71impl std::fmt::Debug for dyn StateCommitter {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.write_str("<StateCommitter>")
74 }
75}
76
77pub const TOOL_SCOPE_CALLER_THREAD_ID_KEY: &str = "__agent_tool_caller_thread_id";
79pub const TOOL_SCOPE_CALLER_AGENT_ID_KEY: &str = "__agent_tool_caller_agent_id";
81pub const TOOL_SCOPE_CALLER_STATE_KEY: &str = "__agent_tool_caller_state";
83pub const TOOL_SCOPE_CALLER_MESSAGES_KEY: &str = "__agent_tool_caller_messages";
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use tokio::time::{timeout, Duration};
90
91 #[tokio::test]
92 async fn await_or_cancel_returns_value_without_token() {
93 let out = await_or_cancel(None, async { 42usize }).await;
94 assert_eq!(out, CancelAware::Value(42));
95 }
96
97 #[tokio::test]
98 async fn await_or_cancel_returns_cancelled_when_token_cancelled() {
99 let token = RunCancellationToken::new();
100 let token_for_task = token.clone();
101 let handle = tokio::spawn(async move {
102 await_or_cancel(Some(&token_for_task), async {
103 tokio::time::sleep(Duration::from_secs(5)).await;
104 7usize
105 })
106 .await
107 });
108
109 token.cancel();
110 let out = timeout(Duration::from_millis(300), handle)
111 .await
112 .expect("await_or_cancel should resolve quickly after cancellation")
113 .expect("task should not panic");
114 assert_eq!(out, CancelAware::Cancelled);
115 }
116
117 #[tokio::test]
118 async fn cancelled_waits_for_token_signal() {
119 let token = RunCancellationToken::new();
120 let token_for_task = token.clone();
121 let handle = tokio::spawn(async move {
122 cancelled(Some(&token_for_task)).await;
123 true
124 });
125
126 token.cancel();
127 let done = timeout(Duration::from_millis(300), handle)
128 .await
129 .expect("cancelled() should return after token cancellation")
130 .expect("task should not panic");
131 assert!(done);
132 }
133}