1use anyhow::{anyhow, Context, Result};
31use std::path::{Path, PathBuf};
32use std::sync::OnceLock;
33use tract_onnx::prelude::tract_ndarray::{Array4, Ix2};
34use tract_onnx::prelude::*;
35
36#[derive(Clone, Debug)]
38pub enum ModelSource {
39 Path(PathBuf),
44 EmbeddedDefault,
47}
48
49pub struct MlModel {
56 model: TypedRunnableModel<TypedModel>,
57 patch_size: usize,
58 #[allow(dead_code)]
64 symbols: SymbolScope,
65}
66
67impl MlModel {
68 pub fn load(source: ModelSource) -> Result<Self> {
78 let (model_path, patch_size) = match source {
79 ModelSource::Path(path) => {
80 let patch_size =
81 patch_size_from_meta_path(&path).unwrap_or_else(default_patch_size);
82 (path, patch_size)
83 }
84 ModelSource::EmbeddedDefault => {
85 #[cfg(feature = "embed-model")]
86 {
87 let patch_size = patch_size_from_meta_bytes(EMBED_META_JSON)
88 .unwrap_or_else(|_| default_patch_size());
89 let path = embedded_model_path()?;
90 (path, patch_size)
91 }
92 #[cfg(not(feature = "embed-model"))]
93 {
94 return Err(anyhow!(
95 "embedded model support disabled; enable feature \"embed-model\""
96 ));
97 }
98 }
99 };
100
101 let mut model = tract_onnx::onnx()
102 .model_for_path(&model_path)
103 .with_context(|| format!("load ONNX model from {}", model_path.display()))?;
104 let symbols = SymbolScope::default();
105 let batch = symbols.sym("N");
106 let shape = tvec!(
107 batch.to_dim(),
108 1.to_dim(),
109 (patch_size as i64).to_dim(),
110 (patch_size as i64).to_dim()
111 );
112 model
113 .set_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), shape))
114 .context("set ML refiner input fact")?;
115 let model = model
116 .into_optimized()
117 .context("optimize ONNX model")?
118 .into_runnable()
119 .context("make ONNX model runnable")?;
120
121 Ok(Self {
122 model,
123 patch_size,
124 symbols,
125 })
126 }
127
128 pub fn patch_size(&self) -> usize {
130 self.patch_size
131 }
132
133 pub fn infer_batch(&self, patches: &[f32], batch: usize) -> Result<Vec<[f32; 3]>> {
145 if batch == 0 {
146 return Ok(Vec::new());
147 }
148 let patch_area = self.patch_size * self.patch_size;
149 let expected = batch * patch_area;
150 if patches.len() != expected {
151 return Err(anyhow!(
152 "expected {} floats (batch {} * patch {}x{}), got {}",
153 expected,
154 batch,
155 self.patch_size,
156 self.patch_size,
157 patches.len()
158 ));
159 }
160
161 let input = Array4::from_shape_vec(
162 (batch, 1, self.patch_size, self.patch_size),
163 patches.to_vec(),
164 )
165 .context("reshape input patches")?
166 .into_tensor();
167 let result = self
168 .model
169 .run(tvec!(input.into_tvalue()))
170 .context("run ONNX inference")?;
171 let output = result[0]
172 .to_array_view::<f32>()
173 .context("read ONNX output")?
174 .into_dimensionality::<Ix2>()
175 .context("reshape ONNX output")?;
176
177 if output.ncols() != 3 {
178 return Err(anyhow!(
179 "expected output shape [N,3], got [N,{}]",
180 output.ncols()
181 ));
182 }
183
184 let mut out = Vec::with_capacity(batch);
185 for row in output.outer_iter() {
186 out.push([row[0], row[1], row[2]]);
187 }
188 Ok(out)
189 }
190}
191
192fn patch_size_from_meta_bytes(bytes: &[u8]) -> Result<usize> {
193 let meta: serde_json::Value =
194 serde_json::from_slice(bytes).context("parse ML refiner meta.json")?;
195 let size = meta
196 .get("patch_size")
197 .and_then(|v| v.as_u64())
198 .ok_or_else(|| anyhow!("meta.json missing patch_size"))?;
199 Ok(size as usize)
200}
201
202fn patch_size_from_meta_path(path: &Path) -> Option<usize> {
203 let meta_path = path.parent()?.join("fixtures").join("meta.json");
204 let bytes = std::fs::read(meta_path).ok()?;
205 patch_size_from_meta_bytes(&bytes).ok()
206}
207
208fn default_patch_size() -> usize {
209 #[cfg(feature = "embed-model")]
210 {
211 patch_size_from_meta_bytes(EMBED_META_JSON).unwrap_or(21)
212 }
213 #[cfg(not(feature = "embed-model"))]
214 {
215 21
216 }
217}
218
219#[cfg(feature = "embed-model")]
220const EMBED_ONNX_NAME: &str = "chess_refiner_v4.onnx";
221#[cfg(feature = "embed-model")]
222const EMBED_ONNX_DATA_NAME: &str = "chess_refiner_v4.onnx.data";
223
224#[cfg(feature = "embed-model")]
225const EMBED_ONNX: &[u8] = include_bytes!(concat!(
226 env!("CARGO_MANIFEST_DIR"),
227 "/assets/ml/chess_refiner_v4.onnx"
228));
229#[cfg(feature = "embed-model")]
230const EMBED_ONNX_DATA: &[u8] = include_bytes!(concat!(
231 env!("CARGO_MANIFEST_DIR"),
232 "/assets/ml/chess_refiner_v4.onnx.data"
233));
234#[cfg(feature = "embed-model")]
235const EMBED_META_JSON: &[u8] = include_bytes!(concat!(
236 env!("CARGO_MANIFEST_DIR"),
237 "/assets/ml/fixtures/v4/meta.json"
238));
239
240#[cfg(feature = "embed-model")]
241fn embedded_model_path() -> Result<PathBuf> {
242 static PATH: OnceLock<PathBuf> = OnceLock::new();
255 let path = PATH.get_or_init(|| {
256 let dir = std::env::temp_dir().join("chess_corners_ml");
257 std::fs::create_dir_all(&dir).expect("create ML model temp dir");
258 let onnx_path = dir.join(EMBED_ONNX_NAME);
259 let data_path = dir.join(EMBED_ONNX_DATA_NAME);
260 write_if_changed(&data_path, EMBED_ONNX_DATA).expect("write embedded ONNX data");
263 write_if_changed(&onnx_path, EMBED_ONNX).expect("write embedded ONNX model");
264 onnx_path
265 });
266 Ok(path.clone())
267}
268
269#[cfg(feature = "embed-model")]
275fn write_if_changed(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
276 if let Ok(meta) = std::fs::metadata(path) {
277 if meta.len() == data.len() as u64 {
278 if let Ok(existing) = std::fs::read(path) {
279 if existing == data {
280 return Ok(());
281 }
282 }
283 }
284 }
285 let tmp = path.with_extension("tmp");
286 std::fs::write(&tmp, data)?;
287 std::fs::rename(&tmp, path)
288}
289
290#[cfg(all(test, feature = "embed-model"))]
291mod tests {
292 use super::write_if_changed;
293
294 #[test]
295 fn write_if_changed_rewrites_same_size_changed_bytes() {
296 let dir = tempfile::tempdir().expect("tempdir");
297 let path = dir.path().join("model.bin");
298
299 write_if_changed(&path, b"abc").expect("initial write");
300 write_if_changed(&path, b"xyz").expect("rewrite same-size bytes");
301
302 let bytes = std::fs::read(&path).expect("read rewritten bytes");
303 assert_eq!(bytes, b"xyz");
304 }
305}