tirea_agentos/runtime/stop_policy/
plugin.rs

1use async_trait::async_trait;
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4
5use crate::composition::StopConditionSpec;
6use crate::contracts::runtime::behavior::{AgentBehavior, ReadOnlyContext};
7use crate::contracts::runtime::phase::{ActionSet, AfterInferenceAction};
8use crate::contracts::runtime::state::AnyStateAction;
9use crate::contracts::runtime::tool_call::ToolResult;
10use crate::contracts::runtime::StreamResult;
11use crate::contracts::thread::{Message, Role, ToolCall};
12use crate::contracts::{RunContext, TerminationReason};
13
14use super::conditions::{StopPolicy, StopPolicyInput, StopPolicyStats};
15use super::state::{StopPolicyRuntimeAction, StopPolicyRuntimeState};
16use super::{
17    ConsecutiveErrors, ContentMatch, LoopDetection, MaxRounds, StopOnTool, Timeout, TokenBudget,
18    STOP_POLICY_PLUGIN_ID,
19};
20
21/// Plugin adapter that evaluates configured stop policies at `AfterInference`.
22///
23/// This keeps stop-domain semantics out of the core loop.
24pub struct StopPolicyPlugin {
25    conditions: Vec<Arc<dyn StopPolicy>>,
26}
27
28impl std::fmt::Debug for StopPolicyPlugin {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("StopPolicyPlugin")
31            .field("conditions_len", &self.conditions.len())
32            .finish()
33    }
34}
35
36impl StopPolicyPlugin {
37    pub fn new(
38        mut stop_conditions: Vec<Arc<dyn StopPolicy>>,
39        stop_condition_specs: Vec<StopConditionSpec>,
40    ) -> Self {
41        stop_conditions.extend(stop_condition_specs.into_iter().map(condition_from_spec));
42        Self {
43            conditions: stop_conditions,
44        }
45    }
46
47    pub fn is_empty(&self) -> bool {
48        self.conditions.is_empty()
49    }
50}
51
52#[async_trait]
53impl AgentBehavior for StopPolicyPlugin {
54    fn id(&self) -> &str {
55        STOP_POLICY_PLUGIN_ID
56    }
57
58    tirea_contract::declare_plugin_states!(StopPolicyRuntimeState);
59
60    async fn after_inference(&self, ctx: &ReadOnlyContext<'_>) -> ActionSet<AfterInferenceAction> {
61        if self.conditions.is_empty() {
62            return ActionSet::empty();
63        }
64
65        let Some(response) = ctx.response() else {
66            return ActionSet::empty();
67        };
68        let now_ms = now_millis();
69        let prompt_tokens = response
70            .usage
71            .as_ref()
72            .and_then(|usage| usage.prompt_tokens)
73            .unwrap_or(0) as usize;
74        let completion_tokens = response
75            .usage
76            .as_ref()
77            .and_then(|usage| usage.completion_tokens)
78            .unwrap_or(0) as usize;
79
80        let runtime = ctx
81            .snapshot_of::<StopPolicyRuntimeState>()
82            .unwrap_or_default();
83        let started_at_ms = runtime.started_at_ms.unwrap_or(now_ms);
84        let total_input_tokens =
85            (runtime.total_input_tokens.value() as usize).saturating_add(prompt_tokens);
86        let total_output_tokens =
87            (runtime.total_output_tokens.value() as usize).saturating_add(completion_tokens);
88
89        let mut actions: ActionSet<AfterInferenceAction> = ActionSet::empty();
90
91        // Emit state patch for token recording
92        actions = actions.and(AfterInferenceAction::State(AnyStateAction::new::<
93            StopPolicyRuntimeState,
94        >(
95            StopPolicyRuntimeAction::RecordTokens {
96                started_at_ms: if runtime.started_at_ms.is_none() {
97                    Some(now_ms)
98                } else {
99                    None
100                },
101                prompt_tokens,
102                completion_tokens,
103            },
104        )));
105
106        // Only count messages from the current run to avoid cross-run accumulation.
107        let run_messages = &ctx.messages()[ctx.initial_message_count()..];
108        let message_stats = derive_stats_from_messages_with_response(run_messages, response);
109        let elapsed = std::time::Duration::from_millis(now_ms.saturating_sub(started_at_ms));
110
111        let run_ctx = RunContext::new(
112            ctx.thread_id().to_string(),
113            ctx.snapshot(),
114            ctx.messages().to_vec(),
115            ctx.run_policy().clone(),
116        );
117        let input = StopPolicyInput {
118            run_ctx: &run_ctx,
119            stats: StopPolicyStats {
120                step: message_stats.step,
121                step_tool_call_count: message_stats.step_tool_call_count,
122                total_tool_call_count: message_stats.total_tool_call_count,
123                total_input_tokens,
124                total_output_tokens,
125                consecutive_errors: message_stats.consecutive_errors,
126                elapsed,
127                last_tool_calls: &message_stats.last_tool_calls,
128                last_text: &message_stats.last_text,
129                tool_call_history: &message_stats.tool_call_history,
130            },
131        };
132        for condition in &self.conditions {
133            if let Some(stopped) = condition.evaluate(&input) {
134                actions = actions.and(AfterInferenceAction::Terminate(TerminationReason::Stopped(
135                    stopped,
136                )));
137                break;
138            }
139        }
140        actions
141    }
142}
143
144fn now_millis() -> u64 {
145    std::time::SystemTime::now()
146        .duration_since(std::time::UNIX_EPOCH)
147        .map(|duration| duration.as_millis() as u64)
148        .unwrap_or(0)
149}
150
151#[derive(Debug, Clone, Default)]
152pub(super) struct MessageDerivedStopStats {
153    pub(super) step: usize,
154    pub(super) step_tool_call_count: usize,
155    pub(super) total_tool_call_count: usize,
156    pub(super) consecutive_errors: usize,
157    pub(super) last_tool_calls: Vec<ToolCall>,
158    pub(super) last_text: String,
159    pub(super) tool_call_history: VecDeque<Vec<String>>,
160}
161
162pub(super) fn derive_stats_from_messages(messages: &[Arc<Message>]) -> MessageDerivedStopStats {
163    let mut assistant_indices = Vec::new();
164    for (idx, message) in messages.iter().enumerate() {
165        if message.role == Role::Assistant {
166            assistant_indices.push(idx);
167        }
168    }
169
170    let mut stats = MessageDerivedStopStats {
171        step: assistant_indices.len(),
172        ..MessageDerivedStopStats::default()
173    };
174    let mut consecutive_errors = 0usize;
175
176    for (round_idx, &assistant_idx) in assistant_indices.iter().enumerate() {
177        let assistant = &messages[assistant_idx];
178        let tool_calls = assistant.tool_calls.clone().unwrap_or_default();
179
180        if !tool_calls.is_empty() {
181            stats.total_tool_call_count =
182                stats.total_tool_call_count.saturating_add(tool_calls.len());
183            let mut names: Vec<String> = tool_calls.iter().map(|tc| tc.name.clone()).collect();
184            names.sort();
185            if stats.tool_call_history.len() >= 20 {
186                stats.tool_call_history.pop_front();
187            }
188            stats.tool_call_history.push_back(names);
189        }
190
191        if round_idx + 1 == assistant_indices.len() {
192            stats.step_tool_call_count = tool_calls.len();
193            stats.last_tool_calls = tool_calls.clone();
194            stats.last_text = assistant.content.clone();
195        }
196
197        if tool_calls.is_empty() {
198            consecutive_errors = 0;
199            continue;
200        }
201
202        let next_assistant_idx = assistant_indices
203            .get(round_idx + 1)
204            .copied()
205            .unwrap_or(messages.len());
206        let tool_results =
207            collect_round_tool_results(messages, assistant_idx + 1, next_assistant_idx);
208        let round_all_errors = tool_calls
209            .iter()
210            .all(|call| tool_results.get(&call.id).copied().unwrap_or(false));
211        if round_all_errors {
212            consecutive_errors = consecutive_errors.saturating_add(1);
213        } else {
214            consecutive_errors = 0;
215        }
216    }
217
218    stats.consecutive_errors = consecutive_errors;
219    stats
220}
221
222pub(super) fn derive_stats_from_messages_with_response(
223    messages: &[Arc<Message>],
224    response: &StreamResult,
225) -> MessageDerivedStopStats {
226    let mut all_messages = Vec::with_capacity(messages.len() + 1);
227    all_messages.extend(messages.iter().cloned());
228    all_messages.push(Arc::new(Message::assistant_with_tool_calls(
229        response.text.clone(),
230        response.tool_calls.clone(),
231    )));
232    derive_stats_from_messages(&all_messages)
233}
234
235fn collect_round_tool_results(
236    messages: &[Arc<Message>],
237    from: usize,
238    to: usize,
239) -> HashMap<String, bool> {
240    let mut out = HashMap::new();
241    for message in messages.iter().take(to).skip(from) {
242        if message.role != Role::Tool {
243            continue;
244        }
245        let Some(call_id) = message.tool_call_id.as_ref() else {
246            continue;
247        };
248        let is_error = serde_json::from_str::<ToolResult>(&message.content)
249            .map(|result| result.is_error())
250            .unwrap_or(false);
251        out.insert(call_id.clone(), is_error);
252    }
253    out
254}
255
256pub(super) fn condition_from_spec(spec: StopConditionSpec) -> Arc<dyn StopPolicy> {
257    match spec {
258        StopConditionSpec::MaxRounds { rounds } => Arc::new(MaxRounds(rounds)),
259        StopConditionSpec::Timeout { seconds } => {
260            Arc::new(Timeout(std::time::Duration::from_secs(seconds)))
261        }
262        StopConditionSpec::TokenBudget { max_total } => Arc::new(TokenBudget { max_total }),
263        StopConditionSpec::ConsecutiveErrors { max } => Arc::new(ConsecutiveErrors(max)),
264        StopConditionSpec::StopOnTool { tool_name } => Arc::new(StopOnTool(tool_name)),
265        StopConditionSpec::ContentMatch { pattern } => Arc::new(ContentMatch(pattern)),
266        StopConditionSpec::LoopDetection { window } => Arc::new(LoopDetection { window }),
267    }
268}