tirea_contract/runtime/inference/
context.rs

1use crate::runtime::inference::transform::InferenceRequestTransform;
2use crate::runtime::tool_call::ToolDescriptor;
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5
6/// Auto-compaction strategy used by [`ContextPlugin`].
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
8#[serde(rename_all = "snake_case")]
9pub enum ContextCompactionMode {
10    /// Replace an older prefix with a summary while preserving a recent raw suffix.
11    #[default]
12    KeepRecentRawSuffix,
13    /// Replace all messages through the latest safe frontier with a summary.
14    CompactToSafeFrontier,
15}
16
17const fn default_compaction_raw_suffix_messages() -> usize {
18    2
19}
20
21/// Context window management policy.
22///
23/// Pure data struct used by [`ContextPlugin`] to configure context
24/// management. Lives in the contract layer so plugin crates can construct
25/// it without depending on `tirea-agentos`.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ContextWindowPolicy {
28    /// Model's total context window size in tokens.
29    pub max_context_tokens: usize,
30    /// Tokens reserved for model output.
31    pub max_output_tokens: usize,
32    /// Minimum number of recent messages to always preserve (never truncated).
33    pub min_recent_messages: usize,
34    /// Whether to enable prompt caching (Anthropic `cache_control: ephemeral`).
35    pub enable_prompt_cache: bool,
36    /// Token count threshold that triggers auto-compaction. `None` disables.
37    /// Used by ContextPlugin to decide when to compact history.
38    #[serde(default, skip_serializing_if = "Option::is_none")]
39    pub autocompact_threshold: Option<usize>,
40    /// Auto-compaction strategy to use when `autocompact_threshold` is reached.
41    #[serde(default)]
42    pub compaction_mode: ContextCompactionMode,
43    /// Number of recent raw messages to preserve in
44    /// [`ContextCompactionMode::KeepRecentRawSuffix`] mode.
45    #[serde(default = "default_compaction_raw_suffix_messages")]
46    pub compaction_raw_suffix_messages: usize,
47}
48
49impl Default for ContextWindowPolicy {
50    fn default() -> Self {
51        Self {
52            max_context_tokens: 200_000,
53            max_output_tokens: 16_384,
54            min_recent_messages: 10,
55            enable_prompt_cache: true,
56            autocompact_threshold: None,
57            compaction_mode: ContextCompactionMode::KeepRecentRawSuffix,
58            compaction_raw_suffix_messages: default_compaction_raw_suffix_messages(),
59        }
60    }
61}
62
63/// Inference-phase extension: system/session context and tool descriptors.
64///
65/// Populated by `AddSystemContext`, `AddSessionContext`, `ExcludeTool`,
66/// `IncludeOnlyTools`, `AddRequestTransform` actions during `BeforeInference`.
67#[derive(Default, Clone)]
68pub struct InferenceContext {
69    /// System context lines appended to the system prompt.
70    pub system_context: Vec<String>,
71    /// Session context messages injected before user messages.
72    pub session_context: Vec<String>,
73    /// Available tool descriptors (can be filtered by actions).
74    pub tools: Vec<ToolDescriptor>,
75    /// Request transforms registered by plugins. Applied in order after
76    /// messages are assembled, before the request is sent to the LLM.
77    pub request_transforms: Vec<Arc<dyn InferenceRequestTransform>>,
78}
79
80impl std::fmt::Debug for InferenceContext {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        f.debug_struct("InferenceContext")
83            .field("system_context", &self.system_context)
84            .field("session_context", &self.session_context)
85            .field("tools", &self.tools)
86            .field("request_transforms", &self.request_transforms.len())
87            .finish()
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use serde_json::json;
95
96    #[test]
97    fn default_policy_uses_suffix_compaction_defaults() {
98        let policy = ContextWindowPolicy::default();
99        assert_eq!(
100            policy.compaction_mode,
101            ContextCompactionMode::KeepRecentRawSuffix
102        );
103        assert_eq!(policy.compaction_raw_suffix_messages, 2);
104    }
105
106    #[test]
107    fn policy_deserialization_backfills_new_compaction_fields() {
108        let value = json!({
109            "max_context_tokens": 4096,
110            "max_output_tokens": 512,
111            "min_recent_messages": 4,
112            "enable_prompt_cache": false,
113            "autocompact_threshold": 2048
114        });
115
116        let policy: ContextWindowPolicy = serde_json::from_value(value).unwrap();
117        assert_eq!(
118            policy.compaction_mode,
119            ContextCompactionMode::KeepRecentRawSuffix
120        );
121        assert_eq!(policy.compaction_raw_suffix_messages, 2);
122    }
123
124    #[test]
125    fn policy_serialization_roundtrip_preserves_frontier_mode() {
126        let policy = ContextWindowPolicy {
127            max_context_tokens: 8192,
128            max_output_tokens: 1024,
129            min_recent_messages: 6,
130            enable_prompt_cache: false,
131            autocompact_threshold: Some(4096),
132            compaction_mode: ContextCompactionMode::CompactToSafeFrontier,
133            compaction_raw_suffix_messages: 5,
134        };
135
136        let encoded = serde_json::to_value(&policy).unwrap();
137        assert_eq!(encoded["compaction_mode"], "compact_to_safe_frontier");
138        assert_eq!(encoded["compaction_raw_suffix_messages"], 5);
139
140        let restored: ContextWindowPolicy = serde_json::from_value(encoded).unwrap();
141        assert_eq!(
142            restored.compaction_mode,
143            ContextCompactionMode::CompactToSafeFrontier
144        );
145        assert_eq!(restored.compaction_raw_suffix_messages, 5);
146    }
147}