Skip to content

Latest commit

 

History

History
 
 

mnist-inference-web

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

MNIST Inference on Web

Live Demo

This crate demonstrates how to run an MNIST-trained model in the browser for inference.

Running

  1. Build

    ./build-for-web.sh {backend}

    The backend can either be ndarray or wgpu. Note that wgpu only works for browsers with support for WebGPU.

  2. Run the server

    ./run-server.sh
  3. Open the http://localhost:8000/ in the browser.

Design

The inference components of burn with the ndarray backend can be built with #![no_std]. This makes it possible to build and run the model with the wasm32-unknown-unknown target without a special system library, such as WASI. (See Cargo.toml on how to include burn dependencies without std).

For this demo, we use trained parameters (model.bin) and model (model.rs) from the burn MNIST example.

The inference API for JavaScript is exposed with the help of wasm-bindgen's library and tools.

JavaScript (index.js) is used to transform hand-drawn digits to a format that the inference API accepts. The transformation includes image cropping, scaling down, and converting it to grayscale values.

Model

Layers:

  1. Input Image (28,28, 1ch)
  2. Conv2d(3x3, 8ch), BatchNorm2d, Gelu
  3. Conv2d(3x3, 16ch), BatchNorm2d, Gelu
  4. Conv2d(3x3, 24ch), BatchNorm2d, Gelu
  5. Linear(11616, 32), Gelu
  6. Linear(32, 10)
  7. Softmax Output

The total number of parameters is 376,952.

The model is trained with 4 epochs and the final test accuracy is 98.67%.

The training and hyper parameter information in can be found in burn MNIST example.

Comparison

The main differentiating factor of this example's approach (compiling rust model into wasm) and other popular tools, such as TensorFlow.js, ONNX Runtime JS and TVM Web is the absence of runtime code. The rust compiler optimizes and includes only used burn routines. 1,509,747 bytes out of Wasm's 1,866,491 byte file is the model's parameters. The rest of 356,744 bytes contain all the code (including burn's nn components, the data deserialization library, and math operations).

Future Improvements

There are several planned enhancements in place:

  • #202 - Saving model's params in half-precision and loading back in full. This can be half the size of the wasm file.
  • #243 - New WebGPU backend would allow computation using GPU in the browser.
  • #1271 - WASM SIMD support in NDArray that can speed up computation on CPU.

Acknowledgements

Two online MNIST demos inspired and helped build this demo: MNIST Draw by Marc (@mco-gh) and MNIST Web Demo (no code was copied but helped tremendously with an implementation approach).

Resources

  1. Rust 🦀 and WebAssembly
  2. wasm-bindgen