Skip to content

An efficient method for the conversion from internal to Cartesian coordinates that utilizes the platform-agnostic JAX Python library.

License

Notifications You must be signed in to change notification settings

PeptoneLtd/nerfax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NeRFax

In this work we implement NeRF and, to the best of our knowledge, the first fully parallel implementation of pNeRF in an emerging framework, JAX. We demonstrate speedups in the range 35-175x in comparison to the fastest public implementation for single chain proteins and utilising the frameworks ability to trivially parallelise functions we show a >10,000x speedup relative to using mp-NeRF serially for a biomolecular condensate of 1,000 chains of 163 residues.

Benchmarks

Single chain

Runtime of different computational methods for single chains

Speedup, relative to the CPU mp_nerf implementation, of different computational methods for single chains

This can be reproduced with notebooks/benchmark_single_chain_reconstruction.ipynb.

Multiple chains: Biomolecular condensate reconstruction

Leveraging the automatic vectorization feature of JAX the reconstruction was parallelized, running in 3.4 ms on GPU. Extrapolation of the torch implementation gives ~60 seconds in previous implementations, approximately 17,000x faster as the torch has no parallel chain implementation so has to be computed serially. This can be reproduced with notebooks/benchmark_multiple_chain_reconstruction.ipynb.

Installation

Pip

git clone https://github.com/PeptoneLtd/nerfax.git && pip install ./nerfax[optional]

Note: for running on GPU, a GPU version of JAX must be installed, please follows the instructions at JAX GPU compatibility instructions

Docker image

We also provide a Dockerfile which can be used to install NerFax. The dockerfile includes the GPU version of JAX.

About

An efficient method for the conversion from internal to Cartesian coordinates that utilizes the platform-agnostic JAX Python library.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published