tirea_agentos/extensions/
mcp.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::composition::ToolRegistry;
5use crate::contracts::runtime::tool_call::Tool;
6
7pub use tirea_extension_mcp::*;
8
9impl ToolRegistry for tirea_extension_mcp::McpToolRegistry {
10 fn len(&self) -> usize {
11 tirea_extension_mcp::McpToolRegistry::len(self)
12 }
13
14 fn get(&self, id: &str) -> Option<Arc<dyn Tool>> {
15 tirea_extension_mcp::McpToolRegistry::get(self, id)
16 }
17
18 fn ids(&self) -> Vec<String> {
19 tirea_extension_mcp::McpToolRegistry::ids(self)
20 }
21
22 fn snapshot(&self) -> HashMap<String, Arc<dyn Tool>> {
23 tirea_extension_mcp::McpToolRegistry::snapshot(self)
24 }
25}
26
27#[cfg(test)]
28mod tests {
29 use super::*;
30 use async_trait::async_trait;
31 use mcp::transport::{McpServerConnectionConfig, McpTransportError, TransportTypeId};
32 use mcp::McpToolDefinition;
33 use serde_json::Value;
34 use std::sync::{Arc, Mutex};
35 use tokio::sync::mpsc;
36
37 use crate::composition::{AgentDefinition, AgentDefinitionSpec};
38 use crate::runtime::AgentOs;
39
40 #[derive(Debug, Clone)]
41 struct MutableTransport {
42 tools: Arc<Mutex<Vec<McpToolDefinition>>>,
43 }
44
45 impl MutableTransport {
46 fn new(tools: Vec<McpToolDefinition>) -> Self {
47 Self {
48 tools: Arc::new(Mutex::new(tools)),
49 }
50 }
51
52 fn replace(&self, tools: Vec<McpToolDefinition>) {
53 match self.tools.lock() {
54 Ok(mut guard) => *guard = tools,
55 Err(poisoned) => *poisoned.into_inner() = tools,
56 }
57 }
58 }
59
60 #[async_trait]
61 impl McpToolTransport for MutableTransport {
62 async fn list_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError> {
63 let tools = match self.tools.lock() {
64 Ok(guard) => guard.clone(),
65 Err(poisoned) => poisoned.into_inner().clone(),
66 };
67 Ok(tools)
68 }
69
70 async fn call_tool(
71 &self,
72 _name: &str,
73 _args: Value,
74 _progress_tx: Option<mpsc::UnboundedSender<McpProgressUpdate>>,
75 ) -> Result<mcp::CallToolResult, McpTransportError> {
76 Ok(mcp::CallToolResult {
77 content: vec![mcp::ToolContent::text("ok")],
78 structured_content: None,
79 is_error: None,
80 })
81 }
82
83 fn transport_type(&self) -> TransportTypeId {
84 TransportTypeId::Stdio
85 }
86 }
87
88 fn cfg(name: &str) -> McpServerConnectionConfig {
89 McpServerConnectionConfig::stdio(name, "unused", vec![])
90 }
91
92 #[tokio::test]
93 async fn mcp_registry_implements_dynamic_tool_registry() {
94 let transport = Arc::new(MutableTransport::new(vec![McpToolDefinition::new("echo")]));
95 let manager = McpToolRegistryManager::from_transports([(
96 cfg("mcp_s1"),
97 transport.clone() as Arc<dyn McpToolTransport>,
98 )])
99 .await
100 .expect("build manager");
101 let registry = Arc::new(manager.registry()) as Arc<dyn ToolRegistry>;
102
103 let os = AgentOs::builder()
104 .with_agent_spec(AgentDefinitionSpec::local_with_id(
105 "assistant",
106 AgentDefinition::new("gpt-4o-mini"),
107 ))
108 .with_tool_registry(registry)
109 .build()
110 .expect("build agent os");
111
112 let resolved1 = os.resolve("assistant").expect("resolve first snapshot");
113 assert!(resolved1.tools.contains_key("mcp__mcp_s1__echo"));
114 assert!(!resolved1.tools.contains_key("mcp__mcp_s1__sum"));
115
116 transport.replace(vec![McpToolDefinition::new("sum")]);
117 manager.refresh().await.expect("refresh registry");
118
119 let resolved2 = os.resolve("assistant").expect("resolve refreshed snapshot");
120 assert!(!resolved2.tools.contains_key("mcp__mcp_s1__echo"));
121 assert!(resolved2.tools.contains_key("mcp__mcp_s1__sum"));
122 }
123}