Skip to main content

chess_corners_ml/
lib.rs

1//! ONNX-backed ML refiner for ChESS corner candidates.
2//!
3//! This crate provides [`MlModel`], a thin wrapper around a
4//! [tract-onnx](https://docs.rs/tract-onnx) runtime that predicts
5//! subpixel `(dx, dy)` offsets for each corner candidate from a
6//! normalized intensity patch.
7//!
8//! # Intended use
9//!
10//! This crate is not meant to be used directly. It is consumed by the
11//! `chess-corners` facade crate when the `ml-refiner` feature is
12//! enabled. With the feature on, set the active ChESS refiner to
13//! `ChessRefiner::Ml` and call `Detector::detect` to route through
14//! the ML refiner.
15//!
16//! # Embedded model
17//!
18//! When the optional `embed-model` feature is enabled, the ONNX model
19//! and its external data file are compiled into the binary via
20//! `include_bytes!` and extracted to a temporary directory on first
21//! use. The extraction is thread-safe and idempotent (write-then-rename
22//! with byte-match skip).
23//!
24//! # Performance note
25//!
26//! ML refinement is significantly slower than the geometric refiners
27//! (~24 ms vs <1 ms for 77 corners on a 640×480 image). Use it only
28//! when maximum subpixel accuracy is required and throughput allows.
29
30use anyhow::{anyhow, Context, Result};
31use std::path::{Path, PathBuf};
32use std::sync::OnceLock;
33use tract_onnx::prelude::tract_ndarray::{Array4, Ix2};
34use tract_onnx::prelude::*;
35
36/// Specifies where [`MlModel::load`] should read the ONNX model from.
37#[derive(Clone, Debug)]
38pub enum ModelSource {
39    /// Load from an explicit filesystem path to the `.onnx` file.
40    /// A `fixtures/meta.json` sidecar next to the model's parent directory
41    /// is read to determine the patch size; falls back to the compiled-in
42    /// default (21 px) when absent.
43    Path(PathBuf),
44    /// Use the model compiled into the binary via the `embed-model`
45    /// Cargo feature. Returns an error when that feature is not enabled.
46    EmbeddedDefault,
47}
48
49/// Loaded and optimised ONNX model for corner refinement.
50///
51/// The model accepts a batch of `f32` intensity patches with shape
52/// `[N, 1, patch_size, patch_size]` (values in `[0, 1]`) and returns
53/// `[N, 3]` with columns `[dx, dy, conf_logit]`. Only `dx` and `dy`
54/// are currently used; `conf_logit` is ignored.
55pub struct MlModel {
56    model: TypedRunnableModel<TypedModel>,
57    patch_size: usize,
58    // `SymbolScope` owns the `Symbol` object for the dynamic batch
59    // dimension "N". Dropping it before `model` would leave the compiled
60    // graph with a dangling reference to the scope's internal table, so
61    // this field must be kept alive for the lifetime of `MlModel` even
62    // though it is never explicitly read after construction.
63    #[allow(dead_code)]
64    symbols: SymbolScope,
65}
66
67impl MlModel {
68    /// Load and optimise an ONNX model from the given source.
69    ///
70    /// For [`ModelSource::EmbeddedDefault`] the `embed-model` Cargo
71    /// feature must be enabled; an error is returned otherwise.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the model file cannot be read, the ONNX
76    /// graph is malformed, or tract optimisation / compilation fails.
77    pub fn load(source: ModelSource) -> Result<Self> {
78        let (model_path, patch_size) = match source {
79            ModelSource::Path(path) => {
80                let patch_size =
81                    patch_size_from_meta_path(&path).unwrap_or_else(default_patch_size);
82                (path, patch_size)
83            }
84            ModelSource::EmbeddedDefault => {
85                #[cfg(feature = "embed-model")]
86                {
87                    let patch_size = patch_size_from_meta_bytes(EMBED_META_JSON)
88                        .unwrap_or_else(|_| default_patch_size());
89                    let path = embedded_model_path()?;
90                    (path, patch_size)
91                }
92                #[cfg(not(feature = "embed-model"))]
93                {
94                    return Err(anyhow!(
95                        "embedded model support disabled; enable feature \"embed-model\""
96                    ));
97                }
98            }
99        };
100
101        let mut model = tract_onnx::onnx()
102            .model_for_path(&model_path)
103            .with_context(|| format!("load ONNX model from {}", model_path.display()))?;
104        let symbols = SymbolScope::default();
105        let batch = symbols.sym("N");
106        let shape = tvec!(
107            batch.to_dim(),
108            1.to_dim(),
109            (patch_size as i64).to_dim(),
110            (patch_size as i64).to_dim()
111        );
112        model
113            .set_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), shape))
114            .context("set ML refiner input fact")?;
115        let model = model
116            .into_optimized()
117            .context("optimize ONNX model")?
118            .into_runnable()
119            .context("make ONNX model runnable")?;
120
121        Ok(Self {
122            model,
123            patch_size,
124            symbols,
125        })
126    }
127
128    /// Side length (in pixels) of the square intensity patch the model expects.
129    pub fn patch_size(&self) -> usize {
130        self.patch_size
131    }
132
133    /// Run inference on a flat batch of intensity patches.
134    ///
135    /// `patches` must contain exactly `batch * patch_size * patch_size`
136    /// `f32` values in `[N, 1, H, W]` order (values in `[0, 1]`).
137    /// Returns one `[dx, dy, conf_logit]` triple per input patch.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if the slice length does not match
142    /// `batch * patch_size²`, if the ONNX output shape is unexpected,
143    /// or if tract inference fails.
144    pub fn infer_batch(&self, patches: &[f32], batch: usize) -> Result<Vec<[f32; 3]>> {
145        if batch == 0 {
146            return Ok(Vec::new());
147        }
148        let patch_area = self.patch_size * self.patch_size;
149        let expected = batch * patch_area;
150        if patches.len() != expected {
151            return Err(anyhow!(
152                "expected {} floats (batch {} * patch {}x{}), got {}",
153                expected,
154                batch,
155                self.patch_size,
156                self.patch_size,
157                patches.len()
158            ));
159        }
160
161        let input = Array4::from_shape_vec(
162            (batch, 1, self.patch_size, self.patch_size),
163            patches.to_vec(),
164        )
165        .context("reshape input patches")?
166        .into_tensor();
167        let result = self
168            .model
169            .run(tvec!(input.into_tvalue()))
170            .context("run ONNX inference")?;
171        let output = result[0]
172            .to_array_view::<f32>()
173            .context("read ONNX output")?
174            .into_dimensionality::<Ix2>()
175            .context("reshape ONNX output")?;
176
177        if output.ncols() != 3 {
178            return Err(anyhow!(
179                "expected output shape [N,3], got [N,{}]",
180                output.ncols()
181            ));
182        }
183
184        let mut out = Vec::with_capacity(batch);
185        for row in output.outer_iter() {
186            out.push([row[0], row[1], row[2]]);
187        }
188        Ok(out)
189    }
190}
191
192fn patch_size_from_meta_bytes(bytes: &[u8]) -> Result<usize> {
193    let meta: serde_json::Value =
194        serde_json::from_slice(bytes).context("parse ML refiner meta.json")?;
195    let size = meta
196        .get("patch_size")
197        .and_then(|v| v.as_u64())
198        .ok_or_else(|| anyhow!("meta.json missing patch_size"))?;
199    Ok(size as usize)
200}
201
202fn patch_size_from_meta_path(path: &Path) -> Option<usize> {
203    let meta_path = path.parent()?.join("fixtures").join("meta.json");
204    let bytes = std::fs::read(meta_path).ok()?;
205    patch_size_from_meta_bytes(&bytes).ok()
206}
207
208fn default_patch_size() -> usize {
209    #[cfg(feature = "embed-model")]
210    {
211        patch_size_from_meta_bytes(EMBED_META_JSON).unwrap_or(21)
212    }
213    #[cfg(not(feature = "embed-model"))]
214    {
215        21
216    }
217}
218
219#[cfg(feature = "embed-model")]
220const EMBED_ONNX_NAME: &str = "chess_refiner_v4.onnx";
221#[cfg(feature = "embed-model")]
222const EMBED_ONNX_DATA_NAME: &str = "chess_refiner_v4.onnx.data";
223
224#[cfg(feature = "embed-model")]
225const EMBED_ONNX: &[u8] = include_bytes!(concat!(
226    env!("CARGO_MANIFEST_DIR"),
227    "/assets/ml/chess_refiner_v4.onnx"
228));
229#[cfg(feature = "embed-model")]
230const EMBED_ONNX_DATA: &[u8] = include_bytes!(concat!(
231    env!("CARGO_MANIFEST_DIR"),
232    "/assets/ml/chess_refiner_v4.onnx.data"
233));
234#[cfg(feature = "embed-model")]
235const EMBED_META_JSON: &[u8] = include_bytes!(concat!(
236    env!("CARGO_MANIFEST_DIR"),
237    "/assets/ml/fixtures/v4/meta.json"
238));
239
240#[cfg(feature = "embed-model")]
241fn embedded_model_path() -> Result<PathBuf> {
242    // `OnceLock::get_or_init` serializes the writes across threads in
243    // this process. Without it, parallel `#[test]` runs all entered
244    // `write_if_changed`, the second `std::fs::write` truncated the
245    // file to 0 bytes mid-rewrite, and a concurrent `tract_onnx`
246    // model load saw an empty `.data` slice and panicked
247    // (`range start index 768 out of range for slice of length 0`).
248    //
249    // For cross-process races (e.g. `cargo test -p A` and
250    // `cargo test -p B` sharing `/tmp/chess_corners_ml/`), the
251    // atomic write-then-rename in `write_if_changed` ensures the
252    // file is either at its old contents or at its new contents,
253    // never partially written.
254    static PATH: OnceLock<PathBuf> = OnceLock::new();
255    let path = PATH.get_or_init(|| {
256        let dir = std::env::temp_dir().join("chess_corners_ml");
257        std::fs::create_dir_all(&dir).expect("create ML model temp dir");
258        let onnx_path = dir.join(EMBED_ONNX_NAME);
259        let data_path = dir.join(EMBED_ONNX_DATA_NAME);
260        // Write `.data` before `.onnx` so tract never sees an `.onnx`
261        // that references a missing or partially-written `.data`.
262        write_if_changed(&data_path, EMBED_ONNX_DATA).expect("write embedded ONNX data");
263        write_if_changed(&onnx_path, EMBED_ONNX).expect("write embedded ONNX model");
264        onnx_path
265    });
266    Ok(path.clone())
267}
268
269/// Write `data` to `path` only if the file doesn't already contain
270/// the same bytes. Uses write-then-rename so concurrent readers see
271/// either the old contents or the new contents — never a truncated /
272/// partially-written file. Cheap-out via the byte-match check avoids
273/// rewriting unchanged files across re-runs in a shared temp dir.
274#[cfg(feature = "embed-model")]
275fn write_if_changed(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
276    if let Ok(meta) = std::fs::metadata(path) {
277        if meta.len() == data.len() as u64 {
278            if let Ok(existing) = std::fs::read(path) {
279                if existing == data {
280                    return Ok(());
281                }
282            }
283        }
284    }
285    let tmp = path.with_extension("tmp");
286    std::fs::write(&tmp, data)?;
287    std::fs::rename(&tmp, path)
288}
289
290#[cfg(all(test, feature = "embed-model"))]
291mod tests {
292    use super::write_if_changed;
293
294    #[test]
295    fn write_if_changed_rewrites_same_size_changed_bytes() {
296        let dir = tempfile::tempdir().expect("tempdir");
297        let path = dir.path().join("model.bin");
298
299        write_if_changed(&path, b"abc").expect("initial write");
300        write_if_changed(&path, b"xyz").expect("rewrite same-size bytes");
301
302        let bytes = std::fs::read(&path).expect("read rewritten bytes");
303        assert_eq!(bytes, b"xyz");
304    }
305}