Skip to content

Model summary in PyTorch similar to `model.summary()` in Keras

Notifications You must be signed in to change notification settings

indigoviolet/pytorch-summary

 
 

Repository files navigation

PyTorch model summaries

Initially based off pytorch-summary.

Improvements:

  • modernizes and simplifies the code
  • adds input shape tracking
  • supports arbitrary input objects to the model (through the extract_input_tensor argument to summary())
  • adds Conv2d complexity
  • adds start, jump, receptive field tracking (based on pytorch-receptive-field and this medium article)
  • adds receptive-field computation based on gradients (see this notebook)

Demo

See the demo notebook

Tests

pytest --nbmake demo.ipynb

Grad Receptive Field

For complex models, this will likely be accurate where the analytical computation is not: for ex. when two modules’ outputs are combined

Notes

  • grad_receptive_field=True modifies the model in multiple ways (but makes a copy before doing so)
    • initialization of Conv2d
    • replaces MaxPool2d with AvgPool2d
    • turns off Dropout and BatchNorm2d
  • Treats any class ending in Conv2d as a Conv2d, and similarly for BatchNorm2d, MaxPool2d, Dropout. This is for handling custom module classes that don’t directly derive from the nn. classes
  • Requires a tensor output from a layer to compute RF for that layer

Conv2d complexity

Computed as num_input_filters * num_output_filters * H * W

Roadmap

About

Model summary in PyTorch similar to `model.summary()` in Keras

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 77.3%
  • Python 22.7%