This crate provides a simple example for importing MNIST ONNX model to Burn. The onnx file is
converted into a Rust source file using burn-import
and the weights are stored in and loaded from
a binary file.
cargo run -- 15
Output:
Finished dev [unoptimized + debuginfo] _target(s) in 0.13s
Running `burn/_target/debug/onnx-inference 15`
Image index: 15
Success!
Predicted: 5
Actual: 5
See the image online, click the link below:
https://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row=15
embedded-model
(default) - Embed the model weights into the binary. This is useful for small models (e.g. MNIST) but not recommended for very large models because it will increase the binary size significantly and will consume a lot of memory at runtime. If you do not use this feature, the model weights will be loaded from a binary file at runtime.
-
Create
model
directory undersrc
-
Copy the ONNX model to
src/model/mnist.onnx
-
Add the following to
mod.rs
:pub mod mnist { include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); }
-
Add the module to
lib.rs
:pub mod model; pub use model::mnist::*;
-
Add the following to
build.rs
:use burn_import::onnx::ModelGen; fn main() { // Generate the model code from the ONNX file. ModelGen::new() .input("src/model/mnist.onnx") .out_dir("model/") .run_from_script(); }
-
Add your model to
src/bin
as a new file, in this specific case we have called itmnist.rs
:use burn::tensor; use burn::backend::ndarray::NdArray; use onnx_inference::mnist::Model; fn main() { // Get a default device for the models's backend let device = Default::default(); // Create a new model and load the state let model: Model<Backend> = Model::new(&device).load_state(); // Create a new input tensor (all zeros for demonstration purposes) let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 1, 28, 28], &device); // Run the model let output = model.forward(input); // Print the output println!("{:?}", output); }
-
Run
cargo build
to generate the model code, weights, andmnist
binary.
The following steps show how to export a PyTorch model to ONNX from checked in PyTorch code (see
pytorch/mnist.py
).
-
Install dependencies:
pip install torch torchvision onnx
-
Run the following script to run the MNIST training and export the model to ONNX:
python3 pytorch/mnist.py
This will generate pytorch/mnist.onnx
.