1use super::store::TaskStore;
10use super::types::*;
11use crate::runtime::loop_runner::RunCancellationToken;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::future::Future;
15use std::sync::Arc;
16use std::time::{SystemTime, UNIX_EPOCH};
17use tokio::sync::Mutex;
18
19fn now_ms() -> u64 {
20 SystemTime::now()
21 .duration_since(UNIX_EPOCH)
22 .unwrap_or_default()
23 .as_millis()
24 .min(u128::from(u64::MAX)) as u64
25}
26
27#[derive(Debug)]
29struct TaskHandle {
30 epoch: u64,
31 owner_thread_id: String,
32 task_type: String,
33 description: String,
34 status: TaskStatus,
35 error: Option<String>,
36 result: Option<Value>,
37 cancel_token: RunCancellationToken,
38 cancellation_requested: bool,
39 created_at_ms: u64,
40 completed_at_ms: Option<u64>,
41 parent_task_id: Option<TaskId>,
42 metadata: Value,
43 persistence_error: Option<String>,
44}
45
46#[derive(Clone)]
51pub struct BackgroundTaskManager {
52 handles: Arc<Mutex<HashMap<TaskId, TaskHandle>>>,
53 task_store: Option<Arc<TaskStore>>,
54}
55
56impl std::fmt::Debug for BackgroundTaskManager {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("BackgroundTaskManager")
59 .field("task_count", &"<locked>")
60 .finish()
61 }
62}
63
64impl Default for BackgroundTaskManager {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl BackgroundTaskManager {
71 pub fn new() -> Self {
72 Self {
73 handles: Arc::new(Mutex::new(HashMap::new())),
74 task_store: None,
75 }
76 }
77
78 pub fn with_task_store(task_store: Option<Arc<TaskStore>>) -> Self {
79 Self {
80 handles: Arc::new(Mutex::new(HashMap::new())),
81 task_store,
82 }
83 }
84
85 pub async fn spawn<F, Fut>(
90 &self,
91 owner_thread_id: &str,
92 task_type: &str,
93 description: &str,
94 task: F,
95 ) -> TaskId
96 where
97 F: FnOnce(RunCancellationToken) -> Fut + Send + 'static,
98 Fut: Future<Output = TaskResult> + Send,
99 {
100 let params = SpawnParams {
101 task_id: new_task_id(),
102 owner_thread_id: owner_thread_id.to_string(),
103 task_type: task_type.to_string(),
104 description: description.to_string(),
105 parent_task_id: None,
106 metadata: Value::Object(serde_json::Map::new()),
107 };
108 self.spawn_impl(params, None, task).await
109 }
110
111 pub async fn spawn_with_id<F, Fut>(
116 &self,
117 params: SpawnParams,
118 cancel_token: RunCancellationToken,
119 task: F,
120 ) -> TaskId
121 where
122 F: FnOnce(RunCancellationToken) -> Fut + Send + 'static,
123 Fut: Future<Output = TaskResult> + Send,
124 {
125 self.spawn_impl(params, Some(cancel_token), task).await
126 }
127
128 async fn spawn_impl<F, Fut>(
129 &self,
130 params: SpawnParams,
131 external_cancel_token: Option<RunCancellationToken>,
132 task: F,
133 ) -> TaskId
134 where
135 F: FnOnce(RunCancellationToken) -> Fut + Send + 'static,
136 Fut: Future<Output = TaskResult> + Send,
137 {
138 let SpawnParams {
139 task_id,
140 owner_thread_id,
141 task_type,
142 description,
143 parent_task_id,
144 metadata,
145 } = params;
146 let cancel_token = external_cancel_token.unwrap_or_default();
147 let now = now_ms();
148
149 let epoch = {
150 let mut handles = self.handles.lock().await;
151 let epoch = handles.get(&task_id).map(|h| h.epoch + 1).unwrap_or(1);
152 handles.insert(
153 task_id.clone(),
154 TaskHandle {
155 epoch,
156 owner_thread_id,
157 task_type,
158 description,
159 status: TaskStatus::Running,
160 error: None,
161 result: None,
162 cancel_token: cancel_token.clone(),
163 cancellation_requested: false,
164 created_at_ms: now,
165 completed_at_ms: None,
166 parent_task_id,
167 metadata,
168 persistence_error: None,
169 },
170 );
171 epoch
172 };
173
174 let handles = self.handles.clone();
175 let task_store = self.task_store.clone();
176 let tid = task_id.clone();
177
178 tokio::spawn(async move {
179 let result = task(cancel_token).await;
180
181 let completed_at = now_ms();
183 {
184 let mut map = handles.lock().await;
185 if let Some(handle) = map.get_mut(&tid) {
186 if handle.epoch != epoch {
187 return; }
189 let status = if handle.cancellation_requested {
190 TaskStatus::Cancelled
191 } else {
192 result.status()
193 };
194 handle.status = status;
195 handle.error = match &result {
196 TaskResult::Failed(e) => Some(e.clone()),
197 _ => None,
198 };
199 handle.result = match &result {
200 TaskResult::Success(v) => Some(v.clone()),
201 _ => None,
202 };
203 handle.completed_at_ms = Some(completed_at);
204 }
205 }
206
207 if let Some(task_store) = task_store {
208 let maybe_summary = {
209 let map = handles.lock().await;
210 map.get(&tid)
211 .filter(|handle| handle.epoch == epoch)
212 .map(|handle| summary_from_handle(&tid, handle))
213 };
214 if let Some(summary) = maybe_summary {
215 let persistence_error = task_store
216 .persist_summary(&summary)
217 .await
218 .err()
219 .map(|e| e.to_string());
220 let mut map = handles.lock().await;
221 if let Some(handle) = map.get_mut(&tid) {
222 if handle.epoch != epoch {
223 return;
224 }
225 handle.persistence_error = persistence_error;
226 }
227 }
228 }
229 });
230
231 task_id
232 }
233
234 pub async fn cancel(&self, owner_thread_id: &str, task_id: &str) -> Result<(), String> {
236 let mut handles = self.handles.lock().await;
237 let Some(handle) = handles.get_mut(task_id) else {
238 return Err(format!("Unknown task_id: {task_id}"));
239 };
240 if handle.owner_thread_id != owner_thread_id {
241 return Err(format!("Unknown task_id: {task_id}"));
242 }
243 if handle.status != TaskStatus::Running {
244 return Err(format!(
245 "Task '{task_id}' is not running (current status: {})",
246 handle.status.as_str()
247 ));
248 }
249 handle.cancellation_requested = true;
250 handle.cancel_token.cancel();
251 Ok(())
252 }
253
254 pub async fn get(&self, owner_thread_id: &str, task_id: &str) -> Option<TaskSummary> {
256 self.retry_persistence(owner_thread_id, Some(task_id)).await;
257 let handles = self.handles.lock().await;
258 let handle = handles.get(task_id)?;
259 if handle.owner_thread_id != owner_thread_id {
260 return None;
261 }
262 Some(summary_from_handle(task_id, handle))
263 }
264
265 pub async fn list(
267 &self,
268 owner_thread_id: &str,
269 status_filter: Option<TaskStatus>,
270 ) -> Vec<TaskSummary> {
271 self.retry_persistence(owner_thread_id, None).await;
272 let handles = self.handles.lock().await;
273 let mut out: Vec<TaskSummary> = handles
274 .iter()
275 .filter(|(_, h)| h.owner_thread_id == owner_thread_id)
276 .filter(|(_, h)| status_filter.is_none_or(|s| s == h.status))
277 .map(|(id, h)| summary_from_handle(id, h))
278 .collect();
279 out.sort_by(|a, b| a.created_at_ms.cmp(&b.created_at_ms));
280 out
281 }
282
283 pub async fn has_running_tasks(&self, owner_thread_id: &str) -> bool {
285 let handles = self.handles.lock().await;
286 handles
287 .values()
288 .any(|h| h.owner_thread_id == owner_thread_id && h.status == TaskStatus::Running)
289 }
290
291 pub async fn gc_terminal(&self, owner_thread_id: &str) -> usize {
293 self.retry_persistence(owner_thread_id, None).await;
294 let mut handles = self.handles.lock().await;
295 let before = handles.len();
296 handles.retain(|_, h| {
297 h.owner_thread_id != owner_thread_id
298 || !h.status.is_terminal()
299 || h.persistence_error.is_some()
300 });
301 before - handles.len()
302 }
303
304 pub async fn contains(&self, owner_thread_id: &str, task_id: &str) -> bool {
306 let handles = self.handles.lock().await;
307 handles
308 .get(task_id)
309 .is_some_and(|h| h.owner_thread_id == owner_thread_id)
310 }
311
312 pub async fn contains_any(&self, task_id: &str) -> bool {
314 let handles = self.handles.lock().await;
315 handles.contains_key(task_id)
316 }
317
318 pub async fn cancel_tree(
324 &self,
325 owner_thread_id: &str,
326 task_id: &str,
327 ) -> Result<Vec<TaskSummary>, String> {
328 let mut handles = self.handles.lock().await;
329 let Some(root) = handles.get(task_id) else {
330 return Err(format!("Unknown task_id: {task_id}"));
331 };
332 if root.owner_thread_id != owner_thread_id {
333 return Err(format!("Unknown task_id: {task_id}"));
334 }
335
336 let tree_ids = collect_descendant_ids(&handles, owner_thread_id, task_id, true);
338 if tree_ids.is_empty() {
339 return Err(format!(
340 "Task '{task_id}' is not running (current status: {})",
341 root.status.as_str()
342 ));
343 }
344
345 let mut cancelled = false;
346 let mut out = Vec::with_capacity(tree_ids.len());
347 for id in tree_ids {
348 if let Some(handle) = handles.get_mut(&id) {
349 if handle.status == TaskStatus::Running {
350 handle.cancellation_requested = true;
351 handle.cancel_token.cancel();
352 cancelled = true;
353 }
354 out.push(summary_from_handle(&id, handle));
355 }
356 }
357
358 if cancelled {
359 Ok(out)
360 } else {
361 Err(format!(
362 "Task '{task_id}' is not running (current status: {})",
363 handles
364 .get(task_id)
365 .map(|h| h.status.as_str())
366 .unwrap_or("unknown")
367 ))
368 }
369 }
370
371 pub async fn update_status(
373 &self,
374 task_id: &str,
375 status: TaskStatus,
376 error: Option<String>,
377 ) -> bool {
378 let mut handles = self.handles.lock().await;
379 if let Some(handle) = handles.get_mut(task_id) {
380 handle.status = status;
381 handle.error = error;
382 if status.is_terminal() {
383 handle.completed_at_ms = Some(now_ms());
384 }
385 true
386 } else {
387 false
388 }
389 }
390
391 pub async fn list_by_type(
393 &self,
394 owner_thread_id: &str,
395 task_type: &str,
396 status_filter: Option<TaskStatus>,
397 ) -> Vec<TaskSummary> {
398 let handles = self.handles.lock().await;
399 let mut out: Vec<TaskSummary> = handles
400 .iter()
401 .filter(|(_, h)| h.owner_thread_id == owner_thread_id)
402 .filter(|(_, h)| h.task_type == task_type)
403 .filter(|(_, h)| status_filter.is_none_or(|s| s == h.status))
404 .map(|(id, h)| summary_from_handle(id, h))
405 .collect();
406 out.sort_by(|a, b| a.created_at_ms.cmp(&b.created_at_ms));
407 out
408 }
409
410 async fn retry_persistence(&self, owner_thread_id: &str, task_id: Option<&str>) {
411 let Some(task_store) = self.task_store.as_ref().cloned() else {
412 return;
413 };
414
415 let candidates: Vec<(TaskId, u64, TaskSummary)> = {
416 let handles = self.handles.lock().await;
417 handles
418 .iter()
419 .filter(|(id, handle)| {
420 handle.owner_thread_id == owner_thread_id
421 && handle.status.is_terminal()
422 && handle.persistence_error.is_some()
423 && task_id.is_none_or(|wanted| wanted == id.as_str())
424 })
425 .map(|(id, handle)| (id.clone(), handle.epoch, summary_from_handle(id, handle)))
426 .collect()
427 };
428
429 for (task_id, epoch, summary) in candidates {
430 let persistence_error = task_store
431 .persist_summary(&summary)
432 .await
433 .err()
434 .map(|e| e.to_string());
435 let mut handles = self.handles.lock().await;
436 if let Some(handle) = handles.get_mut(&task_id) {
437 if handle.epoch != epoch || !handle.status.is_terminal() {
438 continue;
439 }
440 handle.persistence_error = persistence_error;
441 }
442 }
443 }
444}
445
446fn collect_descendant_ids(
448 handles: &HashMap<TaskId, TaskHandle>,
449 owner_thread_id: &str,
450 root_id: &str,
451 include_root: bool,
452) -> Vec<String> {
453 let mut children_by_parent: HashMap<&str, Vec<&str>> = HashMap::new();
455 for (id, h) in handles.iter() {
456 if h.owner_thread_id != owner_thread_id {
457 continue;
458 }
459 if let Some(parent) = h.parent_task_id.as_deref() {
460 children_by_parent
461 .entry(parent)
462 .or_default()
463 .push(id.as_str());
464 }
465 }
466
467 let mut result = Vec::new();
468 let mut stack = vec![root_id];
469 let mut is_root = true;
470 while let Some(current) = stack.pop() {
471 if !is_root || include_root {
472 if handles
474 .get(current)
475 .is_some_and(|h| h.status == TaskStatus::Running)
476 {
477 result.push(current.to_string());
478 }
479 }
480 is_root = false;
481 if let Some(children) = children_by_parent.get(current) {
482 for child in children {
483 stack.push(child);
484 }
485 }
486 }
487 result
488}
489
490fn summary_from_handle(task_id: &str, handle: &TaskHandle) -> TaskSummary {
491 TaskSummary {
492 task_id: task_id.to_string(),
493 task_type: handle.task_type.clone(),
494 description: handle.description.clone(),
495 status: handle.status,
496 error: handle.error.clone(),
497 result: handle.result.clone(),
498 result_ref: None,
499 created_at_ms: handle.created_at_ms,
500 completed_at_ms: handle.completed_at_ms,
501 parent_task_id: handle.parent_task_id.clone(),
502 supports_resume: handle.task_type == "agent_run",
503 attempt: 0,
504 metadata: handle.metadata.clone(),
505 persistence_error: handle.persistence_error.clone(),
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::contracts::storage::{
513 MessagePage, MessageQuery, RunPage, RunQuery, RunRecord, ThreadHead, ThreadListPage,
514 ThreadListQuery, ThreadReader, ThreadStore, ThreadStoreError, ThreadWriter,
515 VersionPrecondition,
516 };
517 use crate::contracts::thread::{Thread, ThreadChangeSet};
518 use std::sync::atomic::{AtomicUsize, Ordering};
519
520 struct FlakyTaskThreadStore {
521 inner: Arc<tirea_store_adapters::MemoryStore>,
522 fail_task_appends: AtomicUsize,
523 }
524
525 impl FlakyTaskThreadStore {
526 fn new(fail_task_appends: usize) -> Arc<Self> {
527 Arc::new(Self {
528 inner: Arc::new(tirea_store_adapters::MemoryStore::new()),
529 fail_task_appends: AtomicUsize::new(fail_task_appends),
530 })
531 }
532
533 fn set_failures(&self, failures: usize) {
534 self.fail_task_appends.store(failures, Ordering::SeqCst);
535 }
536
537 fn remaining_failures(&self) -> usize {
538 self.fail_task_appends.load(Ordering::SeqCst)
539 }
540 }
541
542 #[async_trait::async_trait]
543 impl ThreadReader for FlakyTaskThreadStore {
544 async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
545 self.inner.load(thread_id).await
546 }
547
548 async fn list_threads(
549 &self,
550 query: &ThreadListQuery,
551 ) -> Result<ThreadListPage, ThreadStoreError> {
552 self.inner.list_threads(query).await
553 }
554
555 async fn load_messages(
556 &self,
557 thread_id: &str,
558 query: &MessageQuery,
559 ) -> Result<MessagePage, ThreadStoreError> {
560 self.inner.load_messages(thread_id, query).await
561 }
562
563 async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, ThreadStoreError> {
564 self.inner.load_run(run_id).await
565 }
566
567 async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, ThreadStoreError> {
568 self.inner.list_runs(query).await
569 }
570
571 async fn active_run_for_thread(
572 &self,
573 thread_id: &str,
574 ) -> Result<Option<RunRecord>, ThreadStoreError> {
575 self.inner.active_run_for_thread(thread_id).await
576 }
577 }
578
579 #[async_trait::async_trait]
580 impl ThreadWriter for FlakyTaskThreadStore {
581 async fn create(
582 &self,
583 thread: &Thread,
584 ) -> Result<crate::contracts::storage::Committed, ThreadStoreError> {
585 self.inner.create(thread).await
586 }
587
588 async fn append(
589 &self,
590 thread_id: &str,
591 delta: &ThreadChangeSet,
592 precondition: VersionPrecondition,
593 ) -> Result<crate::contracts::storage::Committed, ThreadStoreError> {
594 if thread_id.starts_with(TASK_THREAD_PREFIX)
595 && self
596 .fail_task_appends
597 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| {
598 if remaining > 0 {
599 Some(remaining - 1)
600 } else {
601 None
602 }
603 })
604 .is_ok()
605 {
606 return Err(ThreadStoreError::Io(std::io::Error::other(
607 "injected task persistence failure",
608 )));
609 }
610 self.inner.append(thread_id, delta, precondition).await
611 }
612
613 async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
614 self.inner.delete(thread_id).await
615 }
616
617 async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
618 self.inner.save(thread).await
619 }
620 }
621
622 #[tokio::test]
623 async fn spawn_and_complete_success() {
624 let mgr = BackgroundTaskManager::new();
625 let tid = mgr
626 .spawn("thread-1", "shell", "echo hello", |_cancel| async {
627 TaskResult::Success(serde_json::json!({ "exit_code": 0 }))
628 })
629 .await;
630
631 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
633
634 let summary = mgr.get("thread-1", &tid).await.expect("task should exist");
635 assert_eq!(summary.status, TaskStatus::Completed);
636 assert!(summary.result.is_some());
637 assert!(summary.error.is_none());
638 assert!(summary.completed_at_ms.is_some());
639 }
640
641 #[tokio::test]
642 async fn spawn_and_complete_failure() {
643 let mgr = BackgroundTaskManager::new();
644 let tid = mgr
645 .spawn("thread-1", "shell", "bad cmd", |_cancel| async {
646 TaskResult::Failed("command not found".into())
647 })
648 .await;
649
650 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
651
652 let summary = mgr.get("thread-1", &tid).await.unwrap();
653 assert_eq!(summary.status, TaskStatus::Failed);
654 assert_eq!(summary.error.as_deref(), Some("command not found"));
655 }
656
657 #[tokio::test]
658 async fn cancel_running_task() {
659 let mgr = BackgroundTaskManager::new();
660 let tid = mgr
661 .spawn("thread-1", "shell", "long running", |cancel| async move {
662 cancel.cancelled().await;
663 TaskResult::Cancelled
664 })
665 .await;
666
667 let summary = mgr.get("thread-1", &tid).await.unwrap();
669 assert_eq!(summary.status, TaskStatus::Running);
670
671 mgr.cancel("thread-1", &tid).await.unwrap();
673
674 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
675
676 let summary = mgr.get("thread-1", &tid).await.unwrap();
677 assert_eq!(summary.status, TaskStatus::Cancelled);
678 }
679
680 #[tokio::test]
681 async fn cancel_wrong_owner_rejected() {
682 let mgr = BackgroundTaskManager::new();
683 let tid = mgr
684 .spawn("thread-1", "shell", "task", |cancel| async move {
685 cancel.cancelled().await;
686 TaskResult::Cancelled
687 })
688 .await;
689
690 let result = mgr.cancel("thread-other", &tid).await;
691 assert!(result.is_err());
692 }
693
694 #[tokio::test]
695 async fn list_filters_by_owner_and_status() {
696 let mgr = BackgroundTaskManager::new();
697
698 let _t1 = mgr
700 .spawn("thread-1", "shell", "task-a", |cancel| async move {
701 cancel.cancelled().await;
702 TaskResult::Cancelled
703 })
704 .await;
705
706 mgr.spawn("thread-1", "http", "task-b", |_cancel| async {
708 TaskResult::Success(Value::Null)
709 })
710 .await;
711
712 mgr.spawn("thread-2", "shell", "task-c", |cancel| async move {
714 cancel.cancelled().await;
715 TaskResult::Cancelled
716 })
717 .await;
718
719 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
720
721 let all = mgr.list("thread-1", None).await;
722 assert_eq!(all.len(), 2);
723
724 let running = mgr.list("thread-1", Some(TaskStatus::Running)).await;
725 assert_eq!(running.len(), 1);
726 assert_eq!(running[0].task_type, "shell");
727
728 let completed = mgr.list("thread-1", Some(TaskStatus::Completed)).await;
729 assert_eq!(completed.len(), 1);
730 assert_eq!(completed[0].task_type, "http");
731 }
732
733 #[tokio::test]
734 async fn has_running_tasks_reflects_state() {
735 let mgr = BackgroundTaskManager::new();
736
737 assert!(!mgr.has_running_tasks("thread-1").await);
738
739 let tid = mgr
740 .spawn("thread-1", "shell", "task", |cancel| async move {
741 cancel.cancelled().await;
742 TaskResult::Cancelled
743 })
744 .await;
745
746 assert!(mgr.has_running_tasks("thread-1").await);
747
748 mgr.cancel("thread-1", &tid).await.unwrap();
749 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
750
751 assert!(!mgr.has_running_tasks("thread-1").await);
752 }
753
754 #[tokio::test]
755 async fn gc_terminal_removes_completed_tasks() {
756 let mgr = BackgroundTaskManager::new();
757
758 mgr.spawn("thread-1", "shell", "done", |_| async {
759 TaskResult::Success(Value::Null)
760 })
761 .await;
762
763 let _running = mgr
764 .spawn("thread-1", "shell", "still going", |cancel| async move {
765 cancel.cancelled().await;
766 TaskResult::Cancelled
767 })
768 .await;
769
770 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
771
772 let removed = mgr.gc_terminal("thread-1").await;
773 assert_eq!(removed, 1);
774
775 let all = mgr.list("thread-1", None).await;
776 assert_eq!(all.len(), 1);
777 assert_eq!(all[0].status, TaskStatus::Running);
778 }
779
780 #[tokio::test]
781 async fn get_returns_none_for_wrong_owner() {
782 let mgr = BackgroundTaskManager::new();
783 let tid = mgr
784 .spawn("thread-1", "shell", "task", |_| async {
785 TaskResult::Success(Value::Null)
786 })
787 .await;
788
789 assert!(mgr.get("thread-other", &tid).await.is_none());
790 assert!(mgr.get("thread-1", &tid).await.is_some());
791 }
792
793 #[tokio::test]
794 async fn cancel_already_completed_returns_error() {
795 let mgr = BackgroundTaskManager::new();
796 let tid = mgr
797 .spawn("thread-1", "shell", "fast", |_| async {
798 TaskResult::Success(Value::Null)
799 })
800 .await;
801
802 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
803
804 let err = mgr.cancel("thread-1", &tid).await.unwrap_err();
805 assert!(err.contains("not running"));
806 }
807
808 #[tokio::test]
813 async fn many_concurrent_spawns_all_tracked() {
814 let mgr = BackgroundTaskManager::new();
815 let mut ids = Vec::new();
816
817 for i in 0..20 {
818 let desc = format!("task-{i}");
819 let tid = mgr
820 .spawn("thread-1", "shell", &desc, |cancel| async move {
821 cancel.cancelled().await;
822 TaskResult::Cancelled
823 })
824 .await;
825 ids.push(tid);
826 }
827
828 let all = mgr.list("thread-1", None).await;
829 assert_eq!(all.len(), 20);
830
831 let running = mgr.list("thread-1", Some(TaskStatus::Running)).await;
833 assert_eq!(running.len(), 20);
834
835 for tid in &ids {
837 mgr.cancel("thread-1", tid).await.unwrap();
838 }
839 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
840
841 let cancelled = mgr.list("thread-1", Some(TaskStatus::Cancelled)).await;
842 assert_eq!(cancelled.len(), 20);
843 }
844
845 #[tokio::test]
846 async fn concurrent_cancel_and_complete_race() {
847 let mgr = BackgroundTaskManager::new();
849 let tid = mgr
850 .spawn("thread-1", "shell", "fast", |_| async {
851 TaskResult::Success(Value::Null)
852 })
853 .await;
854
855 let cancel_result = mgr.cancel("thread-1", &tid).await;
857 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
858
859 let summary = mgr.get("thread-1", &tid).await.unwrap();
860 assert!(
862 summary.status == TaskStatus::Cancelled || summary.status == TaskStatus::Completed,
863 "unexpected status: {:?}",
864 summary.status
865 );
866
867 if cancel_result.is_ok() {
869 assert_eq!(summary.status, TaskStatus::Cancelled);
870 }
871 }
872
873 #[tokio::test]
874 async fn task_with_tokio_select_respects_cancellation() {
875 let mgr = BackgroundTaskManager::new();
876 let tid = mgr
877 .spawn("thread-1", "shell", "select-based", |cancel| async move {
878 tokio::select! {
879 _ = cancel.cancelled() => TaskResult::Cancelled,
880 _ = tokio::time::sleep(std::time::Duration::from_secs(60)) => {
881 TaskResult::Success(Value::Null)
882 }
883 }
884 })
885 .await;
886
887 assert!(mgr.has_running_tasks("thread-1").await);
888
889 mgr.cancel("thread-1", &tid).await.unwrap();
890 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
891
892 let summary = mgr.get("thread-1", &tid).await.unwrap();
893 assert_eq!(summary.status, TaskStatus::Cancelled);
894 assert!(!mgr.has_running_tasks("thread-1").await);
895 }
896
897 #[tokio::test]
898 async fn task_failure_with_panic_safety() {
899 let mgr = BackgroundTaskManager::new();
901 let tid = mgr
902 .spawn("thread-1", "http", "timeout", |_| async {
903 TaskResult::Failed("connection timed out".into())
904 })
905 .await;
906
907 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
908
909 let summary = mgr.get("thread-1", &tid).await.unwrap();
910 assert_eq!(summary.status, TaskStatus::Failed);
911 assert_eq!(summary.error.as_deref(), Some("connection timed out"));
912 assert!(summary.completed_at_ms.is_some());
913 }
914
915 #[tokio::test]
916 async fn gc_only_affects_specified_thread() {
917 let mgr = BackgroundTaskManager::new();
918
919 mgr.spawn("thread-1", "shell", "done-1", |_| async {
921 TaskResult::Success(Value::Null)
922 })
923 .await;
924 mgr.spawn("thread-2", "shell", "done-2", |_| async {
926 TaskResult::Success(Value::Null)
927 })
928 .await;
929
930 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
931
932 let removed = mgr.gc_terminal("thread-1").await;
934 assert_eq!(removed, 1);
935
936 let t2_tasks = mgr.list("thread-2", None).await;
938 assert_eq!(t2_tasks.len(), 1);
939 }
940
941 #[tokio::test]
942 async fn task_summary_has_timing_info() {
943 let mgr = BackgroundTaskManager::new();
944 let tid = mgr
945 .spawn("thread-1", "shell", "timed", |_| async {
946 TaskResult::Success(Value::Null)
947 })
948 .await;
949
950 let running = mgr.get("thread-1", &tid).await.unwrap();
952 assert!(running.created_at_ms > 0);
953 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
956
957 let completed = mgr.get("thread-1", &tid).await.unwrap();
958 assert!(completed.created_at_ms > 0);
959 assert!(completed.completed_at_ms.is_some());
960 assert!(completed.completed_at_ms.unwrap() >= completed.created_at_ms);
961 }
962
963 #[tokio::test]
964 async fn list_returns_sorted_by_creation_time() {
965 let mgr = BackgroundTaskManager::new();
966
967 let t1 = mgr
968 .spawn("thread-1", "shell", "first", |cancel| async move {
969 cancel.cancelled().await;
970 TaskResult::Cancelled
971 })
972 .await;
973 tokio::time::sleep(std::time::Duration::from_millis(2)).await;
975 let t2 = mgr
976 .spawn("thread-1", "shell", "second", |cancel| async move {
977 cancel.cancelled().await;
978 TaskResult::Cancelled
979 })
980 .await;
981 tokio::time::sleep(std::time::Duration::from_millis(2)).await;
982 let t3 = mgr
983 .spawn("thread-1", "shell", "third", |cancel| async move {
984 cancel.cancelled().await;
985 TaskResult::Cancelled
986 })
987 .await;
988
989 let tasks = mgr.list("thread-1", None).await;
990 assert_eq!(tasks.len(), 3);
991 assert_eq!(tasks[0].task_id, t1);
992 assert_eq!(tasks[1].task_id, t2);
993 assert_eq!(tasks[2].task_id, t3);
994 }
995
996 #[tokio::test]
997 async fn default_impl_without_task_store_still_tracks_terminal_tasks() {
998 let mgr = BackgroundTaskManager::new();
999 mgr.spawn("thread-1", "shell", "task", |_| async {
1000 TaskResult::Success(Value::Null)
1001 })
1002 .await;
1003 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1004 let tasks = mgr.list("thread-1", None).await;
1006 assert_eq!(tasks.len(), 1);
1007 }
1008
1009 #[tokio::test]
1014 async fn spawn_with_id_uses_caller_supplied_id() {
1015 let mgr = BackgroundTaskManager::new();
1016 let token = RunCancellationToken::new();
1017 let tid = mgr
1018 .spawn_with_id(
1019 SpawnParams {
1020 task_id: "my-custom-id".to_string(),
1021 owner_thread_id: "thread-1".to_string(),
1022 task_type: "agent_run".to_string(),
1023 description: "agent:worker".to_string(),
1024 parent_task_id: None,
1025 metadata: serde_json::json!({}),
1026 },
1027 token,
1028 |_cancel: RunCancellationToken| async { TaskResult::Success(Value::Null) },
1029 )
1030 .await;
1031
1032 assert_eq!(tid, "my-custom-id");
1033 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1034
1035 let summary = mgr.get("thread-1", "my-custom-id").await.unwrap();
1036 assert_eq!(summary.task_type, "agent_run");
1037 assert_eq!(summary.description, "agent:worker");
1038 assert_eq!(summary.status, TaskStatus::Completed);
1039 }
1040
1041 #[tokio::test]
1042 async fn spawn_with_id_uses_external_cancel_token() {
1043 let mgr = BackgroundTaskManager::new();
1044 let token = RunCancellationToken::new();
1045 let token_clone = token.clone();
1046
1047 mgr.spawn_with_id(
1048 SpawnParams {
1049 task_id: "cancel-test".to_string(),
1050 owner_thread_id: "thread-1".to_string(),
1051 task_type: "shell".to_string(),
1052 description: "long task".to_string(),
1053 parent_task_id: None,
1054 metadata: serde_json::json!({}),
1055 },
1056 token,
1057 |cancel: RunCancellationToken| async move {
1058 cancel.cancelled().await;
1059 TaskResult::Cancelled
1060 },
1061 )
1062 .await;
1063
1064 let summary = mgr.get("thread-1", "cancel-test").await.unwrap();
1066 assert_eq!(summary.status, TaskStatus::Running);
1067
1068 token_clone.cancel();
1070 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1071
1072 let summary = mgr.get("thread-1", "cancel-test").await.unwrap();
1075 assert_eq!(summary.status, TaskStatus::Cancelled);
1076 }
1077
1078 #[tokio::test]
1079 async fn spawn_with_id_cancel_via_manager_works() {
1080 let mgr = BackgroundTaskManager::new();
1081 let token = RunCancellationToken::new();
1082
1083 mgr.spawn_with_id(
1084 SpawnParams {
1085 task_id: "mgr-cancel".to_string(),
1086 owner_thread_id: "thread-1".to_string(),
1087 task_type: "agent_run".to_string(),
1088 description: "agent:worker".to_string(),
1089 parent_task_id: None,
1090 metadata: serde_json::json!({}),
1091 },
1092 token,
1093 |cancel: RunCancellationToken| async move {
1094 cancel.cancelled().await;
1095 TaskResult::Cancelled
1096 },
1097 )
1098 .await;
1099
1100 mgr.cancel("thread-1", "mgr-cancel").await.unwrap();
1102 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1103
1104 let summary = mgr.get("thread-1", "mgr-cancel").await.unwrap();
1105 assert_eq!(summary.status, TaskStatus::Cancelled);
1106 }
1107
1108 #[tokio::test]
1109 async fn manager_persists_terminal_state_to_task_store_when_configured() {
1110 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
1111 let task_store = Arc::new(TaskStore::new(storage.clone() as Arc<dyn ThreadStore>));
1112 task_store
1113 .create_task(super::super::store::NewTaskSpec {
1114 task_id: "persisted-task".to_string(),
1115 owner_thread_id: "thread-1".to_string(),
1116 task_type: "shell".to_string(),
1117 description: "echo hi".to_string(),
1118 parent_task_id: None,
1119 supports_resume: false,
1120 metadata: Value::Object(serde_json::Map::new()),
1121 })
1122 .await
1123 .unwrap();
1124
1125 let mgr = BackgroundTaskManager::with_task_store(Some(task_store.clone()));
1126 mgr.spawn_with_id(
1127 SpawnParams {
1128 task_id: "persisted-task".to_string(),
1129 owner_thread_id: "thread-1".to_string(),
1130 task_type: "shell".to_string(),
1131 description: "echo hi".to_string(),
1132 parent_task_id: None,
1133 metadata: serde_json::json!({}),
1134 },
1135 RunCancellationToken::new(),
1136 |_cancel: RunCancellationToken| async {
1137 TaskResult::Success(serde_json::json!({ "stdout": "hi" }))
1138 },
1139 )
1140 .await;
1141
1142 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1143
1144 let persisted = task_store
1145 .load_task("persisted-task")
1146 .await
1147 .unwrap()
1148 .expect("task should persist");
1149 assert_eq!(persisted.status, TaskStatus::Completed);
1150 assert_eq!(
1151 persisted.result,
1152 Some(serde_json::json!({ "stdout": "hi" }))
1153 );
1154 }
1155
1156 #[tokio::test]
1157 async fn manager_exposes_persistence_error_and_gc_retains_terminal_task_until_persisted() {
1158 let storage = FlakyTaskThreadStore::new(0);
1159 let task_store = Arc::new(TaskStore::new(storage.clone() as Arc<dyn ThreadStore>));
1160 task_store
1161 .create_task(super::super::store::NewTaskSpec {
1162 task_id: "flaky-task".to_string(),
1163 owner_thread_id: "thread-1".to_string(),
1164 task_type: "shell".to_string(),
1165 description: "echo hi".to_string(),
1166 parent_task_id: None,
1167 supports_resume: false,
1168 metadata: Value::Object(serde_json::Map::new()),
1169 })
1170 .await
1171 .unwrap();
1172 storage.set_failures(10);
1173
1174 let mgr = BackgroundTaskManager::with_task_store(Some(task_store));
1175 mgr.spawn_with_id(
1176 SpawnParams {
1177 task_id: "flaky-task".to_string(),
1178 owner_thread_id: "thread-1".to_string(),
1179 task_type: "shell".to_string(),
1180 description: "echo hi".to_string(),
1181 parent_task_id: None,
1182 metadata: serde_json::json!({}),
1183 },
1184 RunCancellationToken::new(),
1185 |_cancel: RunCancellationToken| async { TaskResult::Success(Value::Null) },
1186 )
1187 .await;
1188
1189 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1190
1191 let summary = mgr.get("thread-1", "flaky-task").await.unwrap();
1192 assert_eq!(summary.status, TaskStatus::Completed);
1193 assert!(summary.persistence_error.is_some());
1194 assert!(storage.remaining_failures() < 10);
1195
1196 let removed = mgr.gc_terminal("thread-1").await;
1197 assert_eq!(removed, 0);
1198 assert!(mgr.get("thread-1", "flaky-task").await.is_some());
1199 }
1200
1201 #[tokio::test]
1202 async fn manager_retries_failed_persistence_on_get_and_clears_error() {
1203 let storage = FlakyTaskThreadStore::new(0);
1204 let task_store = Arc::new(TaskStore::new(storage.clone() as Arc<dyn ThreadStore>));
1205 task_store
1206 .create_task(super::super::store::NewTaskSpec {
1207 task_id: "retry-task".to_string(),
1208 owner_thread_id: "thread-1".to_string(),
1209 task_type: "shell".to_string(),
1210 description: "echo hi".to_string(),
1211 parent_task_id: None,
1212 supports_resume: false,
1213 metadata: Value::Object(serde_json::Map::new()),
1214 })
1215 .await
1216 .unwrap();
1217 storage.set_failures(1);
1218
1219 let mgr = BackgroundTaskManager::with_task_store(Some(task_store.clone()));
1220 mgr.spawn_with_id(
1221 SpawnParams {
1222 task_id: "retry-task".to_string(),
1223 owner_thread_id: "thread-1".to_string(),
1224 task_type: "shell".to_string(),
1225 description: "echo hi".to_string(),
1226 parent_task_id: None,
1227 metadata: serde_json::json!({}),
1228 },
1229 RunCancellationToken::new(),
1230 |_cancel: RunCancellationToken| async {
1231 TaskResult::Success(serde_json::json!({ "stdout": "done" }))
1232 },
1233 )
1234 .await;
1235
1236 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1237 let before_retry = task_store
1238 .load_task("retry-task")
1239 .await
1240 .unwrap()
1241 .expect("task should exist");
1242 assert_eq!(before_retry.status, TaskStatus::Running);
1243
1244 let summary = mgr.get("thread-1", "retry-task").await.unwrap();
1245 assert!(summary.persistence_error.is_none());
1246
1247 let after_retry = task_store
1248 .load_task("retry-task")
1249 .await
1250 .unwrap()
1251 .expect("task should exist");
1252 assert_eq!(after_retry.status, TaskStatus::Completed);
1253 assert_eq!(
1254 after_retry.result,
1255 Some(serde_json::json!({ "stdout": "done" }))
1256 );
1257 }
1258}