tirea_agentos/runtime/context/
plugin.rs1use async_trait::async_trait;
2use genai::chat::ChatOptions;
3use std::sync::Arc;
4
5use tirea_contract::runtime::behavior::ReadOnlyContext;
6use tirea_contract::runtime::inference::{ContextWindowPolicy, InferenceRequestTransform};
7use tirea_contract::runtime::phase::{ActionSet, AfterToolExecuteAction, BeforeInferenceAction};
8use tirea_contract::runtime::state::{AnyStateAction, StateScope};
9use tirea_contract::runtime::tool_call::{
10 suspended_calls_from_state, tool_call_states_from_state, ToolResult,
11};
12use tirea_contract::thread::Message;
13
14use crate::engine::token_estimator::{estimate_messages_tokens, estimate_tokens};
15use crate::runtime::loop_runner::LlmExecutor;
16
17use super::compaction::{
18 build_artifact_preview, find_compaction_plan, now_ms, render_messages_for_summary,
19 ContextSummarizer, LlmContextSummarizer, SummaryPayload,
20 DEFAULT_ARTIFACT_COMPACT_THRESHOLD_TOKENS, MIN_COMPACTION_GAIN_TOKENS,
21};
22use super::state::{ArtifactRef, CompactBoundary, ContextAction, ContextState};
23use super::transform::ContextTransform;
24use super::{policy_for_model, CONTEXT_PLUGIN_ID};
25
26#[derive(Clone)]
28pub struct ContextPlugin {
29 pub(super) policy: ContextWindowPolicy,
30 artifact_compact_threshold_tokens: usize,
31 summarizer: Option<Arc<dyn ContextSummarizer>>,
32}
33
34impl std::fmt::Debug for ContextPlugin {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("ContextPlugin")
37 .field("policy", &self.policy)
38 .field(
39 "artifact_compact_threshold_tokens",
40 &self.artifact_compact_threshold_tokens,
41 )
42 .field("has_summarizer", &self.summarizer.is_some())
43 .finish()
44 }
45}
46
47impl Default for ContextPlugin {
48 fn default() -> Self {
49 Self::new(ContextWindowPolicy::default())
50 }
51}
52
53impl ContextPlugin {
54 pub fn new(policy: ContextWindowPolicy) -> Self {
55 Self {
56 policy,
57 artifact_compact_threshold_tokens: DEFAULT_ARTIFACT_COMPACT_THRESHOLD_TOKENS,
58 summarizer: None,
59 }
60 }
61
62 pub fn for_model(model: &str) -> Self {
64 Self::new(policy_for_model(model))
65 }
66
67 pub fn with_artifact_compact_threshold_tokens(mut self, threshold: usize) -> Self {
68 self.artifact_compact_threshold_tokens = threshold;
69 self
70 }
71
72 #[cfg(test)]
73 pub(super) fn with_summarizer(mut self, summarizer: Arc<dyn ContextSummarizer>) -> Self {
74 self.summarizer = Some(summarizer);
75 self
76 }
77
78 pub(crate) fn with_llm_summarizer(
79 mut self,
80 model: String,
81 executor: Arc<dyn LlmExecutor>,
82 chat_options: Option<ChatOptions>,
83 ) -> Self {
84 self.summarizer = Some(Arc::new(LlmContextSummarizer::new(
85 model,
86 executor,
87 chat_options,
88 )));
89 self
90 }
91
92 async fn maybe_compact(
93 &self,
94 ctx: &ReadOnlyContext<'_>,
95 state: &ContextState,
96 ) -> Option<(CompactBoundary, Option<ContextAction>)> {
97 let threshold = self.policy.autocompact_threshold?;
98 let raw_messages: Vec<Message> = ctx
99 .messages()
100 .iter()
101 .map(|message| (**message).clone())
102 .collect();
103 let effective_messages = ContextTransform::new(
104 state.clone(),
105 ContextWindowPolicy {
106 max_context_tokens: usize::MAX,
107 ..self.policy.clone()
108 },
109 )
110 .transform(raw_messages, &[])
111 .messages;
112 let effective_tokens = estimate_messages_tokens(&effective_messages);
113 if effective_tokens < threshold {
114 return None;
115 }
116
117 let snapshot = ctx.snapshot();
118 let tool_states = tool_call_states_from_state(&snapshot);
119 let suspended_calls = suspended_calls_from_state(&snapshot);
120 let plan = find_compaction_plan(
121 ctx.messages(),
122 state,
123 &tool_states,
124 &suspended_calls,
125 self.policy.compaction_mode,
126 self.policy.compaction_raw_suffix_messages,
127 )?;
128 if plan.covered_token_count < MIN_COMPACTION_GAIN_TOKENS {
129 return None;
130 }
131 let summarizer = self.summarizer.as_ref()?;
132
133 let mut delta_messages: Vec<Message> = ctx.messages()
134 [plan.start_index..=plan.boundary_index]
135 .iter()
136 .map(|message| (**message).clone())
137 .collect();
138 ContextTransform::new(state.clone(), ContextWindowPolicy::default())
139 .apply_artifact_refs(&mut delta_messages);
140
141 let payload = SummaryPayload {
142 previous_summary: state
143 .latest_boundary()
144 .map(|boundary| boundary.summary.clone()),
145 transcript: render_messages_for_summary(&delta_messages),
146 };
147 let summary = match summarizer.summarize(payload).await {
148 Ok(summary) => summary,
149 Err(error) => {
150 tracing::warn!(
151 thread_id = %ctx.thread_id(),
152 error = %error,
153 "context compaction skipped after summary generation failed"
154 );
155 return None;
156 }
157 };
158
159 let boundary = CompactBoundary {
160 covers_through_message_id: plan.boundary_message_id.clone(),
161 summary,
162 original_token_count: state
163 .latest_boundary()
164 .map(|boundary| boundary.original_token_count)
165 .unwrap_or(0)
166 + plan.covered_token_count,
167 created_at_ms: now_ms(),
168 };
169
170 let prune_action =
171 if plan.covered_message_ids.is_empty() && plan.covered_tool_call_ids.is_empty() {
172 None
173 } else {
174 Some(ContextAction::PruneArtifacts {
175 message_ids: plan.covered_message_ids,
176 tool_call_ids: plan.covered_tool_call_ids,
177 })
178 };
179
180 Some((boundary, prune_action))
181 }
182
183 fn maybe_build_artifact_ref(&self, call_id: &str, result: &ToolResult) -> Option<ArtifactRef> {
184 let raw_content = serde_json::to_string(result).unwrap_or_else(|_| {
185 result
186 .message
187 .clone()
188 .unwrap_or_else(|| result.data.to_string())
189 });
190 let token_count = estimate_tokens(&raw_content);
191 if token_count < self.artifact_compact_threshold_tokens {
192 return None;
193 }
194
195 Some(ArtifactRef {
196 message_id: None,
197 tool_call_id: Some(call_id.to_string()),
198 label: result.tool_name.clone(),
199 summary: build_artifact_preview(result),
200 original_size: raw_content.len(),
201 original_token_count: token_count,
202 })
203 }
204}
205
206#[async_trait]
207impl tirea_contract::runtime::AgentBehavior for ContextPlugin {
208 fn id(&self) -> &str {
209 CONTEXT_PLUGIN_ID
210 }
211
212 tirea_contract::declare_plugin_states!(ContextState);
213
214 async fn before_inference(
215 &self,
216 ctx: &ReadOnlyContext<'_>,
217 ) -> ActionSet<BeforeInferenceAction> {
218 let state = ctx
219 .scoped_state_of::<ContextState>(StateScope::Thread)
220 .ok()
221 .unwrap_or_default();
222
223 let mut effective_state = state.clone();
224 let mut actions = ActionSet::empty();
225
226 if let Some((boundary, prune_action)) = self.maybe_compact(ctx, &state).await {
227 let boundary_action = ContextAction::AddBoundary(boundary.clone());
228 effective_state.reduce(boundary_action.clone());
229 actions = actions.and(BeforeInferenceAction::State(AnyStateAction::new::<
230 ContextState,
231 >(boundary_action)));
232 if let Some(prune_action) = prune_action {
233 effective_state.reduce(prune_action.clone());
234 actions = actions.and(BeforeInferenceAction::State(AnyStateAction::new::<
235 ContextState,
236 >(prune_action)));
237 }
238 }
239
240 actions.and(BeforeInferenceAction::AddRequestTransform(Arc::new(
243 ContextTransform::new(effective_state, self.policy.clone()),
244 )))
245 }
246
247 async fn after_tool_execute(
248 &self,
249 ctx: &ReadOnlyContext<'_>,
250 ) -> ActionSet<AfterToolExecuteAction> {
251 let Some(result) = ctx.tool_result() else {
252 return ActionSet::empty();
253 };
254 let Some(call_id) = ctx.tool_call_id() else {
255 return ActionSet::empty();
256 };
257
258 let Some(artifact) = self.maybe_build_artifact_ref(call_id, result) else {
259 return ActionSet::empty();
260 };
261
262 ActionSet::single(AfterToolExecuteAction::State(AnyStateAction::new::<
263 ContextState,
264 >(
265 ContextAction::AddArtifact(artifact),
266 )))
267 }
268}