1use super::types::{
2 new_task_id, task_thread_id, TaskAction, TaskId, TaskResultRef, TaskState, TaskStatus,
3 TaskSummary, TASK_THREAD_KIND_METADATA_KEY, TASK_THREAD_KIND_METADATA_VALUE,
4};
5use crate::contracts::runtime::state::{reduce_state_actions, AnyStateAction, ScopeContext};
6use crate::contracts::storage::{
7 ThreadListQuery, ThreadStore, ThreadStoreError, VersionPrecondition,
8};
9use crate::contracts::thread::{CheckpointReason, Message, Role, Thread, ThreadChangeSet};
10use serde_json::{json, Value};
11use std::sync::Arc;
12use thiserror::Error;
13use tirea_state::State;
14
15fn now_ms() -> u64 {
16 use std::time::{SystemTime, UNIX_EPOCH};
17 SystemTime::now()
18 .duration_since(UNIX_EPOCH)
19 .unwrap_or_default()
20 .as_millis()
21 .min(u128::from(u64::MAX)) as u64
22}
23
24#[derive(Debug, Clone)]
25pub struct NewTaskSpec {
26 pub task_id: TaskId,
27 pub owner_thread_id: String,
28 pub task_type: String,
29 pub description: String,
30 pub parent_task_id: Option<TaskId>,
31 pub supports_resume: bool,
32 pub metadata: Value,
33}
34
35#[derive(Debug, Error)]
36pub enum TaskStoreError {
37 #[error(transparent)]
38 ThreadStore(#[from] ThreadStoreError),
39 #[error(transparent)]
40 State(#[from] tirea_state::TireaError),
41 #[error("task thread '{0}' is missing durable task state")]
42 MissingTaskState(String),
43 #[error(
44 "task '{task_id}' belongs to owner '{actual_owner_thread_id}' instead of '{expected_owner_thread_id}'"
45 )]
46 OwnerMismatch {
47 task_id: String,
48 expected_owner_thread_id: String,
49 actual_owner_thread_id: String,
50 },
51 #[error("task thread '{thread_id}' contains invalid durable task state")]
52 InvalidTaskState {
53 thread_id: String,
54 #[source]
55 error: tirea_state::TireaError,
56 },
57}
58
59#[derive(Clone)]
60pub struct TaskStore {
61 threads: Arc<dyn ThreadStore>,
62}
63
64impl std::fmt::Debug for TaskStore {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("TaskStore").finish()
67 }
68}
69
70impl TaskStore {
71 pub fn new(threads: Arc<dyn ThreadStore>) -> Self {
72 Self { threads }
73 }
74
75 pub fn thread_id_for(task_id: &str) -> String {
76 task_thread_id(task_id)
77 }
78
79 pub fn alloc_task_id() -> TaskId {
80 new_task_id()
81 }
82
83 pub async fn create_task(&self, spec: NewTaskSpec) -> Result<TaskState, TaskStoreError> {
84 let task_id = spec.task_id.clone();
85 let thread_id = task_thread_id(&task_id);
86 let created_at_ms = now_ms();
87 let mut thread =
88 Thread::new(thread_id.clone()).with_parent_thread_id(spec.owner_thread_id.clone());
89 thread.metadata.extra.insert(
90 TASK_THREAD_KIND_METADATA_KEY.to_string(),
91 json!(TASK_THREAD_KIND_METADATA_VALUE),
92 );
93 thread
94 .metadata
95 .extra
96 .insert("task_id".to_string(), json!(spec.task_id));
97 thread
98 .metadata
99 .extra
100 .insert("owner_thread_id".to_string(), json!(spec.owner_thread_id));
101 self.threads.create(&thread).await?;
102
103 let task = TaskState {
104 id: spec.task_id,
105 task_type: spec.task_type,
106 description: spec.description,
107 owner_thread_id: spec.owner_thread_id,
108 parent_task_id: spec.parent_task_id,
109 status: TaskStatus::Running,
110 error: None,
111 result: None,
112 result_ref: None,
113 checkpoint: None,
114 supports_resume: spec.supports_resume,
115 attempt: 1,
116 created_at_ms,
117 updated_at_ms: created_at_ms,
118 completed_at_ms: None,
119 cancel_requested_at_ms: None,
120 metadata: spec.metadata,
121 };
122 self.append_task_action(
123 &thread_id,
124 task_id.as_str(),
125 TaskAction::Register {
126 task: Box::new(task.clone()),
127 },
128 Some(Message::internal_system(format!(
129 "background task {} registered as running",
130 task.id
131 ))),
132 )
133 .await?;
134 Ok(task)
135 }
136
137 pub async fn load_task(&self, task_id: &str) -> Result<Option<TaskState>, TaskStoreError> {
138 let Some(head) = self.threads.load(&task_thread_id(task_id)).await? else {
139 return Ok(None);
140 };
141 Ok(Some(Self::task_state_from_thread(&head.thread)?))
142 }
143
144 pub async fn load_task_for_owner(
145 &self,
146 owner_thread_id: &str,
147 task_id: &str,
148 ) -> Result<Option<TaskState>, TaskStoreError> {
149 let Some(task) = self.load_task(task_id).await? else {
150 return Ok(None);
151 };
152 if task.owner_thread_id != owner_thread_id {
153 return Ok(None);
154 }
155 Ok(Some(task))
156 }
157
158 pub async fn list_tasks_for_owner(
159 &self,
160 owner_thread_id: &str,
161 ) -> Result<Vec<TaskState>, TaskStoreError> {
162 let mut offset = 0usize;
163 let mut out = Vec::new();
164 loop {
165 let page = self
166 .threads
167 .list_threads(&ThreadListQuery {
168 offset,
169 limit: 200,
170 resource_id: None,
171 parent_thread_id: Some(owner_thread_id.to_string()),
172 })
173 .await?;
174 for thread_id in &page.items {
175 let Some(head) = self.threads.load(thread_id).await? else {
176 continue;
177 };
178 if !Self::is_task_thread(&head.thread) {
179 continue;
180 }
181 let task = Self::task_state_from_thread(&head.thread)?;
182 if task.owner_thread_id == owner_thread_id {
183 out.push(task);
184 }
185 }
186 if !page.has_more {
187 break;
188 }
189 offset += page.items.len();
190 }
191 out.sort_by(|a, b| a.created_at_ms.cmp(&b.created_at_ms));
192 Ok(out)
193 }
194
195 pub async fn start_task_attempt(&self, task_id: &str) -> Result<TaskState, TaskStoreError> {
196 let thread_id = task_thread_id(task_id);
197 let task = self
198 .load_task(task_id)
199 .await?
200 .ok_or_else(|| TaskStoreError::MissingTaskState(thread_id.clone()))?;
201 let next_attempt = task.attempt.max(1) + 1;
202 self.append_task_action(
203 &thread_id,
204 task_id,
205 TaskAction::StartAttempt {
206 attempt: next_attempt,
207 updated_at_ms: now_ms(),
208 },
209 Some(Message::internal_system(format!(
210 "background task {} resumed (attempt {})",
211 task_id, next_attempt
212 ))),
213 )
214 .await?;
215 self.load_task(task_id)
216 .await?
217 .ok_or(TaskStoreError::MissingTaskState(thread_id))
218 }
219
220 pub async fn mark_cancel_requested(&self, task_id: &str) -> Result<(), TaskStoreError> {
221 let thread_id = task_thread_id(task_id);
222 self.append_task_action(
223 &thread_id,
224 task_id,
225 TaskAction::MarkCancelRequested {
226 requested_at_ms: now_ms(),
227 },
228 Some(Message::internal_system(format!(
229 "background task {} cancellation requested",
230 task_id
231 ))),
232 )
233 .await
234 }
235
236 pub async fn set_checkpoint(
237 &self,
238 task_id: &str,
239 checkpoint: Value,
240 ) -> Result<(), TaskStoreError> {
241 let thread_id = task_thread_id(task_id);
242 self.append_task_action(
243 &thread_id,
244 task_id,
245 TaskAction::SetCheckpoint {
246 checkpoint,
247 updated_at_ms: now_ms(),
248 },
249 None,
250 )
251 .await
252 }
253
254 pub async fn persist_summary(&self, summary: &TaskSummary) -> Result<(), TaskStoreError> {
255 let thread_id = task_thread_id(&summary.task_id);
256 let task = self
257 .load_task(&summary.task_id)
258 .await?
259 .ok_or_else(|| TaskStoreError::MissingTaskState(thread_id.clone()))?;
260 let result_ref = if summary.task_type == "agent_run"
261 && matches!(summary.status, TaskStatus::Completed | TaskStatus::Stopped)
262 {
263 self.resolve_agent_output_ref(&task).await?
264 } else {
265 None
266 };
267
268 self.append_task_action(
269 &thread_id,
270 &summary.task_id,
271 TaskAction::SetStatus {
272 status: summary.status,
273 error: summary.error.clone(),
274 result: if summary.task_type == "agent_run" {
275 None
276 } else {
277 summary.result.clone()
278 },
279 result_ref,
280 completed_at_ms: summary.completed_at_ms.or_else(|| Some(now_ms())),
281 updated_at_ms: now_ms(),
282 },
283 Some(Message::internal_system(format!(
284 "background task {} finished with status {}",
285 summary.task_id,
286 summary.status.as_str()
287 ))),
288 )
289 .await
290 }
291
292 pub async fn persist_foreground_result(
293 &self,
294 task_id: &str,
295 status: TaskStatus,
296 error: Option<String>,
297 result: Option<Value>,
298 ) -> Result<(), TaskStoreError> {
299 let thread_id = task_thread_id(task_id);
300 let task = self
301 .load_task(task_id)
302 .await?
303 .ok_or_else(|| TaskStoreError::MissingTaskState(thread_id.clone()))?;
304 let result_ref = if task.task_type == "agent_run"
305 && matches!(status, TaskStatus::Completed | TaskStatus::Stopped)
306 {
307 self.resolve_agent_output_ref(&task).await?
308 } else {
309 None
310 };
311
312 self.append_task_action(
313 &thread_id,
314 task_id,
315 TaskAction::SetStatus {
316 status,
317 error,
318 result: if task.task_type == "agent_run" {
319 None
320 } else {
321 result
322 },
323 result_ref,
324 completed_at_ms: Some(now_ms()),
325 updated_at_ms: now_ms(),
326 },
327 Some(Message::internal_system(format!(
328 "background task {} persisted terminal status {}",
329 task_id,
330 status.as_str()
331 ))),
332 )
333 .await
334 }
335
336 pub async fn load_output_text(
337 &self,
338 task: &TaskState,
339 ) -> Result<Option<String>, TaskStoreError> {
340 let Some(result_ref) = task.result_ref.as_ref() else {
341 if task.task_type == "agent_run" {
342 if let Some(thread_id) = task.metadata.get("thread_id").and_then(Value::as_str) {
343 return self.load_thread_message_text(thread_id, None).await;
344 }
345 }
346 return Ok(None);
347 };
348 match result_ref {
349 TaskResultRef::ThreadMessage {
350 thread_id,
351 message_id,
352 } => {
353 self.load_thread_message_text(thread_id, message_id.as_deref())
354 .await
355 }
356 TaskResultRef::External { uri } => Ok(Some(uri.clone())),
357 }
358 }
359
360 pub async fn descendant_ids_for_owner(
361 &self,
362 owner_thread_id: &str,
363 root_task_id: &str,
364 ) -> Result<Vec<TaskId>, TaskStoreError> {
365 let tasks = self.list_tasks_for_owner(owner_thread_id).await?;
366 let mut by_parent: std::collections::HashMap<String, Vec<String>> =
367 std::collections::HashMap::new();
368 for task in &tasks {
369 if let Some(parent) = task.parent_task_id.as_ref() {
370 by_parent
371 .entry(parent.clone())
372 .or_default()
373 .push(task.id.clone());
374 }
375 }
376 let mut out = Vec::new();
377 let mut stack = vec![root_task_id.to_string()];
378 while let Some(current) = stack.pop() {
379 if tasks.iter().any(|task| task.id == current) {
380 out.push(current.clone());
381 }
382 if let Some(children) = by_parent.get(¤t) {
383 for child in children {
384 stack.push(child.clone());
385 }
386 }
387 }
388 Ok(out)
389 }
390
391 fn is_task_thread(thread: &Thread) -> bool {
392 thread
393 .metadata
394 .extra
395 .get(TASK_THREAD_KIND_METADATA_KEY)
396 .and_then(Value::as_str)
397 == Some(TASK_THREAD_KIND_METADATA_VALUE)
398 }
399
400 fn task_state_from_thread(thread: &Thread) -> Result<TaskState, TaskStoreError> {
401 let snapshot = thread.rebuild_state()?;
402 let Some(value) = snapshot.get(TaskState::PATH) else {
403 return Err(TaskStoreError::MissingTaskState(thread.id.clone()));
404 };
405 TaskState::from_value(value).map_err(|error| TaskStoreError::InvalidTaskState {
406 thread_id: thread.id.clone(),
407 error,
408 })
409 }
410
411 async fn resolve_agent_output_ref(
412 &self,
413 task: &TaskState,
414 ) -> Result<Option<TaskResultRef>, TaskStoreError> {
415 let Some(thread_id) = task.metadata.get("thread_id").and_then(Value::as_str) else {
416 return Ok(None);
417 };
418 let Some(head) = self.threads.load(thread_id).await? else {
419 return Ok(Some(TaskResultRef::ThreadMessage {
420 thread_id: thread_id.to_string(),
421 message_id: None,
422 }));
423 };
424 let message_id = head
425 .thread
426 .messages
427 .iter()
428 .rev()
429 .find(|m| m.role == Role::Assistant)
430 .and_then(|m| m.id.clone());
431 Ok(Some(TaskResultRef::ThreadMessage {
432 thread_id: thread_id.to_string(),
433 message_id,
434 }))
435 }
436
437 async fn load_thread_message_text(
438 &self,
439 thread_id: &str,
440 message_id: Option<&str>,
441 ) -> Result<Option<String>, TaskStoreError> {
442 let Some(head) = self.threads.load(thread_id).await? else {
443 return Ok(None);
444 };
445 let msg = if let Some(message_id) = message_id {
446 head.thread
447 .messages
448 .iter()
449 .find(|m| m.id.as_deref() == Some(message_id))
450 .map(|m| m.content.clone())
451 } else {
452 head.thread
453 .messages
454 .iter()
455 .rev()
456 .find(|m| m.role == Role::Assistant)
457 .map(|m| m.content.clone())
458 };
459 Ok(msg)
460 }
461
462 async fn append_task_action(
463 &self,
464 thread_id: &str,
465 task_id: &str,
466 action: TaskAction,
467 audit_message: Option<Message>,
468 ) -> Result<(), TaskStoreError> {
469 let head = self
470 .threads
471 .load(thread_id)
472 .await?
473 .ok_or_else(|| TaskStoreError::MissingTaskState(thread_id.to_string()))?;
474 let mut snapshot = head.thread.rebuild_state()?;
475 if snapshot.get(TaskState::PATH).is_none() {
476 let default_task = serde_json::to_value(TaskState::default())
477 .map_err(tirea_state::TireaError::from)?;
478 match snapshot.as_object_mut() {
479 Some(obj) => {
480 obj.insert(TaskState::PATH.to_string(), default_task);
481 }
482 None => {
483 snapshot = json!({ TaskState::PATH: default_task });
484 }
485 }
486 }
487 let state_action = AnyStateAction::new::<TaskState>(action);
488 let serialized = vec![state_action.to_serialized_state_action()];
489 let patches = reduce_state_actions(
490 vec![state_action],
491 &snapshot,
492 "background_task",
493 &ScopeContext::run(),
494 )?;
495 let changeset = ThreadChangeSet::from_parts(
496 task_id.to_string(),
497 None,
498 CheckpointReason::ToolResultsCommitted,
499 audit_message.into_iter().map(std::sync::Arc::new).collect(),
500 patches,
501 serialized,
502 None,
503 );
504 self.threads
505 .append(
506 thread_id,
507 &changeset,
508 VersionPrecondition::Exact(head.version),
509 )
510 .await?;
511 Ok(())
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use crate::contracts::storage::{ThreadReader, ThreadStore, ThreadWriter};
519
520 #[tokio::test]
521 async fn create_task_persists_task_thread_state() {
522 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
523 let store = TaskStore::new(storage.clone() as Arc<dyn ThreadStore>);
524
525 let task = store
526 .create_task(NewTaskSpec {
527 task_id: "task-1".to_string(),
528 owner_thread_id: "owner-1".to_string(),
529 task_type: "shell".to_string(),
530 description: "echo hi".to_string(),
531 parent_task_id: Some("root".to_string()),
532 supports_resume: false,
533 metadata: json!({"kind":"test"}),
534 })
535 .await
536 .expect("task should persist");
537
538 assert_eq!(task.id, "task-1");
539 assert_eq!(task.status, TaskStatus::Running);
540 assert_eq!(task.parent_task_id.as_deref(), Some("root"));
541
542 let loaded = store
543 .load_task("task-1")
544 .await
545 .expect("load should succeed")
546 .expect("task should exist");
547 assert_eq!(loaded.id, "task-1");
548 assert_eq!(loaded.owner_thread_id, "owner-1");
549 assert_eq!(loaded.metadata["kind"], json!("test"));
550
551 let head = storage
552 .load(&task_thread_id("task-1"))
553 .await
554 .expect("thread load should succeed")
555 .expect("task thread should exist");
556 assert_eq!(head.thread.parent_thread_id.as_deref(), Some("owner-1"));
557 }
558
559 #[tokio::test]
560 async fn list_tasks_for_owner_ignores_non_task_children() {
561 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
562 let store = TaskStore::new(storage.clone() as Arc<dyn ThreadStore>);
563
564 storage
565 .create(&Thread::new("child-thread").with_parent_thread_id("owner-1"))
566 .await
567 .expect("non-task child thread should persist");
568
569 store
570 .create_task(NewTaskSpec {
571 task_id: "task-1".to_string(),
572 owner_thread_id: "owner-1".to_string(),
573 task_type: "shell".to_string(),
574 description: "owner one".to_string(),
575 parent_task_id: None,
576 supports_resume: false,
577 metadata: json!({}),
578 })
579 .await
580 .expect("owner task should persist");
581 store
582 .create_task(NewTaskSpec {
583 task_id: "task-2".to_string(),
584 owner_thread_id: "owner-2".to_string(),
585 task_type: "shell".to_string(),
586 description: "owner two".to_string(),
587 parent_task_id: None,
588 supports_resume: false,
589 metadata: json!({}),
590 })
591 .await
592 .expect("other owner task should persist");
593
594 let tasks = store
595 .list_tasks_for_owner("owner-1")
596 .await
597 .expect("list should succeed");
598
599 assert_eq!(tasks.len(), 1);
600 assert_eq!(tasks[0].id, "task-1");
601 }
602
603 #[tokio::test]
604 async fn mark_cancel_requested_persists_timestamp() {
605 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
606 let store = TaskStore::new(storage as Arc<dyn ThreadStore>);
607
608 let task = store
609 .create_task(NewTaskSpec {
610 task_id: "task-1".to_string(),
611 owner_thread_id: "owner-1".to_string(),
612 task_type: "shell".to_string(),
613 description: "cancel me".to_string(),
614 parent_task_id: None,
615 supports_resume: false,
616 metadata: json!({}),
617 })
618 .await
619 .expect("task should persist");
620
621 store
622 .mark_cancel_requested("task-1")
623 .await
624 .expect("cancel request should persist");
625
626 let loaded = store
627 .load_task("task-1")
628 .await
629 .expect("load should succeed")
630 .expect("task should exist");
631 assert!(loaded.cancel_requested_at_ms.is_some());
632 assert!(loaded.updated_at_ms >= task.updated_at_ms);
633 }
634
635 #[tokio::test]
636 async fn persist_summary_for_agent_run_captures_output_ref() {
637 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
638 let store = TaskStore::new(storage.clone() as Arc<dyn ThreadStore>);
639
640 storage
641 .create(
642 &Thread::new("exec-1")
643 .with_message(Message::assistant("draft").with_id("msg-1".to_string()))
644 .with_message(Message::assistant("final").with_id("msg-2".to_string())),
645 )
646 .await
647 .expect("execution thread should persist");
648
649 let task = store
650 .create_task(NewTaskSpec {
651 task_id: "task-1".to_string(),
652 owner_thread_id: "owner-1".to_string(),
653 task_type: "agent_run".to_string(),
654 description: "delegate".to_string(),
655 parent_task_id: None,
656 supports_resume: true,
657 metadata: json!({"thread_id":"exec-1","agent_id":"writer"}),
658 })
659 .await
660 .expect("task should persist");
661
662 let mut summary = task.summary();
663 summary.status = TaskStatus::Completed;
664 summary.completed_at_ms = Some(task.created_at_ms + 1);
665
666 store
667 .persist_summary(&summary)
668 .await
669 .expect("summary should persist");
670
671 let loaded = store
672 .load_task("task-1")
673 .await
674 .expect("load should succeed")
675 .expect("task should exist");
676 assert_eq!(loaded.status, TaskStatus::Completed);
677 assert_eq!(
678 loaded.result_ref,
679 Some(TaskResultRef::ThreadMessage {
680 thread_id: "exec-1".to_string(),
681 message_id: Some("msg-2".to_string()),
682 })
683 );
684 assert_eq!(
685 store
686 .load_output_text(&loaded)
687 .await
688 .expect("output should load")
689 .as_deref(),
690 Some("final")
691 );
692 }
693
694 #[tokio::test]
695 async fn descendant_ids_for_owner_returns_only_owner_subtree() {
696 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
697 let store = TaskStore::new(storage as Arc<dyn ThreadStore>);
698
699 for (task_id, owner, parent) in [
700 ("root", "owner-1", None),
701 ("child", "owner-1", Some("root")),
702 ("grandchild", "owner-1", Some("child")),
703 ("other-root", "owner-2", None),
704 ("other-child", "owner-2", Some("root")),
705 ] {
706 store
707 .create_task(NewTaskSpec {
708 task_id: task_id.to_string(),
709 owner_thread_id: owner.to_string(),
710 task_type: "agent_run".to_string(),
711 description: task_id.to_string(),
712 parent_task_id: parent.map(str::to_string),
713 supports_resume: true,
714 metadata: json!({}),
715 })
716 .await
717 .expect("task should persist");
718 }
719
720 let mut descendants = store
721 .descendant_ids_for_owner("owner-1", "root")
722 .await
723 .expect("descendants should load");
724 descendants.sort();
725
726 assert_eq!(descendants, vec!["child", "grandchild", "root"]);
727 }
728
729 #[tokio::test]
730 async fn load_task_reports_invalid_task_state() {
731 let storage = Arc::new(tirea_store_adapters::MemoryStore::new());
732 let store = TaskStore::new(storage.clone() as Arc<dyn ThreadStore>);
733 let thread_id = task_thread_id("broken-task");
734
735 let mut thread = Thread::with_initial_state(
736 thread_id.clone(),
737 json!({
738 TaskState::PATH: {
739 "id": "broken-task",
740 "status": 123
741 }
742 }),
743 )
744 .with_parent_thread_id("owner-1");
745 thread.metadata.extra.insert(
746 TASK_THREAD_KIND_METADATA_KEY.to_string(),
747 json!(TASK_THREAD_KIND_METADATA_VALUE),
748 );
749 storage
750 .create(&thread)
751 .await
752 .expect("broken task thread should persist");
753
754 let err = store
755 .load_task("broken-task")
756 .await
757 .expect_err("invalid task state should error");
758
759 match err {
760 TaskStoreError::InvalidTaskState {
761 thread_id: err_thread_id,
762 ..
763 } => assert_eq!(err_thread_id, thread_id),
764 other => panic!("expected InvalidTaskState, got {other:?}"),
765 }
766 }
767}