tirea_agentos/runtime/background_tasks/
tools.rs

1//! Built-in tools for querying and managing background tasks.
2//!
3//! These tools provide a unified interface for the LLM to check status,
4//! read output, and cancel any background task regardless of type.
5
6use super::manager::BackgroundTaskManager;
7use super::{TaskState, TaskStatus, TaskStore, TaskSummary};
8use crate::contracts::runtime::tool_call::{ToolCallContext, ToolError, ToolResult};
9use crate::contracts::storage::ThreadStore;
10use async_trait::async_trait;
11use schemars::JsonSchema;
12use serde::Deserialize;
13use serde_json::{json, Value};
14use std::collections::HashMap;
15use std::sync::Arc;
16
17pub const TASK_STATUS_TOOL_ID: &str = "task_status";
18pub const TASK_CANCEL_TOOL_ID: &str = "task_cancel";
19pub const TASK_OUTPUT_TOOL_ID: &str = "task_output";
20
21fn owner_thread_id(ctx: &ToolCallContext<'_>) -> Option<String> {
22    ctx.caller_context().thread_id().map(str::to_string)
23}
24
25// ---------------------------------------------------------------------------
26// task_status
27// ---------------------------------------------------------------------------
28
29/// Query background task status and result.
30///
31/// Supports querying a single task by `task_id` or listing all tasks.
32#[derive(Debug, Clone)]
33pub struct TaskStatusTool {
34    manager: Arc<BackgroundTaskManager>,
35    task_store: Option<Arc<TaskStore>>,
36}
37
38impl TaskStatusTool {
39    pub fn new(manager: Arc<BackgroundTaskManager>) -> Self {
40        Self {
41            manager,
42            task_store: None,
43        }
44    }
45
46    pub fn with_task_store(mut self, task_store: Option<Arc<TaskStore>>) -> Self {
47        self.task_store = task_store;
48        self
49    }
50
51    async fn query_one(
52        &self,
53        owner_thread_id: &str,
54        task_id: &str,
55    ) -> Result<Option<TaskSummary>, String> {
56        let persisted = if let Some(store) = &self.task_store {
57            store
58                .load_task_for_owner(owner_thread_id, task_id)
59                .await
60                .map_err(|e| e.to_string())?
61                .map(|task| task.summary())
62        } else {
63            None
64        };
65        let live = self.manager.get(owner_thread_id, task_id).await;
66
67        Ok(match (persisted, live) {
68            (_, Some(live)) => Some(live),
69            (Some(task), None) => Some(task),
70            (None, None) => None,
71        })
72    }
73
74    async fn list_all(&self, owner_thread_id: &str) -> Result<Vec<TaskSummary>, String> {
75        let mut by_id: HashMap<String, TaskSummary> = HashMap::new();
76
77        if let Some(store) = &self.task_store {
78            let tasks = store
79                .list_tasks_for_owner(owner_thread_id)
80                .await
81                .map_err(|e| e.to_string())?;
82            for task in tasks {
83                by_id.insert(task.id.clone(), task.summary());
84            }
85        }
86
87        for summary in self.manager.list(owner_thread_id, None).await {
88            by_id.insert(summary.task_id.clone(), summary);
89        }
90
91        let mut out: Vec<TaskSummary> = by_id.into_values().collect();
92        out.sort_by(|a, b| {
93            a.created_at_ms
94                .cmp(&b.created_at_ms)
95                .then_with(|| a.task_id.cmp(&b.task_id))
96        });
97        Ok(out)
98    }
99}
100
101#[derive(Debug, Deserialize, JsonSchema)]
102pub struct TaskStatusArgs {
103    /// Task ID to query. Omit to list all tasks.
104    task_id: Option<String>,
105}
106
107#[async_trait]
108impl crate::contracts::runtime::tool_call::TypedTool for TaskStatusTool {
109    type Args = TaskStatusArgs;
110
111    fn tool_id(&self) -> &str {
112        TASK_STATUS_TOOL_ID
113    }
114    fn name(&self) -> &str {
115        "Task Status"
116    }
117    fn description(&self) -> &str {
118        "Check the status and result of background tasks. \
119         Provide task_id to query a specific task, or omit to list all tasks."
120    }
121
122    async fn execute(
123        &self,
124        args: TaskStatusArgs,
125        ctx: &ToolCallContext<'_>,
126    ) -> Result<ToolResult, ToolError> {
127        let Some(thread_id) = owner_thread_id(ctx) else {
128            return Ok(ToolResult::error(
129                TASK_STATUS_TOOL_ID,
130                "Missing caller thread context",
131            ));
132        };
133
134        let task_id = args.task_id.as_deref().filter(|s| !s.trim().is_empty());
135
136        if let Some(task_id) = task_id {
137            match self.query_one(&thread_id, task_id).await {
138                Ok(Some(summary)) => Ok(ToolResult::success(
139                    TASK_STATUS_TOOL_ID,
140                    serde_json::to_value(&summary).unwrap_or(Value::Null),
141                )),
142                Ok(None) => Ok(ToolResult::error(
143                    TASK_STATUS_TOOL_ID,
144                    format!("Unknown task_id: {task_id}"),
145                )),
146                Err(err) => Ok(ToolResult::error(TASK_STATUS_TOOL_ID, err)),
147            }
148        } else {
149            match self.list_all(&thread_id).await {
150                Ok(tasks) => Ok(ToolResult::success(
151                    TASK_STATUS_TOOL_ID,
152                    json!({
153                        "tasks": serde_json::to_value(&tasks).unwrap_or(Value::Null),
154                        "total": tasks.len(),
155                    }),
156                )),
157                Err(err) => Ok(ToolResult::error(TASK_STATUS_TOOL_ID, err)),
158            }
159        }
160    }
161}
162
163// ---------------------------------------------------------------------------
164// task_cancel
165// ---------------------------------------------------------------------------
166
167/// Cancel a running background task and any descendant tasks.
168#[derive(Debug, Clone)]
169pub struct TaskCancelTool {
170    manager: Arc<BackgroundTaskManager>,
171    task_store: Option<Arc<TaskStore>>,
172}
173
174impl TaskCancelTool {
175    pub fn new(manager: Arc<BackgroundTaskManager>) -> Self {
176        Self {
177            manager,
178            task_store: None,
179        }
180    }
181
182    pub fn with_task_store(mut self, task_store: Option<Arc<TaskStore>>) -> Self {
183        self.task_store = task_store;
184        self
185    }
186}
187
188#[derive(Debug, Deserialize, JsonSchema)]
189pub struct TaskCancelArgs {
190    /// The task ID to cancel.
191    task_id: String,
192}
193
194#[async_trait]
195impl crate::contracts::runtime::tool_call::TypedTool for TaskCancelTool {
196    type Args = TaskCancelArgs;
197
198    fn tool_id(&self) -> &str {
199        TASK_CANCEL_TOOL_ID
200    }
201    fn name(&self) -> &str {
202        "Task Cancel"
203    }
204    fn description(&self) -> &str {
205        "Cancel a running background task by task_id. \
206         Descendant tasks are cancelled automatically."
207    }
208
209    fn validate(&self, args: &Self::Args) -> Result<(), String> {
210        if args.task_id.trim().is_empty() {
211            return Err("task_id cannot be empty".to_string());
212        }
213        Ok(())
214    }
215
216    async fn execute(
217        &self,
218        args: TaskCancelArgs,
219        ctx: &ToolCallContext<'_>,
220    ) -> Result<ToolResult, ToolError> {
221        let task_id = &args.task_id;
222
223        let Some(thread_id) = owner_thread_id(ctx) else {
224            return Ok(ToolResult::error(
225                TASK_CANCEL_TOOL_ID,
226                "Missing caller thread context",
227            ));
228        };
229
230        match self.manager.cancel_tree(&thread_id, task_id).await {
231            Ok(cancelled) => {
232                let mut persistence_failures = Vec::new();
233                if let Some(store) = &self.task_store {
234                    for summary in &cancelled {
235                        if let Err(error) = store.mark_cancel_requested(&summary.task_id).await {
236                            tracing::warn!(
237                                root_task_id = %task_id,
238                                cancelled_task_id = %summary.task_id,
239                                owner_thread_id = %thread_id,
240                                error = %error,
241                                "failed to persist background task cancellation marker"
242                            );
243                            persistence_failures.push((summary.task_id.clone(), error.to_string()));
244                        }
245                    }
246                }
247                let ids: Vec<&str> = cancelled.iter().map(|s| s.task_id.as_str()).collect();
248                let mut data = json!({
249                    "task_id": task_id,
250                    "cancelled": true,
251                    "cancelled_ids": ids,
252                    "cancelled_count": cancelled.len(),
253                });
254                if persistence_failures.is_empty() {
255                    return Ok(ToolResult::success(TASK_CANCEL_TOOL_ID, data));
256                }
257
258                data["persistence_warning"] = json!({
259                    "failed_count": persistence_failures.len(),
260                    "failures": persistence_failures.iter().map(|(task_id, error)| json!({
261                        "task_id": task_id,
262                        "error": error,
263                    })).collect::<Vec<_>>(),
264                });
265
266                Ok(ToolResult::warning(
267                    TASK_CANCEL_TOOL_ID,
268                    data,
269                    format!(
270                        "Cancellation requested, but failed to persist cancellation markers for {} task(s).",
271                        persistence_failures.len()
272                    ),
273                ))
274            }
275            Err(e) => Ok(ToolResult::error(TASK_CANCEL_TOOL_ID, e)),
276        }
277    }
278}
279
280// ---------------------------------------------------------------------------
281// task_output
282// ---------------------------------------------------------------------------
283
284/// Read the output of a background task.
285///
286/// For `agent_run` tasks, returns the last assistant message from the sub-agent
287/// thread. For other task types, returns the task result from the manager.
288/// Reads durable task state from task threads and overlays live in-memory
289/// results when available.
290#[derive(Clone)]
291pub struct TaskOutputTool {
292    manager: Arc<BackgroundTaskManager>,
293    task_store: Option<Arc<TaskStore>>,
294}
295
296impl std::fmt::Debug for TaskOutputTool {
297    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298        f.debug_struct("TaskOutputTool")
299            .field("has_task_store", &self.task_store.is_some())
300            .finish()
301    }
302}
303
304impl TaskOutputTool {
305    pub fn new(
306        manager: Arc<BackgroundTaskManager>,
307        thread_store: Option<Arc<dyn ThreadStore>>,
308    ) -> Self {
309        Self {
310            manager,
311            task_store: thread_store.map(TaskStore::new).map(Arc::new),
312        }
313    }
314
315    pub fn with_task_store(mut self, task_store: Option<Arc<TaskStore>>) -> Self {
316        self.task_store = task_store;
317        self
318    }
319}
320
321#[derive(Debug, Deserialize, JsonSchema)]
322pub struct TaskOutputArgs {
323    /// The task ID to read output from.
324    task_id: String,
325}
326
327#[async_trait]
328impl crate::contracts::runtime::tool_call::TypedTool for TaskOutputTool {
329    type Args = TaskOutputArgs;
330
331    fn tool_id(&self) -> &str {
332        TASK_OUTPUT_TOOL_ID
333    }
334    fn name(&self) -> &str {
335        "Task Output"
336    }
337    fn description(&self) -> &str {
338        "Read the output of a background task. \
339         For agent runs, returns the last assistant message. \
340         For other tasks, returns the task result."
341    }
342
343    fn validate(&self, args: &Self::Args) -> Result<(), String> {
344        if args.task_id.trim().is_empty() {
345            return Err("task_id cannot be empty".to_string());
346        }
347        Ok(())
348    }
349
350    async fn execute(
351        &self,
352        args: TaskOutputArgs,
353        ctx: &ToolCallContext<'_>,
354    ) -> Result<ToolResult, ToolError> {
355        let task_id = &args.task_id;
356
357        let Some(thread_id) = owner_thread_id(ctx) else {
358            return Ok(ToolResult::error(
359                TASK_OUTPUT_TOOL_ID,
360                "Missing caller thread context",
361            ));
362        };
363
364        let Some(task_store) = &self.task_store else {
365            if let Some(summary) = self.manager.get(&thread_id, task_id).await {
366                return Ok(ToolResult::success(
367                    TASK_OUTPUT_TOOL_ID,
368                    json!({
369                        "task_id": task_id,
370                        "task_type": summary.task_type,
371                        "status": summary.status.as_str(),
372                        "output": summary.result,
373                    }),
374                ));
375            }
376            return Ok(ToolResult::error(
377                TASK_OUTPUT_TOOL_ID,
378                format!("Unknown task_id: {task_id}"),
379            ));
380        };
381
382        let Some(task) = task_store
383            .load_task_for_owner(&thread_id, task_id)
384            .await
385            .map_err(|e| ToolError::ExecutionFailed(format!("task store lookup failed: {e}")))?
386        else {
387            return Ok(ToolResult::error(
388                TASK_OUTPUT_TOOL_ID,
389                format!("Unknown task_id: {task_id}"),
390            ));
391        };
392
393        Ok(self.output_from_task(task_id, &task).await)
394    }
395}
396
397impl TaskOutputTool {
398    async fn output_from_task(&self, task_id: &str, task: &TaskState) -> ToolResult {
399        let live = self
400            .manager
401            .get(&task.owner_thread_id, task_id)
402            .await
403            .filter(|summary| summary.status == task.status || task.status == TaskStatus::Running);
404
405        let output = if task.task_type == "agent_run" {
406            match &self.task_store {
407                Some(store) => match store.load_output_text(task).await {
408                    Ok(output) => output.map(Value::String),
409                    Err(e) => {
410                        return ToolResult::error(TASK_OUTPUT_TOOL_ID, e.to_string());
411                    }
412                },
413                None => None,
414            }
415        } else {
416            live.and_then(|summary| summary.result)
417                .or_else(|| task.result.clone())
418        };
419
420        ToolResult::success(
421            TASK_OUTPUT_TOOL_ID,
422            json!({
423                "task_id": task_id,
424                "task_type": task.task_type.clone(),
425                "agent_id": task.metadata.get("agent_id").cloned().unwrap_or(Value::Null),
426                "status": task.status.as_str(),
427                "output": output,
428            }),
429        )
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::contracts::runtime::tool_call::{CallerContext, Tool};
437    use crate::contracts::storage::{
438        Committed, MessagePage, MessageQuery, RunPage, RunQuery, RunRecord, ThreadHead,
439        ThreadListPage, ThreadListQuery, ThreadReader, ThreadStore, ThreadStoreError, ThreadWriter,
440        VersionPrecondition,
441    };
442    use crate::contracts::thread::{Thread, ThreadChangeSet};
443    use crate::runtime::background_tasks::SpawnParams;
444    use async_trait::async_trait;
445    use std::sync::atomic::{AtomicUsize, Ordering};
446    use tirea_contract::testing::TestFixture;
447
448    fn fixture_with_thread(thread_id: &str) -> TestFixture {
449        let mut fix = TestFixture::new();
450        fix.caller_context = CallerContext::new(
451            Some(thread_id.to_string()),
452            Some("caller-run".to_string()),
453            Some("caller-agent".to_string()),
454            vec![],
455        );
456        fix
457    }
458
459    struct FailingCancelMarkStore {
460        inner: Arc<tirea_store_adapters::MemoryStore>,
461        fail_task_appends: AtomicUsize,
462    }
463
464    #[async_trait]
465    impl ThreadReader for FailingCancelMarkStore {
466        async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
467            self.inner.load(thread_id).await
468        }
469
470        async fn list_threads(
471            &self,
472            query: &ThreadListQuery,
473        ) -> Result<ThreadListPage, ThreadStoreError> {
474            self.inner.list_threads(query).await
475        }
476
477        async fn load_messages(
478            &self,
479            thread_id: &str,
480            query: &MessageQuery,
481        ) -> Result<MessagePage, ThreadStoreError> {
482            self.inner.load_messages(thread_id, query).await
483        }
484
485        async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, ThreadStoreError> {
486            self.inner.load_run(run_id).await
487        }
488
489        async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, ThreadStoreError> {
490            self.inner.list_runs(query).await
491        }
492
493        async fn active_run_for_thread(
494            &self,
495            thread_id: &str,
496        ) -> Result<Option<RunRecord>, ThreadStoreError> {
497            self.inner.active_run_for_thread(thread_id).await
498        }
499    }
500
501    #[async_trait]
502    impl ThreadWriter for FailingCancelMarkStore {
503        async fn create(&self, thread: &Thread) -> Result<Committed, ThreadStoreError> {
504            self.inner.create(thread).await
505        }
506
507        async fn append(
508            &self,
509            thread_id: &str,
510            delta: &ThreadChangeSet,
511            precondition: VersionPrecondition,
512        ) -> Result<Committed, ThreadStoreError> {
513            if thread_id.starts_with(super::super::TASK_THREAD_PREFIX)
514                && self
515                    .fail_task_appends
516                    .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| {
517                        if remaining > 0 {
518                            Some(remaining - 1)
519                        } else {
520                            None
521                        }
522                    })
523                    .is_ok()
524            {
525                return Err(ThreadStoreError::Io(std::io::Error::other(
526                    "injected cancel mark persistence failure",
527                )));
528            }
529
530            self.inner.append(thread_id, delta, precondition).await
531        }
532
533        async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
534            self.inner.delete(thread_id).await
535        }
536
537        async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
538            self.inner.save(thread).await
539        }
540    }
541
542    #[test]
543    fn task_status_descriptor_has_optional_task_id() {
544        let mgr = Arc::new(BackgroundTaskManager::new());
545        let tool = TaskStatusTool::new(mgr);
546        let desc = tool.descriptor();
547        assert_eq!(desc.id, TASK_STATUS_TOOL_ID);
548        // task_id is not in "required" — it's Option<String>.
549        let required = desc.parameters.get("required");
550        assert!(required.is_none() || required.unwrap().as_array().unwrap().is_empty());
551        assert!(desc.parameters["properties"].get("task_id").is_some());
552    }
553
554    #[test]
555    fn task_cancel_descriptor_requires_task_id() {
556        let mgr = Arc::new(BackgroundTaskManager::new());
557        let tool = TaskCancelTool::new(mgr);
558        let desc = tool.descriptor();
559        assert_eq!(desc.id, TASK_CANCEL_TOOL_ID);
560        let required = desc.parameters["required"].as_array().unwrap();
561        assert!(required.contains(&json!("task_id")));
562    }
563
564    // -----------------------------------------------------------------------
565    // TaskStatusTool execute() tests
566    // -----------------------------------------------------------------------
567
568    #[tokio::test]
569    async fn status_tool_missing_thread_context_returns_error() {
570        let mgr = Arc::new(BackgroundTaskManager::new());
571        let tool = TaskStatusTool::new(mgr);
572        let fix = TestFixture::new(); // no __agent_tool_caller_thread_id
573        let result = tool.execute(json!({}), &fix.ctx()).await.unwrap();
574        assert!(!result.is_success());
575        assert!(result
576            .message
577            .as_deref()
578            .unwrap_or("")
579            .contains("Missing caller thread context"));
580    }
581
582    #[tokio::test]
583    async fn status_tool_list_all_when_no_tasks() {
584        let mgr = Arc::new(BackgroundTaskManager::new());
585        let tool = TaskStatusTool::new(mgr);
586        let fix = fixture_with_thread("thread-1");
587        let result = tool.execute(json!({}), &fix.ctx()).await.unwrap();
588        assert!(result.is_success());
589        let content: Value = result.data.clone();
590        assert_eq!(content["total"], 0);
591        assert!(content["tasks"].as_array().unwrap().is_empty());
592    }
593
594    #[tokio::test]
595    async fn status_tool_query_single_task() {
596        let mgr = Arc::new(BackgroundTaskManager::new());
597        let tid = mgr
598            .spawn("thread-1", "shell", "echo hi", |_cancel| async {
599                super::super::types::TaskResult::Success(json!({"exit": 0}))
600            })
601            .await;
602        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
603
604        let tool = TaskStatusTool::new(mgr);
605        let fix = fixture_with_thread("thread-1");
606        let result = tool
607            .execute(json!({"task_id": tid}), &fix.ctx())
608            .await
609            .unwrap();
610        assert!(result.is_success());
611        let content: Value = result.data.clone();
612        assert_eq!(content["status"], "completed");
613        assert_eq!(content["result"]["exit"], 0);
614    }
615
616    #[tokio::test]
617    async fn status_tool_query_unknown_task_returns_error() {
618        let mgr = Arc::new(BackgroundTaskManager::new());
619        let tool = TaskStatusTool::new(mgr);
620        let fix = fixture_with_thread("thread-1");
621        let result = tool
622            .execute(json!({"task_id": "bogus"}), &fix.ctx())
623            .await
624            .unwrap();
625        assert!(!result.is_success());
626        assert!(result
627            .message
628            .as_deref()
629            .unwrap_or("")
630            .contains("Unknown task_id"));
631    }
632
633    #[tokio::test]
634    async fn status_tool_list_shows_running_and_completed() {
635        let mgr = Arc::new(BackgroundTaskManager::new());
636        // Running task.
637        let _running = mgr
638            .spawn("thread-1", "shell", "long", |cancel| async move {
639                cancel.cancelled().await;
640                super::super::types::TaskResult::Cancelled
641            })
642            .await;
643        // Completed task.
644        mgr.spawn("thread-1", "http", "fetch", |_| async {
645            super::super::types::TaskResult::Success(Value::Null)
646        })
647        .await;
648        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
649
650        let tool = TaskStatusTool::new(mgr);
651        let fix = fixture_with_thread("thread-1");
652        let result = tool.execute(json!({}), &fix.ctx()).await.unwrap();
653        assert!(result.is_success());
654        let content: Value = result.data.clone();
655        assert_eq!(content["total"], 2);
656    }
657
658    #[tokio::test]
659    async fn status_tool_thread_isolation() {
660        let mgr = Arc::new(BackgroundTaskManager::new());
661        let tid = mgr
662            .spawn("thread-A", "shell", "private", |_| async {
663                super::super::types::TaskResult::Success(Value::Null)
664            })
665            .await;
666        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
667
668        let tool = TaskStatusTool::new(mgr);
669
670        // Thread-B cannot see thread-A's task.
671        let fix_b = fixture_with_thread("thread-B");
672        let result = tool
673            .execute(json!({"task_id": tid}), &fix_b.ctx())
674            .await
675            .unwrap();
676        assert!(!result.is_success());
677
678        // Thread-A can see it.
679        let fix_a = fixture_with_thread("thread-A");
680        let result = tool
681            .execute(json!({"task_id": tid}), &fix_a.ctx())
682            .await
683            .unwrap();
684        assert!(result.is_success());
685    }
686
687    #[tokio::test]
688    async fn status_tool_reads_persisted_task_without_live_manager() {
689        let mgr = Arc::new(BackgroundTaskManager::new());
690        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
691        let task_store = Arc::new(TaskStore::new(storage as Arc<dyn ThreadStore>));
692        task_store
693            .create_task(super::super::NewTaskSpec {
694                task_id: "task-1".to_string(),
695                owner_thread_id: "thread-1".to_string(),
696                task_type: "shell".to_string(),
697                description: "persisted only".to_string(),
698                parent_task_id: None,
699                supports_resume: false,
700                metadata: json!({}),
701            })
702            .await
703            .unwrap();
704        task_store
705            .persist_foreground_result(
706                "task-1",
707                TaskStatus::Completed,
708                None,
709                Some(json!({"stdout":"done"})),
710            )
711            .await
712            .unwrap();
713
714        let tool = TaskStatusTool::new(mgr).with_task_store(Some(task_store));
715        let fix = fixture_with_thread("thread-1");
716        let result = tool
717            .execute(json!({"task_id": "task-1"}), &fix.ctx())
718            .await
719            .unwrap();
720
721        assert!(result.is_success());
722        assert_eq!(result.data["status"], "completed");
723        assert_eq!(result.data["result"]["stdout"], "done");
724    }
725
726    #[tokio::test]
727    async fn status_tool_does_not_read_cached_derived_view() {
728        let mgr = Arc::new(BackgroundTaskManager::new());
729        let tool = TaskStatusTool::new(mgr);
730        let mut fix = TestFixture::new_with_state(json!({
731            "__derived": {
732                "background_tasks": {
733                    "tasks": {
734                        "ghost": {
735                            "task_type": "shell",
736                            "description": "ghost task",
737                            "status": "running"
738                        }
739                    },
740                    "synced_at_ms": 1
741                }
742            }
743        }));
744        fix.caller_context = CallerContext::new(
745            Some("thread-1".to_string()),
746            Some("caller-run".to_string()),
747            Some("caller-agent".to_string()),
748            vec![],
749        );
750
751        let result = tool
752            .execute(json!({"task_id": "ghost"}), &fix.ctx())
753            .await
754            .unwrap();
755        assert!(!result.is_success());
756        assert!(result
757            .message
758            .as_deref()
759            .unwrap_or("")
760            .contains("Unknown task_id"));
761    }
762
763    // -----------------------------------------------------------------------
764    // TaskCancelTool execute() tests
765    // -----------------------------------------------------------------------
766
767    #[tokio::test]
768    async fn cancel_tool_missing_thread_context_returns_error() {
769        let mgr = Arc::new(BackgroundTaskManager::new());
770        let tool = TaskCancelTool::new(mgr);
771        let fix = TestFixture::new();
772        let result = tool
773            .execute(json!({"task_id": "some"}), &fix.ctx())
774            .await
775            .unwrap();
776        assert!(!result.is_success());
777        assert!(result
778            .message
779            .as_deref()
780            .unwrap_or("")
781            .contains("Missing caller thread context"));
782    }
783
784    #[tokio::test]
785    async fn cancel_tool_missing_task_id_param() {
786        let mgr = Arc::new(BackgroundTaskManager::new());
787        let tool = TaskCancelTool::new(mgr);
788        let fix = fixture_with_thread("thread-1");
789        let err = tool.execute(json!({}), &fix.ctx()).await.unwrap_err();
790        assert!(matches!(err, ToolError::InvalidArguments(_)));
791    }
792
793    #[tokio::test]
794    async fn cancel_tool_cancels_running_task() {
795        let mgr = Arc::new(BackgroundTaskManager::new());
796        let tid = mgr
797            .spawn("thread-1", "shell", "long", |cancel| async move {
798                cancel.cancelled().await;
799                super::super::types::TaskResult::Cancelled
800            })
801            .await;
802
803        let tool = TaskCancelTool::new(mgr.clone());
804        let fix = fixture_with_thread("thread-1");
805        let result = tool
806            .execute(json!({"task_id": tid}), &fix.ctx())
807            .await
808            .unwrap();
809        assert!(result.is_success());
810        let content: Value = result.data.clone();
811        assert!(content["cancelled"].as_bool().unwrap());
812
813        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
814        let summary = mgr.get("thread-1", &tid).await.unwrap();
815        assert_eq!(summary.status, super::super::types::TaskStatus::Cancelled);
816    }
817
818    #[tokio::test]
819    async fn cancel_tool_cancels_descendants_by_default() {
820        let mgr = Arc::new(BackgroundTaskManager::new());
821        let root_token = crate::runtime::loop_runner::RunCancellationToken::new();
822        let child_token = crate::runtime::loop_runner::RunCancellationToken::new();
823
824        mgr.spawn_with_id(
825            SpawnParams {
826                task_id: "root".to_string(),
827                owner_thread_id: "thread-1".to_string(),
828                task_type: "agent_run".to_string(),
829                description: "agent:root".to_string(),
830                parent_task_id: None,
831                metadata: json!({}),
832            },
833            root_token,
834            |cancel| async move {
835                cancel.cancelled().await;
836                super::super::types::TaskResult::Cancelled
837            },
838        )
839        .await;
840
841        mgr.spawn_with_id(
842            SpawnParams {
843                task_id: "child".to_string(),
844                owner_thread_id: "thread-1".to_string(),
845                task_type: "agent_run".to_string(),
846                description: "agent:child".to_string(),
847                parent_task_id: Some("root".to_string()),
848                metadata: json!({}),
849            },
850            child_token,
851            |cancel| async move {
852                cancel.cancelled().await;
853                super::super::types::TaskResult::Cancelled
854            },
855        )
856        .await;
857
858        let tool = TaskCancelTool::new(mgr.clone());
859        let fix = fixture_with_thread("thread-1");
860        let result = tool
861            .execute(json!({"task_id": "root"}), &fix.ctx())
862            .await
863            .unwrap();
864
865        assert!(result.is_success());
866        assert_eq!(result.data["cancelled_count"], 2);
867        assert!(result.data["cancelled_ids"]
868            .as_array()
869            .unwrap()
870            .iter()
871            .any(|v| v == "root"));
872        assert!(result.data["cancelled_ids"]
873            .as_array()
874            .unwrap()
875            .iter()
876            .any(|v| v == "child"));
877
878        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
879        assert_eq!(
880            mgr.get("thread-1", "root").await.unwrap().status,
881            super::super::types::TaskStatus::Cancelled
882        );
883        assert_eq!(
884            mgr.get("thread-1", "child").await.unwrap().status,
885            super::super::types::TaskStatus::Cancelled
886        );
887    }
888
889    #[tokio::test]
890    async fn cancel_tool_marks_cancel_requested_in_task_store() {
891        let mgr = Arc::new(BackgroundTaskManager::new());
892        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
893        let task_store = Arc::new(TaskStore::new(storage as Arc<dyn ThreadStore>));
894        task_store
895            .create_task(super::super::NewTaskSpec {
896                task_id: "task-1".to_string(),
897                owner_thread_id: "thread-1".to_string(),
898                task_type: "shell".to_string(),
899                description: "long task".to_string(),
900                parent_task_id: None,
901                supports_resume: false,
902                metadata: json!({}),
903            })
904            .await
905            .unwrap();
906
907        mgr.spawn_with_id(
908            SpawnParams {
909                task_id: "task-1".to_string(),
910                owner_thread_id: "thread-1".to_string(),
911                task_type: "shell".to_string(),
912                description: "long task".to_string(),
913                parent_task_id: None,
914                metadata: json!({}),
915            },
916            crate::runtime::loop_runner::RunCancellationToken::new(),
917            |cancel| async move {
918                cancel.cancelled().await;
919                super::super::types::TaskResult::Cancelled
920            },
921        )
922        .await;
923
924        let tool = TaskCancelTool::new(mgr).with_task_store(Some(task_store.clone()));
925        let fix = fixture_with_thread("thread-1");
926        let result = tool
927            .execute(json!({"task_id": "task-1"}), &fix.ctx())
928            .await
929            .unwrap();
930        assert!(result.is_success());
931
932        let task = task_store
933            .load_task("task-1")
934            .await
935            .unwrap()
936            .expect("task should exist");
937        assert!(task.cancel_requested_at_ms.is_some());
938    }
939
940    #[tokio::test]
941    async fn cancel_tool_returns_warning_when_cancel_mark_persistence_fails() {
942        let mgr = Arc::new(BackgroundTaskManager::new());
943        let storage = Arc::new(FailingCancelMarkStore {
944            inner: Arc::new(tirea_store_adapters::MemoryStore::new()),
945            fail_task_appends: AtomicUsize::new(0),
946        });
947        let task_store = Arc::new(TaskStore::new(storage.clone() as Arc<dyn ThreadStore>));
948        task_store
949            .create_task(super::super::NewTaskSpec {
950                task_id: "task-1".to_string(),
951                owner_thread_id: "thread-1".to_string(),
952                task_type: "shell".to_string(),
953                description: "long task".to_string(),
954                parent_task_id: None,
955                supports_resume: false,
956                metadata: json!({}),
957            })
958            .await
959            .unwrap();
960
961        mgr.spawn_with_id(
962            SpawnParams {
963                task_id: "task-1".to_string(),
964                owner_thread_id: "thread-1".to_string(),
965                task_type: "shell".to_string(),
966                description: "long task".to_string(),
967                parent_task_id: None,
968                metadata: json!({}),
969            },
970            crate::runtime::loop_runner::RunCancellationToken::new(),
971            |cancel| async move {
972                cancel.cancelled().await;
973                super::super::types::TaskResult::Cancelled
974            },
975        )
976        .await;
977
978        storage.fail_task_appends.store(1, Ordering::SeqCst);
979
980        let tool = TaskCancelTool::new(mgr.clone()).with_task_store(Some(task_store.clone()));
981        let fix = fixture_with_thread("thread-1");
982        let result = tool
983            .execute(json!({"task_id": "task-1"}), &fix.ctx())
984            .await
985            .unwrap();
986
987        assert!(matches!(
988            result.status,
989            crate::contracts::runtime::tool_call::ToolStatus::Warning
990        ));
991        assert_eq!(result.data["cancelled"], json!(true));
992        assert_eq!(result.data["persistence_warning"]["failed_count"], json!(1));
993        assert_eq!(
994            result.data["persistence_warning"]["failures"][0]["task_id"],
995            json!("task-1")
996        );
997        assert!(result
998            .message
999            .as_deref()
1000            .unwrap_or("")
1001            .contains("failed to persist cancellation markers"));
1002
1003        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1004        let summary = mgr.get("thread-1", "task-1").await.unwrap();
1005        assert_eq!(summary.status, super::super::types::TaskStatus::Cancelled);
1006
1007        let task = task_store
1008            .load_task("task-1")
1009            .await
1010            .unwrap()
1011            .expect("task should exist");
1012        assert!(
1013            task.cancel_requested_at_ms.is_none(),
1014            "failed durable mark should not mutate persisted cancel_requested timestamp"
1015        );
1016    }
1017
1018    #[tokio::test]
1019    async fn cancel_tool_marks_cancel_requested_for_descendant_tree() {
1020        let mgr = Arc::new(BackgroundTaskManager::new());
1021        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
1022        let task_store = Arc::new(TaskStore::new(storage as Arc<dyn ThreadStore>));
1023
1024        for (task_id, parent_task_id) in [("root", None), ("child", Some("root"))] {
1025            task_store
1026                .create_task(super::super::NewTaskSpec {
1027                    task_id: task_id.to_string(),
1028                    owner_thread_id: "thread-1".to_string(),
1029                    task_type: "agent_run".to_string(),
1030                    description: format!("agent:{task_id}"),
1031                    parent_task_id: parent_task_id.map(str::to_string),
1032                    supports_resume: true,
1033                    metadata: json!({}),
1034                })
1035                .await
1036                .unwrap();
1037        }
1038
1039        for (task_id, parent_task_id) in [("root", None), ("child", Some("root"))] {
1040            mgr.spawn_with_id(
1041                SpawnParams {
1042                    task_id: task_id.to_string(),
1043                    owner_thread_id: "thread-1".to_string(),
1044                    task_type: "agent_run".to_string(),
1045                    description: format!("agent:{task_id}"),
1046                    parent_task_id: parent_task_id.map(str::to_string),
1047                    metadata: json!({}),
1048                },
1049                crate::runtime::loop_runner::RunCancellationToken::new(),
1050                |cancel| async move {
1051                    cancel.cancelled().await;
1052                    super::super::types::TaskResult::Cancelled
1053                },
1054            )
1055            .await;
1056        }
1057
1058        let tool = TaskCancelTool::new(mgr).with_task_store(Some(task_store.clone()));
1059        let fix = fixture_with_thread("thread-1");
1060        let result = tool
1061            .execute(json!({"task_id": "root"}), &fix.ctx())
1062            .await
1063            .unwrap();
1064        assert!(result.is_success());
1065        assert_eq!(result.data["cancelled_count"], 2);
1066
1067        for task_id in ["root", "child"] {
1068            let task = task_store
1069                .load_task(task_id)
1070                .await
1071                .unwrap()
1072                .expect("task should exist");
1073            assert!(
1074                task.cancel_requested_at_ms.is_some(),
1075                "expected durable cancel_requested mark for {task_id}"
1076            );
1077        }
1078    }
1079
1080    #[tokio::test]
1081    async fn cancel_tool_unknown_task_returns_error() {
1082        let mgr = Arc::new(BackgroundTaskManager::new());
1083        let tool = TaskCancelTool::new(mgr);
1084        let fix = fixture_with_thread("thread-1");
1085        let result = tool
1086            .execute(json!({"task_id": "nope"}), &fix.ctx())
1087            .await
1088            .unwrap();
1089        assert!(!result.is_success());
1090        assert!(result
1091            .message
1092            .as_deref()
1093            .unwrap_or("")
1094            .contains("Unknown task_id"));
1095    }
1096
1097    #[tokio::test]
1098    async fn cancel_tool_already_completed_returns_error() {
1099        let mgr = Arc::new(BackgroundTaskManager::new());
1100        let tid = mgr
1101            .spawn("thread-1", "shell", "done", |_| async {
1102                super::super::types::TaskResult::Success(Value::Null)
1103            })
1104            .await;
1105        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1106
1107        let tool = TaskCancelTool::new(mgr);
1108        let fix = fixture_with_thread("thread-1");
1109        let result = tool
1110            .execute(json!({"task_id": tid}), &fix.ctx())
1111            .await
1112            .unwrap();
1113        assert!(!result.is_success());
1114        assert!(result
1115            .message
1116            .as_deref()
1117            .unwrap_or("")
1118            .contains("not running"));
1119    }
1120
1121    #[tokio::test]
1122    async fn cancel_tool_wrong_owner_rejected() {
1123        let mgr = Arc::new(BackgroundTaskManager::new());
1124        let tid = mgr
1125            .spawn("thread-1", "shell", "private", |cancel| async move {
1126                cancel.cancelled().await;
1127                super::super::types::TaskResult::Cancelled
1128            })
1129            .await;
1130
1131        let tool = TaskCancelTool::new(mgr);
1132        let fix = fixture_with_thread("thread-2");
1133        let result = tool
1134            .execute(json!({"task_id": tid}), &fix.ctx())
1135            .await
1136            .unwrap();
1137        assert!(!result.is_success());
1138    }
1139
1140    // -----------------------------------------------------------------------
1141    // TaskOutputTool execute() tests
1142    // -----------------------------------------------------------------------
1143
1144    #[test]
1145    fn output_tool_descriptor_requires_task_id() {
1146        let mgr = Arc::new(BackgroundTaskManager::new());
1147        let tool = TaskOutputTool::new(mgr, None);
1148        let desc = tool.descriptor();
1149        assert_eq!(desc.id, TASK_OUTPUT_TOOL_ID);
1150        let required = desc.parameters["required"].as_array().unwrap();
1151        assert!(required.contains(&json!("task_id")));
1152    }
1153
1154    #[tokio::test]
1155    async fn output_tool_missing_task_id_returns_error() {
1156        let mgr = Arc::new(BackgroundTaskManager::new());
1157        let tool = TaskOutputTool::new(mgr, None);
1158        let fix = fixture_with_thread("thread-1");
1159        let err = tool.execute(json!({}), &fix.ctx()).await.unwrap_err();
1160        assert!(matches!(err, ToolError::InvalidArguments(_)));
1161    }
1162
1163    #[tokio::test]
1164    async fn output_tool_unknown_task_returns_error() {
1165        let mgr = Arc::new(BackgroundTaskManager::new());
1166        let tool = TaskOutputTool::new(mgr, None);
1167        let fix = fixture_with_thread("thread-1");
1168        let result = tool
1169            .execute(json!({"task_id": "nonexistent"}), &fix.ctx())
1170            .await
1171            .unwrap();
1172        assert!(!result.is_success());
1173        assert!(result
1174            .message
1175            .as_deref()
1176            .unwrap_or("")
1177            .contains("Unknown task_id"));
1178    }
1179
1180    #[tokio::test]
1181    async fn output_tool_returns_result_for_non_agent_task() {
1182        let mgr = Arc::new(BackgroundTaskManager::new());
1183        let tid = mgr
1184            .spawn("thread-1", "shell", "echo hi", |_| async {
1185                super::super::types::TaskResult::Success(json!({"exit_code": 0, "stdout": "hi"}))
1186            })
1187            .await;
1188        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1189
1190        let tool = TaskOutputTool::new(mgr, None);
1191        let fix = fixture_with_thread("thread-1");
1192        let result = tool
1193            .execute(json!({"task_id": tid}), &fix.ctx())
1194            .await
1195            .unwrap();
1196        assert!(result.is_success());
1197        assert_eq!(result.data["task_type"], "shell");
1198        assert_eq!(result.data["status"], "completed");
1199        assert_eq!(result.data["output"]["exit_code"], 0);
1200        assert_eq!(result.data["output"]["stdout"], "hi");
1201    }
1202
1203    #[tokio::test]
1204    async fn output_tool_reads_persisted_state_from_task_store() {
1205        let mgr = Arc::new(BackgroundTaskManager::new());
1206        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
1207        let task_store = Arc::new(TaskStore::new(storage as Arc<dyn ThreadStore>));
1208        task_store
1209            .create_task(super::super::NewTaskSpec {
1210                task_id: "run-1".to_string(),
1211                owner_thread_id: "thread-1".to_string(),
1212                task_type: "shell".to_string(),
1213                description: "echo test".to_string(),
1214                parent_task_id: None,
1215                supports_resume: false,
1216                metadata: json!({}),
1217            })
1218            .await
1219            .unwrap();
1220        task_store
1221            .persist_foreground_result(
1222                "run-1",
1223                TaskStatus::Completed,
1224                None,
1225                Some(json!({"stdout":"test"})),
1226            )
1227            .await
1228            .unwrap();
1229
1230        let tool = TaskOutputTool::new(mgr, None).with_task_store(Some(task_store));
1231        let fix = fixture_with_thread("thread-1");
1232        let result = tool
1233            .execute(json!({"task_id": "run-1"}), &fix.ctx())
1234            .await
1235            .unwrap();
1236        assert!(result.is_success());
1237        assert_eq!(result.data["task_type"], "shell");
1238        assert_eq!(result.data["status"], "completed");
1239        assert_eq!(result.data["output"]["stdout"], "test");
1240    }
1241
1242    #[tokio::test]
1243    async fn output_tool_without_task_store_cannot_read_persisted_task() {
1244        let mgr = Arc::new(BackgroundTaskManager::new());
1245        let tool = TaskOutputTool::new(mgr, None);
1246        let fix = fixture_with_thread("thread-1");
1247        let result = tool
1248            .execute(json!({"task_id": "run-1"}), &fix.ctx())
1249            .await
1250            .unwrap();
1251        assert!(!result.is_success());
1252        assert!(result
1253            .message
1254            .as_deref()
1255            .unwrap_or("")
1256            .contains("Unknown task_id"));
1257    }
1258
1259    #[tokio::test]
1260    async fn output_tool_does_not_read_cached_derived_view() {
1261        let mgr = Arc::new(BackgroundTaskManager::new());
1262        let tool = TaskOutputTool::new(mgr, None);
1263        let mut fix = TestFixture::new_with_state(json!({
1264            "__derived": {
1265                "background_tasks": {
1266                    "tasks": {
1267                        "ghost": {
1268                            "task_type": "shell",
1269                            "description": "ghost task",
1270                            "status": "running"
1271                        }
1272                    },
1273                    "synced_at_ms": 1
1274                }
1275            }
1276        }));
1277        fix.caller_context = CallerContext::new(
1278            Some("thread-1".to_string()),
1279            Some("caller-run".to_string()),
1280            Some("caller-agent".to_string()),
1281            vec![],
1282        );
1283
1284        let result = tool
1285            .execute(json!({"task_id": "ghost"}), &fix.ctx())
1286            .await
1287            .unwrap();
1288        assert!(!result.is_success());
1289        assert!(result
1290            .message
1291            .as_deref()
1292            .unwrap_or("")
1293            .contains("Unknown task_id"));
1294    }
1295
1296    #[tokio::test]
1297    async fn output_tool_thread_isolation() {
1298        let mgr = Arc::new(BackgroundTaskManager::new());
1299        let tid = mgr
1300            .spawn("thread-A", "shell", "private", |_| async {
1301                super::super::types::TaskResult::Success(json!("secret"))
1302            })
1303            .await;
1304        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1305
1306        let tool = TaskOutputTool::new(mgr, None);
1307
1308        // Thread-B cannot see thread-A's task.
1309        let fix_b = fixture_with_thread("thread-B");
1310        let result = tool
1311            .execute(json!({"task_id": tid}), &fix_b.ctx())
1312            .await
1313            .unwrap();
1314        assert!(!result.is_success());
1315
1316        // Thread-A can see it.
1317        let fix_a = fixture_with_thread("thread-A");
1318        let result = tool
1319            .execute(json!({"task_id": tid}), &fix_a.ctx())
1320            .await
1321            .unwrap();
1322        assert!(result.is_success());
1323    }
1324}