Skip to main content

chess_corners_ml/
lib.rs

1use anyhow::{anyhow, Context, Result};
2use std::path::{Path, PathBuf};
3use std::sync::OnceLock;
4use tract_onnx::prelude::tract_ndarray::{Array4, Ix2};
5use tract_onnx::prelude::*;
6
7#[derive(Clone, Debug)]
8pub enum ModelSource {
9    Path(PathBuf),
10    EmbeddedDefault,
11}
12
13pub struct MlModel {
14    model: TypedRunnableModel<TypedModel>,
15    patch_size: usize,
16    #[allow(dead_code)]
17    // Keep SymbolScope alive for dynamic batch resolution.
18    symbols: SymbolScope,
19}
20
21impl MlModel {
22    pub fn load(source: ModelSource) -> Result<Self> {
23        let (model_path, patch_size) = match source {
24            ModelSource::Path(path) => {
25                let patch_size =
26                    patch_size_from_meta_path(&path).unwrap_or_else(default_patch_size);
27                (path, patch_size)
28            }
29            ModelSource::EmbeddedDefault => {
30                #[cfg(feature = "embed-model")]
31                {
32                    let patch_size = patch_size_from_meta_bytes(EMBED_META_JSON)
33                        .unwrap_or_else(|_| default_patch_size());
34                    let path = embedded_model_path()?;
35                    (path, patch_size)
36                }
37                #[cfg(not(feature = "embed-model"))]
38                {
39                    return Err(anyhow!(
40                        "embedded model support disabled; enable feature \"embed-model\""
41                    ));
42                }
43            }
44        };
45
46        let mut model = tract_onnx::onnx()
47            .model_for_path(&model_path)
48            .with_context(|| format!("load ONNX model from {}", model_path.display()))?;
49        let symbols = SymbolScope::default();
50        let batch = symbols.sym("N");
51        let shape = tvec!(
52            batch.to_dim(),
53            1.to_dim(),
54            (patch_size as i64).to_dim(),
55            (patch_size as i64).to_dim()
56        );
57        model
58            .set_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), shape))
59            .context("set ML refiner input fact")?;
60        let model = model
61            .into_optimized()
62            .context("optimize ONNX model")?
63            .into_runnable()
64            .context("make ONNX model runnable")?;
65
66        Ok(Self {
67            model,
68            patch_size,
69            symbols,
70        })
71    }
72
73    pub fn patch_size(&self) -> usize {
74        self.patch_size
75    }
76
77    pub fn infer_batch(&self, patches: &[f32], batch: usize) -> Result<Vec<[f32; 3]>> {
78        if batch == 0 {
79            return Ok(Vec::new());
80        }
81        let patch_area = self.patch_size * self.patch_size;
82        let expected = batch * patch_area;
83        if patches.len() != expected {
84            return Err(anyhow!(
85                "expected {} floats (batch {} * patch {}x{}), got {}",
86                expected,
87                batch,
88                self.patch_size,
89                self.patch_size,
90                patches.len()
91            ));
92        }
93
94        let input = Array4::from_shape_vec(
95            (batch, 1, self.patch_size, self.patch_size),
96            patches.to_vec(),
97        )
98        .context("reshape input patches")?
99        .into_tensor();
100        let result = self
101            .model
102            .run(tvec!(input.into_tvalue()))
103            .context("run ONNX inference")?;
104        let output = result[0]
105            .to_array_view::<f32>()
106            .context("read ONNX output")?
107            .into_dimensionality::<Ix2>()
108            .context("reshape ONNX output")?;
109
110        if output.ncols() != 3 {
111            return Err(anyhow!(
112                "expected output shape [N,3], got [N,{}]",
113                output.ncols()
114            ));
115        }
116
117        let mut out = Vec::with_capacity(batch);
118        for row in output.outer_iter() {
119            out.push([row[0], row[1], row[2]]);
120        }
121        Ok(out)
122    }
123}
124
125fn patch_size_from_meta_bytes(bytes: &[u8]) -> Result<usize> {
126    let meta: serde_json::Value =
127        serde_json::from_slice(bytes).context("parse ML refiner meta.json")?;
128    let size = meta
129        .get("patch_size")
130        .and_then(|v| v.as_u64())
131        .ok_or_else(|| anyhow!("meta.json missing patch_size"))?;
132    Ok(size as usize)
133}
134
135fn patch_size_from_meta_path(path: &Path) -> Option<usize> {
136    let meta_path = path.parent()?.join("fixtures").join("meta.json");
137    let bytes = std::fs::read(meta_path).ok()?;
138    patch_size_from_meta_bytes(&bytes).ok()
139}
140
141fn default_patch_size() -> usize {
142    #[cfg(feature = "embed-model")]
143    {
144        patch_size_from_meta_bytes(EMBED_META_JSON).unwrap_or(21)
145    }
146    #[cfg(not(feature = "embed-model"))]
147    {
148        21
149    }
150}
151
152#[cfg(feature = "embed-model")]
153const EMBED_ONNX_NAME: &str = "chess_refiner_v2.onnx";
154#[cfg(feature = "embed-model")]
155const EMBED_ONNX_DATA_NAME: &str = "chess_refiner_v2.onnx.data";
156
157#[cfg(feature = "embed-model")]
158const EMBED_ONNX: &[u8] = include_bytes!(concat!(
159    env!("CARGO_MANIFEST_DIR"),
160    "/assets/ml/chess_refiner_v2.onnx"
161));
162#[cfg(feature = "embed-model")]
163const EMBED_ONNX_DATA: &[u8] = include_bytes!(concat!(
164    env!("CARGO_MANIFEST_DIR"),
165    "/assets/ml/chess_refiner_v2.onnx.data"
166));
167#[cfg(feature = "embed-model")]
168const EMBED_META_JSON: &[u8] = include_bytes!(concat!(
169    env!("CARGO_MANIFEST_DIR"),
170    "/assets/ml/fixtures/v2/meta.json"
171));
172
173#[cfg(feature = "embed-model")]
174fn embedded_model_path() -> Result<PathBuf> {
175    static PATH: OnceLock<PathBuf> = OnceLock::new();
176    if let Some(path) = PATH.get() {
177        return Ok(path.clone());
178    }
179
180    let dir = std::env::temp_dir().join("chess_corners_ml");
181    std::fs::create_dir_all(&dir).context("create ML model temp dir")?;
182    let onnx_path = dir.join(EMBED_ONNX_NAME);
183    let data_path = dir.join(EMBED_ONNX_DATA_NAME);
184
185    // Only rewrite files when the exact bytes change, so fixed filenames in a
186    // shared temp dir never serve stale model artifacts across upgrades.
187    write_if_changed(&onnx_path, EMBED_ONNX).context("write embedded ONNX model")?;
188    write_if_changed(&data_path, EMBED_ONNX_DATA).context("write embedded ONNX data")?;
189
190    let _ = PATH.set(onnx_path.clone());
191    Ok(onnx_path)
192}
193
194/// Write `data` to `path` only if the file doesn't already contain the same bytes.
195#[cfg(feature = "embed-model")]
196fn write_if_changed(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
197    if let Ok(meta) = std::fs::metadata(path) {
198        if meta.len() == data.len() as u64 {
199            if let Ok(existing) = std::fs::read(path) {
200                if existing == data {
201                    return Ok(());
202                }
203            }
204        }
205    }
206    std::fs::write(path, data)
207}
208
209#[cfg(all(test, feature = "embed-model"))]
210mod tests {
211    use super::write_if_changed;
212
213    #[test]
214    fn write_if_changed_rewrites_same_size_changed_bytes() {
215        let dir = tempfile::tempdir().expect("tempdir");
216        let path = dir.path().join("model.bin");
217
218        write_if_changed(&path, b"abc").expect("initial write");
219        write_if_changed(&path, b"xyz").expect("rewrite same-size bytes");
220
221        let bytes = std::fs::read(&path).expect("read rewritten bytes");
222        assert_eq!(bytes, b"xyz");
223    }
224}