1use super::ToolCallContext;
6use crate::runtime::phase::AfterToolExecuteAction;
7use crate::runtime::phase::SuspendTicket;
8use async_trait::async_trait;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::HashMap;
13use thiserror::Error;
14
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ToolStatus {
19 Success,
21 Warning,
23 Pending,
25 Error,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ToolResult {
32 pub tool_name: String,
34 pub status: ToolStatus,
36 pub data: Value,
38 pub message: Option<String>,
40 pub metadata: HashMap<String, Value>,
42 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub suspension: Option<Box<SuspendTicket>>,
45}
46
47impl ToolResult {
48 pub fn success(tool_name: impl Into<String>, data: impl Into<Value>) -> Self {
50 Self {
51 tool_name: tool_name.into(),
52 status: ToolStatus::Success,
53 data: data.into(),
54 message: None,
55 metadata: HashMap::new(),
56 suspension: None,
57 }
58 }
59
60 pub fn success_with_message(
62 tool_name: impl Into<String>,
63 data: impl Into<Value>,
64 message: impl Into<String>,
65 ) -> Self {
66 Self {
67 tool_name: tool_name.into(),
68 status: ToolStatus::Success,
69 data: data.into(),
70 message: Some(message.into()),
71 metadata: HashMap::new(),
72 suspension: None,
73 }
74 }
75
76 pub fn error(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
78 Self {
79 tool_name: tool_name.into(),
80 status: ToolStatus::Error,
81 data: Value::Null,
82 message: Some(message.into()),
83 metadata: HashMap::new(),
84 suspension: None,
85 }
86 }
87
88 pub fn error_with_code(
90 tool_name: impl Into<String>,
91 code: impl Into<String>,
92 message: impl Into<String>,
93 ) -> Self {
94 let tool_name = tool_name.into();
95 let code = code.into();
96 let message = message.into();
97 Self {
98 tool_name,
99 status: ToolStatus::Error,
100 data: serde_json::json!({
101 "error": {
102 "code": code,
103 "message": message,
104 }
105 }),
106 message: Some(format!("[{code}] {message}")),
107 metadata: HashMap::new(),
108 suspension: None,
109 }
110 }
111
112 pub fn suspended(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
114 Self {
115 tool_name: tool_name.into(),
116 status: ToolStatus::Pending,
117 data: Value::Null,
118 message: Some(message.into()),
119 metadata: HashMap::new(),
120 suspension: None,
121 }
122 }
123
124 pub fn suspended_with(
126 tool_name: impl Into<String>,
127 message: impl Into<String>,
128 ticket: SuspendTicket,
129 ) -> Self {
130 Self {
131 tool_name: tool_name.into(),
132 status: ToolStatus::Pending,
133 data: Value::Null,
134 message: Some(message.into()),
135 metadata: HashMap::new(),
136 suspension: Some(Box::new(ticket)),
137 }
138 }
139
140 pub fn warning(
142 tool_name: impl Into<String>,
143 data: impl Into<Value>,
144 message: impl Into<String>,
145 ) -> Self {
146 Self {
147 tool_name: tool_name.into(),
148 status: ToolStatus::Warning,
149 data: data.into(),
150 message: Some(message.into()),
151 metadata: HashMap::new(),
152 suspension: None,
153 }
154 }
155
156 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
158 self.metadata.insert(key.into(), value.into());
159 self
160 }
161
162 pub fn with_suspension(mut self, ticket: SuspendTicket) -> Self {
164 self.suspension = Some(Box::new(ticket));
165 self
166 }
167
168 pub fn is_success(&self) -> bool {
170 matches!(self.status, ToolStatus::Success | ToolStatus::Warning)
171 }
172
173 pub fn is_pending(&self) -> bool {
175 matches!(self.status, ToolStatus::Pending)
176 }
177
178 pub fn is_error(&self) -> bool {
180 matches!(self.status, ToolStatus::Error)
181 }
182
183 pub fn suspension(&self) -> Option<SuspendTicket> {
185 self.suspension.as_deref().cloned()
186 }
187
188 pub fn to_json(&self) -> Value {
190 serde_json::to_value(self).unwrap_or(Value::Null)
191 }
192}
193
194pub struct ToolExecutionEffect {
200 pub result: ToolResult,
201 actions: Vec<AfterToolExecuteAction>,
203}
204
205impl ToolExecutionEffect {
206 #[must_use]
207 pub fn new(result: ToolResult) -> Self {
208 Self {
209 result,
210 actions: Vec::new(),
211 }
212 }
213
214 #[must_use]
216 pub fn with_action(mut self, action: impl Into<AfterToolExecuteAction>) -> Self {
217 self.actions.push(action.into());
218 self
219 }
220
221 pub fn into_parts(self) -> (ToolResult, Vec<AfterToolExecuteAction>) {
222 (self.result, self.actions)
223 }
224}
225
226impl From<ToolResult> for ToolExecutionEffect {
227 fn from(result: ToolResult) -> Self {
228 Self::new(result)
229 }
230}
231
232#[derive(Debug, Error)]
234pub enum ToolError {
235 #[error("Invalid arguments: {0}")]
236 InvalidArguments(String),
237
238 #[error("Execution failed: {0}")]
239 ExecutionFailed(String),
240
241 #[error("Denied: {0}")]
242 Denied(String),
243
244 #[error("Not found: {0}")]
245 NotFound(String),
246
247 #[error("Internal error: {0}")]
248 Internal(String),
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct ToolDescriptor {
254 pub id: String,
256 pub name: String,
258 pub description: String,
260 pub parameters: Value,
262 pub category: Option<String>,
264 pub metadata: HashMap<String, Value>,
266}
267
268impl ToolDescriptor {
269 pub fn new(
271 id: impl Into<String>,
272 name: impl Into<String>,
273 description: impl Into<String>,
274 ) -> Self {
275 Self {
276 id: id.into(),
277 name: name.into(),
278 description: description.into(),
279 parameters: serde_json::json!({"type": "object", "properties": {}}),
280 category: None,
281 metadata: HashMap::new(),
282 }
283 }
284
285 pub fn with_parameters(mut self, schema: Value) -> Self {
287 self.parameters = schema;
288 self
289 }
290
291 pub fn with_category(mut self, category: impl Into<String>) -> Self {
293 self.category = Some(category.into());
294 self
295 }
296
297 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
299 self.metadata.insert(key.into(), value.into());
300 self
301 }
302}
303
304#[async_trait]
350pub trait Tool: Send + Sync {
351 fn descriptor(&self) -> ToolDescriptor;
353
354 fn validate_args(&self, args: &Value) -> Result<(), ToolError> {
359 validate_against_schema(&self.descriptor().parameters, args)
360 }
361
362 async fn execute(
375 &self,
376 args: Value,
377 ctx: &ToolCallContext<'_>,
378 ) -> Result<ToolResult, ToolError>;
379
380 async fn execute_effect(
388 &self,
389 args: Value,
390 _ctx: &ToolCallContext<'_>,
391 ) -> Result<ToolExecutionEffect, ToolError> {
392 let result = self.execute(args, _ctx).await?;
393 Ok(ToolExecutionEffect::from(result))
394 }
395}
396
397pub fn validate_against_schema(schema: &Value, args: &Value) -> Result<(), ToolError> {
402 let validator = jsonschema::Validator::new(schema)
403 .map_err(|e| ToolError::Internal(format!("invalid tool schema: {e}")))?;
404 if validator.is_valid(args) {
405 return Ok(());
406 }
407 let errors: Vec<String> = validator.iter_errors(args).map(|e| e.to_string()).collect();
408 Err(ToolError::InvalidArguments(errors.join("; ")))
409}
410
411#[async_trait]
448pub trait TypedTool: Send + Sync {
449 type Args: for<'de> Deserialize<'de> + JsonSchema + Send;
451
452 fn tool_id(&self) -> &str;
454
455 fn name(&self) -> &str;
457
458 fn description(&self) -> &str;
460
461 fn validate(&self, _args: &Self::Args) -> Result<(), String> {
465 Ok(())
466 }
467
468 async fn execute(
470 &self,
471 args: Self::Args,
472 ctx: &ToolCallContext<'_>,
473 ) -> Result<ToolResult, ToolError>;
474
475 async fn execute_effect(
480 &self,
481 args: Self::Args,
482 ctx: &ToolCallContext<'_>,
483 ) -> Result<ToolExecutionEffect, ToolError> {
484 let result = self.execute(args, ctx).await?;
485 Ok(ToolExecutionEffect::from(result))
486 }
487}
488
489#[async_trait]
490impl<T: TypedTool> Tool for T {
491 fn descriptor(&self) -> ToolDescriptor {
492 let schema = typed_tool_schema::<T::Args>();
493 ToolDescriptor::new(self.tool_id(), self.name(), self.description()).with_parameters(schema)
494 }
495
496 fn validate_args(&self, _args: &Value) -> Result<(), ToolError> {
498 Ok(())
499 }
500
501 async fn execute(
502 &self,
503 args: Value,
504 ctx: &ToolCallContext<'_>,
505 ) -> Result<ToolResult, ToolError> {
506 let typed: T::Args =
507 serde_json::from_value(args).map_err(|e| ToolError::InvalidArguments(e.to_string()))?;
508 self.validate(&typed).map_err(ToolError::InvalidArguments)?;
509 TypedTool::execute(self, typed, ctx).await
510 }
511
512 async fn execute_effect(
513 &self,
514 args: Value,
515 ctx: &ToolCallContext<'_>,
516 ) -> Result<ToolExecutionEffect, ToolError> {
517 let typed: T::Args =
518 serde_json::from_value(args).map_err(|e| ToolError::InvalidArguments(e.to_string()))?;
519 self.validate(&typed).map_err(ToolError::InvalidArguments)?;
520 TypedTool::execute_effect(self, typed, ctx).await
521 }
522}
523
524fn typed_tool_schema<T: JsonSchema>() -> Value {
526 let mut v = serde_json::to_value(schemars::schema_for!(T))
527 .unwrap_or_else(|_| serde_json::json!({"type": "object", "properties": {}}));
528 if let Some(obj) = v.as_object_mut() {
530 obj.remove("$schema");
531 }
532 v
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use crate::runtime::phase::SuspendTicket;
539 use crate::runtime::state::AnyStateAction;
540 use crate::runtime::state::StateSpec;
541 use crate::runtime::Suspension;
542 use crate::runtime::{PendingToolCall, ToolCallResumeMode};
543 use crate::testing::TestFixtureState;
544 use serde_json::json;
545 use tirea_state::{DocCell, PatchSink, Path as TPath, State, TireaResult};
546
547 #[test]
552 fn test_tool_error_invalid_arguments() {
553 let err = ToolError::InvalidArguments("missing field".to_string());
554 assert_eq!(err.to_string(), "Invalid arguments: missing field");
555 }
556
557 #[test]
558 fn test_tool_error_execution_failed() {
559 let err = ToolError::ExecutionFailed("timeout".to_string());
560 assert_eq!(err.to_string(), "Execution failed: timeout");
561 }
562
563 #[test]
564 fn test_tool_error_denied() {
565 let err = ToolError::Denied("no access".to_string());
566 assert_eq!(err.to_string(), "Denied: no access");
567 }
568
569 #[test]
570 fn test_tool_error_not_found() {
571 let err = ToolError::NotFound("file.txt".to_string());
572 assert_eq!(err.to_string(), "Not found: file.txt");
573 }
574
575 #[test]
576 fn test_tool_error_internal() {
577 let err = ToolError::Internal("unexpected".to_string());
578 assert_eq!(err.to_string(), "Internal error: unexpected");
579 }
580
581 #[test]
586 fn test_tool_status_serialization() {
587 assert_eq!(
588 serde_json::to_string(&ToolStatus::Success).unwrap(),
589 "\"success\""
590 );
591 assert_eq!(
592 serde_json::to_string(&ToolStatus::Warning).unwrap(),
593 "\"warning\""
594 );
595 assert_eq!(
596 serde_json::to_string(&ToolStatus::Pending).unwrap(),
597 "\"pending\""
598 );
599 assert_eq!(
600 serde_json::to_string(&ToolStatus::Error).unwrap(),
601 "\"error\""
602 );
603 }
604
605 #[test]
606 fn test_tool_status_deserialization() {
607 assert_eq!(
608 serde_json::from_str::<ToolStatus>("\"success\"").unwrap(),
609 ToolStatus::Success
610 );
611 assert_eq!(
612 serde_json::from_str::<ToolStatus>("\"warning\"").unwrap(),
613 ToolStatus::Warning
614 );
615 assert_eq!(
616 serde_json::from_str::<ToolStatus>("\"pending\"").unwrap(),
617 ToolStatus::Pending
618 );
619 assert_eq!(
620 serde_json::from_str::<ToolStatus>("\"error\"").unwrap(),
621 ToolStatus::Error
622 );
623 }
624
625 #[test]
626 fn test_tool_status_equality() {
627 assert_eq!(ToolStatus::Success, ToolStatus::Success);
628 assert_ne!(ToolStatus::Success, ToolStatus::Error);
629 }
630
631 #[test]
632 fn test_tool_status_clone() {
633 let status = ToolStatus::Warning;
634 let cloned = status.clone();
635 assert_eq!(status, cloned);
636 }
637
638 #[test]
639 fn test_tool_status_debug() {
640 assert_eq!(format!("{:?}", ToolStatus::Success), "Success");
641 assert_eq!(format!("{:?}", ToolStatus::Error), "Error");
642 }
643
644 #[test]
649 fn test_tool_result_success() {
650 let result = ToolResult::success("my_tool", json!({"value": 42}));
651 assert_eq!(result.tool_name, "my_tool");
652 assert_eq!(result.status, ToolStatus::Success);
653 assert_eq!(result.data, json!({"value": 42}));
654 assert!(result.message.is_none());
655 assert!(result.metadata.is_empty());
656 assert!(result.is_success());
657 assert!(!result.is_error());
658 assert!(!result.is_pending());
659 }
660
661 #[test]
662 fn test_tool_result_success_with_message() {
663 let result = ToolResult::success_with_message(
664 "my_tool",
665 json!({"done": true}),
666 "Operation complete",
667 );
668 assert_eq!(result.tool_name, "my_tool");
669 assert_eq!(result.status, ToolStatus::Success);
670 assert_eq!(result.data, json!({"done": true}));
671 assert_eq!(result.message, Some("Operation complete".to_string()));
672 assert!(result.is_success());
673 }
674
675 #[test]
676 fn test_tool_result_error() {
677 let result = ToolResult::error("my_tool", "Something went wrong");
678 assert_eq!(result.tool_name, "my_tool");
679 assert_eq!(result.status, ToolStatus::Error);
680 assert_eq!(result.data, Value::Null);
681 assert_eq!(result.message, Some("Something went wrong".to_string()));
682 assert!(!result.is_success());
683 assert!(result.is_error());
684 assert!(!result.is_pending());
685 }
686
687 #[test]
688 fn test_tool_result_error_with_code() {
689 let result = ToolResult::error_with_code("my_tool", "invalid_arguments", "missing input");
690 assert_eq!(result.tool_name, "my_tool");
691 assert_eq!(result.status, ToolStatus::Error);
692 assert_eq!(
693 result.data,
694 json!({
695 "error": {
696 "code": "invalid_arguments",
697 "message": "missing input"
698 }
699 })
700 );
701 assert_eq!(
702 result.message,
703 Some("[invalid_arguments] missing input".to_string())
704 );
705 assert!(result.is_error());
706 }
707
708 #[test]
709 fn test_tool_result_pending() {
710 let result = ToolResult::suspended("my_tool", "Waiting for confirmation");
711 assert_eq!(result.tool_name, "my_tool");
712 assert_eq!(result.status, ToolStatus::Pending);
713 assert_eq!(result.data, Value::Null);
714 assert_eq!(result.message, Some("Waiting for confirmation".to_string()));
715 assert!(!result.is_success());
716 assert!(!result.is_error());
717 assert!(result.is_pending());
718 }
719
720 #[test]
721 fn test_tool_result_with_suspension_roundtrip() {
722 let suspension = SuspendTicket::new(
723 Suspension::new("call_1", "tool:confirm")
724 .with_message("Need confirmation")
725 .with_parameters(json!({"message":"hi"})),
726 PendingToolCall::new("call_1", "confirm", json!({"message":"hi"})),
727 ToolCallResumeMode::ReplayToolCall,
728 );
729 let result = ToolResult::suspended_with("confirm", "waiting", suspension.clone());
730
731 assert!(result.is_pending());
732 assert_eq!(result.suspension(), Some(suspension));
733 }
734
735 #[test]
736 fn test_tool_result_warning() {
737 let result = ToolResult::warning("my_tool", json!({"partial": true}), "Some items skipped");
738 assert_eq!(result.tool_name, "my_tool");
739 assert_eq!(result.status, ToolStatus::Warning);
740 assert_eq!(result.data, json!({"partial": true}));
741 assert_eq!(result.message, Some("Some items skipped".to_string()));
742 assert!(result.is_success());
744 assert!(!result.is_error());
745 }
746
747 #[test]
748 fn test_tool_result_with_metadata() {
749 let result = ToolResult::success("my_tool", json!({}))
750 .with_metadata("duration_ms", 150)
751 .with_metadata("retry_count", 2);
752 assert_eq!(result.metadata.get("duration_ms"), Some(&json!(150)));
753 assert_eq!(result.metadata.get("retry_count"), Some(&json!(2)));
754 }
755
756 #[test]
757 fn test_tool_result_serialization() {
758 let result =
759 ToolResult::success("my_tool", json!({"key": "value"})).with_metadata("extra", "data");
760
761 let json = serde_json::to_string(&result).unwrap();
762 let parsed: ToolResult = serde_json::from_str(&json).unwrap();
763
764 assert_eq!(parsed.tool_name, "my_tool");
765 assert_eq!(parsed.status, ToolStatus::Success);
766 assert_eq!(parsed.data, json!({"key": "value"}));
767 }
768
769 #[test]
770 fn test_tool_result_clone() {
771 let result = ToolResult::success("my_tool", json!({"x": 1}));
772 let cloned = result.clone();
773 assert_eq!(result.tool_name, cloned.tool_name);
774 assert_eq!(result.status, cloned.status);
775 }
776
777 #[test]
778 fn test_tool_result_debug() {
779 let result = ToolResult::success("test", json!(null));
780 let debug = format!("{:?}", result);
781 assert!(debug.contains("ToolResult"));
782 assert!(debug.contains("test"));
783 }
784
785 #[test]
790 fn test_tool_descriptor_new() {
791 let desc = ToolDescriptor::new("read_file", "Read File", "Reads a file from disk");
792 assert_eq!(desc.id, "read_file");
793 assert_eq!(desc.name, "Read File");
794 assert_eq!(desc.description, "Reads a file from disk");
795 assert!(desc.category.is_none());
796 assert!(desc.metadata.is_empty());
797 assert_eq!(desc.parameters, json!({"type": "object", "properties": {}}));
799 }
800
801 #[test]
802 fn test_tool_descriptor_with_parameters() {
803 let schema = json!({
804 "type": "object",
805 "properties": {
806 "path": { "type": "string" }
807 },
808 "required": ["path"]
809 });
810 let desc =
811 ToolDescriptor::new("read_file", "Read File", "Read").with_parameters(schema.clone());
812 assert_eq!(desc.parameters, schema);
813 }
814
815 #[test]
816 fn test_tool_descriptor_with_category() {
817 let desc =
818 ToolDescriptor::new("read_file", "Read File", "Read").with_category("filesystem");
819 assert_eq!(desc.category, Some("filesystem".to_string()));
820 }
821
822 #[test]
823 fn test_tool_descriptor_with_metadata() {
824 let desc = ToolDescriptor::new("my_tool", "My Tool", "Description")
825 .with_metadata("version", "1.0")
826 .with_metadata("author", "test");
827 assert_eq!(desc.metadata.get("version"), Some(&json!("1.0")));
828 assert_eq!(desc.metadata.get("author"), Some(&json!("test")));
829 }
830
831 #[test]
832 fn test_tool_descriptor_builder_chain() {
833 let desc = ToolDescriptor::new("tool", "Tool", "Desc")
834 .with_parameters(json!({"type": "object"}))
835 .with_category("test")
836 .with_metadata("key", "value");
837
838 assert_eq!(desc.id, "tool");
839 assert_eq!(desc.category, Some("test".to_string()));
840 assert_eq!(desc.metadata.get("key"), Some(&json!("value")));
841 }
842
843 #[test]
844 fn test_tool_descriptor_serialization() {
845 let desc =
846 ToolDescriptor::new("my_tool", "My Tool", "Does things").with_category("utilities");
847
848 let json = serde_json::to_string(&desc).unwrap();
849 let parsed: ToolDescriptor = serde_json::from_str(&json).unwrap();
850
851 assert_eq!(parsed.id, "my_tool");
852 assert_eq!(parsed.name, "My Tool");
853 assert_eq!(parsed.category, Some("utilities".to_string()));
854 }
855
856 #[test]
857 fn test_tool_descriptor_clone() {
858 let desc = ToolDescriptor::new("tool", "Tool", "Desc").with_category("cat");
859 let cloned = desc.clone();
860 assert_eq!(desc.id, cloned.id);
861 assert_eq!(desc.category, cloned.category);
862 }
863
864 #[test]
865 fn test_tool_descriptor_debug() {
866 let desc = ToolDescriptor::new("tool", "Tool", "Desc");
867 let debug = format!("{:?}", desc);
868 assert!(debug.contains("ToolDescriptor"));
869 assert!(debug.contains("tool"));
870 }
871
872 #[test]
877 fn test_validate_against_schema_valid() {
878 let schema = json!({
879 "type": "object",
880 "properties": {
881 "name": { "type": "string" }
882 },
883 "required": ["name"]
884 });
885 assert!(validate_against_schema(&schema, &json!({"name": "Alice"})).is_ok());
886 }
887
888 #[test]
889 fn test_validate_against_schema_missing_required() {
890 let schema = json!({
891 "type": "object",
892 "properties": {
893 "name": { "type": "string" }
894 },
895 "required": ["name"]
896 });
897 let err = validate_against_schema(&schema, &json!({})).unwrap_err();
898 assert!(matches!(err, ToolError::InvalidArguments(_)));
899 }
900
901 #[test]
902 fn test_validate_against_schema_wrong_type() {
903 let schema = json!({
904 "type": "object",
905 "properties": {
906 "count": { "type": "integer" }
907 },
908 "required": ["count"]
909 });
910 let err = validate_against_schema(&schema, &json!({"count": "not_a_number"})).unwrap_err();
911 assert!(matches!(err, ToolError::InvalidArguments(_)));
912 }
913
914 #[test]
915 fn test_validate_against_schema_empty_schema_accepts_object() {
916 let schema = json!({"type": "object", "properties": {}});
917 assert!(validate_against_schema(&schema, &json!({"anything": true})).is_ok());
918 }
919
920 #[test]
921 fn test_validate_against_schema_multiple_errors_joined() {
922 let schema = json!({
923 "type": "object",
924 "properties": {
925 "name": { "type": "string" },
926 "age": { "type": "integer" }
927 },
928 "required": ["name", "age"]
929 });
930 let err = validate_against_schema(&schema, &json!({})).unwrap_err();
931 let msg = err.to_string();
932 assert!(
934 msg.contains("; "),
935 "expected multiple errors joined by '; ', got: {msg}"
936 );
937 assert!(msg.contains("name"), "expected 'name' in error: {msg}");
938 assert!(msg.contains("age"), "expected 'age' in error: {msg}");
939 }
940
941 #[test]
942 fn test_validate_against_schema_null_args_rejected() {
943 let schema = json!({"type": "object", "properties": {}});
944 let err = validate_against_schema(&schema, &json!(null)).unwrap_err();
945 assert!(matches!(err, ToolError::InvalidArguments(_)));
946 }
947
948 #[test]
949 fn test_validate_against_schema_invalid_schema_returns_internal() {
950 let bad_schema = json!({"type": 123});
952 let err = validate_against_schema(&bad_schema, &json!({})).unwrap_err();
953 assert!(
954 matches!(err, ToolError::Internal(_)),
955 "expected Internal error for invalid schema, got: {err}"
956 );
957 }
958
959 #[test]
960 fn test_validate_against_schema_nested_object() {
961 let schema = json!({
962 "type": "object",
963 "properties": {
964 "address": {
965 "type": "object",
966 "properties": {
967 "city": { "type": "string" }
968 },
969 "required": ["city"]
970 }
971 },
972 "required": ["address"]
973 });
974 assert!(validate_against_schema(&schema, &json!({"address": {"city": "Berlin"}})).is_ok());
976 let err = validate_against_schema(&schema, &json!({"address": {}})).unwrap_err();
978 assert!(matches!(err, ToolError::InvalidArguments(_)));
979 let err = validate_against_schema(&schema, &json!({"address": {"city": 42}})).unwrap_err();
981 assert!(matches!(err, ToolError::InvalidArguments(_)));
982 }
983
984 #[derive(Deserialize, JsonSchema)]
989 struct GreetArgs {
990 name: String,
991 }
992
993 struct GreetTool;
994
995 #[async_trait]
996 impl TypedTool for GreetTool {
997 type Args = GreetArgs;
998 fn tool_id(&self) -> &str {
999 "greet"
1000 }
1001 fn name(&self) -> &str {
1002 "Greet"
1003 }
1004 fn description(&self) -> &str {
1005 "Greet a user"
1006 }
1007
1008 async fn execute(
1009 &self,
1010 args: GreetArgs,
1011 _ctx: &ToolCallContext<'_>,
1012 ) -> Result<ToolResult, ToolError> {
1013 Ok(ToolResult::success(
1014 "greet",
1015 json!({"greeting": format!("Hello, {}!", args.name)}),
1016 ))
1017 }
1018 }
1019
1020 #[test]
1021 fn test_typed_tool_descriptor_schema() {
1022 let tool = GreetTool;
1023 let desc = Tool::descriptor(&tool);
1024 assert_eq!(desc.id, "greet");
1025 assert_eq!(desc.name, "Greet");
1026 assert_eq!(desc.description, "Greet a user");
1027
1028 let props = desc.parameters.get("properties").unwrap();
1029 assert!(props.get("name").is_some());
1030 let required = desc.parameters.get("required").unwrap().as_array().unwrap();
1031 assert!(required.iter().any(|v| v == "name"));
1032 assert!(desc.parameters.get("$schema").is_none());
1034 }
1035
1036 #[tokio::test]
1037 async fn test_typed_tool_execute_success() {
1038 let tool = GreetTool;
1039 let fixture = crate::testing::TestFixture::new();
1040 let ctx = fixture.ctx_with("call_1", "test");
1041 let result = Tool::execute(&tool, json!({"name": "World"}), &ctx)
1042 .await
1043 .unwrap();
1044 assert!(result.is_success());
1045 assert_eq!(result.data["greeting"], "Hello, World!");
1046 }
1047
1048 #[tokio::test]
1049 async fn test_typed_tool_execute_deser_failure() {
1050 let tool = GreetTool;
1051 let fixture = crate::testing::TestFixture::new();
1052 let ctx = fixture.ctx_with("call_1", "test");
1053 let err = Tool::execute(&tool, json!({"name": 123}), &ctx)
1054 .await
1055 .unwrap_err();
1056 assert!(matches!(err, ToolError::InvalidArguments(_)));
1057 }
1058
1059 #[derive(Deserialize, JsonSchema)]
1060 struct PositiveArgs {
1061 value: i64,
1062 }
1063
1064 struct PositiveTool;
1065
1066 #[async_trait]
1067 impl TypedTool for PositiveTool {
1068 type Args = PositiveArgs;
1069 fn tool_id(&self) -> &str {
1070 "positive"
1071 }
1072 fn name(&self) -> &str {
1073 "Positive"
1074 }
1075 fn description(&self) -> &str {
1076 "Requires positive value"
1077 }
1078
1079 fn validate(&self, args: &PositiveArgs) -> Result<(), String> {
1080 if args.value <= 0 {
1081 return Err("value must be positive".into());
1082 }
1083 Ok(())
1084 }
1085
1086 async fn execute(
1087 &self,
1088 args: PositiveArgs,
1089 _ctx: &ToolCallContext<'_>,
1090 ) -> Result<ToolResult, ToolError> {
1091 Ok(ToolResult::success(
1092 "positive",
1093 json!({"value": args.value}),
1094 ))
1095 }
1096 }
1097
1098 #[tokio::test]
1099 async fn test_typed_tool_validate_rejection() {
1100 let tool = PositiveTool;
1101 let fixture = crate::testing::TestFixture::new();
1102 let ctx = fixture.ctx_with("call_1", "test");
1103 let err = Tool::execute(&tool, json!({"value": -1}), &ctx)
1104 .await
1105 .unwrap_err();
1106 assert!(matches!(err, ToolError::InvalidArguments(_)));
1107 assert!(err.to_string().contains("positive"));
1108 }
1109
1110 #[test]
1111 fn test_typed_tool_as_arc_dyn_tool() {
1112 let tool: std::sync::Arc<dyn Tool> = std::sync::Arc::new(GreetTool);
1113 let desc = tool.descriptor();
1114 assert_eq!(desc.id, "greet");
1115 }
1116
1117 #[test]
1118 fn test_typed_tool_skips_schema_validation() {
1119 let tool = GreetTool;
1120 assert!(tool.validate_args(&json!({})).is_ok());
1122 assert!(tool.validate_args(&json!({"wrong": 123})).is_ok());
1123 assert!(tool.validate_args(&json!(null)).is_ok());
1124 }
1125
1126 #[derive(Deserialize, JsonSchema)]
1129 struct OptionalArgs {
1130 required_field: String,
1131 optional_field: Option<i64>,
1132 }
1133
1134 struct OptionalTool;
1135
1136 #[async_trait]
1137 impl TypedTool for OptionalTool {
1138 type Args = OptionalArgs;
1139 fn tool_id(&self) -> &str {
1140 "optional"
1141 }
1142 fn name(&self) -> &str {
1143 "Optional"
1144 }
1145 fn description(&self) -> &str {
1146 "Tool with optional field"
1147 }
1148
1149 async fn execute(
1150 &self,
1151 args: OptionalArgs,
1152 _ctx: &ToolCallContext<'_>,
1153 ) -> Result<ToolResult, ToolError> {
1154 Ok(ToolResult::success(
1155 "optional",
1156 json!({
1157 "required": args.required_field,
1158 "optional": args.optional_field,
1159 }),
1160 ))
1161 }
1162 }
1163
1164 #[tokio::test]
1165 async fn test_typed_tool_optional_field_absent() {
1166 let tool = OptionalTool;
1167 let fixture = crate::testing::TestFixture::new();
1168 let ctx = fixture.ctx_with("call_1", "test");
1169 let result = Tool::execute(&tool, json!({"required_field": "hi"}), &ctx)
1170 .await
1171 .unwrap();
1172 assert!(result.is_success());
1173 assert_eq!(result.data["optional"], json!(null));
1174 }
1175
1176 #[tokio::test]
1177 async fn test_typed_tool_extra_fields_ignored() {
1178 let tool = GreetTool;
1179 let fixture = crate::testing::TestFixture::new();
1180 let ctx = fixture.ctx_with("call_1", "test");
1181 let result = Tool::execute(&tool, json!({"name": "World", "extra": 999}), &ctx)
1183 .await
1184 .unwrap();
1185 assert!(result.is_success());
1186 assert_eq!(result.data["greeting"], "Hello, World!");
1187 }
1188
1189 #[tokio::test]
1190 async fn test_typed_tool_empty_json_all_required() {
1191 let tool = GreetTool;
1192 let fixture = crate::testing::TestFixture::new();
1193 let ctx = fixture.ctx_with("call_1", "test");
1194 let err = Tool::execute(&tool, json!({}), &ctx).await.unwrap_err();
1195 assert!(matches!(err, ToolError::InvalidArguments(_)));
1196 }
1197
1198 #[tokio::test]
1199 async fn test_default_execute_effect_wraps_execute_result() {
1200 let tool = GreetTool;
1201 let fixture = crate::testing::TestFixture::new();
1202 let ctx = fixture.ctx_with("call_1", "test");
1203
1204 let effect = Tool::execute_effect(&tool, json!({"name": "World"}), &ctx)
1205 .await
1206 .expect("execute_effect should succeed");
1207
1208 assert_eq!(effect.result.tool_name, "greet");
1209 assert!(effect.result.is_success());
1210 let (_, actions) = effect.into_parts();
1211 assert!(actions.is_empty());
1212 }
1213
1214 struct ContextWriteDefaultTool;
1215
1216 #[async_trait]
1217 impl Tool for ContextWriteDefaultTool {
1218 fn descriptor(&self) -> ToolDescriptor {
1219 ToolDescriptor::new(
1220 "context_write_default",
1221 "ContextWriteDefault",
1222 "writes state in execute",
1223 )
1224 }
1225
1226 async fn execute(
1227 &self,
1228 _args: Value,
1229 ctx: &ToolCallContext<'_>,
1230 ) -> Result<ToolResult, ToolError> {
1231 ctx.state_of::<TestFixtureState>()
1232 .set_label(Some("default_execute_write".to_string()))
1233 .expect("failed to set label");
1234 Ok(ToolResult::success(
1235 "context_write_default",
1236 json!({"ok": true}),
1237 ))
1238 }
1239 }
1240
1241 #[tokio::test]
1242 async fn test_default_execute_effect_leaves_context_writes_in_context() {
1243 let tool = ContextWriteDefaultTool;
1244 let fixture = crate::testing::TestFixture::new();
1245 let ctx = fixture.ctx_with("call_1", "test");
1246
1247 let effect = Tool::execute_effect(&tool, json!({}), &ctx)
1248 .await
1249 .expect("execute_effect should succeed");
1250
1251 assert!(effect.result.is_success());
1252 let (_, actions) = effect.into_parts();
1253 assert!(actions.is_empty());
1254 assert!(!ctx.take_patch().patch().is_empty());
1255 }
1256
1257 #[derive(Debug, Clone, Default, Serialize, Deserialize)]
1258 struct ToolEffectState {
1259 value: i64,
1260 }
1261
1262 struct ToolEffectStateRef;
1263
1264 impl State for ToolEffectState {
1265 type Ref<'a> = ToolEffectStateRef;
1266 const PATH: &'static str = "tool_effect";
1267
1268 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
1269 ToolEffectStateRef
1270 }
1271
1272 fn from_value(value: &Value) -> TireaResult<Self> {
1273 if value.is_null() {
1274 return Ok(Self::default());
1275 }
1276 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
1277 }
1278
1279 fn to_value(&self) -> TireaResult<Value> {
1280 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
1281 }
1282 }
1283
1284 impl StateSpec for ToolEffectState {
1285 type Action = i64;
1286
1287 fn reduce(&mut self, action: Self::Action) {
1288 self.value += action;
1289 }
1290 }
1291
1292 struct EffectOnlyTool;
1293
1294 #[async_trait]
1295 impl Tool for EffectOnlyTool {
1296 fn descriptor(&self) -> ToolDescriptor {
1297 ToolDescriptor::new("effect_only", "EffectOnly", "returns state actions")
1298 }
1299
1300 async fn execute(
1301 &self,
1302 _args: Value,
1303 _ctx: &ToolCallContext<'_>,
1304 ) -> Result<ToolResult, ToolError> {
1305 Ok(ToolResult::success("effect_only", json!({"ok": true})))
1306 }
1307
1308 async fn execute_effect(
1309 &self,
1310 _args: Value,
1311 _ctx: &ToolCallContext<'_>,
1312 ) -> Result<ToolExecutionEffect, ToolError> {
1313 Ok(
1314 ToolExecutionEffect::new(ToolResult::success("effect_only", json!({"ok": true})))
1315 .with_action(AnyStateAction::new::<ToolEffectState>(1)),
1316 )
1317 }
1318 }
1319
1320 #[tokio::test]
1321 async fn test_tool_can_return_state_actions_via_execute_effect() {
1322 let tool = EffectOnlyTool;
1323 let fixture = crate::testing::TestFixture::new();
1324 let ctx = fixture.ctx_with("call_1", "test");
1325
1326 let effect = Tool::execute_effect(&tool, json!({}), &ctx)
1327 .await
1328 .expect("effect tool should succeed");
1329
1330 assert!(effect.result.is_success());
1331 let (_, actions) = effect.into_parts();
1332 assert_eq!(actions.len(), 1);
1333 let action = actions.into_iter().next().unwrap();
1334 match action {
1335 crate::runtime::phase::AfterToolExecuteAction::State(sa) => {
1336 assert!(sa.state_type_name().contains("ToolEffectState"));
1337 }
1338 _ => panic!("expected State action"),
1339 }
1340 }
1341
1342 #[derive(Deserialize, JsonSchema)]
1345 struct IncrementArgs {
1346 amount: i64,
1347 }
1348
1349 struct TypedEffectTool;
1350
1351 #[async_trait]
1352 impl TypedTool for TypedEffectTool {
1353 type Args = IncrementArgs;
1354 fn tool_id(&self) -> &str {
1355 "typed_effect"
1356 }
1357 fn name(&self) -> &str {
1358 "TypedEffect"
1359 }
1360 fn description(&self) -> &str {
1361 "Typed tool with execute_effect"
1362 }
1363
1364 async fn execute(
1365 &self,
1366 args: IncrementArgs,
1367 _ctx: &ToolCallContext<'_>,
1368 ) -> Result<ToolResult, ToolError> {
1369 Ok(ToolResult::success(
1370 "typed_effect",
1371 json!({"amount": args.amount}),
1372 ))
1373 }
1374
1375 async fn execute_effect(
1376 &self,
1377 args: IncrementArgs,
1378 _ctx: &ToolCallContext<'_>,
1379 ) -> Result<ToolExecutionEffect, ToolError> {
1380 Ok(ToolExecutionEffect::new(ToolResult::success(
1381 "typed_effect",
1382 json!({"amount": args.amount}),
1383 ))
1384 .with_action(AnyStateAction::new::<ToolEffectState>(args.amount)))
1385 }
1386 }
1387
1388 #[tokio::test]
1389 async fn test_typed_tool_execute_effect_override() {
1390 let tool = TypedEffectTool;
1391 let fixture = crate::testing::TestFixture::new();
1392 let ctx = fixture.ctx_with("call_1", "test");
1393
1394 let effect = Tool::execute_effect(&tool, json!({"amount": 5}), &ctx)
1395 .await
1396 .expect("typed execute_effect should succeed");
1397
1398 assert!(effect.result.is_success());
1399 assert_eq!(effect.result.data["amount"], 5);
1400 let (_, actions) = effect.into_parts();
1401 assert_eq!(actions.len(), 1);
1402 let action = actions.into_iter().next().unwrap();
1403 match action {
1404 crate::runtime::phase::AfterToolExecuteAction::State(sa) => {
1405 assert!(sa.state_type_name().contains("ToolEffectState"));
1406 }
1407 _ => panic!("expected State action"),
1408 }
1409 }
1410
1411 #[tokio::test]
1412 async fn test_typed_tool_default_execute_effect_delegates_to_execute() {
1413 let tool = GreetTool;
1414 let fixture = crate::testing::TestFixture::new();
1415 let ctx = fixture.ctx_with("call_1", "test");
1416
1417 let effect = Tool::execute_effect(&tool, json!({"name": "TypedDefault"}), &ctx)
1418 .await
1419 .expect("default execute_effect should succeed");
1420
1421 assert!(effect.result.is_success());
1422 assert_eq!(effect.result.data["greeting"], "Hello, TypedDefault!");
1423 let (_, actions) = effect.into_parts();
1424 assert!(actions.is_empty());
1425 }
1426}