1use crate::contracts::runtime::behavior::AgentBehavior;
4use crate::contracts::runtime::tool_call::ToolCallContext;
5use crate::contracts::runtime::tool_call::{Tool, ToolExecutionEffect, ToolResult};
6pub use crate::contracts::runtime::ToolExecution;
7use crate::contracts::thread::ToolCall;
8use crate::contracts::{reduce_state_actions, AnyStateAction, ScopeContext};
9use futures::future::join_all;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tirea_contract::RunPolicy;
14use tirea_state::{apply_patch, DocCell, Patch, TrackedPatch};
15
16pub async fn execute_single_tool(
29 tool: Option<&dyn Tool>,
30 call: &ToolCall,
31 state: &Value,
32) -> ToolExecution {
33 execute_single_tool_with_run_policy_and_behavior(tool, call, state, None, None).await
34}
35
36pub async fn execute_single_tool_with_run_policy(
38 tool: Option<&dyn Tool>,
39 call: &ToolCall,
40 state: &Value,
41 run_policy: Option<&RunPolicy>,
42) -> ToolExecution {
43 execute_single_tool_with_run_policy_and_behavior(tool, call, state, run_policy, None).await
44}
45
46pub async fn execute_single_tool_with_run_policy_and_behavior(
48 tool: Option<&dyn Tool>,
49 call: &ToolCall,
50 state: &Value,
51 run_policy: Option<&RunPolicy>,
52 _behavior: Option<&dyn AgentBehavior>,
53) -> ToolExecution {
54 let Some(tool) = tool else {
55 return ToolExecution {
56 call: call.clone(),
57 result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
58 patch: None,
59 };
60 };
61
62 let doc = DocCell::new(state.clone());
64 let ops = Mutex::new(Vec::new());
65 let default_run_policy = RunPolicy::default();
66 let run_policy = run_policy.unwrap_or(&default_run_policy);
67 let pending_messages = Mutex::new(Vec::new());
68 let ctx = ToolCallContext::new(
69 &doc,
70 &ops,
71 &call.id,
72 format!("tool:{}", call.name),
73 run_policy,
74 &pending_messages,
75 tirea_contract::runtime::activity::NoOpActivityManager::arc(),
76 )
77 .as_read_only();
78
79 if let Err(e) = tool.validate_args(&call.arguments) {
81 return ToolExecution {
82 call: call.clone(),
83 result: ToolResult::error(&call.name, e.to_string()),
84 patch: None,
85 };
86 }
87
88 let effect = match tool.execute_effect(call.arguments.clone(), &ctx).await {
90 Ok(effect) => effect,
91 Err(e) => ToolExecutionEffect::from(ToolResult::error(&call.name, e.to_string())),
92 };
93 let (result, actions) = effect.into_parts();
94 let state_actions: Vec<AnyStateAction> = actions
95 .into_iter()
96 .filter_map(|a| match a {
97 crate::contracts::runtime::phase::AfterToolExecuteAction::State(sa) => Some(sa),
98 _ => None,
99 })
100 .collect();
101 let tool_scope_ctx = ScopeContext::for_call(&call.id);
102 let action_patches = match reduce_state_actions(
103 state_actions,
104 state,
105 &format!("tool:{}", call.name),
106 &tool_scope_ctx,
107 ) {
108 Ok(patches) => patches,
109 Err(err) => {
110 return ToolExecution {
111 call: call.clone(),
112 result: ToolResult::error(
113 &call.name,
114 format!("tool state action reduce failed: {err}"),
115 ),
116 patch: None,
117 };
118 }
119 };
120
121 let mut merged_patch = Patch::new();
122 for tracked in action_patches {
123 merged_patch.extend(tracked.patch().clone());
124 }
125
126 let patch = if merged_patch.is_empty() {
127 None
128 } else {
129 Some(TrackedPatch::new(merged_patch).with_source(format!("tool:{}", call.name)))
130 };
131
132 ToolExecution {
133 call: call.clone(),
134 result,
135 patch,
136 }
137}
138
139pub async fn execute_tools_parallel(
141 tools: &HashMap<String, Arc<dyn Tool>>,
142 calls: &[ToolCall],
143 state: &Value,
144) -> Vec<ToolExecution> {
145 let tasks = calls.iter().map(|call| {
146 let tool = tools.get(&call.name).cloned();
147 let state = state.clone();
148 async move { execute_single_tool(tool.as_deref(), call, &state).await }
149 });
150 join_all(tasks).await
151}
152
153pub async fn execute_tools_sequential(
155 tools: &HashMap<String, Arc<dyn Tool>>,
156 calls: &[ToolCall],
157 state: &Value,
158) -> (Value, Vec<ToolExecution>) {
159 let mut state = state.clone();
160 let mut executions = Vec::with_capacity(calls.len());
161
162 for call in calls {
163 let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
164 if let Some(patch) = exec.patch.as_ref() {
165 if let Ok(next) = apply_patch(&state, patch.patch()) {
166 state = next;
167 }
168 }
169 executions.push(exec);
170 }
171
172 (state, executions)
173}
174
175pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
177 executions.iter().filter_map(|e| e.patch.clone()).collect()
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::contracts::runtime::state::AnyStateAction;
184 use crate::contracts::runtime::state::StateSpec;
185 use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError};
186 use crate::contracts::ToolCallContext;
187 use async_trait::async_trait;
188 use serde::{Deserialize, Serialize};
189 use serde_json::json;
190
191 use tirea_state::{PatchSink, Path as TPath, State, TireaResult};
192
193 struct EchoTool;
194
195 #[async_trait]
196 impl Tool for EchoTool {
197 fn descriptor(&self) -> ToolDescriptor {
198 ToolDescriptor::new("echo", "Echo", "Echo the input")
199 }
200
201 async fn execute(
202 &self,
203 args: Value,
204 _ctx: &ToolCallContext<'_>,
205 ) -> Result<ToolResult, ToolError> {
206 Ok(ToolResult::success("echo", args))
207 }
208 }
209
210 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
211 struct EffectCounterState {
212 value: i64,
213 }
214
215 struct EffectCounterRef;
216
217 impl State for EffectCounterState {
218 type Ref<'a> = EffectCounterRef;
219 const PATH: &'static str = "counter";
220
221 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
222 EffectCounterRef
223 }
224
225 fn from_value(value: &Value) -> TireaResult<Self> {
226 if value.is_null() {
227 return Ok(Self::default());
228 }
229 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
230 }
231
232 fn to_value(&self) -> TireaResult<Value> {
233 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
234 }
235 }
236
237 impl StateSpec for EffectCounterState {
238 type Action = i64;
239
240 fn reduce(&mut self, action: Self::Action) {
241 self.value += action;
242 }
243 }
244
245 struct EffectTool;
246
247 #[async_trait]
248 impl Tool for EffectTool {
249 fn descriptor(&self) -> ToolDescriptor {
250 ToolDescriptor::new("effect", "Effect", "Tool returning state actions")
251 }
252
253 async fn execute(
254 &self,
255 _args: Value,
256 _ctx: &ToolCallContext<'_>,
257 ) -> Result<ToolResult, ToolError> {
258 Ok(ToolResult::success("effect", json!({})))
259 }
260
261 async fn execute_effect(
262 &self,
263 _args: Value,
264 _ctx: &ToolCallContext<'_>,
265 ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
266 Ok(
267 crate::contracts::runtime::ToolExecutionEffect::new(ToolResult::success(
268 "effect",
269 json!({}),
270 ))
271 .with_action(AnyStateAction::new::<EffectCounterState>(2)),
272 )
273 }
274 }
275
276 #[tokio::test]
277 async fn test_execute_single_tool_not_found() {
278 let call = ToolCall::new("call_1", "nonexistent", json!({}));
279 let state = json!({});
280
281 let exec = execute_single_tool(None, &call, &state).await;
282
283 assert!(exec.result.is_error());
284 assert!(exec.patch.is_none());
285 }
286
287 #[tokio::test]
288 async fn test_execute_single_tool_success() {
289 let tool = EchoTool;
290 let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
291 let state = json!({});
292
293 let exec = execute_single_tool(Some(&tool), &call, &state).await;
294
295 assert!(exec.result.is_success());
296 assert_eq!(exec.result.data["msg"], "hello");
297 }
298
299 #[tokio::test]
300 async fn test_execute_single_tool_applies_state_actions_from_effect() {
301 let tool = EffectTool;
302 let call = ToolCall::new("call_1", "effect", json!({}));
303 let state = json!({"counter": {"value": 1}});
304
305 let exec = execute_single_tool(Some(&tool), &call, &state).await;
306 let patch = exec.patch.expect("patch should be emitted");
307 let next = apply_patch(&state, patch.patch()).expect("patch should apply");
308
309 assert_eq!(next["counter"]["value"], 3);
310 }
311
312 #[tokio::test]
313 async fn test_collect_patches() {
314 use tirea_state::{path, Op, Patch};
315
316 let executions = vec![
317 ToolExecution {
318 call: ToolCall::new("1", "a", json!({})),
319 result: ToolResult::success("a", json!({})),
320 patch: Some(TrackedPatch::new(
321 Patch::new().with_op(Op::set(path!("a"), json!(1))),
322 )),
323 },
324 ToolExecution {
325 call: ToolCall::new("2", "b", json!({})),
326 result: ToolResult::success("b", json!({})),
327 patch: None,
328 },
329 ToolExecution {
330 call: ToolCall::new("3", "c", json!({})),
331 result: ToolResult::success("c", json!({})),
332 patch: Some(TrackedPatch::new(
333 Patch::new().with_op(Op::set(path!("c"), json!(3))),
334 )),
335 },
336 ];
337
338 let patches = collect_patches(&executions);
339 assert_eq!(patches.len(), 2);
340 }
341
342 #[tokio::test]
343 async fn test_tool_execution_error() {
344 struct FailingTool;
345
346 #[async_trait]
347 impl Tool for FailingTool {
348 fn descriptor(&self) -> ToolDescriptor {
349 ToolDescriptor::new("failing", "Failing", "Always fails")
350 }
351
352 async fn execute(
353 &self,
354 _args: Value,
355 _ctx: &ToolCallContext<'_>,
356 ) -> Result<ToolResult, ToolError> {
357 Err(ToolError::ExecutionFailed(
358 "Intentional failure".to_string(),
359 ))
360 }
361 }
362
363 let tool = FailingTool;
364 let call = ToolCall::new("call_1", "failing", json!({}));
365 let state = json!({});
366
367 let exec = execute_single_tool(Some(&tool), &call, &state).await;
368
369 assert!(exec.result.is_error());
370 assert!(exec
371 .result
372 .message
373 .as_ref()
374 .unwrap()
375 .contains("Intentional failure"));
376 }
377
378 #[tokio::test]
379 async fn test_execute_single_tool_with_default_run_identity_has_no_parent_tool_call() {
380 struct RunIdentityReaderTool;
382
383 #[async_trait]
384 impl Tool for RunIdentityReaderTool {
385 fn descriptor(&self) -> ToolDescriptor {
386 ToolDescriptor::new(
387 "run_identity_reader",
388 "RunIdentityReader",
389 "Reads run identity",
390 )
391 }
392
393 async fn execute(
394 &self,
395 _args: Value,
396 ctx: &ToolCallContext<'_>,
397 ) -> Result<ToolResult, ToolError> {
398 let parent_tool_call_id = ctx
399 .run_identity()
400 .parent_tool_call_id_opt()
401 .unwrap_or("none");
402 Ok(ToolResult::success(
403 "run_identity_reader",
404 json!({"parent_tool_call_id": parent_tool_call_id}),
405 ))
406 }
407 }
408
409 let tool = RunIdentityReaderTool;
410 let call = ToolCall::new("call_1", "run_identity_reader", json!({}));
411 let state = json!({});
412
413 let exec = execute_single_tool_with_run_policy(Some(&tool), &call, &state, None).await;
414
415 assert!(exec.result.is_success());
416 assert_eq!(exec.result.data["parent_tool_call_id"], "none");
417 }
418
419 #[tokio::test]
420 async fn test_execute_single_tool_with_run_policy_none() {
421 struct RunPolicyCheckerTool;
423
424 #[async_trait]
425 impl Tool for RunPolicyCheckerTool {
426 fn descriptor(&self) -> ToolDescriptor {
427 ToolDescriptor::new(
428 "run_policy_checker",
429 "RunPolicyChecker",
430 "Checks runtime option presence",
431 )
432 }
433
434 async fn execute(
435 &self,
436 _args: Value,
437 ctx: &ToolCallContext<'_>,
438 ) -> Result<ToolResult, ToolError> {
439 Ok(ToolResult::success(
440 "run_policy_checker",
441 json!({
442 "has_run_policy": true,
443 "has_parent_tool_call_id": ctx.run_identity().parent_tool_call_id_opt().is_some()
444 }),
445 ))
446 }
447 }
448
449 let tool = RunPolicyCheckerTool;
450 let call = ToolCall::new("call_1", "run_policy_checker", json!({}));
451 let state = json!({});
452
453 let exec = execute_single_tool_with_run_policy(Some(&tool), &call, &state, None).await;
455 assert_eq!(exec.result.data["has_run_policy"], true);
456 assert_eq!(exec.result.data["has_parent_tool_call_id"], false);
457
458 let run_policy = RunPolicy::new();
460 let exec =
461 execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&run_policy))
462 .await;
463 assert_eq!(exec.result.data["has_run_policy"], true);
464 assert_eq!(exec.result.data["has_parent_tool_call_id"], false);
465 }
466
467 #[tokio::test]
468 async fn test_execute_with_run_policy() {
469 struct SensitiveReaderTool;
471
472 #[async_trait]
473 impl Tool for SensitiveReaderTool {
474 fn descriptor(&self) -> ToolDescriptor {
475 ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
476 }
477
478 async fn execute(
479 &self,
480 _args: Value,
481 ctx: &ToolCallContext<'_>,
482 ) -> Result<ToolResult, ToolError> {
483 let allowed_tools = ctx
484 .run_policy()
485 .allowed_tools()
486 .map(|items| items.to_vec())
487 .unwrap_or_default();
488 Ok(ToolResult::success(
489 "sensitive",
490 json!({"allowed_tools": allowed_tools}),
491 ))
492 }
493 }
494
495 let mut run_policy = RunPolicy::new();
496 run_policy
497 .set_allowed_tools_if_absent(Some(&["sensitive".to_string(), "echo".to_string()]));
498
499 let tool = SensitiveReaderTool;
500 let call = ToolCall::new("call_1", "sensitive", json!({}));
501 let state = json!({});
502
503 let exec =
504 execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&run_policy))
505 .await;
506
507 assert!(exec.result.is_success());
508 assert_eq!(
509 exec.result.data["allowed_tools"],
510 json!(["sensitive", "echo"])
511 );
512 }
513
514 struct StrictSchemaTool {
520 executed: std::sync::atomic::AtomicBool,
521 }
522
523 #[async_trait]
524 impl Tool for StrictSchemaTool {
525 fn descriptor(&self) -> ToolDescriptor {
526 ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
527 json!({
528 "type": "object",
529 "properties": {
530 "name": { "type": "string" }
531 },
532 "required": ["name"]
533 }),
534 )
535 }
536
537 async fn execute(
538 &self,
539 args: Value,
540 _ctx: &ToolCallContext<'_>,
541 ) -> Result<ToolResult, ToolError> {
542 self.executed
543 .store(true, std::sync::atomic::Ordering::SeqCst);
544 Ok(ToolResult::success("strict", args))
545 }
546 }
547
548 #[tokio::test]
549 async fn test_validate_args_blocks_invalid_before_execute() {
550 let tool = StrictSchemaTool {
551 executed: std::sync::atomic::AtomicBool::new(false),
552 };
553 let call = ToolCall::new("call_1", "strict", json!({}));
555 let state = json!({});
556
557 let exec = execute_single_tool(Some(&tool), &call, &state).await;
558
559 assert!(exec.result.is_error());
560 assert!(
561 exec.result.message.as_ref().unwrap().contains("name"),
562 "error should mention the missing field"
563 );
564 assert!(
565 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
566 "execute() must NOT be called when validate_args fails"
567 );
568 }
569
570 #[tokio::test]
571 async fn test_validate_args_passes_valid_to_execute() {
572 let tool = StrictSchemaTool {
573 executed: std::sync::atomic::AtomicBool::new(false),
574 };
575 let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
576 let state = json!({});
577
578 let exec = execute_single_tool(Some(&tool), &call, &state).await;
579
580 assert!(exec.result.is_success());
581 assert!(
582 tool.executed.load(std::sync::atomic::Ordering::SeqCst),
583 "execute() should be called for valid args"
584 );
585 }
586
587 #[tokio::test]
588 async fn test_validate_args_wrong_type_blocks_execute() {
589 let tool = StrictSchemaTool {
590 executed: std::sync::atomic::AtomicBool::new(false),
591 };
592 let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
594 let state = json!({});
595
596 let exec = execute_single_tool(Some(&tool), &call, &state).await;
597
598 assert!(exec.result.is_error());
599 assert!(
600 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
601 "execute() must NOT be called when validate_args fails"
602 );
603 }
604}