chess_corners_core/refine/
saddle_point.rs1use super::{CornerRefiner, RefineContext, RefineResult, RefineStatus};
10use serde::{Deserialize, Serialize};
11
12#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
14#[serde(default)]
15pub struct SaddlePointConfig {
16 pub radius: i32,
17 pub det_margin: f32,
18 pub max_offset: f32,
19 pub min_abs_det: f32,
20}
21
22impl Default for SaddlePointConfig {
23 fn default() -> Self {
24 Self {
25 radius: 2,
26 det_margin: 1e-3,
27 max_offset: 1.5,
28 min_abs_det: 1e-4,
29 }
30 }
31}
32
33#[derive(Debug)]
34pub struct SaddlePointRefiner {
35 cfg: SaddlePointConfig,
36 m: [f32; 36],
37 rhs: [f32; 6],
38}
39
40impl SaddlePointRefiner {
41 pub fn new(cfg: SaddlePointConfig) -> Self {
42 Self {
43 cfg,
44 m: [0.0; 36],
45 rhs: [0.0; 6],
46 }
47 }
48
49 fn solve_6x6(&mut self) -> Option<[f32; 6]> {
50 for i in 0..6 {
52 let mut pivot = i;
53 let mut pivot_val = self.m[i * 6 + i].abs();
54 for r in (i + 1)..6 {
55 let v = self.m[r * 6 + i].abs();
56 if v > pivot_val {
57 pivot = r;
58 pivot_val = v;
59 }
60 }
61
62 if pivot_val < 1e-9 {
63 return None;
64 }
65
66 if pivot != i {
67 for c in i..6 {
68 self.m.swap(i * 6 + c, pivot * 6 + c);
69 }
70 self.rhs.swap(i, pivot);
71 }
72
73 let diag = self.m[i * 6 + i];
74 let inv_diag = 1.0 / diag;
75
76 for c in i..6 {
77 self.m[i * 6 + c] *= inv_diag;
78 }
79 self.rhs[i] *= inv_diag;
80
81 for r in 0..6 {
82 if r == i {
83 continue;
84 }
85 let factor = self.m[r * 6 + i];
86 if factor == 0.0 {
87 continue;
88 }
89 for c in i..6 {
90 self.m[r * 6 + c] -= factor * self.m[i * 6 + c];
91 }
92 self.rhs[r] -= factor * self.rhs[i];
93 }
94 }
95
96 let mut out = [0.0f32; 6];
97 out.copy_from_slice(&self.rhs);
98 Some(out)
99 }
100}
101
102impl CornerRefiner for SaddlePointRefiner {
103 #[inline]
104 fn radius(&self) -> i32 {
105 self.cfg.radius
106 }
107
108 fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
109 let img = match ctx.image {
110 Some(view) => view,
111 None => {
112 return RefineResult {
113 x: seed_xy[0],
114 y: seed_xy[1],
115 score: 0.0,
116 status: RefineStatus::Rejected,
117 }
118 }
119 };
120
121 let cx = seed_xy[0].round() as i32;
122 let cy = seed_xy[1].round() as i32;
123 let r = self.cfg.radius;
124
125 if !img.supports_patch(cx, cy, r) {
126 return RefineResult {
127 x: seed_xy[0],
128 y: seed_xy[1],
129 score: 0.0,
130 status: RefineStatus::OutOfBounds,
131 };
132 }
133
134 let mut sum = 0.0f32;
135 let mut count = 0.0f32;
136 for dy in -r..=r {
137 let gy = cy + dy;
138 for dx in -r..=r {
139 let gx = cx + dx;
140 sum += img.sample(gx, gy);
141 count += 1.0;
142 }
143 }
144
145 let mean = if count > 0.0 { sum / count } else { 0.0 };
146
147 self.m.fill(0.0);
148 self.rhs.fill(0.0);
149
150 for dy in -r..=r {
151 let gy = cy + dy;
152 for dx in -r..=r {
153 let gx = cx + dx;
154 let i = img.sample(gx, gy) - mean;
155
156 let x = gx as f32 - seed_xy[0];
157 let y = gy as f32 - seed_xy[1];
158 let phi = [x * x, x * y, y * y, x, y, 1.0];
159
160 for row in 0..6 {
161 self.rhs[row] += phi[row] * i;
162 for col in row..6 {
163 self.m[row * 6 + col] += phi[row] * phi[col];
164 }
165 }
166 }
167 }
168
169 for row in 0..6 {
171 for col in 0..row {
172 self.m[row * 6 + col] = self.m[col * 6 + row];
173 }
174 }
175
176 let coeffs = match self.solve_6x6() {
177 Some(c) => c,
178 None => {
179 return RefineResult {
180 x: seed_xy[0],
181 y: seed_xy[1],
182 score: 0.0,
183 status: RefineStatus::IllConditioned,
184 }
185 }
186 };
187
188 let a = coeffs[0];
189 let b = coeffs[1];
190 let c = coeffs[2];
191 let d = coeffs[3];
192 let e = coeffs[4];
193
194 let det_h = 4.0 * a * c - b * b;
195 if det_h > -self.cfg.det_margin || det_h.abs() < self.cfg.min_abs_det {
196 return RefineResult {
197 x: seed_xy[0],
198 y: seed_xy[1],
199 score: det_h,
200 status: RefineStatus::IllConditioned,
201 };
202 }
203
204 let inv_det = 1.0 / det_h;
205 let ux = -(2.0 * c * d - b * e) * inv_det;
206 let uy = (b * d - 2.0 * a * e) * inv_det;
207
208 let max_off = self.cfg.max_offset.min(r as f32 + 0.5);
209 if ux.abs() > max_off || uy.abs() > max_off {
210 return RefineResult {
211 x: seed_xy[0],
212 y: seed_xy[1],
213 score: det_h,
214 status: RefineStatus::Rejected,
215 };
216 }
217
218 let score = (-det_h).sqrt();
219 RefineResult::accepted([seed_xy[0] + ux, seed_xy[1] + uy], score)
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::super::test_fixtures::synthetic_checkerboard;
226 use super::*;
227 use crate::imageview::ImageView;
228
229 #[test]
230 fn saddle_point_recovers_stationary_point_and_rejects_flat() {
231 let img = synthetic_checkerboard(17, (8.2, 8.6), 30, 230);
232 let view = ImageView::from_u8_slice(17, 17, &img).unwrap();
233 let ctx = RefineContext {
234 image: Some(view),
235 response: None,
236 };
237 let mut refiner = SaddlePointRefiner::new(SaddlePointConfig::default());
238 let res = refiner.refine([8.0, 9.0], ctx);
239 assert_eq!(res.status, RefineStatus::Accepted);
240 let true_xy = [8.2f32, 8.6f32];
241 let seed_err = ((8.0 - true_xy[0]).powi(2) + (9.0 - true_xy[1]).powi(2)).sqrt();
242 let refined_err = ((res.x - true_xy[0]).powi(2) + (res.y - true_xy[1]).powi(2)).sqrt();
243 assert!(
244 refined_err <= seed_err * 1.6 && refined_err < 1.0,
245 "refined_err {refined_err} seed_err {seed_err} res {:?}",
246 (res.x, res.y)
247 );
248
249 let flat = vec![128u8; 25];
250 let flat_view = ImageView::from_u8_slice(5, 5, &flat).unwrap();
251 let flat_ctx = RefineContext {
252 image: Some(flat_view),
253 response: None,
254 };
255 let flat_res = refiner.refine([2.0, 2.0], flat_ctx);
256 assert_ne!(flat_res.status, RefineStatus::Accepted);
257 }
258}