tirea_store_adapters/
nats_buffered.rs1use async_nats::jetstream;
27use async_trait::async_trait;
28use futures::StreamExt;
29use std::collections::HashMap;
30use std::sync::Arc;
31use tirea_contract::storage::{
32 Committed, ThreadHead, ThreadListPage, ThreadListQuery, ThreadReader, ThreadStore,
33 ThreadStoreError, ThreadWriter, VersionPrecondition,
34};
35use tirea_contract::{CheckpointReason, Thread, ThreadChangeSet};
36
37const STREAM_NAME: &str = "THREAD_DELTAS";
39
40const SUBJECT_PREFIX: &str = "thread";
42const DRAIN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
43
44fn delta_subject(thread_id: &str) -> String {
45 format!("{SUBJECT_PREFIX}.{thread_id}.deltas")
46}
47
48pub struct NatsBufferedThreadWriter {
61 inner: Arc<dyn ThreadStore>,
62 jetstream: jetstream::Context,
63}
64
65impl NatsBufferedThreadWriter {
66 pub async fn new(
73 inner: Arc<dyn ThreadStore>,
74 jetstream: jetstream::Context,
75 ) -> Result<Self, async_nats::Error> {
76 jetstream
78 .get_or_create_stream(jetstream::stream::Config {
79 name: STREAM_NAME.to_string(),
80 subjects: vec![format!("{SUBJECT_PREFIX}.*.deltas")],
81 retention: jetstream::stream::RetentionPolicy::WorkQueue,
82 storage: jetstream::stream::StorageType::File,
83 max_age: std::time::Duration::from_secs(24 * 3600), ..Default::default()
85 })
86 .await?;
87
88 Ok(Self { inner, jetstream })
89 }
90
91 pub async fn recover(&self) -> Result<usize, NatsBufferedThreadWriterError> {
97 let stream = self.stream().await?;
98 let consumer_name = format!("recovery_{}", uuid::Uuid::now_v7().simple());
99 let consumer = stream
100 .create_consumer(jetstream::consumer::pull::Config {
101 name: Some(consumer_name.clone()),
102 ack_policy: jetstream::consumer::AckPolicy::Explicit,
103 deliver_policy: jetstream::consumer::DeliverPolicy::All,
104 filter_subject: format!("{SUBJECT_PREFIX}.*.deltas"),
105 ..Default::default()
106 })
107 .await
108 .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
109
110 let mut pending: HashMap<String, Vec<(ThreadChangeSet, jetstream::Message)>> =
111 HashMap::new();
112 let mut messages = consumer
113 .messages()
114 .await
115 .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
116
117 while let Ok(Some(Ok(msg))) = tokio::time::timeout(DRAIN_TIMEOUT, messages.next()).await {
118 let subject = msg.subject.to_string();
119 let parts: Vec<&str> = subject.split('.').collect();
120 if parts.len() != 3 {
121 let _ = msg.double_ack().await;
122 continue;
123 }
124 let thread_id = parts[1].to_string();
125 match serde_json::from_slice::<ThreadChangeSet>(&msg.payload) {
126 Ok(delta) => pending.entry(thread_id).or_default().push((delta, msg)),
127 Err(_) => {
128 let _ = msg.double_ack().await;
129 }
130 }
131 }
132
133 let mut recovered = 0usize;
134 for (thread_id, deltas_with_msgs) in pending {
135 match self
136 .materialize_and_save_thread(&thread_id, deltas_with_msgs)
137 .await
138 {
139 Ok(acked) => recovered += acked,
140 Err(e) => {
141 tracing::error!(
142 thread_id = %thread_id,
143 error = %e,
144 "recovery: failed to materialize thread"
145 );
146 }
147 }
148 }
149
150 let _ = stream.delete_consumer(&consumer_name).await;
151 Ok(recovered)
152 }
153
154 async fn stream(&self) -> Result<jetstream::stream::Stream, NatsBufferedThreadWriterError> {
155 self.jetstream
156 .get_stream(STREAM_NAME)
157 .await
158 .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))
159 }
160
161 async fn materialize_and_save_thread(
162 &self,
163 thread_id: &str,
164 deltas_with_msgs: Vec<(ThreadChangeSet, jetstream::Message)>,
165 ) -> Result<usize, NatsBufferedThreadWriterError> {
166 if deltas_with_msgs.is_empty() {
167 return Ok(0);
168 }
169
170 let mut thread = match self.inner.load(thread_id).await? {
171 Some(head) => head.thread,
172 None => Thread::new(thread_id.to_string()),
173 };
174
175 for (delta, _) in &deltas_with_msgs {
176 delta.apply_to(&mut thread);
177 }
178
179 self.inner.save(&thread).await?;
180
181 let mut acked = 0usize;
182 for (_, msg) in deltas_with_msgs {
183 let _ = msg.double_ack().await;
184 acked += 1;
185 }
186 Ok(acked)
187 }
188
189 async fn flush_thread_buffer(
190 &self,
191 thread_id: &str,
192 ) -> Result<usize, NatsBufferedThreadWriterError> {
193 let stream = self.stream().await?;
194 let consumer_name = format!("flush_{}", uuid::Uuid::now_v7().simple());
195 let consumer = stream
196 .create_consumer(jetstream::consumer::pull::Config {
197 name: Some(consumer_name.clone()),
198 ack_policy: jetstream::consumer::AckPolicy::Explicit,
199 deliver_policy: jetstream::consumer::DeliverPolicy::All,
200 filter_subject: delta_subject(thread_id),
201 ..Default::default()
202 })
203 .await
204 .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
205
206 let mut deltas_with_msgs = Vec::new();
207 let mut messages = consumer
208 .messages()
209 .await
210 .map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
211
212 while let Ok(Some(Ok(msg))) = tokio::time::timeout(DRAIN_TIMEOUT, messages.next()).await {
213 match serde_json::from_slice::<ThreadChangeSet>(&msg.payload) {
214 Ok(delta) => deltas_with_msgs.push((delta, msg)),
215 Err(_) => {
216 let _ = msg.double_ack().await;
217 }
218 }
219 }
220
221 let result = self
222 .materialize_and_save_thread(thread_id, deltas_with_msgs)
223 .await;
224 let _ = stream.delete_consumer(&consumer_name).await;
225 result
226 }
227}
228
229#[async_trait]
230impl ThreadWriter for NatsBufferedThreadWriter {
231 async fn create(&self, thread: &Thread) -> Result<Committed, ThreadStoreError> {
232 self.inner.create(thread).await
233 }
234
235 async fn append(
241 &self,
242 thread_id: &str,
243 delta: &ThreadChangeSet,
244 precondition: VersionPrecondition,
245 ) -> Result<Committed, ThreadStoreError> {
246 let payload = serde_json::to_vec(delta)
247 .map_err(|e| ThreadStoreError::Serialization(e.to_string()))?;
248
249 self.jetstream
250 .publish(delta_subject(thread_id), payload.into())
251 .await
252 .map_err(|e| ThreadStoreError::Io(std::io::Error::other(e)))?
253 .await
254 .map_err(|e| ThreadStoreError::Io(std::io::Error::other(e)))?;
255
256 if delta.reason == CheckpointReason::RunFinished {
257 self.flush_thread_buffer(thread_id)
258 .await
259 .map_err(|e| match e {
260 NatsBufferedThreadWriterError::JetStream(msg) => {
261 ThreadStoreError::Io(std::io::Error::other(msg))
262 }
263 NatsBufferedThreadWriterError::Storage(err) => err,
264 })?;
265 }
266
267 let version = match precondition {
268 VersionPrecondition::Any => 0,
269 VersionPrecondition::Exact(v) => v.saturating_add(1),
270 };
271 Ok(Committed { version })
272 }
273
274 async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
275 self.inner.delete(thread_id).await
276 }
277
278 async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
281 self.inner.save(thread).await?;
283
284 if let Ok(stream) = self.jetstream.get_stream(STREAM_NAME).await {
286 let _ = stream.purge().filter(delta_subject(&thread.id)).await;
287 }
288
289 Ok(())
290 }
291}
292
293#[async_trait]
294impl ThreadReader for NatsBufferedThreadWriter {
295 async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
296 self.inner.load(thread_id).await
297 }
298
299 async fn list_threads(
300 &self,
301 query: &ThreadListQuery,
302 ) -> Result<ThreadListPage, ThreadStoreError> {
303 self.inner.list_threads(query).await
304 }
305}
306
307#[derive(Debug, thiserror::Error)]
308pub enum NatsBufferedThreadWriterError {
309 #[error("jetstream error: {0}")]
310 JetStream(String),
311
312 #[error("storage error: {0}")]
313 Storage(#[from] ThreadStoreError),
314}