1use crate::{
2 LoadSkillResourceTool, SkillActivateTool, SkillDiscoveryPlugin, SkillRegistry, SkillScriptTool,
3};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tirea_contract::runtime::tool_call::Tool;
7
8#[derive(Debug, thiserror::Error)]
10pub enum SkillSubsystemError {
11 #[error("tool id already registered: {0}")]
12 ToolIdConflict(String),
13}
14
15#[derive(Clone)]
48pub struct SkillSubsystem {
49 registry: Arc<dyn SkillRegistry>,
50}
51
52impl std::fmt::Debug for SkillSubsystem {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("SkillSubsystem").finish_non_exhaustive()
55 }
56}
57
58impl SkillSubsystem {
59 pub fn new(registry: Arc<dyn SkillRegistry>) -> Self {
60 Self { registry }
61 }
62
63 pub fn registry(&self) -> &Arc<dyn SkillRegistry> {
64 &self.registry
65 }
66
67 pub fn discovery_plugin(&self) -> SkillDiscoveryPlugin {
69 SkillDiscoveryPlugin::new(self.registry.clone())
70 }
71
72 pub fn tools(&self) -> HashMap<String, Arc<dyn Tool>> {
79 let mut out: HashMap<String, Arc<dyn Tool>> = HashMap::new();
80 let _ = self.extend_tools(&mut out);
82 out
83 }
84
85 pub fn extend_tools(
89 &self,
90 tools: &mut HashMap<String, Arc<dyn Tool>>,
91 ) -> Result<(), SkillSubsystemError> {
92 let registry = self.registry.clone();
93 let tool_defs: Vec<Arc<dyn Tool>> = vec![
94 Arc::new(SkillActivateTool::new(registry.clone())),
95 Arc::new(LoadSkillResourceTool::new(registry.clone())),
96 Arc::new(SkillScriptTool::new(registry)),
97 ];
98
99 for t in tool_defs {
100 let id = t.descriptor().id.clone();
101 if tools.contains_key(&id) {
102 return Err(SkillSubsystemError::ToolIdConflict(id));
103 }
104 tools.insert(id, t);
105 }
106
107 Ok(())
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::{
115 FsSkill, InMemorySkillRegistry, Skill, SKILL_ACTIVATE_TOOL_ID, SKILL_LOAD_RESOURCE_TOOL_ID,
116 SKILL_SCRIPT_TOOL_ID,
117 };
118 use async_trait::async_trait;
119 use serde_json::json;
120 use serde_json::Value;
121 use std::fs;
122 use std::io::Write;
123 use tempfile::TempDir;
124 use tirea_contract::runtime::behavior::{AgentBehavior, ReadOnlyContext};
125 use tirea_contract::runtime::phase::Phase;
126 use tirea_contract::runtime::state::{reduce_state_actions, ScopeContext};
127 use tirea_contract::runtime::tool_call::{ToolError, ToolResult};
128 use tirea_contract::testing::TestFixture;
129 use tirea_contract::thread::Thread;
130 use tirea_contract::thread::{Message, ToolCall};
131 use tirea_state::TrackedPatch;
132
133 fn make_registry(skills: Vec<Arc<dyn Skill>>) -> Arc<dyn SkillRegistry> {
134 Arc::new(InMemorySkillRegistry::from_skills(skills))
135 }
136
137 struct LocalToolExecution {
138 result: ToolResult,
139 patch: Option<TrackedPatch>,
140 }
141
142 async fn execute_single_tool(
143 tool: Option<&dyn Tool>,
144 call: &ToolCall,
145 state: &Value,
146 ) -> LocalToolExecution {
147 let Some(tool) = tool else {
148 return LocalToolExecution {
149 result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
150 patch: None,
151 };
152 };
153
154 let fix = TestFixture::new_with_state(state.clone());
155 let tool_ctx = fix.ctx_with(&call.id, format!("tool:{}", call.name));
156 let effect = match tool.execute_effect(call.arguments.clone(), &tool_ctx).await {
157 Ok(e) => e,
158 Err(e) => {
159 return LocalToolExecution {
160 result: ToolResult::error(&call.name, e.to_string()),
161 patch: None,
162 };
163 }
164 };
165 let (result, actions) = effect.into_parts();
166 let state_actions: Vec<_> = actions
167 .into_iter()
168 .filter_map(|a| match a {
169 tirea_contract::runtime::phase::AfterToolExecuteAction::State(sa) => Some(sa),
170 _ => None,
171 })
172 .collect();
173 let scope_ctx = ScopeContext::run();
174 let patches = reduce_state_actions(
175 state_actions,
176 state,
177 &format!("tool:{}", call.name),
178 &scope_ctx,
179 )
180 .unwrap();
181 let patch = patches
182 .into_iter()
183 .reduce(|mut acc, p| {
184 acc.patch.merge(p.into_patch());
185 acc
186 })
187 .filter(|p| !p.patch().is_empty());
188 LocalToolExecution { result, patch }
189 }
190
191 #[derive(Debug)]
192 struct DummyTool;
193
194 #[async_trait]
195 impl Tool for DummyTool {
196 fn descriptor(&self) -> tirea_contract::runtime::tool_call::ToolDescriptor {
197 tirea_contract::runtime::tool_call::ToolDescriptor::new(
198 SKILL_ACTIVATE_TOOL_ID,
199 "x",
200 "x",
201 )
202 .with_parameters(json!({}))
203 }
204
205 async fn execute(
206 &self,
207 _args: Value,
208 _ctx: &tirea_contract::runtime::tool_call::ToolCallContext<'_>,
209 ) -> Result<
210 tirea_contract::runtime::tool_call::ToolResult,
211 tirea_contract::runtime::tool_call::ToolError,
212 > {
213 Ok(tirea_contract::runtime::tool_call::ToolResult::success(
214 SKILL_ACTIVATE_TOOL_ID,
215 json!({}),
216 ))
217 }
218 }
219
220 fn make_subsystem() -> (TempDir, SkillSubsystem) {
221 let td = TempDir::new().unwrap();
222 let root = td.path().join("skills");
223 fs::create_dir_all(root.join("s1")).unwrap();
224 fs::write(
225 root.join("s1").join("SKILL.md"),
226 "---\nname: s1\ndescription: ok\n---\nBody\n",
227 )
228 .unwrap();
229
230 let result = FsSkill::discover(root).unwrap();
231 let sys = SkillSubsystem::new(make_registry(FsSkill::into_arc_skills(result.skills)));
232 (td, sys)
233 }
234
235 #[test]
236 fn subsystem_extend_tools_detects_conflict() {
237 let (_td, sys) = make_subsystem();
238 let mut tools = HashMap::<String, Arc<dyn Tool>>::new();
239 tools.insert(SKILL_ACTIVATE_TOOL_ID.to_string(), Arc::new(DummyTool));
240 let err = sys.extend_tools(&mut tools).unwrap_err();
241 assert!(err.to_string().contains("tool id already registered"));
242 }
243
244 #[test]
245 fn subsystem_tools_returns_expected_ids() {
246 let (_td, sys) = make_subsystem();
247 let tools = sys.tools();
248 assert!(tools.contains_key(SKILL_ACTIVATE_TOOL_ID));
249 assert!(tools.contains_key(SKILL_LOAD_RESOURCE_TOOL_ID));
250 assert!(tools.contains_key(SKILL_SCRIPT_TOOL_ID));
251 assert_eq!(tools.len(), 3);
252 }
253
254 #[test]
255 fn subsystem_extend_tools_inserts_tools_into_existing_map() {
256 let (_td, sys) = make_subsystem();
257 let mut tools = HashMap::<String, Arc<dyn Tool>>::new();
258 tools.insert("other".to_string(), Arc::new(DummyOtherTool));
259 sys.extend_tools(&mut tools).unwrap();
260 assert!(tools.contains_key("other"));
261 assert!(tools.contains_key(SKILL_ACTIVATE_TOOL_ID));
262 assert!(tools.contains_key(SKILL_LOAD_RESOURCE_TOOL_ID));
263 assert!(tools.contains_key(SKILL_SCRIPT_TOOL_ID));
264 assert_eq!(tools.len(), 4);
265 }
266
267 #[derive(Debug)]
268 struct DummyOtherTool;
269
270 #[async_trait]
271 impl Tool for DummyOtherTool {
272 fn descriptor(&self) -> tirea_contract::runtime::tool_call::ToolDescriptor {
273 tirea_contract::runtime::tool_call::ToolDescriptor::new("other", "x", "x")
274 .with_parameters(json!({}))
275 }
276
277 async fn execute(
278 &self,
279 _args: Value,
280 _ctx: &tirea_contract::runtime::tool_call::ToolCallContext<'_>,
281 ) -> Result<ToolResult, ToolError> {
282 Ok(ToolResult::success("other", json!({})))
283 }
284 }
285
286 #[tokio::test]
287 async fn subsystem_plugin_injects_catalog_and_activated_skill() {
288 let td = TempDir::new().unwrap();
289 let root = td.path().join("skills");
290 fs::create_dir_all(root.join("docx").join("references")).unwrap();
291 fs::write(
292 root.join("docx").join("references").join("DOCX-JS.md"),
293 "Use docx-js for new documents.",
294 )
295 .unwrap();
296
297 let mut f = fs::File::create(root.join("docx").join("SKILL.md")).unwrap();
298 f.write_all(
299 b"---\nname: docx\ndescription: DOCX guidance\n---\nUse docx-js for new documents.\n\n",
300 )
301 .unwrap();
302
303 let result = FsSkill::discover(root).unwrap();
304 let sys = SkillSubsystem::new(make_registry(FsSkill::into_arc_skills(result.skills)));
305 let tools = sys.tools();
306
307 let thread = Thread::with_initial_state("s", json!({})).with_message(Message::user("hi"));
309 let state = thread.rebuild_state().unwrap();
310 let call = ToolCall::new("call_1", SKILL_ACTIVATE_TOOL_ID, json!({"skill": "docx"}));
311 let activate_tool = tools.get(SKILL_ACTIVATE_TOOL_ID).unwrap().as_ref();
312 let exec = execute_single_tool(Some(activate_tool), &call, &state).await;
313 assert!(exec.result.is_success());
314 let thread = thread.with_patch(exec.patch.unwrap());
315
316 let state = thread.rebuild_state().unwrap();
317 let call = ToolCall::new(
318 "call_2",
319 SKILL_LOAD_RESOURCE_TOOL_ID,
320 json!({"skill": "docx", "path": "references/DOCX-JS.md"}),
321 );
322 let load_resource_tool = tools.get(SKILL_LOAD_RESOURCE_TOOL_ID).unwrap().as_ref();
323 let exec = execute_single_tool(Some(load_resource_tool), &call, &state).await;
324 assert!(exec.result.is_success());
325 let thread = if let Some(patch) = exec.patch {
326 thread.with_patch(patch)
327 } else {
328 thread
329 };
330
331 let plugin: Arc<dyn AgentBehavior> = Arc::new(sys.discovery_plugin());
333 let state = thread.rebuild_state().unwrap();
334 let fix = tirea_contract::testing::TestFixture::new_with_state(state);
335 let run_policy = tirea_contract::RunPolicy::default();
336 let fixture_ctx = fix.ctx();
337 let ctx = ReadOnlyContext::new(
338 Phase::BeforeInference,
339 &thread.id,
340 &thread.messages,
341 &run_policy,
342 fixture_ctx.doc(),
343 );
344 let actions = plugin.before_inference(&ctx).await;
345
346 use tirea_contract::runtime::phase::BeforeInferenceAction;
347 let system_context_count = actions
348 .as_slice()
349 .iter()
350 .filter(|a| matches!(a, BeforeInferenceAction::AddSystemContext(_)))
351 .count();
352 assert_eq!(system_context_count, 1);
353 }
354}