1use super::manager::BackgroundTaskManager;
7use super::{TaskState, TaskStatus, TaskStore, TaskSummary};
8use crate::contracts::runtime::tool_call::{ToolCallContext, ToolError, ToolResult};
9use crate::contracts::storage::ThreadStore;
10use async_trait::async_trait;
11use schemars::JsonSchema;
12use serde::Deserialize;
13use serde_json::{json, Value};
14use std::collections::HashMap;
15use std::sync::Arc;
16
17pub const TASK_STATUS_TOOL_ID: &str = "task_status";
18pub const TASK_CANCEL_TOOL_ID: &str = "task_cancel";
19pub const TASK_OUTPUT_TOOL_ID: &str = "task_output";
20
21fn owner_thread_id(ctx: &ToolCallContext<'_>) -> Option<String> {
22 ctx.caller_context().thread_id().map(str::to_string)
23}
24
25#[derive(Debug, Clone)]
33pub struct TaskStatusTool {
34 manager: Arc<BackgroundTaskManager>,
35 task_store: Option<Arc<TaskStore>>,
36}
37
38impl TaskStatusTool {
39 pub fn new(manager: Arc<BackgroundTaskManager>) -> Self {
40 Self {
41 manager,
42 task_store: None,
43 }
44 }
45
46 pub fn with_task_store(mut self, task_store: Option<Arc<TaskStore>>) -> Self {
47 self.task_store = task_store;
48 self
49 }
50
51 async fn query_one(
52 &self,
53 owner_thread_id: &str,
54 task_id: &str,
55 ) -> Result<Option<TaskSummary>, String> {
56 let persisted = if let Some(store) = &self.task_store {
57 store
58 .load_task_for_owner(owner_thread_id, task_id)
59 .await
60 .map_err(|e| e.to_string())?
61 .map(|task| task.summary())
62 } else {
63 None
64 };
65 let live = self.manager.get(owner_thread_id, task_id).await;
66
67 Ok(match (persisted, live) {
68 (_, Some(live)) => Some(live),
69 (Some(task), None) => Some(task),
70 (None, None) => None,
71 })
72 }
73
74 async fn list_all(&self, owner_thread_id: &str) -> Result<Vec<TaskSummary>, String> {
75 let mut by_id: HashMap<String, TaskSummary> = HashMap::new();
76
77 if let Some(store) = &self.task_store {
78 let tasks = store
79 .list_tasks_for_owner(owner_thread_id)
80 .await
81 .map_err(|e| e.to_string())?;
82 for task in tasks {
83 by_id.insert(task.id.clone(), task.summary());
84 }
85 }
86
87 for summary in self.manager.list(owner_thread_id, None).await {
88 by_id.insert(summary.task_id.clone(), summary);
89 }
90
91 let mut out: Vec<TaskSummary> = by_id.into_values().collect();
92 out.sort_by(|a, b| {
93 a.created_at_ms
94 .cmp(&b.created_at_ms)
95 .then_with(|| a.task_id.cmp(&b.task_id))
96 });
97 Ok(out)
98 }
99}
100
101#[derive(Debug, Deserialize, JsonSchema)]
102pub struct TaskStatusArgs {
103 task_id: Option<String>,
105}
106
107#[async_trait]
108impl crate::contracts::runtime::tool_call::TypedTool for TaskStatusTool {
109 type Args = TaskStatusArgs;
110
111 fn tool_id(&self) -> &str {
112 TASK_STATUS_TOOL_ID
113 }
114 fn name(&self) -> &str {
115 "Task Status"
116 }
117 fn description(&self) -> &str {
118 "Check the status and result of background tasks. \
119 Provide task_id to query a specific task, or omit to list all tasks."
120 }
121
122 async fn execute(
123 &self,
124 args: TaskStatusArgs,
125 ctx: &ToolCallContext<'_>,
126 ) -> Result<ToolResult, ToolError> {
127 let Some(thread_id) = owner_thread_id(ctx) else {
128 return Ok(ToolResult::error(
129 TASK_STATUS_TOOL_ID,
130 "Missing caller thread context",
131 ));
132 };
133
134 let task_id = args.task_id.as_deref().filter(|s| !s.trim().is_empty());
135
136 if let Some(task_id) = task_id {
137 match self.query_one(&thread_id, task_id).await {
138 Ok(Some(summary)) => Ok(ToolResult::success(
139 TASK_STATUS_TOOL_ID,
140 serde_json::to_value(&summary).unwrap_or(Value::Null),
141 )),
142 Ok(None) => Ok(ToolResult::error(
143 TASK_STATUS_TOOL_ID,
144 format!("Unknown task_id: {task_id}"),
145 )),
146 Err(err) => Ok(ToolResult::error(TASK_STATUS_TOOL_ID, err)),
147 }
148 } else {
149 match self.list_all(&thread_id).await {
150 Ok(tasks) => Ok(ToolResult::success(
151 TASK_STATUS_TOOL_ID,
152 json!({
153 "tasks": serde_json::to_value(&tasks).unwrap_or(Value::Null),
154 "total": tasks.len(),
155 }),
156 )),
157 Err(err) => Ok(ToolResult::error(TASK_STATUS_TOOL_ID, err)),
158 }
159 }
160 }
161}
162
163#[derive(Debug, Clone)]
169pub struct TaskCancelTool {
170 manager: Arc<BackgroundTaskManager>,
171 task_store: Option<Arc<TaskStore>>,
172}
173
174impl TaskCancelTool {
175 pub fn new(manager: Arc<BackgroundTaskManager>) -> Self {
176 Self {
177 manager,
178 task_store: None,
179 }
180 }
181
182 pub fn with_task_store(mut self, task_store: Option<Arc<TaskStore>>) -> Self {
183 self.task_store = task_store;
184 self
185 }
186}
187
188#[derive(Debug, Deserialize, JsonSchema)]
189pub struct TaskCancelArgs {
190 task_id: String,
192}
193
194#[async_trait]
195impl crate::contracts::runtime::tool_call::TypedTool for TaskCancelTool {
196 type Args = TaskCancelArgs;
197
198 fn tool_id(&self) -> &str {
199 TASK_CANCEL_TOOL_ID
200 }
201 fn name(&self) -> &str {
202 "Task Cancel"
203 }
204 fn description(&self) -> &str {
205 "Cancel a running background task by task_id. \
206 Descendant tasks are cancelled automatically."
207 }
208
209 fn validate(&self, args: &Self::Args) -> Result<(), String> {
210 if args.task_id.trim().is_empty() {
211 return Err("task_id cannot be empty".to_string());
212 }
213 Ok(())
214 }
215
216 async fn execute(
217 &self,
218 args: TaskCancelArgs,
219 ctx: &ToolCallContext<'_>,
220 ) -> Result<ToolResult, ToolError> {
221 let task_id = &args.task_id;
222
223 let Some(thread_id) = owner_thread_id(ctx) else {
224 return Ok(ToolResult::error(
225 TASK_CANCEL_TOOL_ID,
226 "Missing caller thread context",
227 ));
228 };
229
230 match self.manager.cancel_tree(&thread_id, task_id).await {
231 Ok(cancelled) => {
232 let mut persistence_failures = Vec::new();
233 if let Some(store) = &self.task_store {
234 for summary in &cancelled {
235 if let Err(error) = store.mark_cancel_requested(&summary.task_id).await {
236 tracing::warn!(
237 root_task_id = %task_id,
238 cancelled_task_id = %summary.task_id,
239 owner_thread_id = %thread_id,
240 error = %error,
241 "failed to persist background task cancellation marker"
242 );
243 persistence_failures.push((summary.task_id.clone(), error.to_string()));
244 }
245 }
246 }
247 let ids: Vec<&str> = cancelled.iter().map(|s| s.task_id.as_str()).collect();
248 let mut data = json!({
249 "task_id": task_id,
250 "cancelled": true,
251 "cancelled_ids": ids,
252 "cancelled_count": cancelled.len(),
253 });
254 if persistence_failures.is_empty() {
255 return Ok(ToolResult::success(TASK_CANCEL_TOOL_ID, data));
256 }
257
258 data["persistence_warning"] = json!({
259 "failed_count": persistence_failures.len(),
260 "failures": persistence_failures.iter().map(|(task_id, error)| json!({
261 "task_id": task_id,
262 "error": error,
263 })).collect::<Vec<_>>(),
264 });
265
266 Ok(ToolResult::warning(
267 TASK_CANCEL_TOOL_ID,
268 data,
269 format!(
270 "Cancellation requested, but failed to persist cancellation markers for {} task(s).",
271 persistence_failures.len()
272 ),
273 ))
274 }
275 Err(e) => Ok(ToolResult::error(TASK_CANCEL_TOOL_ID, e)),
276 }
277 }
278}
279
280#[derive(Clone)]
291pub struct TaskOutputTool {
292 manager: Arc<BackgroundTaskManager>,
293 task_store: Option<Arc<TaskStore>>,
294}
295
296impl std::fmt::Debug for TaskOutputTool {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 f.debug_struct("TaskOutputTool")
299 .field("has_task_store", &self.task_store.is_some())
300 .finish()
301 }
302}
303
304impl TaskOutputTool {
305 pub fn new(
306 manager: Arc<BackgroundTaskManager>,
307 thread_store: Option<Arc<dyn ThreadStore>>,
308 ) -> Self {
309 Self {
310 manager,
311 task_store: thread_store.map(TaskStore::new).map(Arc::new),
312 }
313 }
314
315 pub fn with_task_store(mut self, task_store: Option<Arc<TaskStore>>) -> Self {
316 self.task_store = task_store;
317 self
318 }
319}
320
321#[derive(Debug, Deserialize, JsonSchema)]
322pub struct TaskOutputArgs {
323 task_id: String,
325}
326
327#[async_trait]
328impl crate::contracts::runtime::tool_call::TypedTool for TaskOutputTool {
329 type Args = TaskOutputArgs;
330
331 fn tool_id(&self) -> &str {
332 TASK_OUTPUT_TOOL_ID
333 }
334 fn name(&self) -> &str {
335 "Task Output"
336 }
337 fn description(&self) -> &str {
338 "Read the output of a background task. \
339 For agent runs, returns the last assistant message. \
340 For other tasks, returns the task result."
341 }
342
343 fn validate(&self, args: &Self::Args) -> Result<(), String> {
344 if args.task_id.trim().is_empty() {
345 return Err("task_id cannot be empty".to_string());
346 }
347 Ok(())
348 }
349
350 async fn execute(
351 &self,
352 args: TaskOutputArgs,
353 ctx: &ToolCallContext<'_>,
354 ) -> Result<ToolResult, ToolError> {
355 let task_id = &args.task_id;
356
357 let Some(thread_id) = owner_thread_id(ctx) else {
358 return Ok(ToolResult::error(
359 TASK_OUTPUT_TOOL_ID,
360 "Missing caller thread context",
361 ));
362 };
363
364 let Some(task_store) = &self.task_store else {
365 if let Some(summary) = self.manager.get(&thread_id, task_id).await {
366 return Ok(ToolResult::success(
367 TASK_OUTPUT_TOOL_ID,
368 json!({
369 "task_id": task_id,
370 "task_type": summary.task_type,
371 "status": summary.status.as_str(),
372 "output": summary.result,
373 }),
374 ));
375 }
376 return Ok(ToolResult::error(
377 TASK_OUTPUT_TOOL_ID,
378 format!("Unknown task_id: {task_id}"),
379 ));
380 };
381
382 let Some(task) = task_store
383 .load_task_for_owner(&thread_id, task_id)
384 .await
385 .map_err(|e| ToolError::ExecutionFailed(format!("task store lookup failed: {e}")))?
386 else {
387 return Ok(ToolResult::error(
388 TASK_OUTPUT_TOOL_ID,
389 format!("Unknown task_id: {task_id}"),
390 ));
391 };
392
393 Ok(self.output_from_task(task_id, &task).await)
394 }
395}
396
397impl TaskOutputTool {
398 async fn output_from_task(&self, task_id: &str, task: &TaskState) -> ToolResult {
399 let live = self
400 .manager
401 .get(&task.owner_thread_id, task_id)
402 .await
403 .filter(|summary| summary.status == task.status || task.status == TaskStatus::Running);
404
405 let output = if task.task_type == "agent_run" {
406 match &self.task_store {
407 Some(store) => match store.load_output_text(task).await {
408 Ok(output) => output.map(Value::String),
409 Err(e) => {
410 return ToolResult::error(TASK_OUTPUT_TOOL_ID, e.to_string());
411 }
412 },
413 None => None,
414 }
415 } else {
416 live.and_then(|summary| summary.result)
417 .or_else(|| task.result.clone())
418 };
419
420 ToolResult::success(
421 TASK_OUTPUT_TOOL_ID,
422 json!({
423 "task_id": task_id,
424 "task_type": task.task_type.clone(),
425 "agent_id": task.metadata.get("agent_id").cloned().unwrap_or(Value::Null),
426 "status": task.status.as_str(),
427 "output": output,
428 }),
429 )
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::contracts::runtime::tool_call::{CallerContext, Tool};
437 use crate::contracts::storage::{
438 Committed, MessagePage, MessageQuery, RunPage, RunQuery, RunRecord, ThreadHead,
439 ThreadListPage, ThreadListQuery, ThreadReader, ThreadStore, ThreadStoreError, ThreadWriter,
440 VersionPrecondition,
441 };
442 use crate::contracts::thread::{Thread, ThreadChangeSet};
443 use crate::runtime::background_tasks::SpawnParams;
444 use async_trait::async_trait;
445 use std::sync::atomic::{AtomicUsize, Ordering};
446 use tirea_contract::testing::TestFixture;
447
448 fn fixture_with_thread(thread_id: &str) -> TestFixture {
449 let mut fix = TestFixture::new();
450 fix.caller_context = CallerContext::new(
451 Some(thread_id.to_string()),
452 Some("caller-run".to_string()),
453 Some("caller-agent".to_string()),
454 vec![],
455 );
456 fix
457 }
458
459 struct FailingCancelMarkStore {
460 inner: Arc<tirea_store_adapters::MemoryStore>,
461 fail_task_appends: AtomicUsize,
462 }
463
464 #[async_trait]
465 impl ThreadReader for FailingCancelMarkStore {
466 async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
467 self.inner.load(thread_id).await
468 }
469
470 async fn list_threads(
471 &self,
472 query: &ThreadListQuery,
473 ) -> Result<ThreadListPage, ThreadStoreError> {
474 self.inner.list_threads(query).await
475 }
476
477 async fn load_messages(
478 &self,
479 thread_id: &str,
480 query: &MessageQuery,
481 ) -> Result<MessagePage, ThreadStoreError> {
482 self.inner.load_messages(thread_id, query).await
483 }
484
485 async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, ThreadStoreError> {
486 self.inner.load_run(run_id).await
487 }
488
489 async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, ThreadStoreError> {
490 self.inner.list_runs(query).await
491 }
492
493 async fn active_run_for_thread(
494 &self,
495 thread_id: &str,
496 ) -> Result<Option<RunRecord>, ThreadStoreError> {
497 self.inner.active_run_for_thread(thread_id).await
498 }
499 }
500
501 #[async_trait]
502 impl ThreadWriter for FailingCancelMarkStore {
503 async fn create(&self, thread: &Thread) -> Result<Committed, ThreadStoreError> {
504 self.inner.create(thread).await
505 }
506
507 async fn append(
508 &self,
509 thread_id: &str,
510 delta: &ThreadChangeSet,
511 precondition: VersionPrecondition,
512 ) -> Result<Committed, ThreadStoreError> {
513 if thread_id.starts_with(super::super::TASK_THREAD_PREFIX)
514 && self
515 .fail_task_appends
516 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| {
517 if remaining > 0 {
518 Some(remaining - 1)
519 } else {
520 None
521 }
522 })
523 .is_ok()
524 {
525 return Err(ThreadStoreError::Io(std::io::Error::other(
526 "injected cancel mark persistence failure",
527 )));
528 }
529
530 self.inner.append(thread_id, delta, precondition).await
531 }
532
533 async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
534 self.inner.delete(thread_id).await
535 }
536
537 async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
538 self.inner.save(thread).await
539 }
540 }
541
542 #[test]
543 fn task_status_descriptor_has_optional_task_id() {
544 let mgr = Arc::new(BackgroundTaskManager::new());
545 let tool = TaskStatusTool::new(mgr);
546 let desc = tool.descriptor();
547 assert_eq!(desc.id, TASK_STATUS_TOOL_ID);
548 let required = desc.parameters.get("required");
550 assert!(required.is_none() || required.unwrap().as_array().unwrap().is_empty());
551 assert!(desc.parameters["properties"].get("task_id").is_some());
552 }
553
554 #[test]
555 fn task_cancel_descriptor_requires_task_id() {
556 let mgr = Arc::new(BackgroundTaskManager::new());
557 let tool = TaskCancelTool::new(mgr);
558 let desc = tool.descriptor();
559 assert_eq!(desc.id, TASK_CANCEL_TOOL_ID);
560 let required = desc.parameters["required"].as_array().unwrap();
561 assert!(required.contains(&json!("task_id")));
562 }
563
564 #[tokio::test]
569 async fn status_tool_missing_thread_context_returns_error() {
570 let mgr = Arc::new(BackgroundTaskManager::new());
571 let tool = TaskStatusTool::new(mgr);
572 let fix = TestFixture::new(); let result = tool.execute(json!({}), &fix.ctx()).await.unwrap();
574 assert!(!result.is_success());
575 assert!(result
576 .message
577 .as_deref()
578 .unwrap_or("")
579 .contains("Missing caller thread context"));
580 }
581
582 #[tokio::test]
583 async fn status_tool_list_all_when_no_tasks() {
584 let mgr = Arc::new(BackgroundTaskManager::new());
585 let tool = TaskStatusTool::new(mgr);
586 let fix = fixture_with_thread("thread-1");
587 let result = tool.execute(json!({}), &fix.ctx()).await.unwrap();
588 assert!(result.is_success());
589 let content: Value = result.data.clone();
590 assert_eq!(content["total"], 0);
591 assert!(content["tasks"].as_array().unwrap().is_empty());
592 }
593
594 #[tokio::test]
595 async fn status_tool_query_single_task() {
596 let mgr = Arc::new(BackgroundTaskManager::new());
597 let tid = mgr
598 .spawn("thread-1", "shell", "echo hi", |_cancel| async {
599 super::super::types::TaskResult::Success(json!({"exit": 0}))
600 })
601 .await;
602 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
603
604 let tool = TaskStatusTool::new(mgr);
605 let fix = fixture_with_thread("thread-1");
606 let result = tool
607 .execute(json!({"task_id": tid}), &fix.ctx())
608 .await
609 .unwrap();
610 assert!(result.is_success());
611 let content: Value = result.data.clone();
612 assert_eq!(content["status"], "completed");
613 assert_eq!(content["result"]["exit"], 0);
614 }
615
616 #[tokio::test]
617 async fn status_tool_query_unknown_task_returns_error() {
618 let mgr = Arc::new(BackgroundTaskManager::new());
619 let tool = TaskStatusTool::new(mgr);
620 let fix = fixture_with_thread("thread-1");
621 let result = tool
622 .execute(json!({"task_id": "bogus"}), &fix.ctx())
623 .await
624 .unwrap();
625 assert!(!result.is_success());
626 assert!(result
627 .message
628 .as_deref()
629 .unwrap_or("")
630 .contains("Unknown task_id"));
631 }
632
633 #[tokio::test]
634 async fn status_tool_list_shows_running_and_completed() {
635 let mgr = Arc::new(BackgroundTaskManager::new());
636 let _running = mgr
638 .spawn("thread-1", "shell", "long", |cancel| async move {
639 cancel.cancelled().await;
640 super::super::types::TaskResult::Cancelled
641 })
642 .await;
643 mgr.spawn("thread-1", "http", "fetch", |_| async {
645 super::super::types::TaskResult::Success(Value::Null)
646 })
647 .await;
648 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
649
650 let tool = TaskStatusTool::new(mgr);
651 let fix = fixture_with_thread("thread-1");
652 let result = tool.execute(json!({}), &fix.ctx()).await.unwrap();
653 assert!(result.is_success());
654 let content: Value = result.data.clone();
655 assert_eq!(content["total"], 2);
656 }
657
658 #[tokio::test]
659 async fn status_tool_thread_isolation() {
660 let mgr = Arc::new(BackgroundTaskManager::new());
661 let tid = mgr
662 .spawn("thread-A", "shell", "private", |_| async {
663 super::super::types::TaskResult::Success(Value::Null)
664 })
665 .await;
666 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
667
668 let tool = TaskStatusTool::new(mgr);
669
670 let fix_b = fixture_with_thread("thread-B");
672 let result = tool
673 .execute(json!({"task_id": tid}), &fix_b.ctx())
674 .await
675 .unwrap();
676 assert!(!result.is_success());
677
678 let fix_a = fixture_with_thread("thread-A");
680 let result = tool
681 .execute(json!({"task_id": tid}), &fix_a.ctx())
682 .await
683 .unwrap();
684 assert!(result.is_success());
685 }
686
687 #[tokio::test]
688 async fn status_tool_reads_persisted_task_without_live_manager() {
689 let mgr = Arc::new(BackgroundTaskManager::new());
690 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
691 let task_store = Arc::new(TaskStore::new(storage as Arc<dyn ThreadStore>));
692 task_store
693 .create_task(super::super::NewTaskSpec {
694 task_id: "task-1".to_string(),
695 owner_thread_id: "thread-1".to_string(),
696 task_type: "shell".to_string(),
697 description: "persisted only".to_string(),
698 parent_task_id: None,
699 supports_resume: false,
700 metadata: json!({}),
701 })
702 .await
703 .unwrap();
704 task_store
705 .persist_foreground_result(
706 "task-1",
707 TaskStatus::Completed,
708 None,
709 Some(json!({"stdout":"done"})),
710 )
711 .await
712 .unwrap();
713
714 let tool = TaskStatusTool::new(mgr).with_task_store(Some(task_store));
715 let fix = fixture_with_thread("thread-1");
716 let result = tool
717 .execute(json!({"task_id": "task-1"}), &fix.ctx())
718 .await
719 .unwrap();
720
721 assert!(result.is_success());
722 assert_eq!(result.data["status"], "completed");
723 assert_eq!(result.data["result"]["stdout"], "done");
724 }
725
726 #[tokio::test]
727 async fn status_tool_does_not_read_cached_derived_view() {
728 let mgr = Arc::new(BackgroundTaskManager::new());
729 let tool = TaskStatusTool::new(mgr);
730 let mut fix = TestFixture::new_with_state(json!({
731 "__derived": {
732 "background_tasks": {
733 "tasks": {
734 "ghost": {
735 "task_type": "shell",
736 "description": "ghost task",
737 "status": "running"
738 }
739 },
740 "synced_at_ms": 1
741 }
742 }
743 }));
744 fix.caller_context = CallerContext::new(
745 Some("thread-1".to_string()),
746 Some("caller-run".to_string()),
747 Some("caller-agent".to_string()),
748 vec![],
749 );
750
751 let result = tool
752 .execute(json!({"task_id": "ghost"}), &fix.ctx())
753 .await
754 .unwrap();
755 assert!(!result.is_success());
756 assert!(result
757 .message
758 .as_deref()
759 .unwrap_or("")
760 .contains("Unknown task_id"));
761 }
762
763 #[tokio::test]
768 async fn cancel_tool_missing_thread_context_returns_error() {
769 let mgr = Arc::new(BackgroundTaskManager::new());
770 let tool = TaskCancelTool::new(mgr);
771 let fix = TestFixture::new();
772 let result = tool
773 .execute(json!({"task_id": "some"}), &fix.ctx())
774 .await
775 .unwrap();
776 assert!(!result.is_success());
777 assert!(result
778 .message
779 .as_deref()
780 .unwrap_or("")
781 .contains("Missing caller thread context"));
782 }
783
784 #[tokio::test]
785 async fn cancel_tool_missing_task_id_param() {
786 let mgr = Arc::new(BackgroundTaskManager::new());
787 let tool = TaskCancelTool::new(mgr);
788 let fix = fixture_with_thread("thread-1");
789 let err = tool.execute(json!({}), &fix.ctx()).await.unwrap_err();
790 assert!(matches!(err, ToolError::InvalidArguments(_)));
791 }
792
793 #[tokio::test]
794 async fn cancel_tool_cancels_running_task() {
795 let mgr = Arc::new(BackgroundTaskManager::new());
796 let tid = mgr
797 .spawn("thread-1", "shell", "long", |cancel| async move {
798 cancel.cancelled().await;
799 super::super::types::TaskResult::Cancelled
800 })
801 .await;
802
803 let tool = TaskCancelTool::new(mgr.clone());
804 let fix = fixture_with_thread("thread-1");
805 let result = tool
806 .execute(json!({"task_id": tid}), &fix.ctx())
807 .await
808 .unwrap();
809 assert!(result.is_success());
810 let content: Value = result.data.clone();
811 assert!(content["cancelled"].as_bool().unwrap());
812
813 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
814 let summary = mgr.get("thread-1", &tid).await.unwrap();
815 assert_eq!(summary.status, super::super::types::TaskStatus::Cancelled);
816 }
817
818 #[tokio::test]
819 async fn cancel_tool_cancels_descendants_by_default() {
820 let mgr = Arc::new(BackgroundTaskManager::new());
821 let root_token = crate::runtime::loop_runner::RunCancellationToken::new();
822 let child_token = crate::runtime::loop_runner::RunCancellationToken::new();
823
824 mgr.spawn_with_id(
825 SpawnParams {
826 task_id: "root".to_string(),
827 owner_thread_id: "thread-1".to_string(),
828 task_type: "agent_run".to_string(),
829 description: "agent:root".to_string(),
830 parent_task_id: None,
831 metadata: json!({}),
832 },
833 root_token,
834 |cancel| async move {
835 cancel.cancelled().await;
836 super::super::types::TaskResult::Cancelled
837 },
838 )
839 .await;
840
841 mgr.spawn_with_id(
842 SpawnParams {
843 task_id: "child".to_string(),
844 owner_thread_id: "thread-1".to_string(),
845 task_type: "agent_run".to_string(),
846 description: "agent:child".to_string(),
847 parent_task_id: Some("root".to_string()),
848 metadata: json!({}),
849 },
850 child_token,
851 |cancel| async move {
852 cancel.cancelled().await;
853 super::super::types::TaskResult::Cancelled
854 },
855 )
856 .await;
857
858 let tool = TaskCancelTool::new(mgr.clone());
859 let fix = fixture_with_thread("thread-1");
860 let result = tool
861 .execute(json!({"task_id": "root"}), &fix.ctx())
862 .await
863 .unwrap();
864
865 assert!(result.is_success());
866 assert_eq!(result.data["cancelled_count"], 2);
867 assert!(result.data["cancelled_ids"]
868 .as_array()
869 .unwrap()
870 .iter()
871 .any(|v| v == "root"));
872 assert!(result.data["cancelled_ids"]
873 .as_array()
874 .unwrap()
875 .iter()
876 .any(|v| v == "child"));
877
878 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
879 assert_eq!(
880 mgr.get("thread-1", "root").await.unwrap().status,
881 super::super::types::TaskStatus::Cancelled
882 );
883 assert_eq!(
884 mgr.get("thread-1", "child").await.unwrap().status,
885 super::super::types::TaskStatus::Cancelled
886 );
887 }
888
889 #[tokio::test]
890 async fn cancel_tool_marks_cancel_requested_in_task_store() {
891 let mgr = Arc::new(BackgroundTaskManager::new());
892 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
893 let task_store = Arc::new(TaskStore::new(storage as Arc<dyn ThreadStore>));
894 task_store
895 .create_task(super::super::NewTaskSpec {
896 task_id: "task-1".to_string(),
897 owner_thread_id: "thread-1".to_string(),
898 task_type: "shell".to_string(),
899 description: "long task".to_string(),
900 parent_task_id: None,
901 supports_resume: false,
902 metadata: json!({}),
903 })
904 .await
905 .unwrap();
906
907 mgr.spawn_with_id(
908 SpawnParams {
909 task_id: "task-1".to_string(),
910 owner_thread_id: "thread-1".to_string(),
911 task_type: "shell".to_string(),
912 description: "long task".to_string(),
913 parent_task_id: None,
914 metadata: json!({}),
915 },
916 crate::runtime::loop_runner::RunCancellationToken::new(),
917 |cancel| async move {
918 cancel.cancelled().await;
919 super::super::types::TaskResult::Cancelled
920 },
921 )
922 .await;
923
924 let tool = TaskCancelTool::new(mgr).with_task_store(Some(task_store.clone()));
925 let fix = fixture_with_thread("thread-1");
926 let result = tool
927 .execute(json!({"task_id": "task-1"}), &fix.ctx())
928 .await
929 .unwrap();
930 assert!(result.is_success());
931
932 let task = task_store
933 .load_task("task-1")
934 .await
935 .unwrap()
936 .expect("task should exist");
937 assert!(task.cancel_requested_at_ms.is_some());
938 }
939
940 #[tokio::test]
941 async fn cancel_tool_returns_warning_when_cancel_mark_persistence_fails() {
942 let mgr = Arc::new(BackgroundTaskManager::new());
943 let storage = Arc::new(FailingCancelMarkStore {
944 inner: Arc::new(tirea_store_adapters::MemoryStore::new()),
945 fail_task_appends: AtomicUsize::new(0),
946 });
947 let task_store = Arc::new(TaskStore::new(storage.clone() as Arc<dyn ThreadStore>));
948 task_store
949 .create_task(super::super::NewTaskSpec {
950 task_id: "task-1".to_string(),
951 owner_thread_id: "thread-1".to_string(),
952 task_type: "shell".to_string(),
953 description: "long task".to_string(),
954 parent_task_id: None,
955 supports_resume: false,
956 metadata: json!({}),
957 })
958 .await
959 .unwrap();
960
961 mgr.spawn_with_id(
962 SpawnParams {
963 task_id: "task-1".to_string(),
964 owner_thread_id: "thread-1".to_string(),
965 task_type: "shell".to_string(),
966 description: "long task".to_string(),
967 parent_task_id: None,
968 metadata: json!({}),
969 },
970 crate::runtime::loop_runner::RunCancellationToken::new(),
971 |cancel| async move {
972 cancel.cancelled().await;
973 super::super::types::TaskResult::Cancelled
974 },
975 )
976 .await;
977
978 storage.fail_task_appends.store(1, Ordering::SeqCst);
979
980 let tool = TaskCancelTool::new(mgr.clone()).with_task_store(Some(task_store.clone()));
981 let fix = fixture_with_thread("thread-1");
982 let result = tool
983 .execute(json!({"task_id": "task-1"}), &fix.ctx())
984 .await
985 .unwrap();
986
987 assert!(matches!(
988 result.status,
989 crate::contracts::runtime::tool_call::ToolStatus::Warning
990 ));
991 assert_eq!(result.data["cancelled"], json!(true));
992 assert_eq!(result.data["persistence_warning"]["failed_count"], json!(1));
993 assert_eq!(
994 result.data["persistence_warning"]["failures"][0]["task_id"],
995 json!("task-1")
996 );
997 assert!(result
998 .message
999 .as_deref()
1000 .unwrap_or("")
1001 .contains("failed to persist cancellation markers"));
1002
1003 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1004 let summary = mgr.get("thread-1", "task-1").await.unwrap();
1005 assert_eq!(summary.status, super::super::types::TaskStatus::Cancelled);
1006
1007 let task = task_store
1008 .load_task("task-1")
1009 .await
1010 .unwrap()
1011 .expect("task should exist");
1012 assert!(
1013 task.cancel_requested_at_ms.is_none(),
1014 "failed durable mark should not mutate persisted cancel_requested timestamp"
1015 );
1016 }
1017
1018 #[tokio::test]
1019 async fn cancel_tool_marks_cancel_requested_for_descendant_tree() {
1020 let mgr = Arc::new(BackgroundTaskManager::new());
1021 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
1022 let task_store = Arc::new(TaskStore::new(storage as Arc<dyn ThreadStore>));
1023
1024 for (task_id, parent_task_id) in [("root", None), ("child", Some("root"))] {
1025 task_store
1026 .create_task(super::super::NewTaskSpec {
1027 task_id: task_id.to_string(),
1028 owner_thread_id: "thread-1".to_string(),
1029 task_type: "agent_run".to_string(),
1030 description: format!("agent:{task_id}"),
1031 parent_task_id: parent_task_id.map(str::to_string),
1032 supports_resume: true,
1033 metadata: json!({}),
1034 })
1035 .await
1036 .unwrap();
1037 }
1038
1039 for (task_id, parent_task_id) in [("root", None), ("child", Some("root"))] {
1040 mgr.spawn_with_id(
1041 SpawnParams {
1042 task_id: task_id.to_string(),
1043 owner_thread_id: "thread-1".to_string(),
1044 task_type: "agent_run".to_string(),
1045 description: format!("agent:{task_id}"),
1046 parent_task_id: parent_task_id.map(str::to_string),
1047 metadata: json!({}),
1048 },
1049 crate::runtime::loop_runner::RunCancellationToken::new(),
1050 |cancel| async move {
1051 cancel.cancelled().await;
1052 super::super::types::TaskResult::Cancelled
1053 },
1054 )
1055 .await;
1056 }
1057
1058 let tool = TaskCancelTool::new(mgr).with_task_store(Some(task_store.clone()));
1059 let fix = fixture_with_thread("thread-1");
1060 let result = tool
1061 .execute(json!({"task_id": "root"}), &fix.ctx())
1062 .await
1063 .unwrap();
1064 assert!(result.is_success());
1065 assert_eq!(result.data["cancelled_count"], 2);
1066
1067 for task_id in ["root", "child"] {
1068 let task = task_store
1069 .load_task(task_id)
1070 .await
1071 .unwrap()
1072 .expect("task should exist");
1073 assert!(
1074 task.cancel_requested_at_ms.is_some(),
1075 "expected durable cancel_requested mark for {task_id}"
1076 );
1077 }
1078 }
1079
1080 #[tokio::test]
1081 async fn cancel_tool_unknown_task_returns_error() {
1082 let mgr = Arc::new(BackgroundTaskManager::new());
1083 let tool = TaskCancelTool::new(mgr);
1084 let fix = fixture_with_thread("thread-1");
1085 let result = tool
1086 .execute(json!({"task_id": "nope"}), &fix.ctx())
1087 .await
1088 .unwrap();
1089 assert!(!result.is_success());
1090 assert!(result
1091 .message
1092 .as_deref()
1093 .unwrap_or("")
1094 .contains("Unknown task_id"));
1095 }
1096
1097 #[tokio::test]
1098 async fn cancel_tool_already_completed_returns_error() {
1099 let mgr = Arc::new(BackgroundTaskManager::new());
1100 let tid = mgr
1101 .spawn("thread-1", "shell", "done", |_| async {
1102 super::super::types::TaskResult::Success(Value::Null)
1103 })
1104 .await;
1105 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1106
1107 let tool = TaskCancelTool::new(mgr);
1108 let fix = fixture_with_thread("thread-1");
1109 let result = tool
1110 .execute(json!({"task_id": tid}), &fix.ctx())
1111 .await
1112 .unwrap();
1113 assert!(!result.is_success());
1114 assert!(result
1115 .message
1116 .as_deref()
1117 .unwrap_or("")
1118 .contains("not running"));
1119 }
1120
1121 #[tokio::test]
1122 async fn cancel_tool_wrong_owner_rejected() {
1123 let mgr = Arc::new(BackgroundTaskManager::new());
1124 let tid = mgr
1125 .spawn("thread-1", "shell", "private", |cancel| async move {
1126 cancel.cancelled().await;
1127 super::super::types::TaskResult::Cancelled
1128 })
1129 .await;
1130
1131 let tool = TaskCancelTool::new(mgr);
1132 let fix = fixture_with_thread("thread-2");
1133 let result = tool
1134 .execute(json!({"task_id": tid}), &fix.ctx())
1135 .await
1136 .unwrap();
1137 assert!(!result.is_success());
1138 }
1139
1140 #[test]
1145 fn output_tool_descriptor_requires_task_id() {
1146 let mgr = Arc::new(BackgroundTaskManager::new());
1147 let tool = TaskOutputTool::new(mgr, None);
1148 let desc = tool.descriptor();
1149 assert_eq!(desc.id, TASK_OUTPUT_TOOL_ID);
1150 let required = desc.parameters["required"].as_array().unwrap();
1151 assert!(required.contains(&json!("task_id")));
1152 }
1153
1154 #[tokio::test]
1155 async fn output_tool_missing_task_id_returns_error() {
1156 let mgr = Arc::new(BackgroundTaskManager::new());
1157 let tool = TaskOutputTool::new(mgr, None);
1158 let fix = fixture_with_thread("thread-1");
1159 let err = tool.execute(json!({}), &fix.ctx()).await.unwrap_err();
1160 assert!(matches!(err, ToolError::InvalidArguments(_)));
1161 }
1162
1163 #[tokio::test]
1164 async fn output_tool_unknown_task_returns_error() {
1165 let mgr = Arc::new(BackgroundTaskManager::new());
1166 let tool = TaskOutputTool::new(mgr, None);
1167 let fix = fixture_with_thread("thread-1");
1168 let result = tool
1169 .execute(json!({"task_id": "nonexistent"}), &fix.ctx())
1170 .await
1171 .unwrap();
1172 assert!(!result.is_success());
1173 assert!(result
1174 .message
1175 .as_deref()
1176 .unwrap_or("")
1177 .contains("Unknown task_id"));
1178 }
1179
1180 #[tokio::test]
1181 async fn output_tool_returns_result_for_non_agent_task() {
1182 let mgr = Arc::new(BackgroundTaskManager::new());
1183 let tid = mgr
1184 .spawn("thread-1", "shell", "echo hi", |_| async {
1185 super::super::types::TaskResult::Success(json!({"exit_code": 0, "stdout": "hi"}))
1186 })
1187 .await;
1188 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1189
1190 let tool = TaskOutputTool::new(mgr, None);
1191 let fix = fixture_with_thread("thread-1");
1192 let result = tool
1193 .execute(json!({"task_id": tid}), &fix.ctx())
1194 .await
1195 .unwrap();
1196 assert!(result.is_success());
1197 assert_eq!(result.data["task_type"], "shell");
1198 assert_eq!(result.data["status"], "completed");
1199 assert_eq!(result.data["output"]["exit_code"], 0);
1200 assert_eq!(result.data["output"]["stdout"], "hi");
1201 }
1202
1203 #[tokio::test]
1204 async fn output_tool_reads_persisted_state_from_task_store() {
1205 let mgr = Arc::new(BackgroundTaskManager::new());
1206 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
1207 let task_store = Arc::new(TaskStore::new(storage as Arc<dyn ThreadStore>));
1208 task_store
1209 .create_task(super::super::NewTaskSpec {
1210 task_id: "run-1".to_string(),
1211 owner_thread_id: "thread-1".to_string(),
1212 task_type: "shell".to_string(),
1213 description: "echo test".to_string(),
1214 parent_task_id: None,
1215 supports_resume: false,
1216 metadata: json!({}),
1217 })
1218 .await
1219 .unwrap();
1220 task_store
1221 .persist_foreground_result(
1222 "run-1",
1223 TaskStatus::Completed,
1224 None,
1225 Some(json!({"stdout":"test"})),
1226 )
1227 .await
1228 .unwrap();
1229
1230 let tool = TaskOutputTool::new(mgr, None).with_task_store(Some(task_store));
1231 let fix = fixture_with_thread("thread-1");
1232 let result = tool
1233 .execute(json!({"task_id": "run-1"}), &fix.ctx())
1234 .await
1235 .unwrap();
1236 assert!(result.is_success());
1237 assert_eq!(result.data["task_type"], "shell");
1238 assert_eq!(result.data["status"], "completed");
1239 assert_eq!(result.data["output"]["stdout"], "test");
1240 }
1241
1242 #[tokio::test]
1243 async fn output_tool_without_task_store_cannot_read_persisted_task() {
1244 let mgr = Arc::new(BackgroundTaskManager::new());
1245 let tool = TaskOutputTool::new(mgr, None);
1246 let fix = fixture_with_thread("thread-1");
1247 let result = tool
1248 .execute(json!({"task_id": "run-1"}), &fix.ctx())
1249 .await
1250 .unwrap();
1251 assert!(!result.is_success());
1252 assert!(result
1253 .message
1254 .as_deref()
1255 .unwrap_or("")
1256 .contains("Unknown task_id"));
1257 }
1258
1259 #[tokio::test]
1260 async fn output_tool_does_not_read_cached_derived_view() {
1261 let mgr = Arc::new(BackgroundTaskManager::new());
1262 let tool = TaskOutputTool::new(mgr, None);
1263 let mut fix = TestFixture::new_with_state(json!({
1264 "__derived": {
1265 "background_tasks": {
1266 "tasks": {
1267 "ghost": {
1268 "task_type": "shell",
1269 "description": "ghost task",
1270 "status": "running"
1271 }
1272 },
1273 "synced_at_ms": 1
1274 }
1275 }
1276 }));
1277 fix.caller_context = CallerContext::new(
1278 Some("thread-1".to_string()),
1279 Some("caller-run".to_string()),
1280 Some("caller-agent".to_string()),
1281 vec![],
1282 );
1283
1284 let result = tool
1285 .execute(json!({"task_id": "ghost"}), &fix.ctx())
1286 .await
1287 .unwrap();
1288 assert!(!result.is_success());
1289 assert!(result
1290 .message
1291 .as_deref()
1292 .unwrap_or("")
1293 .contains("Unknown task_id"));
1294 }
1295
1296 #[tokio::test]
1297 async fn output_tool_thread_isolation() {
1298 let mgr = Arc::new(BackgroundTaskManager::new());
1299 let tid = mgr
1300 .spawn("thread-A", "shell", "private", |_| async {
1301 super::super::types::TaskResult::Success(json!("secret"))
1302 })
1303 .await;
1304 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1305
1306 let tool = TaskOutputTool::new(mgr, None);
1307
1308 let fix_b = fixture_with_thread("thread-B");
1310 let result = tool
1311 .execute(json!({"task_id": tid}), &fix_b.ctx())
1312 .await
1313 .unwrap();
1314 assert!(!result.is_success());
1315
1316 let fix_a = fixture_with_thread("thread-A");
1318 let result = tool
1319 .execute(json!({"task_id": tid}), &fix_a.ctx())
1320 .await
1321 .unwrap();
1322 assert!(result.is_success());
1323 }
1324}