tirea_agentos/engine/
token_estimator.rs1use crate::contracts::runtime::tool_call::ToolDescriptor;
8use crate::contracts::thread::Message;
9
10const CHARS_PER_TOKEN_ASCII: f32 = 4.0;
12const CHARS_PER_TOKEN_CJK: f32 = 1.5;
14const MESSAGE_OVERHEAD: usize = 4;
16const TOOL_CALL_OVERHEAD: usize = 20;
18const TOOL_DESCRIPTOR_OVERHEAD: usize = 20;
20
21fn is_cjk(c: char) -> bool {
22 matches!(c,
23 '\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{F900}'..='\u{FAFF}' | '\u{3000}'..='\u{303F}' | '\u{3040}'..='\u{309F}' | '\u{30A0}'..='\u{30FF}' | '\u{AC00}'..='\u{D7AF}' )
31}
32
33pub fn estimate_tokens(text: &str) -> usize {
35 if text.is_empty() {
36 return 0;
37 }
38 let mut cjk_chars = 0usize;
39 let mut other_chars = 0usize;
40 for c in text.chars() {
41 if is_cjk(c) {
42 cjk_chars += 1;
43 } else {
44 other_chars += 1;
45 }
46 }
47 let cjk_tokens = (cjk_chars as f32 / CHARS_PER_TOKEN_CJK).ceil() as usize;
48 let ascii_tokens = (other_chars as f32 / CHARS_PER_TOKEN_ASCII).ceil() as usize;
49 (cjk_tokens + ascii_tokens).max(1)
50}
51
52pub fn estimate_message_tokens(msg: &Message) -> usize {
54 let content_tokens = estimate_tokens(&msg.content);
55 let tool_call_tokens: usize = msg
56 .tool_calls
57 .as_ref()
58 .map(|calls| {
59 calls
60 .iter()
61 .map(|c| {
62 estimate_tokens(&c.name)
63 + estimate_tokens(&c.arguments.to_string())
64 + TOOL_CALL_OVERHEAD
65 })
66 .sum()
67 })
68 .unwrap_or(0);
69 content_tokens + tool_call_tokens + MESSAGE_OVERHEAD
70}
71
72pub fn estimate_messages_tokens(messages: &[Message]) -> usize {
74 messages.iter().map(estimate_message_tokens).sum()
75}
76
77pub fn estimate_tool_tokens(tools: &[ToolDescriptor]) -> usize {
79 tools
80 .iter()
81 .map(|t| {
82 estimate_tokens(&t.name)
83 + estimate_tokens(&t.description)
84 + estimate_tokens(&t.parameters.to_string())
85 + TOOL_DESCRIPTOR_OVERHEAD
86 })
87 .sum()
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use serde_json::json;
94
95 #[test]
96 fn estimate_tokens_empty() {
97 assert_eq!(estimate_tokens(""), 0);
98 }
99
100 #[test]
101 fn estimate_tokens_ascii() {
102 let tokens = estimate_tokens("Hello world");
104 assert!((2..=5).contains(&tokens), "got {tokens}");
105 }
106
107 #[test]
108 fn estimate_tokens_cjk() {
109 let tokens = estimate_tokens("你好世界");
111 assert!((2..=5).contains(&tokens), "got {tokens}");
112 }
113
114 #[test]
115 fn estimate_tokens_mixed() {
116 let tokens = estimate_tokens("Hello 你好 world 世界");
117 assert!((4..=10).contains(&tokens), "got {tokens}");
118 }
119
120 #[test]
121 fn estimate_tokens_code_block() {
122 let code = "fn main() {\n let x = compute(42);\n return x;\n}";
123 let tokens = estimate_tokens(code);
124 assert!((8..=20).contains(&tokens), "got {tokens}");
125 }
126
127 #[test]
128 fn estimate_message_tokens_simple() {
129 let msg = Message::user("What is 2+2?");
130 let tokens = estimate_message_tokens(&msg);
131 assert!(tokens >= 5, "got {tokens}");
132 }
133
134 #[test]
135 fn estimate_message_tokens_with_tool_calls() {
136 use crate::contracts::thread::ToolCall;
137 let msg = Message::assistant_with_tool_calls(
138 "I'll calculate that.",
139 vec![ToolCall::new(
140 "call_1",
141 "calculator",
142 json!({"expr": "2+2"}),
143 )],
144 );
145 let tokens = estimate_message_tokens(&msg);
146 assert!(tokens >= 15, "got {tokens}");
148 }
149
150 #[test]
151 fn estimate_tool_tokens_basic() {
152 let tools = vec![
153 ToolDescriptor::new("calc", "Calculator", "Evaluate math expressions").with_parameters(
154 json!({
155 "type": "object",
156 "properties": {
157 "expression": { "type": "string" }
158 },
159 "required": ["expression"]
160 }),
161 ),
162 ];
163 let tokens = estimate_tool_tokens(&tools);
164 assert!(tokens >= 20, "got {tokens}");
165 }
166
167 #[test]
168 fn estimate_messages_tokens_multiple() {
169 let messages = vec![
170 Message::user("Hello"),
171 Message::assistant("Hi there!"),
172 Message::user("How are you?"),
173 ];
174 let total = estimate_messages_tokens(&messages);
175 let sum: usize = messages.iter().map(estimate_message_tokens).sum();
176 assert_eq!(total, sum);
177 }
178}