tirea_agentos/runtime/stop_policy/
conditions.rs1use std::collections::VecDeque;
2
3use crate::contracts::thread::ToolCall;
4use crate::contracts::{RunContext, StoppedReason};
5
6pub struct StopPolicyStats<'a> {
8 pub step: usize,
10 pub step_tool_call_count: usize,
12 pub total_tool_call_count: usize,
14 pub total_input_tokens: usize,
16 pub total_output_tokens: usize,
18 pub consecutive_errors: usize,
20 pub elapsed: std::time::Duration,
22 pub last_tool_calls: &'a [ToolCall],
24 pub last_text: &'a str,
26 pub tool_call_history: &'a VecDeque<Vec<String>>,
28}
29
30pub struct StopPolicyInput<'a> {
32 pub run_ctx: &'a RunContext,
34 pub stats: StopPolicyStats<'a>,
36}
37
38pub trait StopPolicy: Send + Sync {
40 fn id(&self) -> &str;
42
43 fn evaluate(&self, input: &StopPolicyInput<'_>) -> Option<StoppedReason>;
45}
46
47pub struct MaxRounds(pub usize);
53
54impl StopPolicy for MaxRounds {
55 fn id(&self) -> &str {
56 "max_rounds"
57 }
58
59 fn evaluate(&self, input: &StopPolicyInput<'_>) -> Option<StoppedReason> {
60 if input.stats.step >= self.0 {
61 Some(StoppedReason::new("max_rounds_reached"))
62 } else {
63 None
64 }
65 }
66}
67
68pub struct Timeout(pub std::time::Duration);
70
71impl StopPolicy for Timeout {
72 fn id(&self) -> &str {
73 "timeout"
74 }
75
76 fn evaluate(&self, input: &StopPolicyInput<'_>) -> Option<StoppedReason> {
77 if input.stats.elapsed >= self.0 {
78 Some(StoppedReason::new("timeout_reached"))
79 } else {
80 None
81 }
82 }
83}
84
85pub struct TokenBudget {
87 pub max_total: usize,
89}
90
91impl StopPolicy for TokenBudget {
92 fn id(&self) -> &str {
93 "token_budget"
94 }
95
96 fn evaluate(&self, input: &StopPolicyInput<'_>) -> Option<StoppedReason> {
97 if self.max_total > 0
98 && (input.stats.total_input_tokens + input.stats.total_output_tokens) >= self.max_total
99 {
100 Some(StoppedReason::new("token_budget_exceeded"))
101 } else {
102 None
103 }
104 }
105}
106
107pub struct ConsecutiveErrors(pub usize);
109
110impl StopPolicy for ConsecutiveErrors {
111 fn id(&self) -> &str {
112 "consecutive_errors"
113 }
114
115 fn evaluate(&self, input: &StopPolicyInput<'_>) -> Option<StoppedReason> {
116 if self.0 > 0 && input.stats.consecutive_errors >= self.0 {
117 Some(StoppedReason::new("consecutive_errors_exceeded"))
118 } else {
119 None
120 }
121 }
122}
123
124pub struct StopOnTool(pub String);
126
127impl StopPolicy for StopOnTool {
128 fn id(&self) -> &str {
129 "stop_on_tool"
130 }
131
132 fn evaluate(&self, input: &StopPolicyInput<'_>) -> Option<StoppedReason> {
133 for call in input.stats.last_tool_calls {
134 if call.name == self.0 {
135 return Some(StoppedReason::with_detail("tool_called", self.0.clone()));
136 }
137 }
138 None
139 }
140}
141
142pub struct ContentMatch(pub String);
144
145impl StopPolicy for ContentMatch {
146 fn id(&self) -> &str {
147 "content_match"
148 }
149
150 fn evaluate(&self, input: &StopPolicyInput<'_>) -> Option<StoppedReason> {
151 if !self.0.is_empty() && input.stats.last_text.contains(&self.0) {
152 Some(StoppedReason::with_detail(
153 "content_matched",
154 self.0.clone(),
155 ))
156 } else {
157 None
158 }
159 }
160}
161
162pub struct LoopDetection {
168 pub window: usize,
170}
171
172impl StopPolicy for LoopDetection {
173 fn id(&self) -> &str {
174 "loop_detection"
175 }
176
177 fn evaluate(&self, input: &StopPolicyInput<'_>) -> Option<StoppedReason> {
178 let window = self.window.max(2);
179 let history = input.stats.tool_call_history;
180 if history.len() < 2 {
181 return None;
182 }
183
184 let recent: Vec<_> = history.iter().rev().take(window).collect();
185 for pair in recent.windows(2) {
186 if pair[0] == pair[1] {
187 return Some(StoppedReason::new("loop_detected"));
188 }
189 }
190 None
191 }
192}