Skip to main content

chess_corners_core/
refine.rs

1//! Pluggable subpixel refinement backends for ChESS corners.
2//!
3//! The default pipeline uses a 5×5 center-of-mass refinement on the response
4//! map (matching the legacy behavior). Alternative refiners operate on the
5//! original image intensity patch and provide more discriminative scores and
6//! acceptance logic.
7use crate::imageview::ImageView;
8use crate::ResponseMap;
9
10/// Status of a refinement attempt.
11#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12#[non_exhaustive]
13pub enum RefineStatus {
14    Accepted,
15    Rejected,
16    OutOfBounds,
17    IllConditioned,
18}
19
20/// Result of refining a single corner candidate.
21#[derive(Copy, Clone, Debug)]
22pub struct RefineResult {
23    /// Refined subpixel x coordinate.
24    pub x: f32,
25    /// Refined subpixel y coordinate.
26    pub y: f32,
27    pub score: f32,
28    pub status: RefineStatus,
29}
30
31impl RefineResult {
32    #[inline]
33    pub fn accepted(xy: [f32; 2], score: f32) -> Self {
34        Self {
35            x: xy[0],
36            y: xy[1],
37            score,
38            status: RefineStatus::Accepted,
39        }
40    }
41}
42
43/// Inputs shared by refinement methods.
44#[derive(Copy, Clone, Debug, Default)]
45pub struct RefineContext<'a> {
46    pub image: Option<ImageView<'a>>,
47    pub response: Option<&'a ResponseMap>,
48}
49
50/// Trait implemented by pluggable refinement backends.
51pub trait CornerRefiner {
52    /// Half-width of the patch the refiner needs around the seed.
53    fn radius(&self) -> i32;
54    fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult;
55}
56
57/// User-facing enum selecting a refinement backend.
58#[derive(Clone, Debug)]
59#[non_exhaustive]
60pub enum RefinerKind {
61    CenterOfMass(CenterOfMassConfig),
62    Forstner(ForstnerConfig),
63    SaddlePoint(SaddlePointConfig),
64}
65
66impl Default for RefinerKind {
67    fn default() -> Self {
68        Self::CenterOfMass(CenterOfMassConfig::default())
69    }
70}
71
72/// Runtime refiner with reusable scratch buffers.
73#[derive(Debug)]
74pub enum Refiner {
75    CenterOfMass(CenterOfMassRefiner),
76    Forstner(ForstnerRefiner),
77    SaddlePoint(SaddlePointRefiner),
78}
79
80impl Refiner {
81    pub fn from_kind(kind: RefinerKind) -> Self {
82        match kind {
83            RefinerKind::CenterOfMass(cfg) => Refiner::CenterOfMass(CenterOfMassRefiner::new(cfg)),
84            RefinerKind::Forstner(cfg) => Refiner::Forstner(ForstnerRefiner::new(cfg)),
85            RefinerKind::SaddlePoint(cfg) => Refiner::SaddlePoint(SaddlePointRefiner::new(cfg)),
86        }
87    }
88}
89
90impl CornerRefiner for Refiner {
91    #[inline]
92    fn radius(&self) -> i32 {
93        match self {
94            Refiner::CenterOfMass(r) => r.radius(),
95            Refiner::Forstner(r) => r.radius(),
96            Refiner::SaddlePoint(r) => r.radius(),
97        }
98    }
99
100    #[inline]
101    fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
102        match self {
103            Refiner::CenterOfMass(r) => r.refine(seed_xy, ctx),
104            Refiner::Forstner(r) => r.refine(seed_xy, ctx),
105            Refiner::SaddlePoint(r) => r.refine(seed_xy, ctx),
106        }
107    }
108}
109
110/// Legacy center-of-mass refinement on the response map.
111#[derive(Clone, Copy, Debug)]
112pub struct CenterOfMassConfig {
113    pub radius: i32,
114}
115
116impl Default for CenterOfMassConfig {
117    fn default() -> Self {
118        Self { radius: 2 }
119    }
120}
121
122#[derive(Debug)]
123pub struct CenterOfMassRefiner {
124    cfg: CenterOfMassConfig,
125}
126
127impl CenterOfMassRefiner {
128    pub fn new(cfg: CenterOfMassConfig) -> Self {
129        Self { cfg }
130    }
131}
132
133impl CornerRefiner for CenterOfMassRefiner {
134    #[inline]
135    fn radius(&self) -> i32 {
136        self.cfg.radius
137    }
138
139    fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
140        let resp = match ctx.response {
141            Some(r) => r,
142            None => {
143                return RefineResult {
144                    x: seed_xy[0],
145                    y: seed_xy[1],
146                    score: 0.0,
147                    status: RefineStatus::Rejected,
148                }
149            }
150        };
151
152        let x = seed_xy[0].round() as i32;
153        let y = seed_xy[1].round() as i32;
154        let r = self.cfg.radius;
155
156        let mut sx = 0.0;
157        let mut sy = 0.0;
158        let mut sw = 0.0;
159
160        let w = resp.w as i32;
161        let h = resp.h as i32;
162
163        if x < r || y < r || x >= w - r || y >= h - r {
164            return RefineResult {
165                x: seed_xy[0],
166                y: seed_xy[1],
167                score: 0.0,
168                status: RefineStatus::OutOfBounds,
169            };
170        }
171
172        for dy in -r..=r {
173            let yy = (y + dy).clamp(0, h - 1) as usize;
174            for dx in -r..=r {
175                let xx = (x + dx).clamp(0, w - 1) as usize;
176                let w_px = resp.at(xx, yy).max(0.0);
177                sx += (xx as f32) * w_px;
178                sy += (yy as f32) * w_px;
179                sw += w_px;
180            }
181        }
182
183        if sw > 0.0 {
184            RefineResult::accepted([sx / sw, sy / sw], sw)
185        } else {
186            RefineResult {
187                x: seed_xy[0],
188                y: seed_xy[1],
189                score: 0.0,
190                status: RefineStatus::Accepted,
191            }
192        }
193    }
194}
195
196/// Förstner-style gradient-based refiner.
197///
198/// The Förstner operator fits a subpixel corner location by solving a
199/// weighted least-squares system on the image gradient structure tensor
200/// within a local window. The thresholds below control when the system
201/// is well-conditioned enough to yield a reliable estimate.
202///
203/// Reference: Förstner, W. & Gülch, E. (1987). "A fast operator for
204/// detection and precise location of distinct points, corners and centres
205/// of circular features."
206#[derive(Clone, Copy, Debug)]
207pub struct ForstnerConfig {
208    /// Half-size of the local gradient window (full window is `2*radius+1`).
209    /// A radius of 2 gives a 5×5 patch — large enough to capture the
210    /// gradient structure around a corner while staying local.
211    pub radius: i32,
212    /// Minimum trace of the structure tensor (sum of eigenvalues).
213    /// Rejects flat regions where gradient energy is too low. The value
214    /// 25.0 corresponds roughly to an average gradient magnitude of ~5
215    /// per pixel in a 5×5 window (5² = 25), filtering out textureless areas.
216    pub min_trace: f32,
217    /// Minimum determinant of the structure tensor (product of eigenvalues).
218    /// Guards against singular or near-singular systems where the least-squares
219    /// solution is numerically unstable. 1e-3 is a conservative floor that
220    /// rejects only truly degenerate cases.
221    pub min_det: f32,
222    /// Maximum ratio of the larger to the smaller eigenvalue. A high
223    /// condition number indicates an edge rather than a corner (one dominant
224    /// gradient direction). The threshold 50.0 is permissive — standard
225    /// Harris/Förstner literature suggests values in the 10–100 range
226    /// depending on noise level and corner sharpness.
227    pub max_condition_number: f32,
228    /// Maximum displacement (in pixels) from the initial integer seed to
229    /// the refined subpixel location. Offsets larger than ~1.5 px suggest
230    /// the seed was mislocated and the refinement is extrapolating rather
231    /// than interpolating; such results are rejected.
232    pub max_offset: f32,
233}
234
235impl Default for ForstnerConfig {
236    fn default() -> Self {
237        Self {
238            radius: 2,
239            min_trace: 25.0,
240            min_det: 1e-3,
241            max_condition_number: 50.0,
242            max_offset: 1.5,
243        }
244    }
245}
246
247#[derive(Debug)]
248pub struct ForstnerRefiner {
249    cfg: ForstnerConfig,
250}
251
252impl ForstnerRefiner {
253    pub fn new(cfg: ForstnerConfig) -> Self {
254        Self { cfg }
255    }
256}
257
258impl CornerRefiner for ForstnerRefiner {
259    #[inline]
260    fn radius(&self) -> i32 {
261        // Gradients sample one pixel beyond the interior, so reserve an extra pixel.
262        self.cfg.radius + 1
263    }
264
265    fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
266        let img = match ctx.image {
267            Some(view) => view,
268            None => {
269                return RefineResult {
270                    x: seed_xy[0],
271                    y: seed_xy[1],
272                    score: 0.0,
273                    status: RefineStatus::Rejected,
274                }
275            }
276        };
277
278        let cx = seed_xy[0].round() as i32;
279        let cy = seed_xy[1].round() as i32;
280        let patch_r = self.cfg.radius;
281
282        if !img.supports_patch(cx, cy, patch_r + 1) {
283            return RefineResult {
284                x: seed_xy[0],
285                y: seed_xy[1],
286                score: 0.0,
287                status: RefineStatus::OutOfBounds,
288            };
289        }
290
291        let mut a00 = 0.0;
292        let mut a01 = 0.0;
293        let mut a11 = 0.0;
294        let mut bx = 0.0;
295        let mut by = 0.0;
296
297        for dy in -patch_r..=patch_r {
298            let gy = cy + dy;
299            for dx in -patch_r..=patch_r {
300                let gx = cx + dx;
301
302                let ix_plus = img.sample(gx + 1, gy);
303                let ix_minus = img.sample(gx - 1, gy);
304                let iy_plus = img.sample(gx, gy + 1);
305                let iy_minus = img.sample(gx, gy - 1);
306
307                let gx_f = 0.5 * (ix_plus - ix_minus);
308                let gy_f = 0.5 * (iy_plus - iy_minus);
309
310                let px = gx as f32 - seed_xy[0];
311                let py = gy as f32 - seed_xy[1];
312                let gxgx = gx_f * gx_f;
313                let gxgy = gx_f * gy_f;
314                let gygy = gy_f * gy_f;
315                let dist2 = px * px + py * py;
316                let w = 1.0 / (1.0 + 0.5 * dist2);
317
318                a00 += w * gxgx;
319                a01 += w * gxgy;
320                a11 += w * gygy;
321
322                // b = Σ w g gᵀ p  (derivation from minimizing first-moment error)
323                bx += w * (gxgx * px + gxgy * py);
324                by += w * (gxgy * px + gygy * py);
325            }
326        }
327
328        let trace = a00 + a11;
329        let det = a00 * a11 - a01 * a01;
330        if trace < self.cfg.min_trace || det <= self.cfg.min_det {
331            return RefineResult {
332                x: seed_xy[0],
333                y: seed_xy[1],
334                score: det,
335                status: RefineStatus::IllConditioned,
336            };
337        }
338
339        let discr = (trace * trace - 4.0 * det).max(0.0).sqrt();
340        let lambda_min = 0.5 * (trace - discr);
341        let lambda_max = 0.5 * (trace + discr);
342
343        if lambda_min <= 0.0 {
344            return RefineResult {
345                x: seed_xy[0],
346                y: seed_xy[1],
347                score: det,
348                status: RefineStatus::IllConditioned,
349            };
350        }
351
352        let cond = lambda_max / lambda_min;
353        if !cond.is_finite() || cond > self.cfg.max_condition_number {
354            return RefineResult {
355                x: seed_xy[0],
356                y: seed_xy[1],
357                score: det,
358                status: RefineStatus::IllConditioned,
359            };
360        }
361
362        let inv_det = 1.0 / det;
363        let ux = (a11 * bx - a01 * by) * inv_det;
364        let uy = (-a01 * bx + a00 * by) * inv_det;
365
366        let max_off = self.cfg.max_offset.min(self.cfg.radius as f32 + 0.5);
367        if ux.abs() > max_off || uy.abs() > max_off {
368            return RefineResult {
369                x: seed_xy[0],
370                y: seed_xy[1],
371                score: det,
372                status: RefineStatus::Rejected,
373            };
374        }
375
376        let score = det / (trace * trace + 1e-6);
377        RefineResult::accepted([seed_xy[0] + ux, seed_xy[1] + uy], score)
378    }
379}
380
381/// Quadratic saddle-point surface refiner.
382#[derive(Clone, Copy, Debug)]
383pub struct SaddlePointConfig {
384    pub radius: i32,
385    pub det_margin: f32,
386    pub max_offset: f32,
387    pub min_abs_det: f32,
388}
389
390impl Default for SaddlePointConfig {
391    fn default() -> Self {
392        Self {
393            radius: 2,
394            det_margin: 1e-3,
395            max_offset: 1.5,
396            min_abs_det: 1e-4,
397        }
398    }
399}
400
401#[derive(Debug)]
402pub struct SaddlePointRefiner {
403    cfg: SaddlePointConfig,
404    m: [f32; 36],
405    rhs: [f32; 6],
406}
407
408impl SaddlePointRefiner {
409    pub fn new(cfg: SaddlePointConfig) -> Self {
410        Self {
411            cfg,
412            m: [0.0; 36],
413            rhs: [0.0; 6],
414        }
415    }
416
417    fn solve_6x6(&mut self) -> Option<[f32; 6]> {
418        // Simple Gauss-Jordan elimination with partial pivoting on the stack.
419        for i in 0..6 {
420            let mut pivot = i;
421            let mut pivot_val = self.m[i * 6 + i].abs();
422            for r in (i + 1)..6 {
423                let v = self.m[r * 6 + i].abs();
424                if v > pivot_val {
425                    pivot = r;
426                    pivot_val = v;
427                }
428            }
429
430            if pivot_val < 1e-9 {
431                return None;
432            }
433
434            if pivot != i {
435                for c in i..6 {
436                    self.m.swap(i * 6 + c, pivot * 6 + c);
437                }
438                self.rhs.swap(i, pivot);
439            }
440
441            let diag = self.m[i * 6 + i];
442            let inv_diag = 1.0 / diag;
443
444            for c in i..6 {
445                self.m[i * 6 + c] *= inv_diag;
446            }
447            self.rhs[i] *= inv_diag;
448
449            for r in 0..6 {
450                if r == i {
451                    continue;
452                }
453                let factor = self.m[r * 6 + i];
454                if factor == 0.0 {
455                    continue;
456                }
457                for c in i..6 {
458                    self.m[r * 6 + c] -= factor * self.m[i * 6 + c];
459                }
460                self.rhs[r] -= factor * self.rhs[i];
461            }
462        }
463
464        let mut out = [0.0f32; 6];
465        out.copy_from_slice(&self.rhs);
466        Some(out)
467    }
468}
469
470impl CornerRefiner for SaddlePointRefiner {
471    #[inline]
472    fn radius(&self) -> i32 {
473        self.cfg.radius
474    }
475
476    fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
477        let img = match ctx.image {
478            Some(view) => view,
479            None => {
480                return RefineResult {
481                    x: seed_xy[0],
482                    y: seed_xy[1],
483                    score: 0.0,
484                    status: RefineStatus::Rejected,
485                }
486            }
487        };
488
489        let cx = seed_xy[0].round() as i32;
490        let cy = seed_xy[1].round() as i32;
491        let r = self.cfg.radius;
492
493        if !img.supports_patch(cx, cy, r) {
494            return RefineResult {
495                x: seed_xy[0],
496                y: seed_xy[1],
497                score: 0.0,
498                status: RefineStatus::OutOfBounds,
499            };
500        }
501
502        let mut sum = 0.0f32;
503        let mut count = 0.0f32;
504        for dy in -r..=r {
505            let gy = cy + dy;
506            for dx in -r..=r {
507                let gx = cx + dx;
508                sum += img.sample(gx, gy);
509                count += 1.0;
510            }
511        }
512
513        let mean = if count > 0.0 { sum / count } else { 0.0 };
514
515        self.m.fill(0.0);
516        self.rhs.fill(0.0);
517
518        for dy in -r..=r {
519            let gy = cy + dy;
520            for dx in -r..=r {
521                let gx = cx + dx;
522                let i = img.sample(gx, gy) - mean;
523
524                let x = gx as f32 - seed_xy[0];
525                let y = gy as f32 - seed_xy[1];
526                let phi = [x * x, x * y, y * y, x, y, 1.0];
527
528                for row in 0..6 {
529                    self.rhs[row] += phi[row] * i;
530                    for col in row..6 {
531                        self.m[row * 6 + col] += phi[row] * phi[col];
532                    }
533                }
534            }
535        }
536
537        // Fill the lower triangle to make elimination logic simpler.
538        for row in 0..6 {
539            for col in 0..row {
540                self.m[row * 6 + col] = self.m[col * 6 + row];
541            }
542        }
543
544        let coeffs = match self.solve_6x6() {
545            Some(c) => c,
546            None => {
547                return RefineResult {
548                    x: seed_xy[0],
549                    y: seed_xy[1],
550                    score: 0.0,
551                    status: RefineStatus::IllConditioned,
552                }
553            }
554        };
555
556        let a = coeffs[0];
557        let b = coeffs[1];
558        let c = coeffs[2];
559        let d = coeffs[3];
560        let e = coeffs[4];
561
562        let det_h = 4.0 * a * c - b * b;
563        if det_h > -self.cfg.det_margin || det_h.abs() < self.cfg.min_abs_det {
564            return RefineResult {
565                x: seed_xy[0],
566                y: seed_xy[1],
567                score: det_h,
568                status: RefineStatus::IllConditioned,
569            };
570        }
571
572        let inv_det = 1.0 / det_h;
573        let ux = -(2.0 * c * d - b * e) * inv_det;
574        let uy = (b * d - 2.0 * a * e) * inv_det;
575
576        let max_off = self.cfg.max_offset.min(r as f32 + 0.5);
577        if ux.abs() > max_off || uy.abs() > max_off {
578            return RefineResult {
579                x: seed_xy[0],
580                y: seed_xy[1],
581                score: det_h,
582                status: RefineStatus::Rejected,
583            };
584        }
585
586        let score = (-det_h).sqrt();
587        RefineResult::accepted([seed_xy[0] + ux, seed_xy[1] + uy], score)
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    fn synthetic_checkerboard(size: usize, offset: (f32, f32), dark: u8, bright: u8) -> Vec<u8> {
596        let mut img = vec![0u8; size * size];
597        let ox = offset.0;
598        let oy = offset.1;
599        for y in 0..size {
600            for x in 0..size {
601                let xf = x as f32 - ox;
602                let yf = y as f32 - oy;
603                let dark_quad = (xf >= 0.0 && yf >= 0.0) || (xf < 0.0 && yf < 0.0);
604                img[y * size + x] = if dark_quad { dark } else { bright };
605            }
606        }
607        // Mild blur to provide gradients.
608        let mut blurred = img.clone();
609        for y in 1..(size - 1) {
610            for x in 1..(size - 1) {
611                let mut acc = 0u32;
612                for ky in -1..=1 {
613                    for kx in -1..=1 {
614                        acc +=
615                            img[(y as i32 + ky) as usize * size + (x as i32 + kx) as usize] as u32;
616                    }
617                }
618                blurred[y * size + x] = (acc / 9) as u8;
619            }
620        }
621        blurred
622    }
623
624    #[test]
625    fn center_of_mass_matches_expected_centroid() {
626        let mut resp = ResponseMap {
627            w: 7,
628            h: 7,
629            data: vec![0.0; 49],
630        };
631        // Put asymmetric weights so the centroid is easy to predict.
632        resp.data[3 * 7 + 3] = 10.0;
633        resp.data[3 * 7 + 4] = 5.0;
634        resp.data[4 * 7 + 3] = 5.0;
635        resp.data[4 * 7 + 4] = 2.0;
636
637        let mut refiner = CenterOfMassRefiner::new(CenterOfMassConfig { radius: 1 });
638        let ctx = RefineContext {
639            image: None,
640            response: Some(&resp),
641        };
642        let res = refiner.refine([3.0, 3.0], ctx);
643        assert_eq!(res.status, RefineStatus::Accepted);
644        // Compute expected centroid explicitly.
645        let mut sx = 0.0;
646        let mut sy = 0.0;
647        let mut sw = 0.0;
648        for dy in -1..=1 {
649            for dx in -1..=1 {
650                let xx = (3 + dx) as usize;
651                let yy = (3 + dy) as usize;
652                let w_px = resp.at(xx, yy).max(0.0);
653                sx += xx as f32 * w_px;
654                sy += yy as f32 * w_px;
655                sw += w_px;
656            }
657        }
658        let expected = [sx / sw, sy / sw];
659        assert!((res.x - expected[0]).abs() < 1e-4);
660        assert!((res.y - expected[1]).abs() < 1e-4);
661    }
662
663    #[test]
664    fn forstner_refines_toward_true_offset() {
665        let img = synthetic_checkerboard(15, (7.35, 7.8), 40, 220);
666        let view = ImageView::from_u8_slice(15, 15, &img).unwrap();
667        let ctx = RefineContext {
668            image: Some(view),
669            response: None,
670        };
671        let mut refiner = ForstnerRefiner::new(ForstnerConfig::default());
672        let res = refiner.refine([7.0, 8.0], ctx);
673        assert_eq!(res.status, RefineStatus::Accepted);
674        let true_xy = [7.35f32, 7.8f32];
675        let seed_err = ((7.0 - true_xy[0]).powi(2) + (8.0 - true_xy[1]).powi(2)).sqrt();
676        let refined_err = ((res.x - true_xy[0]).powi(2) + (res.y - true_xy[1]).powi(2)).sqrt();
677        assert!(
678            refined_err <= seed_err * 1.6 && refined_err < 1.0,
679            "refined_err {refined_err} seed_err {seed_err} res {:?}",
680            (res.x, res.y)
681        );
682    }
683
684    #[test]
685    fn saddle_point_recovers_stationary_point_and_rejects_flat() {
686        let img = synthetic_checkerboard(17, (8.2, 8.6), 30, 230);
687        let view = ImageView::from_u8_slice(17, 17, &img).unwrap();
688        let ctx = RefineContext {
689            image: Some(view),
690            response: None,
691        };
692        let mut refiner = SaddlePointRefiner::new(SaddlePointConfig::default());
693        let res = refiner.refine([8.0, 9.0], ctx);
694        assert_eq!(res.status, RefineStatus::Accepted);
695        let true_xy = [8.2f32, 8.6f32];
696        let seed_err = ((8.0 - true_xy[0]).powi(2) + (9.0 - true_xy[1]).powi(2)).sqrt();
697        let refined_err = ((res.x - true_xy[0]).powi(2) + (res.y - true_xy[1]).powi(2)).sqrt();
698        assert!(
699            refined_err <= seed_err * 1.6 && refined_err < 1.0,
700            "refined_err {refined_err} seed_err {seed_err} res {:?}",
701            (res.x, res.y)
702        );
703
704        let flat = vec![128u8; 25];
705        let flat_view = ImageView::from_u8_slice(5, 5, &flat).unwrap();
706        let flat_ctx = RefineContext {
707            image: Some(flat_view),
708            response: None,
709        };
710        let flat_res = refiner.refine([2.0, 2.0], flat_ctx);
711        assert_ne!(flat_res.status, RefineStatus::Accepted);
712    }
713}