1use crate::imageview::ImageView;
8use crate::ResponseMap;
9
10#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12#[non_exhaustive]
13pub enum RefineStatus {
14 Accepted,
15 Rejected,
16 OutOfBounds,
17 IllConditioned,
18}
19
20#[derive(Copy, Clone, Debug)]
22pub struct RefineResult {
23 pub x: f32,
25 pub y: f32,
27 pub score: f32,
28 pub status: RefineStatus,
29}
30
31impl RefineResult {
32 #[inline]
33 pub fn accepted(xy: [f32; 2], score: f32) -> Self {
34 Self {
35 x: xy[0],
36 y: xy[1],
37 score,
38 status: RefineStatus::Accepted,
39 }
40 }
41}
42
43#[derive(Copy, Clone, Debug, Default)]
45pub struct RefineContext<'a> {
46 pub image: Option<ImageView<'a>>,
47 pub response: Option<&'a ResponseMap>,
48}
49
50pub trait CornerRefiner {
52 fn radius(&self) -> i32;
54 fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult;
55}
56
57#[derive(Clone, Debug)]
59#[non_exhaustive]
60pub enum RefinerKind {
61 CenterOfMass(CenterOfMassConfig),
62 Forstner(ForstnerConfig),
63 SaddlePoint(SaddlePointConfig),
64}
65
66impl Default for RefinerKind {
67 fn default() -> Self {
68 Self::CenterOfMass(CenterOfMassConfig::default())
69 }
70}
71
72#[derive(Debug)]
74pub enum Refiner {
75 CenterOfMass(CenterOfMassRefiner),
76 Forstner(ForstnerRefiner),
77 SaddlePoint(SaddlePointRefiner),
78}
79
80impl Refiner {
81 pub fn from_kind(kind: RefinerKind) -> Self {
82 match kind {
83 RefinerKind::CenterOfMass(cfg) => Refiner::CenterOfMass(CenterOfMassRefiner::new(cfg)),
84 RefinerKind::Forstner(cfg) => Refiner::Forstner(ForstnerRefiner::new(cfg)),
85 RefinerKind::SaddlePoint(cfg) => Refiner::SaddlePoint(SaddlePointRefiner::new(cfg)),
86 }
87 }
88}
89
90impl CornerRefiner for Refiner {
91 #[inline]
92 fn radius(&self) -> i32 {
93 match self {
94 Refiner::CenterOfMass(r) => r.radius(),
95 Refiner::Forstner(r) => r.radius(),
96 Refiner::SaddlePoint(r) => r.radius(),
97 }
98 }
99
100 #[inline]
101 fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
102 match self {
103 Refiner::CenterOfMass(r) => r.refine(seed_xy, ctx),
104 Refiner::Forstner(r) => r.refine(seed_xy, ctx),
105 Refiner::SaddlePoint(r) => r.refine(seed_xy, ctx),
106 }
107 }
108}
109
110#[derive(Clone, Copy, Debug)]
112pub struct CenterOfMassConfig {
113 pub radius: i32,
114}
115
116impl Default for CenterOfMassConfig {
117 fn default() -> Self {
118 Self { radius: 2 }
119 }
120}
121
122#[derive(Debug)]
123pub struct CenterOfMassRefiner {
124 cfg: CenterOfMassConfig,
125}
126
127impl CenterOfMassRefiner {
128 pub fn new(cfg: CenterOfMassConfig) -> Self {
129 Self { cfg }
130 }
131}
132
133impl CornerRefiner for CenterOfMassRefiner {
134 #[inline]
135 fn radius(&self) -> i32 {
136 self.cfg.radius
137 }
138
139 fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
140 let resp = match ctx.response {
141 Some(r) => r,
142 None => {
143 return RefineResult {
144 x: seed_xy[0],
145 y: seed_xy[1],
146 score: 0.0,
147 status: RefineStatus::Rejected,
148 }
149 }
150 };
151
152 let x = seed_xy[0].round() as i32;
153 let y = seed_xy[1].round() as i32;
154 let r = self.cfg.radius;
155
156 let mut sx = 0.0;
157 let mut sy = 0.0;
158 let mut sw = 0.0;
159
160 let w = resp.w as i32;
161 let h = resp.h as i32;
162
163 if x < r || y < r || x >= w - r || y >= h - r {
164 return RefineResult {
165 x: seed_xy[0],
166 y: seed_xy[1],
167 score: 0.0,
168 status: RefineStatus::OutOfBounds,
169 };
170 }
171
172 for dy in -r..=r {
173 let yy = (y + dy).clamp(0, h - 1) as usize;
174 for dx in -r..=r {
175 let xx = (x + dx).clamp(0, w - 1) as usize;
176 let w_px = resp.at(xx, yy).max(0.0);
177 sx += (xx as f32) * w_px;
178 sy += (yy as f32) * w_px;
179 sw += w_px;
180 }
181 }
182
183 if sw > 0.0 {
184 RefineResult::accepted([sx / sw, sy / sw], sw)
185 } else {
186 RefineResult {
187 x: seed_xy[0],
188 y: seed_xy[1],
189 score: 0.0,
190 status: RefineStatus::Accepted,
191 }
192 }
193 }
194}
195
196#[derive(Clone, Copy, Debug)]
207pub struct ForstnerConfig {
208 pub radius: i32,
212 pub min_trace: f32,
217 pub min_det: f32,
222 pub max_condition_number: f32,
228 pub max_offset: f32,
233}
234
235impl Default for ForstnerConfig {
236 fn default() -> Self {
237 Self {
238 radius: 2,
239 min_trace: 25.0,
240 min_det: 1e-3,
241 max_condition_number: 50.0,
242 max_offset: 1.5,
243 }
244 }
245}
246
247#[derive(Debug)]
248pub struct ForstnerRefiner {
249 cfg: ForstnerConfig,
250}
251
252impl ForstnerRefiner {
253 pub fn new(cfg: ForstnerConfig) -> Self {
254 Self { cfg }
255 }
256}
257
258impl CornerRefiner for ForstnerRefiner {
259 #[inline]
260 fn radius(&self) -> i32 {
261 self.cfg.radius + 1
263 }
264
265 fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
266 let img = match ctx.image {
267 Some(view) => view,
268 None => {
269 return RefineResult {
270 x: seed_xy[0],
271 y: seed_xy[1],
272 score: 0.0,
273 status: RefineStatus::Rejected,
274 }
275 }
276 };
277
278 let cx = seed_xy[0].round() as i32;
279 let cy = seed_xy[1].round() as i32;
280 let patch_r = self.cfg.radius;
281
282 if !img.supports_patch(cx, cy, patch_r + 1) {
283 return RefineResult {
284 x: seed_xy[0],
285 y: seed_xy[1],
286 score: 0.0,
287 status: RefineStatus::OutOfBounds,
288 };
289 }
290
291 let mut a00 = 0.0;
292 let mut a01 = 0.0;
293 let mut a11 = 0.0;
294 let mut bx = 0.0;
295 let mut by = 0.0;
296
297 for dy in -patch_r..=patch_r {
298 let gy = cy + dy;
299 for dx in -patch_r..=patch_r {
300 let gx = cx + dx;
301
302 let ix_plus = img.sample(gx + 1, gy);
303 let ix_minus = img.sample(gx - 1, gy);
304 let iy_plus = img.sample(gx, gy + 1);
305 let iy_minus = img.sample(gx, gy - 1);
306
307 let gx_f = 0.5 * (ix_plus - ix_minus);
308 let gy_f = 0.5 * (iy_plus - iy_minus);
309
310 let px = gx as f32 - seed_xy[0];
311 let py = gy as f32 - seed_xy[1];
312 let gxgx = gx_f * gx_f;
313 let gxgy = gx_f * gy_f;
314 let gygy = gy_f * gy_f;
315 let dist2 = px * px + py * py;
316 let w = 1.0 / (1.0 + 0.5 * dist2);
317
318 a00 += w * gxgx;
319 a01 += w * gxgy;
320 a11 += w * gygy;
321
322 bx += w * (gxgx * px + gxgy * py);
324 by += w * (gxgy * px + gygy * py);
325 }
326 }
327
328 let trace = a00 + a11;
329 let det = a00 * a11 - a01 * a01;
330 if trace < self.cfg.min_trace || det <= self.cfg.min_det {
331 return RefineResult {
332 x: seed_xy[0],
333 y: seed_xy[1],
334 score: det,
335 status: RefineStatus::IllConditioned,
336 };
337 }
338
339 let discr = (trace * trace - 4.0 * det).max(0.0).sqrt();
340 let lambda_min = 0.5 * (trace - discr);
341 let lambda_max = 0.5 * (trace + discr);
342
343 if lambda_min <= 0.0 {
344 return RefineResult {
345 x: seed_xy[0],
346 y: seed_xy[1],
347 score: det,
348 status: RefineStatus::IllConditioned,
349 };
350 }
351
352 let cond = lambda_max / lambda_min;
353 if !cond.is_finite() || cond > self.cfg.max_condition_number {
354 return RefineResult {
355 x: seed_xy[0],
356 y: seed_xy[1],
357 score: det,
358 status: RefineStatus::IllConditioned,
359 };
360 }
361
362 let inv_det = 1.0 / det;
363 let ux = (a11 * bx - a01 * by) * inv_det;
364 let uy = (-a01 * bx + a00 * by) * inv_det;
365
366 let max_off = self.cfg.max_offset.min(self.cfg.radius as f32 + 0.5);
367 if ux.abs() > max_off || uy.abs() > max_off {
368 return RefineResult {
369 x: seed_xy[0],
370 y: seed_xy[1],
371 score: det,
372 status: RefineStatus::Rejected,
373 };
374 }
375
376 let score = det / (trace * trace + 1e-6);
377 RefineResult::accepted([seed_xy[0] + ux, seed_xy[1] + uy], score)
378 }
379}
380
381#[derive(Clone, Copy, Debug)]
383pub struct SaddlePointConfig {
384 pub radius: i32,
385 pub det_margin: f32,
386 pub max_offset: f32,
387 pub min_abs_det: f32,
388}
389
390impl Default for SaddlePointConfig {
391 fn default() -> Self {
392 Self {
393 radius: 2,
394 det_margin: 1e-3,
395 max_offset: 1.5,
396 min_abs_det: 1e-4,
397 }
398 }
399}
400
401#[derive(Debug)]
402pub struct SaddlePointRefiner {
403 cfg: SaddlePointConfig,
404 m: [f32; 36],
405 rhs: [f32; 6],
406}
407
408impl SaddlePointRefiner {
409 pub fn new(cfg: SaddlePointConfig) -> Self {
410 Self {
411 cfg,
412 m: [0.0; 36],
413 rhs: [0.0; 6],
414 }
415 }
416
417 fn solve_6x6(&mut self) -> Option<[f32; 6]> {
418 for i in 0..6 {
420 let mut pivot = i;
421 let mut pivot_val = self.m[i * 6 + i].abs();
422 for r in (i + 1)..6 {
423 let v = self.m[r * 6 + i].abs();
424 if v > pivot_val {
425 pivot = r;
426 pivot_val = v;
427 }
428 }
429
430 if pivot_val < 1e-9 {
431 return None;
432 }
433
434 if pivot != i {
435 for c in i..6 {
436 self.m.swap(i * 6 + c, pivot * 6 + c);
437 }
438 self.rhs.swap(i, pivot);
439 }
440
441 let diag = self.m[i * 6 + i];
442 let inv_diag = 1.0 / diag;
443
444 for c in i..6 {
445 self.m[i * 6 + c] *= inv_diag;
446 }
447 self.rhs[i] *= inv_diag;
448
449 for r in 0..6 {
450 if r == i {
451 continue;
452 }
453 let factor = self.m[r * 6 + i];
454 if factor == 0.0 {
455 continue;
456 }
457 for c in i..6 {
458 self.m[r * 6 + c] -= factor * self.m[i * 6 + c];
459 }
460 self.rhs[r] -= factor * self.rhs[i];
461 }
462 }
463
464 let mut out = [0.0f32; 6];
465 out.copy_from_slice(&self.rhs);
466 Some(out)
467 }
468}
469
470impl CornerRefiner for SaddlePointRefiner {
471 #[inline]
472 fn radius(&self) -> i32 {
473 self.cfg.radius
474 }
475
476 fn refine(&mut self, seed_xy: [f32; 2], ctx: RefineContext<'_>) -> RefineResult {
477 let img = match ctx.image {
478 Some(view) => view,
479 None => {
480 return RefineResult {
481 x: seed_xy[0],
482 y: seed_xy[1],
483 score: 0.0,
484 status: RefineStatus::Rejected,
485 }
486 }
487 };
488
489 let cx = seed_xy[0].round() as i32;
490 let cy = seed_xy[1].round() as i32;
491 let r = self.cfg.radius;
492
493 if !img.supports_patch(cx, cy, r) {
494 return RefineResult {
495 x: seed_xy[0],
496 y: seed_xy[1],
497 score: 0.0,
498 status: RefineStatus::OutOfBounds,
499 };
500 }
501
502 let mut sum = 0.0f32;
503 let mut count = 0.0f32;
504 for dy in -r..=r {
505 let gy = cy + dy;
506 for dx in -r..=r {
507 let gx = cx + dx;
508 sum += img.sample(gx, gy);
509 count += 1.0;
510 }
511 }
512
513 let mean = if count > 0.0 { sum / count } else { 0.0 };
514
515 self.m.fill(0.0);
516 self.rhs.fill(0.0);
517
518 for dy in -r..=r {
519 let gy = cy + dy;
520 for dx in -r..=r {
521 let gx = cx + dx;
522 let i = img.sample(gx, gy) - mean;
523
524 let x = gx as f32 - seed_xy[0];
525 let y = gy as f32 - seed_xy[1];
526 let phi = [x * x, x * y, y * y, x, y, 1.0];
527
528 for row in 0..6 {
529 self.rhs[row] += phi[row] * i;
530 for col in row..6 {
531 self.m[row * 6 + col] += phi[row] * phi[col];
532 }
533 }
534 }
535 }
536
537 for row in 0..6 {
539 for col in 0..row {
540 self.m[row * 6 + col] = self.m[col * 6 + row];
541 }
542 }
543
544 let coeffs = match self.solve_6x6() {
545 Some(c) => c,
546 None => {
547 return RefineResult {
548 x: seed_xy[0],
549 y: seed_xy[1],
550 score: 0.0,
551 status: RefineStatus::IllConditioned,
552 }
553 }
554 };
555
556 let a = coeffs[0];
557 let b = coeffs[1];
558 let c = coeffs[2];
559 let d = coeffs[3];
560 let e = coeffs[4];
561
562 let det_h = 4.0 * a * c - b * b;
563 if det_h > -self.cfg.det_margin || det_h.abs() < self.cfg.min_abs_det {
564 return RefineResult {
565 x: seed_xy[0],
566 y: seed_xy[1],
567 score: det_h,
568 status: RefineStatus::IllConditioned,
569 };
570 }
571
572 let inv_det = 1.0 / det_h;
573 let ux = -(2.0 * c * d - b * e) * inv_det;
574 let uy = (b * d - 2.0 * a * e) * inv_det;
575
576 let max_off = self.cfg.max_offset.min(r as f32 + 0.5);
577 if ux.abs() > max_off || uy.abs() > max_off {
578 return RefineResult {
579 x: seed_xy[0],
580 y: seed_xy[1],
581 score: det_h,
582 status: RefineStatus::Rejected,
583 };
584 }
585
586 let score = (-det_h).sqrt();
587 RefineResult::accepted([seed_xy[0] + ux, seed_xy[1] + uy], score)
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 fn synthetic_checkerboard(size: usize, offset: (f32, f32), dark: u8, bright: u8) -> Vec<u8> {
596 let mut img = vec![0u8; size * size];
597 let ox = offset.0;
598 let oy = offset.1;
599 for y in 0..size {
600 for x in 0..size {
601 let xf = x as f32 - ox;
602 let yf = y as f32 - oy;
603 let dark_quad = (xf >= 0.0 && yf >= 0.0) || (xf < 0.0 && yf < 0.0);
604 img[y * size + x] = if dark_quad { dark } else { bright };
605 }
606 }
607 let mut blurred = img.clone();
609 for y in 1..(size - 1) {
610 for x in 1..(size - 1) {
611 let mut acc = 0u32;
612 for ky in -1..=1 {
613 for kx in -1..=1 {
614 acc +=
615 img[(y as i32 + ky) as usize * size + (x as i32 + kx) as usize] as u32;
616 }
617 }
618 blurred[y * size + x] = (acc / 9) as u8;
619 }
620 }
621 blurred
622 }
623
624 #[test]
625 fn center_of_mass_matches_expected_centroid() {
626 let mut resp = ResponseMap {
627 w: 7,
628 h: 7,
629 data: vec![0.0; 49],
630 };
631 resp.data[3 * 7 + 3] = 10.0;
633 resp.data[3 * 7 + 4] = 5.0;
634 resp.data[4 * 7 + 3] = 5.0;
635 resp.data[4 * 7 + 4] = 2.0;
636
637 let mut refiner = CenterOfMassRefiner::new(CenterOfMassConfig { radius: 1 });
638 let ctx = RefineContext {
639 image: None,
640 response: Some(&resp),
641 };
642 let res = refiner.refine([3.0, 3.0], ctx);
643 assert_eq!(res.status, RefineStatus::Accepted);
644 let mut sx = 0.0;
646 let mut sy = 0.0;
647 let mut sw = 0.0;
648 for dy in -1..=1 {
649 for dx in -1..=1 {
650 let xx = (3 + dx) as usize;
651 let yy = (3 + dy) as usize;
652 let w_px = resp.at(xx, yy).max(0.0);
653 sx += xx as f32 * w_px;
654 sy += yy as f32 * w_px;
655 sw += w_px;
656 }
657 }
658 let expected = [sx / sw, sy / sw];
659 assert!((res.x - expected[0]).abs() < 1e-4);
660 assert!((res.y - expected[1]).abs() < 1e-4);
661 }
662
663 #[test]
664 fn forstner_refines_toward_true_offset() {
665 let img = synthetic_checkerboard(15, (7.35, 7.8), 40, 220);
666 let view = ImageView::from_u8_slice(15, 15, &img).unwrap();
667 let ctx = RefineContext {
668 image: Some(view),
669 response: None,
670 };
671 let mut refiner = ForstnerRefiner::new(ForstnerConfig::default());
672 let res = refiner.refine([7.0, 8.0], ctx);
673 assert_eq!(res.status, RefineStatus::Accepted);
674 let true_xy = [7.35f32, 7.8f32];
675 let seed_err = ((7.0 - true_xy[0]).powi(2) + (8.0 - true_xy[1]).powi(2)).sqrt();
676 let refined_err = ((res.x - true_xy[0]).powi(2) + (res.y - true_xy[1]).powi(2)).sqrt();
677 assert!(
678 refined_err <= seed_err * 1.6 && refined_err < 1.0,
679 "refined_err {refined_err} seed_err {seed_err} res {:?}",
680 (res.x, res.y)
681 );
682 }
683
684 #[test]
685 fn saddle_point_recovers_stationary_point_and_rejects_flat() {
686 let img = synthetic_checkerboard(17, (8.2, 8.6), 30, 230);
687 let view = ImageView::from_u8_slice(17, 17, &img).unwrap();
688 let ctx = RefineContext {
689 image: Some(view),
690 response: None,
691 };
692 let mut refiner = SaddlePointRefiner::new(SaddlePointConfig::default());
693 let res = refiner.refine([8.0, 9.0], ctx);
694 assert_eq!(res.status, RefineStatus::Accepted);
695 let true_xy = [8.2f32, 8.6f32];
696 let seed_err = ((8.0 - true_xy[0]).powi(2) + (9.0 - true_xy[1]).powi(2)).sqrt();
697 let refined_err = ((res.x - true_xy[0]).powi(2) + (res.y - true_xy[1]).powi(2)).sqrt();
698 assert!(
699 refined_err <= seed_err * 1.6 && refined_err < 1.0,
700 "refined_err {refined_err} seed_err {seed_err} res {:?}",
701 (res.x, res.y)
702 );
703
704 let flat = vec![128u8; 25];
705 let flat_view = ImageView::from_u8_slice(5, 5, &flat).unwrap();
706 let flat_ctx = RefineContext {
707 image: Some(flat_view),
708 response: None,
709 };
710 let flat_res = refiner.refine([2.0, 2.0], flat_ctx);
711 assert_ne!(flat_res.status, RefineStatus::Accepted);
712 }
713}