scuffle_batching/
dataloader.rs

1//! Types related to the dataloader.
2//!
3//! Dataloaders should only be used for fetching data.
4//! If you need to batch writes, use a [`Batcher`](crate::batch::Batcher) instead.
5use std::collections::{HashMap, HashSet};
6use std::future::Future;
7use std::sync::Arc;
8
9/// A trait for fetching data in batches
10pub trait DataLoaderFetcher {
11    /// The incoming key type
12    type Key: Clone + Eq + std::hash::Hash + Send + Sync;
13    /// The outgoing value type
14    type Value: Clone + Send + Sync;
15
16    /// Load a batch of keys
17    fn load(&self, keys: HashSet<Self::Key>) -> impl Future<Output = Option<HashMap<Self::Key, Self::Value>>> + Send;
18}
19
20/// A builder for a [`DataLoader`]
21#[derive(Clone, Copy, Debug)]
22#[must_use = "builders must be used to create a dataloader"]
23pub struct DataLoaderBuilder<E> {
24    batch_size: usize,
25    concurrency: usize,
26    delay: std::time::Duration,
27    _phantom: std::marker::PhantomData<E>,
28}
29
30impl<E> Default for DataLoaderBuilder<E> {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<E> DataLoaderBuilder<E> {
37    /// Create a new builder
38    pub const fn new() -> Self {
39        Self {
40            batch_size: 1000,
41            concurrency: 50,
42            delay: std::time::Duration::from_millis(5),
43            _phantom: std::marker::PhantomData,
44        }
45    }
46
47    /// Set the batch size
48    #[inline]
49    pub const fn batch_size(mut self, batch_size: usize) -> Self {
50        self.with_batch_size(batch_size);
51        self
52    }
53
54    /// Set the delay
55    #[inline]
56    pub const fn delay(mut self, delay: std::time::Duration) -> Self {
57        self.with_delay(delay);
58        self
59    }
60
61    /// Set the concurrency
62    #[inline]
63    pub const fn concurrency(mut self, concurrency: usize) -> Self {
64        self.with_concurrency(concurrency);
65        self
66    }
67
68    /// Set the batch size
69    #[inline]
70    pub const fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
71        self.batch_size = batch_size;
72        self
73    }
74
75    /// Set the delay
76    #[inline]
77    pub const fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self {
78        self.delay = delay;
79        self
80    }
81
82    /// Set the concurrency
83    #[inline]
84    pub const fn with_concurrency(&mut self, concurrency: usize) -> &mut Self {
85        self.concurrency = concurrency;
86        self
87    }
88
89    /// Build the dataloader
90    #[inline]
91    pub fn build(self, executor: E) -> DataLoader<E>
92    where
93        E: DataLoaderFetcher + Send + Sync + 'static,
94    {
95        DataLoader::new(executor, self.batch_size, self.concurrency, self.delay)
96    }
97}
98
99/// A dataloader used to batch requests to a [`DataLoaderFetcher`]
100#[must_use = "dataloaders must be used to load data"]
101pub struct DataLoader<E>
102where
103    E: DataLoaderFetcher + Send + Sync + 'static,
104{
105    _auto_spawn: tokio::task::JoinHandle<()>,
106    executor: Arc<E>,
107    semaphore: Arc<tokio::sync::Semaphore>,
108    current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
109    batch_size: usize,
110}
111
112impl<E> DataLoader<E>
113where
114    E: DataLoaderFetcher + Send + Sync + 'static,
115{
116    /// Create a new dataloader
117    pub fn new(executor: E, batch_size: usize, concurrency: usize, delay: std::time::Duration) -> Self {
118        let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
119        let current_batch = Arc::new(tokio::sync::Mutex::new(None));
120        let executor = Arc::new(executor);
121
122        let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), delay));
123
124        Self {
125            executor,
126            _auto_spawn: join_handle,
127            semaphore,
128            current_batch,
129            batch_size: batch_size.max(1),
130        }
131    }
132
133    /// Create a builder for a [`DataLoader`]
134    #[inline]
135    pub const fn builder() -> DataLoaderBuilder<E> {
136        DataLoaderBuilder::new()
137    }
138
139    /// Load a single key
140    /// Can return an error if the underlying [`DataLoaderFetcher`] returns an
141    /// error
142    ///
143    /// Returns `None` if the key is not found
144    pub async fn load(&self, items: E::Key) -> Result<Option<E::Value>, ()> {
145        Ok(self.load_many(std::iter::once(items)).await?.into_values().next())
146    }
147
148    /// Load many keys
149    /// Can return an error if the underlying [`DataLoaderFetcher`] returns an
150    /// error
151    ///
152    /// Returns a map of keys to values which may be incomplete if any of the
153    /// keys were not found
154    pub async fn load_many<I>(&self, items: I) -> Result<HashMap<E::Key, E::Value>, ()>
155    where
156        I: IntoIterator<Item = E::Key> + Send,
157    {
158        struct BatchWaiting<K, V> {
159            keys: HashSet<K>,
160            result: Arc<BatchResult<K, V>>,
161        }
162
163        let mut waiters = Vec::<BatchWaiting<E::Key, E::Value>>::new();
164
165        let mut count = 0;
166
167        {
168            let mut new_batch = true;
169            let mut batch = self.current_batch.lock().await;
170
171            for item in items {
172                if batch.is_none() {
173                    batch.replace(Batch::new(self.semaphore.clone()));
174                    new_batch = true;
175                }
176
177                let batch_mut = batch.as_mut().unwrap();
178                batch_mut.items.insert(item.clone());
179
180                if new_batch {
181                    new_batch = false;
182                    waiters.push(BatchWaiting {
183                        keys: HashSet::new(),
184                        result: batch_mut.result.clone(),
185                    });
186                }
187
188                let waiting = waiters.last_mut().unwrap();
189                waiting.keys.insert(item);
190
191                count += 1;
192
193                if batch_mut.items.len() >= self.batch_size {
194                    tokio::spawn(batch.take().unwrap().spawn(self.executor.clone()));
195                }
196            }
197        }
198
199        let mut results = HashMap::with_capacity(count);
200        for waiting in waiters {
201            let result = waiting.result.wait().await?;
202            results.extend(waiting.keys.into_iter().filter_map(|key| {
203                let value = result.get(&key)?.clone();
204                Some((key, value))
205            }));
206        }
207
208        Ok(results)
209    }
210}
211
212async fn batch_loop<E>(
213    executor: Arc<E>,
214    current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>,
215    delay: std::time::Duration,
216) where
217    E: DataLoaderFetcher + Send + Sync + 'static,
218{
219    let mut delay_delta = delay;
220    loop {
221        tokio::time::sleep(delay_delta).await;
222
223        let mut batch = current_batch.lock().await;
224        let Some(created_at) = batch.as_ref().map(|b| b.created_at) else {
225            delay_delta = delay;
226            continue;
227        };
228
229        let remaining = delay.saturating_sub(created_at.elapsed());
230        if remaining == std::time::Duration::ZERO {
231            tokio::spawn(batch.take().unwrap().spawn(executor.clone()));
232            delay_delta = delay;
233        } else {
234            delay_delta = remaining;
235        }
236    }
237}
238
239struct BatchResult<K, V> {
240    values: tokio::sync::OnceCell<Option<HashMap<K, V>>>,
241    token: tokio_util::sync::CancellationToken,
242}
243
244impl<K, V> BatchResult<K, V> {
245    fn new() -> Self {
246        Self {
247            values: tokio::sync::OnceCell::new(),
248            token: tokio_util::sync::CancellationToken::new(),
249        }
250    }
251
252    async fn wait(&self) -> Result<&HashMap<K, V>, ()> {
253        if !self.token.is_cancelled() {
254            self.token.cancelled().await;
255        }
256
257        self.values.get().ok_or(())?.as_ref().ok_or(())
258    }
259}
260
261struct Batch<E>
262where
263    E: DataLoaderFetcher + Send + Sync + 'static,
264{
265    items: HashSet<E::Key>,
266    result: Arc<BatchResult<E::Key, E::Value>>,
267    semaphore: Arc<tokio::sync::Semaphore>,
268    created_at: std::time::Instant,
269}
270
271impl<E> Batch<E>
272where
273    E: DataLoaderFetcher + Send + Sync + 'static,
274{
275    fn new(semaphore: Arc<tokio::sync::Semaphore>) -> Self {
276        Self {
277            items: HashSet::new(),
278            result: Arc::new(BatchResult::new()),
279            semaphore,
280            created_at: std::time::Instant::now(),
281        }
282    }
283
284    async fn spawn(self, executor: Arc<E>) {
285        let _drop_guard = self.result.token.clone().drop_guard();
286        let _ticket = self.semaphore.acquire_owned().await.unwrap();
287        let result = executor.load(self.items).await;
288
289        #[cfg_attr(all(coverage_nightly, test), coverage(off))]
290        fn unknwown_error<E>(_: E) -> ! {
291            unreachable!(
292                "batch result already set, this is a bug please report it https://github.com/scufflecloud/scuffle/issues"
293            )
294        }
295
296        self.result.values.set(result).map_err(unknwown_error).unwrap();
297    }
298}
299
300/// TODO: Windows is disabled because i suspect windows doesnt measure time precisely
301/// enough to test the time-sensitive tests.
302/// We should fix this and re-enable the tests.
303/// Similar issue with macos, but macos is disabled because it is too slow
304/// in CI and the tests fail due to timeouts.
305/// CLOUD-74
306#[cfg_attr(all(coverage_nightly, test), coverage(off))]
307#[cfg(all(test, not(windows), not(target_os = "macos")))]
308mod tests {
309    use std::sync::atomic::AtomicUsize;
310
311    use super::*;
312
313    struct TestFetcher<K, V> {
314        values: HashMap<K, V>,
315        delay: std::time::Duration,
316        requests: Arc<AtomicUsize>,
317        capacity: usize,
318    }
319
320    impl<K, V> DataLoaderFetcher for TestFetcher<K, V>
321    where
322        K: Clone + Eq + std::hash::Hash + Send + Sync,
323        V: Clone + Send + Sync,
324    {
325        type Key = K;
326        type Value = V;
327
328        async fn load(&self, keys: HashSet<Self::Key>) -> Option<HashMap<Self::Key, Self::Value>> {
329            assert!(keys.len() <= self.capacity);
330            tokio::time::sleep(self.delay).await;
331            self.requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
332            Some(
333                keys.into_iter()
334                    .filter_map(|k| {
335                        let value = self.values.get(&k)?.clone();
336                        Some((k, value))
337                    })
338                    .collect(),
339            )
340        }
341    }
342
343    #[cfg(not(valgrind))] // test is time-sensitive
344    #[tokio::test]
345    async fn basic() {
346        let requests = Arc::new(AtomicUsize::new(0));
347
348        let fetcher = TestFetcher {
349            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
350            delay: std::time::Duration::from_millis(5),
351            requests: requests.clone(),
352            capacity: 2,
353        };
354
355        let loader = DataLoader::builder().batch_size(2).concurrency(1).build(fetcher);
356
357        let start = std::time::Instant::now();
358        let a = loader.load("a").await.unwrap();
359        assert_eq!(a, Some(1));
360        assert!(start.elapsed() < std::time::Duration::from_millis(15));
361        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
362
363        let start = std::time::Instant::now();
364        let b = loader.load("b").await.unwrap();
365        assert_eq!(b, Some(2));
366        assert!(start.elapsed() < std::time::Duration::from_millis(15));
367        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 2);
368        let start = std::time::Instant::now();
369        let c = loader.load("c").await.unwrap();
370        assert_eq!(c, Some(3));
371        assert!(start.elapsed() < std::time::Duration::from_millis(15));
372        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 3);
373
374        let start = std::time::Instant::now();
375        let ab = loader.load_many(vec!["a", "b"]).await.unwrap();
376        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2)]));
377        assert!(start.elapsed() < std::time::Duration::from_millis(15));
378        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 4);
379
380        let start = std::time::Instant::now();
381        let unknown = loader.load("unknown").await.unwrap();
382        assert_eq!(unknown, None);
383        assert!(start.elapsed() < std::time::Duration::from_millis(15));
384        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
385    }
386
387    #[cfg(not(valgrind))] // test is time-sensitive
388    #[tokio::test]
389    async fn concurrency_high() {
390        let requests = Arc::new(AtomicUsize::new(0));
391
392        let fetcher = TestFetcher {
393            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
394            delay: std::time::Duration::from_millis(5),
395            requests: requests.clone(),
396            capacity: 2,
397        };
398
399        let loader = DataLoader::builder().batch_size(2).concurrency(10).build(fetcher);
400
401        let start = std::time::Instant::now();
402        let ab = loader
403            .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
404            .await
405            .unwrap();
406        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
407        assert!(start.elapsed() < std::time::Duration::from_millis(15));
408        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
409    }
410
411    #[cfg(not(valgrind))] // test is time-sensitive
412    #[tokio::test]
413    async fn delay_low() {
414        let requests = Arc::new(AtomicUsize::new(0));
415
416        let fetcher = TestFetcher {
417            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
418            delay: std::time::Duration::from_millis(5),
419            requests: requests.clone(),
420            capacity: 2,
421        };
422
423        let loader = DataLoader::builder()
424            .batch_size(2)
425            .concurrency(1)
426            .delay(std::time::Duration::from_millis(10))
427            .build(fetcher);
428
429        let start = std::time::Instant::now();
430        let ab = loader
431            .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
432            .await
433            .unwrap();
434        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
435        assert!(start.elapsed() < std::time::Duration::from_millis(35));
436        assert!(start.elapsed() >= std::time::Duration::from_millis(25));
437        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 5);
438    }
439
440    #[cfg(not(valgrind))] // test is time-sensitive
441    #[tokio::test]
442    async fn batch_size() {
443        let requests = Arc::new(AtomicUsize::new(0));
444
445        let fetcher = TestFetcher {
446            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
447            delay: std::time::Duration::from_millis(5),
448            requests: requests.clone(),
449            capacity: 100,
450        };
451
452        let loader = DataLoaderBuilder::default()
453            .batch_size(100)
454            .concurrency(1)
455            .delay(std::time::Duration::from_millis(10))
456            .build(fetcher);
457
458        let start = std::time::Instant::now();
459        let ab = loader
460            .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
461            .await
462            .unwrap();
463        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
464        assert!(start.elapsed() >= std::time::Duration::from_millis(10));
465        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
466    }
467
468    #[cfg(not(valgrind))] // test is time-sensitive
469    #[tokio::test]
470    async fn high_concurrency() {
471        let requests = Arc::new(AtomicUsize::new(0));
472
473        let fetcher = TestFetcher {
474            values: HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))),
475            delay: std::time::Duration::from_millis(5),
476            requests: requests.clone(),
477            capacity: 100,
478        };
479
480        let loader = DataLoaderBuilder::default()
481            .batch_size(100)
482            .concurrency(10)
483            .delay(std::time::Duration::from_millis(10))
484            .build(fetcher);
485
486        let start = std::time::Instant::now();
487        let ab = loader.load_many(0..1134).await.unwrap();
488        assert_eq!(ab, HashMap::from_iter((0..1134).map(|i| (i, i * 2 + 5))));
489        assert!(start.elapsed() >= std::time::Duration::from_millis(15));
490        assert!(start.elapsed() < std::time::Duration::from_millis(25));
491        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1134 / 100 + 1);
492    }
493
494    #[cfg(not(valgrind))] // test is time-sensitive
495    #[tokio::test]
496    async fn delayed_start() {
497        let requests = Arc::new(AtomicUsize::new(0));
498
499        let fetcher = TestFetcher {
500            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
501            delay: std::time::Duration::from_millis(5),
502            requests: requests.clone(),
503            capacity: 2,
504        };
505
506        let loader = DataLoader::builder()
507            .batch_size(2)
508            .concurrency(100)
509            .delay(std::time::Duration::from_millis(10))
510            .build(fetcher);
511
512        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
513
514        let start = std::time::Instant::now();
515        let ab = loader
516            .load_many(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
517            .await
518            .unwrap();
519        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
520        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
521        assert!(start.elapsed() < std::time::Duration::from_millis(25));
522    }
523
524    #[cfg(not(valgrind))] // test is time-sensitive
525    #[tokio::test]
526    async fn delayed_start_single() {
527        let requests = Arc::new(AtomicUsize::new(0));
528
529        let fetcher = TestFetcher {
530            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
531            delay: std::time::Duration::from_millis(5),
532            requests: requests.clone(),
533            capacity: 2,
534        };
535
536        let loader = DataLoader::builder()
537            .batch_size(2)
538            .concurrency(100)
539            .delay(std::time::Duration::from_millis(10))
540            .build(fetcher);
541
542        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
543
544        let start = std::time::Instant::now();
545        let ab = loader.load_many(vec!["a"]).await.unwrap();
546        assert_eq!(ab, HashMap::from_iter(vec![("a", 1)]));
547        assert!(start.elapsed() >= std::time::Duration::from_millis(15));
548        assert!(start.elapsed() < std::time::Duration::from_millis(20));
549    }
550
551    #[cfg(not(valgrind))] // test is time-sensitive
552    #[tokio::test]
553    async fn deduplication() {
554        let requests = Arc::new(AtomicUsize::new(0));
555
556        let fetcher = TestFetcher {
557            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
558            delay: std::time::Duration::from_millis(5),
559            requests: requests.clone(),
560            capacity: 4,
561        };
562
563        let loader = DataLoader::builder()
564            .batch_size(4)
565            .concurrency(1)
566            .delay(std::time::Duration::from_millis(10))
567            .build(fetcher);
568
569        let start = std::time::Instant::now();
570        let ab = loader.load_many(vec!["a", "a", "b", "b", "c", "c"]).await.unwrap();
571        assert_eq!(ab, HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]));
572        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
573        assert!(start.elapsed() >= std::time::Duration::from_millis(5));
574        assert!(start.elapsed() < std::time::Duration::from_millis(20));
575    }
576
577    #[cfg(not(valgrind))] // test is time-sensitive
578    #[tokio::test]
579    async fn already_batch() {
580        let requests = Arc::new(AtomicUsize::new(0));
581
582        let fetcher = TestFetcher {
583            values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]),
584            delay: std::time::Duration::from_millis(5),
585            requests: requests.clone(),
586            capacity: 2,
587        };
588
589        let loader = DataLoader::builder().batch_size(10).concurrency(1).build(fetcher);
590
591        let start = std::time::Instant::now();
592        let (a, b) = tokio::join!(loader.load("a"), loader.load("b"));
593        assert_eq!(a, Ok(Some(1)));
594        assert_eq!(b, Ok(Some(2)));
595        assert!(start.elapsed() < std::time::Duration::from_millis(15));
596        assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1);
597    }
598}