tirea_store_adapters/
memory_store.rs

1use async_trait::async_trait;
2use tirea_contract::storage::{
3    has_active_claim_for_mailbox, paginate_mailbox_entries, paginate_runs_in_memory, MailboxEntry,
4    MailboxInterrupt, MailboxPage, MailboxQuery, MailboxReader, MailboxState, MailboxStoreError,
5    MailboxWriter, RunPage, RunQuery, RunReader, RunRecord, RunStoreError, RunWriter, ThreadHead,
6    ThreadListPage, ThreadListQuery, ThreadReader, ThreadStoreError, ThreadSync, ThreadWriter,
7    VersionPrecondition,
8};
9use tirea_contract::{Committed, Thread, ThreadChangeSet, Version};
10
11fn now_unix_millis() -> u64 {
12    std::time::SystemTime::now()
13        .duration_since(std::time::UNIX_EPOCH)
14        .map_or(0, |d| d.as_millis().min(u128::from(u64::MAX)) as u64)
15}
16
17struct MemoryEntry {
18    thread: Thread,
19    version: Version,
20    deltas: Vec<ThreadChangeSet>,
21}
22
23/// In-memory storage for testing and local development.
24#[derive(Default)]
25pub struct MemoryStore {
26    entries: tokio::sync::RwLock<std::collections::HashMap<String, MemoryEntry>>,
27    runs: tokio::sync::RwLock<std::collections::HashMap<String, RunRecord>>,
28    mailbox: tokio::sync::RwLock<std::collections::HashMap<String, MailboxEntry>>,
29    mailbox_states: tokio::sync::RwLock<std::collections::HashMap<String, MailboxState>>,
30}
31
32impl MemoryStore {
33    /// Create a new in-memory storage.
34    pub fn new() -> Self {
35        Self::default()
36    }
37}
38
39#[async_trait]
40impl MailboxReader for MemoryStore {
41    async fn load_mailbox_entry(
42        &self,
43        entry_id: &str,
44    ) -> Result<Option<MailboxEntry>, MailboxStoreError> {
45        Ok(self.mailbox.read().await.get(entry_id).cloned())
46    }
47
48    async fn load_mailbox_state(
49        &self,
50        mailbox_id: &str,
51    ) -> Result<Option<MailboxState>, MailboxStoreError> {
52        Ok(self.mailbox_states.read().await.get(mailbox_id).cloned())
53    }
54
55    async fn list_mailbox_entries(
56        &self,
57        query: &MailboxQuery,
58    ) -> Result<MailboxPage, MailboxStoreError> {
59        let mailbox = self.mailbox.read().await;
60        let entries: Vec<MailboxEntry> = mailbox.values().cloned().collect();
61        Ok(paginate_mailbox_entries(&entries, query))
62    }
63}
64
65#[async_trait]
66impl MailboxWriter for MemoryStore {
67    async fn enqueue_mailbox_entry(&self, entry: &MailboxEntry) -> Result<(), MailboxStoreError> {
68        let mut mailbox_states = self.mailbox_states.write().await;
69        let state = mailbox_states
70            .entry(entry.mailbox_id.clone())
71            .or_insert(MailboxState {
72                mailbox_id: entry.mailbox_id.clone(),
73                current_generation: entry.generation,
74                updated_at: entry.updated_at,
75            });
76        if state.current_generation != entry.generation {
77            return Err(MailboxStoreError::GenerationMismatch {
78                mailbox_id: entry.mailbox_id.clone(),
79                expected: state.current_generation,
80                actual: entry.generation,
81            });
82        }
83
84        let mut mailbox = self.mailbox.write().await;
85        if let Some(dedupe_key) = entry.dedupe_key.as_deref() {
86            if mailbox.values().any(|existing| {
87                existing.mailbox_id == entry.mailbox_id
88                    && existing.dedupe_key.as_deref() == Some(dedupe_key)
89            }) {
90                return Err(MailboxStoreError::AlreadyExists(dedupe_key.to_string()));
91            }
92        }
93        if mailbox.contains_key(&entry.entry_id) {
94            return Err(MailboxStoreError::AlreadyExists(entry.entry_id.clone()));
95        }
96        mailbox.insert(entry.entry_id.clone(), entry.clone());
97        Ok(())
98    }
99
100    async fn ensure_mailbox_state(
101        &self,
102        mailbox_id: &str,
103        now: u64,
104    ) -> Result<MailboxState, MailboxStoreError> {
105        let mut mailbox_states = self.mailbox_states.write().await;
106        let state = mailbox_states
107            .entry(mailbox_id.to_string())
108            .or_insert(MailboxState {
109                mailbox_id: mailbox_id.to_string(),
110                current_generation: 0,
111                updated_at: now,
112            });
113        state.updated_at = now;
114        Ok(state.clone())
115    }
116
117    async fn claim_mailbox_entries(
118        &self,
119        mailbox_id: Option<&str>,
120        limit: usize,
121        consumer_id: &str,
122        now: u64,
123        lease_duration_ms: u64,
124    ) -> Result<Vec<MailboxEntry>, MailboxStoreError> {
125        let mut mailbox = self.mailbox.write().await;
126        let mut claimable_ids: Vec<String> = mailbox
127            .values()
128            .filter(|entry| entry.is_claimable(now))
129            .filter(|entry| match mailbox_id {
130                Some(mid) => entry.mailbox_id == mid,
131                None => true,
132            })
133            .map(|entry| entry.entry_id.clone())
134            .collect();
135        claimable_ids.sort_by(|left, right| {
136            let left_entry = mailbox.get(left).expect("mailbox entry should exist");
137            let right_entry = mailbox.get(right).expect("mailbox entry should exist");
138            right_entry
139                .priority
140                .cmp(&left_entry.priority)
141                .then_with(|| left_entry.available_at.cmp(&right_entry.available_at))
142                .then_with(|| left_entry.created_at.cmp(&right_entry.created_at))
143                .then_with(|| left.cmp(right))
144        });
145
146        // Track which mailbox IDs we've already claimed in this batch so we
147        // don't claim two entries for the same mailbox in one call.
148        let mut claimed_mailbox_ids = std::collections::HashSet::new();
149
150        let mut claimed = Vec::new();
151        for entry_id in claimable_ids.into_iter() {
152            if claimed.len() >= limit {
153                break;
154            }
155            let Some(entry) = mailbox.get(&entry_id) else {
156                continue;
157            };
158            let mid = entry.mailbox_id.clone();
159
160            // Mailbox-level exclusive claim: skip if this mailbox already has
161            // an active (non-expired) claim.
162            if claimed_mailbox_ids.contains(&mid)
163                || has_active_claim_for_mailbox(mailbox.values(), &mid, now, Some(&entry_id))
164            {
165                continue;
166            }
167
168            let entry = mailbox.get_mut(&entry_id).expect("checked above");
169            entry.status = tirea_contract::MailboxEntryStatus::Claimed;
170            entry.claim_token = Some(uuid::Uuid::now_v7().simple().to_string());
171            entry.claimed_by = Some(consumer_id.to_string());
172            entry.lease_until = Some(now.saturating_add(lease_duration_ms));
173            entry.attempt_count = entry.attempt_count.saturating_add(1);
174            entry.updated_at = now;
175            claimed.push(entry.clone());
176            claimed_mailbox_ids.insert(mid);
177        }
178        Ok(claimed)
179    }
180
181    async fn claim_mailbox_entry(
182        &self,
183        entry_id: &str,
184        consumer_id: &str,
185        now: u64,
186        lease_duration_ms: u64,
187    ) -> Result<Option<MailboxEntry>, MailboxStoreError> {
188        let mut mailbox = self.mailbox.write().await;
189        let Some(entry) = mailbox.get(entry_id) else {
190            return Ok(None);
191        };
192        if entry.status.is_terminal() {
193            return Ok(None);
194        }
195        if entry.status == tirea_contract::MailboxEntryStatus::Claimed
196            && entry.lease_until.is_some_and(|lease| lease > now)
197        {
198            return Ok(None);
199        }
200
201        // Mailbox-level exclusive claim: reject if another entry in the same
202        // mailbox already holds an active lease.
203        if has_active_claim_for_mailbox(mailbox.values(), &entry.mailbox_id, now, Some(entry_id)) {
204            return Ok(None);
205        }
206
207        let entry = mailbox.get_mut(entry_id).expect("checked above");
208        entry.status = tirea_contract::MailboxEntryStatus::Claimed;
209        entry.claim_token = Some(uuid::Uuid::now_v7().simple().to_string());
210        entry.claimed_by = Some(consumer_id.to_string());
211        entry.lease_until = Some(now.saturating_add(lease_duration_ms));
212        entry.attempt_count = entry.attempt_count.saturating_add(1);
213        entry.updated_at = now;
214        Ok(Some(entry.clone()))
215    }
216
217    async fn ack_mailbox_entry(
218        &self,
219        entry_id: &str,
220        claim_token: &str,
221        now: u64,
222    ) -> Result<(), MailboxStoreError> {
223        let mut mailbox = self.mailbox.write().await;
224        let entry = mailbox
225            .get_mut(entry_id)
226            .ok_or_else(|| MailboxStoreError::NotFound(entry_id.to_string()))?;
227        if entry.claim_token.as_deref() != Some(claim_token) {
228            return Err(MailboxStoreError::ClaimConflict(entry_id.to_string()));
229        }
230        entry.status = tirea_contract::MailboxEntryStatus::Accepted;
231        entry.claim_token = None;
232        entry.claimed_by = None;
233        entry.lease_until = None;
234        entry.updated_at = now;
235        Ok(())
236    }
237
238    async fn nack_mailbox_entry(
239        &self,
240        entry_id: &str,
241        claim_token: &str,
242        retry_at: u64,
243        error: &str,
244        now: u64,
245    ) -> Result<(), MailboxStoreError> {
246        let mut mailbox = self.mailbox.write().await;
247        let entry = mailbox
248            .get_mut(entry_id)
249            .ok_or_else(|| MailboxStoreError::NotFound(entry_id.to_string()))?;
250        if entry.claim_token.as_deref() != Some(claim_token) {
251            return Err(MailboxStoreError::ClaimConflict(entry_id.to_string()));
252        }
253        entry.status = tirea_contract::MailboxEntryStatus::Queued;
254        entry.available_at = retry_at;
255        entry.last_error = Some(error.to_string());
256        entry.claim_token = None;
257        entry.claimed_by = None;
258        entry.lease_until = None;
259        entry.updated_at = now;
260        Ok(())
261    }
262
263    async fn dead_letter_mailbox_entry(
264        &self,
265        entry_id: &str,
266        claim_token: &str,
267        error: &str,
268        now: u64,
269    ) -> Result<(), MailboxStoreError> {
270        let mut mailbox = self.mailbox.write().await;
271        let entry = mailbox
272            .get_mut(entry_id)
273            .ok_or_else(|| MailboxStoreError::NotFound(entry_id.to_string()))?;
274        if entry.claim_token.as_deref() != Some(claim_token) {
275            return Err(MailboxStoreError::ClaimConflict(entry_id.to_string()));
276        }
277        entry.status = tirea_contract::MailboxEntryStatus::DeadLetter;
278        entry.last_error = Some(error.to_string());
279        entry.claim_token = None;
280        entry.claimed_by = None;
281        entry.lease_until = None;
282        entry.updated_at = now;
283        Ok(())
284    }
285
286    async fn cancel_mailbox_entry(
287        &self,
288        entry_id: &str,
289        now: u64,
290    ) -> Result<Option<MailboxEntry>, MailboxStoreError> {
291        let mut mailbox = self.mailbox.write().await;
292        let Some(entry) = mailbox.get_mut(entry_id) else {
293            return Ok(None);
294        };
295        if entry.status.is_terminal() {
296            return Ok(Some(entry.clone()));
297        }
298        entry.status = tirea_contract::MailboxEntryStatus::Cancelled;
299        entry.last_error = Some("cancelled".to_string());
300        entry.claim_token = None;
301        entry.claimed_by = None;
302        entry.lease_until = None;
303        entry.updated_at = now;
304        Ok(Some(entry.clone()))
305    }
306
307    async fn supersede_mailbox_entry(
308        &self,
309        entry_id: &str,
310        now: u64,
311        reason: &str,
312    ) -> Result<Option<MailboxEntry>, MailboxStoreError> {
313        let mut mailbox = self.mailbox.write().await;
314        let Some(entry) = mailbox.get_mut(entry_id) else {
315            return Ok(None);
316        };
317        if entry.status.is_terminal() {
318            return Ok(Some(entry.clone()));
319        }
320        entry.status = tirea_contract::MailboxEntryStatus::Superseded;
321        entry.last_error = Some(reason.to_string());
322        entry.claim_token = None;
323        entry.claimed_by = None;
324        entry.lease_until = None;
325        entry.updated_at = now;
326        Ok(Some(entry.clone()))
327    }
328
329    async fn cancel_pending_for_mailbox(
330        &self,
331        mailbox_id: &str,
332        now: u64,
333        exclude_entry_id: Option<&str>,
334    ) -> Result<Vec<MailboxEntry>, MailboxStoreError> {
335        let mut mailbox = self.mailbox.write().await;
336        let mut cancelled = Vec::new();
337        for entry in mailbox.values_mut() {
338            if entry.mailbox_id != mailbox_id || entry.status.is_terminal() {
339                continue;
340            }
341            if exclude_entry_id.is_some_and(|eid| entry.entry_id == eid) {
342                continue;
343            }
344            entry.status = tirea_contract::MailboxEntryStatus::Cancelled;
345            entry.last_error = Some("cancelled".to_string());
346            entry.claim_token = None;
347            entry.claimed_by = None;
348            entry.lease_until = None;
349            entry.updated_at = now;
350            cancelled.push(entry.clone());
351        }
352        Ok(cancelled)
353    }
354
355    async fn interrupt_mailbox(
356        &self,
357        mailbox_id: &str,
358        now: u64,
359    ) -> Result<MailboxInterrupt, MailboxStoreError> {
360        let mut mailbox_states = self.mailbox_states.write().await;
361        let mut mailbox = self.mailbox.write().await;
362
363        let state = mailbox_states
364            .entry(mailbox_id.to_string())
365            .or_insert(MailboxState {
366                mailbox_id: mailbox_id.to_string(),
367                current_generation: 0,
368                updated_at: now,
369            });
370        state.current_generation = state.current_generation.saturating_add(1);
371        state.updated_at = now;
372        let next_generation = state.current_generation;
373        let mailbox_state = state.clone();
374
375        let mut superseded = Vec::new();
376        for entry in mailbox.values_mut() {
377            if entry.mailbox_id != mailbox_id || entry.status.is_terminal() {
378                continue;
379            }
380            if entry.generation >= next_generation {
381                continue;
382            }
383            entry.status = tirea_contract::MailboxEntryStatus::Superseded;
384            entry.last_error = Some("superseded by interrupt".to_string());
385            entry.claim_token = None;
386            entry.claimed_by = None;
387            entry.lease_until = None;
388            entry.updated_at = now;
389            superseded.push(entry.clone());
390        }
391
392        Ok(MailboxInterrupt {
393            mailbox_state,
394            superseded_entries: superseded,
395        })
396    }
397
398    async fn extend_lease(
399        &self,
400        entry_id: &str,
401        claim_token: &str,
402        extension_ms: u64,
403        now: u64,
404    ) -> Result<bool, MailboxStoreError> {
405        let mut mailbox = self.mailbox.write().await;
406        let Some(entry) = mailbox.get_mut(entry_id) else {
407            return Ok(false);
408        };
409        if entry.status != tirea_contract::MailboxEntryStatus::Claimed {
410            return Ok(false);
411        }
412        if entry.claim_token.as_deref() != Some(claim_token) {
413            return Ok(false);
414        }
415        entry.lease_until = Some(now.saturating_add(extension_ms));
416        entry.updated_at = now;
417        Ok(true)
418    }
419
420    async fn purge_terminal_mailbox_entries(
421        &self,
422        older_than: u64,
423    ) -> Result<usize, MailboxStoreError> {
424        let mut mailbox = self.mailbox.write().await;
425        let before = mailbox.len();
426        mailbox.retain(|_, entry| !(entry.status.is_terminal() && entry.updated_at < older_than));
427        Ok(before - mailbox.len())
428    }
429}
430
431#[async_trait]
432impl ThreadWriter for MemoryStore {
433    async fn create(&self, thread: &Thread) -> Result<Committed, ThreadStoreError> {
434        let mut entries = self.entries.write().await;
435        if entries.contains_key(&thread.id) {
436            return Err(ThreadStoreError::AlreadyExists);
437        }
438        entries.insert(
439            thread.id.clone(),
440            MemoryEntry {
441                thread: thread.clone(),
442                version: 0,
443                deltas: Vec::new(),
444            },
445        );
446        Ok(Committed { version: 0 })
447    }
448
449    async fn append(
450        &self,
451        thread_id: &str,
452        delta: &ThreadChangeSet,
453        precondition: VersionPrecondition,
454    ) -> Result<Committed, ThreadStoreError> {
455        let mut entries = self.entries.write().await;
456        let entry = entries
457            .get_mut(thread_id)
458            .ok_or_else(|| ThreadStoreError::NotFound(thread_id.to_string()))?;
459
460        if let VersionPrecondition::Exact(expected) = precondition {
461            if entry.version != expected {
462                return Err(ThreadStoreError::VersionConflict {
463                    expected,
464                    actual: entry.version,
465                });
466            }
467        }
468
469        delta.apply_to(&mut entry.thread);
470        entry.version += 1;
471        entry.deltas.push(delta.clone());
472
473        // Maintain run index from changeset metadata.
474        if !delta.run_id.is_empty() {
475            let now = now_unix_millis();
476            let mut runs = self.runs.write().await;
477            if let Some(meta) = &delta.run_meta {
478                let record = runs.entry(delta.run_id.clone()).or_insert_with(|| {
479                    RunRecord::new(
480                        &delta.run_id,
481                        thread_id,
482                        &meta.agent_id,
483                        meta.origin,
484                        meta.status,
485                        now,
486                    )
487                });
488                record.status = meta.status;
489                record.agent_id.clone_from(&meta.agent_id);
490                record.origin = meta.origin;
491                record.thread_id = thread_id.to_string();
492                if record.parent_run_id.is_none() {
493                    record.parent_run_id.clone_from(&delta.parent_run_id);
494                }
495                if record.parent_thread_id.is_none() {
496                    record.parent_thread_id.clone_from(&meta.parent_thread_id);
497                }
498                record.termination_code.clone_from(&meta.termination_code);
499                record
500                    .termination_detail
501                    .clone_from(&meta.termination_detail);
502                if record.source_mailbox_entry_id.is_none() {
503                    record
504                        .source_mailbox_entry_id
505                        .clone_from(&meta.source_mailbox_entry_id);
506                }
507                record.updated_at = now;
508            } else if let Some(record) = runs.get_mut(&delta.run_id) {
509                record.updated_at = now;
510            }
511        }
512
513        Ok(Committed {
514            version: entry.version,
515        })
516    }
517
518    async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
519        let mut entries = self.entries.write().await;
520        entries.remove(thread_id);
521        Ok(())
522    }
523
524    async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
525        let mut entries = self.entries.write().await;
526        let version = entries.get(&thread.id).map_or(0, |e| e.version + 1);
527        entries.insert(
528            thread.id.clone(),
529            MemoryEntry {
530                thread: thread.clone(),
531                version,
532                deltas: Vec::new(),
533            },
534        );
535        Ok(())
536    }
537}
538
539#[async_trait]
540impl RunReader for MemoryStore {
541    async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, RunStoreError> {
542        Ok(self.runs.read().await.get(run_id).cloned())
543    }
544
545    async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, RunStoreError> {
546        let runs = self.runs.read().await;
547        let records: Vec<RunRecord> = runs.values().cloned().collect();
548        Ok(paginate_runs_in_memory(&records, query))
549    }
550
551    async fn load_current_run(&self, thread_id: &str) -> Result<Option<RunRecord>, RunStoreError> {
552        let runs = self.runs.read().await;
553        Ok(runs
554            .values()
555            .filter(|r| r.thread_id == thread_id && !r.status.is_terminal())
556            .max_by(|a, b| {
557                a.created_at
558                    .cmp(&b.created_at)
559                    .then_with(|| a.updated_at.cmp(&b.updated_at))
560                    .then_with(|| a.run_id.cmp(&b.run_id))
561            })
562            .cloned())
563    }
564}
565
566#[async_trait]
567impl RunWriter for MemoryStore {
568    async fn upsert_run(&self, record: &RunRecord) -> Result<(), RunStoreError> {
569        self.runs
570            .write()
571            .await
572            .insert(record.run_id.clone(), record.clone());
573        Ok(())
574    }
575
576    async fn delete_run(&self, run_id: &str) -> Result<(), RunStoreError> {
577        self.runs.write().await.remove(run_id);
578        Ok(())
579    }
580}
581
582#[async_trait]
583impl ThreadReader for MemoryStore {
584    async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
585        let entries = self.entries.read().await;
586        Ok(entries.get(thread_id).map(|e| ThreadHead {
587            thread: e.thread.clone(),
588            version: e.version,
589        }))
590    }
591
592    async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, ThreadStoreError> {
593        Ok(self.runs.read().await.get(run_id).cloned())
594    }
595
596    async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, ThreadStoreError> {
597        let runs = self.runs.read().await;
598        let records: Vec<RunRecord> = runs.values().cloned().collect();
599        Ok(paginate_runs_in_memory(&records, query))
600    }
601
602    async fn active_run_for_thread(
603        &self,
604        thread_id: &str,
605    ) -> Result<Option<RunRecord>, ThreadStoreError> {
606        let runs = self.runs.read().await;
607        Ok(runs
608            .values()
609            .filter(|r| r.thread_id == thread_id && !r.status.is_terminal())
610            .max_by(|a, b| {
611                a.created_at
612                    .cmp(&b.created_at)
613                    .then_with(|| a.updated_at.cmp(&b.updated_at))
614                    .then_with(|| a.run_id.cmp(&b.run_id))
615            })
616            .cloned())
617    }
618
619    async fn list_threads(
620        &self,
621        query: &ThreadListQuery,
622    ) -> Result<ThreadListPage, ThreadStoreError> {
623        let entries = self.entries.read().await;
624        let mut ids: Vec<String> = entries
625            .iter()
626            .filter(|(_, e)| {
627                if let Some(ref rid) = query.resource_id {
628                    e.thread.resource_id.as_deref() == Some(rid.as_str())
629                } else {
630                    true
631                }
632            })
633            .filter(|(_, e)| {
634                if let Some(ref pid) = query.parent_thread_id {
635                    e.thread.parent_thread_id.as_deref() == Some(pid.as_str())
636                } else {
637                    true
638                }
639            })
640            .map(|(id, _)| id.clone())
641            .collect();
642        ids.sort();
643        let total = ids.len();
644        let limit = query.limit.clamp(1, 200);
645        let offset = query.offset.min(total);
646        let end = (offset + limit + 1).min(total);
647        let slice = &ids[offset..end];
648        let has_more = slice.len() > limit;
649        let items: Vec<String> = slice.iter().take(limit).cloned().collect();
650        Ok(ThreadListPage {
651            items,
652            total,
653            has_more,
654        })
655    }
656}
657
658#[async_trait]
659impl ThreadSync for MemoryStore {
660    async fn load_deltas(
661        &self,
662        thread_id: &str,
663        after_version: Version,
664    ) -> Result<Vec<ThreadChangeSet>, ThreadStoreError> {
665        let entries = self.entries.read().await;
666        let entry = entries
667            .get(thread_id)
668            .ok_or_else(|| ThreadStoreError::NotFound(thread_id.to_string()))?;
669        // Deltas are 1-indexed: delta[0] produced version 1, delta[1] produced version 2, etc.
670        let skip = after_version as usize;
671        Ok(entry.deltas[skip..].to_vec())
672    }
673}