1use crate::runtime::activity::ActivityManager;
8use crate::runtime::run::RunIdentity;
9use crate::runtime::{ToolCallResume, ToolCallState};
10use crate::thread::Message;
11use crate::RunPolicy;
12use futures::future::pending;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::sync::{Arc, Mutex};
16use std::time::{SystemTime, UNIX_EPOCH};
17use tirea_state::{
18 get_at_path, parse_path, DocCell, Op, Patch, PatchSink, Path, State, TireaError, TireaResult,
19 TrackedPatch,
20};
21use tokio_util::sync::CancellationToken;
22
23type PatchHook<'a> = Arc<dyn Fn(&Op) -> TireaResult<()> + Send + Sync + 'a>;
24const TOOL_PROGRESS_STREAM_PREFIX: &str = "tool_call:";
25pub const TOOL_CALL_PROGRESS_ACTIVITY_TYPE: &str = "tool-call-progress";
27pub const TOOL_PROGRESS_ACTIVITY_TYPE: &str = TOOL_CALL_PROGRESS_ACTIVITY_TYPE;
29pub const TOOL_PROGRESS_ACTIVITY_TYPE_LEGACY: &str = "progress";
31pub const TOOL_CALL_PROGRESS_TYPE: &str = "tool-call-progress";
33pub const TOOL_CALL_PROGRESS_SCHEMA: &str = "tool-call-progress.v1";
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum ToolCallProgressStatus {
40 Pending,
41 #[default]
42 Running,
43 Done,
44 Failed,
45 Cancelled,
46}
47
48#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
50pub struct ToolCallProgressState {
51 #[serde(rename = "type")]
53 pub event_type: String,
54 pub schema: String,
56 pub node_id: String,
58 #[serde(default)]
60 pub parent_node_id: Option<String>,
61 #[serde(default)]
63 pub parent_call_id: Option<String>,
64 pub call_id: String,
66 #[serde(default)]
68 pub tool_name: Option<String>,
69 pub status: ToolCallProgressStatus,
71 #[serde(default)]
73 pub progress: Option<f64>,
74 #[serde(default)]
76 pub loaded: Option<f64>,
77 #[serde(default)]
79 pub total: Option<f64>,
80 #[serde(default)]
82 pub message: Option<String>,
83 #[serde(default)]
85 pub run_id: Option<String>,
86 #[serde(default)]
88 pub parent_run_id: Option<String>,
89 #[serde(default)]
91 pub thread_id: Option<String>,
92 pub updated_at_ms: u64,
94}
95
96#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct ToolCallProgressUpdate {
99 #[serde(default)]
100 pub status: ToolCallProgressStatus,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub progress: Option<f64>,
103 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub loaded: Option<f64>,
105 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub total: Option<f64>,
107 #[serde(default, skip_serializing_if = "Option::is_none")]
108 pub message: Option<String>,
109}
110
111#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
113pub struct ToolProgressState {
114 pub progress: f64,
116 #[serde(default, skip_serializing_if = "Option::is_none")]
118 pub total: Option<f64>,
119 #[serde(default, skip_serializing_if = "Option::is_none")]
121 pub message: Option<String>,
122}
123
124pub trait ToolCallProgressSink: Send + Sync {
130 fn report(
132 &self,
133 stream_id: &str,
134 activity_type: &str,
135 payload: &ToolCallProgressState,
136 ) -> TireaResult<()>;
137}
138
139#[derive(Clone)]
140struct ActivityManagerProgressSink {
141 manager: Arc<dyn ActivityManager>,
142}
143
144impl ActivityManagerProgressSink {
145 fn new(manager: Arc<dyn ActivityManager>) -> Self {
146 Self { manager }
147 }
148}
149
150#[derive(Clone, Debug, Default)]
152pub struct CallerContext {
153 thread_id: Option<String>,
154 run_id: Option<String>,
155 agent_id: Option<String>,
156 messages: Arc<[Arc<Message>]>,
157}
158
159impl CallerContext {
160 pub fn new(
161 thread_id: Option<String>,
162 run_id: Option<String>,
163 agent_id: Option<String>,
164 messages: Vec<Arc<Message>>,
165 ) -> Self {
166 Self {
167 thread_id: thread_id
168 .map(|value| value.trim().to_string())
169 .filter(|value| !value.is_empty()),
170 run_id: run_id
171 .map(|value| value.trim().to_string())
172 .filter(|value| !value.is_empty()),
173 agent_id: agent_id
174 .map(|value| value.trim().to_string())
175 .filter(|value| !value.is_empty()),
176 messages: Arc::<[Arc<Message>]>::from(messages),
177 }
178 }
179
180 pub fn thread_id(&self) -> Option<&str> {
181 self.thread_id.as_deref()
182 }
183
184 pub fn run_id(&self) -> Option<&str> {
185 self.run_id.as_deref()
186 }
187
188 pub fn agent_id(&self) -> Option<&str> {
189 self.agent_id.as_deref()
190 }
191
192 pub fn messages(&self) -> &[Arc<Message>] {
193 self.messages.as_ref()
194 }
195}
196
197impl ToolCallProgressSink for ActivityManagerProgressSink {
198 fn report(
199 &self,
200 stream_id: &str,
201 activity_type: &str,
202 payload: &ToolCallProgressState,
203 ) -> TireaResult<()> {
204 let Value::Object(fields) = serde_json::to_value(payload)? else {
205 return Err(TireaError::invalid_operation(
206 "tool-call-progress payload must serialize as object",
207 ));
208 };
209 for (key, value) in fields {
210 let op = Op::set(Path::root().key(key), value);
211 self.manager.on_activity_op(stream_id, activity_type, &op);
212 }
213 Ok(())
214 }
215}
216
217pub struct ToolCallContext<'a> {
223 doc: &'a DocCell,
224 ops: &'a Mutex<Vec<Op>>,
225 call_id: String,
226 source: String,
227 run_policy: &'a RunPolicy,
228 run_identity: RunIdentity,
229 caller_context: CallerContext,
230 pending_messages: &'a Mutex<Vec<Arc<Message>>>,
231 activity_manager: Arc<dyn ActivityManager>,
232 tool_call_progress_sink: Arc<dyn ToolCallProgressSink>,
233 cancellation_token: Option<&'a CancellationToken>,
234 read_only: bool,
235}
236
237impl<'a> ToolCallContext<'a> {
238 fn tool_call_state_path(call_id: &str) -> Path {
239 Path::root()
240 .key("__tool_call_scope")
241 .key(call_id)
242 .key("tool_call_state")
243 }
244
245 fn apply_op(&self, op: Op) -> TireaResult<()> {
246 if self.read_only {
247 return Err(TireaError::invalid_operation(
248 "tool context is read-only; emit ToolExecutionEffect actions instead",
249 ));
250 }
251 self.doc.apply(&op)?;
252 self.ops.lock().unwrap().push(op);
253 Ok(())
254 }
255
256 pub fn new(
258 doc: &'a DocCell,
259 ops: &'a Mutex<Vec<Op>>,
260 call_id: impl Into<String>,
261 source: impl Into<String>,
262 run_policy: &'a RunPolicy,
263 pending_messages: &'a Mutex<Vec<Arc<Message>>>,
264 activity_manager: Arc<dyn ActivityManager>,
265 ) -> Self {
266 let tool_call_progress_sink: Arc<dyn ToolCallProgressSink> =
267 Arc::new(ActivityManagerProgressSink::new(activity_manager.clone()));
268 Self {
269 doc,
270 ops,
271 call_id: call_id.into(),
272 source: source.into(),
273 run_policy,
274 run_identity: RunIdentity::default(),
275 caller_context: CallerContext::default(),
276 pending_messages,
277 activity_manager,
278 tool_call_progress_sink,
279 cancellation_token: None,
280 read_only: false,
281 }
282 }
283
284 #[must_use]
286 pub fn as_read_only(mut self) -> Self {
287 self.read_only = true;
288 self
289 }
290
291 #[must_use]
293 pub fn with_cancellation_token(mut self, token: &'a CancellationToken) -> Self {
294 self.cancellation_token = Some(token);
295 self
296 }
297
298 #[must_use]
299 pub fn with_run_identity(mut self, run_identity: RunIdentity) -> Self {
300 self.run_identity = run_identity;
301 self
302 }
303
304 #[must_use]
305 pub fn with_caller_context(mut self, caller_context: CallerContext) -> Self {
306 self.caller_context = caller_context;
307 self
308 }
309
310 #[must_use]
315 pub fn with_tool_call_progress_sink(mut self, sink: Arc<dyn ToolCallProgressSink>) -> Self {
316 self.tool_call_progress_sink = sink;
317 self
318 }
319
320 pub fn doc(&self) -> &DocCell {
326 self.doc
327 }
328
329 pub fn call_id(&self) -> &str {
331 &self.call_id
332 }
333
334 pub fn idempotency_key(&self) -> &str {
338 self.call_id()
339 }
340
341 pub fn source(&self) -> &str {
343 &self.source
344 }
345
346 pub fn is_cancelled(&self) -> bool {
348 self.cancellation_token
349 .is_some_and(CancellationToken::is_cancelled)
350 }
351
352 pub async fn cancelled(&self) {
356 if let Some(token) = self.cancellation_token {
357 token.cancelled().await;
358 } else {
359 pending::<()>().await;
360 }
361 }
362
363 pub fn cancellation_token(&self) -> Option<&CancellationToken> {
365 self.cancellation_token
366 }
367
368 pub fn run_policy(&self) -> &RunPolicy {
374 self.run_policy
375 }
376
377 pub fn run_identity(&self) -> &RunIdentity {
378 &self.run_identity
379 }
380
381 pub fn caller_context(&self) -> &CallerContext {
382 &self.caller_context
383 }
384
385 pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
391 let base = parse_path(path);
392 let doc = self.doc;
393 let read_only = self.read_only;
394 let hook: PatchHook<'_> = Arc::new(move |op: &Op| {
395 if read_only {
396 return Err(TireaError::invalid_operation(
397 "tool context is read-only; emit ToolExecutionEffect actions instead",
398 ));
399 }
400 doc.apply(op)?;
401 Ok(())
402 });
403 T::state_ref(doc, base, PatchSink::new_with_hook(self.ops, hook))
404 }
405
406 pub fn state_of<T: State>(&self) -> T::Ref<'_> {
410 assert!(
411 !T::PATH.is_empty(),
412 "State type has no bound path; use state::<T>(path) instead"
413 );
414 self.state::<T>(T::PATH)
415 }
416
417 pub fn call_state<T: State>(&self) -> T::Ref<'_> {
419 let path = format!("tool_calls.{}", self.call_id);
420 self.state::<T>(&path)
421 }
422
423 pub fn tool_call_state_for(&self, call_id: &str) -> TireaResult<Option<ToolCallState>> {
425 if call_id.trim().is_empty() {
426 return Ok(None);
427 }
428 let val = self.doc.snapshot();
429 let path = Self::tool_call_state_path(call_id);
430 let at = get_at_path(&val, &path);
431 match at {
432 Some(v) if !v.is_null() => {
433 let state = ToolCallState::from_value(v)?;
434 Ok(Some(state))
435 }
436 _ => Ok(None),
437 }
438 }
439
440 pub fn tool_call_state(&self) -> TireaResult<Option<ToolCallState>> {
442 self.tool_call_state_for(self.call_id())
443 }
444
445 pub fn set_tool_call_state_for(&self, call_id: &str, state: ToolCallState) -> TireaResult<()> {
447 if call_id.trim().is_empty() {
448 return Err(TireaError::invalid_operation(
449 "tool_call_state requires non-empty call_id",
450 ));
451 }
452 let value = serde_json::to_value(state)?;
453 self.apply_op(Op::set(Self::tool_call_state_path(call_id), value))
454 }
455
456 pub fn set_tool_call_state(&self, state: ToolCallState) -> TireaResult<()> {
458 self.set_tool_call_state_for(self.call_id(), state)
459 }
460
461 pub fn clear_tool_call_state_for(&self, call_id: &str) -> TireaResult<()> {
463 if call_id.trim().is_empty() {
464 return Ok(());
465 }
466 if self.tool_call_state_for(call_id)?.is_some() {
467 self.apply_op(Op::delete(Self::tool_call_state_path(call_id)))?;
468 }
469 Ok(())
470 }
471
472 pub fn clear_tool_call_state(&self) -> TireaResult<()> {
474 self.clear_tool_call_state_for(self.call_id())
475 }
476
477 pub fn resume_input_for(&self, call_id: &str) -> TireaResult<Option<ToolCallResume>> {
479 Ok(self
480 .tool_call_state_for(call_id)?
481 .and_then(|state| state.resume))
482 }
483
484 pub fn resume_input(&self) -> TireaResult<Option<ToolCallResume>> {
486 self.resume_input_for(self.call_id())
487 }
488
489 pub fn add_message(&self, message: Message) {
495 self.pending_messages
496 .lock()
497 .unwrap()
498 .push(Arc::new(message));
499 }
500
501 pub fn add_messages(&self, messages: impl IntoIterator<Item = Message>) {
503 self.pending_messages
504 .lock()
505 .unwrap()
506 .extend(messages.into_iter().map(Arc::new));
507 }
508
509 pub fn activity(
515 &self,
516 stream_id: impl Into<String>,
517 activity_type: impl Into<String>,
518 ) -> ActivityContext {
519 let stream_id = stream_id.into();
520 let activity_type = activity_type.into();
521 let snapshot = self.activity_manager.snapshot(&stream_id);
522
523 ActivityContext::new(
524 snapshot,
525 stream_id,
526 activity_type,
527 self.activity_manager.clone(),
528 )
529 }
530
531 pub fn progress_stream_id(&self) -> String {
533 format!("{TOOL_PROGRESS_STREAM_PREFIX}{}", self.call_id)
534 }
535
536 fn source_tool_name(&self) -> Option<String> {
537 self.source
538 .strip_prefix("tool:")
539 .filter(|name| !name.trim().is_empty())
540 .map(ToOwned::to_owned)
541 }
542
543 fn validate_progress_value(name: &str, value: Option<f64>) -> TireaResult<()> {
544 let Some(value) = value else {
545 return Ok(());
546 };
547 if !value.is_finite() {
548 return Err(TireaError::invalid_operation(format!(
549 "{name} must be a finite number"
550 )));
551 }
552 if value < 0.0 {
553 return Err(TireaError::invalid_operation(format!(
554 "{name} must be non-negative"
555 )));
556 }
557 Ok(())
558 }
559
560 pub fn report_tool_call_progress(&self, update: ToolCallProgressUpdate) -> TireaResult<()> {
565 Self::validate_progress_value("progress value", update.progress)?;
566 Self::validate_progress_value("progress loaded", update.loaded)?;
567 Self::validate_progress_value("progress total", update.total)?;
568
569 let run_id = self.run_identity.run_id_opt().map(ToOwned::to_owned);
570 let parent_run_id = self.run_identity.parent_run_id_opt().map(ToOwned::to_owned);
571 let thread_id = self.caller_context.thread_id().map(ToOwned::to_owned);
572 let parent_call_id = self.run_identity.parent_tool_call_id_opt().and_then(|id| {
573 if id == self.call_id {
574 None
575 } else {
576 Some(id.to_string())
577 }
578 });
579 let parent_node_id = parent_call_id
580 .as_ref()
581 .map(|id| format!("{TOOL_PROGRESS_STREAM_PREFIX}{id}"))
582 .or_else(|| run_id.as_ref().map(|id| format!("run:{id}")));
583 let stream_id = self.progress_stream_id();
584 let payload = ToolCallProgressState {
585 event_type: TOOL_CALL_PROGRESS_TYPE.to_string(),
586 schema: TOOL_CALL_PROGRESS_SCHEMA.to_string(),
587 node_id: stream_id.clone(),
588 parent_node_id,
589 parent_call_id,
590 call_id: self.call_id.clone(),
591 tool_name: self.source_tool_name(),
592 status: update.status,
593 progress: update.progress,
594 loaded: update.loaded,
595 total: update.total,
596 message: update.message,
597 run_id,
598 parent_run_id,
599 thread_id,
600 updated_at_ms: current_unix_millis(),
601 };
602
603 self.tool_call_progress_sink
604 .report(&stream_id, TOOL_CALL_PROGRESS_ACTIVITY_TYPE, &payload)
605 }
606
607 pub fn snapshot(&self) -> Value {
616 self.doc.snapshot()
617 }
618
619 pub fn snapshot_of<T: State>(&self) -> TireaResult<T> {
623 let val = self.doc.snapshot();
624 let at = get_at_path(&val, &parse_path(T::PATH)).unwrap_or(&Value::Null);
625 T::from_value(at)
626 }
627
628 pub fn snapshot_at<T: State>(&self, path: &str) -> TireaResult<T> {
632 let val = self.doc.snapshot();
633 let at = get_at_path(&val, &parse_path(path)).unwrap_or(&Value::Null);
634 T::from_value(at)
635 }
636
637 pub fn take_patch(&self) -> TrackedPatch {
643 let ops = std::mem::take(&mut *self.ops.lock().unwrap());
644 TrackedPatch::new(Patch::with_ops(ops)).with_source(self.source.clone())
645 }
646
647 pub fn has_changes(&self) -> bool {
649 !self.ops.lock().unwrap().is_empty()
650 }
651
652 pub fn ops_count(&self) -> usize {
654 self.ops.lock().unwrap().len()
655 }
656}
657
658fn current_unix_millis() -> u64 {
659 SystemTime::now()
660 .duration_since(UNIX_EPOCH)
661 .map_or(0, |d| d.as_millis().min(u128::from(u64::MAX)) as u64)
662}
663
664pub struct ActivityContext {
666 doc: DocCell,
667 stream_id: String,
668 activity_type: String,
669 ops: Mutex<Vec<Op>>,
670 manager: Arc<dyn ActivityManager>,
671}
672
673impl ActivityContext {
674 pub(crate) fn new(
675 doc: Value,
676 stream_id: String,
677 activity_type: String,
678 manager: Arc<dyn ActivityManager>,
679 ) -> Self {
680 Self {
681 doc: DocCell::new(doc),
682 stream_id,
683 activity_type,
684 ops: Mutex::new(Vec::new()),
685 manager,
686 }
687 }
688
689 pub fn state_of<T: State>(&self) -> T::Ref<'_> {
693 assert!(
694 !T::PATH.is_empty(),
695 "State type has no bound path; use state::<T>(path) instead"
696 );
697 self.state::<T>(T::PATH)
698 }
699
700 pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
706 let base = parse_path(path);
707 let manager = self.manager.clone();
708 let stream_id = self.stream_id.clone();
709 let activity_type = self.activity_type.clone();
710 let doc = &self.doc;
711 let hook: PatchHook<'_> = Arc::new(move |op: &Op| {
712 doc.apply(op)?;
713 manager.on_activity_op(&stream_id, &activity_type, op);
714 Ok(())
715 });
716 T::state_ref(&self.doc, base, PatchSink::new_with_hook(&self.ops, hook))
717 }
718}
719
720#[cfg(test)]
721mod tests {
722 use super::*;
723 use crate::io::ResumeDecisionAction;
724 use crate::runtime::activity::{ActivityManager, NoOpActivityManager};
725 use crate::testing::TestFixtureState;
726 use serde_json::json;
727 use std::sync::Arc;
728 use tirea_state::apply_patch;
729 use tokio::time::{timeout, Duration};
730 use tokio_util::sync::CancellationToken;
731
732 fn make_ctx<'a>(
733 doc: &'a DocCell,
734 ops: &'a Mutex<Vec<Op>>,
735 run_policy: &'a RunPolicy,
736 pending: &'a Mutex<Vec<Arc<Message>>>,
737 ) -> ToolCallContext<'a> {
738 ToolCallContext::new(
739 doc,
740 ops,
741 "call-1",
742 "test",
743 run_policy,
744 pending,
745 NoOpActivityManager::arc(),
746 )
747 }
748
749 fn run_identity(run_id: &str) -> RunIdentity {
750 RunIdentity::new(
751 "thread-child".to_string(),
752 None,
753 run_id.to_string(),
754 None,
755 "agent".to_string(),
756 crate::storage::RunOrigin::Internal,
757 )
758 }
759
760 fn caller_context(thread_id: &str) -> CallerContext {
761 CallerContext::new(
762 Some(thread_id.to_string()),
763 Some("run-parent".to_string()),
764 Some("caller".to_string()),
765 vec![Arc::new(Message::user("seed"))],
766 )
767 }
768
769 #[test]
770 fn test_identity() {
771 let doc = DocCell::new(json!({}));
772 let ops = Mutex::new(Vec::new());
773 let scope = RunPolicy::default();
774 let pending = Mutex::new(Vec::new());
775
776 let ctx = make_ctx(&doc, &ops, &scope, &pending);
777 assert_eq!(ctx.call_id(), "call-1");
778 assert_eq!(ctx.idempotency_key(), "call-1");
779 assert_eq!(ctx.source(), "test");
780 }
781
782 #[test]
783 fn test_typed_context_access() {
784 let doc = DocCell::new(json!({}));
785 let ops = Mutex::new(Vec::new());
786 let scope = RunPolicy::new();
787 let pending = Mutex::new(Vec::new());
788
789 let ctx = make_ctx(&doc, &ops, &scope, &pending)
790 .with_run_identity(run_identity("run-1").with_parent_tool_call_id("call-parent"))
791 .with_caller_context(caller_context("thread-1"));
792
793 assert_eq!(
794 ctx.run_identity().parent_tool_call_id_opt(),
795 Some("call-parent")
796 );
797 assert_eq!(ctx.run_identity().run_id_opt(), Some("run-1"));
798 assert_eq!(ctx.caller_context().thread_id(), Some("thread-1"));
799 assert_eq!(ctx.caller_context().agent_id(), Some("caller"));
800 assert_eq!(ctx.caller_context().messages().len(), 1);
801 }
802
803 #[test]
804 fn test_state_of_read_write() {
805 let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
806 let ops = Mutex::new(Vec::new());
807 let scope = RunPolicy::default();
808 let pending = Mutex::new(Vec::new());
809
810 let ctx = make_ctx(&doc, &ops, &scope, &pending);
811
812 let ctrl = ctx.state_of::<TestFixtureState>();
814 ctrl.set_label(Some("rate_limit".into()))
815 .expect("failed to set label");
816
817 let val = ctrl.label().unwrap();
819 assert!(val.is_some());
820 assert_eq!(val.unwrap(), "rate_limit");
821
822 assert!(!ops.lock().unwrap().is_empty());
824 }
825
826 #[test]
827 fn test_write_through_read_cross_ref() {
828 let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
829 let ops = Mutex::new(Vec::new());
830 let scope = RunPolicy::default();
831 let pending = Mutex::new(Vec::new());
832
833 let ctx = make_ctx(&doc, &ops, &scope, &pending);
834
835 ctx.state_of::<TestFixtureState>()
837 .set_label(Some("timeout".into()))
838 .expect("failed to set label");
839
840 let val = ctx.state_of::<TestFixtureState>().label().unwrap();
842 assert_eq!(val.unwrap(), "timeout");
843 }
844
845 #[test]
846 fn test_take_patch() {
847 let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
848 let ops = Mutex::new(Vec::new());
849 let scope = RunPolicy::default();
850 let pending = Mutex::new(Vec::new());
851
852 let ctx = make_ctx(&doc, &ops, &scope, &pending);
853
854 ctx.state_of::<TestFixtureState>()
855 .set_label(Some("test".into()))
856 .expect("failed to set label");
857
858 assert!(ctx.has_changes());
859 assert!(ctx.ops_count() > 0);
860
861 let patch = ctx.take_patch();
862 assert!(!patch.patch().is_empty());
863 assert_eq!(patch.source.as_deref(), Some("test"));
864 assert!(!ctx.has_changes());
865 assert_eq!(ctx.ops_count(), 0);
866 }
867
868 #[test]
869 fn test_add_messages() {
870 let doc = DocCell::new(json!({}));
871 let ops = Mutex::new(Vec::new());
872 let scope = RunPolicy::default();
873 let pending = Mutex::new(Vec::new());
874
875 let ctx = make_ctx(&doc, &ops, &scope, &pending);
876
877 ctx.add_message(Message::user("hello"));
878 ctx.add_messages(vec![Message::assistant("hi"), Message::user("bye")]);
879
880 assert_eq!(pending.lock().unwrap().len(), 3);
881 }
882
883 #[test]
884 fn test_call_state() {
885 let doc = DocCell::new(json!({"tool_calls": {}}));
886 let ops = Mutex::new(Vec::new());
887 let scope = RunPolicy::default();
888 let pending = Mutex::new(Vec::new());
889
890 let ctx = make_ctx(&doc, &ops, &scope, &pending);
891
892 let ctrl = ctx.call_state::<TestFixtureState>();
893 ctrl.set_label(Some("call_scoped".into()))
894 .expect("failed to set label");
895
896 assert!(ctx.has_changes());
897 }
898
899 #[test]
900 fn test_tool_call_state_roundtrip_and_resume_input() {
901 let doc = DocCell::new(json!({}));
902 let ops = Mutex::new(Vec::new());
903 let scope = RunPolicy::default();
904 let pending = Mutex::new(Vec::new());
905 let ctx = make_ctx(&doc, &ops, &scope, &pending);
906
907 let state = ToolCallState {
908 call_id: "call.1".to_string(),
909 tool_name: "confirm".to_string(),
910 arguments: json!({"value": 1}),
911 status: crate::runtime::ToolCallStatus::Resuming,
912 resume_token: Some("resume.1".to_string()),
913 resume: Some(crate::runtime::ToolCallResume {
914 decision_id: "decision_1".to_string(),
915 action: ResumeDecisionAction::Resume,
916 result: json!({"approved": true}),
917 reason: None,
918 updated_at: 123,
919 }),
920 scratch: json!({"k": "v"}),
921 updated_at: 124,
922 };
923
924 ctx.set_tool_call_state_for("call.1", state.clone())
925 .expect("state should be persisted");
926
927 let loaded = ctx
928 .tool_call_state_for("call.1")
929 .expect("state read should succeed");
930 assert_eq!(loaded, Some(state.clone()));
931
932 let resume = ctx
933 .resume_input_for("call.1")
934 .expect("resume read should succeed");
935 assert_eq!(resume, state.resume);
936 }
937
938 #[test]
939 fn test_clear_tool_call_state_for_removes_entry() {
940 let doc = DocCell::new(json!({}));
941 let ops = Mutex::new(Vec::new());
942 let scope = RunPolicy::default();
943 let pending = Mutex::new(Vec::new());
944 let ctx = make_ctx(&doc, &ops, &scope, &pending);
945
946 ctx.set_tool_call_state_for(
947 "call-1",
948 ToolCallState {
949 call_id: "call-1".to_string(),
950 tool_name: "echo".to_string(),
951 arguments: json!({"x": 1}),
952 status: crate::runtime::ToolCallStatus::Running,
953 resume_token: None,
954 resume: None,
955 scratch: Value::Null,
956 updated_at: 1,
957 },
958 )
959 .expect("state should be set");
960
961 ctx.clear_tool_call_state_for("call-1")
962 .expect("clear should succeed");
963 assert_eq!(
964 ctx.tool_call_state_for("call-1")
965 .expect("state read should succeed"),
966 None
967 );
968 }
969
970 #[test]
971 fn test_cancellation_token_absent_by_default() {
972 let doc = DocCell::new(json!({}));
973 let ops = Mutex::new(Vec::new());
974 let scope = RunPolicy::default();
975 let pending = Mutex::new(Vec::new());
976 let ctx = make_ctx(&doc, &ops, &scope, &pending);
977
978 assert!(!ctx.is_cancelled());
979 assert!(ctx.cancellation_token().is_none());
980 }
981
982 #[tokio::test]
983 async fn test_cancelled_waits_for_attached_token() {
984 let doc = DocCell::new(json!({}));
985 let ops = Mutex::new(Vec::new());
986 let scope = RunPolicy::default();
987 let pending = Mutex::new(Vec::new());
988 let token = CancellationToken::new();
989
990 let ctx = ToolCallContext::new(
991 &doc,
992 &ops,
993 "call-1",
994 "test",
995 &scope,
996 &pending,
997 NoOpActivityManager::arc(),
998 )
999 .with_cancellation_token(&token);
1000
1001 let token_for_task = token.clone();
1002 tokio::spawn(async move {
1003 tokio::time::sleep(Duration::from_millis(20)).await;
1004 token_for_task.cancel();
1005 });
1006
1007 timeout(Duration::from_millis(300), ctx.cancelled())
1008 .await
1009 .expect("cancelled() should resolve after token cancellation");
1010 }
1011
1012 #[tokio::test]
1013 async fn test_cancelled_without_token_never_resolves() {
1014 let doc = DocCell::new(json!({}));
1015 let ops = Mutex::new(Vec::new());
1016 let scope = RunPolicy::default();
1017 let pending = Mutex::new(Vec::new());
1018 let ctx = make_ctx(&doc, &ops, &scope, &pending);
1019
1020 let timed_out = timeout(Duration::from_millis(30), ctx.cancelled())
1021 .await
1022 .is_err();
1023 assert!(timed_out, "cancelled() without token should remain pending");
1024 }
1025
1026 #[derive(Default)]
1027 struct RecordingActivityManager {
1028 events: Mutex<Vec<(String, String, Op)>>,
1029 }
1030
1031 impl ActivityManager for RecordingActivityManager {
1032 fn snapshot(&self, _stream_id: &str) -> Value {
1033 json!({})
1034 }
1035
1036 fn on_activity_op(&self, stream_id: &str, activity_type: &str, op: &Op) {
1037 self.events.lock().unwrap().push((
1038 stream_id.to_string(),
1039 activity_type.to_string(),
1040 op.clone(),
1041 ));
1042 }
1043 }
1044
1045 fn rebuild_activity_state(events: &[(String, String, Op)]) -> Value {
1046 let mut value = json!({});
1047 for (_, _, op) in events {
1048 value = apply_patch(&value, &Patch::with_ops(vec![op.clone()]))
1049 .expect("activity op should apply");
1050 }
1051 value
1052 }
1053
1054 #[derive(Default)]
1055 struct RecordingProgressSink {
1056 events: Mutex<Vec<(String, String, ToolCallProgressState)>>,
1057 }
1058
1059 impl ToolCallProgressSink for RecordingProgressSink {
1060 fn report(
1061 &self,
1062 stream_id: &str,
1063 activity_type: &str,
1064 payload: &ToolCallProgressState,
1065 ) -> TireaResult<()> {
1066 self.events.lock().unwrap().push((
1067 stream_id.to_string(),
1068 activity_type.to_string(),
1069 payload.clone(),
1070 ));
1071 Ok(())
1072 }
1073 }
1074
1075 struct FailingProgressSink;
1076
1077 impl ToolCallProgressSink for FailingProgressSink {
1078 fn report(
1079 &self,
1080 _stream_id: &str,
1081 _activity_type: &str,
1082 _payload: &ToolCallProgressState,
1083 ) -> TireaResult<()> {
1084 Err(TireaError::invalid_operation("sink failed"))
1085 }
1086 }
1087
1088 #[test]
1089 fn test_report_tool_call_progress_emits_tool_call_progress_activity() {
1090 let doc = DocCell::new(json!({}));
1091 let ops = Mutex::new(Vec::new());
1092 let scope = RunPolicy::default();
1093 let pending = Mutex::new(Vec::new());
1094 let activity_manager = Arc::new(RecordingActivityManager::default());
1095
1096 let ctx = ToolCallContext::new(
1097 &doc,
1098 &ops,
1099 "call-1",
1100 "test",
1101 &scope,
1102 &pending,
1103 activity_manager.clone(),
1104 );
1105
1106 ctx.report_tool_call_progress(ToolCallProgressUpdate {
1107 status: ToolCallProgressStatus::Running,
1108 progress: Some(0.5),
1109 loaded: None,
1110 total: Some(10.0),
1111 message: Some("half way".to_string()),
1112 })
1113 .expect("progress should be emitted");
1114
1115 let events = activity_manager.events.lock().unwrap();
1116 assert!(!events.is_empty());
1117 assert!(events.iter().all(|(stream_id, activity_type, _)| {
1118 stream_id == "tool_call:call-1" && activity_type == TOOL_CALL_PROGRESS_ACTIVITY_TYPE
1119 }));
1120 let state = rebuild_activity_state(&events);
1121 assert_eq!(state["type"], TOOL_CALL_PROGRESS_TYPE);
1122 assert_eq!(state["schema"], TOOL_CALL_PROGRESS_SCHEMA);
1123 assert_eq!(state["node_id"], "tool_call:call-1");
1124 assert_eq!(state["call_id"], "call-1");
1125 assert_eq!(state["status"], "running");
1126 assert_eq!(state["progress"], json!(0.5));
1127 assert_eq!(state["total"], json!(10.0));
1128 assert_eq!(state["message"], json!("half way"));
1129 }
1130
1131 #[test]
1132 fn test_report_tool_call_progress_rejects_non_finite_values() {
1133 let doc = DocCell::new(json!({}));
1134 let ops = Mutex::new(Vec::new());
1135 let scope = RunPolicy::default();
1136 let pending = Mutex::new(Vec::new());
1137 let ctx = make_ctx(&doc, &ops, &scope, &pending);
1138
1139 assert!(ctx
1140 .report_tool_call_progress(ToolCallProgressUpdate {
1141 status: ToolCallProgressStatus::Running,
1142 progress: Some(f64::NAN),
1143 loaded: None,
1144 total: None,
1145 message: None,
1146 })
1147 .is_err());
1148 assert!(ctx
1149 .report_tool_call_progress(ToolCallProgressUpdate {
1150 status: ToolCallProgressStatus::Running,
1151 progress: Some(0.5),
1152 loaded: None,
1153 total: Some(f64::INFINITY),
1154 message: None,
1155 })
1156 .is_err());
1157 assert!(ctx
1158 .report_tool_call_progress(ToolCallProgressUpdate {
1159 status: ToolCallProgressStatus::Running,
1160 progress: Some(0.5),
1161 loaded: Some(-1.0),
1162 total: None,
1163 message: None,
1164 })
1165 .is_err());
1166 }
1167
1168 #[test]
1169 fn test_report_tool_call_progress_writes_lineage_and_metadata() {
1170 let doc = DocCell::new(json!({}));
1171 let ops = Mutex::new(Vec::new());
1172 let scope = RunPolicy::new();
1173 let pending = Mutex::new(Vec::new());
1174 let activity_manager = Arc::new(RecordingActivityManager::default());
1175 let run_identity = RunIdentity::new(
1176 "thread-abc".to_string(),
1177 None,
1178 "run-123".to_string(),
1179 Some("run-parent".to_string()),
1180 "agent".to_string(),
1181 crate::storage::RunOrigin::Internal,
1182 )
1183 .with_parent_tool_call_id("call-parent");
1184 let caller_context = CallerContext::new(
1185 Some("thread-abc".to_string()),
1186 Some("run-parent".to_string()),
1187 Some("caller".to_string()),
1188 vec![],
1189 );
1190
1191 let ctx = ToolCallContext::new(
1192 &doc,
1193 &ops,
1194 "call-1",
1195 "tool:echo",
1196 &scope,
1197 &pending,
1198 activity_manager.clone(),
1199 )
1200 .with_run_identity(run_identity)
1201 .with_caller_context(caller_context);
1202
1203 ctx.report_tool_call_progress(ToolCallProgressUpdate {
1204 status: ToolCallProgressStatus::Done,
1205 progress: Some(1.0),
1206 loaded: Some(5.0),
1207 total: Some(5.0),
1208 message: Some("done".to_string()),
1209 })
1210 .expect("tool call progress should be emitted");
1211
1212 let events = activity_manager.events.lock().unwrap();
1213 let state = rebuild_activity_state(&events);
1214 assert_eq!(state["type"], TOOL_CALL_PROGRESS_TYPE);
1215 assert_eq!(state["schema"], TOOL_CALL_PROGRESS_SCHEMA);
1216 assert_eq!(state["node_id"], "tool_call:call-1");
1217 assert_eq!(state["parent_node_id"], "tool_call:call-parent");
1218 assert_eq!(state["parent_call_id"], "call-parent");
1219 assert_eq!(state["tool_name"], "echo");
1220 assert_eq!(state["status"], "done");
1221 assert_eq!(state["run_id"], "run-123");
1222 assert_eq!(state["parent_run_id"], "run-parent");
1223 assert_eq!(state["thread_id"], "thread-abc");
1224 assert!(state["updated_at_ms"].as_u64().unwrap_or_default() > 0);
1225 }
1226
1227 #[test]
1228 fn test_report_tool_call_progress_without_parent_tool_call_anchors_to_run_node() {
1229 let doc = DocCell::new(json!({}));
1230 let ops = Mutex::new(Vec::new());
1231 let scope = RunPolicy::new();
1232 let pending = Mutex::new(Vec::new());
1233 let activity_manager = Arc::new(RecordingActivityManager::default());
1234 let run_identity = run_identity("run-123");
1235 let ctx = ToolCallContext::new(
1236 &doc,
1237 &ops,
1238 "call-1",
1239 "tool:echo",
1240 &scope,
1241 &pending,
1242 activity_manager.clone(),
1243 )
1244 .with_run_identity(run_identity);
1245
1246 ctx.report_tool_call_progress(ToolCallProgressUpdate {
1247 status: ToolCallProgressStatus::Running,
1248 progress: Some(0.3),
1249 loaded: None,
1250 total: None,
1251 message: Some("working".to_string()),
1252 })
1253 .expect("tool call progress should be emitted");
1254
1255 let events = activity_manager.events.lock().unwrap();
1256 let state = rebuild_activity_state(&events);
1257 assert_eq!(state["parent_node_id"], "run:run-123");
1258 assert!(state["parent_call_id"].is_null());
1259 }
1260
1261 #[test]
1262 fn test_report_tool_call_progress_uses_injected_sink_instead_of_activity_manager() {
1263 let doc = DocCell::new(json!({}));
1264 let ops = Mutex::new(Vec::new());
1265 let scope = RunPolicy::default();
1266 let pending = Mutex::new(Vec::new());
1267 let activity_manager = Arc::new(RecordingActivityManager::default());
1268 let sink = Arc::new(RecordingProgressSink::default());
1269 let ctx = ToolCallContext::new(
1270 &doc,
1271 &ops,
1272 "call-1",
1273 "tool:echo",
1274 &scope,
1275 &pending,
1276 activity_manager.clone(),
1277 )
1278 .with_tool_call_progress_sink(sink.clone());
1279
1280 ctx.report_tool_call_progress(ToolCallProgressUpdate {
1281 status: ToolCallProgressStatus::Running,
1282 progress: Some(0.2),
1283 loaded: None,
1284 total: Some(10.0),
1285 message: Some("working".to_string()),
1286 })
1287 .expect("tool call progress should be reported");
1288
1289 let sink_events = sink.events.lock().unwrap();
1290 assert_eq!(sink_events.len(), 1);
1291 let (stream_id, activity_type, payload) = &sink_events[0];
1292 assert_eq!(stream_id, "tool_call:call-1");
1293 assert_eq!(activity_type, TOOL_CALL_PROGRESS_ACTIVITY_TYPE);
1294 assert_eq!(payload.call_id, "call-1");
1295 assert_eq!(payload.progress, Some(0.2));
1296
1297 let activity_events = activity_manager.events.lock().unwrap();
1298 assert!(
1299 activity_events.is_empty(),
1300 "injected sink should bypass default activity manager sink"
1301 );
1302 }
1303
1304 #[test]
1305 fn test_report_tool_call_progress_propagates_sink_error() {
1306 let doc = DocCell::new(json!({}));
1307 let ops = Mutex::new(Vec::new());
1308 let scope = RunPolicy::default();
1309 let pending = Mutex::new(Vec::new());
1310 let ctx = ToolCallContext::new(
1311 &doc,
1312 &ops,
1313 "call-1",
1314 "tool:echo",
1315 &scope,
1316 &pending,
1317 NoOpActivityManager::arc(),
1318 )
1319 .with_tool_call_progress_sink(Arc::new(FailingProgressSink));
1320
1321 let result = ctx.report_tool_call_progress(ToolCallProgressUpdate {
1322 status: ToolCallProgressStatus::Running,
1323 progress: Some(0.1),
1324 loaded: None,
1325 total: None,
1326 message: None,
1327 });
1328 assert!(result.is_err());
1329 }
1330}