Skip to main content

box_image_pyramid/
pyramid.rs

1//! Image pyramid construction using fixed 2x box-filter downsampling.
2//!
3//! The API is allocation-friendly: construct a [`PyramidBuffers`] once, then
4//! reuse it to build pyramids for successive frames without re-allocating
5//! intermediate levels. When both the `par_pyramid` and `simd` features are
6//! enabled, the 2x box downsample uses portable SIMD for higher throughput.
7
8use crate::imageview::{ImageBuffer, ImageView};
9#[cfg(feature = "tracing")]
10use tracing::instrument;
11
12/// Reusable backing storage for pyramid construction.
13///
14/// Typically you construct a [`PyramidBuffers`] once (for example with
15/// [`PyramidBuffers::with_capacity`]) and reuse it across frames by
16/// passing a mutable reference into [`build_pyramid`]. The internal level
17/// buffers are resized on demand to match the requested pyramid shape.
18pub struct PyramidBuffers {
19    levels: Vec<ImageBuffer>,
20}
21
22impl Default for PyramidBuffers {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl PyramidBuffers {
29    /// Create an empty buffer set.
30    pub fn new() -> Self {
31        Self { levels: Vec::new() }
32    }
33
34    /// Create a buffer set with capacity reserved for `num_levels`.
35    pub fn with_capacity(num_levels: u8) -> Self {
36        Self {
37            levels: Vec::with_capacity(num_levels.saturating_sub(1) as usize),
38        }
39    }
40
41    fn ensure_level_shape(&mut self, idx: usize, w: usize, h: usize) {
42        if idx >= self.levels.len() {
43            self.levels.resize_with(idx + 1, || ImageBuffer::new(w, h));
44        }
45
46        let level = &mut self.levels[idx];
47        if level.width != w || level.height != h {
48            *level = ImageBuffer::new(w, h);
49        }
50    }
51}
52
53/// A single pyramid level. The `scale` is relative to the base image.
54#[non_exhaustive]
55pub struct PyramidLevel<'a> {
56    pub img: ImageView<'a>,
57    pub scale: f32, // relative to base (e.g. 1.0, 0.5, 0.25, ...)
58}
59
60/// A top-down pyramid where `levels[0]` is the base (full resolution).
61#[non_exhaustive]
62pub struct Pyramid<'a> {
63    pub levels: Vec<PyramidLevel<'a>>, // levels[0] is base
64}
65
66/// Parameters controlling pyramid generation.
67#[derive(Clone, Debug)]
68#[non_exhaustive]
69pub struct PyramidParams {
70    /// Maximum number of levels (including the base).
71    pub num_levels: u8,
72    /// Stop building when either dimension falls below this value.
73    pub min_size: usize,
74}
75
76impl Default for PyramidParams {
77    fn default() -> Self {
78        Self {
79            num_levels: 1,
80            min_size: 128,
81        }
82    }
83}
84
85/// Build a top-down image pyramid using fixed 2x downsampling.
86///
87/// The base image is always included as level 0. Each subsequent level is a
88/// 2x downsampled copy (box filter) written into `buffers`. Construction stops
89/// when:
90/// - either dimension would fall below `min_size`, or
91/// - `num_levels` is reached.
92#[cfg_attr(
93    feature = "tracing",
94    instrument(
95        level = "info",
96        skip(base, params, buffers),
97        fields(levels = params.num_levels, min_size = params.min_size)
98    )
99)]
100pub fn build_pyramid<'a>(
101    base: ImageView<'a>,
102    params: &PyramidParams,
103    buffers: &'a mut PyramidBuffers,
104) -> Pyramid<'a> {
105    if params.num_levels == 0 || base.width < params.min_size || base.height < params.min_size {
106        return Pyramid { levels: Vec::new() };
107    }
108
109    #[derive(Clone, Copy)]
110    enum LevelSource {
111        Base,
112        Buffer(usize),
113    }
114
115    let mut sources: Vec<(LevelSource, f32)> = Vec::with_capacity(params.num_levels as usize);
116    sources.push((LevelSource::Base, 1.0));
117
118    let mut current_src = LevelSource::Base;
119    let mut current_w = base.width;
120    let mut current_h = base.height;
121    let mut scale = 1.0f32;
122
123    for level_idx in 1..params.num_levels {
124        let w2 = current_w / 2;
125        let h2 = current_h / 2;
126
127        if w2 == 0 || h2 == 0 || w2 < params.min_size || h2 < params.min_size {
128            break;
129        }
130
131        let buf_idx = (level_idx - 1) as usize;
132        buffers.ensure_level_shape(buf_idx, w2, h2);
133
134        let (src_img, dst): (ImageView<'_>, &mut ImageBuffer) = match current_src {
135            LevelSource::Base => (base, &mut buffers.levels[buf_idx]),
136            LevelSource::Buffer(src_idx) => {
137                debug_assert!(src_idx < buf_idx);
138                let (head, tail) = buffers.levels.split_at_mut(buf_idx);
139                (head[src_idx].as_view(), &mut tail[0])
140            }
141        };
142
143        downsample_2x_box(src_img, dst);
144
145        scale *= 0.5;
146        current_src = LevelSource::Buffer(buf_idx);
147        current_w = w2;
148        current_h = h2;
149        sources.push((current_src, scale));
150    }
151
152    let mut levels = Vec::with_capacity(sources.len());
153    for (source, lvl_scale) in sources {
154        let img = match source {
155            LevelSource::Base => base,
156            LevelSource::Buffer(idx) => buffers.levels[idx].as_view(),
157        };
158        levels.push(PyramidLevel {
159            img,
160            scale: lvl_scale,
161        });
162    }
163
164    Pyramid { levels }
165}
166
167/// Fast 2x downsample with a 2x2 box filter into a pre-allocated destination.
168///
169/// Uses SIMD and/or `rayon` specializations when the `par_pyramid`
170/// feature is enabled alongside the relevant flags.
171#[inline]
172fn downsample_2x_box(src: ImageView<'_>, dst: &mut ImageBuffer) {
173    #[cfg(all(feature = "par_pyramid", feature = "rayon", feature = "simd"))]
174    return downsample_2x_box_parallel_simd(src, dst);
175
176    #[cfg(all(feature = "par_pyramid", feature = "rayon", not(feature = "simd")))]
177    return downsample_2x_box_parallel_scalar(src, dst);
178
179    #[cfg(all(feature = "par_pyramid", not(feature = "rayon"), feature = "simd"))]
180    return downsample_2x_box_simd(src, dst);
181
182    #[cfg(all(feature = "par_pyramid", not(feature = "rayon"), not(feature = "simd")))]
183    return downsample_2x_box_scalar(src, dst);
184
185    #[cfg(not(feature = "par_pyramid"))]
186    return downsample_2x_box_scalar(src, dst);
187}
188
189#[inline]
190#[cfg_attr(
191    all(feature = "par_pyramid", any(feature = "rayon", feature = "simd")),
192    allow(dead_code)
193)]
194fn downsample_2x_box_scalar(src: ImageView<'_>, dst: &mut ImageBuffer) {
195    debug_assert_eq!(src.width / 2, dst.width);
196    debug_assert_eq!(src.height / 2, dst.height);
197
198    let src_w = src.width;
199    let dst_w = dst.width;
200    let dst_h = dst.height;
201
202    for y in 0..dst_h {
203        let row0 = (y * 2) * src_w;
204        let row1 = row0 + src_w;
205
206        downsample_row_scalar(
207            &src.data[row0..row0 + src_w],
208            &src.data[row1..row1 + src_w],
209            &mut dst.data[y * dst_w..(y + 1) * dst_w],
210        );
211    }
212}
213
214#[cfg(all(feature = "par_pyramid", not(feature = "rayon"), feature = "simd"))]
215fn downsample_2x_box_simd(src: ImageView<'_>, dst: &mut ImageBuffer) {
216    debug_assert_eq!(src.width / 2, dst.width);
217    debug_assert_eq!(src.height / 2, dst.height);
218
219    let src_w = src.width;
220    let dst_w = dst.width;
221    let dst_h = dst.height;
222
223    for y_out in 0..dst_h {
224        let y0 = 2 * y_out;
225        let y1 = y0 + 1;
226
227        let row0 = &src.data[y0 * src_w..(y0 + 1) * src_w];
228        let row1 = &src.data[y1 * src_w..(y1 + 1) * src_w];
229
230        let dst_row = &mut dst.data[y_out * dst_w..(y_out + 1) * dst_w];
231
232        downsample_row_simd(row0, row1, dst_row);
233    }
234}
235
236#[cfg(all(feature = "par_pyramid", feature = "rayon", not(feature = "simd")))]
237fn downsample_2x_box_parallel_scalar(src: ImageView<'_>, dst: &mut ImageBuffer) {
238    use rayon::prelude::*;
239
240    debug_assert_eq!(src.width / 2, dst.width);
241    debug_assert_eq!(src.height / 2, dst.height);
242
243    let src_w = src.width;
244    let dst_w = dst.width;
245
246    dst.data
247        .par_chunks_mut(dst_w)
248        .enumerate()
249        .for_each(|(y_out, dst_row)| {
250            let y0 = 2 * y_out;
251            let y1 = y0 + 1;
252
253            let row0 = &src.data[y0 * src_w..(y0 + 1) * src_w];
254            let row1 = &src.data[y1 * src_w..(y1 + 1) * src_w];
255
256            downsample_row_scalar(row0, row1, dst_row);
257        });
258}
259
260#[cfg(all(feature = "par_pyramid", feature = "rayon", feature = "simd"))]
261fn downsample_2x_box_parallel_simd(src: ImageView<'_>, dst: &mut ImageBuffer) {
262    use rayon::prelude::*;
263
264    debug_assert_eq!(src.width / 2, dst.width);
265    debug_assert_eq!(src.height / 2, dst.height);
266
267    let src_w = src.width;
268    let dst_w = dst.width;
269
270    dst.data
271        .par_chunks_mut(dst_w)
272        .enumerate()
273        .for_each(|(y_out, dst_row)| {
274            let y0 = 2 * y_out;
275            let y1 = y0 + 1;
276
277            let row0 = &src.data[y0 * src_w..(y0 + 1) * src_w];
278            let row1 = &src.data[y1 * src_w..(y1 + 1) * src_w];
279
280            downsample_row_simd(row0, row1, dst_row);
281        });
282}
283
284#[inline]
285fn downsample_row_scalar(row0: &[u8], row1: &[u8], dst_row: &mut [u8]) {
286    let dst_w = dst_row.len();
287
288    for (x, item) in dst_row.iter_mut().enumerate().take(dst_w) {
289        let sx = x * 2;
290        let p00 = row0[sx] as u16;
291        let p01 = row0[sx + 1] as u16;
292        let p10 = row1[sx] as u16;
293        let p11 = row1[sx + 1] as u16;
294        let sum = p00 + p01 + p10 + p11;
295        *item = ((sum + 2) >> 2) as u8;
296    }
297}
298
299#[cfg(all(feature = "par_pyramid", feature = "simd"))]
300fn downsample_row_simd(row0: &[u8], row1: &[u8], dst_row: &mut [u8]) {
301    use std::ops::Shr;
302    use std::simd::num::SimdUint;
303    use std::simd::{u16x16, u8x16};
304
305    const LANES: usize = 16;
306    let mut x_out = 0usize;
307
308    while x_out + LANES <= dst_row.len() {
309        let mut p00 = [0u8; LANES];
310        let mut p01 = [0u8; LANES];
311        let mut p10 = [0u8; LANES];
312        let mut p11 = [0u8; LANES];
313
314        for lane in 0..LANES {
315            let x = x_out + lane;
316            let sx = 2 * x;
317            p00[lane] = row0[sx];
318            p01[lane] = row0[sx + 1];
319            p10[lane] = row1[sx];
320            p11[lane] = row1[sx + 1];
321        }
322
323        let p00v = u8x16::from_array(p00).cast::<u16>();
324        let p01v = u8x16::from_array(p01).cast::<u16>();
325        let p10v = u8x16::from_array(p10).cast::<u16>();
326        let p11v = u8x16::from_array(p11).cast::<u16>();
327
328        let sum = p00v + p01v + p10v + p11v;
329        let avg = (sum + u16x16::splat(2)).shr(2);
330        let out = avg.cast::<u8>();
331
332        dst_row[x_out..x_out + LANES].copy_from_slice(out.as_array());
333        x_out += LANES;
334    }
335
336    // Tail
337    if x_out < dst_row.len() {
338        downsample_row_scalar(row0, row1, &mut dst_row[x_out..]);
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    fn reference_downsample(src: &ImageBuffer) -> ImageBuffer {
347        let mut dst = ImageBuffer::new(src.width / 2, src.height / 2);
348        downsample_2x_box_scalar(src.as_view(), &mut dst);
349        dst
350    }
351
352    fn gray_to_buffer(w: u32, h: u32, data: Vec<u8>) -> ImageBuffer {
353        ImageBuffer {
354            width: w as usize,
355            height: h as usize,
356            data,
357        }
358    }
359
360    fn make_checker(w: u32, h: u32, a: u8, b: u8) -> ImageBuffer {
361        let mut data = vec![0u8; (w * h) as usize];
362        for y in 0..h {
363            for x in 0..w {
364                data[(y * w + x) as usize] = if (x + y) % 2 == 0 { a } else { b };
365            }
366        }
367        gray_to_buffer(w, h, data)
368    }
369
370    #[test]
371    fn downsample_matches_reference() {
372        let mut src = ImageBuffer::new(8, 8);
373        for (i, p) in src.data.iter_mut().enumerate() {
374            *p = (i % 251) as u8;
375        }
376        let mut dst = ImageBuffer::new(4, 4);
377        downsample_2x_box(src.as_view(), &mut dst);
378        let expected = reference_downsample(&src);
379        assert_eq!(dst.data, expected.data);
380    }
381
382    #[test]
383    fn downsample_matches_reference_on_checker() {
384        let src = make_checker(16, 14, 0, 255);
385        let mut dst = ImageBuffer::new(src.width / 2, src.height / 2);
386        downsample_2x_box(src.as_view(), &mut dst);
387        let expected = reference_downsample(&src);
388        assert_eq!(dst.data, expected.data);
389    }
390
391    #[test]
392    fn build_pyramid_single_level() {
393        let img = ImageBuffer::new(64, 64);
394        let params = PyramidParams {
395            num_levels: 1,
396            min_size: 16,
397        };
398        let mut buffers = PyramidBuffers::new();
399        let pyramid = build_pyramid(img.as_view(), &params, &mut buffers);
400        assert_eq!(pyramid.levels.len(), 1);
401        assert_eq!(pyramid.levels[0].scale, 1.0);
402    }
403
404    #[test]
405    fn build_pyramid_multiple_levels() {
406        let img = ImageBuffer::new(128, 128);
407        let params = PyramidParams {
408            num_levels: 4,
409            min_size: 16,
410        };
411        let mut buffers = PyramidBuffers::new();
412        let pyramid = build_pyramid(img.as_view(), &params, &mut buffers);
413        assert_eq!(pyramid.levels.len(), 4);
414        assert_eq!(pyramid.levels[0].img.width, 128);
415        assert_eq!(pyramid.levels[1].img.width, 64);
416        assert_eq!(pyramid.levels[2].img.width, 32);
417        assert_eq!(pyramid.levels[3].img.width, 16);
418        assert!((pyramid.levels[3].scale - 0.125).abs() < 1e-6);
419    }
420
421    #[test]
422    fn build_pyramid_stops_at_min_size() {
423        let img = ImageBuffer::new(64, 64);
424        let params = PyramidParams {
425            num_levels: 10,
426            min_size: 32,
427        };
428        let mut buffers = PyramidBuffers::new();
429        let pyramid = build_pyramid(img.as_view(), &params, &mut buffers);
430        assert_eq!(pyramid.levels.len(), 2); // 64 -> 32, stops before 16
431    }
432}