Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support switching burn-wgpu between WebGPU implementations (wgpu <-> Dawn) #1583

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

p1-0tr
Copy link

@p1-0tr p1-0tr commented Apr 7, 2024

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Split from: #1475

Changes

Currently Burn can only be built against the wgpu WebGPU implementation. I found the ability to switch to Dawn useful, e.g. to get float16 support, which helps when trying to run Llama2 on top of Burn.

I've:

  • split out any uses of wgpu into a separate module, and defined traits which provide all operations needed by burn-wgpu,
  • added Dawn as a submodule,
  • modified the build system to be able to build, generate bindings for, and link against Dawn,
  • added a module which wraps the usage of the Dawn bindings to make it look more like wgpu.

Testing

I ran unit tests for the burn-wgpu crate. Dawn test cases should be generated when the dawn feature is enabled for burn-wgpu, some of them will fail because Dawn performs out of bound accesses if no explicit bounds checks are emitted for I/O arrays (I've pushed a branch with a fix for that here - https://github.com/p1-0tr/burn/tree/ps-allow-using-dawn-and-wgpu-w-bounds).

@p1-0tr p1-0tr mentioned this pull request Apr 7, 2024
2 tasks
@p1-0tr p1-0tr force-pushed the ps-allow-using-dawn-and-wgpu branch 2 times, most recently from c12d895 to b316fde Compare April 7, 2024 14:02
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good job, I like how you abstracted the WebGpu Api. I have a few comments, but it's still in draft so some might be already obvious to you.

crates/burn-wgpu/Cargo.toml Outdated Show resolved Hide resolved
@@ -0,0 +1,202 @@
#[cfg(all(feature = "dawn", not(target_os = "macos")))]
compile_error!("The 'dawn' backend currently only builds on macos.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there something we can do to compile dawn for other targets?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the main reason I left the PR in draft. Linux build should be simple to add, Windows may be more problematic. I'm planning to look into that tomorrow or on Thursday.

crates/burn-wgpu/build.rs Outdated Show resolved Hide resolved
crates/burn-wgpu/build.rs Outdated Show resolved Hide resolved
return Reader::Future(Box::pin(future));
}

#[cfg(not(target_family = "wasm"))]
Reader::Concrete(self.buffer_reader(handle).read(&self.device))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find a way to read the buffer synchronously with wgpu when targetting wasm. Did you find a way?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no, I must have missed this when resolving a rebase. Thanks :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, is there a way to run unit tests when targetting wasm?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should, be fixed now :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we run tests yet on wasm, but it should be possible

crates/burn-wgpu/src/compute/webgpu_api.rs Outdated Show resolved Hide resolved
crates/burn-wgpu/src/compute/webgpu_api.rs Outdated Show resolved Hide resolved
async fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> Self::Adapter;
#[cfg(not(target_family = "wasm"))]
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> Self::Adapter;
fn read_buffer(buffer: &Self::Buffer, device: &Self::Device) -> Vec<u8>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to move that function to the queue, so that it can be closer to write_buffer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up moving it to the buffer trait, since it does not use the queue.

crates/burn-wgpu/src/compute/wgpu_api_shim.rs Outdated Show resolved Hide resolved
Comment on lines 57 to 58
pub type Wgpu<W = WgpuApi, G = AutoGraphicsApi, F = f32, I = i32> =
burn_fusion::Fusion<JitBackend<WgpuRuntime<W, G, F, I>>>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would hardcode the WgpuApi here, otherwise Wgpu can become Dawn, which would be confusing.

pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> = burn_fusion::Fusion<JitBackend<WgpuRuntime<WgpuApi, G, F, I>>>;

Maybe we could also simply export one backend type and have the backend type being an associated type of a trait that is also generic over the other settings:

impl BackendSelection for WgpuBackend {
    #[cfg(feature = "fusion")]
    type Backend<G, F, I> = burn_fusion::Fusion<JitBackend<WgpuRuntime<Self, G, F, I>>;
    
    #[cfg(not(feature = "fusion"))]
    type Backend<G, F, I> = JitBackend<WgpuRuntime<Self, G, F, I>;
}

Just an idea.

Elem::I32 => f.write_fmt(format_args!("{number}i")),
// Dawn seems to get tripped up by the 'i' suffix, while wgpu is happy
// with or without it, so emit the literal without it.
Elem::I32 => f.write_fmt(format_args!("{number}")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sanity question: wgpu is still correctly interpreting this as a signed int, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not seen any evidence to the contrary. The standard says that literals without a suffix should be treated as AbstractInt (https://www.w3.org/TR/WGSL/#numeric-literal) with values in the range [-2^63, 2^63) (https://www.w3.org/TR/WGSL/#abstractint). So, I think the answer is yes :)

Copy link

codecov bot commented Apr 12, 2024

Codecov Report

Attention: Patch coverage is 78.02198% with 80 lines in your changes are missing coverage. Please review.

Project coverage is 86.40%. Comparing base (5bbc5ea) to head (ebd6963).
Report is 5 commits behind head on main.

❗ Current head ebd6963 differs from pull request most recent head 3212b1a. Consider uploading reports for the commit 3212b1a to get more accurate results

Files Patch % Lines
crates/burn-wgpu/src/compute/wgpu_api_shim.rs 75.58% 73 Missing ⚠️
crates/burn-wgpu/src/runtime.rs 68.18% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1583      +/-   ##
==========================================
- Coverage   86.61%   86.40%   -0.21%     
==========================================
  Files         700      695       -5     
  Lines       83427    80618    -2809     
==========================================
- Hits        72257    69656    -2601     
+ Misses      11170    10962     -208     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@p1-0tr p1-0tr force-pushed the ps-allow-using-dawn-and-wgpu branch from b316fde to ebd6963 Compare April 18, 2024 13:22
@p1-0tr p1-0tr marked this pull request as ready for review April 18, 2024 13:27
@antimora antimora added the enhancement Enhance existing features label Apr 26, 2024
p1-0tr added 2 commits May 9, 2024 15:27
Currently the burn-wgpu crate is hardcoded to use the wgpu WebGPU
implementaion. It would be nice to be able to use other WebGPU
implementations as new features may land at different times, and there
may be potential performance gains to unlock. So, separate any use of
the wgpu crate into a separate module, in preparation for adding the
ability to target other WebGPU implementaions.

Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
Dawn, Google's WebGPU impelementation, currently supports feattures
which wgpu does not. For example using 16 bit floats in shaders. So, add
the ability to build the burn-wgpu backend against Dawn.

Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
@p1-0tr p1-0tr force-pushed the ps-allow-using-dawn-and-wgpu branch from ebd6963 to 3212b1a Compare May 13, 2024 21:40
@nathanielsimard
Copy link
Member

@p1-0tr Just an update:

As you may have seen, we updated quite a lot in the wgpu server, which solved a lot of issues with different graphics APIs and improved performance a bit. We also changed how the element types are handled in the JIT backend, which paves the way for quantization and our new GPU Rust API. We still want to go further with this PR and eventually merge it. Things are a bit more stable now, so we can better review and integrate it.

Note: I don't think using a git submodule is optimal; I would prefer downloading it with a tagged version of Dawn to the cache directory using this method, however this could be done in a following PR.

@p1-0tr
Copy link
Author

p1-0tr commented May 14, 2024

As you may have seen, we updated quite a lot in the wgpu server, which solved a lot of issues with different graphics APIs and improved performance a bit. We also changed how the element types are handled in the JIT backend, which paves the way for quantization and our new GPU Rust API. We still want to go further with this PR and eventually merge it. Things are a bit more stable now, so we can better review and integrate it.

@nathanielsimard thanks for the update. I've noticed the changes when doing my last rebase :D

Note: I don't think using a git submodule is optimal; I would prefer downloading it with a tagged version of Dawn to the cache directory using this method, however this could be done in a following PR.

Thanks :) I'll look into that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhance existing features
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants