Skip to main content

chess_corners_core/refine/
saddle_point.rs

1//! Saddle-point quadratic-surface refiner.
2//!
3//! Fits a 2nd-order surface `I(x, y) = a x² + b x y + c y² + d x + e y + f`
4//! to the image patch around the seed and locates the unique
5//! stationary point of the resulting quadratic. The Hessian
6//! `[2a b; b 2c]` must have negative determinant (a saddle) for the
7//! corner to be accepted.
8
9use super::{CornerRefiner, RefineContext, RefineResult, RefineStatus};
10use serde::{Deserialize, Serialize};
11
12/// Configuration for the [`SaddlePointRefiner`].
13#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
14#[serde(default)]
15pub struct SaddlePointConfig {
16    pub radius: i32,
17    pub det_margin: f32,
18    pub max_offset: f32,
19    pub min_abs_det: f32,
20}
21
22impl Default for SaddlePointConfig {
23    fn default() -> Self {
24        Self {
25            radius: 2,
26            det_margin: 1e-3,
27            max_offset: 1.5,
28            min_abs_det: 1e-4,
29        }
30    }
31}
32
33#[derive(Debug)]
34pub struct SaddlePointRefiner {
35    cfg: SaddlePointConfig,
36    m: [f32; 36],
37    rhs: [f32; 6],
38}
39
40impl SaddlePointRefiner {
41    pub fn new(cfg: SaddlePointConfig) -> Self {
42        Self {
43            cfg,
44            m: [0.0; 36],
45            rhs: [0.0; 6],
46        }
47    }
48
49    fn solve_6x6(&mut self) -> Option<[f32; 6]> {
50        // Simple Gauss-Jordan elimination with partial pivoting on the stack.
51        for i in 0..6 {
52            let mut pivot = i;
53            let mut pivot_val = self.m[i * 6 + i].abs();
54            for r in (i + 1)..6 {
55                let v = self.m[r * 6 + i].abs();
56                if v > pivot_val {
57                    pivot = r;
58                    pivot_val = v;
59                }
60            }
61
62            if pivot_val < 1e-9 {
63                return None;
64            }
65
66            if pivot != i {
67                for c in i..6 {
68                    self.m.swap(i * 6 + c, pivot * 6 + c);
69                }
70                self.rhs.swap(i, pivot);
71            }
72
73            let diag = self.m[i * 6 + i];
74            let inv_diag = 1.0 / diag;
75
76            for c in i..6 {
77                self.m[i * 6 + c] *= inv_diag;
78            }
79            self.rhs[i] *= inv_diag;
80
81            for r in 0..6 {
82                if r == i {
83                    continue;
84                }
85                let factor = self.m[r * 6 + i];
86                if factor == 0.0 {
87                    continue;
88                }
89                for c in i..6 {
90                    self.m[r * 6 + c] -= factor * self.m[i * 6 + c];
91                }
92                self.rhs[r] -= factor * self.rhs[i];
93            }
94        }
95
96        let mut out = [0.0f32; 6];
97        out.copy_from_slice(&self.rhs);
98        Some(out)
99    }
100}
101
102impl CornerRefiner for SaddlePointRefiner {
103    #[inline]
104    fn radius(&self) -> i32 {
105        self.cfg.radius
106    }
107
108    fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
109        let img = match ctx.image {
110            Some(view) => view,
111            None => {
112                return RefineResult {
113                    x: seed_xy[0],
114                    y: seed_xy[1],
115                    score: 0.0,
116                    status: RefineStatus::Rejected,
117                }
118            }
119        };
120
121        let cx = seed_xy[0].round() as i32;
122        let cy = seed_xy[1].round() as i32;
123        let r = self.cfg.radius;
124
125        if !img.supports_patch(cx, cy, r) {
126            return RefineResult {
127                x: seed_xy[0],
128                y: seed_xy[1],
129                score: 0.0,
130                status: RefineStatus::OutOfBounds,
131            };
132        }
133
134        let mut sum = 0.0f32;
135        let mut count = 0.0f32;
136        for dy in -r..=r {
137            let gy = cy + dy;
138            for dx in -r..=r {
139                let gx = cx + dx;
140                sum += img.sample(gx, gy);
141                count += 1.0;
142            }
143        }
144
145        let mean = if count > 0.0 { sum / count } else { 0.0 };
146
147        self.m.fill(0.0);
148        self.rhs.fill(0.0);
149
150        for dy in -r..=r {
151            let gy = cy + dy;
152            for dx in -r..=r {
153                let gx = cx + dx;
154                let i = img.sample(gx, gy) - mean;
155
156                let x = gx as f32 - seed_xy[0];
157                let y = gy as f32 - seed_xy[1];
158                let phi = [x * x, x * y, y * y, x, y, 1.0];
159
160                for row in 0..6 {
161                    self.rhs[row] += phi[row] * i;
162                    for col in row..6 {
163                        self.m[row * 6 + col] += phi[row] * phi[col];
164                    }
165                }
166            }
167        }
168
169        // Fill the lower triangle to make elimination logic simpler.
170        for row in 0..6 {
171            for col in 0..row {
172                self.m[row * 6 + col] = self.m[col * 6 + row];
173            }
174        }
175
176        let coeffs = match self.solve_6x6() {
177            Some(c) => c,
178            None => {
179                return RefineResult {
180                    x: seed_xy[0],
181                    y: seed_xy[1],
182                    score: 0.0,
183                    status: RefineStatus::IllConditioned,
184                }
185            }
186        };
187
188        let a = coeffs[0];
189        let b = coeffs[1];
190        let c = coeffs[2];
191        let d = coeffs[3];
192        let e = coeffs[4];
193
194        let det_h = 4.0 * a * c - b * b;
195        if det_h > -self.cfg.det_margin || det_h.abs() < self.cfg.min_abs_det {
196            return RefineResult {
197                x: seed_xy[0],
198                y: seed_xy[1],
199                score: det_h,
200                status: RefineStatus::IllConditioned,
201            };
202        }
203
204        let inv_det = 1.0 / det_h;
205        let ux = -(2.0 * c * d - b * e) * inv_det;
206        let uy = (b * d - 2.0 * a * e) * inv_det;
207
208        let max_off = self.cfg.max_offset.min(r as f32 + 0.5);
209        if ux.abs() > max_off || uy.abs() > max_off {
210            return RefineResult {
211                x: seed_xy[0],
212                y: seed_xy[1],
213                score: det_h,
214                status: RefineStatus::Rejected,
215            };
216        }
217
218        let score = (-det_h).sqrt();
219        RefineResult::accepted([seed_xy[0] + ux, seed_xy[1] + uy], score)
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::super::test_fixtures::synthetic_checkerboard;
226    use super::*;
227    use crate::imageview::ImageView;
228
229    #[test]
230    fn saddle_point_recovers_stationary_point_and_rejects_flat() {
231        let img = synthetic_checkerboard(17, (8.2, 8.6), 30, 230);
232        let view = ImageView::from_u8_slice(17, 17, &img).unwrap();
233        let ctx = RefineContext {
234            image: Some(view),
235            response: None,
236        };
237        let mut refiner = SaddlePointRefiner::new(SaddlePointConfig::default());
238        let res = refiner.refine([8.0, 9.0], ctx);
239        assert_eq!(res.status, RefineStatus::Accepted);
240        let true_xy = [8.2f32, 8.6f32];
241        let seed_err = ((8.0 - true_xy[0]).powi(2) + (9.0 - true_xy[1]).powi(2)).sqrt();
242        let refined_err = ((res.x - true_xy[0]).powi(2) + (res.y - true_xy[1]).powi(2)).sqrt();
243        assert!(
244            refined_err <= seed_err * 1.6 && refined_err < 1.0,
245            "refined_err {refined_err} seed_err {seed_err} res {:?}",
246            (res.x, res.y)
247        );
248
249        let flat = vec![128u8; 25];
250        let flat_view = ImageView::from_u8_slice(5, 5, &flat).unwrap();
251        let flat_ctx = RefineContext {
252            image: Some(flat_view),
253            response: None,
254        };
255        let flat_res = refiner.refine([2.0, 2.0], flat_ctx);
256        assert_ne!(flat_res.status, RefineStatus::Accepted);
257    }
258}