tirea_state/lattice/
or_set.rs1use std::collections::BTreeMap;
2
3use serde::de::DeserializeOwned;
4use serde::{Deserialize, Serialize};
5
6use super::Lattice;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
18pub struct ORSet<T: Ord> {
19 entries: BTreeMap<T, u64>,
20 tombstones: BTreeMap<T, u64>,
21 #[serde(skip)]
22 clock: u64,
23}
24
25impl<'de, T: Ord + DeserializeOwned> Deserialize<'de> for ORSet<T> {
26 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
27 where
28 D: serde::Deserializer<'de>,
29 {
30 #[derive(Deserialize)]
31 struct Raw<T: Ord> {
32 entries: BTreeMap<T, u64>,
33 tombstones: BTreeMap<T, u64>,
34 }
35
36 let raw = Raw::deserialize(deserializer)?;
37 let max_entry = raw.entries.values().copied().max().unwrap_or(0);
38 let max_tomb = raw.tombstones.values().copied().max().unwrap_or(0);
39 let clock = max_entry.max(max_tomb);
40
41 Ok(Self {
42 entries: raw.entries,
43 tombstones: raw.tombstones,
44 clock,
45 })
46 }
47}
48
49impl<T: Ord> ORSet<T> {
50 pub fn new() -> Self {
52 Self {
53 entries: BTreeMap::new(),
54 tombstones: BTreeMap::new(),
55 clock: 0,
56 }
57 }
58
59 pub fn insert(&mut self, value: T) {
61 self.clock += 1;
62 self.entries.insert(value, self.clock);
63 }
64
65 pub fn remove(&mut self, value: &T)
67 where
68 T: Clone,
69 {
70 if self.entries.contains_key(value) {
71 self.clock += 1;
72 self.tombstones.insert(value.clone(), self.clock);
73 }
74 }
75
76 pub fn contains(&self, value: &T) -> bool {
78 match self.entries.get(value) {
79 Some(&entry_ts) => {
80 let tomb_ts = self.tombstones.get(value).copied().unwrap_or(0);
81 entry_ts >= tomb_ts
82 }
83 None => false,
84 }
85 }
86
87 pub fn elements(&self) -> Vec<&T> {
89 self.entries
90 .iter()
91 .filter(|(k, &entry_ts)| {
92 let tomb_ts = self.tombstones.get(*k).copied().unwrap_or(0);
93 entry_ts >= tomb_ts
94 })
95 .map(|(k, _)| k)
96 .collect()
97 }
98
99 pub fn len(&self) -> usize {
101 self.elements().len()
102 }
103
104 pub fn is_empty(&self) -> bool {
106 self.len() == 0
107 }
108
109 fn max_observed_ts(&self) -> u64 {
110 let max_entry = self.entries.values().copied().max().unwrap_or(0);
111 let max_tomb = self.tombstones.values().copied().max().unwrap_or(0);
112 max_entry.max(max_tomb)
113 }
114}
115
116impl<T: Ord> Default for ORSet<T> {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl<T: Ord + Clone + PartialEq> Lattice for ORSet<T> {
123 fn merge(&self, other: &Self) -> Self {
124 let mut entries = BTreeMap::new();
125 let mut tombstones = BTreeMap::new();
126
127 for (k, &ts) in &self.entries {
129 entries.insert(k.clone(), ts);
130 }
131 for (k, &ts) in &other.entries {
132 let entry = entries.entry(k.clone()).or_insert(0);
133 *entry = (*entry).max(ts);
134 }
135
136 for (k, &ts) in &self.tombstones {
138 tombstones.insert(k.clone(), ts);
139 }
140 for (k, &ts) in &other.tombstones {
141 let entry = tombstones.entry(k.clone()).or_insert(0);
142 *entry = (*entry).max(ts);
143 }
144
145 let clock = self.max_observed_ts().max(other.max_observed_ts());
146
147 Self {
148 entries,
149 tombstones,
150 clock,
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::lattice::assert_lattice_laws;
159
160 #[test]
161 fn new_is_empty() {
162 let s: ORSet<i32> = ORSet::new();
163 assert!(s.is_empty());
164 assert_eq!(s.len(), 0);
165 }
166
167 #[test]
168 fn insert_and_contains() {
169 let mut s = ORSet::new();
170 s.insert(1);
171 s.insert(2);
172 assert!(s.contains(&1));
173 assert!(s.contains(&2));
174 assert!(!s.contains(&3));
175 assert_eq!(s.len(), 2);
176 }
177
178 #[test]
179 fn remove_element() {
180 let mut s = ORSet::new();
181 s.insert(1);
182 s.insert(2);
183 s.remove(&1);
184 assert!(!s.contains(&1));
185 assert!(s.contains(&2));
186 assert_eq!(s.len(), 1);
187 }
188
189 #[test]
190 fn remove_nonexistent_is_noop() {
191 let mut s = ORSet::new();
192 s.insert(1);
193 s.remove(&2); assert_eq!(s.len(), 1);
195 }
196
197 #[test]
198 fn add_wins_after_concurrent_remove() {
199 let mut a = ORSet::new();
201 a.insert("x".to_string());
202
203 let mut b = a.clone();
204
205 a.remove(&"x".to_string());
207 assert!(!a.contains(&"x".to_string()));
208
209 b.insert("x".to_string());
211 assert!(b.contains(&"x".to_string()));
212
213 let merged = a.merge(&b);
215 assert!(
216 merged.contains(&"x".to_string()),
217 "add-wins: element should be present after merge"
218 );
219 }
220
221 #[test]
222 fn elements_sorted() {
223 let mut s = ORSet::new();
224 s.insert(3);
225 s.insert(1);
226 s.insert(2);
227 assert_eq!(s.elements(), vec![&1, &2, &3]);
228 }
229
230 #[test]
231 fn lattice_laws() {
232 let mut a = ORSet::new();
233 a.insert(1);
234 a.insert(2);
235
236 let mut b = ORSet::new();
237 b.insert(2);
238 b.insert(3);
239
240 let mut c = ORSet::new();
241 c.insert(1);
242 c.insert(3);
243
244 assert_lattice_laws(&a, &b, &c);
245 }
246
247 #[test]
248 fn lattice_laws_with_removes() {
249 let mut a = ORSet::new();
250 a.insert(1);
251 a.insert(2);
252 a.remove(&1);
253
254 let mut b = ORSet::new();
255 b.insert(2);
256 b.insert(3);
257 b.remove(&3);
258
259 let mut c = ORSet::new();
260 c.insert(1);
261 c.insert(3);
262
263 assert_lattice_laws(&a, &b, &c);
264 }
265
266 #[test]
267 fn merge_union_of_entries() {
268 let mut a = ORSet::new();
269 a.insert(1);
270 a.insert(2);
271
272 let mut b = ORSet::new();
273 b.insert(3);
274 b.insert(4);
275
276 let merged = a.merge(&b);
277 assert_eq!(merged.len(), 4);
278 }
279
280 #[test]
281 fn merge_empty_sets() {
282 let a: ORSet<i32> = ORSet::new();
283 let b: ORSet<i32> = ORSet::new();
284 let merged = a.merge(&b);
285 assert!(merged.is_empty());
286 }
287
288 #[test]
289 fn merge_one_empty() {
290 let a: ORSet<i32> = ORSet::new();
291 let mut b = ORSet::new();
292 b.insert(1);
293
294 assert_eq!(a.merge(&b).len(), 1);
295 assert_eq!(b.merge(&a).len(), 1);
296 }
297
298 #[test]
299 fn serde_roundtrip() {
300 let mut s = ORSet::new();
301 s.insert(1);
302 s.insert(2);
303 s.remove(&1);
304
305 let json = serde_json::to_string(&s).unwrap();
306 let back: ORSet<i32> = serde_json::from_str(&json).unwrap();
307
308 assert_eq!(s.elements(), back.elements());
310 assert_eq!(s.entries, back.entries);
311 assert_eq!(s.tombstones, back.tombstones);
312 }
313
314 #[test]
315 fn serde_preserves_tombstones() {
316 let mut s = ORSet::new();
317 s.insert("a".to_string());
318 s.remove(&"a".to_string());
319
320 let json = serde_json::to_value(&s).unwrap();
321 assert!(json.get("entries").is_some());
323 assert!(json.get("tombstones").is_some());
324
325 let back: ORSet<String> = serde_json::from_value(json).unwrap();
326 assert!(!back.contains(&"a".to_string()));
327 }
328}