tirea_agentos/extensions/
mcp.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::composition::ToolRegistry;
5use crate::contracts::runtime::tool_call::Tool;
6
7pub use tirea_extension_mcp::*;
8
9impl ToolRegistry for tirea_extension_mcp::McpToolRegistry {
10    fn len(&self) -> usize {
11        tirea_extension_mcp::McpToolRegistry::len(self)
12    }
13
14    fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
15        tirea_extension_mcp::McpToolRegistry::get(self, id)
16    }
17
18    fn ids(&self) -> Vec<String> {
19        tirea_extension_mcp::McpToolRegistry::ids(self)
20    }
21
22    fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
23        tirea_extension_mcp::McpToolRegistry::snapshot(self)
24    }
25}
26
27#[cfg(test)]
28mod tests {
29    use super::*;
30    use async_trait::async_trait;
31    use mcp::transport::{McpServerConnectionConfig, McpTransportError, TransportTypeId};
32    use mcp::McpToolDefinition;
33    use serde_json::Value;
34    use std::sync::{Arc, Mutex};
35    use tokio::sync::mpsc;
36
37    use crate::composition::{AgentDefinition, AgentDefinitionSpec};
38    use crate::runtime::AgentOs;
39
40    #[derive(Debug, Clone)]
41    struct MutableTransport {
42        tools: Arc<Mutex<Vec<McpToolDefinition>>>,
43    }
44
45    impl MutableTransport {
46        fn new(tools: Vec<McpToolDefinition>) -> Self {
47            Self {
48                tools: Arc::new(Mutex::new(tools)),
49            }
50        }
51
52        fn replace(&self, tools: Vec<McpToolDefinition>) {
53            match self.tools.lock() {
54                Ok(mut guard) => *guard = tools,
55                Err(poisoned) => *poisoned.into_inner() = tools,
56            }
57        }
58    }
59
60    #[async_trait]
61    impl McpToolTransport for MutableTransport {
62        async fn list_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError> {
63            let tools = match self.tools.lock() {
64                Ok(guard) => guard.clone(),
65                Err(poisoned) => poisoned.into_inner().clone(),
66            };
67            Ok(tools)
68        }
69
70        async fn call_tool(
71            &self,
72            _name: &str,
73            _args: Value,
74            _progress_tx: Option<mpsc::UnboundedSender<McpProgressUpdate>>,
75        ) -> Result<mcp::CallToolResult, McpTransportError> {
76            Ok(mcp::CallToolResult {
77                content: vec![mcp::ToolContent::text("ok")],
78                structured_content: None,
79                is_error: None,
80            })
81        }
82
83        fn transport_type(&self) -> TransportTypeId {
84            TransportTypeId::Stdio
85        }
86    }
87
88    fn cfg(name: &str) -> McpServerConnectionConfig {
89        McpServerConnectionConfig::stdio(name, "unused", vec![])
90    }
91
92    #[tokio::test]
93    async fn mcp_registry_implements_dynamic_tool_registry() {
94        let transport = Arc::new(MutableTransport::new(vec![McpToolDefinition::new("echo")]));
95        let manager = McpToolRegistryManager::from_transports([(
96            cfg("mcp_s1"),
97            transport.clone() as Arc<dyn McpToolTransport>,
98        )])
99        .await
100        .expect("build manager");
101        let registry = Arc::new(manager.registry()) as Arc<dyn ToolRegistry>;
102
103        let os = AgentOs::builder()
104            .with_agent_spec(AgentDefinitionSpec::local_with_id(
105                "assistant",
106                AgentDefinition::new("gpt-4o-mini"),
107            ))
108            .with_tool_registry(registry)
109            .build()
110            .expect("build agent os");
111
112        let resolved1 = os.resolve("assistant").expect("resolve first snapshot");
113        assert!(resolved1.tools.contains_key("mcp__mcp_s1__echo"));
114        assert!(!resolved1.tools.contains_key("mcp__mcp_s1__sum"));
115
116        transport.replace(vec![McpToolDefinition::new("sum")]);
117        manager.refresh().await.expect("refresh registry");
118
119        let resolved2 = os.resolve("assistant").expect("resolve refreshed snapshot");
120        assert!(!resolved2.tools.contains_key("mcp__mcp_s1__echo"));
121        assert!(resolved2.tools.contains_key("mcp__mcp_s1__sum"));
122    }
123}