tirea_extension_skills/
subsystem.rs

1use crate::{
2    LoadSkillResourceTool, SkillActivateTool, SkillDiscoveryPlugin, SkillRegistry, SkillScriptTool,
3};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tirea_contract::runtime::tool_call::Tool;
7
8/// Errors returned when wiring the skills subsystem into an agent.
9#[derive(Debug, thiserror::Error)]
10pub enum SkillSubsystemError {
11    #[error("tool id already registered: {0}")]
12    ToolIdConflict(String),
13}
14
15/// High-level facade for wiring skills into an agent.
16///
17/// Callers should prefer this over manually instantiating the tools/plugins so:
18/// - tool ids stay consistent
19/// - plugin ordering is stable (discovery first, runtime second)
20///
21/// # Example
22///
23/// ```ignore
24/// use tirea::extensions::skills::{
25///     FsSkill, InMemorySkillRegistry, SkillSubsystem,
26/// };
27/// use std::sync::Arc;
28///
29/// // 1) Discover skills and build a registry.
30/// let result = FsSkill::discover("skills").unwrap();
31/// let registry = Arc::new(
32///     InMemorySkillRegistry::from_skills(FsSkill::into_arc_skills(result.skills)),
33/// );
34///
35/// // 2) Wire into subsystem.
36/// let skills = SkillSubsystem::new(registry);
37///
38/// // 3) Register tools (skill activation + reference/script utilities).
39/// let mut tools = std::collections::HashMap::new();
40/// skills.extend_tools(&mut tools).unwrap();
41///
42/// // 4) Register the discovery plugin: injects skills catalog before inference.
43/// let config = BaseAgent::new("gpt-4o-mini").with_plugin(Arc::new(skills.discovery_plugin()));
44/// # let _ = config;
45/// # let _ = tools;
46/// ```
47#[derive(Clone)]
48pub struct SkillSubsystem {
49    registry: Arc<dyn SkillRegistry>,
50}
51
52impl std::fmt::Debug for SkillSubsystem {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("SkillSubsystem").finish_non_exhaustive()
55    }
56}
57
58impl SkillSubsystem {
59    pub fn new(registry: Arc<dyn SkillRegistry>) -> Self {
60        Self { registry }
61    }
62
63    pub fn registry(&self) -> &Arc<dyn SkillRegistry> {
64        &self.registry
65    }
66
67    /// Build the discovery plugin (injects skills catalog before inference).
68    pub fn discovery_plugin(&self) -> SkillDiscoveryPlugin {
69        SkillDiscoveryPlugin::new(self.registry.clone())
70    }
71
72    /// Construct the skills tools map.
73    ///
74    /// Tool ids:
75    /// - `SKILL_ACTIVATE_TOOL_ID`
76    /// - `SKILL_LOAD_RESOURCE_TOOL_ID`
77    /// - `SKILL_SCRIPT_TOOL_ID`
78    pub fn tools(&self) -> HashMap<String, Arc<dyn Tool>> {
79        let mut out: HashMap<String, Arc<dyn Tool>> = HashMap::new();
80        // These inserts cannot conflict inside an empty map.
81        let _ = self.extend_tools(&mut out);
82        out
83    }
84
85    /// Add skills tools to an existing tool map.
86    ///
87    /// Returns an error if any tool id is already present.
88    pub fn extend_tools(
89        &self,
90        tools: &mut HashMap<String, Arc<dyn Tool>>,
91    ) -> Result<(), SkillSubsystemError> {
92        let registry = self.registry.clone();
93        let tool_defs: Vec<Arc<dyn Tool>> = vec![
94            Arc::new(SkillActivateTool::new(registry.clone())),
95            Arc::new(LoadSkillResourceTool::new(registry.clone())),
96            Arc::new(SkillScriptTool::new(registry)),
97        ];
98
99        for t in tool_defs {
100            let id = t.descriptor().id.clone();
101            if tools.contains_key(&id) {
102                return Err(SkillSubsystemError::ToolIdConflict(id));
103            }
104            tools.insert(id, t);
105        }
106
107        Ok(())
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::{
115        FsSkill, InMemorySkillRegistry, Skill, SKILL_ACTIVATE_TOOL_ID, SKILL_LOAD_RESOURCE_TOOL_ID,
116        SKILL_SCRIPT_TOOL_ID,
117    };
118    use async_trait::async_trait;
119    use serde_json::json;
120    use serde_json::Value;
121    use std::fs;
122    use std::io::Write;
123    use tempfile::TempDir;
124    use tirea_contract::runtime::behavior::{AgentBehavior, ReadOnlyContext};
125    use tirea_contract::runtime::phase::Phase;
126    use tirea_contract::runtime::state::{reduce_state_actions, ScopeContext};
127    use tirea_contract::runtime::tool_call::{ToolError, ToolResult};
128    use tirea_contract::testing::TestFixture;
129    use tirea_contract::thread::Thread;
130    use tirea_contract::thread::{Message, ToolCall};
131    use tirea_state::TrackedPatch;
132
133    fn make_registry(skills: Vec<Arc<dyn Skill>>) -> Arc<dyn SkillRegistry> {
134        Arc::new(InMemorySkillRegistry::from_skills(skills))
135    }
136
137    struct LocalToolExecution {
138        result: ToolResult,
139        patch: Option<TrackedPatch>,
140    }
141
142    async fn execute_single_tool(
143        tool: Option<&dyn Tool>,
144        call: &ToolCall,
145        state: &Value,
146    ) -> LocalToolExecution {
147        let Some(tool) = tool else {
148            return LocalToolExecution {
149                result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
150                patch: None,
151            };
152        };
153
154        let fix = TestFixture::new_with_state(state.clone());
155        let tool_ctx = fix.ctx_with(&call.id, format!("tool:{}", call.name));
156        let effect = match tool.execute_effect(call.arguments.clone(), &tool_ctx).await {
157            Ok(e) => e,
158            Err(e) => {
159                return LocalToolExecution {
160                    result: ToolResult::error(&call.name, e.to_string()),
161                    patch: None,
162                };
163            }
164        };
165        let (result, actions) = effect.into_parts();
166        let state_actions: Vec<_> = actions
167            .into_iter()
168            .filter_map(|a| match a {
169                tirea_contract::runtime::phase::AfterToolExecuteAction::State(sa) => Some(sa),
170                _ => None,
171            })
172            .collect();
173        let scope_ctx = ScopeContext::run();
174        let patches = reduce_state_actions(
175            state_actions,
176            state,
177            &format!("tool:{}", call.name),
178            &scope_ctx,
179        )
180        .unwrap();
181        let patch = patches
182            .into_iter()
183            .reduce(|mut acc, p| {
184                acc.patch.merge(p.into_patch());
185                acc
186            })
187            .filter(|p| !p.patch().is_empty());
188        LocalToolExecution { result, patch }
189    }
190
191    #[derive(Debug)]
192    struct DummyTool;
193
194    #[async_trait]
195    impl Tool for DummyTool {
196        fn descriptor(&self) -> tirea_contract::runtime::tool_call::ToolDescriptor {
197            tirea_contract::runtime::tool_call::ToolDescriptor::new(
198                SKILL_ACTIVATE_TOOL_ID,
199                "x",
200                "x",
201            )
202            .with_parameters(json!({}))
203        }
204
205        async fn execute(
206            &self,
207            _args: Value,
208            _ctx: &tirea_contract::runtime::tool_call::ToolCallContext<'_>,
209        ) -> Result<
210            tirea_contract::runtime::tool_call::ToolResult,
211            tirea_contract::runtime::tool_call::ToolError,
212        > {
213            Ok(tirea_contract::runtime::tool_call::ToolResult::success(
214                SKILL_ACTIVATE_TOOL_ID,
215                json!({}),
216            ))
217        }
218    }
219
220    fn make_subsystem() -> (TempDir, SkillSubsystem) {
221        let td = TempDir::new().unwrap();
222        let root = td.path().join("skills");
223        fs::create_dir_all(root.join("s1")).unwrap();
224        fs::write(
225            root.join("s1").join("SKILL.md"),
226            "---\nname: s1\ndescription: ok\n---\nBody\n",
227        )
228        .unwrap();
229
230        let result = FsSkill::discover(root).unwrap();
231        let sys = SkillSubsystem::new(make_registry(FsSkill::into_arc_skills(result.skills)));
232        (td, sys)
233    }
234
235    #[test]
236    fn subsystem_extend_tools_detects_conflict() {
237        let (_td, sys) = make_subsystem();
238        let mut tools = HashMap::<String, Arc<dyn Tool>>::new();
239        tools.insert(SKILL_ACTIVATE_TOOL_ID.to_string(), Arc::new(DummyTool));
240        let err = sys.extend_tools(&mut tools).unwrap_err();
241        assert!(err.to_string().contains("tool id already registered"));
242    }
243
244    #[test]
245    fn subsystem_tools_returns_expected_ids() {
246        let (_td, sys) = make_subsystem();
247        let tools = sys.tools();
248        assert!(tools.contains_key(SKILL_ACTIVATE_TOOL_ID));
249        assert!(tools.contains_key(SKILL_LOAD_RESOURCE_TOOL_ID));
250        assert!(tools.contains_key(SKILL_SCRIPT_TOOL_ID));
251        assert_eq!(tools.len(), 3);
252    }
253
254    #[test]
255    fn subsystem_extend_tools_inserts_tools_into_existing_map() {
256        let (_td, sys) = make_subsystem();
257        let mut tools = HashMap::<String, Arc<dyn Tool>>::new();
258        tools.insert("other".to_string(), Arc::new(DummyOtherTool));
259        sys.extend_tools(&mut tools).unwrap();
260        assert!(tools.contains_key("other"));
261        assert!(tools.contains_key(SKILL_ACTIVATE_TOOL_ID));
262        assert!(tools.contains_key(SKILL_LOAD_RESOURCE_TOOL_ID));
263        assert!(tools.contains_key(SKILL_SCRIPT_TOOL_ID));
264        assert_eq!(tools.len(), 4);
265    }
266
267    #[derive(Debug)]
268    struct DummyOtherTool;
269
270    #[async_trait]
271    impl Tool for DummyOtherTool {
272        fn descriptor(&self) -> tirea_contract::runtime::tool_call::ToolDescriptor {
273            tirea_contract::runtime::tool_call::ToolDescriptor::new("other", "x", "x")
274                .with_parameters(json!({}))
275        }
276
277        async fn execute(
278            &self,
279            _args: Value,
280            _ctx: &tirea_contract::runtime::tool_call::ToolCallContext<'_>,
281        ) -> Result<ToolResult, ToolError> {
282            Ok(ToolResult::success("other", json!({})))
283        }
284    }
285
286    #[tokio::test]
287    async fn subsystem_plugin_injects_catalog_and_activated_skill() {
288        let td = TempDir::new().unwrap();
289        let root = td.path().join("skills");
290        fs::create_dir_all(root.join("docx").join("references")).unwrap();
291        fs::write(
292            root.join("docx").join("references").join("DOCX-JS.md"),
293            "Use docx-js for new documents.",
294        )
295        .unwrap();
296
297        let mut f = fs::File::create(root.join("docx").join("SKILL.md")).unwrap();
298        f.write_all(
299            b"---\nname: docx\ndescription: DOCX guidance\n---\nUse docx-js for new documents.\n\n",
300        )
301        .unwrap();
302
303        let result = FsSkill::discover(root).unwrap();
304        let sys = SkillSubsystem::new(make_registry(FsSkill::into_arc_skills(result.skills)));
305        let tools = sys.tools();
306
307        // Activate the skill via the registered "skill" tool.
308        let thread = Thread::with_initial_state("s", json!({})).with_message(Message::user("hi"));
309        let state = thread.rebuild_state().unwrap();
310        let call = ToolCall::new("call_1", SKILL_ACTIVATE_TOOL_ID, json!({"skill": "docx"}));
311        let activate_tool = tools.get(SKILL_ACTIVATE_TOOL_ID).unwrap().as_ref();
312        let exec = execute_single_tool(Some(activate_tool), &call, &state).await;
313        assert!(exec.result.is_success());
314        let thread = thread.with_patch(exec.patch.unwrap());
315
316        let state = thread.rebuild_state().unwrap();
317        let call = ToolCall::new(
318            "call_2",
319            SKILL_LOAD_RESOURCE_TOOL_ID,
320            json!({"skill": "docx", "path": "references/DOCX-JS.md"}),
321        );
322        let load_resource_tool = tools.get(SKILL_LOAD_RESOURCE_TOOL_ID).unwrap().as_ref();
323        let exec = execute_single_tool(Some(load_resource_tool), &call, &state).await;
324        assert!(exec.result.is_success());
325        let thread = if let Some(patch) = exec.patch {
326            thread.with_patch(patch)
327        } else {
328            thread
329        };
330
331        // Run the discovery plugin and verify discovery catalog is injected.
332        let plugin: Arc<dyn AgentBehavior> = Arc::new(sys.discovery_plugin());
333        let state = thread.rebuild_state().unwrap();
334        let fix = tirea_contract::testing::TestFixture::new_with_state(state);
335        let run_policy = tirea_contract::RunPolicy::default();
336        let fixture_ctx = fix.ctx();
337        let ctx = ReadOnlyContext::new(
338            Phase::BeforeInference,
339            &thread.id,
340            &thread.messages,
341            &run_policy,
342            fixture_ctx.doc(),
343        );
344        let actions = plugin.before_inference(&ctx).await;
345
346        use tirea_contract::runtime::phase::BeforeInferenceAction;
347        let system_context_count = actions
348            .as_slice()
349            .iter()
350            .filter(|a| matches!(a, BeforeInferenceAction::AddSystemContext(_)))
351            .count();
352        assert_eq!(system_context_count, 1);
353    }
354}