tirea_agentos/engine/
tool_execution.rs

1//! Tool execution utilities.
2
3use crate::contracts::runtime::behavior::AgentBehavior;
4use crate::contracts::runtime::tool_call::ToolCallContext;
5use crate::contracts::runtime::tool_call::{Tool, ToolExecutionEffect, ToolResult};
6pub use crate::contracts::runtime::ToolExecution;
7use crate::contracts::thread::ToolCall;
8use crate::contracts::{reduce_state_actions, AnyStateAction, ScopeContext};
9use futures::future::join_all;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tirea_contract::RunPolicy;
14use tirea_state::{apply_patch, DocCell, Patch, TrackedPatch};
15
16/// Execute a single tool call.
17///
18/// This function:
19/// 1. Creates a Context from the state snapshot
20/// 2. Executes the tool
21/// 3. Extracts any state changes as a TrackedPatch
22///
23/// # Arguments
24///
25/// * `tool` - The tool to execute (or None if not found)
26/// * `call` - The tool call with id, name, and arguments
27/// * `state` - The current state snapshot (read-only)
28pub async fn execute_single_tool(
29    tool: Option<&dyn Tool>,
30    call: &ToolCall,
31    state: &Value,
32) -> ToolExecution {
33    execute_single_tool_with_run_policy_and_behavior(tool, call, state, None, None).await
34}
35
36/// Execute a single tool call with optional run policy.
37pub async fn execute_single_tool_with_run_policy(
38    tool: Option<&dyn Tool>,
39    call: &ToolCall,
40    state: &Value,
41    run_policy: Option<&RunPolicy>,
42) -> ToolExecution {
43    execute_single_tool_with_run_policy_and_behavior(tool, call, state, run_policy, None).await
44}
45
46/// Execute a single tool call with optional run policy and behavior router.
47pub async fn execute_single_tool_with_run_policy_and_behavior(
48    tool: Option<&dyn Tool>,
49    call: &ToolCall,
50    state: &Value,
51    run_policy: Option<&RunPolicy>,
52    _behavior: Option<&dyn AgentBehavior>,
53) -> ToolExecution {
54    let Some(tool) = tool else {
55        return ToolExecution {
56            call: call.clone(),
57            result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
58            patch: None,
59        };
60    };
61
62    // Create context for this tool call
63    let doc = DocCell::new(state.clone());
64    let ops = Mutex::new(Vec::new());
65    let default_run_policy = RunPolicy::default();
66    let run_policy = run_policy.unwrap_or(&default_run_policy);
67    let pending_messages = Mutex::new(Vec::new());
68    let ctx = ToolCallContext::new(
69        &doc,
70        &ops,
71        &call.id,
72        format!("tool:{}", call.name),
73        run_policy,
74        &pending_messages,
75        tirea_contract::runtime::activity::NoOpActivityManager::arc(),
76    )
77    .as_read_only();
78
79    // Validate arguments against the tool's JSON Schema
80    if let Err(e) = tool.validate_args(&call.arguments) {
81        return ToolExecution {
82            call: call.clone(),
83            result: ToolResult::error(&call.name, e.to_string()),
84            patch: None,
85        };
86    }
87
88    // Execute the tool
89    let effect = match tool.execute_effect(call.arguments.clone(), &ctx).await {
90        Ok(effect) => effect,
91        Err(e) => ToolExecutionEffect::from(ToolResult::error(&call.name, e.to_string())),
92    };
93    let (result, actions) = effect.into_parts();
94    let state_actions: Vec<AnyStateAction> = actions
95        .into_iter()
96        .filter_map(|a| match a {
97            crate::contracts::runtime::phase::AfterToolExecuteAction::State(sa) => Some(sa),
98            _ => None,
99        })
100        .collect();
101    let tool_scope_ctx = ScopeContext::for_call(&call.id);
102    let action_patches = match reduce_state_actions(
103        state_actions,
104        state,
105        &format!("tool:{}", call.name),
106        &tool_scope_ctx,
107    ) {
108        Ok(patches) => patches,
109        Err(err) => {
110            return ToolExecution {
111                call: call.clone(),
112                result: ToolResult::error(
113                    &call.name,
114                    format!("tool state action reduce failed: {err}"),
115                ),
116                patch: None,
117            };
118        }
119    };
120
121    let mut merged_patch = Patch::new();
122    for tracked in action_patches {
123        merged_patch.extend(tracked.patch().clone());
124    }
125
126    let patch = if merged_patch.is_empty() {
127        None
128    } else {
129        Some(TrackedPatch::new(merged_patch).with_source(format!("tool:{}", call.name)))
130    };
131
132    ToolExecution {
133        call: call.clone(),
134        result,
135        patch,
136    }
137}
138
139/// Execute tool calls in parallel using the same state snapshot for every call.
140pub async fn execute_tools_parallel(
141    tools: &HashMap<String, Arc<dyn Tool>>,
142    calls: &[ToolCall],
143    state: &Value,
144) -> Vec<ToolExecution> {
145    let tasks = calls.iter().map(|call| {
146        let tool = tools.get(&call.name).cloned();
147        let state = state.clone();
148        async move { execute_single_tool(tool.as_deref(), call, &state).await }
149    });
150    join_all(tasks).await
151}
152
153/// Execute tool calls sequentially, applying each resulting patch before the next call.
154pub async fn execute_tools_sequential(
155    tools: &HashMap<String, Arc<dyn Tool>>,
156    calls: &[ToolCall],
157    state: &Value,
158) -> (Value, Vec<ToolExecution>) {
159    let mut state = state.clone();
160    let mut executions = Vec::with_capacity(calls.len());
161
162    for call in calls {
163        let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
164        if let Some(patch) = exec.patch.as_ref() {
165            if let Ok(next) = apply_patch(&state, patch.patch()) {
166                state = next;
167            }
168        }
169        executions.push(exec);
170    }
171
172    (state, executions)
173}
174
175/// Collect patches from executions.
176pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
177    executions.iter().filter_map(|e| e.patch.clone()).collect()
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::contracts::runtime::state::AnyStateAction;
184    use crate::contracts::runtime::state::StateSpec;
185    use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError};
186    use crate::contracts::ToolCallContext;
187    use async_trait::async_trait;
188    use serde::{Deserialize, Serialize};
189    use serde_json::json;
190
191    use tirea_state::{PatchSink, Path as TPath, State, TireaResult};
192
193    struct EchoTool;
194
195    #[async_trait]
196    impl Tool for EchoTool {
197        fn descriptor(&self) -> ToolDescriptor {
198            ToolDescriptor::new("echo", "Echo", "Echo the input")
199        }
200
201        async fn execute(
202            &self,
203            args: Value,
204            _ctx: &ToolCallContext<'_>,
205        ) -> Result<ToolResult, ToolError> {
206            Ok(ToolResult::success("echo", args))
207        }
208    }
209
210    #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
211    struct EffectCounterState {
212        value: i64,
213    }
214
215    struct EffectCounterRef;
216
217    impl State for EffectCounterState {
218        type Ref<'a> = EffectCounterRef;
219        const PATH: &'static str = "counter";
220
221        fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
222            EffectCounterRef
223        }
224
225        fn from_value(value: &Value) -> TireaResult<Self> {
226            if value.is_null() {
227                return Ok(Self::default());
228            }
229            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
230        }
231
232        fn to_value(&self) -> TireaResult<Value> {
233            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
234        }
235    }
236
237    impl StateSpec for EffectCounterState {
238        type Action = i64;
239
240        fn reduce(&mut self, action: Self::Action) {
241            self.value += action;
242        }
243    }
244
245    struct EffectTool;
246
247    #[async_trait]
248    impl Tool for EffectTool {
249        fn descriptor(&self) -> ToolDescriptor {
250            ToolDescriptor::new("effect", "Effect", "Tool returning state actions")
251        }
252
253        async fn execute(
254            &self,
255            _args: Value,
256            _ctx: &ToolCallContext<'_>,
257        ) -> Result<ToolResult, ToolError> {
258            Ok(ToolResult::success("effect", json!({})))
259        }
260
261        async fn execute_effect(
262            &self,
263            _args: Value,
264            _ctx: &ToolCallContext<'_>,
265        ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
266            Ok(
267                crate::contracts::runtime::ToolExecutionEffect::new(ToolResult::success(
268                    "effect",
269                    json!({}),
270                ))
271                .with_action(AnyStateAction::new::<EffectCounterState>(2)),
272            )
273        }
274    }
275
276    #[tokio::test]
277    async fn test_execute_single_tool_not_found() {
278        let call = ToolCall::new("call_1", "nonexistent", json!({}));
279        let state = json!({});
280
281        let exec = execute_single_tool(None, &call, &state).await;
282
283        assert!(exec.result.is_error());
284        assert!(exec.patch.is_none());
285    }
286
287    #[tokio::test]
288    async fn test_execute_single_tool_success() {
289        let tool = EchoTool;
290        let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
291        let state = json!({});
292
293        let exec = execute_single_tool(Some(&tool), &call, &state).await;
294
295        assert!(exec.result.is_success());
296        assert_eq!(exec.result.data["msg"], "hello");
297    }
298
299    #[tokio::test]
300    async fn test_execute_single_tool_applies_state_actions_from_effect() {
301        let tool = EffectTool;
302        let call = ToolCall::new("call_1", "effect", json!({}));
303        let state = json!({"counter": {"value": 1}});
304
305        let exec = execute_single_tool(Some(&tool), &call, &state).await;
306        let patch = exec.patch.expect("patch should be emitted");
307        let next = apply_patch(&state, patch.patch()).expect("patch should apply");
308
309        assert_eq!(next["counter"]["value"], 3);
310    }
311
312    #[tokio::test]
313    async fn test_collect_patches() {
314        use tirea_state::{path, Op, Patch};
315
316        let executions = vec![
317            ToolExecution {
318                call: ToolCall::new("1", "a", json!({})),
319                result: ToolResult::success("a", json!({})),
320                patch: Some(TrackedPatch::new(
321                    Patch::new().with_op(Op::set(path!("a"), json!(1))),
322                )),
323            },
324            ToolExecution {
325                call: ToolCall::new("2", "b", json!({})),
326                result: ToolResult::success("b", json!({})),
327                patch: None,
328            },
329            ToolExecution {
330                call: ToolCall::new("3", "c", json!({})),
331                result: ToolResult::success("c", json!({})),
332                patch: Some(TrackedPatch::new(
333                    Patch::new().with_op(Op::set(path!("c"), json!(3))),
334                )),
335            },
336        ];
337
338        let patches = collect_patches(&executions);
339        assert_eq!(patches.len(), 2);
340    }
341
342    #[tokio::test]
343    async fn test_tool_execution_error() {
344        struct FailingTool;
345
346        #[async_trait]
347        impl Tool for FailingTool {
348            fn descriptor(&self) -> ToolDescriptor {
349                ToolDescriptor::new("failing", "Failing", "Always fails")
350            }
351
352            async fn execute(
353                &self,
354                _args: Value,
355                _ctx: &ToolCallContext<'_>,
356            ) -> Result<ToolResult, ToolError> {
357                Err(ToolError::ExecutionFailed(
358                    "Intentional failure".to_string(),
359                ))
360            }
361        }
362
363        let tool = FailingTool;
364        let call = ToolCall::new("call_1", "failing", json!({}));
365        let state = json!({});
366
367        let exec = execute_single_tool(Some(&tool), &call, &state).await;
368
369        assert!(exec.result.is_error());
370        assert!(exec
371            .result
372            .message
373            .as_ref()
374            .unwrap()
375            .contains("Intentional failure"));
376    }
377
378    #[tokio::test]
379    async fn test_execute_single_tool_with_default_run_identity_has_no_parent_tool_call() {
380        /// Tool that reads the default run identity and returns parent lineage.
381        struct RunIdentityReaderTool;
382
383        #[async_trait]
384        impl Tool for RunIdentityReaderTool {
385            fn descriptor(&self) -> ToolDescriptor {
386                ToolDescriptor::new(
387                    "run_identity_reader",
388                    "RunIdentityReader",
389                    "Reads run identity",
390                )
391            }
392
393            async fn execute(
394                &self,
395                _args: Value,
396                ctx: &ToolCallContext<'_>,
397            ) -> Result<ToolResult, ToolError> {
398                let parent_tool_call_id = ctx
399                    .run_identity()
400                    .parent_tool_call_id_opt()
401                    .unwrap_or("none");
402                Ok(ToolResult::success(
403                    "run_identity_reader",
404                    json!({"parent_tool_call_id": parent_tool_call_id}),
405                ))
406            }
407        }
408
409        let tool = RunIdentityReaderTool;
410        let call = ToolCall::new("call_1", "run_identity_reader", json!({}));
411        let state = json!({});
412
413        let exec = execute_single_tool_with_run_policy(Some(&tool), &call, &state, None).await;
414
415        assert!(exec.result.is_success());
416        assert_eq!(exec.result.data["parent_tool_call_id"], "none");
417    }
418
419    #[tokio::test]
420    async fn test_execute_single_tool_with_run_policy_none() {
421        /// Tool that checks typed run-policy defaults when none are supplied.
422        struct RunPolicyCheckerTool;
423
424        #[async_trait]
425        impl Tool for RunPolicyCheckerTool {
426            fn descriptor(&self) -> ToolDescriptor {
427                ToolDescriptor::new(
428                    "run_policy_checker",
429                    "RunPolicyChecker",
430                    "Checks runtime option presence",
431                )
432            }
433
434            async fn execute(
435                &self,
436                _args: Value,
437                ctx: &ToolCallContext<'_>,
438            ) -> Result<ToolResult, ToolError> {
439                Ok(ToolResult::success(
440                    "run_policy_checker",
441                    json!({
442                        "has_run_policy": true,
443                        "has_parent_tool_call_id": ctx.run_identity().parent_tool_call_id_opt().is_some()
444                    }),
445                ))
446            }
447        }
448
449        let tool = RunPolicyCheckerTool;
450        let call = ToolCall::new("call_1", "run_policy_checker", json!({}));
451        let state = json!({});
452
453        // Without explicit run policy, ToolCallContext still provides defaults.
454        let exec = execute_single_tool_with_run_policy(Some(&tool), &call, &state, None).await;
455        assert_eq!(exec.result.data["has_run_policy"], true);
456        assert_eq!(exec.result.data["has_parent_tool_call_id"], false);
457
458        // With explicit empty run policy.
459        let run_policy = RunPolicy::new();
460        let exec =
461            execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&run_policy))
462                .await;
463        assert_eq!(exec.result.data["has_run_policy"], true);
464        assert_eq!(exec.result.data["has_parent_tool_call_id"], false);
465    }
466
467    #[tokio::test]
468    async fn test_execute_with_run_policy() {
469        /// Tool that reads typed policy values from the run policy.
470        struct SensitiveReaderTool;
471
472        #[async_trait]
473        impl Tool for SensitiveReaderTool {
474            fn descriptor(&self) -> ToolDescriptor {
475                ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
476            }
477
478            async fn execute(
479                &self,
480                _args: Value,
481                ctx: &ToolCallContext<'_>,
482            ) -> Result<ToolResult, ToolError> {
483                let allowed_tools = ctx
484                    .run_policy()
485                    .allowed_tools()
486                    .map(|items| items.to_vec())
487                    .unwrap_or_default();
488                Ok(ToolResult::success(
489                    "sensitive",
490                    json!({"allowed_tools": allowed_tools}),
491                ))
492            }
493        }
494
495        let mut run_policy = RunPolicy::new();
496        run_policy
497            .set_allowed_tools_if_absent(Some(&["sensitive".to_string(), "echo".to_string()]));
498
499        let tool = SensitiveReaderTool;
500        let call = ToolCall::new("call_1", "sensitive", json!({}));
501        let state = json!({});
502
503        let exec =
504            execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&run_policy))
505                .await;
506
507        assert!(exec.result.is_success());
508        assert_eq!(
509            exec.result.data["allowed_tools"],
510            json!(["sensitive", "echo"])
511        );
512    }
513
514    // =========================================================================
515    // validate_args integration: strict schema blocks invalid args at exec path
516    // =========================================================================
517
518    /// Tool with a strict schema — execute should never be reached on invalid args.
519    struct StrictSchemaTool {
520        executed: std::sync::atomic::AtomicBool,
521    }
522
523    #[async_trait]
524    impl Tool for StrictSchemaTool {
525        fn descriptor(&self) -> ToolDescriptor {
526            ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
527                json!({
528                    "type": "object",
529                    "properties": {
530                        "name": { "type": "string" }
531                    },
532                    "required": ["name"]
533                }),
534            )
535        }
536
537        async fn execute(
538            &self,
539            args: Value,
540            _ctx: &ToolCallContext<'_>,
541        ) -> Result<ToolResult, ToolError> {
542            self.executed
543                .store(true, std::sync::atomic::Ordering::SeqCst);
544            Ok(ToolResult::success("strict", args))
545        }
546    }
547
548    #[tokio::test]
549    async fn test_validate_args_blocks_invalid_before_execute() {
550        let tool = StrictSchemaTool {
551            executed: std::sync::atomic::AtomicBool::new(false),
552        };
553        // Missing required "name" field
554        let call = ToolCall::new("call_1", "strict", json!({}));
555        let state = json!({});
556
557        let exec = execute_single_tool(Some(&tool), &call, &state).await;
558
559        assert!(exec.result.is_error());
560        assert!(
561            exec.result.message.as_ref().unwrap().contains("name"),
562            "error should mention the missing field"
563        );
564        assert!(
565            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
566            "execute() must NOT be called when validate_args fails"
567        );
568    }
569
570    #[tokio::test]
571    async fn test_validate_args_passes_valid_to_execute() {
572        let tool = StrictSchemaTool {
573            executed: std::sync::atomic::AtomicBool::new(false),
574        };
575        let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
576        let state = json!({});
577
578        let exec = execute_single_tool(Some(&tool), &call, &state).await;
579
580        assert!(exec.result.is_success());
581        assert!(
582            tool.executed.load(std::sync::atomic::Ordering::SeqCst),
583            "execute() should be called for valid args"
584        );
585    }
586
587    #[tokio::test]
588    async fn test_validate_args_wrong_type_blocks_execute() {
589        let tool = StrictSchemaTool {
590            executed: std::sync::atomic::AtomicBool::new(false),
591        };
592        // "name" should be string, not integer
593        let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
594        let state = json!({});
595
596        let exec = execute_single_tool(Some(&tool), &call, &state).await;
597
598        assert!(exec.result.is_error());
599        assert!(
600            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
601            "execute() must NOT be called when validate_args fails"
602        );
603    }
604}