tirea_agentos/runtime/background_tasks/
store.rs

1use super::types::{
2    new_task_id, task_thread_id, TaskAction, TaskId, TaskResultRef, TaskState, TaskStatus,
3    TaskSummary, TASK_THREAD_KIND_METADATA_KEY, TASK_THREAD_KIND_METADATA_VALUE,
4};
5use crate::contracts::runtime::state::{reduce_state_actions, AnyStateAction, ScopeContext};
6use crate::contracts::storage::{
7    ThreadListQuery, ThreadStore, ThreadStoreError, VersionPrecondition,
8};
9use crate::contracts::thread::{CheckpointReason, Message, Role, Thread, ThreadChangeSet};
10use serde_json::{json, Value};
11use std::sync::Arc;
12use thiserror::Error;
13use tirea_state::State;
14
15fn now_ms() -> u64 {
16    use std::time::{SystemTime, UNIX_EPOCH};
17    SystemTime::now()
18        .duration_since(UNIX_EPOCH)
19        .unwrap_or_default()
20        .as_millis()
21        .min(u128::from(u64::MAX)) as u64
22}
23
24#[derive(Debug, Clone)]
25pub struct NewTaskSpec {
26    pub task_id: TaskId,
27    pub owner_thread_id: String,
28    pub task_type: String,
29    pub description: String,
30    pub parent_task_id: Option<TaskId>,
31    pub supports_resume: bool,
32    pub metadata: Value,
33}
34
35#[derive(Debug, Error)]
36pub enum TaskStoreError {
37    #[error(transparent)]
38    ThreadStore(#[from] ThreadStoreError),
39    #[error(transparent)]
40    State(#[from] tirea_state::TireaError),
41    #[error("task thread '{0}' is missing durable task state")]
42    MissingTaskState(String),
43    #[error(
44        "task '{task_id}' belongs to owner '{actual_owner_thread_id}' instead of '{expected_owner_thread_id}'"
45    )]
46    OwnerMismatch {
47        task_id: String,
48        expected_owner_thread_id: String,
49        actual_owner_thread_id: String,
50    },
51    #[error("task thread '{thread_id}' contains invalid durable task state")]
52    InvalidTaskState {
53        thread_id: String,
54        #[source]
55        error: tirea_state::TireaError,
56    },
57}
58
59#[derive(Clone)]
60pub struct TaskStore {
61    threads: Arc<dyn ThreadStore>,
62}
63
64impl std::fmt::Debug for TaskStore {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("TaskStore").finish()
67    }
68}
69
70impl TaskStore {
71    pub fn new(threads: Arc<dyn ThreadStore>) -> Self {
72        Self { threads }
73    }
74
75    pub fn thread_id_for(task_id: &str) -> String {
76        task_thread_id(task_id)
77    }
78
79    pub fn alloc_task_id() -> TaskId {
80        new_task_id()
81    }
82
83    pub async fn create_task(&self, spec: NewTaskSpec) -> Result<TaskState, TaskStoreError> {
84        let task_id = spec.task_id.clone();
85        let thread_id = task_thread_id(&task_id);
86        let created_at_ms = now_ms();
87        let mut thread =
88            Thread::new(thread_id.clone()).with_parent_thread_id(spec.owner_thread_id.clone());
89        thread.metadata.extra.insert(
90            TASK_THREAD_KIND_METADATA_KEY.to_string(),
91            json!(TASK_THREAD_KIND_METADATA_VALUE),
92        );
93        thread
94            .metadata
95            .extra
96            .insert("task_id".to_string(), json!(spec.task_id));
97        thread
98            .metadata
99            .extra
100            .insert("owner_thread_id".to_string(), json!(spec.owner_thread_id));
101        self.threads.create(&thread).await?;
102
103        let task = TaskState {
104            id: spec.task_id,
105            task_type: spec.task_type,
106            description: spec.description,
107            owner_thread_id: spec.owner_thread_id,
108            parent_task_id: spec.parent_task_id,
109            status: TaskStatus::Running,
110            error: None,
111            result: None,
112            result_ref: None,
113            checkpoint: None,
114            supports_resume: spec.supports_resume,
115            attempt: 1,
116            created_at_ms,
117            updated_at_ms: created_at_ms,
118            completed_at_ms: None,
119            cancel_requested_at_ms: None,
120            metadata: spec.metadata,
121        };
122        self.append_task_action(
123            &thread_id,
124            task_id.as_str(),
125            TaskAction::Register {
126                task: Box::new(task.clone()),
127            },
128            Some(Message::internal_system(format!(
129                "background task {} registered as running",
130                task.id
131            ))),
132        )
133        .await?;
134        Ok(task)
135    }
136
137    pub async fn load_task(&self, task_id: &str) -> Result<Option<TaskState>, TaskStoreError> {
138        let Some(head) = self.threads.load(&task_thread_id(task_id)).await? else {
139            return Ok(None);
140        };
141        Ok(Some(Self::task_state_from_thread(&head.thread)?))
142    }
143
144    pub async fn load_task_for_owner(
145        &self,
146        owner_thread_id: &str,
147        task_id: &str,
148    ) -> Result<Option<TaskState>, TaskStoreError> {
149        let Some(task) = self.load_task(task_id).await? else {
150            return Ok(None);
151        };
152        if task.owner_thread_id != owner_thread_id {
153            return Ok(None);
154        }
155        Ok(Some(task))
156    }
157
158    pub async fn list_tasks_for_owner(
159        &self,
160        owner_thread_id: &str,
161    ) -> Result<Vec<TaskState>, TaskStoreError> {
162        let mut offset = 0usize;
163        let mut out = Vec::new();
164        loop {
165            let page = self
166                .threads
167                .list_threads(&ThreadListQuery {
168                    offset,
169                    limit: 200,
170                    resource_id: None,
171                    parent_thread_id: Some(owner_thread_id.to_string()),
172                })
173                .await?;
174            for thread_id in &page.items {
175                let Some(head) = self.threads.load(thread_id).await? else {
176                    continue;
177                };
178                if !Self::is_task_thread(&head.thread) {
179                    continue;
180                }
181                let task = Self::task_state_from_thread(&head.thread)?;
182                if task.owner_thread_id == owner_thread_id {
183                    out.push(task);
184                }
185            }
186            if !page.has_more {
187                break;
188            }
189            offset += page.items.len();
190        }
191        out.sort_by(|a, b| a.created_at_ms.cmp(&b.created_at_ms));
192        Ok(out)
193    }
194
195    pub async fn start_task_attempt(&self, task_id: &str) -> Result<TaskState, TaskStoreError> {
196        let thread_id = task_thread_id(task_id);
197        let task = self
198            .load_task(task_id)
199            .await?
200            .ok_or_else(|| TaskStoreError::MissingTaskState(thread_id.clone()))?;
201        let next_attempt = task.attempt.max(1) + 1;
202        self.append_task_action(
203            &thread_id,
204            task_id,
205            TaskAction::StartAttempt {
206                attempt: next_attempt,
207                updated_at_ms: now_ms(),
208            },
209            Some(Message::internal_system(format!(
210                "background task {} resumed (attempt {})",
211                task_id, next_attempt
212            ))),
213        )
214        .await?;
215        self.load_task(task_id)
216            .await?
217            .ok_or(TaskStoreError::MissingTaskState(thread_id))
218    }
219
220    pub async fn mark_cancel_requested(&self, task_id: &str) -> Result<(), TaskStoreError> {
221        let thread_id = task_thread_id(task_id);
222        self.append_task_action(
223            &thread_id,
224            task_id,
225            TaskAction::MarkCancelRequested {
226                requested_at_ms: now_ms(),
227            },
228            Some(Message::internal_system(format!(
229                "background task {} cancellation requested",
230                task_id
231            ))),
232        )
233        .await
234    }
235
236    pub async fn set_checkpoint(
237        &self,
238        task_id: &str,
239        checkpoint: Value,
240    ) -> Result<(), TaskStoreError> {
241        let thread_id = task_thread_id(task_id);
242        self.append_task_action(
243            &thread_id,
244            task_id,
245            TaskAction::SetCheckpoint {
246                checkpoint,
247                updated_at_ms: now_ms(),
248            },
249            None,
250        )
251        .await
252    }
253
254    pub async fn persist_summary(&self, summary: &TaskSummary) -> Result<(), TaskStoreError> {
255        let thread_id = task_thread_id(&summary.task_id);
256        let task = self
257            .load_task(&summary.task_id)
258            .await?
259            .ok_or_else(|| TaskStoreError::MissingTaskState(thread_id.clone()))?;
260        let result_ref = if summary.task_type == "agent_run"
261            && matches!(summary.status, TaskStatus::Completed | TaskStatus::Stopped)
262        {
263            self.resolve_agent_output_ref(&task).await?
264        } else {
265            None
266        };
267
268        self.append_task_action(
269            &thread_id,
270            &summary.task_id,
271            TaskAction::SetStatus {
272                status: summary.status,
273                error: summary.error.clone(),
274                result: if summary.task_type == "agent_run" {
275                    None
276                } else {
277                    summary.result.clone()
278                },
279                result_ref,
280                completed_at_ms: summary.completed_at_ms.or_else(|| Some(now_ms())),
281                updated_at_ms: now_ms(),
282            },
283            Some(Message::internal_system(format!(
284                "background task {} finished with status {}",
285                summary.task_id,
286                summary.status.as_str()
287            ))),
288        )
289        .await
290    }
291
292    pub async fn persist_foreground_result(
293        &self,
294        task_id: &str,
295        status: TaskStatus,
296        error: Option<String>,
297        result: Option<Value>,
298    ) -> Result<(), TaskStoreError> {
299        let thread_id = task_thread_id(task_id);
300        let task = self
301            .load_task(task_id)
302            .await?
303            .ok_or_else(|| TaskStoreError::MissingTaskState(thread_id.clone()))?;
304        let result_ref = if task.task_type == "agent_run"
305            && matches!(status, TaskStatus::Completed | TaskStatus::Stopped)
306        {
307            self.resolve_agent_output_ref(&task).await?
308        } else {
309            None
310        };
311
312        self.append_task_action(
313            &thread_id,
314            task_id,
315            TaskAction::SetStatus {
316                status,
317                error,
318                result: if task.task_type == "agent_run" {
319                    None
320                } else {
321                    result
322                },
323                result_ref,
324                completed_at_ms: Some(now_ms()),
325                updated_at_ms: now_ms(),
326            },
327            Some(Message::internal_system(format!(
328                "background task {} persisted terminal status {}",
329                task_id,
330                status.as_str()
331            ))),
332        )
333        .await
334    }
335
336    pub async fn load_output_text(
337        &self,
338        task: &TaskState,
339    ) -> Result<Option<String>, TaskStoreError> {
340        let Some(result_ref) = task.result_ref.as_ref() else {
341            if task.task_type == "agent_run" {
342                if let Some(thread_id) = task.metadata.get("thread_id").and_then(Value::as_str) {
343                    return self.load_thread_message_text(thread_id, None).await;
344                }
345            }
346            return Ok(None);
347        };
348        match result_ref {
349            TaskResultRef::ThreadMessage {
350                thread_id,
351                message_id,
352            } => {
353                self.load_thread_message_text(thread_id, message_id.as_deref())
354                    .await
355            }
356            TaskResultRef::External { uri } => Ok(Some(uri.clone())),
357        }
358    }
359
360    pub async fn descendant_ids_for_owner(
361        &self,
362        owner_thread_id: &str,
363        root_task_id: &str,
364    ) -> Result<Vec<TaskId>, TaskStoreError> {
365        let tasks = self.list_tasks_for_owner(owner_thread_id).await?;
366        let mut by_parent: std::collections::HashMap<String, Vec<String>> =
367            std::collections::HashMap::new();
368        for task in &tasks {
369            if let Some(parent) = task.parent_task_id.as_ref() {
370                by_parent
371                    .entry(parent.clone())
372                    .or_default()
373                    .push(task.id.clone());
374            }
375        }
376        let mut out = Vec::new();
377        let mut stack = vec![root_task_id.to_string()];
378        while let Some(current) = stack.pop() {
379            if tasks.iter().any(|task| task.id == current) {
380                out.push(current.clone());
381            }
382            if let Some(children) = by_parent.get(&current) {
383                for child in children {
384                    stack.push(child.clone());
385                }
386            }
387        }
388        Ok(out)
389    }
390
391    fn is_task_thread(thread: &Thread) -> bool {
392        thread
393            .metadata
394            .extra
395            .get(TASK_THREAD_KIND_METADATA_KEY)
396            .and_then(Value::as_str)
397            == Some(TASK_THREAD_KIND_METADATA_VALUE)
398    }
399
400    fn task_state_from_thread(thread: &Thread) -> Result<TaskState, TaskStoreError> {
401        let snapshot = thread.rebuild_state()?;
402        let Some(value) = snapshot.get(TaskState::PATH) else {
403            return Err(TaskStoreError::MissingTaskState(thread.id.clone()));
404        };
405        TaskState::from_value(value).map_err(|error| TaskStoreError::InvalidTaskState {
406            thread_id: thread.id.clone(),
407            error,
408        })
409    }
410
411    async fn resolve_agent_output_ref(
412        &self,
413        task: &TaskState,
414    ) -> Result<Option<TaskResultRef>, TaskStoreError> {
415        let Some(thread_id) = task.metadata.get("thread_id").and_then(Value::as_str) else {
416            return Ok(None);
417        };
418        let Some(head) = self.threads.load(thread_id).await? else {
419            return Ok(Some(TaskResultRef::ThreadMessage {
420                thread_id: thread_id.to_string(),
421                message_id: None,
422            }));
423        };
424        let message_id = head
425            .thread
426            .messages
427            .iter()
428            .rev()
429            .find(|m| m.role == Role::Assistant)
430            .and_then(|m| m.id.clone());
431        Ok(Some(TaskResultRef::ThreadMessage {
432            thread_id: thread_id.to_string(),
433            message_id,
434        }))
435    }
436
437    async fn load_thread_message_text(
438        &self,
439        thread_id: &str,
440        message_id: Option<&str>,
441    ) -> Result<Option<String>, TaskStoreError> {
442        let Some(head) = self.threads.load(thread_id).await? else {
443            return Ok(None);
444        };
445        let msg = if let Some(message_id) = message_id {
446            head.thread
447                .messages
448                .iter()
449                .find(|m| m.id.as_deref() == Some(message_id))
450                .map(|m| m.content.clone())
451        } else {
452            head.thread
453                .messages
454                .iter()
455                .rev()
456                .find(|m| m.role == Role::Assistant)
457                .map(|m| m.content.clone())
458        };
459        Ok(msg)
460    }
461
462    async fn append_task_action(
463        &self,
464        thread_id: &str,
465        task_id: &str,
466        action: TaskAction,
467        audit_message: Option<Message>,
468    ) -> Result<(), TaskStoreError> {
469        let head = self
470            .threads
471            .load(thread_id)
472            .await?
473            .ok_or_else(|| TaskStoreError::MissingTaskState(thread_id.to_string()))?;
474        let mut snapshot = head.thread.rebuild_state()?;
475        if snapshot.get(TaskState::PATH).is_none() {
476            let default_task = serde_json::to_value(TaskState::default())
477                .map_err(tirea_state::TireaError::from)?;
478            match snapshot.as_object_mut() {
479                Some(obj) => {
480                    obj.insert(TaskState::PATH.to_string(), default_task);
481                }
482                None => {
483                    snapshot = json!({ TaskState::PATH: default_task });
484                }
485            }
486        }
487        let state_action = AnyStateAction::new::<TaskState>(action);
488        let serialized = vec![state_action.to_serialized_state_action()];
489        let patches = reduce_state_actions(
490            vec![state_action],
491            &snapshot,
492            "background_task",
493            &ScopeContext::run(),
494        )?;
495        let changeset = ThreadChangeSet::from_parts(
496            task_id.to_string(),
497            None,
498            CheckpointReason::ToolResultsCommitted,
499            audit_message.into_iter().map(std::sync::Arc::new).collect(),
500            patches,
501            serialized,
502            None,
503        );
504        self.threads
505            .append(
506                thread_id,
507                &changeset,
508                VersionPrecondition::Exact(head.version),
509            )
510            .await?;
511        Ok(())
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::contracts::storage::{ThreadReader, ThreadStore, ThreadWriter};
519
520    #[tokio::test]
521    async fn create_task_persists_task_thread_state() {
522        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
523        let store = TaskStore::new(storage.clone() as Arc<dyn ThreadStore>);
524
525        let task = store
526            .create_task(NewTaskSpec {
527                task_id: "task-1".to_string(),
528                owner_thread_id: "owner-1".to_string(),
529                task_type: "shell".to_string(),
530                description: "echo hi".to_string(),
531                parent_task_id: Some("root".to_string()),
532                supports_resume: false,
533                metadata: json!({"kind":"test"}),
534            })
535            .await
536            .expect("task should persist");
537
538        assert_eq!(task.id, "task-1");
539        assert_eq!(task.status, TaskStatus::Running);
540        assert_eq!(task.parent_task_id.as_deref(), Some("root"));
541
542        let loaded = store
543            .load_task("task-1")
544            .await
545            .expect("load should succeed")
546            .expect("task should exist");
547        assert_eq!(loaded.id, "task-1");
548        assert_eq!(loaded.owner_thread_id, "owner-1");
549        assert_eq!(loaded.metadata["kind"], json!("test"));
550
551        let head = storage
552            .load(&task_thread_id("task-1"))
553            .await
554            .expect("thread load should succeed")
555            .expect("task thread should exist");
556        assert_eq!(head.thread.parent_thread_id.as_deref(), Some("owner-1"));
557    }
558
559    #[tokio::test]
560    async fn list_tasks_for_owner_ignores_non_task_children() {
561        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
562        let store = TaskStore::new(storage.clone() as Arc<dyn ThreadStore>);
563
564        storage
565            .create(&Thread::new("child-thread").with_parent_thread_id("owner-1"))
566            .await
567            .expect("non-task child thread should persist");
568
569        store
570            .create_task(NewTaskSpec {
571                task_id: "task-1".to_string(),
572                owner_thread_id: "owner-1".to_string(),
573                task_type: "shell".to_string(),
574                description: "owner one".to_string(),
575                parent_task_id: None,
576                supports_resume: false,
577                metadata: json!({}),
578            })
579            .await
580            .expect("owner task should persist");
581        store
582            .create_task(NewTaskSpec {
583                task_id: "task-2".to_string(),
584                owner_thread_id: "owner-2".to_string(),
585                task_type: "shell".to_string(),
586                description: "owner two".to_string(),
587                parent_task_id: None,
588                supports_resume: false,
589                metadata: json!({}),
590            })
591            .await
592            .expect("other owner task should persist");
593
594        let tasks = store
595            .list_tasks_for_owner("owner-1")
596            .await
597            .expect("list should succeed");
598
599        assert_eq!(tasks.len(), 1);
600        assert_eq!(tasks[0].id, "task-1");
601    }
602
603    #[tokio::test]
604    async fn mark_cancel_requested_persists_timestamp() {
605        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
606        let store = TaskStore::new(storage as Arc<dyn ThreadStore>);
607
608        let task = store
609            .create_task(NewTaskSpec {
610                task_id: "task-1".to_string(),
611                owner_thread_id: "owner-1".to_string(),
612                task_type: "shell".to_string(),
613                description: "cancel me".to_string(),
614                parent_task_id: None,
615                supports_resume: false,
616                metadata: json!({}),
617            })
618            .await
619            .expect("task should persist");
620
621        store
622            .mark_cancel_requested("task-1")
623            .await
624            .expect("cancel request should persist");
625
626        let loaded = store
627            .load_task("task-1")
628            .await
629            .expect("load should succeed")
630            .expect("task should exist");
631        assert!(loaded.cancel_requested_at_ms.is_some());
632        assert!(loaded.updated_at_ms >= task.updated_at_ms);
633    }
634
635    #[tokio::test]
636    async fn persist_summary_for_agent_run_captures_output_ref() {
637        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
638        let store = TaskStore::new(storage.clone() as Arc<dyn ThreadStore>);
639
640        storage
641            .create(
642                &Thread::new("exec-1")
643                    .with_message(Message::assistant("draft").with_id("msg-1".to_string()))
644                    .with_message(Message::assistant("final").with_id("msg-2".to_string())),
645            )
646            .await
647            .expect("execution thread should persist");
648
649        let task = store
650            .create_task(NewTaskSpec {
651                task_id: "task-1".to_string(),
652                owner_thread_id: "owner-1".to_string(),
653                task_type: "agent_run".to_string(),
654                description: "delegate".to_string(),
655                parent_task_id: None,
656                supports_resume: true,
657                metadata: json!({"thread_id":"exec-1","agent_id":"writer"}),
658            })
659            .await
660            .expect("task should persist");
661
662        let mut summary = task.summary();
663        summary.status = TaskStatus::Completed;
664        summary.completed_at_ms = Some(task.created_at_ms + 1);
665
666        store
667            .persist_summary(&summary)
668            .await
669            .expect("summary should persist");
670
671        let loaded = store
672            .load_task("task-1")
673            .await
674            .expect("load should succeed")
675            .expect("task should exist");
676        assert_eq!(loaded.status, TaskStatus::Completed);
677        assert_eq!(
678            loaded.result_ref,
679            Some(TaskResultRef::ThreadMessage {
680                thread_id: "exec-1".to_string(),
681                message_id: Some("msg-2".to_string()),
682            })
683        );
684        assert_eq!(
685            store
686                .load_output_text(&loaded)
687                .await
688                .expect("output should load")
689                .as_deref(),
690            Some("final")
691        );
692    }
693
694    #[tokio::test]
695    async fn descendant_ids_for_owner_returns_only_owner_subtree() {
696        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
697        let store = TaskStore::new(storage as Arc<dyn ThreadStore>);
698
699        for (task_id, owner, parent) in [
700            ("root", "owner-1", None),
701            ("child", "owner-1", Some("root")),
702            ("grandchild", "owner-1", Some("child")),
703            ("other-root", "owner-2", None),
704            ("other-child", "owner-2", Some("root")),
705        ] {
706            store
707                .create_task(NewTaskSpec {
708                    task_id: task_id.to_string(),
709                    owner_thread_id: owner.to_string(),
710                    task_type: "agent_run".to_string(),
711                    description: task_id.to_string(),
712                    parent_task_id: parent.map(str::to_string),
713                    supports_resume: true,
714                    metadata: json!({}),
715                })
716                .await
717                .expect("task should persist");
718        }
719
720        let mut descendants = store
721            .descendant_ids_for_owner("owner-1", "root")
722            .await
723            .expect("descendants should load");
724        descendants.sort();
725
726        assert_eq!(descendants, vec!["child", "grandchild", "root"]);
727    }
728
729    #[tokio::test]
730    async fn load_task_reports_invalid_task_state() {
731        let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
732        let store = TaskStore::new(storage.clone() as Arc<dyn ThreadStore>);
733        let thread_id = task_thread_id("broken-task");
734
735        let mut thread = Thread::with_initial_state(
736            thread_id.clone(),
737            json!({
738                TaskState::PATH: {
739                    "id": "broken-task",
740                    "status": 123
741                }
742            }),
743        )
744        .with_parent_thread_id("owner-1");
745        thread.metadata.extra.insert(
746            TASK_THREAD_KIND_METADATA_KEY.to_string(),
747            json!(TASK_THREAD_KIND_METADATA_VALUE),
748        );
749        storage
750            .create(&thread)
751            .await
752            .expect("broken task thread should persist");
753
754        let err = store
755            .load_task("broken-task")
756            .await
757            .expect_err("invalid task state should error");
758
759        match err {
760            TaskStoreError::InvalidTaskState {
761                thread_id: err_thread_id,
762                ..
763            } => assert_eq!(err_thread_id, thread_id),
764            other => panic!("expected InvalidTaskState, got {other:?}"),
765        }
766    }
767}