tirea_agentos/engine/
token_estimator.rs

1//! Fast heuristic token estimation for context window management.
2//!
3//! Uses character-based heuristics rather than a full tokenizer to avoid
4//! heavy dependencies. Accuracy is ±20% which is sufficient for deciding
5//! when to truncate or compact conversation history.
6
7use crate::contracts::runtime::tool_call::ToolDescriptor;
8use crate::contracts::thread::Message;
9
10/// Approximate tokens per ASCII character for typical LLM tokenizers.
11const CHARS_PER_TOKEN_ASCII: f32 = 4.0;
12/// Approximate tokens per CJK character.
13const CHARS_PER_TOKEN_CJK: f32 = 1.5;
14/// Overhead tokens per message (role tag, separators, etc.).
15const MESSAGE_OVERHEAD: usize = 4;
16/// Overhead tokens per tool call structure (JSON envelope, name, id).
17const TOOL_CALL_OVERHEAD: usize = 20;
18/// Overhead tokens per tool descriptor (JSON schema envelope).
19const TOOL_DESCRIPTOR_OVERHEAD: usize = 20;
20
21fn is_cjk(c: char) -> bool {
22    matches!(c,
23        '\u{4E00}'..='\u{9FFF}'   // CJK Unified Ideographs
24        | '\u{3400}'..='\u{4DBF}' // CJK Extension A
25        | '\u{F900}'..='\u{FAFF}' // CJK Compatibility Ideographs
26        | '\u{3000}'..='\u{303F}' // CJK Symbols and Punctuation
27        | '\u{3040}'..='\u{309F}' // Hiragana
28        | '\u{30A0}'..='\u{30FF}' // Katakana
29        | '\u{AC00}'..='\u{D7AF}' // Hangul Syllables
30    )
31}
32
33/// Estimate token count for a text string.
34pub fn estimate_tokens(text: &str) -> usize {
35    if text.is_empty() {
36        return 0;
37    }
38    let mut cjk_chars = 0usize;
39    let mut other_chars = 0usize;
40    for c in text.chars() {
41        if is_cjk(c) {
42            cjk_chars += 1;
43        } else {
44            other_chars += 1;
45        }
46    }
47    let cjk_tokens = (cjk_chars as f32 / CHARS_PER_TOKEN_CJK).ceil() as usize;
48    let ascii_tokens = (other_chars as f32 / CHARS_PER_TOKEN_ASCII).ceil() as usize;
49    (cjk_tokens + ascii_tokens).max(1)
50}
51
52/// Estimate token count for a single message.
53pub fn estimate_message_tokens(msg: &Message) -> usize {
54    let content_tokens = estimate_tokens(&msg.content);
55    let tool_call_tokens: usize = msg
56        .tool_calls
57        .as_ref()
58        .map(|calls| {
59            calls
60                .iter()
61                .map(|c| {
62                    estimate_tokens(&c.name)
63                        + estimate_tokens(&c.arguments.to_string())
64                        + TOOL_CALL_OVERHEAD
65                })
66                .sum()
67        })
68        .unwrap_or(0);
69    content_tokens + tool_call_tokens + MESSAGE_OVERHEAD
70}
71
72/// Estimate total token count for a slice of messages.
73pub fn estimate_messages_tokens(messages: &[Message]) -> usize {
74    messages.iter().map(estimate_message_tokens).sum()
75}
76
77/// Estimate total token count for tool descriptors.
78pub fn estimate_tool_tokens(tools: &[ToolDescriptor]) -> usize {
79    tools
80        .iter()
81        .map(|t| {
82            estimate_tokens(&t.name)
83                + estimate_tokens(&t.description)
84                + estimate_tokens(&t.parameters.to_string())
85                + TOOL_DESCRIPTOR_OVERHEAD
86        })
87        .sum()
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use serde_json::json;
94
95    #[test]
96    fn estimate_tokens_empty() {
97        assert_eq!(estimate_tokens(""), 0);
98    }
99
100    #[test]
101    fn estimate_tokens_ascii() {
102        // "Hello world" = 11 chars → ~3 tokens
103        let tokens = estimate_tokens("Hello world");
104        assert!((2..=5).contains(&tokens), "got {tokens}");
105    }
106
107    #[test]
108    fn estimate_tokens_cjk() {
109        // "你好世界" = 4 CJK chars → ~3 tokens
110        let tokens = estimate_tokens("你好世界");
111        assert!((2..=5).contains(&tokens), "got {tokens}");
112    }
113
114    #[test]
115    fn estimate_tokens_mixed() {
116        let tokens = estimate_tokens("Hello 你好 world 世界");
117        assert!((4..=10).contains(&tokens), "got {tokens}");
118    }
119
120    #[test]
121    fn estimate_tokens_code_block() {
122        let code = "fn main() {\n    let x = compute(42);\n    return x;\n}";
123        let tokens = estimate_tokens(code);
124        assert!((8..=20).contains(&tokens), "got {tokens}");
125    }
126
127    #[test]
128    fn estimate_message_tokens_simple() {
129        let msg = Message::user("What is 2+2?");
130        let tokens = estimate_message_tokens(&msg);
131        assert!(tokens >= 5, "got {tokens}");
132    }
133
134    #[test]
135    fn estimate_message_tokens_with_tool_calls() {
136        use crate::contracts::thread::ToolCall;
137        let msg = Message::assistant_with_tool_calls(
138            "I'll calculate that.",
139            vec![ToolCall::new(
140                "call_1",
141                "calculator",
142                json!({"expr": "2+2"}),
143            )],
144        );
145        let tokens = estimate_message_tokens(&msg);
146        // Content + tool call name + args + overheads
147        assert!(tokens >= 15, "got {tokens}");
148    }
149
150    #[test]
151    fn estimate_tool_tokens_basic() {
152        let tools = vec![
153            ToolDescriptor::new("calc", "Calculator", "Evaluate math expressions").with_parameters(
154                json!({
155                    "type": "object",
156                    "properties": {
157                        "expression": { "type": "string" }
158                    },
159                    "required": ["expression"]
160                }),
161            ),
162        ];
163        let tokens = estimate_tool_tokens(&tools);
164        assert!(tokens >= 20, "got {tokens}");
165    }
166
167    #[test]
168    fn estimate_messages_tokens_multiple() {
169        let messages = vec![
170            Message::user("Hello"),
171            Message::assistant("Hi there!"),
172            Message::user("How are you?"),
173        ];
174        let total = estimate_messages_tokens(&messages);
175        let sum: usize = messages.iter().map(estimate_message_tokens).sum();
176        assert_eq!(total, sum);
177    }
178}