tirea_agentos/runtime/stop_policy/
plugin.rs1use async_trait::async_trait;
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4
5use crate::composition::StopConditionSpec;
6use crate::contracts::runtime::behavior::{AgentBehavior, ReadOnlyContext};
7use crate::contracts::runtime::phase::{ActionSet, AfterInferenceAction};
8use crate::contracts::runtime::state::AnyStateAction;
9use crate::contracts::runtime::tool_call::ToolResult;
10use crate::contracts::runtime::StreamResult;
11use crate::contracts::thread::{Message, Role, ToolCall};
12use crate::contracts::{RunContext, TerminationReason};
13
14use super::conditions::{StopPolicy, StopPolicyInput, StopPolicyStats};
15use super::state::{StopPolicyRuntimeAction, StopPolicyRuntimeState};
16use super::{
17 ConsecutiveErrors, ContentMatch, LoopDetection, MaxRounds, StopOnTool, Timeout, TokenBudget,
18 STOP_POLICY_PLUGIN_ID,
19};
20
21pub struct StopPolicyPlugin {
25 conditions: Vec<Arc<dyn StopPolicy>>,
26}
27
28impl std::fmt::Debug for StopPolicyPlugin {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("StopPolicyPlugin")
31 .field("conditions_len", &self.conditions.len())
32 .finish()
33 }
34}
35
36impl StopPolicyPlugin {
37 pub fn new(
38 mut stop_conditions: Vec<Arc<dyn StopPolicy>>,
39 stop_condition_specs: Vec<StopConditionSpec>,
40 ) -> Self {
41 stop_conditions.extend(stop_condition_specs.into_iter().map(condition_from_spec));
42 Self {
43 conditions: stop_conditions,
44 }
45 }
46
47 pub fn is_empty(&self) -> bool {
48 self.conditions.is_empty()
49 }
50}
51
52#[async_trait]
53impl AgentBehavior for StopPolicyPlugin {
54 fn id(&self) -> &str {
55 STOP_POLICY_PLUGIN_ID
56 }
57
58 tirea_contract::declare_plugin_states!(StopPolicyRuntimeState);
59
60 async fn after_inference(&self, ctx: &ReadOnlyContext<'_>) -> ActionSet<AfterInferenceAction> {
61 if self.conditions.is_empty() {
62 return ActionSet::empty();
63 }
64
65 let Some(response) = ctx.response() else {
66 return ActionSet::empty();
67 };
68 let now_ms = now_millis();
69 let prompt_tokens = response
70 .usage
71 .as_ref()
72 .and_then(|usage| usage.prompt_tokens)
73 .unwrap_or(0) as usize;
74 let completion_tokens = response
75 .usage
76 .as_ref()
77 .and_then(|usage| usage.completion_tokens)
78 .unwrap_or(0) as usize;
79
80 let runtime = ctx
81 .snapshot_of::<StopPolicyRuntimeState>()
82 .unwrap_or_default();
83 let started_at_ms = runtime.started_at_ms.unwrap_or(now_ms);
84 let total_input_tokens =
85 (runtime.total_input_tokens.value() as usize).saturating_add(prompt_tokens);
86 let total_output_tokens =
87 (runtime.total_output_tokens.value() as usize).saturating_add(completion_tokens);
88
89 let mut actions: ActionSet<AfterInferenceAction> = ActionSet::empty();
90
91 actions = actions.and(AfterInferenceAction::State(AnyStateAction::new::<
93 StopPolicyRuntimeState,
94 >(
95 StopPolicyRuntimeAction::RecordTokens {
96 started_at_ms: if runtime.started_at_ms.is_none() {
97 Some(now_ms)
98 } else {
99 None
100 },
101 prompt_tokens,
102 completion_tokens,
103 },
104 )));
105
106 let run_messages = &ctx.messages()[ctx.initial_message_count()..];
108 let message_stats = derive_stats_from_messages_with_response(run_messages, response);
109 let elapsed = std::time::Duration::from_millis(now_ms.saturating_sub(started_at_ms));
110
111 let run_ctx = RunContext::new(
112 ctx.thread_id().to_string(),
113 ctx.snapshot(),
114 ctx.messages().to_vec(),
115 ctx.run_policy().clone(),
116 );
117 let input = StopPolicyInput {
118 run_ctx: &run_ctx,
119 stats: StopPolicyStats {
120 step: message_stats.step,
121 step_tool_call_count: message_stats.step_tool_call_count,
122 total_tool_call_count: message_stats.total_tool_call_count,
123 total_input_tokens,
124 total_output_tokens,
125 consecutive_errors: message_stats.consecutive_errors,
126 elapsed,
127 last_tool_calls: &message_stats.last_tool_calls,
128 last_text: &message_stats.last_text,
129 tool_call_history: &message_stats.tool_call_history,
130 },
131 };
132 for condition in &self.conditions {
133 if let Some(stopped) = condition.evaluate(&input) {
134 actions = actions.and(AfterInferenceAction::Terminate(TerminationReason::Stopped(
135 stopped,
136 )));
137 break;
138 }
139 }
140 actions
141 }
142}
143
144fn now_millis() -> u64 {
145 std::time::SystemTime::now()
146 .duration_since(std::time::UNIX_EPOCH)
147 .map(|duration| duration.as_millis() as u64)
148 .unwrap_or(0)
149}
150
151#[derive(Debug, Clone, Default)]
152pub(super) struct MessageDerivedStopStats {
153 pub(super) step: usize,
154 pub(super) step_tool_call_count: usize,
155 pub(super) total_tool_call_count: usize,
156 pub(super) consecutive_errors: usize,
157 pub(super) last_tool_calls: Vec<ToolCall>,
158 pub(super) last_text: String,
159 pub(super) tool_call_history: VecDeque<Vec<String>>,
160}
161
162pub(super) fn derive_stats_from_messages(messages: &[Arc<Message>]) -> MessageDerivedStopStats {
163 let mut assistant_indices = Vec::new();
164 for (idx, message) in messages.iter().enumerate() {
165 if message.role == Role::Assistant {
166 assistant_indices.push(idx);
167 }
168 }
169
170 let mut stats = MessageDerivedStopStats {
171 step: assistant_indices.len(),
172 ..MessageDerivedStopStats::default()
173 };
174 let mut consecutive_errors = 0usize;
175
176 for (round_idx, &assistant_idx) in assistant_indices.iter().enumerate() {
177 let assistant = &messages[assistant_idx];
178 let tool_calls = assistant.tool_calls.clone().unwrap_or_default();
179
180 if !tool_calls.is_empty() {
181 stats.total_tool_call_count =
182 stats.total_tool_call_count.saturating_add(tool_calls.len());
183 let mut names: Vec<String> = tool_calls.iter().map(|tc| tc.name.clone()).collect();
184 names.sort();
185 if stats.tool_call_history.len() >= 20 {
186 stats.tool_call_history.pop_front();
187 }
188 stats.tool_call_history.push_back(names);
189 }
190
191 if round_idx + 1 == assistant_indices.len() {
192 stats.step_tool_call_count = tool_calls.len();
193 stats.last_tool_calls = tool_calls.clone();
194 stats.last_text = assistant.content.clone();
195 }
196
197 if tool_calls.is_empty() {
198 consecutive_errors = 0;
199 continue;
200 }
201
202 let next_assistant_idx = assistant_indices
203 .get(round_idx + 1)
204 .copied()
205 .unwrap_or(messages.len());
206 let tool_results =
207 collect_round_tool_results(messages, assistant_idx + 1, next_assistant_idx);
208 let round_all_errors = tool_calls
209 .iter()
210 .all(|call| tool_results.get(&call.id).copied().unwrap_or(false));
211 if round_all_errors {
212 consecutive_errors = consecutive_errors.saturating_add(1);
213 } else {
214 consecutive_errors = 0;
215 }
216 }
217
218 stats.consecutive_errors = consecutive_errors;
219 stats
220}
221
222pub(super) fn derive_stats_from_messages_with_response(
223 messages: &[Arc<Message>],
224 response: &StreamResult,
225) -> MessageDerivedStopStats {
226 let mut all_messages = Vec::with_capacity(messages.len() + 1);
227 all_messages.extend(messages.iter().cloned());
228 all_messages.push(Arc::new(Message::assistant_with_tool_calls(
229 response.text.clone(),
230 response.tool_calls.clone(),
231 )));
232 derive_stats_from_messages(&all_messages)
233}
234
235fn collect_round_tool_results(
236 messages: &[Arc<Message>],
237 from: usize,
238 to: usize,
239) -> HashMap<String, bool> {
240 let mut out = HashMap::new();
241 for message in messages.iter().take(to).skip(from) {
242 if message.role != Role::Tool {
243 continue;
244 }
245 let Some(call_id) = message.tool_call_id.as_ref() else {
246 continue;
247 };
248 let is_error = serde_json::from_str::<ToolResult>(&message.content)
249 .map(|result| result.is_error())
250 .unwrap_or(false);
251 out.insert(call_id.clone(), is_error);
252 }
253 out
254}
255
256pub(super) fn condition_from_spec(spec: StopConditionSpec) -> Arc<dyn StopPolicy> {
257 match spec {
258 StopConditionSpec::MaxRounds { rounds } => Arc::new(MaxRounds(rounds)),
259 StopConditionSpec::Timeout { seconds } => {
260 Arc::new(Timeout(std::time::Duration::from_secs(seconds)))
261 }
262 StopConditionSpec::TokenBudget { max_total } => Arc::new(TokenBudget { max_total }),
263 StopConditionSpec::ConsecutiveErrors { max } => Arc::new(ConsecutiveErrors(max)),
264 StopConditionSpec::StopOnTool { tool_name } => Arc::new(StopOnTool(tool_name)),
265 StopConditionSpec::ContentMatch { pattern } => Arc::new(ContentMatch(pattern)),
266 StopConditionSpec::LoopDetection { window } => Arc::new(LoopDetection { window }),
267 }
268}