1use super::core::{
2 pending_approval_placeholder_message, transition_tool_call_state, ToolCallStateSeed,
3 ToolCallStateTransition,
4};
5use super::parallel_state_merge::merge_parallel_state_patches;
6use super::plugin_runtime::emit_tool_phase;
7use super::{Agent, AgentLoopError, BaseAgent, RunCancellationToken};
8use crate::contracts::runtime::behavior::AgentBehavior;
9use crate::contracts::runtime::phase::{AfterToolExecuteAction, Phase, StepContext};
10use crate::contracts::runtime::state::{reduce_state_actions, AnyStateAction, ScopeContext};
11use crate::contracts::runtime::tool_call::{CallerContext, ToolGate};
12use crate::contracts::runtime::tool_call::{Tool, ToolDescriptor, ToolResult};
13use crate::contracts::runtime::{
14 ActivityManager, PendingToolCall, SuspendTicket, SuspendedCall, ToolCallResumeMode,
15};
16use crate::contracts::runtime::{
17 DecisionReplayPolicy, StreamResult, ToolCallOutcome, ToolCallStatus, ToolExecution,
18 ToolExecutionEffect, ToolExecutionRequest, ToolExecutionResult, ToolExecutor,
19 ToolExecutorError,
20};
21use crate::contracts::thread::Thread;
22use crate::contracts::thread::{Message, MessageMetadata, ToolCall};
23use crate::contracts::{RunContext, Suspension};
24use crate::engine::convert::tool_response;
25use crate::runtime::run_context::{await_or_cancel, is_cancelled, CancelAware};
26use async_trait::async_trait;
27use serde_json::Value;
28use std::collections::HashMap;
29use std::sync::Arc;
30use tirea_state::{apply_patch, Patch, TrackedPatch};
31
32#[derive(Debug)]
37pub enum ExecuteToolsOutcome {
38 Completed(Thread),
40 Suspended {
42 thread: Thread,
43 suspended_call: Box<SuspendedCall>,
44 },
45}
46
47impl ExecuteToolsOutcome {
48 pub fn into_thread(self) -> Thread {
50 match self {
51 Self::Completed(t) | Self::Suspended { thread: t, .. } => t,
52 }
53 }
54
55 pub fn is_suspended(&self) -> bool {
57 matches!(self, Self::Suspended { .. })
58 }
59}
60
61pub(super) struct AppliedToolResults {
62 pub(super) suspended_calls: Vec<SuspendedCall>,
63 pub(super) state_snapshot: Option<Value>,
64}
65
66#[derive(Clone)]
67pub(super) struct ToolPhaseContext<'a> {
68 pub(super) tool_descriptors: &'a [ToolDescriptor],
69 pub(super) agent_behavior: Option<&'a dyn AgentBehavior>,
70 pub(super) activity_manager: Arc<dyn ActivityManager>,
71 pub(super) run_policy: &'a tirea_contract::RunPolicy,
72 pub(super) run_identity: tirea_contract::runtime::RunIdentity,
73 pub(super) caller_context: CallerContext,
74 pub(super) thread_id: &'a str,
75 pub(super) thread_messages: &'a [Arc<Message>],
76 pub(super) cancellation_token: Option<&'a RunCancellationToken>,
77}
78
79impl<'a> ToolPhaseContext<'a> {
80 pub(super) fn from_request(request: &'a ToolExecutionRequest<'a>) -> Self {
81 Self {
82 tool_descriptors: request.tool_descriptors,
83 agent_behavior: request.agent_behavior,
84 activity_manager: request.activity_manager.clone(),
85 run_policy: request.run_policy,
86 run_identity: request.run_identity.clone(),
87 caller_context: request.caller_context.clone(),
88 thread_id: request.thread_id,
89 thread_messages: request.thread_messages,
90 cancellation_token: request.cancellation_token,
91 }
92 }
93}
94
95fn now_unix_millis() -> u64 {
96 std::time::SystemTime::now()
97 .duration_since(std::time::UNIX_EPOCH)
98 .map_or(0, |d| d.as_millis().min(u128::from(u64::MAX)) as u64)
99}
100
101fn suspended_call_from_tool_result(call: &ToolCall, result: &ToolResult) -> SuspendedCall {
102 if let Some(mut explicit) = result.suspension() {
103 if explicit.pending.id.trim().is_empty() || explicit.pending.name.trim().is_empty() {
104 explicit.pending =
105 PendingToolCall::new(call.id.clone(), call.name.clone(), call.arguments.clone());
106 }
107 return SuspendedCall::new(call, explicit);
108 }
109
110 let mut suspension = Suspension::new(&call.id, format!("tool:{}", call.name))
111 .with_parameters(call.arguments.clone());
112 if let Some(message) = result.message.as_ref() {
113 suspension = suspension.with_message(message.clone());
114 }
115
116 SuspendedCall::new(
117 call,
118 SuspendTicket::new(
119 suspension,
120 PendingToolCall::new(call.id.clone(), call.name.clone(), call.arguments.clone()),
121 ToolCallResumeMode::ReplayToolCall,
122 ),
123 )
124}
125
126fn persist_tool_call_status(
127 step: &StepContext<'_>,
128 call: &ToolCall,
129 status: ToolCallStatus,
130 suspended_call: Option<&SuspendedCall>,
131) -> Result<crate::contracts::runtime::ToolCallState, AgentLoopError> {
132 let current_state = step.ctx().tool_call_state_for(&call.id).map_err(|e| {
133 AgentLoopError::StateError(format!(
134 "failed to read tool call state for '{}' before setting {:?}: {e}",
135 call.id, status
136 ))
137 })?;
138 let previous_status = current_state
139 .as_ref()
140 .map(|state| state.status)
141 .unwrap_or(ToolCallStatus::New);
142 let current_resume_token = current_state
143 .as_ref()
144 .and_then(|state| state.resume_token.clone());
145 let current_resume = current_state
146 .as_ref()
147 .and_then(|state| state.resume.clone());
148
149 let (next_resume_token, next_resume) = match status {
150 ToolCallStatus::Running => {
151 if matches!(previous_status, ToolCallStatus::Resuming) {
152 (current_resume_token.clone(), current_resume.clone())
153 } else {
154 (None, None)
155 }
156 }
157 ToolCallStatus::Suspended => (
158 suspended_call
159 .map(|entry| entry.ticket.pending.id.clone())
160 .or(current_resume_token.clone()),
161 None,
162 ),
163 ToolCallStatus::Succeeded
164 | ToolCallStatus::Failed
165 | ToolCallStatus::Cancelled
166 | ToolCallStatus::New
167 | ToolCallStatus::Resuming => (current_resume_token, current_resume),
168 };
169
170 let Some(runtime_state) = transition_tool_call_state(
171 current_state,
172 ToolCallStateSeed {
173 call_id: &call.id,
174 tool_name: &call.name,
175 arguments: &call.arguments,
176 status: ToolCallStatus::New,
177 resume_token: None,
178 },
179 ToolCallStateTransition {
180 status,
181 resume_token: next_resume_token,
182 resume: next_resume,
183 updated_at: now_unix_millis(),
184 },
185 ) else {
186 return Err(AgentLoopError::StateError(format!(
187 "invalid tool call status transition for '{}': {:?} -> {:?}",
188 call.id, previous_status, status
189 )));
190 };
191
192 step.ctx()
193 .set_tool_call_state_for(&call.id, runtime_state.clone())
194 .map_err(|e| {
195 AgentLoopError::StateError(format!(
196 "failed to persist tool call state for '{}' as {:?}: {e}",
197 call.id, status
198 ))
199 })?;
200
201 Ok(runtime_state)
202}
203
204fn map_tool_executor_error(err: AgentLoopError, thread_id: &str) -> ToolExecutorError {
205 match err {
206 AgentLoopError::Cancelled => ToolExecutorError::Cancelled {
207 thread_id: thread_id.to_string(),
208 },
209 other => ToolExecutorError::Failed {
210 message: other.to_string(),
211 },
212 }
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq)]
217pub enum ParallelToolExecutionMode {
218 BatchApproval,
219 Streaming,
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub struct ParallelToolExecutor {
225 mode: ParallelToolExecutionMode,
226}
227
228impl ParallelToolExecutor {
229 pub const fn batch_approval() -> Self {
230 Self {
231 mode: ParallelToolExecutionMode::BatchApproval,
232 }
233 }
234
235 pub const fn streaming() -> Self {
236 Self {
237 mode: ParallelToolExecutionMode::Streaming,
238 }
239 }
240
241 fn mode_name(self) -> &'static str {
242 match self.mode {
243 ParallelToolExecutionMode::BatchApproval => "parallel_batch_approval",
244 ParallelToolExecutionMode::Streaming => "parallel_streaming",
245 }
246 }
247}
248
249impl Default for ParallelToolExecutor {
250 fn default() -> Self {
251 Self::streaming()
252 }
253}
254
255#[async_trait]
256impl ToolExecutor for ParallelToolExecutor {
257 async fn execute(
258 &self,
259 request: ToolExecutionRequest<'_>,
260 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
261 let thread_id = request.thread_id;
262 let phase_ctx = ToolPhaseContext::from_request(&request);
263 execute_tools_parallel_with_phases(request.tools, request.calls, request.state, phase_ctx)
264 .await
265 .map_err(|e| map_tool_executor_error(e, thread_id))
266 }
267
268 fn name(&self) -> &'static str {
269 self.mode_name()
270 }
271
272 fn requires_parallel_patch_conflict_check(&self) -> bool {
273 true
274 }
275
276 fn decision_replay_policy(&self) -> DecisionReplayPolicy {
277 match self.mode {
278 ParallelToolExecutionMode::BatchApproval => DecisionReplayPolicy::BatchAllSuspended,
279 ParallelToolExecutionMode::Streaming => DecisionReplayPolicy::Immediate,
280 }
281 }
282}
283
284#[derive(Debug, Clone, Copy, Default)]
286pub struct SequentialToolExecutor;
287
288#[async_trait]
289impl ToolExecutor for SequentialToolExecutor {
290 async fn execute(
291 &self,
292 request: ToolExecutionRequest<'_>,
293 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
294 let thread_id = request.thread_id;
295 let phase_ctx = ToolPhaseContext::from_request(&request);
296 execute_tools_sequential_with_phases(request.tools, request.calls, request.state, phase_ctx)
297 .await
298 .map_err(|e| map_tool_executor_error(e, thread_id))
299 }
300
301 fn name(&self) -> &'static str {
302 "sequential"
303 }
304}
305
306pub(super) fn apply_tool_results_to_session(
307 run_ctx: &mut RunContext,
308 results: &[ToolExecutionResult],
309 metadata: Option<MessageMetadata>,
310 check_parallel_patch_conflicts: bool,
311) -> Result<AppliedToolResults, AgentLoopError> {
312 apply_tool_results_impl(
313 run_ctx,
314 results,
315 metadata,
316 check_parallel_patch_conflicts,
317 None,
318 )
319}
320
321pub(super) fn apply_tool_results_impl(
322 run_ctx: &mut RunContext,
323 results: &[ToolExecutionResult],
324 metadata: Option<MessageMetadata>,
325 check_parallel_patch_conflicts: bool,
326 tool_msg_ids: Option<&HashMap<String, String>>,
327) -> Result<AppliedToolResults, AgentLoopError> {
328 let suspended: Vec<SuspendedCall> = results
330 .iter()
331 .filter_map(|r| {
332 if matches!(r.outcome, ToolCallOutcome::Suspended) {
333 r.suspended_call.clone()
334 } else {
335 None
336 }
337 })
338 .collect();
339
340 let all_serialized_state_actions: Vec<tirea_contract::SerializedStateAction> = results
342 .iter()
343 .flat_map(|r| r.serialized_state_actions.iter().cloned())
344 .collect();
345 if !all_serialized_state_actions.is_empty() {
346 run_ctx.add_serialized_state_actions(all_serialized_state_actions);
347 }
348
349 let base_snapshot = run_ctx
350 .snapshot()
351 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
352 let patches = merge_parallel_state_patches(
353 &base_snapshot,
354 results,
355 check_parallel_patch_conflicts,
356 run_ctx.lattice_registry(),
357 )?;
358 let mut state_changed = !patches.is_empty();
359 run_ctx.add_thread_patches(patches);
360
361 let tool_messages: Vec<Arc<Message>> = results
363 .iter()
364 .flat_map(|r| {
365 let is_suspended = matches!(r.outcome, ToolCallOutcome::Suspended);
366 let mut msgs = if is_suspended {
367 vec![Message::tool(
368 &r.execution.call.id,
369 pending_approval_placeholder_message(&r.execution.call.name),
370 )]
371 } else {
372 let mut tool_msg = tool_response(&r.execution.call.id, &r.execution.result);
373 if let Some(id) = tool_msg_ids.and_then(|ids| ids.get(&r.execution.call.id)) {
374 tool_msg = tool_msg.with_id(id.clone());
375 }
376 vec![tool_msg]
377 };
378 for reminder in &r.reminders {
379 msgs.push(Message::internal_system(format!(
380 "<system-reminder>{}</system-reminder>",
381 reminder
382 )));
383 }
384 if let Some(ref meta) = metadata {
385 for msg in &mut msgs {
386 msg.metadata = Some(meta.clone());
387 }
388 }
389 msgs.into_iter().map(Arc::new).collect::<Vec<_>>()
390 })
391 .collect();
392
393 run_ctx.add_messages(tool_messages);
394
395 let user_messages: Vec<Arc<Message>> = results
397 .iter()
398 .flat_map(|r| {
399 r.user_messages
400 .iter()
401 .map(|s| s.trim())
402 .filter(|s| !s.is_empty())
403 .map(|text| {
404 let mut msg = Message::user(text.to_string());
405 if let Some(ref meta) = metadata {
406 msg.metadata = Some(meta.clone());
407 }
408 Arc::new(msg)
409 })
410 .collect::<Vec<_>>()
411 })
412 .collect();
413 if !user_messages.is_empty() {
414 run_ctx.add_messages(user_messages);
415 }
416 if !suspended.is_empty() {
417 let state = run_ctx
418 .snapshot()
419 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
420 let actions: Vec<AnyStateAction> = suspended
421 .iter()
422 .map(|call| call.clone().into_state_action())
423 .collect();
424 let patches = reduce_state_actions(actions, &state, "agent_loop", &ScopeContext::run())
425 .map_err(|e| {
426 AgentLoopError::StateError(format!("failed to reduce suspended call actions: {e}"))
427 })?;
428 for patch in patches {
429 if !patch.patch().is_empty() {
430 state_changed = true;
431 run_ctx.add_thread_patch(patch);
432 }
433 }
434 let state_snapshot = if state_changed {
435 Some(
436 run_ctx
437 .snapshot()
438 .map_err(|e| AgentLoopError::StateError(e.to_string()))?,
439 )
440 } else {
441 None
442 };
443 return Ok(AppliedToolResults {
444 suspended_calls: suspended,
445 state_snapshot,
446 });
447 }
448
449 let state_snapshot = if state_changed {
456 Some(
457 run_ctx
458 .snapshot()
459 .map_err(|e| AgentLoopError::StateError(e.to_string()))?,
460 )
461 } else {
462 None
463 };
464
465 Ok(AppliedToolResults {
466 suspended_calls: Vec::new(),
467 state_snapshot,
468 })
469}
470
471fn tool_result_metadata_from_run_ctx(
472 run_ctx: &RunContext,
473 run_id: Option<&str>,
474) -> Option<MessageMetadata> {
475 let run_id = run_id.map(|id| id.to_string()).or_else(|| {
476 run_ctx.messages().iter().rev().find_map(|m| {
477 m.metadata
478 .as_ref()
479 .and_then(|meta| meta.run_id.as_ref().cloned())
480 })
481 });
482
483 let step_index = run_ctx
484 .messages()
485 .iter()
486 .rev()
487 .find_map(|m| m.metadata.as_ref().and_then(|meta| meta.step_index));
488
489 if run_id.is_none() && step_index.is_none() {
490 None
491 } else {
492 Some(MessageMetadata { run_id, step_index })
493 }
494}
495
496#[allow(dead_code)]
497pub(super) fn next_step_index(run_ctx: &RunContext) -> u32 {
498 run_ctx
499 .messages()
500 .iter()
501 .filter_map(|m| m.metadata.as_ref().and_then(|meta| meta.step_index))
502 .max()
503 .map(|v| v.saturating_add(1))
504 .unwrap_or(0)
505}
506
507pub(super) fn step_metadata(run_id: Option<String>, step_index: u32) -> MessageMetadata {
508 MessageMetadata {
509 run_id,
510 step_index: Some(step_index),
511 }
512}
513
514pub async fn execute_tools(
518 thread: Thread,
519 result: &StreamResult,
520 tools: &HashMap<String, Arc<dyn Tool>>,
521 parallel: bool,
522) -> Result<ExecuteToolsOutcome, AgentLoopError> {
523 let parallel_executor = ParallelToolExecutor::streaming();
524 let sequential_executor = SequentialToolExecutor;
525 let executor: &dyn ToolExecutor = if parallel {
526 ¶llel_executor
527 } else {
528 &sequential_executor
529 };
530 execute_tools_with_agent_and_executor(thread, result, tools, executor, None).await
531}
532
533pub async fn execute_tools_with_config(
535 thread: Thread,
536 result: &StreamResult,
537 tools: &HashMap<String, Arc<dyn Tool>>,
538 agent: &dyn Agent,
539) -> Result<ExecuteToolsOutcome, AgentLoopError> {
540 execute_tools_with_agent_and_executor(
541 thread,
542 result,
543 tools,
544 agent.tool_executor().as_ref(),
545 Some(agent.behavior()),
546 )
547 .await
548}
549
550pub(super) fn caller_context_for_tool_execution(
551 run_ctx: &RunContext,
552 _state: &Value,
553) -> CallerContext {
554 CallerContext::new(
555 Some(run_ctx.thread_id().to_string()),
556 run_ctx.run_identity().run_id_opt().map(ToOwned::to_owned),
557 run_ctx.run_identity().agent_id_opt().map(ToOwned::to_owned),
558 run_ctx.messages().to_vec(),
559 )
560}
561
562pub async fn execute_tools_with_behaviors(
564 thread: Thread,
565 result: &StreamResult,
566 tools: &HashMap<String, Arc<dyn Tool>>,
567 parallel: bool,
568 behavior: Arc<dyn AgentBehavior>,
569) -> Result<ExecuteToolsOutcome, AgentLoopError> {
570 let executor: Arc<dyn ToolExecutor> = if parallel {
571 Arc::new(ParallelToolExecutor::streaming())
572 } else {
573 Arc::new(SequentialToolExecutor)
574 };
575 let agent = BaseAgent::default()
576 .with_behavior(behavior)
577 .with_tool_executor(executor);
578 execute_tools_with_config(thread, result, tools, &agent).await
579}
580
581async fn execute_tools_with_agent_and_executor(
582 thread: Thread,
583 result: &StreamResult,
584 tools: &HashMap<String, Arc<dyn Tool>>,
585 executor: &dyn ToolExecutor,
586 behavior: Option<&dyn AgentBehavior>,
587) -> Result<ExecuteToolsOutcome, AgentLoopError> {
588 let rebuilt_state = thread
590 .rebuild_state()
591 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
592 let mut run_ctx = RunContext::new(
593 &thread.id,
594 rebuilt_state.clone(),
595 thread.messages.clone(),
596 tirea_contract::RunPolicy::default(),
597 );
598
599 let tool_descriptors: Vec<ToolDescriptor> =
600 tools.values().map(|t| t.descriptor().clone()).collect();
601 if let Some(behavior) = behavior {
603 let run_start_patches = super::plugin_runtime::behavior_run_phase_block(
604 &run_ctx,
605 &tool_descriptors,
606 behavior,
607 &[Phase::RunStart],
608 |_| {},
609 |_| (),
610 )
611 .await?
612 .1;
613 if !run_start_patches.is_empty() {
614 run_ctx.add_thread_patches(run_start_patches);
615 }
616 }
617
618 let replay_executor: Arc<dyn ToolExecutor> = match executor.decision_replay_policy() {
619 DecisionReplayPolicy::BatchAllSuspended => Arc::new(ParallelToolExecutor::batch_approval()),
620 DecisionReplayPolicy::Immediate => Arc::new(ParallelToolExecutor::streaming()),
621 };
622 let replay_config = BaseAgent::default().with_tool_executor(replay_executor);
623 let replay = super::drain_resuming_tool_calls_and_replay(
624 &mut run_ctx,
625 tools,
626 &replay_config,
627 &tool_descriptors,
628 )
629 .await?;
630
631 if replay.replayed {
632 let suspended = run_ctx.suspended_calls().values().next().cloned();
633 let delta = run_ctx.take_delta();
634 let mut out_thread = thread;
635 for msg in delta.messages {
636 out_thread = out_thread.with_message((*msg).clone());
637 }
638 out_thread = out_thread.with_patches(delta.patches);
639 return if let Some(first) = suspended {
640 Ok(ExecuteToolsOutcome::Suspended {
641 thread: out_thread,
642 suspended_call: Box::new(first),
643 })
644 } else {
645 Ok(ExecuteToolsOutcome::Completed(out_thread))
646 };
647 }
648
649 if result.tool_calls.is_empty() {
650 let delta = run_ctx.take_delta();
651 let mut out_thread = thread;
652 for msg in delta.messages {
653 out_thread = out_thread.with_message((*msg).clone());
654 }
655 out_thread = out_thread.with_patches(delta.patches);
656 return Ok(ExecuteToolsOutcome::Completed(out_thread));
657 }
658
659 let current_state = run_ctx
660 .snapshot()
661 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
662 let caller_context = caller_context_for_tool_execution(&run_ctx, ¤t_state);
663 let results = executor
664 .execute(ToolExecutionRequest {
665 tools,
666 calls: &result.tool_calls,
667 state: ¤t_state,
668 tool_descriptors: &tool_descriptors,
669 agent_behavior: behavior,
670 activity_manager: tirea_contract::runtime::activity::NoOpActivityManager::arc(),
671 run_policy: run_ctx.run_policy(),
672 run_identity: run_ctx.run_identity().clone(),
673 caller_context,
674 thread_id: run_ctx.thread_id(),
675 thread_messages: run_ctx.messages(),
676 state_version: run_ctx.version(),
677 cancellation_token: None,
678 })
679 .await?;
680
681 let metadata = tool_result_metadata_from_run_ctx(&run_ctx, None);
682 let applied = apply_tool_results_to_session(
683 &mut run_ctx,
684 &results,
685 metadata,
686 executor.requires_parallel_patch_conflict_check(),
687 )?;
688 let suspended = applied.suspended_calls.into_iter().next();
689
690 let delta = run_ctx.take_delta();
692 let mut out_thread = thread;
693 for msg in delta.messages {
694 out_thread = out_thread.with_message((*msg).clone());
695 }
696 out_thread = out_thread.with_patches(delta.patches);
697
698 if let Some(first) = suspended {
699 Ok(ExecuteToolsOutcome::Suspended {
700 thread: out_thread,
701 suspended_call: Box::new(first),
702 })
703 } else {
704 Ok(ExecuteToolsOutcome::Completed(out_thread))
705 }
706}
707
708pub(super) async fn execute_tools_parallel_with_phases(
710 tools: &HashMap<String, Arc<dyn Tool>>,
711 calls: &[crate::contracts::thread::ToolCall],
712 state: &Value,
713 phase_ctx: ToolPhaseContext<'_>,
714) -> Result<Vec<ToolExecutionResult>, AgentLoopError> {
715 use futures::future::join_all;
716
717 if is_cancelled(phase_ctx.cancellation_token) {
718 return Err(cancelled_error(phase_ctx.thread_id));
719 }
720
721 let run_policy_owned = phase_ctx.run_policy.clone();
723 let thread_id = phase_ctx.thread_id.to_string();
724 let thread_messages = Arc::new(phase_ctx.thread_messages.to_vec());
725 let tool_descriptors = phase_ctx.tool_descriptors.to_vec();
726 let agent = phase_ctx.agent_behavior;
727
728 let futures = calls.iter().map(|call| {
729 let tool = tools.get(&call.name).cloned();
730 let state = state.clone();
731 let call = call.clone();
732 let tool_descriptors = tool_descriptors.clone();
733 let activity_manager = phase_ctx.activity_manager.clone();
734 let rt = run_policy_owned.clone();
735 let run_identity = phase_ctx.run_identity.clone();
736 let caller_context = phase_ctx.caller_context.clone();
737 let sid = thread_id.clone();
738 let thread_messages = thread_messages.clone();
739
740 async move {
741 execute_single_tool_with_phases_impl(
742 tool.as_deref(),
743 &call,
744 &state,
745 &ToolPhaseContext {
746 tool_descriptors: &tool_descriptors,
747 agent_behavior: agent,
748 activity_manager,
749 run_policy: &rt,
750 run_identity,
751 caller_context,
752 thread_id: &sid,
753 thread_messages: thread_messages.as_slice(),
754 cancellation_token: None,
755 },
756 )
757 .await
758 }
759 });
760
761 let join_future = join_all(futures);
762 let results = match await_or_cancel(phase_ctx.cancellation_token, join_future).await {
763 CancelAware::Cancelled => return Err(cancelled_error(&thread_id)),
764 CancelAware::Value(results) => results,
765 };
766 let results: Vec<ToolExecutionResult> = results.into_iter().collect::<Result<_, _>>()?;
767 Ok(results)
768}
769
770pub(super) async fn execute_tools_sequential_with_phases(
772 tools: &HashMap<String, Arc<dyn Tool>>,
773 calls: &[crate::contracts::thread::ToolCall],
774 initial_state: &Value,
775 phase_ctx: ToolPhaseContext<'_>,
776) -> Result<Vec<ToolExecutionResult>, AgentLoopError> {
777 use tirea_state::apply_patch;
778
779 if is_cancelled(phase_ctx.cancellation_token) {
780 return Err(cancelled_error(phase_ctx.thread_id));
781 }
782
783 let mut state = initial_state.clone();
784 let mut results = Vec::with_capacity(calls.len());
785
786 for call in calls {
787 let tool = tools.get(&call.name).cloned();
788 let call_phase_ctx = ToolPhaseContext {
789 tool_descriptors: phase_ctx.tool_descriptors,
790 agent_behavior: phase_ctx.agent_behavior,
791 activity_manager: phase_ctx.activity_manager.clone(),
792 run_policy: phase_ctx.run_policy,
793 run_identity: phase_ctx.run_identity.clone(),
794 caller_context: phase_ctx.caller_context.clone(),
795 thread_id: phase_ctx.thread_id,
796 thread_messages: phase_ctx.thread_messages,
797 cancellation_token: None,
798 };
799 let result = match await_or_cancel(
800 phase_ctx.cancellation_token,
801 execute_single_tool_with_phases_impl(tool.as_deref(), call, &state, &call_phase_ctx),
802 )
803 .await
804 {
805 CancelAware::Cancelled => return Err(cancelled_error(phase_ctx.thread_id)),
806 CancelAware::Value(result) => result?,
807 };
808
809 if let Some(ref patch) = result.execution.patch {
811 state = apply_patch(&state, patch.patch()).map_err(|e| {
812 AgentLoopError::StateError(format!(
813 "failed to apply tool patch for call '{}': {}",
814 result.execution.call.id, e
815 ))
816 })?;
817 }
818 for pp in &result.pending_patches {
820 state = apply_patch(&state, pp.patch()).map_err(|e| {
821 AgentLoopError::StateError(format!(
822 "failed to apply plugin patch for call '{}': {}",
823 result.execution.call.id, e
824 ))
825 })?;
826 }
827
828 results.push(result);
829
830 if results
831 .last()
832 .is_some_and(|r| matches!(r.outcome, ToolCallOutcome::Suspended))
833 {
834 break;
835 }
836 }
837
838 Ok(results)
839}
840
841#[cfg(test)]
843pub(super) async fn execute_single_tool_with_phases(
844 tool: Option<&dyn Tool>,
845 call: &crate::contracts::thread::ToolCall,
846 state: &Value,
847 phase_ctx: &ToolPhaseContext<'_>,
848) -> Result<ToolExecutionResult, AgentLoopError> {
849 execute_single_tool_with_phases_impl(tool, call, state, phase_ctx).await
850}
851
852pub(super) async fn execute_single_tool_with_phases_deferred(
853 tool: Option<&dyn Tool>,
854 call: &crate::contracts::thread::ToolCall,
855 state: &Value,
856 phase_ctx: &ToolPhaseContext<'_>,
857) -> Result<ToolExecutionResult, AgentLoopError> {
858 execute_single_tool_with_phases_impl(tool, call, state, phase_ctx).await
859}
860
861async fn execute_single_tool_with_phases_impl(
862 tool: Option<&dyn Tool>,
863 call: &crate::contracts::thread::ToolCall,
864 state: &Value,
865 phase_ctx: &ToolPhaseContext<'_>,
866) -> Result<ToolExecutionResult, AgentLoopError> {
867 let doc = tirea_state::DocCell::new(state.clone());
869 let ops = std::sync::Mutex::new(Vec::new());
870 let pending_messages = std::sync::Mutex::new(Vec::new());
871 let plugin_scope = phase_ctx.run_policy;
872 let mut plugin_tool_call_ctx = crate::contracts::ToolCallContext::new(
873 &doc,
874 &ops,
875 "plugin_phase",
876 "plugin:tool_phase",
877 plugin_scope,
878 &pending_messages,
879 tirea_contract::runtime::activity::NoOpActivityManager::arc(),
880 )
881 .with_run_identity(phase_ctx.run_identity.clone())
882 .with_caller_context(phase_ctx.caller_context.clone());
883 if let Some(token) = phase_ctx.cancellation_token {
884 plugin_tool_call_ctx = plugin_tool_call_ctx.with_cancellation_token(token);
885 }
886
887 let mut step = StepContext::new(
889 plugin_tool_call_ctx,
890 phase_ctx.thread_id,
891 phase_ctx.thread_messages,
892 phase_ctx.tool_descriptors.to_vec(),
893 );
894 step.gate = Some(ToolGate::from_tool_call(call));
895 emit_tool_phase(
897 Phase::BeforeToolExecute,
898 &mut step,
899 phase_ctx.agent_behavior,
900 &doc,
901 )
902 .await?;
903
904 let (mut execution, outcome, suspended_call, tool_actions) = if step.tool_blocked() {
906 let reason = step
907 .gate
908 .as_ref()
909 .and_then(|g| g.block_reason.clone())
910 .unwrap_or_else(|| "Blocked by plugin".to_string());
911 (
912 ToolExecution {
913 call: call.clone(),
914 result: ToolResult::error(&call.name, reason),
915 patch: None,
916 },
917 ToolCallOutcome::Failed,
918 None,
919 Vec::<AfterToolExecuteAction>::new(),
920 )
921 } else if let Some(plugin_result) = step.tool_result().cloned() {
922 let outcome = ToolCallOutcome::from_tool_result(&plugin_result);
923 (
924 ToolExecution {
925 call: call.clone(),
926 result: plugin_result,
927 patch: None,
928 },
929 outcome,
930 None,
931 Vec::<AfterToolExecuteAction>::new(),
932 )
933 } else {
934 match tool {
935 None => (
936 ToolExecution {
937 call: call.clone(),
938 result: ToolResult::error(
939 &call.name,
940 format!("Tool '{}' not found", call.name),
941 ),
942 patch: None,
943 },
944 ToolCallOutcome::Failed,
945 None,
946 Vec::<AfterToolExecuteAction>::new(),
947 ),
948 Some(tool) => {
949 if let Err(e) = tool.validate_args(&call.arguments) {
950 (
951 ToolExecution {
952 call: call.clone(),
953 result: ToolResult::error(&call.name, e.to_string()),
954 patch: None,
955 },
956 ToolCallOutcome::Failed,
957 None,
958 Vec::<AfterToolExecuteAction>::new(),
959 )
960 } else if step.tool_pending() {
961 let Some(suspend_ticket) =
962 step.gate.as_ref().and_then(|g| g.suspend_ticket.clone())
963 else {
964 return Err(AgentLoopError::StateError(
965 "tool is pending but suspend ticket is missing".to_string(),
966 ));
967 };
968 (
969 ToolExecution {
970 call: call.clone(),
971 result: ToolResult::suspended(
972 &call.name,
973 "Execution suspended; awaiting external decision",
974 ),
975 patch: None,
976 },
977 ToolCallOutcome::Suspended,
978 Some(SuspendedCall::new(call, suspend_ticket)),
979 Vec::<AfterToolExecuteAction>::new(),
980 )
981 } else {
982 persist_tool_call_status(&step, call, ToolCallStatus::Running, None)?;
983 let tool_doc = tirea_state::DocCell::new(state.clone());
985 let tool_ops = std::sync::Mutex::new(Vec::new());
986 let tool_pending_msgs = std::sync::Mutex::new(Vec::new());
987 let mut tool_ctx = crate::contracts::ToolCallContext::new(
988 &tool_doc,
989 &tool_ops,
990 &call.id,
991 format!("tool:{}", call.name),
992 plugin_scope,
993 &tool_pending_msgs,
994 phase_ctx.activity_manager.clone(),
995 )
996 .as_read_only()
997 .with_run_identity(phase_ctx.run_identity.clone())
998 .with_caller_context(phase_ctx.caller_context.clone());
999 if let Some(token) = phase_ctx.cancellation_token {
1000 tool_ctx = tool_ctx.with_cancellation_token(token);
1001 }
1002 let effect = match tool.execute_effect(call.arguments.clone(), &tool_ctx).await
1003 {
1004 Ok(effect) => effect,
1005 Err(e) => {
1006 ToolExecutionEffect::from(ToolResult::error(&call.name, e.to_string()))
1007 }
1008 };
1009 let (result, actions) = effect.into_parts();
1010 let outcome = ToolCallOutcome::from_tool_result(&result);
1011
1012 let suspended_call = if matches!(outcome, ToolCallOutcome::Suspended) {
1013 Some(suspended_call_from_tool_result(call, &result))
1014 } else {
1015 None
1016 };
1017
1018 (
1019 ToolExecution {
1020 call: call.clone(),
1021 result,
1022 patch: None,
1023 },
1024 outcome,
1025 suspended_call,
1026 actions,
1027 )
1028 }
1029 }
1030 }
1031 };
1032
1033 if let Some(gate) = step.gate.as_mut() {
1035 gate.result = Some(execution.result.clone());
1036 }
1037
1038 let mut tool_state_actions = Vec::<AnyStateAction>::new();
1041 for action in tool_actions {
1042 match action {
1043 AfterToolExecuteAction::State(sa) => tool_state_actions.push(sa),
1044 AfterToolExecuteAction::AddSystemReminder(text) => {
1045 step.messaging.reminders.push(text);
1046 }
1047 AfterToolExecuteAction::AddUserMessage(text) => {
1048 step.messaging.user_messages.push(text);
1049 }
1050 }
1051 }
1052
1053 emit_tool_phase(
1055 Phase::AfterToolExecute,
1056 &mut step,
1057 phase_ctx.agent_behavior,
1058 &doc,
1059 )
1060 .await?;
1061
1062 let terminal_tool_call_state = match outcome {
1063 ToolCallOutcome::Suspended => Some(persist_tool_call_status(
1064 &step,
1065 call,
1066 ToolCallStatus::Suspended,
1067 suspended_call.as_ref(),
1068 )?),
1069 ToolCallOutcome::Succeeded => Some(persist_tool_call_status(
1070 &step,
1071 call,
1072 ToolCallStatus::Succeeded,
1073 None,
1074 )?),
1075 ToolCallOutcome::Failed => Some(persist_tool_call_status(
1076 &step,
1077 call,
1078 ToolCallStatus::Failed,
1079 None,
1080 )?),
1081 };
1082
1083 if let Some(tool_call_state) = terminal_tool_call_state {
1084 tool_state_actions.push(tool_call_state.into_state_action());
1085 }
1086
1087 if !matches!(outcome, ToolCallOutcome::Suspended) {
1090 let cleanup_path = format!("__tool_call_scope.{}.suspended_call", call.id);
1091 let cleanup_patch = Patch::with_ops(vec![tirea_state::Op::delete(
1092 tirea_state::parse_path(&cleanup_path),
1093 )]);
1094 let tracked = TrackedPatch::new(cleanup_patch).with_source("framework:scope_cleanup");
1095 step.emit_patch(tracked);
1096 }
1097
1098 let mut serialized_state_actions: Vec<tirea_contract::SerializedStateAction> =
1100 tool_state_actions
1101 .iter()
1102 .map(|a| a.to_serialized_state_action())
1103 .collect();
1104
1105 let tool_scope_ctx = ScopeContext::for_call(&call.id);
1106 let execution_patch_parts = reduce_tool_state_actions(
1107 state,
1108 tool_state_actions,
1109 &format!("tool:{}", call.name),
1110 &tool_scope_ctx,
1111 )?;
1112 execution.patch = merge_tracked_patches(&execution_patch_parts, &format!("tool:{}", call.name));
1113
1114 let phase_base_state = if let Some(tool_patch) = execution.patch.as_ref() {
1115 tirea_state::apply_patch(state, tool_patch.patch()).map_err(|e| {
1116 AgentLoopError::StateError(format!(
1117 "failed to apply tool patch for call '{}': {}",
1118 call.id, e
1119 ))
1120 })?
1121 } else {
1122 state.clone()
1123 };
1124 let pending_patches = apply_tracked_patches_checked(
1125 &phase_base_state,
1126 std::mem::take(&mut step.pending_patches),
1127 &call.id,
1128 )?;
1129
1130 let reminders = step.messaging.reminders.clone();
1131 let user_messages = std::mem::take(&mut step.messaging.user_messages);
1132
1133 serialized_state_actions.extend(step.take_pending_serialized_state_actions());
1135
1136 Ok(ToolExecutionResult {
1137 execution,
1138 outcome,
1139 suspended_call,
1140 reminders,
1141 user_messages,
1142 pending_patches,
1143 serialized_state_actions,
1144 })
1145}
1146
1147fn reduce_tool_state_actions(
1148 base_state: &Value,
1149 actions: Vec<AnyStateAction>,
1150 source: &str,
1151 scope_ctx: &ScopeContext,
1152) -> Result<Vec<TrackedPatch>, AgentLoopError> {
1153 reduce_state_actions(actions, base_state, source, scope_ctx).map_err(|e| {
1154 AgentLoopError::StateError(format!("failed to reduce tool state actions: {e}"))
1155 })
1156}
1157
1158fn merge_tracked_patches(patches: &[TrackedPatch], source: &str) -> Option<TrackedPatch> {
1159 let mut merged = Patch::new();
1160 for tracked in patches {
1161 merged.extend(tracked.patch().clone());
1162 }
1163 if merged.is_empty() {
1164 None
1165 } else {
1166 Some(TrackedPatch::new(merged).with_source(source.to_string()))
1167 }
1168}
1169
1170fn apply_tracked_patches_checked(
1171 base_state: &Value,
1172 patches: Vec<TrackedPatch>,
1173 call_id: &str,
1174) -> Result<Vec<TrackedPatch>, AgentLoopError> {
1175 let mut rolling = base_state.clone();
1176 let mut validated = Vec::with_capacity(patches.len());
1177 for tracked in patches {
1178 if tracked.patch().is_empty() {
1179 continue;
1180 }
1181 rolling = apply_patch(&rolling, tracked.patch()).map_err(|e| {
1182 AgentLoopError::StateError(format!(
1183 "failed to apply pending state patch for call '{}': {}",
1184 call_id, e
1185 ))
1186 })?;
1187 validated.push(tracked);
1188 }
1189 Ok(validated)
1190}
1191
1192fn cancelled_error(_thread_id: &str) -> AgentLoopError {
1193 AgentLoopError::Cancelled
1194}
1195
1196#[cfg(test)]
1197mod tests {
1198 use super::*;
1199 use serde_json::json;
1200 use tirea_state::Op;
1201
1202 #[test]
1203 fn apply_tracked_patches_checked_keeps_valid_sequence() {
1204 let patches = vec![
1205 TrackedPatch::new(Patch::new().with_op(Op::set(tirea_state::path!("alpha"), json!(1))))
1206 .with_source("test:first"),
1207 TrackedPatch::new(Patch::new().with_op(Op::set(tirea_state::path!("beta"), json!(2))))
1208 .with_source("test:second"),
1209 ];
1210
1211 let validated =
1212 apply_tracked_patches_checked(&json!({}), patches, "call_1").expect("patches valid");
1213
1214 assert_eq!(validated.len(), 2);
1215 assert_eq!(validated[0].patch().ops().len(), 1);
1216 assert_eq!(validated[1].patch().ops().len(), 1);
1217 }
1218
1219 #[test]
1220 fn apply_tracked_patches_checked_reports_invalid_sequence() {
1221 let patches = vec![TrackedPatch::new(
1222 Patch::new().with_op(Op::increment(tirea_state::path!("counter"), 1_i64)),
1223 )
1224 .with_source("test:broken")];
1225
1226 let error = apply_tracked_patches_checked(&json!({}), patches, "call_1")
1227 .expect_err("increment against missing path should fail");
1228
1229 assert!(matches!(error, AgentLoopError::StateError(message)
1230 if message.contains("failed to apply pending state patch for call 'call_1'")));
1231 }
1232}