tirea_contract/runtime/tool_call/
context.rs

1//! Execution context types for tools and plugins.
2//!
3//! `ToolCallContext` provides state access, run policy, and identity for tool execution.
4//! It replaces direct `&Thread` usage in tool signatures, keeping the persistent
5//! entity (`Thread`) invisible to tools and plugins.
6
7use crate::runtime::activity::ActivityManager;
8use crate::runtime::run::RunIdentity;
9use crate::runtime::{ToolCallResume, ToolCallState};
10use crate::thread::Message;
11use crate::RunPolicy;
12use futures::future::pending;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::sync::{Arc, Mutex};
16use std::time::{SystemTime, UNIX_EPOCH};
17use tirea_state::{
18    get_at_path, parse_path, DocCell, Op, Patch, PatchSink, Path, State, TireaError, TireaResult,
19    TrackedPatch,
20};
21use tokio_util::sync::CancellationToken;
22
23type PatchHook<'a> = Arc<dyn Fn(&Op) -> TireaResult<()> + Send + Sync + 'a>;
24const TOOL_PROGRESS_STREAM_PREFIX: &str = "tool_call:";
25/// Activity type used for tool-call progress updates.
26pub const TOOL_CALL_PROGRESS_ACTIVITY_TYPE: &str = "tool-call-progress";
27/// Legacy public alias kept for backward compatibility.
28pub const TOOL_PROGRESS_ACTIVITY_TYPE: &str = TOOL_CALL_PROGRESS_ACTIVITY_TYPE;
29/// Legacy activity type accepted by consumers.
30pub const TOOL_PROGRESS_ACTIVITY_TYPE_LEGACY: &str = "progress";
31/// Canonical payload `type` value for tool-call progress events.
32pub const TOOL_CALL_PROGRESS_TYPE: &str = "tool-call-progress";
33/// Canonical payload schema version for tool-call progress events.
34pub const TOOL_CALL_PROGRESS_SCHEMA: &str = "tool-call-progress.v1";
35
36/// Status marker for a tool-call progress node.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum ToolCallProgressStatus {
40    Pending,
41    #[default]
42    Running,
43    Done,
44    Failed,
45    Cancelled,
46}
47
48/// Canonical tree-node payload for tool-call progress updates.
49#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
50pub struct ToolCallProgressState {
51    /// Payload type identifier.
52    #[serde(rename = "type")]
53    pub event_type: String,
54    /// Payload schema version.
55    pub schema: String,
56    /// Stable node id.
57    pub node_id: String,
58    /// Optional parent node id in the progress tree.
59    #[serde(default)]
60    pub parent_node_id: Option<String>,
61    /// Optional parent tool call id when this node belongs to a nested run.
62    #[serde(default)]
63    pub parent_call_id: Option<String>,
64    /// Tool call id that owns this node.
65    pub call_id: String,
66    /// Optional tool name.
67    #[serde(default)]
68    pub tool_name: Option<String>,
69    /// Current status.
70    pub status: ToolCallProgressStatus,
71    /// Normalized progress ratio when available.
72    #[serde(default)]
73    pub progress: Option<f64>,
74    /// Optional absolute loaded counter.
75    #[serde(default)]
76    pub loaded: Option<f64>,
77    /// Optional absolute total counter.
78    #[serde(default)]
79    pub total: Option<f64>,
80    /// Optional human-readable message.
81    #[serde(default)]
82    pub message: Option<String>,
83    /// Current run id.
84    #[serde(default)]
85    pub run_id: Option<String>,
86    /// Parent run id.
87    #[serde(default)]
88    pub parent_run_id: Option<String>,
89    /// Current thread id when available.
90    #[serde(default)]
91    pub thread_id: Option<String>,
92    /// Last update timestamp in unix milliseconds.
93    pub updated_at_ms: u64,
94}
95
96/// Input shape for publishing tool-call progress updates.
97#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct ToolCallProgressUpdate {
99    #[serde(default)]
100    pub status: ToolCallProgressStatus,
101    #[serde(default, skip_serializing_if = "Option::is_none")]
102    pub progress: Option<f64>,
103    #[serde(default, skip_serializing_if = "Option::is_none")]
104    pub loaded: Option<f64>,
105    #[serde(default, skip_serializing_if = "Option::is_none")]
106    pub total: Option<f64>,
107    #[serde(default, skip_serializing_if = "Option::is_none")]
108    pub message: Option<String>,
109}
110
111/// Canonical activity state shape for tool progress updates.
112#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
113pub struct ToolProgressState {
114    /// Normalized progress value.
115    pub progress: f64,
116    /// Optional absolute total if the source has one.
117    #[serde(default, skip_serializing_if = "Option::is_none")]
118    pub total: Option<f64>,
119    /// Optional human-readable progress message.
120    #[serde(default, skip_serializing_if = "Option::is_none")]
121    pub message: Option<String>,
122}
123
124/// Sink interface for tool-call progress events.
125///
126/// Tools report progress through [`ToolCallContext::report_tool_call_progress`], and
127/// the context forwards canonical payloads into this sink. The sink implementation
128/// decides how payloads are emitted/transported.
129pub trait ToolCallProgressSink: Send + Sync {
130    /// Consume a canonical tool-call progress payload.
131    fn report(
132        &self,
133        stream_id: &str,
134        activity_type: &str,
135        payload: &ToolCallProgressState,
136    ) -> TireaResult<()>;
137}
138
139#[derive(Clone)]
140struct ActivityManagerProgressSink {
141    manager: Arc<dyn ActivityManager>,
142}
143
144impl ActivityManagerProgressSink {
145    fn new(manager: Arc<dyn ActivityManager>) -> Self {
146        Self { manager }
147    }
148}
149
150/// Typed caller metadata exposed to tool executions.
151#[derive(Clone, Debug, Default)]
152pub struct CallerContext {
153    thread_id: Option<String>,
154    run_id: Option<String>,
155    agent_id: Option<String>,
156    messages: Arc<[Arc<Message>]>,
157}
158
159impl CallerContext {
160    pub fn new(
161        thread_id: Option<String>,
162        run_id: Option<String>,
163        agent_id: Option<String>,
164        messages: Vec<Arc<Message>>,
165    ) -> Self {
166        Self {
167            thread_id: thread_id
168                .map(|value| value.trim().to_string())
169                .filter(|value| !value.is_empty()),
170            run_id: run_id
171                .map(|value| value.trim().to_string())
172                .filter(|value| !value.is_empty()),
173            agent_id: agent_id
174                .map(|value| value.trim().to_string())
175                .filter(|value| !value.is_empty()),
176            messages: Arc::<[Arc<Message>]>::from(messages),
177        }
178    }
179
180    pub fn thread_id(&self) -> Option<&str> {
181        self.thread_id.as_deref()
182    }
183
184    pub fn run_id(&self) -> Option<&str> {
185        self.run_id.as_deref()
186    }
187
188    pub fn agent_id(&self) -> Option<&str> {
189        self.agent_id.as_deref()
190    }
191
192    pub fn messages(&self) -> &[Arc<Message>] {
193        self.messages.as_ref()
194    }
195}
196
197impl ToolCallProgressSink for ActivityManagerProgressSink {
198    fn report(
199        &self,
200        stream_id: &str,
201        activity_type: &str,
202        payload: &ToolCallProgressState,
203    ) -> TireaResult<()> {
204        let Value::Object(fields) = serde_json::to_value(payload)? else {
205            return Err(TireaError::invalid_operation(
206                "tool-call-progress payload must serialize as object",
207            ));
208        };
209        for (key, value) in fields {
210            let op = Op::set(Path::root().key(key), value);
211            self.manager.on_activity_op(stream_id, activity_type, &op);
212        }
213        Ok(())
214    }
215}
216
217/// Execution context for tool invocations.
218///
219/// Provides typed state access (read/write), run policy access, identity,
220/// message queuing, and activity tracking. Tools receive `&ToolCallContext`
221/// instead of `&Thread`.
222pub struct ToolCallContext<'a> {
223    doc: &'a DocCell,
224    ops: &'a Mutex<Vec<Op>>,
225    call_id: String,
226    source: String,
227    run_policy: &'a RunPolicy,
228    run_identity: RunIdentity,
229    caller_context: CallerContext,
230    pending_messages: &'a Mutex<Vec<Arc<Message>>>,
231    activity_manager: Arc<dyn ActivityManager>,
232    tool_call_progress_sink: Arc<dyn ToolCallProgressSink>,
233    cancellation_token: Option<&'a CancellationToken>,
234    read_only: bool,
235}
236
237impl<'a> ToolCallContext<'a> {
238    fn tool_call_state_path(call_id: &str) -> Path {
239        Path::root()
240            .key("__tool_call_scope")
241            .key(call_id)
242            .key("tool_call_state")
243    }
244
245    fn apply_op(&self, op: Op) -> TireaResult<()> {
246        if self.read_only {
247            return Err(TireaError::invalid_operation(
248                "tool context is read-only; emit ToolExecutionEffect actions instead",
249            ));
250        }
251        self.doc.apply(&op)?;
252        self.ops.lock().unwrap().push(op);
253        Ok(())
254    }
255
256    /// Create a new tool call context.
257    pub fn new(
258        doc: &'a DocCell,
259        ops: &'a Mutex<Vec<Op>>,
260        call_id: impl Into<String>,
261        source: impl Into<String>,
262        run_policy: &'a RunPolicy,
263        pending_messages: &'a Mutex<Vec<Arc<Message>>>,
264        activity_manager: Arc<dyn ActivityManager>,
265    ) -> Self {
266        let tool_call_progress_sink: Arc<dyn ToolCallProgressSink> =
267            Arc::new(ActivityManagerProgressSink::new(activity_manager.clone()));
268        Self {
269            doc,
270            ops,
271            call_id: call_id.into(),
272            source: source.into(),
273            run_policy,
274            run_identity: RunIdentity::default(),
275            caller_context: CallerContext::default(),
276            pending_messages,
277            activity_manager,
278            tool_call_progress_sink,
279            cancellation_token: None,
280            read_only: false,
281        }
282    }
283
284    /// Mark this context as read-only; state writes via `apply_op` are rejected.
285    #[must_use]
286    pub fn as_read_only(mut self) -> Self {
287        self.read_only = true;
288        self
289    }
290
291    /// Attach cancellation token.
292    #[must_use]
293    pub fn with_cancellation_token(mut self, token: &'a CancellationToken) -> Self {
294        self.cancellation_token = Some(token);
295        self
296    }
297
298    #[must_use]
299    pub fn with_run_identity(mut self, run_identity: RunIdentity) -> Self {
300        self.run_identity = run_identity;
301        self
302    }
303
304    #[must_use]
305    pub fn with_caller_context(mut self, caller_context: CallerContext) -> Self {
306        self.caller_context = caller_context;
307        self
308    }
309
310    /// Override the sink used for tool-call progress payload forwarding.
311    ///
312    /// This allows runtime integrations to decouple progress collection from
313    /// activity transport details.
314    #[must_use]
315    pub fn with_tool_call_progress_sink(mut self, sink: Arc<dyn ToolCallProgressSink>) -> Self {
316        self.tool_call_progress_sink = sink;
317        self
318    }
319
320    // =========================================================================
321    // Identity
322    // =========================================================================
323
324    /// Borrow the underlying document cell.
325    pub fn doc(&self) -> &DocCell {
326        self.doc
327    }
328
329    /// Current call id (typically the `tool_call_id`).
330    pub fn call_id(&self) -> &str {
331        &self.call_id
332    }
333
334    /// Stable idempotency key for the current tool invocation.
335    ///
336    /// Tools should use this value when implementing idempotent side effects.
337    pub fn idempotency_key(&self) -> &str {
338        self.call_id()
339    }
340
341    /// Source identifier used for tracked patches.
342    pub fn source(&self) -> &str {
343        &self.source
344    }
345
346    /// Whether the run cancellation token has already been cancelled.
347    pub fn is_cancelled(&self) -> bool {
348        self.cancellation_token
349            .is_some_and(CancellationToken::is_cancelled)
350    }
351
352    /// Await cancellation for this context.
353    ///
354    /// If no cancellation token is available, this future never resolves.
355    pub async fn cancelled(&self) {
356        if let Some(token) = self.cancellation_token {
357            token.cancelled().await;
358        } else {
359            pending::<()>().await;
360        }
361    }
362
363    /// Borrow the cancellation token when present.
364    pub fn cancellation_token(&self) -> Option<&CancellationToken> {
365        self.cancellation_token
366    }
367
368    // =========================================================================
369    // Run policy / identity
370    // =========================================================================
371
372    /// Borrow the run policy.
373    pub fn run_policy(&self) -> &RunPolicy {
374        self.run_policy
375    }
376
377    pub fn run_identity(&self) -> &RunIdentity {
378        &self.run_identity
379    }
380
381    pub fn caller_context(&self) -> &CallerContext {
382        &self.caller_context
383    }
384
385    // =========================================================================
386    // State access
387    // =========================================================================
388
389    /// Typed state reference at path.
390    pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
391        let base = parse_path(path);
392        let doc = self.doc;
393        let read_only = self.read_only;
394        let hook: PatchHook<'_> = Arc::new(move |op: &Op| {
395            if read_only {
396                return Err(TireaError::invalid_operation(
397                    "tool context is read-only; emit ToolExecutionEffect actions instead",
398                ));
399            }
400            doc.apply(op)?;
401            Ok(())
402        });
403        T::state_ref(doc, base, PatchSink::new_with_hook(self.ops, hook))
404    }
405
406    /// Typed state reference at the type's canonical path.
407    ///
408    /// Panics if `T::PATH` is empty (no bound path via `#[tirea(path = "...")]`).
409    pub fn state_of<T: State>(&self) -> T::Ref<'_> {
410        assert!(
411            !T::PATH.is_empty(),
412            "State type has no bound path; use state::<T>(path) instead"
413        );
414        self.state::<T>(T::PATH)
415    }
416
417    /// Typed state reference for current call (`tool_calls.<call_id>`).
418    pub fn call_state<T: State>(&self) -> T::Ref<'_> {
419        let path = format!("tool_calls.{}", self.call_id);
420        self.state::<T>(&path)
421    }
422
423    /// Read persisted runtime state for a specific tool call.
424    pub fn tool_call_state_for(&self, call_id: &str) -> TireaResult<Option<ToolCallState>> {
425        if call_id.trim().is_empty() {
426            return Ok(None);
427        }
428        let val = self.doc.snapshot();
429        let path = Self::tool_call_state_path(call_id);
430        let at = get_at_path(&val, &path);
431        match at {
432            Some(v) if !v.is_null() => {
433                let state = ToolCallState::from_value(v)?;
434                Ok(Some(state))
435            }
436            _ => Ok(None),
437        }
438    }
439
440    /// Read persisted runtime state for current `call_id`.
441    pub fn tool_call_state(&self) -> TireaResult<Option<ToolCallState>> {
442        self.tool_call_state_for(self.call_id())
443    }
444
445    /// Upsert persisted runtime state for a specific tool call.
446    pub fn set_tool_call_state_for(&self, call_id: &str, state: ToolCallState) -> TireaResult<()> {
447        if call_id.trim().is_empty() {
448            return Err(TireaError::invalid_operation(
449                "tool_call_state requires non-empty call_id",
450            ));
451        }
452        let value = serde_json::to_value(state)?;
453        self.apply_op(Op::set(Self::tool_call_state_path(call_id), value))
454    }
455
456    /// Upsert persisted runtime state for current `call_id`.
457    pub fn set_tool_call_state(&self, state: ToolCallState) -> TireaResult<()> {
458        self.set_tool_call_state_for(self.call_id(), state)
459    }
460
461    /// Remove persisted runtime state for a specific tool call.
462    pub fn clear_tool_call_state_for(&self, call_id: &str) -> TireaResult<()> {
463        if call_id.trim().is_empty() {
464            return Ok(());
465        }
466        if self.tool_call_state_for(call_id)?.is_some() {
467            self.apply_op(Op::delete(Self::tool_call_state_path(call_id)))?;
468        }
469        Ok(())
470    }
471
472    /// Remove persisted runtime state for current `call_id`.
473    pub fn clear_tool_call_state(&self) -> TireaResult<()> {
474        self.clear_tool_call_state_for(self.call_id())
475    }
476
477    /// Read resume payload for a specific tool call.
478    pub fn resume_input_for(&self, call_id: &str) -> TireaResult<Option<ToolCallResume>> {
479        Ok(self
480            .tool_call_state_for(call_id)?
481            .and_then(|state| state.resume))
482    }
483
484    /// Read resume payload for current `call_id`.
485    pub fn resume_input(&self) -> TireaResult<Option<ToolCallResume>> {
486        self.resume_input_for(self.call_id())
487    }
488
489    // =========================================================================
490    // Messages
491    // =========================================================================
492
493    /// Queue a message addition in this operation.
494    pub fn add_message(&self, message: Message) {
495        self.pending_messages
496            .lock()
497            .unwrap()
498            .push(Arc::new(message));
499    }
500
501    /// Queue multiple messages in this operation.
502    pub fn add_messages(&self, messages: impl IntoIterator<Item = Message>) {
503        self.pending_messages
504            .lock()
505            .unwrap()
506            .extend(messages.into_iter().map(Arc::new));
507    }
508
509    // =========================================================================
510    // Activity
511    // =========================================================================
512
513    /// Create an activity context for a stream/type pair.
514    pub fn activity(
515        &self,
516        stream_id: impl Into<String>,
517        activity_type: impl Into<String>,
518    ) -> ActivityContext {
519        let stream_id = stream_id.into();
520        let activity_type = activity_type.into();
521        let snapshot = self.activity_manager.snapshot(&stream_id);
522
523        ActivityContext::new(
524            snapshot,
525            stream_id,
526            activity_type,
527            self.activity_manager.clone(),
528        )
529    }
530
531    /// Stable stream id used by default for this tool call's progress activity.
532    pub fn progress_stream_id(&self) -> String {
533        format!("{TOOL_PROGRESS_STREAM_PREFIX}{}", self.call_id)
534    }
535
536    fn source_tool_name(&self) -> Option<String> {
537        self.source
538            .strip_prefix("tool:")
539            .filter(|name| !name.trim().is_empty())
540            .map(ToOwned::to_owned)
541    }
542
543    fn validate_progress_value(name: &str, value: Option<f64>) -> TireaResult<()> {
544        let Some(value) = value else {
545            return Ok(());
546        };
547        if !value.is_finite() {
548            return Err(TireaError::invalid_operation(format!(
549                "{name} must be a finite number"
550            )));
551        }
552        if value < 0.0 {
553            return Err(TireaError::invalid_operation(format!(
554                "{name} must be non-negative"
555            )));
556        }
557        Ok(())
558    }
559
560    /// Publish a typed tool-call progress node update.
561    ///
562    /// The update is written to `activity(progress_stream_id(), "tool-call-progress")`
563    /// with payload schema `tool-call-progress.v1`.
564    pub fn report_tool_call_progress(&self, update: ToolCallProgressUpdate) -> TireaResult<()> {
565        Self::validate_progress_value("progress value", update.progress)?;
566        Self::validate_progress_value("progress loaded", update.loaded)?;
567        Self::validate_progress_value("progress total", update.total)?;
568
569        let run_id = self.run_identity.run_id_opt().map(ToOwned::to_owned);
570        let parent_run_id = self.run_identity.parent_run_id_opt().map(ToOwned::to_owned);
571        let thread_id = self.caller_context.thread_id().map(ToOwned::to_owned);
572        let parent_call_id = self.run_identity.parent_tool_call_id_opt().and_then(|id| {
573            if id == self.call_id {
574                None
575            } else {
576                Some(id.to_string())
577            }
578        });
579        let parent_node_id = parent_call_id
580            .as_ref()
581            .map(|id| format!("{TOOL_PROGRESS_STREAM_PREFIX}{id}"))
582            .or_else(|| run_id.as_ref().map(|id| format!("run:{id}")));
583        let stream_id = self.progress_stream_id();
584        let payload = ToolCallProgressState {
585            event_type: TOOL_CALL_PROGRESS_TYPE.to_string(),
586            schema: TOOL_CALL_PROGRESS_SCHEMA.to_string(),
587            node_id: stream_id.clone(),
588            parent_node_id,
589            parent_call_id,
590            call_id: self.call_id.clone(),
591            tool_name: self.source_tool_name(),
592            status: update.status,
593            progress: update.progress,
594            loaded: update.loaded,
595            total: update.total,
596            message: update.message,
597            run_id,
598            parent_run_id,
599            thread_id,
600            updated_at_ms: current_unix_millis(),
601        };
602
603        self.tool_call_progress_sink
604            .report(&stream_id, TOOL_CALL_PROGRESS_ACTIVITY_TYPE, &payload)
605    }
606
607    // =========================================================================
608    // State snapshot
609    // =========================================================================
610
611    /// Snapshot the current document state.
612    ///
613    /// Returns the current state including all write-through updates.
614    /// Equivalent to `Thread::rebuild_state()` in transient contexts.
615    pub fn snapshot(&self) -> Value {
616        self.doc.snapshot()
617    }
618
619    /// Typed snapshot at the type's canonical path.
620    ///
621    /// Reads current doc state and deserializes the value at `T::PATH`.
622    pub fn snapshot_of<T: State>(&self) -> TireaResult<T> {
623        let val = self.doc.snapshot();
624        let at = get_at_path(&val, &parse_path(T::PATH)).unwrap_or(&Value::Null);
625        T::from_value(at)
626    }
627
628    /// Typed snapshot at an explicit path.
629    ///
630    /// Reads current doc state and deserializes the value at the given path.
631    pub fn snapshot_at<T: State>(&self, path: &str) -> TireaResult<T> {
632        let val = self.doc.snapshot();
633        let at = get_at_path(&val, &parse_path(path)).unwrap_or(&Value::Null);
634        T::from_value(at)
635    }
636
637    // =========================================================================
638    // Patch extraction
639    // =========================================================================
640
641    /// Extract accumulated patch with context source metadata.
642    pub fn take_patch(&self) -> TrackedPatch {
643        let ops = std::mem::take(&mut *self.ops.lock().unwrap());
644        TrackedPatch::new(Patch::with_ops(ops)).with_source(self.source.clone())
645    }
646
647    /// Whether state has pending transient changes.
648    pub fn has_changes(&self) -> bool {
649        !self.ops.lock().unwrap().is_empty()
650    }
651
652    /// Number of queued transient operations.
653    pub fn ops_count(&self) -> usize {
654        self.ops.lock().unwrap().len()
655    }
656}
657
658fn current_unix_millis() -> u64 {
659    SystemTime::now()
660        .duration_since(UNIX_EPOCH)
661        .map_or(0, |d| d.as_millis().min(u128::from(u64::MAX)) as u64)
662}
663
664/// Activity-scoped state context.
665pub struct ActivityContext {
666    doc: DocCell,
667    stream_id: String,
668    activity_type: String,
669    ops: Mutex<Vec<Op>>,
670    manager: Arc<dyn ActivityManager>,
671}
672
673impl ActivityContext {
674    pub(crate) fn new(
675        doc: Value,
676        stream_id: String,
677        activity_type: String,
678        manager: Arc<dyn ActivityManager>,
679    ) -> Self {
680        Self {
681            doc: DocCell::new(doc),
682            stream_id,
683            activity_type,
684            ops: Mutex::new(Vec::new()),
685            manager,
686        }
687    }
688
689    /// Typed activity state reference at the type's canonical path.
690    ///
691    /// Panics if `T::PATH` is empty.
692    pub fn state_of<T: State>(&self) -> T::Ref<'_> {
693        assert!(
694            !T::PATH.is_empty(),
695            "State type has no bound path; use state::<T>(path) instead"
696        );
697        self.state::<T>(T::PATH)
698    }
699
700    /// Get a typed activity state reference at the specified path.
701    ///
702    /// All modifications are automatically collected and immediately reported
703    /// to the activity manager. Writes are applied to the shared doc for
704    /// immediate read-back.
705    pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
706        let base = parse_path(path);
707        let manager = self.manager.clone();
708        let stream_id = self.stream_id.clone();
709        let activity_type = self.activity_type.clone();
710        let doc = &self.doc;
711        let hook: PatchHook<'_> = Arc::new(move |op: &Op| {
712            doc.apply(op)?;
713            manager.on_activity_op(&stream_id, &activity_type, op);
714            Ok(())
715        });
716        T::state_ref(&self.doc, base, PatchSink::new_with_hook(&self.ops, hook))
717    }
718}
719
720#[cfg(test)]
721mod tests {
722    use super::*;
723    use crate::io::ResumeDecisionAction;
724    use crate::runtime::activity::{ActivityManager, NoOpActivityManager};
725    use crate::testing::TestFixtureState;
726    use serde_json::json;
727    use std::sync::Arc;
728    use tirea_state::apply_patch;
729    use tokio::time::{timeout, Duration};
730    use tokio_util::sync::CancellationToken;
731
732    fn make_ctx<'a>(
733        doc: &'a DocCell,
734        ops: &'a Mutex<Vec<Op>>,
735        run_policy: &'a RunPolicy,
736        pending: &'a Mutex<Vec<Arc<Message>>>,
737    ) -> ToolCallContext<'a> {
738        ToolCallContext::new(
739            doc,
740            ops,
741            "call-1",
742            "test",
743            run_policy,
744            pending,
745            NoOpActivityManager::arc(),
746        )
747    }
748
749    fn run_identity(run_id: &str) -> RunIdentity {
750        RunIdentity::new(
751            "thread-child".to_string(),
752            None,
753            run_id.to_string(),
754            None,
755            "agent".to_string(),
756            crate::storage::RunOrigin::Internal,
757        )
758    }
759
760    fn caller_context(thread_id: &str) -> CallerContext {
761        CallerContext::new(
762            Some(thread_id.to_string()),
763            Some("run-parent".to_string()),
764            Some("caller".to_string()),
765            vec![Arc::new(Message::user("seed"))],
766        )
767    }
768
769    #[test]
770    fn test_identity() {
771        let doc = DocCell::new(json!({}));
772        let ops = Mutex::new(Vec::new());
773        let scope = RunPolicy::default();
774        let pending = Mutex::new(Vec::new());
775
776        let ctx = make_ctx(&doc, &ops, &scope, &pending);
777        assert_eq!(ctx.call_id(), "call-1");
778        assert_eq!(ctx.idempotency_key(), "call-1");
779        assert_eq!(ctx.source(), "test");
780    }
781
782    #[test]
783    fn test_typed_context_access() {
784        let doc = DocCell::new(json!({}));
785        let ops = Mutex::new(Vec::new());
786        let scope = RunPolicy::new();
787        let pending = Mutex::new(Vec::new());
788
789        let ctx = make_ctx(&doc, &ops, &scope, &pending)
790            .with_run_identity(run_identity("run-1").with_parent_tool_call_id("call-parent"))
791            .with_caller_context(caller_context("thread-1"));
792
793        assert_eq!(
794            ctx.run_identity().parent_tool_call_id_opt(),
795            Some("call-parent")
796        );
797        assert_eq!(ctx.run_identity().run_id_opt(), Some("run-1"));
798        assert_eq!(ctx.caller_context().thread_id(), Some("thread-1"));
799        assert_eq!(ctx.caller_context().agent_id(), Some("caller"));
800        assert_eq!(ctx.caller_context().messages().len(), 1);
801    }
802
803    #[test]
804    fn test_state_of_read_write() {
805        let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
806        let ops = Mutex::new(Vec::new());
807        let scope = RunPolicy::default();
808        let pending = Mutex::new(Vec::new());
809
810        let ctx = make_ctx(&doc, &ops, &scope, &pending);
811
812        // Write
813        let ctrl = ctx.state_of::<TestFixtureState>();
814        ctrl.set_label(Some("rate_limit".into()))
815            .expect("failed to set label");
816
817        // Read back from same ref
818        let val = ctrl.label().unwrap();
819        assert!(val.is_some());
820        assert_eq!(val.unwrap(), "rate_limit");
821
822        // Ops captured in thread ops
823        assert!(!ops.lock().unwrap().is_empty());
824    }
825
826    #[test]
827    fn test_write_through_read_cross_ref() {
828        let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
829        let ops = Mutex::new(Vec::new());
830        let scope = RunPolicy::default();
831        let pending = Mutex::new(Vec::new());
832
833        let ctx = make_ctx(&doc, &ops, &scope, &pending);
834
835        // Write via first ref
836        ctx.state_of::<TestFixtureState>()
837            .set_label(Some("timeout".into()))
838            .expect("failed to set label");
839
840        // Read via second ref
841        let val = ctx.state_of::<TestFixtureState>().label().unwrap();
842        assert_eq!(val.unwrap(), "timeout");
843    }
844
845    #[test]
846    fn test_take_patch() {
847        let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
848        let ops = Mutex::new(Vec::new());
849        let scope = RunPolicy::default();
850        let pending = Mutex::new(Vec::new());
851
852        let ctx = make_ctx(&doc, &ops, &scope, &pending);
853
854        ctx.state_of::<TestFixtureState>()
855            .set_label(Some("test".into()))
856            .expect("failed to set label");
857
858        assert!(ctx.has_changes());
859        assert!(ctx.ops_count() > 0);
860
861        let patch = ctx.take_patch();
862        assert!(!patch.patch().is_empty());
863        assert_eq!(patch.source.as_deref(), Some("test"));
864        assert!(!ctx.has_changes());
865        assert_eq!(ctx.ops_count(), 0);
866    }
867
868    #[test]
869    fn test_add_messages() {
870        let doc = DocCell::new(json!({}));
871        let ops = Mutex::new(Vec::new());
872        let scope = RunPolicy::default();
873        let pending = Mutex::new(Vec::new());
874
875        let ctx = make_ctx(&doc, &ops, &scope, &pending);
876
877        ctx.add_message(Message::user("hello"));
878        ctx.add_messages(vec![Message::assistant("hi"), Message::user("bye")]);
879
880        assert_eq!(pending.lock().unwrap().len(), 3);
881    }
882
883    #[test]
884    fn test_call_state() {
885        let doc = DocCell::new(json!({"tool_calls": {}}));
886        let ops = Mutex::new(Vec::new());
887        let scope = RunPolicy::default();
888        let pending = Mutex::new(Vec::new());
889
890        let ctx = make_ctx(&doc, &ops, &scope, &pending);
891
892        let ctrl = ctx.call_state::<TestFixtureState>();
893        ctrl.set_label(Some("call_scoped".into()))
894            .expect("failed to set label");
895
896        assert!(ctx.has_changes());
897    }
898
899    #[test]
900    fn test_tool_call_state_roundtrip_and_resume_input() {
901        let doc = DocCell::new(json!({}));
902        let ops = Mutex::new(Vec::new());
903        let scope = RunPolicy::default();
904        let pending = Mutex::new(Vec::new());
905        let ctx = make_ctx(&doc, &ops, &scope, &pending);
906
907        let state = ToolCallState {
908            call_id: "call.1".to_string(),
909            tool_name: "confirm".to_string(),
910            arguments: json!({"value": 1}),
911            status: crate::runtime::ToolCallStatus::Resuming,
912            resume_token: Some("resume.1".to_string()),
913            resume: Some(crate::runtime::ToolCallResume {
914                decision_id: "decision_1".to_string(),
915                action: ResumeDecisionAction::Resume,
916                result: json!({"approved": true}),
917                reason: None,
918                updated_at: 123,
919            }),
920            scratch: json!({"k": "v"}),
921            updated_at: 124,
922        };
923
924        ctx.set_tool_call_state_for("call.1", state.clone())
925            .expect("state should be persisted");
926
927        let loaded = ctx
928            .tool_call_state_for("call.1")
929            .expect("state read should succeed");
930        assert_eq!(loaded, Some(state.clone()));
931
932        let resume = ctx
933            .resume_input_for("call.1")
934            .expect("resume read should succeed");
935        assert_eq!(resume, state.resume);
936    }
937
938    #[test]
939    fn test_clear_tool_call_state_for_removes_entry() {
940        let doc = DocCell::new(json!({}));
941        let ops = Mutex::new(Vec::new());
942        let scope = RunPolicy::default();
943        let pending = Mutex::new(Vec::new());
944        let ctx = make_ctx(&doc, &ops, &scope, &pending);
945
946        ctx.set_tool_call_state_for(
947            "call-1",
948            ToolCallState {
949                call_id: "call-1".to_string(),
950                tool_name: "echo".to_string(),
951                arguments: json!({"x": 1}),
952                status: crate::runtime::ToolCallStatus::Running,
953                resume_token: None,
954                resume: None,
955                scratch: Value::Null,
956                updated_at: 1,
957            },
958        )
959        .expect("state should be set");
960
961        ctx.clear_tool_call_state_for("call-1")
962            .expect("clear should succeed");
963        assert_eq!(
964            ctx.tool_call_state_for("call-1")
965                .expect("state read should succeed"),
966            None
967        );
968    }
969
970    #[test]
971    fn test_cancellation_token_absent_by_default() {
972        let doc = DocCell::new(json!({}));
973        let ops = Mutex::new(Vec::new());
974        let scope = RunPolicy::default();
975        let pending = Mutex::new(Vec::new());
976        let ctx = make_ctx(&doc, &ops, &scope, &pending);
977
978        assert!(!ctx.is_cancelled());
979        assert!(ctx.cancellation_token().is_none());
980    }
981
982    #[tokio::test]
983    async fn test_cancelled_waits_for_attached_token() {
984        let doc = DocCell::new(json!({}));
985        let ops = Mutex::new(Vec::new());
986        let scope = RunPolicy::default();
987        let pending = Mutex::new(Vec::new());
988        let token = CancellationToken::new();
989
990        let ctx = ToolCallContext::new(
991            &doc,
992            &ops,
993            "call-1",
994            "test",
995            &scope,
996            &pending,
997            NoOpActivityManager::arc(),
998        )
999        .with_cancellation_token(&token);
1000
1001        let token_for_task = token.clone();
1002        tokio::spawn(async move {
1003            tokio::time::sleep(Duration::from_millis(20)).await;
1004            token_for_task.cancel();
1005        });
1006
1007        timeout(Duration::from_millis(300), ctx.cancelled())
1008            .await
1009            .expect("cancelled() should resolve after token cancellation");
1010    }
1011
1012    #[tokio::test]
1013    async fn test_cancelled_without_token_never_resolves() {
1014        let doc = DocCell::new(json!({}));
1015        let ops = Mutex::new(Vec::new());
1016        let scope = RunPolicy::default();
1017        let pending = Mutex::new(Vec::new());
1018        let ctx = make_ctx(&doc, &ops, &scope, &pending);
1019
1020        let timed_out = timeout(Duration::from_millis(30), ctx.cancelled())
1021            .await
1022            .is_err();
1023        assert!(timed_out, "cancelled() without token should remain pending");
1024    }
1025
1026    #[derive(Default)]
1027    struct RecordingActivityManager {
1028        events: Mutex<Vec<(String, String, Op)>>,
1029    }
1030
1031    impl ActivityManager for RecordingActivityManager {
1032        fn snapshot(&self, _stream_id: &str) -> Value {
1033            json!({})
1034        }
1035
1036        fn on_activity_op(&self, stream_id: &str, activity_type: &str, op: &Op) {
1037            self.events.lock().unwrap().push((
1038                stream_id.to_string(),
1039                activity_type.to_string(),
1040                op.clone(),
1041            ));
1042        }
1043    }
1044
1045    fn rebuild_activity_state(events: &[(String, String, Op)]) -> Value {
1046        let mut value = json!({});
1047        for (_, _, op) in events {
1048            value = apply_patch(&value, &Patch::with_ops(vec![op.clone()]))
1049                .expect("activity op should apply");
1050        }
1051        value
1052    }
1053
1054    #[derive(Default)]
1055    struct RecordingProgressSink {
1056        events: Mutex<Vec<(String, String, ToolCallProgressState)>>,
1057    }
1058
1059    impl ToolCallProgressSink for RecordingProgressSink {
1060        fn report(
1061            &self,
1062            stream_id: &str,
1063            activity_type: &str,
1064            payload: &ToolCallProgressState,
1065        ) -> TireaResult<()> {
1066            self.events.lock().unwrap().push((
1067                stream_id.to_string(),
1068                activity_type.to_string(),
1069                payload.clone(),
1070            ));
1071            Ok(())
1072        }
1073    }
1074
1075    struct FailingProgressSink;
1076
1077    impl ToolCallProgressSink for FailingProgressSink {
1078        fn report(
1079            &self,
1080            _stream_id: &str,
1081            _activity_type: &str,
1082            _payload: &ToolCallProgressState,
1083        ) -> TireaResult<()> {
1084            Err(TireaError::invalid_operation("sink failed"))
1085        }
1086    }
1087
1088    #[test]
1089    fn test_report_tool_call_progress_emits_tool_call_progress_activity() {
1090        let doc = DocCell::new(json!({}));
1091        let ops = Mutex::new(Vec::new());
1092        let scope = RunPolicy::default();
1093        let pending = Mutex::new(Vec::new());
1094        let activity_manager = Arc::new(RecordingActivityManager::default());
1095
1096        let ctx = ToolCallContext::new(
1097            &doc,
1098            &ops,
1099            "call-1",
1100            "test",
1101            &scope,
1102            &pending,
1103            activity_manager.clone(),
1104        );
1105
1106        ctx.report_tool_call_progress(ToolCallProgressUpdate {
1107            status: ToolCallProgressStatus::Running,
1108            progress: Some(0.5),
1109            loaded: None,
1110            total: Some(10.0),
1111            message: Some("half way".to_string()),
1112        })
1113        .expect("progress should be emitted");
1114
1115        let events = activity_manager.events.lock().unwrap();
1116        assert!(!events.is_empty());
1117        assert!(events.iter().all(|(stream_id, activity_type, _)| {
1118            stream_id == "tool_call:call-1" && activity_type == TOOL_CALL_PROGRESS_ACTIVITY_TYPE
1119        }));
1120        let state = rebuild_activity_state(&events);
1121        assert_eq!(state["type"], TOOL_CALL_PROGRESS_TYPE);
1122        assert_eq!(state["schema"], TOOL_CALL_PROGRESS_SCHEMA);
1123        assert_eq!(state["node_id"], "tool_call:call-1");
1124        assert_eq!(state["call_id"], "call-1");
1125        assert_eq!(state["status"], "running");
1126        assert_eq!(state["progress"], json!(0.5));
1127        assert_eq!(state["total"], json!(10.0));
1128        assert_eq!(state["message"], json!("half way"));
1129    }
1130
1131    #[test]
1132    fn test_report_tool_call_progress_rejects_non_finite_values() {
1133        let doc = DocCell::new(json!({}));
1134        let ops = Mutex::new(Vec::new());
1135        let scope = RunPolicy::default();
1136        let pending = Mutex::new(Vec::new());
1137        let ctx = make_ctx(&doc, &ops, &scope, &pending);
1138
1139        assert!(ctx
1140            .report_tool_call_progress(ToolCallProgressUpdate {
1141                status: ToolCallProgressStatus::Running,
1142                progress: Some(f64::NAN),
1143                loaded: None,
1144                total: None,
1145                message: None,
1146            })
1147            .is_err());
1148        assert!(ctx
1149            .report_tool_call_progress(ToolCallProgressUpdate {
1150                status: ToolCallProgressStatus::Running,
1151                progress: Some(0.5),
1152                loaded: None,
1153                total: Some(f64::INFINITY),
1154                message: None,
1155            })
1156            .is_err());
1157        assert!(ctx
1158            .report_tool_call_progress(ToolCallProgressUpdate {
1159                status: ToolCallProgressStatus::Running,
1160                progress: Some(0.5),
1161                loaded: Some(-1.0),
1162                total: None,
1163                message: None,
1164            })
1165            .is_err());
1166    }
1167
1168    #[test]
1169    fn test_report_tool_call_progress_writes_lineage_and_metadata() {
1170        let doc = DocCell::new(json!({}));
1171        let ops = Mutex::new(Vec::new());
1172        let scope = RunPolicy::new();
1173        let pending = Mutex::new(Vec::new());
1174        let activity_manager = Arc::new(RecordingActivityManager::default());
1175        let run_identity = RunIdentity::new(
1176            "thread-abc".to_string(),
1177            None,
1178            "run-123".to_string(),
1179            Some("run-parent".to_string()),
1180            "agent".to_string(),
1181            crate::storage::RunOrigin::Internal,
1182        )
1183        .with_parent_tool_call_id("call-parent");
1184        let caller_context = CallerContext::new(
1185            Some("thread-abc".to_string()),
1186            Some("run-parent".to_string()),
1187            Some("caller".to_string()),
1188            vec![],
1189        );
1190
1191        let ctx = ToolCallContext::new(
1192            &doc,
1193            &ops,
1194            "call-1",
1195            "tool:echo",
1196            &scope,
1197            &pending,
1198            activity_manager.clone(),
1199        )
1200        .with_run_identity(run_identity)
1201        .with_caller_context(caller_context);
1202
1203        ctx.report_tool_call_progress(ToolCallProgressUpdate {
1204            status: ToolCallProgressStatus::Done,
1205            progress: Some(1.0),
1206            loaded: Some(5.0),
1207            total: Some(5.0),
1208            message: Some("done".to_string()),
1209        })
1210        .expect("tool call progress should be emitted");
1211
1212        let events = activity_manager.events.lock().unwrap();
1213        let state = rebuild_activity_state(&events);
1214        assert_eq!(state["type"], TOOL_CALL_PROGRESS_TYPE);
1215        assert_eq!(state["schema"], TOOL_CALL_PROGRESS_SCHEMA);
1216        assert_eq!(state["node_id"], "tool_call:call-1");
1217        assert_eq!(state["parent_node_id"], "tool_call:call-parent");
1218        assert_eq!(state["parent_call_id"], "call-parent");
1219        assert_eq!(state["tool_name"], "echo");
1220        assert_eq!(state["status"], "done");
1221        assert_eq!(state["run_id"], "run-123");
1222        assert_eq!(state["parent_run_id"], "run-parent");
1223        assert_eq!(state["thread_id"], "thread-abc");
1224        assert!(state["updated_at_ms"].as_u64().unwrap_or_default() > 0);
1225    }
1226
1227    #[test]
1228    fn test_report_tool_call_progress_without_parent_tool_call_anchors_to_run_node() {
1229        let doc = DocCell::new(json!({}));
1230        let ops = Mutex::new(Vec::new());
1231        let scope = RunPolicy::new();
1232        let pending = Mutex::new(Vec::new());
1233        let activity_manager = Arc::new(RecordingActivityManager::default());
1234        let run_identity = run_identity("run-123");
1235        let ctx = ToolCallContext::new(
1236            &doc,
1237            &ops,
1238            "call-1",
1239            "tool:echo",
1240            &scope,
1241            &pending,
1242            activity_manager.clone(),
1243        )
1244        .with_run_identity(run_identity);
1245
1246        ctx.report_tool_call_progress(ToolCallProgressUpdate {
1247            status: ToolCallProgressStatus::Running,
1248            progress: Some(0.3),
1249            loaded: None,
1250            total: None,
1251            message: Some("working".to_string()),
1252        })
1253        .expect("tool call progress should be emitted");
1254
1255        let events = activity_manager.events.lock().unwrap();
1256        let state = rebuild_activity_state(&events);
1257        assert_eq!(state["parent_node_id"], "run:run-123");
1258        assert!(state["parent_call_id"].is_null());
1259    }
1260
1261    #[test]
1262    fn test_report_tool_call_progress_uses_injected_sink_instead_of_activity_manager() {
1263        let doc = DocCell::new(json!({}));
1264        let ops = Mutex::new(Vec::new());
1265        let scope = RunPolicy::default();
1266        let pending = Mutex::new(Vec::new());
1267        let activity_manager = Arc::new(RecordingActivityManager::default());
1268        let sink = Arc::new(RecordingProgressSink::default());
1269        let ctx = ToolCallContext::new(
1270            &doc,
1271            &ops,
1272            "call-1",
1273            "tool:echo",
1274            &scope,
1275            &pending,
1276            activity_manager.clone(),
1277        )
1278        .with_tool_call_progress_sink(sink.clone());
1279
1280        ctx.report_tool_call_progress(ToolCallProgressUpdate {
1281            status: ToolCallProgressStatus::Running,
1282            progress: Some(0.2),
1283            loaded: None,
1284            total: Some(10.0),
1285            message: Some("working".to_string()),
1286        })
1287        .expect("tool call progress should be reported");
1288
1289        let sink_events = sink.events.lock().unwrap();
1290        assert_eq!(sink_events.len(), 1);
1291        let (stream_id, activity_type, payload) = &sink_events[0];
1292        assert_eq!(stream_id, "tool_call:call-1");
1293        assert_eq!(activity_type, TOOL_CALL_PROGRESS_ACTIVITY_TYPE);
1294        assert_eq!(payload.call_id, "call-1");
1295        assert_eq!(payload.progress, Some(0.2));
1296
1297        let activity_events = activity_manager.events.lock().unwrap();
1298        assert!(
1299            activity_events.is_empty(),
1300            "injected sink should bypass default activity manager sink"
1301        );
1302    }
1303
1304    #[test]
1305    fn test_report_tool_call_progress_propagates_sink_error() {
1306        let doc = DocCell::new(json!({}));
1307        let ops = Mutex::new(Vec::new());
1308        let scope = RunPolicy::default();
1309        let pending = Mutex::new(Vec::new());
1310        let ctx = ToolCallContext::new(
1311            &doc,
1312            &ops,
1313            "call-1",
1314            "tool:echo",
1315            &scope,
1316            &pending,
1317            NoOpActivityManager::arc(),
1318        )
1319        .with_tool_call_progress_sink(Arc::new(FailingProgressSink));
1320
1321        let result = ctx.report_tool_call_progress(ToolCallProgressUpdate {
1322            status: ToolCallProgressStatus::Running,
1323            progress: Some(0.1),
1324            loaded: None,
1325            total: None,
1326            message: None,
1327        });
1328        assert!(result.is_err());
1329    }
1330}