1use super::manager::BackgroundTaskManager;
6use super::{
7 derived_task_view_from_doc, BackgroundTaskView, BackgroundTaskViewAction,
8 BackgroundTaskViewState, TaskStore, TaskSummary,
9};
10use crate::contracts::runtime::behavior::{AgentBehavior, ReadOnlyContext};
11use crate::contracts::runtime::phase::{
12 ActionSet, AfterToolExecuteAction, BeforeInferenceAction, LifecycleAction,
13};
14use crate::contracts::runtime::state::AnyStateAction;
15use async_trait::async_trait;
16use serde_json::Value;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::{SystemTime, UNIX_EPOCH};
20
21pub const BACKGROUND_TASKS_PLUGIN_ID: &str = "background_tasks";
22
23fn now_ms() -> u64 {
24 SystemTime::now()
25 .duration_since(UNIX_EPOCH)
26 .unwrap_or_default()
27 .as_millis()
28 .min(u128::from(u64::MAX)) as u64
29}
30
31pub struct BackgroundTasksPlugin {
37 manager: Arc<BackgroundTaskManager>,
38 task_store: Option<Arc<TaskStore>>,
39}
40
41impl BackgroundTasksPlugin {
42 pub fn new(manager: Arc<BackgroundTaskManager>) -> Self {
43 Self {
44 manager,
45 task_store: None,
46 }
47 }
48
49 pub fn with_task_store(mut self, task_store: Option<Arc<TaskStore>>) -> Self {
50 self.task_store = task_store;
51 self
52 }
53
54 async fn collect_task_summaries(&self, thread_id: &str) -> Vec<TaskSummary> {
55 let mut by_id: HashMap<String, TaskSummary> = HashMap::new();
56
57 if let Some(task_store) = &self.task_store {
58 match task_store.list_tasks_for_owner(thread_id).await {
59 Ok(tasks) => {
60 for task in tasks {
61 by_id.insert(task.id.clone(), task.summary());
62 }
63 }
64 Err(error) => {
65 tracing::warn!(
66 owner_thread_id = %thread_id,
67 error = %error,
68 "failed to list persisted background tasks for derived task view"
69 );
70 }
71 }
72 }
73
74 for task in self.manager.list(thread_id, None).await {
75 by_id.insert(task.task_id.clone(), task);
76 }
77
78 let mut tasks: Vec<_> = by_id.into_values().collect();
79 tasks.sort_by(|a, b| {
80 a.created_at_ms
81 .cmp(&b.created_at_ms)
82 .then_with(|| a.task_id.cmp(&b.task_id))
83 });
84 tasks
85 }
86
87 fn derive_task_view(tasks: &[TaskSummary]) -> HashMap<String, BackgroundTaskView> {
88 tasks
89 .iter()
90 .filter(|task| !task.status.is_terminal())
91 .map(|task| (task.task_id.clone(), BackgroundTaskView::from_summary(task)))
92 .collect()
93 }
94
95 fn sync_action_if_changed(
96 &self,
97 snapshot: &Value,
98 tasks: &HashMap<String, BackgroundTaskView>,
99 ) -> Option<AnyStateAction> {
100 let current = derived_task_view_from_doc(snapshot);
101 if current.tasks == *tasks {
102 return None;
103 }
104
105 Some(AnyStateAction::new::<BackgroundTaskViewState>(
106 BackgroundTaskViewAction::Replace {
107 tasks: tasks.clone(),
108 synced_at_ms: now_ms(),
109 },
110 ))
111 }
112
113 fn render_task_view(tasks: &HashMap<String, BackgroundTaskView>) -> Option<String> {
114 if tasks.is_empty() {
115 return None;
116 }
117
118 let mut entries: Vec<_> = tasks.iter().collect();
119 entries.sort_by(|(left_id, _), (right_id, _)| left_id.cmp(right_id));
120
121 let mut out = String::new();
122 out.push_str("<background_tasks>\n");
123 for (task_id, task) in entries {
124 out.push_str(&format!(
125 "<task id=\"{}\" type=\"{}\" status=\"{}\" description=\"{}\"",
126 task_id,
127 task.task_type,
128 task.status.as_str(),
129 task.description,
130 ));
131 if let Some(parent_task_id) = task.parent_task_id.as_deref() {
132 out.push_str(&format!(" parent_task_id=\"{}\"", parent_task_id));
133 }
134 if let Some(agent_id) = task.agent_id.as_deref() {
135 out.push_str(&format!(" agent_id=\"{}\"", agent_id));
136 }
137 out.push_str("/>\n");
138 }
139 out.push_str("</background_tasks>\n");
140 out.push_str(
141 "Use tool \"task_status\" to check progress, \"task_output\" to read results, or \"task_cancel\" to cancel active tasks.",
142 );
143 Some(out)
144 }
145}
146
147#[async_trait]
148impl AgentBehavior for BackgroundTasksPlugin {
149 fn id(&self) -> &str {
150 BACKGROUND_TASKS_PLUGIN_ID
151 }
152
153 tirea_contract::declare_plugin_states!(BackgroundTaskViewState);
154
155 async fn run_start(&self, ctx: &ReadOnlyContext<'_>) -> ActionSet<LifecycleAction> {
156 let snapshot = ctx.snapshot();
157 let tasks = self.collect_task_summaries(ctx.thread_id()).await;
158 let view = Self::derive_task_view(&tasks);
159
160 self.sync_action_if_changed(&snapshot, &view)
161 .map(LifecycleAction::State)
162 .map(ActionSet::single)
163 .unwrap_or_else(ActionSet::empty)
164 }
165
166 async fn before_inference(
167 &self,
168 ctx: &ReadOnlyContext<'_>,
169 ) -> ActionSet<BeforeInferenceAction> {
170 let view = derived_task_view_from_doc(&ctx.snapshot());
171 Self::render_task_view(&view.tasks)
172 .map(BeforeInferenceAction::AddSystemContext)
173 .map(ActionSet::single)
174 .unwrap_or_else(ActionSet::empty)
175 }
176
177 async fn after_tool_execute(
178 &self,
179 ctx: &ReadOnlyContext<'_>,
180 ) -> ActionSet<AfterToolExecuteAction> {
181 let snapshot = ctx.snapshot();
182 let tasks = self.collect_task_summaries(ctx.thread_id()).await;
183 let view = Self::derive_task_view(&tasks);
184
185 let mut actions = ActionSet::empty();
186 if let Some(state) = self.sync_action_if_changed(&snapshot, &view) {
187 actions = actions.and(AfterToolExecuteAction::State(state));
188 }
189 if let Some(reminder) = Self::render_task_view(&view) {
190 actions = actions.and(AfterToolExecuteAction::AddSystemReminder(reminder));
191 }
192 actions
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::contracts::runtime::phase::Phase;
200 use crate::contracts::runtime::state::{reduce_state_actions, ScopeContext};
201 use crate::contracts::storage::{
202 Committed, MessagePage, MessageQuery, RunPage, RunQuery, RunRecord, ThreadHead,
203 ThreadListPage, ThreadListQuery, ThreadReader, ThreadStore, ThreadStoreError, ThreadWriter,
204 VersionPrecondition,
205 };
206 use crate::contracts::thread::{Thread, ThreadChangeSet};
207 use crate::contracts::RunPolicy;
208 use async_trait::async_trait;
209 use serde_json::{json, Value};
210 use std::sync::atomic::{AtomicUsize, Ordering};
211 use tirea_state::DocCell;
212
213 struct FailingTaskListStore {
214 inner: Arc<tirea_store_adapters::MemoryStore>,
215 fail_task_lists: AtomicUsize,
216 }
217
218 #[async_trait]
219 impl ThreadReader for FailingTaskListStore {
220 async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
221 self.inner.load(thread_id).await
222 }
223
224 async fn list_threads(
225 &self,
226 query: &ThreadListQuery,
227 ) -> Result<ThreadListPage, ThreadStoreError> {
228 if self
229 .fail_task_lists
230 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| {
231 if remaining > 0 {
232 Some(remaining - 1)
233 } else {
234 None
235 }
236 })
237 .is_ok()
238 {
239 return Err(ThreadStoreError::Io(std::io::Error::other(
240 "injected task list failure",
241 )));
242 }
243 self.inner.list_threads(query).await
244 }
245
246 async fn load_messages(
247 &self,
248 thread_id: &str,
249 query: &MessageQuery,
250 ) -> Result<MessagePage, ThreadStoreError> {
251 self.inner.load_messages(thread_id, query).await
252 }
253
254 async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, ThreadStoreError> {
255 self.inner.load_run(run_id).await
256 }
257
258 async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, ThreadStoreError> {
259 self.inner.list_runs(query).await
260 }
261
262 async fn active_run_for_thread(
263 &self,
264 thread_id: &str,
265 ) -> Result<Option<RunRecord>, ThreadStoreError> {
266 self.inner.active_run_for_thread(thread_id).await
267 }
268 }
269
270 #[async_trait]
271 impl ThreadWriter for FailingTaskListStore {
272 async fn create(&self, thread: &Thread) -> Result<Committed, ThreadStoreError> {
273 self.inner.create(thread).await
274 }
275
276 async fn append(
277 &self,
278 thread_id: &str,
279 delta: &ThreadChangeSet,
280 precondition: VersionPrecondition,
281 ) -> Result<Committed, ThreadStoreError> {
282 self.inner.append(thread_id, delta, precondition).await
283 }
284
285 async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
286 self.inner.delete(thread_id).await
287 }
288
289 async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
290 self.inner.save(thread).await
291 }
292 }
293
294 fn make_ctx<'a>(
295 phase: Phase,
296 thread_id: &'a str,
297 run_policy: &'a RunPolicy,
298 doc: &'a DocCell,
299 ) -> ReadOnlyContext<'a> {
300 ReadOnlyContext::new(phase, thread_id, &[], run_policy, doc)
301 }
302
303 fn apply_state_actions(doc: &DocCell, actions: Vec<AnyStateAction>) {
304 if actions.is_empty() {
305 return;
306 }
307 let snapshot = doc.snapshot();
308 let patches = reduce_state_actions(actions, &snapshot, "test", &ScopeContext::run())
309 .expect("state actions should reduce");
310 for patch in patches {
311 for op in patch.patch().ops() {
312 doc.apply(op).expect("state patch op should apply");
313 }
314 }
315 }
316
317 fn lifecycle_state_actions(actions: ActionSet<LifecycleAction>) -> Vec<AnyStateAction> {
318 actions
319 .into_iter()
320 .map(|action| match action {
321 LifecycleAction::State(action) => action,
322 })
323 .collect()
324 }
325
326 fn after_tool_parts(
327 actions: ActionSet<AfterToolExecuteAction>,
328 ) -> (Vec<AnyStateAction>, Vec<String>) {
329 let mut state_actions = Vec::new();
330 let mut reminders = Vec::new();
331 for action in actions {
332 match action {
333 AfterToolExecuteAction::State(action) => state_actions.push(action),
334 AfterToolExecuteAction::AddSystemReminder(text) => reminders.push(text),
335 AfterToolExecuteAction::AddUserMessage(_) => {}
336 }
337 }
338 (state_actions, reminders)
339 }
340
341 fn before_inference_parts(
342 actions: ActionSet<BeforeInferenceAction>,
343 ) -> (Vec<AnyStateAction>, Vec<String>) {
344 let mut state_actions = Vec::new();
345 let mut contexts = Vec::new();
346 for action in actions {
347 match action {
348 BeforeInferenceAction::State(action) => state_actions.push(action),
349 BeforeInferenceAction::AddSystemContext(text) => contexts.push(text),
350 BeforeInferenceAction::AddSessionContext(_)
351 | BeforeInferenceAction::ExcludeTool(_)
352 | BeforeInferenceAction::IncludeOnlyTools(_)
353 | BeforeInferenceAction::AddRequestTransform(_)
354 | BeforeInferenceAction::Terminate(_) => {}
355 }
356 }
357 (state_actions, contexts)
358 }
359
360 fn derived_view(doc: &DocCell) -> BackgroundTaskViewState {
361 let snapshot = doc.snapshot();
362 derived_task_view_from_doc(&snapshot)
363 }
364
365 #[test]
366 fn plugin_id_is_background_tasks() {
367 let mgr = Arc::new(BackgroundTaskManager::new());
368 let plugin = BackgroundTasksPlugin::new(mgr);
369 assert_eq!(plugin.id(), BACKGROUND_TASKS_PLUGIN_ID);
370 }
371
372 #[test]
373 fn plugin_registers_lattice_and_scope() {
374 let mgr = Arc::new(BackgroundTaskManager::new());
375 let plugin = BackgroundTasksPlugin::new(mgr);
376
377 let mut lattice = tirea_state::LatticeRegistry::new();
378 plugin.register_lattice_paths(&mut lattice);
379
380 let mut scope_reg = tirea_contract::runtime::state::StateScopeRegistry::new();
381 plugin.register_state_scopes(&mut scope_reg);
382
383 let mut state_action_deserializer_registry =
384 tirea_contract::runtime::state::StateActionDeserializerRegistry::new();
385 plugin.register_state_action_deserializers(&mut state_action_deserializer_registry);
386 }
387
388 #[tokio::test]
389 async fn run_start_syncs_derived_view_state_from_task_store() {
390 let mgr = Arc::new(BackgroundTaskManager::new());
391 let thread_store = Arc::new(tirea_store_adapters::MemoryStore::new());
392 let task_store = Arc::new(TaskStore::new(thread_store as Arc<dyn ThreadStore>));
393 task_store
394 .create_task(super::super::NewTaskSpec {
395 task_id: "task-1".to_string(),
396 owner_thread_id: "thread-1".to_string(),
397 task_type: "agent_run".to_string(),
398 description: "delegate to writer".to_string(),
399 parent_task_id: Some("root".to_string()),
400 supports_resume: true,
401 metadata: json!({"agent_id":"writer"}),
402 })
403 .await
404 .expect("task should persist");
405
406 let plugin = BackgroundTasksPlugin::new(mgr).with_task_store(Some(task_store));
407 let doc = DocCell::new(json!({}));
408 let rc = RunPolicy::new();
409 let ctx = make_ctx(Phase::RunStart, "thread-1", &rc, &doc);
410
411 let actions = plugin.run_start(&ctx).await;
412 apply_state_actions(&doc, lifecycle_state_actions(actions));
413
414 let derived = derived_view(&doc);
415 let task = derived.tasks.get("task-1").expect("task view should exist");
416 assert_eq!(task.task_type, "agent_run");
417 assert_eq!(task.description, "delegate to writer");
418 assert_eq!(task.status.as_str(), "running");
419 assert_eq!(task.parent_task_id.as_deref(), Some("root"));
420 assert_eq!(task.agent_id.as_deref(), Some("writer"));
421 }
422
423 #[tokio::test]
424 async fn run_start_replaces_stale_derived_view_with_store_snapshot() {
425 let mgr = Arc::new(BackgroundTaskManager::new());
426 let thread_store = Arc::new(tirea_store_adapters::MemoryStore::new());
427 let task_store = Arc::new(TaskStore::new(thread_store as Arc<dyn ThreadStore>));
428 task_store
429 .create_task(super::super::NewTaskSpec {
430 task_id: "task-fresh".to_string(),
431 owner_thread_id: "thread-1".to_string(),
432 task_type: "shell".to_string(),
433 description: "fresh task".to_string(),
434 parent_task_id: None,
435 supports_resume: false,
436 metadata: json!({}),
437 })
438 .await
439 .expect("task should persist");
440
441 let plugin = BackgroundTasksPlugin::new(mgr).with_task_store(Some(task_store));
442 let doc = DocCell::new(json!({
443 "__derived": {
444 "background_tasks": {
445 "tasks": {
446 "stale-task": {
447 "task_type": "shell",
448 "description": "stale task",
449 "status": "running"
450 }
451 },
452 "synced_at_ms": 1
453 }
454 }
455 }));
456 let rc = RunPolicy::new();
457 let ctx = make_ctx(Phase::RunStart, "thread-1", &rc, &doc);
458
459 let actions = plugin.run_start(&ctx).await;
460 apply_state_actions(&doc, lifecycle_state_actions(actions));
461
462 let derived = derived_view(&doc);
463 assert!(!derived.tasks.contains_key("stale-task"));
464 assert!(derived.tasks.contains_key("task-fresh"));
465 }
466
467 #[tokio::test]
468 async fn run_start_falls_back_to_live_tasks_when_store_listing_fails() {
469 let mgr = Arc::new(BackgroundTaskManager::new());
470 let storage = Arc::new(FailingTaskListStore {
471 inner: Arc::new(tirea_store_adapters::MemoryStore::new()),
472 fail_task_lists: AtomicUsize::new(1),
473 });
474 let task_store = Arc::new(TaskStore::new(storage as Arc<dyn ThreadStore>));
475 mgr.spawn("thread-1", "shell", "live task", |cancel| async move {
476 cancel.cancelled().await;
477 super::super::types::TaskResult::Cancelled
478 })
479 .await;
480
481 let plugin = BackgroundTasksPlugin::new(mgr).with_task_store(Some(task_store));
482 let doc = DocCell::new(json!({}));
483 let rc = RunPolicy::new();
484 let ctx = make_ctx(Phase::RunStart, "thread-1", &rc, &doc);
485
486 let actions = plugin.run_start(&ctx).await;
487 apply_state_actions(&doc, lifecycle_state_actions(actions));
488
489 let derived = derived_view(&doc);
490 assert_eq!(derived.tasks.len(), 1);
491 let task = derived
492 .tasks
493 .values()
494 .next()
495 .expect("live manager task should be used when store listing fails");
496 assert_eq!(task.description, "live task");
497 assert_eq!(task.status.as_str(), "running");
498 }
499
500 #[tokio::test]
501 async fn before_inference_uses_cached_view() {
502 let mgr = Arc::new(BackgroundTaskManager::new());
503 let plugin = BackgroundTasksPlugin::new(mgr);
504 let doc = DocCell::new(json!({
505 "__derived": {
506 "background_tasks": {
507 "tasks": {
508 "task-1": {
509 "task_type": "agent_run",
510 "description": "delegate to writer",
511 "status": "running",
512 "parent_task_id": "root",
513 "agent_id": "writer"
514 }
515 },
516 "synced_at_ms": 123
517 }
518 }
519 }));
520 let rc = RunPolicy::new();
521 let ctx = make_ctx(Phase::BeforeInference, "thread-1", &rc, &doc);
522
523 let actions = plugin.before_inference(&ctx).await;
524 let (state_actions, contexts) = before_inference_parts(actions);
525 assert!(state_actions.is_empty());
526 assert_eq!(contexts.len(), 1);
527 assert!(contexts[0].contains("<background_tasks>"));
528 assert!(contexts[0].contains("task-1"));
529 assert!(contexts[0].contains("delegate to writer"));
530 assert!(contexts[0].contains("task_cancel"));
531 }
532
533 #[tokio::test]
534 async fn after_tool_execute_empty_when_no_tasks() {
535 let mgr = Arc::new(BackgroundTaskManager::new());
536 let plugin = BackgroundTasksPlugin::new(mgr);
537
538 let doc = DocCell::new(json!({}));
539 let rc = RunPolicy::new();
540 let ctx = make_ctx(Phase::AfterToolExecute, "thread-1", &rc, &doc);
541
542 let actions = plugin.after_tool_execute(&ctx).await;
543 let (state_actions, reminders) = after_tool_parts(actions);
544 assert!(state_actions.is_empty());
545 assert!(reminders.is_empty());
546 }
547
548 #[tokio::test]
549 async fn after_tool_execute_shows_running_tasks_and_updates_view() {
550 let mgr = Arc::new(BackgroundTaskManager::new());
551 mgr.spawn(
552 "thread-1",
553 "shell",
554 "building project",
555 |cancel| async move {
556 cancel.cancelled().await;
557 super::super::types::TaskResult::Cancelled
558 },
559 )
560 .await;
561
562 let plugin = BackgroundTasksPlugin::new(mgr);
563 let doc = DocCell::new(json!({}));
564 let rc = RunPolicy::new();
565 let ctx = make_ctx(Phase::AfterToolExecute, "thread-1", &rc, &doc);
566
567 let actions = plugin.after_tool_execute(&ctx).await;
568 let (state_actions, reminders) = after_tool_parts(actions);
569 assert_eq!(reminders.len(), 1);
570 assert!(reminders[0].contains("<background_tasks>"));
571 assert!(reminders[0].contains("building project"));
572 assert!(reminders[0].contains("task_status"));
573 assert!(reminders[0].contains("task_output"));
574 apply_state_actions(&doc, state_actions);
575
576 let derived = derived_view(&doc);
577 assert_eq!(derived.tasks.len(), 1);
578 let task = derived
579 .tasks
580 .values()
581 .find(|task| task.description == "building project")
582 .expect("running task view should exist");
583 assert_eq!(task.task_type, "shell");
584 assert_eq!(task.status.as_str(), "running");
585 }
586
587 #[tokio::test]
588 async fn after_tool_execute_ignores_completed_tasks() {
589 let mgr = Arc::new(BackgroundTaskManager::new());
590 mgr.spawn("thread-1", "http", "fetch data", |_| async {
591 super::super::types::TaskResult::Success(Value::Null)
592 })
593 .await;
594 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
595
596 let plugin = BackgroundTasksPlugin::new(mgr);
597 let doc = DocCell::new(json!({}));
598 let rc = RunPolicy::new();
599 let ctx = make_ctx(Phase::AfterToolExecute, "thread-1", &rc, &doc);
600
601 let actions = plugin.after_tool_execute(&ctx).await;
602 let (state_actions, reminders) = after_tool_parts(actions);
603 assert!(
604 reminders.is_empty(),
605 "completed tasks should not trigger reminder"
606 );
607 assert!(
608 state_actions.is_empty(),
609 "empty cached view should remain unchanged"
610 );
611 }
612
613 #[tokio::test]
614 async fn after_tool_execute_clears_stale_derived_view_when_no_tasks() {
615 let mgr = Arc::new(BackgroundTaskManager::new());
616 let plugin = BackgroundTasksPlugin::new(mgr);
617 let doc = DocCell::new(json!({
618 "__derived": {
619 "background_tasks": {
620 "tasks": {
621 "task-1": {
622 "task_type": "shell",
623 "description": "stale task",
624 "status": "running"
625 }
626 },
627 "synced_at_ms": 1
628 }
629 }
630 }));
631 let rc = RunPolicy::new();
632 let ctx = make_ctx(Phase::AfterToolExecute, "thread-1", &rc, &doc);
633
634 let actions = plugin.after_tool_execute(&ctx).await;
635 let (state_actions, reminders) = after_tool_parts(actions);
636 assert!(reminders.is_empty());
637 assert_eq!(state_actions.len(), 1);
638 apply_state_actions(&doc, state_actions);
639 assert!(derived_view(&doc).tasks.is_empty());
640 }
641
642 #[tokio::test]
643 async fn after_tool_execute_thread_isolation() {
644 let mgr = Arc::new(BackgroundTaskManager::new());
645 mgr.spawn("thread-A", "shell", "private task", |cancel| async move {
646 cancel.cancelled().await;
647 super::super::types::TaskResult::Cancelled
648 })
649 .await;
650
651 let plugin = BackgroundTasksPlugin::new(mgr);
652 let doc = DocCell::new(json!({}));
653 let rc = RunPolicy::new();
654
655 let ctx_b = make_ctx(Phase::AfterToolExecute, "thread-B", &rc, &doc);
656 let actions = plugin.after_tool_execute(&ctx_b).await;
657 let (state_actions, reminders) = after_tool_parts(actions);
658 assert!(state_actions.is_empty());
659 assert!(reminders.is_empty());
660
661 let ctx_a = make_ctx(Phase::AfterToolExecute, "thread-A", &rc, &doc);
662 let actions = plugin.after_tool_execute(&ctx_a).await;
663 let (_, reminders) = after_tool_parts(actions);
664 assert_eq!(reminders.len(), 1);
665 }
666}