tirea_agentos/runtime/background_tasks/
manager.rs

1//! In-memory background task handle table and spawner.
2//!
3//! Manages the lifecycle of background tasks: spawn, track, cancel, query.
4//! Epoch-based stale-completion guards prevent double updates. When a durable
5//! [`TaskStore`] is configured, terminal summaries are persisted directly from
6//! the manager and persistence failures remain visible on the live handle until
7//! they are retried successfully.
8
9use super::store::TaskStore;
10use super::types::*;
11use crate::runtime::loop_runner::RunCancellationToken;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::future::Future;
15use std::sync::Arc;
16use std::time::{SystemTime, UNIX_EPOCH};
17use tokio::sync::Mutex;
18
19fn now_ms() -> u64 {
20    SystemTime::now()
21        .duration_since(UNIX_EPOCH)
22        .unwrap_or_default()
23        .as_millis()
24        .min(u128::from(u64::MAX)) as u64
25}
26
27/// In-memory runtime handle for a single background task.
28#[derive(Debug)]
29struct TaskHandle {
30    epoch: u64,
31    owner_thread_id: String,
32    task_type: String,
33    description: String,
34    status: TaskStatus,
35    error: Option<String>,
36    result: Option<Value>,
37    cancel_token: RunCancellationToken,
38    cancellation_requested: bool,
39    created_at_ms: u64,
40    completed_at_ms: Option<u64>,
41    parent_task_id: Option<TaskId>,
42    metadata: Value,
43    persistence_error: Option<String>,
44}
45
46/// Thread-scoped background task manager.
47///
48/// Manages the full lifecycle: spawn → track → cancel → query.
49/// Tasks outlive individual runs but are scoped to a thread.
50#[derive(Clone)]
51pub struct BackgroundTaskManager {
52    handles: Arc<Mutex<HashMap<TaskId, TaskHandle>>>,
53    task_store: Option<Arc<TaskStore>>,
54}
55
56impl std::fmt::Debug for BackgroundTaskManager {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("BackgroundTaskManager")
59            .field("task_count", &"<locked>")
60            .finish()
61    }
62}
63
64impl Default for BackgroundTaskManager {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl BackgroundTaskManager {
71    pub fn new() -> Self {
72        Self {
73            handles: Arc::new(Mutex::new(HashMap::new())),
74            task_store: None,
75        }
76    }
77
78    pub fn with_task_store(task_store: Option<Arc<TaskStore>>) -> Self {
79        Self {
80            handles: Arc::new(Mutex::new(HashMap::new())),
81            task_store,
82        }
83    }
84
85    /// Spawn a background task with the given closure.
86    ///
87    /// Returns the generated `TaskId` immediately. The closure receives a
88    /// `CancellationToken` for cooperative cancellation.
89    pub async fn spawn<F, Fut>(
90        &self,
91        owner_thread_id: &str,
92        task_type: &str,
93        description: &str,
94        task: F,
95    ) -> TaskId
96    where
97        F: FnOnce(RunCancellationToken) -> Fut + Send + 'static,
98        Fut: Future<Output = TaskResult> + Send,
99    {
100        let params = SpawnParams {
101            task_id: new_task_id(),
102            owner_thread_id: owner_thread_id.to_string(),
103            task_type: task_type.to_string(),
104            description: description.to_string(),
105            parent_task_id: None,
106            metadata: Value::Object(serde_json::Map::new()),
107        };
108        self.spawn_impl(params, None, task).await
109    }
110
111    /// Spawn a background task with a caller-supplied ID and cancellation token.
112    ///
113    /// Use this when the caller already owns an ID (e.g. a sub-agent `run_id`)
114    /// and/or a cancellation token that is shared with other subsystems.
115    pub async fn spawn_with_id<F, Fut>(
116        &self,
117        params: SpawnParams,
118        cancel_token: RunCancellationToken,
119        task: F,
120    ) -> TaskId
121    where
122        F: FnOnce(RunCancellationToken) -> Fut + Send + 'static,
123        Fut: Future<Output = TaskResult> + Send,
124    {
125        self.spawn_impl(params, Some(cancel_token), task).await
126    }
127
128    async fn spawn_impl<F, Fut>(
129        &self,
130        params: SpawnParams,
131        external_cancel_token: Option<RunCancellationToken>,
132        task: F,
133    ) -> TaskId
134    where
135        F: FnOnce(RunCancellationToken) -> Fut + Send + 'static,
136        Fut: Future<Output = TaskResult> + Send,
137    {
138        let SpawnParams {
139            task_id,
140            owner_thread_id,
141            task_type,
142            description,
143            parent_task_id,
144            metadata,
145        } = params;
146        let cancel_token = external_cancel_token.unwrap_or_default();
147        let now = now_ms();
148
149        let epoch = {
150            let mut handles = self.handles.lock().await;
151            let epoch = handles.get(&task_id).map(|h| h.epoch + 1).unwrap_or(1);
152            handles.insert(
153                task_id.clone(),
154                TaskHandle {
155                    epoch,
156                    owner_thread_id,
157                    task_type,
158                    description,
159                    status: TaskStatus::Running,
160                    error: None,
161                    result: None,
162                    cancel_token: cancel_token.clone(),
163                    cancellation_requested: false,
164                    created_at_ms: now,
165                    completed_at_ms: None,
166                    parent_task_id,
167                    metadata,
168                    persistence_error: None,
169                },
170            );
171            epoch
172        };
173
174        let handles = self.handles.clone();
175        let task_store = self.task_store.clone();
176        let tid = task_id.clone();
177
178        tokio::spawn(async move {
179            let result = task(cancel_token).await;
180
181            // Update in-memory handle.
182            let completed_at = now_ms();
183            {
184                let mut map = handles.lock().await;
185                if let Some(handle) = map.get_mut(&tid) {
186                    if handle.epoch != epoch {
187                        return; // Stale completion.
188                    }
189                    let status = if handle.cancellation_requested {
190                        TaskStatus::Cancelled
191                    } else {
192                        result.status()
193                    };
194                    handle.status = status;
195                    handle.error = match &result {
196                        TaskResult::Failed(e) => Some(e.clone()),
197                        _ => None,
198                    };
199                    handle.result = match &result {
200                        TaskResult::Success(v) => Some(v.clone()),
201                        _ => None,
202                    };
203                    handle.completed_at_ms = Some(completed_at);
204                }
205            }
206
207            if let Some(task_store) = task_store {
208                let maybe_summary = {
209                    let map = handles.lock().await;
210                    map.get(&tid)
211                        .filter(|handle| handle.epoch == epoch)
212                        .map(|handle| summary_from_handle(&tid, handle))
213                };
214                if let Some(summary) = maybe_summary {
215                    let persistence_error = task_store
216                        .persist_summary(&summary)
217                        .await
218                        .err()
219                        .map(|e| e.to_string());
220                    let mut map = handles.lock().await;
221                    if let Some(handle) = map.get_mut(&tid) {
222                        if handle.epoch != epoch {
223                            return;
224                        }
225                        handle.persistence_error = persistence_error;
226                    }
227                }
228            }
229        });
230
231        task_id
232    }
233
234    /// Cancel a task owned by the given thread. Returns `true` if cancelled.
235    pub async fn cancel(&self, owner_thread_id: &str, task_id: &str) -> Result<(), String> {
236        let mut handles = self.handles.lock().await;
237        let Some(handle) = handles.get_mut(task_id) else {
238            return Err(format!("Unknown task_id: {task_id}"));
239        };
240        if handle.owner_thread_id != owner_thread_id {
241            return Err(format!("Unknown task_id: {task_id}"));
242        }
243        if handle.status != TaskStatus::Running {
244            return Err(format!(
245                "Task '{task_id}' is not running (current status: {})",
246                handle.status.as_str()
247            ));
248        }
249        handle.cancellation_requested = true;
250        handle.cancel_token.cancel();
251        Ok(())
252    }
253
254    /// Get a summary of a single task.
255    pub async fn get(&self, owner_thread_id: &str, task_id: &str) -> Option<TaskSummary> {
256        self.retry_persistence(owner_thread_id, Some(task_id)).await;
257        let handles = self.handles.lock().await;
258        let handle = handles.get(task_id)?;
259        if handle.owner_thread_id != owner_thread_id {
260            return None;
261        }
262        Some(summary_from_handle(task_id, handle))
263    }
264
265    /// List all tasks for a thread, optionally filtered by status.
266    pub async fn list(
267        &self,
268        owner_thread_id: &str,
269        status_filter: Option<TaskStatus>,
270    ) -> Vec<TaskSummary> {
271        self.retry_persistence(owner_thread_id, None).await;
272        let handles = self.handles.lock().await;
273        let mut out: Vec<TaskSummary> = handles
274            .iter()
275            .filter(|(_, h)| h.owner_thread_id == owner_thread_id)
276            .filter(|(_, h)| status_filter.is_none_or(|s| s == h.status))
277            .map(|(id, h)| summary_from_handle(id, h))
278            .collect();
279        out.sort_by(|a, b| a.created_at_ms.cmp(&b.created_at_ms));
280        out
281    }
282
283    /// Check if there are running tasks for a thread.
284    pub async fn has_running_tasks(&self, owner_thread_id: &str) -> bool {
285        let handles = self.handles.lock().await;
286        handles
287            .values()
288            .any(|h| h.owner_thread_id == owner_thread_id && h.status == TaskStatus::Running)
289    }
290
291    /// Remove completed/terminal task entries from the in-memory table.
292    pub async fn gc_terminal(&self, owner_thread_id: &str) -> usize {
293        self.retry_persistence(owner_thread_id, None).await;
294        let mut handles = self.handles.lock().await;
295        let before = handles.len();
296        handles.retain(|_, h| {
297            h.owner_thread_id != owner_thread_id
298                || !h.status.is_terminal()
299                || h.persistence_error.is_some()
300        });
301        before - handles.len()
302    }
303
304    /// Check if a task exists for the given thread.
305    pub async fn contains(&self, owner_thread_id: &str, task_id: &str) -> bool {
306        let handles = self.handles.lock().await;
307        handles
308            .get(task_id)
309            .is_some_and(|h| h.owner_thread_id == owner_thread_id)
310    }
311
312    /// Check if a task exists in any thread.
313    pub async fn contains_any(&self, task_id: &str) -> bool {
314        let handles = self.handles.lock().await;
315        handles.contains_key(task_id)
316    }
317
318    /// Cancel a task and all its descendants (by `parent_task_id` chain).
319    ///
320    /// Returns summaries of all tasks that were cancelled. The root task must
321    /// be owned by `owner_thread_id`; descendants are found by traversing
322    /// `parent_task_id` links within the same owner.
323    pub async fn cancel_tree(
324        &self,
325        owner_thread_id: &str,
326        task_id: &str,
327    ) -> Result<Vec<TaskSummary>, String> {
328        let mut handles = self.handles.lock().await;
329        let Some(root) = handles.get(task_id) else {
330            return Err(format!("Unknown task_id: {task_id}"));
331        };
332        if root.owner_thread_id != owner_thread_id {
333            return Err(format!("Unknown task_id: {task_id}"));
334        }
335
336        // Collect the full descendant tree.
337        let tree_ids = collect_descendant_ids(&handles, owner_thread_id, task_id, true);
338        if tree_ids.is_empty() {
339            return Err(format!(
340                "Task '{task_id}' is not running (current status: {})",
341                root.status.as_str()
342            ));
343        }
344
345        let mut cancelled = false;
346        let mut out = Vec::with_capacity(tree_ids.len());
347        for id in tree_ids {
348            if let Some(handle) = handles.get_mut(&id) {
349                if handle.status == TaskStatus::Running {
350                    handle.cancellation_requested = true;
351                    handle.cancel_token.cancel();
352                    cancelled = true;
353                }
354                out.push(summary_from_handle(&id, handle));
355            }
356        }
357
358        if cancelled {
359            Ok(out)
360        } else {
361            Err(format!(
362                "Task '{task_id}' is not running (current status: {})",
363                handles
364                    .get(task_id)
365                    .map(|h| h.status.as_str())
366                    .unwrap_or("unknown")
367            ))
368        }
369    }
370
371    /// Externally update a task's status (e.g. recovery marking orphans as stopped).
372    pub async fn update_status(
373        &self,
374        task_id: &str,
375        status: TaskStatus,
376        error: Option<String>,
377    ) -> bool {
378        let mut handles = self.handles.lock().await;
379        if let Some(handle) = handles.get_mut(task_id) {
380            handle.status = status;
381            handle.error = error;
382            if status.is_terminal() {
383                handle.completed_at_ms = Some(now_ms());
384            }
385            true
386        } else {
387            false
388        }
389    }
390
391    /// List tasks filtered by `task_type`, optionally filtered by status.
392    pub async fn list_by_type(
393        &self,
394        owner_thread_id: &str,
395        task_type: &str,
396        status_filter: Option<TaskStatus>,
397    ) -> Vec<TaskSummary> {
398        let handles = self.handles.lock().await;
399        let mut out: Vec<TaskSummary> = handles
400            .iter()
401            .filter(|(_, h)| h.owner_thread_id == owner_thread_id)
402            .filter(|(_, h)| h.task_type == task_type)
403            .filter(|(_, h)| status_filter.is_none_or(|s| s == h.status))
404            .map(|(id, h)| summary_from_handle(id, h))
405            .collect();
406        out.sort_by(|a, b| a.created_at_ms.cmp(&b.created_at_ms));
407        out
408    }
409
410    async fn retry_persistence(&self, owner_thread_id: &str, task_id: Option<&str>) {
411        let Some(task_store) = self.task_store.as_ref().cloned() else {
412            return;
413        };
414
415        let candidates: Vec<(TaskId, u64, TaskSummary)> = {
416            let handles = self.handles.lock().await;
417            handles
418                .iter()
419                .filter(|(id, handle)| {
420                    handle.owner_thread_id == owner_thread_id
421                        && handle.status.is_terminal()
422                        && handle.persistence_error.is_some()
423                        && task_id.is_none_or(|wanted| wanted == id.as_str())
424                })
425                .map(|(id, handle)| (id.clone(), handle.epoch, summary_from_handle(id, handle)))
426                .collect()
427        };
428
429        for (task_id, epoch, summary) in candidates {
430            let persistence_error = task_store
431                .persist_summary(&summary)
432                .await
433                .err()
434                .map(|e| e.to_string());
435            let mut handles = self.handles.lock().await;
436            if let Some(handle) = handles.get_mut(&task_id) {
437                if handle.epoch != epoch || !handle.status.is_terminal() {
438                    continue;
439                }
440                handle.persistence_error = persistence_error;
441            }
442        }
443    }
444}
445
446/// Collect task IDs forming the subtree rooted at `root_id` via `parent_task_id` links.
447fn collect_descendant_ids(
448    handles: &HashMap<TaskId, TaskHandle>,
449    owner_thread_id: &str,
450    root_id: &str,
451    include_root: bool,
452) -> Vec<String> {
453    // Build children-by-parent index.
454    let mut children_by_parent: HashMap<&str, Vec<&str>> = HashMap::new();
455    for (id, h) in handles.iter() {
456        if h.owner_thread_id != owner_thread_id {
457            continue;
458        }
459        if let Some(parent) = h.parent_task_id.as_deref() {
460            children_by_parent
461                .entry(parent)
462                .or_default()
463                .push(id.as_str());
464        }
465    }
466
467    let mut result = Vec::new();
468    let mut stack = vec![root_id];
469    let mut is_root = true;
470    while let Some(current) = stack.pop() {
471        if !is_root || include_root {
472            // Only include running tasks (for cancel_tree semantics).
473            if handles
474                .get(current)
475                .is_some_and(|h| h.status == TaskStatus::Running)
476            {
477                result.push(current.to_string());
478            }
479        }
480        is_root = false;
481        if let Some(children) = children_by_parent.get(current) {
482            for child in children {
483                stack.push(child);
484            }
485        }
486    }
487    result
488}
489
490fn summary_from_handle(task_id: &str, handle: &TaskHandle) -> TaskSummary {
491    TaskSummary {
492        task_id: task_id.to_string(),
493        task_type: handle.task_type.clone(),
494        description: handle.description.clone(),
495        status: handle.status,
496        error: handle.error.clone(),
497        result: handle.result.clone(),
498        result_ref: None,
499        created_at_ms: handle.created_at_ms,
500        completed_at_ms: handle.completed_at_ms,
501        parent_task_id: handle.parent_task_id.clone(),
502        supports_resume: handle.task_type == "agent_run",
503        attempt: 0,
504        metadata: handle.metadata.clone(),
505        persistence_error: handle.persistence_error.clone(),
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::contracts::storage::{
513        MessagePage, MessageQuery, RunPage, RunQuery, RunRecord, ThreadHead, ThreadListPage,
514        ThreadListQuery, ThreadReader, ThreadStore, ThreadStoreError, ThreadWriter,
515        VersionPrecondition,
516    };
517    use crate::contracts::thread::{Thread, ThreadChangeSet};
518    use std::sync::atomic::{AtomicUsize, Ordering};
519
520    struct FlakyTaskThreadStore {
521        inner: Arc<tirea_store_adapters::MemoryStore>,
522        fail_task_appends: AtomicUsize,
523    }
524
525    impl FlakyTaskThreadStore {
526        fn new(fail_task_appends: usize) -> Arc<Self> {
527            Arc::new(Self {
528                inner: Arc::new(tirea_store_adapters::MemoryStore::new()),
529                fail_task_appends: AtomicUsize::new(fail_task_appends),
530            })
531        }
532
533        fn set_failures(&self, failures: usize) {
534            self.fail_task_appends.store(failures, Ordering::SeqCst);
535        }
536
537        fn remaining_failures(&self) -> usize {
538            self.fail_task_appends.load(Ordering::SeqCst)
539        }
540    }
541
542    #[async_trait::async_trait]
543    impl ThreadReader for FlakyTaskThreadStore {
544        async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
545            self.inner.load(thread_id).await
546        }
547
548        async fn list_threads(
549            &self,
550            query: &ThreadListQuery,
551        ) -> Result<ThreadListPage, ThreadStoreError> {
552            self.inner.list_threads(query).await
553        }
554
555        async fn load_messages(
556            &self,
557            thread_id: &str,
558            query: &MessageQuery,
559        ) -> Result<MessagePage, ThreadStoreError> {
560            self.inner.load_messages(thread_id, query).await
561        }
562
563        async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, ThreadStoreError> {
564            self.inner.load_run(run_id).await
565        }
566
567        async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, ThreadStoreError> {
568            self.inner.list_runs(query).await
569        }
570
571        async fn active_run_for_thread(
572            &self,
573            thread_id: &str,
574        ) -> Result<Option<RunRecord>, ThreadStoreError> {
575            self.inner.active_run_for_thread(thread_id).await
576        }
577    }
578
579    #[async_trait::async_trait]
580    impl ThreadWriter for FlakyTaskThreadStore {
581        async fn create(
582            &self,
583            thread: &Thread,
584        ) -> Result<crate::contracts::storage::Committed, ThreadStoreError> {
585            self.inner.create(thread).await
586        }
587
588        async fn append(
589            &self,
590            thread_id: &str,
591            delta: &ThreadChangeSet,
592            precondition: VersionPrecondition,
593        ) -> Result<crate::contracts::storage::Committed, ThreadStoreError> {
594            if thread_id.starts_with(TASK_THREAD_PREFIX)
595                && self
596                    .fail_task_appends
597                    .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| {
598                        if remaining > 0 {
599                            Some(remaining - 1)
600                        } else {
601                            None
602                        }
603                    })
604                    .is_ok()
605            {
606                return Err(ThreadStoreError::Io(std::io::Error::other(
607                    "injected task persistence failure",
608                )));
609            }
610            self.inner.append(thread_id, delta, precondition).await
611        }
612
613        async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
614            self.inner.delete(thread_id).await
615        }
616
617        async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
618            self.inner.save(thread).await
619        }
620    }
621
622    #[tokio::test]
623    async fn spawn_and_complete_success() {
624        let mgr = BackgroundTaskManager::new();
625        let tid = mgr
626            .spawn("thread-1", "shell", "echo hello", |_cancel| async {
627                TaskResult::Success(serde_json::json!({ "exit_code": 0 }))
628            })
629            .await;
630
631        // Wait for the spawned task to complete.
632        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
633
634        let summary = mgr.get("thread-1", &tid).await.expect("task should exist");
635        assert_eq!(summary.status, TaskStatus::Completed);
636        assert!(summary.result.is_some());
637        assert!(summary.error.is_none());
638        assert!(summary.completed_at_ms.is_some());
639    }
640
641    #[tokio::test]
642    async fn spawn_and_complete_failure() {
643        let mgr = BackgroundTaskManager::new();
644        let tid = mgr
645            .spawn("thread-1", "shell", "bad cmd", |_cancel| async {
646                TaskResult::Failed("command not found".into())
647            })
648            .await;
649
650        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
651
652        let summary = mgr.get("thread-1", &tid).await.unwrap();
653        assert_eq!(summary.status, TaskStatus::Failed);
654        assert_eq!(summary.error.as_deref(), Some("command not found"));
655    }
656
657    #[tokio::test]
658    async fn cancel_running_task() {
659        let mgr = BackgroundTaskManager::new();
660        let tid = mgr
661            .spawn("thread-1", "shell", "long running", |cancel| async move {
662                cancel.cancelled().await;
663                TaskResult::Cancelled
664            })
665            .await;
666
667        // Task should be running.
668        let summary = mgr.get("thread-1", &tid).await.unwrap();
669        assert_eq!(summary.status, TaskStatus::Running);
670
671        // Cancel it.
672        mgr.cancel("thread-1", &tid).await.unwrap();
673
674        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
675
676        let summary = mgr.get("thread-1", &tid).await.unwrap();
677        assert_eq!(summary.status, TaskStatus::Cancelled);
678    }
679
680    #[tokio::test]
681    async fn cancel_wrong_owner_rejected() {
682        let mgr = BackgroundTaskManager::new();
683        let tid = mgr
684            .spawn("thread-1", "shell", "task", |cancel| async move {
685                cancel.cancelled().await;
686                TaskResult::Cancelled
687            })
688            .await;
689
690        let result = mgr.cancel("thread-other", &tid).await;
691        assert!(result.is_err());
692    }
693
694    #[tokio::test]
695    async fn list_filters_by_owner_and_status() {
696        let mgr = BackgroundTaskManager::new();
697
698        // Running task for thread-1.
699        let _t1 = mgr
700            .spawn("thread-1", "shell", "task-a", |cancel| async move {
701                cancel.cancelled().await;
702                TaskResult::Cancelled
703            })
704            .await;
705
706        // Completed task for thread-1.
707        mgr.spawn("thread-1", "http", "task-b", |_cancel| async {
708            TaskResult::Success(Value::Null)
709        })
710        .await;
711
712        // Task for thread-2 (should not appear).
713        mgr.spawn("thread-2", "shell", "task-c", |cancel| async move {
714            cancel.cancelled().await;
715            TaskResult::Cancelled
716        })
717        .await;
718
719        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
720
721        let all = mgr.list("thread-1", None).await;
722        assert_eq!(all.len(), 2);
723
724        let running = mgr.list("thread-1", Some(TaskStatus::Running)).await;
725        assert_eq!(running.len(), 1);
726        assert_eq!(running[0].task_type, "shell");
727
728        let completed = mgr.list("thread-1", Some(TaskStatus::Completed)).await;
729        assert_eq!(completed.len(), 1);
730        assert_eq!(completed[0].task_type, "http");
731    }
732
733    #[tokio::test]
734    async fn has_running_tasks_reflects_state() {
735        let mgr = BackgroundTaskManager::new();
736
737        assert!(!mgr.has_running_tasks("thread-1").await);
738
739        let tid = mgr
740            .spawn("thread-1", "shell", "task", |cancel| async move {
741                cancel.cancelled().await;
742                TaskResult::Cancelled
743            })
744            .await;
745
746        assert!(mgr.has_running_tasks("thread-1").await);
747
748        mgr.cancel("thread-1", &tid).await.unwrap();
749        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
750
751        assert!(!mgr.has_running_tasks("thread-1").await);
752    }
753
754    #[tokio::test]
755    async fn gc_terminal_removes_completed_tasks() {
756        let mgr = BackgroundTaskManager::new();
757
758        mgr.spawn("thread-1", "shell", "done", |_| async {
759            TaskResult::Success(Value::Null)
760        })
761        .await;
762
763        let _running = mgr
764            .spawn("thread-1", "shell", "still going", |cancel| async move {
765                cancel.cancelled().await;
766                TaskResult::Cancelled
767            })
768            .await;
769
770        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
771
772        let removed = mgr.gc_terminal("thread-1").await;
773        assert_eq!(removed, 1);
774
775        let all = mgr.list("thread-1", None).await;
776        assert_eq!(all.len(), 1);
777        assert_eq!(all[0].status, TaskStatus::Running);
778    }
779
780    #[tokio::test]
781    async fn get_returns_none_for_wrong_owner() {
782        let mgr = BackgroundTaskManager::new();
783        let tid = mgr
784            .spawn("thread-1", "shell", "task", |_| async {
785                TaskResult::Success(Value::Null)
786            })
787            .await;
788
789        assert!(mgr.get("thread-other", &tid).await.is_none());
790        assert!(mgr.get("thread-1", &tid).await.is_some());
791    }
792
793    #[tokio::test]
794    async fn cancel_already_completed_returns_error() {
795        let mgr = BackgroundTaskManager::new();
796        let tid = mgr
797            .spawn("thread-1", "shell", "fast", |_| async {
798                TaskResult::Success(Value::Null)
799            })
800            .await;
801
802        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
803
804        let err = mgr.cancel("thread-1", &tid).await.unwrap_err();
805        assert!(err.contains("not running"));
806    }
807
808    // -----------------------------------------------------------------------
809    // Concurrency & edge-case tests
810    // -----------------------------------------------------------------------
811
812    #[tokio::test]
813    async fn many_concurrent_spawns_all_tracked() {
814        let mgr = BackgroundTaskManager::new();
815        let mut ids = Vec::new();
816
817        for i in 0..20 {
818            let desc = format!("task-{i}");
819            let tid = mgr
820                .spawn("thread-1", "shell", &desc, |cancel| async move {
821                    cancel.cancelled().await;
822                    TaskResult::Cancelled
823                })
824                .await;
825            ids.push(tid);
826        }
827
828        let all = mgr.list("thread-1", None).await;
829        assert_eq!(all.len(), 20);
830
831        // All tasks should be running.
832        let running = mgr.list("thread-1", Some(TaskStatus::Running)).await;
833        assert_eq!(running.len(), 20);
834
835        // Cancel all.
836        for tid in &ids {
837            mgr.cancel("thread-1", tid).await.unwrap();
838        }
839        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
840
841        let cancelled = mgr.list("thread-1", Some(TaskStatus::Cancelled)).await;
842        assert_eq!(cancelled.len(), 20);
843    }
844
845    #[tokio::test]
846    async fn concurrent_cancel_and_complete_race() {
847        // Task completes very quickly; cancel may arrive after completion.
848        let mgr = BackgroundTaskManager::new();
849        let tid = mgr
850            .spawn("thread-1", "shell", "fast", |_| async {
851                TaskResult::Success(Value::Null)
852            })
853            .await;
854
855        // Race: try to cancel immediately (may or may not succeed).
856        let cancel_result = mgr.cancel("thread-1", &tid).await;
857        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
858
859        let summary = mgr.get("thread-1", &tid).await.unwrap();
860        // Either cancelled (if we beat the completion) or completed (if not).
861        assert!(
862            summary.status == TaskStatus::Cancelled || summary.status == TaskStatus::Completed,
863            "unexpected status: {:?}",
864            summary.status
865        );
866
867        // If cancel succeeded, status should be Cancelled.
868        if cancel_result.is_ok() {
869            assert_eq!(summary.status, TaskStatus::Cancelled);
870        }
871    }
872
873    #[tokio::test]
874    async fn task_with_tokio_select_respects_cancellation() {
875        let mgr = BackgroundTaskManager::new();
876        let tid = mgr
877            .spawn("thread-1", "shell", "select-based", |cancel| async move {
878                tokio::select! {
879                    _ = cancel.cancelled() => TaskResult::Cancelled,
880                    _ = tokio::time::sleep(std::time::Duration::from_secs(60)) => {
881                        TaskResult::Success(Value::Null)
882                    }
883                }
884            })
885            .await;
886
887        assert!(mgr.has_running_tasks("thread-1").await);
888
889        mgr.cancel("thread-1", &tid).await.unwrap();
890        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
891
892        let summary = mgr.get("thread-1", &tid).await.unwrap();
893        assert_eq!(summary.status, TaskStatus::Cancelled);
894        assert!(!mgr.has_running_tasks("thread-1").await);
895    }
896
897    #[tokio::test]
898    async fn task_failure_with_panic_safety() {
899        // A task that returns a failure result (not a panic).
900        let mgr = BackgroundTaskManager::new();
901        let tid = mgr
902            .spawn("thread-1", "http", "timeout", |_| async {
903                TaskResult::Failed("connection timed out".into())
904            })
905            .await;
906
907        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
908
909        let summary = mgr.get("thread-1", &tid).await.unwrap();
910        assert_eq!(summary.status, TaskStatus::Failed);
911        assert_eq!(summary.error.as_deref(), Some("connection timed out"));
912        assert!(summary.completed_at_ms.is_some());
913    }
914
915    #[tokio::test]
916    async fn gc_only_affects_specified_thread() {
917        let mgr = BackgroundTaskManager::new();
918
919        // Completed task on thread-1.
920        mgr.spawn("thread-1", "shell", "done-1", |_| async {
921            TaskResult::Success(Value::Null)
922        })
923        .await;
924        // Completed task on thread-2.
925        mgr.spawn("thread-2", "shell", "done-2", |_| async {
926            TaskResult::Success(Value::Null)
927        })
928        .await;
929
930        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
931
932        // GC only thread-1.
933        let removed = mgr.gc_terminal("thread-1").await;
934        assert_eq!(removed, 1);
935
936        // Thread-2's task should still be there.
937        let t2_tasks = mgr.list("thread-2", None).await;
938        assert_eq!(t2_tasks.len(), 1);
939    }
940
941    #[tokio::test]
942    async fn task_summary_has_timing_info() {
943        let mgr = BackgroundTaskManager::new();
944        let tid = mgr
945            .spawn("thread-1", "shell", "timed", |_| async {
946                TaskResult::Success(Value::Null)
947            })
948            .await;
949
950        // While running, should have created_at but no completed_at.
951        let running = mgr.get("thread-1", &tid).await.unwrap();
952        assert!(running.created_at_ms > 0);
953        // Note: task may have already completed since it's instant.
954
955        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
956
957        let completed = mgr.get("thread-1", &tid).await.unwrap();
958        assert!(completed.created_at_ms > 0);
959        assert!(completed.completed_at_ms.is_some());
960        assert!(completed.completed_at_ms.unwrap() >= completed.created_at_ms);
961    }
962
963    #[tokio::test]
964    async fn list_returns_sorted_by_creation_time() {
965        let mgr = BackgroundTaskManager::new();
966
967        let t1 = mgr
968            .spawn("thread-1", "shell", "first", |cancel| async move {
969                cancel.cancelled().await;
970                TaskResult::Cancelled
971            })
972            .await;
973        // Small delays to guarantee distinct timestamps.
974        tokio::time::sleep(std::time::Duration::from_millis(2)).await;
975        let t2 = mgr
976            .spawn("thread-1", "shell", "second", |cancel| async move {
977                cancel.cancelled().await;
978                TaskResult::Cancelled
979            })
980            .await;
981        tokio::time::sleep(std::time::Duration::from_millis(2)).await;
982        let t3 = mgr
983            .spawn("thread-1", "shell", "third", |cancel| async move {
984                cancel.cancelled().await;
985                TaskResult::Cancelled
986            })
987            .await;
988
989        let tasks = mgr.list("thread-1", None).await;
990        assert_eq!(tasks.len(), 3);
991        assert_eq!(tasks[0].task_id, t1);
992        assert_eq!(tasks[1].task_id, t2);
993        assert_eq!(tasks[2].task_id, t3);
994    }
995
996    #[tokio::test]
997    async fn default_impl_without_task_store_still_tracks_terminal_tasks() {
998        let mgr = BackgroundTaskManager::new();
999        mgr.spawn("thread-1", "shell", "task", |_| async {
1000            TaskResult::Success(Value::Null)
1001        })
1002        .await;
1003        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1004        // No panic, no error — just verify it works.
1005        let tasks = mgr.list("thread-1", None).await;
1006        assert_eq!(tasks.len(), 1);
1007    }
1008
1009    // -----------------------------------------------------------------------
1010    // spawn_with_id tests
1011    // -----------------------------------------------------------------------
1012
1013    #[tokio::test]
1014    async fn spawn_with_id_uses_caller_supplied_id() {
1015        let mgr = BackgroundTaskManager::new();
1016        let token = RunCancellationToken::new();
1017        let tid = mgr
1018            .spawn_with_id(
1019                SpawnParams {
1020                    task_id: "my-custom-id".to_string(),
1021                    owner_thread_id: "thread-1".to_string(),
1022                    task_type: "agent_run".to_string(),
1023                    description: "agent:worker".to_string(),
1024                    parent_task_id: None,
1025                    metadata: serde_json::json!({}),
1026                },
1027                token,
1028                |_cancel: RunCancellationToken| async { TaskResult::Success(Value::Null) },
1029            )
1030            .await;
1031
1032        assert_eq!(tid, "my-custom-id");
1033        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1034
1035        let summary = mgr.get("thread-1", "my-custom-id").await.unwrap();
1036        assert_eq!(summary.task_type, "agent_run");
1037        assert_eq!(summary.description, "agent:worker");
1038        assert_eq!(summary.status, TaskStatus::Completed);
1039    }
1040
1041    #[tokio::test]
1042    async fn spawn_with_id_uses_external_cancel_token() {
1043        let mgr = BackgroundTaskManager::new();
1044        let token = RunCancellationToken::new();
1045        let token_clone = token.clone();
1046
1047        mgr.spawn_with_id(
1048            SpawnParams {
1049                task_id: "cancel-test".to_string(),
1050                owner_thread_id: "thread-1".to_string(),
1051                task_type: "shell".to_string(),
1052                description: "long task".to_string(),
1053                parent_task_id: None,
1054                metadata: serde_json::json!({}),
1055            },
1056            token,
1057            |cancel: RunCancellationToken| async move {
1058                cancel.cancelled().await;
1059                TaskResult::Cancelled
1060            },
1061        )
1062        .await;
1063
1064        // Task should be running.
1065        let summary = mgr.get("thread-1", "cancel-test").await.unwrap();
1066        assert_eq!(summary.status, TaskStatus::Running);
1067
1068        // Cancel via the external token directly.
1069        token_clone.cancel();
1070        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1071
1072        // Task closure returns Cancelled, but cancellation_requested was not set
1073        // via manager.cancel(), so the status uses result.status() = Cancelled.
1074        let summary = mgr.get("thread-1", "cancel-test").await.unwrap();
1075        assert_eq!(summary.status, TaskStatus::Cancelled);
1076    }
1077
1078    #[tokio::test]
1079    async fn spawn_with_id_cancel_via_manager_works() {
1080        let mgr = BackgroundTaskManager::new();
1081        let token = RunCancellationToken::new();
1082
1083        mgr.spawn_with_id(
1084            SpawnParams {
1085                task_id: "mgr-cancel".to_string(),
1086                owner_thread_id: "thread-1".to_string(),
1087                task_type: "agent_run".to_string(),
1088                description: "agent:worker".to_string(),
1089                parent_task_id: None,
1090                metadata: serde_json::json!({}),
1091            },
1092            token,
1093            |cancel: RunCancellationToken| async move {
1094                cancel.cancelled().await;
1095                TaskResult::Cancelled
1096            },
1097        )
1098        .await;
1099
1100        // Cancel via manager (sets cancellation_requested).
1101        mgr.cancel("thread-1", "mgr-cancel").await.unwrap();
1102        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1103
1104        let summary = mgr.get("thread-1", "mgr-cancel").await.unwrap();
1105        assert_eq!(summary.status, TaskStatus::Cancelled);
1106    }
1107
1108    #[tokio::test]
1109    async fn manager_persists_terminal_state_to_task_store_when_configured() {
1110        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
1111        let task_store = Arc::new(TaskStore::new(storage.clone() as Arc<dyn ThreadStore>));
1112        task_store
1113            .create_task(super::super::store::NewTaskSpec {
1114                task_id: "persisted-task".to_string(),
1115                owner_thread_id: "thread-1".to_string(),
1116                task_type: "shell".to_string(),
1117                description: "echo hi".to_string(),
1118                parent_task_id: None,
1119                supports_resume: false,
1120                metadata: Value::Object(serde_json::Map::new()),
1121            })
1122            .await
1123            .unwrap();
1124
1125        let mgr = BackgroundTaskManager::with_task_store(Some(task_store.clone()));
1126        mgr.spawn_with_id(
1127            SpawnParams {
1128                task_id: "persisted-task".to_string(),
1129                owner_thread_id: "thread-1".to_string(),
1130                task_type: "shell".to_string(),
1131                description: "echo hi".to_string(),
1132                parent_task_id: None,
1133                metadata: serde_json::json!({}),
1134            },
1135            RunCancellationToken::new(),
1136            |_cancel: RunCancellationToken| async {
1137                TaskResult::Success(serde_json::json!({ "stdout": "hi" }))
1138            },
1139        )
1140        .await;
1141
1142        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1143
1144        let persisted = task_store
1145            .load_task("persisted-task")
1146            .await
1147            .unwrap()
1148            .expect("task should persist");
1149        assert_eq!(persisted.status, TaskStatus::Completed);
1150        assert_eq!(
1151            persisted.result,
1152            Some(serde_json::json!({ "stdout": "hi" }))
1153        );
1154    }
1155
1156    #[tokio::test]
1157    async fn manager_exposes_persistence_error_and_gc_retains_terminal_task_until_persisted() {
1158        let storage = FlakyTaskThreadStore::new(0);
1159        let task_store = Arc::new(TaskStore::new(storage.clone() as Arc<dyn ThreadStore>));
1160        task_store
1161            .create_task(super::super::store::NewTaskSpec {
1162                task_id: "flaky-task".to_string(),
1163                owner_thread_id: "thread-1".to_string(),
1164                task_type: "shell".to_string(),
1165                description: "echo hi".to_string(),
1166                parent_task_id: None,
1167                supports_resume: false,
1168                metadata: Value::Object(serde_json::Map::new()),
1169            })
1170            .await
1171            .unwrap();
1172        storage.set_failures(10);
1173
1174        let mgr = BackgroundTaskManager::with_task_store(Some(task_store));
1175        mgr.spawn_with_id(
1176            SpawnParams {
1177                task_id: "flaky-task".to_string(),
1178                owner_thread_id: "thread-1".to_string(),
1179                task_type: "shell".to_string(),
1180                description: "echo hi".to_string(),
1181                parent_task_id: None,
1182                metadata: serde_json::json!({}),
1183            },
1184            RunCancellationToken::new(),
1185            |_cancel: RunCancellationToken| async { TaskResult::Success(Value::Null) },
1186        )
1187        .await;
1188
1189        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1190
1191        let summary = mgr.get("thread-1", "flaky-task").await.unwrap();
1192        assert_eq!(summary.status, TaskStatus::Completed);
1193        assert!(summary.persistence_error.is_some());
1194        assert!(storage.remaining_failures() < 10);
1195
1196        let removed = mgr.gc_terminal("thread-1").await;
1197        assert_eq!(removed, 0);
1198        assert!(mgr.get("thread-1", "flaky-task").await.is_some());
1199    }
1200
1201    #[tokio::test]
1202    async fn manager_retries_failed_persistence_on_get_and_clears_error() {
1203        let storage = FlakyTaskThreadStore::new(0);
1204        let task_store = Arc::new(TaskStore::new(storage.clone() as Arc<dyn ThreadStore>));
1205        task_store
1206            .create_task(super::super::store::NewTaskSpec {
1207                task_id: "retry-task".to_string(),
1208                owner_thread_id: "thread-1".to_string(),
1209                task_type: "shell".to_string(),
1210                description: "echo hi".to_string(),
1211                parent_task_id: None,
1212                supports_resume: false,
1213                metadata: Value::Object(serde_json::Map::new()),
1214            })
1215            .await
1216            .unwrap();
1217        storage.set_failures(1);
1218
1219        let mgr = BackgroundTaskManager::with_task_store(Some(task_store.clone()));
1220        mgr.spawn_with_id(
1221            SpawnParams {
1222                task_id: "retry-task".to_string(),
1223                owner_thread_id: "thread-1".to_string(),
1224                task_type: "shell".to_string(),
1225                description: "echo hi".to_string(),
1226                parent_task_id: None,
1227                metadata: serde_json::json!({}),
1228            },
1229            RunCancellationToken::new(),
1230            |_cancel: RunCancellationToken| async {
1231                TaskResult::Success(serde_json::json!({ "stdout": "done" }))
1232            },
1233        )
1234        .await;
1235
1236        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1237        let before_retry = task_store
1238            .load_task("retry-task")
1239            .await
1240            .unwrap()
1241            .expect("task should exist");
1242        assert_eq!(before_retry.status, TaskStatus::Running);
1243
1244        let summary = mgr.get("thread-1", "retry-task").await.unwrap();
1245        assert!(summary.persistence_error.is_none());
1246
1247        let after_retry = task_store
1248            .load_task("retry-task")
1249            .await
1250            .unwrap()
1251            .expect("task should exist");
1252        assert_eq!(after_retry.status, TaskStatus::Completed);
1253        assert_eq!(
1254            after_retry.result,
1255            Some(serde_json::json!({ "stdout": "done" }))
1256        );
1257    }
1258}