New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add f16 support in the wgpu backend #1582
base: main
Are you sure you want to change the base?
Conversation
The burn-wgpu backend currently does not support computations on 16 bit floats. This. for example, limits the ability to run LLMs on top of Burn, on widely available hardware. So, add 16 bit float support in burn-wgpu. Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1582 +/- ##
==========================================
- Coverage 86.39% 86.35% -0.05%
==========================================
Files 688 688
Lines 78676 78718 +42
==========================================
+ Hits 67974 67977 +3
- Misses 10702 10741 +39 ☔ View full report in Codecov by Sentry. |
@@ -21,6 +21,7 @@ pub enum Visibility { | |||
#[allow(missing_docs)] | |||
pub enum Elem { | |||
Float, | |||
Half, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to add Half
here, Float
should cover all float types of all precisions in this context.
fn gpu_elem() -> gpu::Elem { | ||
gpu::Elem::Half | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gpu element would be Float
here.
let features = match F::gpu_elem() { | ||
gpu::Elem::Half => vec![wgsl::Feature::ShaderF16], | ||
_ => vec![], | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would check using F::wgpu_elem() == Elem::F16
instead.
gpu::Elem::Float => F::wgpu_elem(), | ||
gpu::Elem::Half => F::wgpu_elem(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line pretty much explains why we don't need Half
in the gpu::Elem
enum :)
This PR has been marked as stale because it has not been updated for over a month |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Split from: #1475
Dawn support: #1583
Changes
The burn-wgpu backend currently does not support computations on 16 bit floats. This. for example, limits the ability to run LLMs on top of Burn, on widely available hardware. So, add 16 bit float support in burn-wgpu.
Testing
I used this change on top of my changes which add the ability to run with Dawn instead of wgpu, to run llama2-burn with f16.