tirea_state/lattice/
or_set.rs

1use std::collections::BTreeMap;
2
3use serde::de::DeserializeOwned;
4use serde::{Deserialize, Serialize};
5
6use super::Lattice;
7
8/// An observed-remove set (OR-Set) with add-wins semantics.
9///
10/// Each element is tracked with a timestamp. Removals record a tombstone timestamp.
11/// An element is considered present when its entry timestamp is greater than or equal
12/// to its tombstone timestamp (add-wins: concurrent add and remove resolve in favor
13/// of add).
14///
15/// The internal clock advances automatically on mutation and is bumped to
16/// `max(self, other)` on merge.
17#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
18pub struct ORSet<T: Ord> {
19    entries: BTreeMap<T, u64>,
20    tombstones: BTreeMap<T, u64>,
21    #[serde(skip)]
22    clock: u64,
23}
24
25impl<'de, T: Ord + DeserializeOwned> Deserialize<'de> for ORSet<T> {
26    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
27    where
28        D: serde::Deserializer<'de>,
29    {
30        #[derive(Deserialize)]
31        struct Raw<T: Ord> {
32            entries: BTreeMap<T, u64>,
33            tombstones: BTreeMap<T, u64>,
34        }
35
36        let raw = Raw::deserialize(deserializer)?;
37        let max_entry = raw.entries.values().copied().max().unwrap_or(0);
38        let max_tomb = raw.tombstones.values().copied().max().unwrap_or(0);
39        let clock = max_entry.max(max_tomb);
40
41        Ok(Self {
42            entries: raw.entries,
43            tombstones: raw.tombstones,
44            clock,
45        })
46    }
47}
48
49impl<T: Ord> ORSet<T> {
50    /// Create an empty OR-Set.
51    pub fn new() -> Self {
52        Self {
53            entries: BTreeMap::new(),
54            tombstones: BTreeMap::new(),
55            clock: 0,
56        }
57    }
58
59    /// Insert an element. If previously removed, the add-wins semantics re-add it.
60    pub fn insert(&mut self, value: T) {
61        self.clock += 1;
62        self.entries.insert(value, self.clock);
63    }
64
65    /// Remove an element by recording a tombstone at the current clock.
66    pub fn remove(&mut self, value: &T)
67    where
68        T: Clone,
69    {
70        if self.entries.contains_key(value) {
71            self.clock += 1;
72            self.tombstones.insert(value.clone(), self.clock);
73        }
74    }
75
76    /// Returns `true` if the element is present (entry timestamp >= tombstone timestamp).
77    pub fn contains(&self, value: &T) -> bool {
78        match self.entries.get(value) {
79            Some(&entry_ts) => {
80                let tomb_ts = self.tombstones.get(value).copied().unwrap_or(0);
81                entry_ts >= tomb_ts
82            }
83            None => false,
84        }
85    }
86
87    /// Returns a sorted vector of references to all present elements.
88    pub fn elements(&self) -> Vec<&T> {
89        self.entries
90            .iter()
91            .filter(|(k, &entry_ts)| {
92                let tomb_ts = self.tombstones.get(*k).copied().unwrap_or(0);
93                entry_ts >= tomb_ts
94            })
95            .map(|(k, _)| k)
96            .collect()
97    }
98
99    /// Returns the number of present elements.
100    pub fn len(&self) -> usize {
101        self.elements().len()
102    }
103
104    /// Returns `true` if no elements are present.
105    pub fn is_empty(&self) -> bool {
106        self.len() == 0
107    }
108
109    fn max_observed_ts(&self) -> u64 {
110        let max_entry = self.entries.values().copied().max().unwrap_or(0);
111        let max_tomb = self.tombstones.values().copied().max().unwrap_or(0);
112        max_entry.max(max_tomb)
113    }
114}
115
116impl<T: Ord> Default for ORSet<T> {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl<T: Ord + Clone + PartialEq> Lattice for ORSet<T> {
123    fn merge(&self, other: &Self) -> Self {
124        let mut entries = BTreeMap::new();
125        let mut tombstones = BTreeMap::new();
126
127        // Merge entries: per-element max timestamp
128        for (k, &ts) in &self.entries {
129            entries.insert(k.clone(), ts);
130        }
131        for (k, &ts) in &other.entries {
132            let entry = entries.entry(k.clone()).or_insert(0);
133            *entry = (*entry).max(ts);
134        }
135
136        // Merge tombstones: per-element max timestamp
137        for (k, &ts) in &self.tombstones {
138            tombstones.insert(k.clone(), ts);
139        }
140        for (k, &ts) in &other.tombstones {
141            let entry = tombstones.entry(k.clone()).or_insert(0);
142            *entry = (*entry).max(ts);
143        }
144
145        let clock = self.max_observed_ts().max(other.max_observed_ts());
146
147        Self {
148            entries,
149            tombstones,
150            clock,
151        }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::lattice::assert_lattice_laws;
159
160    #[test]
161    fn new_is_empty() {
162        let s: ORSet<i32> = ORSet::new();
163        assert!(s.is_empty());
164        assert_eq!(s.len(), 0);
165    }
166
167    #[test]
168    fn insert_and_contains() {
169        let mut s = ORSet::new();
170        s.insert(1);
171        s.insert(2);
172        assert!(s.contains(&1));
173        assert!(s.contains(&2));
174        assert!(!s.contains(&3));
175        assert_eq!(s.len(), 2);
176    }
177
178    #[test]
179    fn remove_element() {
180        let mut s = ORSet::new();
181        s.insert(1);
182        s.insert(2);
183        s.remove(&1);
184        assert!(!s.contains(&1));
185        assert!(s.contains(&2));
186        assert_eq!(s.len(), 1);
187    }
188
189    #[test]
190    fn remove_nonexistent_is_noop() {
191        let mut s = ORSet::new();
192        s.insert(1);
193        s.remove(&2); // no-op
194        assert_eq!(s.len(), 1);
195    }
196
197    #[test]
198    fn add_wins_after_concurrent_remove() {
199        // Simulate: replica A removes x, replica B adds x concurrently
200        let mut a = ORSet::new();
201        a.insert("x".to_string());
202
203        let mut b = a.clone();
204
205        // A removes
206        a.remove(&"x".to_string());
207        assert!(!a.contains(&"x".to_string()));
208
209        // B re-adds (concurrent with A's remove)
210        b.insert("x".to_string());
211        assert!(b.contains(&"x".to_string()));
212
213        // Merge: add-wins
214        let merged = a.merge(&b);
215        assert!(
216            merged.contains(&"x".to_string()),
217            "add-wins: element should be present after merge"
218        );
219    }
220
221    #[test]
222    fn elements_sorted() {
223        let mut s = ORSet::new();
224        s.insert(3);
225        s.insert(1);
226        s.insert(2);
227        assert_eq!(s.elements(), vec![&1, &2, &3]);
228    }
229
230    #[test]
231    fn lattice_laws() {
232        let mut a = ORSet::new();
233        a.insert(1);
234        a.insert(2);
235
236        let mut b = ORSet::new();
237        b.insert(2);
238        b.insert(3);
239
240        let mut c = ORSet::new();
241        c.insert(1);
242        c.insert(3);
243
244        assert_lattice_laws(&a, &b, &c);
245    }
246
247    #[test]
248    fn lattice_laws_with_removes() {
249        let mut a = ORSet::new();
250        a.insert(1);
251        a.insert(2);
252        a.remove(&1);
253
254        let mut b = ORSet::new();
255        b.insert(2);
256        b.insert(3);
257        b.remove(&3);
258
259        let mut c = ORSet::new();
260        c.insert(1);
261        c.insert(3);
262
263        assert_lattice_laws(&a, &b, &c);
264    }
265
266    #[test]
267    fn merge_union_of_entries() {
268        let mut a = ORSet::new();
269        a.insert(1);
270        a.insert(2);
271
272        let mut b = ORSet::new();
273        b.insert(3);
274        b.insert(4);
275
276        let merged = a.merge(&b);
277        assert_eq!(merged.len(), 4);
278    }
279
280    #[test]
281    fn merge_empty_sets() {
282        let a: ORSet<i32> = ORSet::new();
283        let b: ORSet<i32> = ORSet::new();
284        let merged = a.merge(&b);
285        assert!(merged.is_empty());
286    }
287
288    #[test]
289    fn merge_one_empty() {
290        let a: ORSet<i32> = ORSet::new();
291        let mut b = ORSet::new();
292        b.insert(1);
293
294        assert_eq!(a.merge(&b).len(), 1);
295        assert_eq!(b.merge(&a).len(), 1);
296    }
297
298    #[test]
299    fn serde_roundtrip() {
300        let mut s = ORSet::new();
301        s.insert(1);
302        s.insert(2);
303        s.remove(&1);
304
305        let json = serde_json::to_string(&s).unwrap();
306        let back: ORSet<i32> = serde_json::from_str(&json).unwrap();
307
308        // After deserialization, the "visible" state should match
309        assert_eq!(s.elements(), back.elements());
310        assert_eq!(s.entries, back.entries);
311        assert_eq!(s.tombstones, back.tombstones);
312    }
313
314    #[test]
315    fn serde_preserves_tombstones() {
316        let mut s = ORSet::new();
317        s.insert("a".to_string());
318        s.remove(&"a".to_string());
319
320        let json = serde_json::to_value(&s).unwrap();
321        // Should have both entries and tombstones in the JSON
322        assert!(json.get("entries").is_some());
323        assert!(json.get("tombstones").is_some());
324
325        let back: ORSet<String> = serde_json::from_value(json).unwrap();
326        assert!(!back.contains(&"a".to_string()));
327    }
328}