tirea_store_adapters/
nats_buffered.rs

1//! NATS JetStream-buffered storage decorator.
2//!
3//! Wraps an inner [`ThreadWriter`] (typically PostgreSQL) and routes delta
4//! writes through NATS JetStream instead of hitting the database per-delta.
5//!
6//! # Run-end flush strategy
7//!
8//! During a run the [`AgentOs::run_stream`] checkpoint background task calls
9//! `append()` for each delta.  This storage publishes those deltas to a
10//! JetStream subject `thread.{thread_id}.deltas` so they are durably buffered
11//! in NATS.  No database writes happen during the run.
12//!
13//! When the run emits `CheckpointReason::RunFinished`, `append()` triggers a
14//! flush for that thread: buffered deltas are materialized and persisted to the
15//! inner storage via a single `save()`. The buffered NATS messages are then
16//! acknowledged.
17//!
18//! `save()` remains available for explicit run-end flush when callers already
19//! have a final materialized state.
20//!
21//! # Crash recovery
22//!
23//! On startup, call [`NatsBufferedThreadWriter::recover`] to replay any unacked
24//! deltas left over from interrupted runs.
25
26use async_nats::jetstream;
27use async_trait::async_trait;
28use futures::StreamExt;
29use std::collections::HashMap;
30use std::sync::Arc;
31use tirea_contract::storage::{
32    Committed, ThreadHead, ThreadListPage, ThreadListQuery, ThreadReader, ThreadStore,
33    ThreadStoreError, ThreadWriter, VersionPrecondition,
34};
35use tirea_contract::{CheckpointReason, Thread, ThreadChangeSet};
36
37/// NATS JetStream stream name for thread deltas.
38const STREAM_NAME: &str = "THREAD_DELTAS";
39
40/// Subject prefix.  Full subject: `thread.{thread_id}.deltas`.
41const SUBJECT_PREFIX: &str = "thread";
42const DRAIN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
43
44fn delta_subject(thread_id: &str) -> String {
45    format!("{SUBJECT_PREFIX}.{thread_id}.deltas")
46}
47
48/// A [`ThreadWriter`] decorator that buffers deltas in NATS JetStream and
49/// flushes the final thread to the inner storage at run end.
50///
51/// # Query consistency (CQRS)
52///
53/// [`load`](ThreadReader::load) always reads from the inner (durable) storage.
54/// During an active run, queries return the **last-flushed snapshot** — they do
55/// not include deltas that are buffered in NATS but not yet flushed.
56///
57/// Real-time data for in-progress runs is delivered through the SSE/NATS event
58/// stream.  Callers that need up-to-date messages during a run should consume
59/// the event stream rather than polling the query API.
60pub struct NatsBufferedThreadWriter {
61    inner: Arc<dyn ThreadStore>,
62    jetstream: jetstream::Context,
63}
64
65impl NatsBufferedThreadWriter {
66    /// Create a new buffered storage.
67    ///
68    /// `inner` is the durable backend (e.g. PostgreSQL) used for `create`,
69    /// `load`, `delete`, and the final `save` at run end.
70    ///
71    /// `jetstream` is an already-connected JetStream context.
72    pub async fn new(
73        inner: Arc<dyn ThreadStore>,
74        jetstream: jetstream::Context,
75    ) -> Result<Self, async_nats::Error> {
76        // Ensure the stream exists (idempotent).
77        jetstream
78            .get_or_create_stream(jetstream::stream::Config {
79                name: STREAM_NAME.to_string(),
80                subjects: vec![format!("{SUBJECT_PREFIX}.*.deltas")],
81                retention: jetstream::stream::RetentionPolicy::WorkQueue,
82                storage: jetstream::stream::StorageType::File,
83                max_age: std::time::Duration::from_secs(24 * 3600), // 24h TTL
84                ..Default::default()
85            })
86            .await?;
87
88        Ok(Self { inner, jetstream })
89    }
90
91    /// Recover incomplete runs after a crash.
92    ///
93    /// Replays any unacked deltas from the JetStream stream, applies them to
94    /// the corresponding threads loaded from the inner storage, and saves the
95    /// result.  Acked messages are then purged.
96    pub async fn recover(&self) -> Result<usize, NatsBufferedThreadWriterError> {
97        let stream = self.stream().await?;
98        let consumer_name = format!("recovery_{}", uuid::Uuid::now_v7().simple());
99        let consumer = stream
100            .create_consumer(jetstream::consumer::pull::Config {
101                name: Some(consumer_name.clone()),
102                ack_policy: jetstream::consumer::AckPolicy::Explicit,
103                deliver_policy: jetstream::consumer::DeliverPolicy::All,
104                filter_subject: format!("{SUBJECT_PREFIX}.*.deltas"),
105                ..Default::default()
106            })
107            .await
108            .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
109
110        let mut pending: HashMap<String, Vec<(ThreadChangeSet, jetstream::Message)>> =
111            HashMap::new();
112        let mut messages = consumer
113            .messages()
114            .await
115            .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
116
117        while let Ok(Some(Ok(msg))) = tokio::time::timeout(DRAIN_TIMEOUT, messages.next()).await {
118            let subject = msg.subject.to_string();
119            let parts: Vec<&str> = subject.split('.').collect();
120            if parts.len() != 3 {
121                let _ = msg.double_ack().await;
122                continue;
123            }
124            let thread_id = parts[1].to_string();
125            match serde_json::from_slice::<ThreadChangeSet>(&msg.payload) {
126                Ok(delta) => pending.entry(thread_id).or_default().push((delta, msg)),
127                Err(_) => {
128                    let _ = msg.double_ack().await;
129                }
130            }
131        }
132
133        let mut recovered = 0usize;
134        for (thread_id, deltas_with_msgs) in pending {
135            match self
136                .materialize_and_save_thread(&thread_id, deltas_with_msgs)
137                .await
138            {
139                Ok(acked) => recovered += acked,
140                Err(e) => {
141                    tracing::error!(
142                        thread_id = %thread_id,
143                        error = %e,
144                        "recovery: failed to materialize thread"
145                    );
146                }
147            }
148        }
149
150        let _ = stream.delete_consumer(&consumer_name).await;
151        Ok(recovered)
152    }
153
154    async fn stream(&self) -> Result<jetstream::stream::Stream, NatsBufferedThreadWriterError> {
155        self.jetstream
156            .get_stream(STREAM_NAME)
157            .await
158            .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))
159    }
160
161    async fn materialize_and_save_thread(
162        &self,
163        thread_id: &str,
164        deltas_with_msgs: Vec<(ThreadChangeSet, jetstream::Message)>,
165    ) -> Result<usize, NatsBufferedThreadWriterError> {
166        if deltas_with_msgs.is_empty() {
167            return Ok(0);
168        }
169
170        let mut thread = match self.inner.load(thread_id).await? {
171            Some(head) => head.thread,
172            None => Thread::new(thread_id.to_string()),
173        };
174
175        for (delta, _) in &deltas_with_msgs {
176            delta.apply_to(&mut thread);
177        }
178
179        self.inner.save(&thread).await?;
180
181        let mut acked = 0usize;
182        for (_, msg) in deltas_with_msgs {
183            let _ = msg.double_ack().await;
184            acked += 1;
185        }
186        Ok(acked)
187    }
188
189    async fn flush_thread_buffer(
190        &self,
191        thread_id: &str,
192    ) -> Result<usize, NatsBufferedThreadWriterError> {
193        let stream = self.stream().await?;
194        let consumer_name = format!("flush_{}", uuid::Uuid::now_v7().simple());
195        let consumer = stream
196            .create_consumer(jetstream::consumer::pull::Config {
197                name: Some(consumer_name.clone()),
198                ack_policy: jetstream::consumer::AckPolicy::Explicit,
199                deliver_policy: jetstream::consumer::DeliverPolicy::All,
200                filter_subject: delta_subject(thread_id),
201                ..Default::default()
202            })
203            .await
204            .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
205
206        let mut deltas_with_msgs = Vec::new();
207        let mut messages = consumer
208            .messages()
209            .await
210            .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
211
212        while let Ok(Some(Ok(msg))) = tokio::time::timeout(DRAIN_TIMEOUT, messages.next()).await {
213            match serde_json::from_slice::<ThreadChangeSet>(&msg.payload) {
214                Ok(delta) => deltas_with_msgs.push((delta, msg)),
215                Err(_) => {
216                    let _ = msg.double_ack().await;
217                }
218            }
219        }
220
221        let result = self
222            .materialize_and_save_thread(thread_id, deltas_with_msgs)
223            .await;
224        let _ = stream.delete_consumer(&consumer_name).await;
225        result
226    }
227}
228
229#[async_trait]
230impl ThreadWriter for NatsBufferedThreadWriter {
231    async fn create(&self, thread: &Thread) -> Result<Committed, ThreadStoreError> {
232        self.inner.create(thread).await
233    }
234
235    /// Publish delta to NATS JetStream instead of writing to database.
236    ///
237    /// The delta is durably stored in JetStream and will be purged after the
238    /// run-end `save()` succeeds.  If publishing fails the error is mapped to
239    /// [`ThreadStoreError::Io`].
240    async fn append(
241        &self,
242        thread_id: &str,
243        delta: &ThreadChangeSet,
244        precondition: VersionPrecondition,
245    ) -> Result<Committed, ThreadStoreError> {
246        let payload = serde_json::to_vec(delta)
247            .map_err(|e| ThreadStoreError::Serialization(e.to_string()))?;
248
249        self.jetstream
250            .publish(delta_subject(thread_id), payload.into())
251            .await
252            .map_err(|e| ThreadStoreError::Io(std::io::Error::other(e)))?
253            .await
254            .map_err(|e| ThreadStoreError::Io(std::io::Error::other(e)))?;
255
256        if delta.reason == CheckpointReason::RunFinished {
257            self.flush_thread_buffer(thread_id)
258                .await
259                .map_err(|e| match e {
260                    NatsBufferedThreadWriterError::JetStream(msg) => {
261                        ThreadStoreError::Io(std::io::Error::other(msg))
262                    }
263                    NatsBufferedThreadWriterError::Storage(err) => err,
264                })?;
265        }
266
267        let version = match precondition {
268            VersionPrecondition::Any => 0,
269            VersionPrecondition::Exact(v) => v.saturating_add(1),
270        };
271        Ok(Committed { version })
272    }
273
274    async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
275        self.inner.delete(thread_id).await
276    }
277
278    /// Run-end flush: saves the final materialized thread to the inner storage
279    /// and purges the corresponding NATS JetStream messages.
280    async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
281        // Write to durable storage.
282        self.inner.save(thread).await?;
283
284        // Best-effort purge of buffered deltas for this thread.
285        if let Ok(stream) = self.jetstream.get_stream(STREAM_NAME).await {
286            let _ = stream.purge().filter(delta_subject(&thread.id)).await;
287        }
288
289        Ok(())
290    }
291}
292
293#[async_trait]
294impl ThreadReader for NatsBufferedThreadWriter {
295    async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
296        self.inner.load(thread_id).await
297    }
298
299    async fn list_threads(
300        &self,
301        query: &ThreadListQuery,
302    ) -> Result<ThreadListPage, ThreadStoreError> {
303        self.inner.list_threads(query).await
304    }
305}
306
307#[derive(Debug, thiserror::Error)]
308pub enum NatsBufferedThreadWriterError {
309    #[error("jetstream error: {0}")]
310    JetStream(String),
311
312    #[error("storage error: {0}")]
313    Storage(#[from] ThreadStoreError),
314}