tirea_state/lattice/
or_map.rs

1use std::collections::BTreeMap;
2
3use serde::de::DeserializeOwned;
4use serde::{Deserialize, Serialize};
5
6use super::Lattice;
7
8/// Internal entry pairing a value with its insertion timestamp.
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub(crate) struct Entry<V> {
11    value: V,
12    timestamp: u64,
13}
14
15/// An observed-remove map (OR-Map) with put-wins semantics and recursive value merge.
16///
17/// Each key maps to an [`Entry`] containing a value `V: Lattice` and an insertion timestamp.
18/// Removals record a tombstone timestamp. A key is considered present when its entry timestamp
19/// is greater than or equal to its tombstone timestamp (put-wins: concurrent put and
20/// remove resolve in favor of put).
21///
22/// When both replicas have a present entry for the same key, the values are merged using
23/// the `V::merge` lattice operation, providing recursive conflict resolution.
24///
25/// The internal clock advances automatically on mutation and is bumped to
26/// `max(self, other)` on merge.
27#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
28pub struct ORMap<K: Ord, V: Lattice> {
29    entries: BTreeMap<K, Entry<V>>,
30    tombstones: BTreeMap<K, u64>,
31    #[serde(skip)]
32    clock: u64,
33}
34
35impl<'de, K, V> Deserialize<'de> for ORMap<K, V>
36where
37    K: Ord + DeserializeOwned,
38    V: Lattice + DeserializeOwned,
39{
40    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41    where
42        D: serde::Deserializer<'de>,
43    {
44        #[derive(Deserialize)]
45        struct Raw<K: Ord, V> {
46            entries: BTreeMap<K, Entry<V>>,
47            tombstones: BTreeMap<K, u64>,
48        }
49
50        let raw = Raw::<K, V>::deserialize(deserializer)?;
51        let max_entry = raw.entries.values().map(|e| e.timestamp).max().unwrap_or(0);
52        let max_tomb = raw.tombstones.values().copied().max().unwrap_or(0);
53        let clock = max_entry.max(max_tomb);
54
55        Ok(Self {
56            entries: raw.entries,
57            tombstones: raw.tombstones,
58            clock,
59        })
60    }
61}
62
63impl<K: Ord, V: Lattice> ORMap<K, V> {
64    /// Create an empty OR-Map.
65    pub fn new() -> Self {
66        Self {
67            entries: BTreeMap::new(),
68            tombstones: BTreeMap::new(),
69            clock: 0,
70        }
71    }
72
73    /// Put a key-value pair. Overwrites any existing value for the key.
74    pub fn put(&mut self, key: K, value: V) {
75        self.clock += 1;
76        self.entries.insert(
77            key,
78            Entry {
79                value,
80                timestamp: self.clock,
81            },
82        );
83    }
84
85    /// Remove a key by recording a tombstone at the current clock.
86    pub fn remove(&mut self, key: &K)
87    where
88        K: Clone,
89    {
90        if self.entries.contains_key(key) {
91            self.clock += 1;
92            self.tombstones.insert(key.clone(), self.clock);
93        }
94    }
95
96    /// Returns a reference to the value for the key, if present.
97    pub fn get(&self, key: &K) -> Option<&V> {
98        self.entries.get(key).and_then(|entry| {
99            let tomb_ts = self.tombstones.get(key).copied().unwrap_or(0);
100            if entry.timestamp >= tomb_ts {
101                Some(&entry.value)
102            } else {
103                None
104            }
105        })
106    }
107
108    /// Returns `true` if the key is present.
109    pub fn contains_key(&self, key: &K) -> bool {
110        self.get(key).is_some()
111    }
112
113    /// Returns a sorted vector of references to all present keys.
114    pub fn keys(&self) -> Vec<&K> {
115        self.entries
116            .iter()
117            .filter(|(k, entry)| {
118                let tomb_ts = self.tombstones.get(*k).copied().unwrap_or(0);
119                entry.timestamp >= tomb_ts
120            })
121            .map(|(k, _)| k)
122            .collect()
123    }
124
125    /// Returns a sorted vector of `(key, value)` pairs for all present entries.
126    pub fn entries(&self) -> Vec<(&K, &V)> {
127        self.entries
128            .iter()
129            .filter(|(k, entry)| {
130                let tomb_ts = self.tombstones.get(*k).copied().unwrap_or(0);
131                entry.timestamp >= tomb_ts
132            })
133            .map(|(k, entry)| (k, &entry.value))
134            .collect()
135    }
136
137    /// Returns the number of present entries.
138    pub fn len(&self) -> usize {
139        self.keys().len()
140    }
141
142    /// Returns `true` if no entries are present.
143    pub fn is_empty(&self) -> bool {
144        self.len() == 0
145    }
146
147    fn max_observed_ts(&self) -> u64 {
148        let max_entry = self
149            .entries
150            .values()
151            .map(|e| e.timestamp)
152            .max()
153            .unwrap_or(0);
154        let max_tomb = self.tombstones.values().copied().max().unwrap_or(0);
155        max_entry.max(max_tomb)
156    }
157}
158
159impl<K: Ord, V: Lattice> Default for ORMap<K, V> {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165impl<K: Ord + Clone + PartialEq, V: Lattice> Lattice for ORMap<K, V> {
166    fn merge(&self, other: &Self) -> Self {
167        let mut entries: BTreeMap<K, Entry<V>> = BTreeMap::new();
168        let mut tombstones: BTreeMap<K, u64> = BTreeMap::new();
169
170        // Merge tombstones: per-key max
171        for (k, &ts) in &self.tombstones {
172            tombstones.insert(k.clone(), ts);
173        }
174        for (k, &ts) in &other.tombstones {
175            let entry = tombstones.entry(k.clone()).or_insert(0);
176            *entry = (*entry).max(ts);
177        }
178
179        // Merge entries
180        // Collect all keys that have entries in either side
181        let mut all_entry_keys: std::collections::BTreeSet<&K> = std::collections::BTreeSet::new();
182        for k in self.entries.keys() {
183            all_entry_keys.insert(k);
184        }
185        for k in other.entries.keys() {
186            all_entry_keys.insert(k);
187        }
188
189        for k in all_entry_keys {
190            let self_entry = self.entries.get(k);
191            let other_entry = other.entries.get(k);
192
193            let merged_entry = match (self_entry, other_entry) {
194                (Some(a), Some(b)) => {
195                    // Both sides have the entry; merge values and take max timestamp
196                    let max_ts = a.timestamp.max(b.timestamp);
197                    Entry {
198                        value: a.value.merge(&b.value),
199                        timestamp: max_ts,
200                    }
201                }
202                (Some(a), None) => a.clone(),
203                (None, Some(b)) => b.clone(),
204                (None, None) => unreachable!(),
205            };
206
207            // Always keep entries so future merges can compare timestamps
208            entries.insert(k.clone(), merged_entry);
209        }
210
211        let clock = self.max_observed_ts().max(other.max_observed_ts());
212
213        Self {
214            entries,
215            tombstones,
216            clock,
217        }
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::lattice::{assert_lattice_laws, GCounter, MaxReg};
225
226    #[test]
227    fn new_is_empty() {
228        let m: ORMap<String, MaxReg<i64>> = ORMap::new();
229        assert!(m.is_empty());
230        assert_eq!(m.len(), 0);
231    }
232
233    #[test]
234    fn put_and_get() {
235        let mut m = ORMap::new();
236        m.put("a".to_string(), MaxReg::new(10i64));
237        m.put("b".to_string(), MaxReg::new(20i64));
238
239        assert_eq!(m.get(&"a".to_string()).map(|r| *r.value()), Some(10));
240        assert_eq!(m.get(&"b".to_string()).map(|r| *r.value()), Some(20));
241        assert!(m.get(&"c".to_string()).is_none());
242        assert_eq!(m.len(), 2);
243    }
244
245    #[test]
246    fn put_overwrites() {
247        let mut m = ORMap::new();
248        m.put("a".to_string(), MaxReg::new(10i64));
249        m.put("a".to_string(), MaxReg::new(20i64));
250        assert_eq!(m.get(&"a".to_string()).map(|r| *r.value()), Some(20));
251        assert_eq!(m.len(), 1);
252    }
253
254    #[test]
255    fn remove_key() {
256        let mut m = ORMap::new();
257        m.put("a".to_string(), MaxReg::new(10i64));
258        m.put("b".to_string(), MaxReg::new(20i64));
259        m.remove(&"a".to_string());
260
261        assert!(!m.contains_key(&"a".to_string()));
262        assert!(m.contains_key(&"b".to_string()));
263        assert_eq!(m.len(), 1);
264    }
265
266    #[test]
267    fn remove_nonexistent_is_noop() {
268        let mut m: ORMap<String, MaxReg<i64>> = ORMap::new();
269        m.put("a".to_string(), MaxReg::new(10));
270        m.remove(&"b".to_string()); // no-op
271        assert_eq!(m.len(), 1);
272    }
273
274    #[test]
275    fn keys_and_entries_sorted() {
276        let mut m = ORMap::new();
277        m.put("c".to_string(), MaxReg::new(3i64));
278        m.put("a".to_string(), MaxReg::new(1i64));
279        m.put("b".to_string(), MaxReg::new(2i64));
280
281        let keys: Vec<_> = m.keys();
282        assert_eq!(keys, vec!["a", "b", "c"]);
283
284        let entries: Vec<_> = m.entries();
285        assert_eq!(entries.len(), 3);
286        assert_eq!(entries[0].0, "a");
287    }
288
289    #[test]
290    fn put_wins_after_concurrent_remove() {
291        let mut a: ORMap<String, MaxReg<i64>> = ORMap::new();
292        a.put("x".to_string(), MaxReg::new(1));
293
294        let mut b = a.clone();
295
296        // A removes
297        a.remove(&"x".to_string());
298        assert!(!a.contains_key(&"x".to_string()));
299
300        // B re-puts (concurrent with A's remove)
301        b.put("x".to_string(), MaxReg::new(2));
302        assert!(b.contains_key(&"x".to_string()));
303
304        // Merge: put-wins
305        let merged = a.merge(&b);
306        assert!(
307            merged.contains_key(&"x".to_string()),
308            "put-wins: key should be present after merge"
309        );
310    }
311
312    #[test]
313    fn merge_recursive_value_merge() {
314        // Both replicas have the same key with different MaxReg values
315        let mut a = ORMap::new();
316        a.put("k".to_string(), MaxReg::new(10i64));
317
318        let mut b = ORMap::new();
319        b.put("k".to_string(), MaxReg::new(20i64));
320
321        let merged = a.merge(&b);
322        // MaxReg merge => max(10, 20) = 20
323        assert_eq!(merged.get(&"k".to_string()).map(|r| *r.value()), Some(20));
324    }
325
326    #[test]
327    fn merge_recursive_value_merge_gcounter() {
328        let mut c1 = GCounter::new();
329        c1.increment("node-a", 5);
330
331        let mut c2 = GCounter::new();
332        c2.increment("node-b", 3);
333
334        let mut a = ORMap::new();
335        a.put("counter".to_string(), c1);
336
337        let mut b = ORMap::new();
338        b.put("counter".to_string(), c2);
339
340        let merged = a.merge(&b);
341        let counter = merged.get(&"counter".to_string()).unwrap();
342        // GCounter merge: per-node max => {node-a: 5, node-b: 3}, value = 8
343        assert_eq!(counter.value(), 8);
344    }
345
346    #[test]
347    fn lattice_laws_maxreg() {
348        let mut a: ORMap<String, MaxReg<i64>> = ORMap::new();
349        a.put("x".to_string(), MaxReg::new(1));
350
351        let mut b: ORMap<String, MaxReg<i64>> = ORMap::new();
352        b.put("y".to_string(), MaxReg::new(2));
353
354        let mut c: ORMap<String, MaxReg<i64>> = ORMap::new();
355        c.put("x".to_string(), MaxReg::new(3));
356        c.put("z".to_string(), MaxReg::new(4));
357
358        assert_lattice_laws(&a, &b, &c);
359    }
360
361    #[test]
362    fn lattice_laws_with_removes() {
363        let mut a: ORMap<String, MaxReg<i64>> = ORMap::new();
364        a.put("x".to_string(), MaxReg::new(1));
365        a.put("y".to_string(), MaxReg::new(2));
366        a.remove(&"x".to_string());
367
368        let mut b: ORMap<String, MaxReg<i64>> = ORMap::new();
369        b.put("y".to_string(), MaxReg::new(5));
370        b.put("z".to_string(), MaxReg::new(3));
371
372        let mut c: ORMap<String, MaxReg<i64>> = ORMap::new();
373        c.put("z".to_string(), MaxReg::new(10));
374
375        assert_lattice_laws(&a, &b, &c);
376    }
377
378    #[test]
379    fn merge_empty_maps() {
380        let a: ORMap<String, MaxReg<i64>> = ORMap::new();
381        let b: ORMap<String, MaxReg<i64>> = ORMap::new();
382        let merged = a.merge(&b);
383        assert!(merged.is_empty());
384    }
385
386    #[test]
387    fn merge_one_empty() {
388        let a: ORMap<String, MaxReg<i64>> = ORMap::new();
389        let mut b: ORMap<String, MaxReg<i64>> = ORMap::new();
390        b.put("k".to_string(), MaxReg::new(1));
391
392        assert_eq!(a.merge(&b).len(), 1);
393        assert_eq!(b.merge(&a).len(), 1);
394    }
395
396    #[test]
397    fn serde_roundtrip() {
398        let mut m = ORMap::new();
399        m.put("a".to_string(), MaxReg::new(10i64));
400        m.put("b".to_string(), MaxReg::new(20i64));
401        m.remove(&"a".to_string());
402
403        let json = serde_json::to_string(&m).unwrap();
404        let back: ORMap<String, MaxReg<i64>> = serde_json::from_str(&json).unwrap();
405
406        // Visible state should match
407        assert_eq!(m.entries(), back.entries());
408        assert_eq!(m.entries, back.entries);
409        assert_eq!(m.tombstones, back.tombstones);
410    }
411
412    #[test]
413    fn serde_preserves_tombstones() {
414        let mut m: ORMap<String, MaxReg<i64>> = ORMap::new();
415        m.put("a".to_string(), MaxReg::new(1));
416        m.remove(&"a".to_string());
417
418        let json = serde_json::to_value(&m).unwrap();
419        assert!(json.get("entries").is_some());
420        assert!(json.get("tombstones").is_some());
421
422        let back: ORMap<String, MaxReg<i64>> = serde_json::from_value(json).unwrap();
423        assert!(!back.contains_key(&"a".to_string()));
424    }
425}