tirea_agentos/runtime/context/
plugin.rs

1use async_trait::async_trait;
2use genai::chat::ChatOptions;
3use std::sync::Arc;
4
5use tirea_contract::runtime::behavior::ReadOnlyContext;
6use tirea_contract::runtime::inference::{ContextWindowPolicy, InferenceRequestTransform};
7use tirea_contract::runtime::phase::{ActionSet, AfterToolExecuteAction, BeforeInferenceAction};
8use tirea_contract::runtime::state::{AnyStateAction, StateScope};
9use tirea_contract::runtime::tool_call::{
10    suspended_calls_from_state, tool_call_states_from_state, ToolResult,
11};
12use tirea_contract::thread::Message;
13
14use crate::engine::token_estimator::{estimate_messages_tokens, estimate_tokens};
15use crate::runtime::loop_runner::LlmExecutor;
16
17use super::compaction::{
18    build_artifact_preview, find_compaction_plan, now_ms, render_messages_for_summary,
19    ContextSummarizer, LlmContextSummarizer, SummaryPayload,
20    DEFAULT_ARTIFACT_COMPACT_THRESHOLD_TOKENS, MIN_COMPACTION_GAIN_TOKENS,
21};
22use super::state::{ArtifactRef, CompactBoundary, ContextAction, ContextState};
23use super::transform::ContextTransform;
24use super::{policy_for_model, CONTEXT_PLUGIN_ID};
25
26/// Unified context plugin: logical compression + hard truncation + prompt caching.
27#[derive(Clone)]
28pub struct ContextPlugin {
29    pub(super) policy: ContextWindowPolicy,
30    artifact_compact_threshold_tokens: usize,
31    summarizer: Option<Arc<dyn ContextSummarizer>>,
32}
33
34impl std::fmt::Debug for ContextPlugin {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("ContextPlugin")
37            .field("policy", &self.policy)
38            .field(
39                "artifact_compact_threshold_tokens",
40                &self.artifact_compact_threshold_tokens,
41            )
42            .field("has_summarizer", &self.summarizer.is_some())
43            .finish()
44    }
45}
46
47impl Default for ContextPlugin {
48    fn default() -> Self {
49        Self::new(ContextWindowPolicy::default())
50    }
51}
52
53impl ContextPlugin {
54    pub fn new(policy: ContextWindowPolicy) -> Self {
55        Self {
56            policy,
57            artifact_compact_threshold_tokens: DEFAULT_ARTIFACT_COMPACT_THRESHOLD_TOKENS,
58            summarizer: None,
59        }
60    }
61
62    /// Create with model-specific defaults.
63    pub fn for_model(model: &str) -> Self {
64        Self::new(policy_for_model(model))
65    }
66
67    pub fn with_artifact_compact_threshold_tokens(mut self, threshold: usize) -> Self {
68        self.artifact_compact_threshold_tokens = threshold;
69        self
70    }
71
72    #[cfg(test)]
73    pub(super) fn with_summarizer(mut self, summarizer: Arc<dyn ContextSummarizer>) -> Self {
74        self.summarizer = Some(summarizer);
75        self
76    }
77
78    pub(crate) fn with_llm_summarizer(
79        mut self,
80        model: String,
81        executor: Arc<dyn LlmExecutor>,
82        chat_options: Option<ChatOptions>,
83    ) -> Self {
84        self.summarizer = Some(Arc::new(LlmContextSummarizer::new(
85            model,
86            executor,
87            chat_options,
88        )));
89        self
90    }
91
92    async fn maybe_compact(
93        &self,
94        ctx: &ReadOnlyContext<'_>,
95        state: &ContextState,
96    ) -> Option<(CompactBoundary, Option<ContextAction>)> {
97        let threshold = self.policy.autocompact_threshold?;
98        let raw_messages: Vec<Message> = ctx
99            .messages()
100            .iter()
101            .map(|message| (**message).clone())
102            .collect();
103        let effective_messages = ContextTransform::new(
104            state.clone(),
105            ContextWindowPolicy {
106                max_context_tokens: usize::MAX,
107                ..self.policy.clone()
108            },
109        )
110        .transform(raw_messages, &[])
111        .messages;
112        let effective_tokens = estimate_messages_tokens(&effective_messages);
113        if effective_tokens < threshold {
114            return None;
115        }
116
117        let snapshot = ctx.snapshot();
118        let tool_states = tool_call_states_from_state(&snapshot);
119        let suspended_calls = suspended_calls_from_state(&snapshot);
120        let plan = find_compaction_plan(
121            ctx.messages(),
122            state,
123            &tool_states,
124            &suspended_calls,
125            self.policy.compaction_mode,
126            self.policy.compaction_raw_suffix_messages,
127        )?;
128        if plan.covered_token_count < MIN_COMPACTION_GAIN_TOKENS {
129            return None;
130        }
131        let summarizer = self.summarizer.as_ref()?;
132
133        let mut delta_messages: Vec<Message> = ctx.messages()
134            [plan.start_index..=plan.boundary_index]
135            .iter()
136            .map(|message| (**message).clone())
137            .collect();
138        ContextTransform::new(state.clone(), ContextWindowPolicy::default())
139            .apply_artifact_refs(&mut delta_messages);
140
141        let payload = SummaryPayload {
142            previous_summary: state
143                .latest_boundary()
144                .map(|boundary| boundary.summary.clone()),
145            transcript: render_messages_for_summary(&delta_messages),
146        };
147        let summary = match summarizer.summarize(payload).await {
148            Ok(summary) => summary,
149            Err(error) => {
150                tracing::warn!(
151                    thread_id = %ctx.thread_id(),
152                    error = %error,
153                    "context compaction skipped after summary generation failed"
154                );
155                return None;
156            }
157        };
158
159        let boundary = CompactBoundary {
160            covers_through_message_id: plan.boundary_message_id.clone(),
161            summary,
162            original_token_count: state
163                .latest_boundary()
164                .map(|boundary| boundary.original_token_count)
165                .unwrap_or(0)
166                + plan.covered_token_count,
167            created_at_ms: now_ms(),
168        };
169
170        let prune_action =
171            if plan.covered_message_ids.is_empty() && plan.covered_tool_call_ids.is_empty() {
172                None
173            } else {
174                Some(ContextAction::PruneArtifacts {
175                    message_ids: plan.covered_message_ids,
176                    tool_call_ids: plan.covered_tool_call_ids,
177                })
178            };
179
180        Some((boundary, prune_action))
181    }
182
183    fn maybe_build_artifact_ref(&self, call_id: &str, result: &ToolResult) -> Option<ArtifactRef> {
184        let raw_content = serde_json::to_string(result).unwrap_or_else(|_| {
185            result
186                .message
187                .clone()
188                .unwrap_or_else(|| result.data.to_string())
189        });
190        let token_count = estimate_tokens(&raw_content);
191        if token_count < self.artifact_compact_threshold_tokens {
192            return None;
193        }
194
195        Some(ArtifactRef {
196            message_id: None,
197            tool_call_id: Some(call_id.to_string()),
198            label: result.tool_name.clone(),
199            summary: build_artifact_preview(result),
200            original_size: raw_content.len(),
201            original_token_count: token_count,
202        })
203    }
204}
205
206#[async_trait]
207impl tirea_contract::runtime::AgentBehavior for ContextPlugin {
208    fn id(&self) -> &str {
209        CONTEXT_PLUGIN_ID
210    }
211
212    tirea_contract::declare_plugin_states!(ContextState);
213
214    async fn before_inference(
215        &self,
216        ctx: &ReadOnlyContext<'_>,
217    ) -> ActionSet<BeforeInferenceAction> {
218        let state = ctx
219            .scoped_state_of::<ContextState>(StateScope::Thread)
220            .ok()
221            .unwrap_or_default();
222
223        let mut effective_state = state.clone();
224        let mut actions = ActionSet::empty();
225
226        if let Some((boundary, prune_action)) = self.maybe_compact(ctx, &state).await {
227            let boundary_action = ContextAction::AddBoundary(boundary.clone());
228            effective_state.reduce(boundary_action.clone());
229            actions = actions.and(BeforeInferenceAction::State(AnyStateAction::new::<
230                ContextState,
231            >(boundary_action)));
232            if let Some(prune_action) = prune_action {
233                effective_state.reduce(prune_action.clone());
234                actions = actions.and(BeforeInferenceAction::State(AnyStateAction::new::<
235                    ContextState,
236                >(prune_action)));
237            }
238        }
239
240        // Always register the combined transform: compaction is a no-op when
241        // state is empty, but truncation must always run.
242        actions.and(BeforeInferenceAction::AddRequestTransform(Arc::new(
243            ContextTransform::new(effective_state, self.policy.clone()),
244        )))
245    }
246
247    async fn after_tool_execute(
248        &self,
249        ctx: &ReadOnlyContext<'_>,
250    ) -> ActionSet<AfterToolExecuteAction> {
251        let Some(result) = ctx.tool_result() else {
252            return ActionSet::empty();
253        };
254        let Some(call_id) = ctx.tool_call_id() else {
255            return ActionSet::empty();
256        };
257
258        let Some(artifact) = self.maybe_build_artifact_ref(call_id, result) else {
259            return ActionSet::empty();
260        };
261
262        ActionSet::single(AfterToolExecuteAction::State(AnyStateAction::new::<
263            ContextState,
264        >(
265            ContextAction::AddArtifact(artifact),
266        )))
267    }
268}