Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix record nested value de/serialization #1751

Merged
merged 3 commits into from May 22, 2024

Conversation

laggui
Copy link
Member

@laggui laggui commented May 9, 2024

While working on the Llama-3 implementation I stumbled upon a memory issue when importing pytorch weights with PyTorchFileRecorder.

When I profiled the memory usage for ResNet-152 (checkpoint is 252MB on disk), I saw a huge peak memory usage for what is supposed to be a relatively small model. Up to ~5GB as pointed out by the heaptrack trace below.

Before
image

After
image

Checklist

  • Confirmed that run-checks all script has been executed.

Changes

Added U16s and F32s variants for NestedValue so weights can be parsed as a vector of primitive types instead of Vec<NestedValue>. For example, a vec of f32s is now represented as Vec[v, v, v, ...] instead of Vec[NestedValue::F32(v), NestedValue::F32(v), ...]. The NestedValue enum has a size of 56 bytes so it can grow very rapidly (just imagine for a very large number of parameters like in LLama 8B 🤯 ).

  • Handle different vec types in Serializer based on the input element type
  • Make VecSeqAccess's iter generic and add concrete implementations for vec of NestedValue, u16 and f32

Testing

All unit tests pass, including half precision record tests in burn-import/pytorch-tests.

@laggui laggui marked this pull request as draft May 9, 2024 19:17
@laggui
Copy link
Member Author

laggui commented May 9, 2024

Was running the checks locally and test record::serde::ser::tests::test_param_serde just failed.. will investigate & fix.

/edit:

Previously the fmt::Debug captured by the test had the vector as Vec([F32(1.0), F32(1.0), F32(1.0), ...] len=4) but now the values are no longer encapsulated as a NestedValue::F32 so it is just Vec([1.0, 1.0, 1.0, ...] len=4) instead.

That means the characters F32() x 3 are excluded (15 characters total), which comes down to a new length 149 - 15 = 134 ✅

@laggui laggui marked this pull request as ready for review May 9, 2024 19:43
Copy link

codecov bot commented May 9, 2024

Codecov Report

Attention: Patch coverage is 87.61905% with 13 lines in your changes are missing coverage. Please review.

Project coverage is 86.61%. Comparing base (5bbc5ea) to head (40b3c71).
Report is 1 commits behind head on main.

Files Patch % Lines
crates/burn-core/src/record/serde/de.rs 88.73% 8 Missing ⚠️
crates/burn-core/src/record/serde/data.rs 83.33% 3 Missing ⚠️
crates/burn-core/src/record/serde/ser.rs 87.50% 2 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1751   +/-   ##
=======================================
  Coverage   86.61%   86.61%           
=======================================
  Files         700      700           
  Lines       83427    83509   +82     
=======================================
+ Hits        72257    72329   +72     
- Misses      11170    11180   +10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this. This is the correct approach, but to be complete, we need the following as well (start with 2 and 3):

  1. Handle other Tensor elements. See https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/pytorch/reader.rs#L91-L124 for all types.

  2. Instead of serializing CandleTensor, we should just convert the underlying ParamSerde and DataSerialize structures to NestedValue directly. The data (vector) pointer will be copied to NestedValue::F32s or NestedValue::F16s (a new datatype that preserves encoding). You will need to modify the serialize_data function at https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/pytorch/reader.rs#L127-L145.

  3. You will need to do something similar to the serialize function in https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/pytorch/adapter.rs#L68-L78. The logic overlaps with converting ParamSerde::new(param_id, DataSerialize::new(data, shape)) to NestedValue.

Once you fix points 2 and 3, we can bypass the serialization of tensors. Later, we can remove serialization artifacts and rely on generating NestedValue directly.

The main reason we want to do points 2 and 3 is that the code still converts each tensor element to NestedValue::F32, which is later collected into a Vec again, although there is no extra memory overhead except for one element at a time. Additionally, we can eliminate code complexity. We need to review the code and see if there are places where the serializer is used, as I believe I only intended to use it for Param objects.

crates/burn-core/src/record/serde/de.rs Show resolved Hide resolved
Comment on lines +295 to +302
NestedValue::U16s(_) => visitor.visit_seq(VecSeqAccess::<A, u16>::new(
value,
self.default_for_missing_fields,
)),
NestedValue::F32s(_) => visitor.visit_seq(VecSeqAccess::<A, f32>::new(
value,
self.default_for_missing_fields,
)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we add more element types such as bf16, f16, f64, u64, u32?

Copy link
Member Author

@laggui laggui May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we're gonna have to (at some point anyway). I started with the two types for my use cases (BF16 for Llama and F32 for all other models like the ResNet family for testing purposes). They're probably the most common too.

Wanted to get a review of the implementation before going further. As mentioned by @antimora currently this doesn't scale very well to add a concrete implementation for each type, but it's the easiest solution I came up with for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see another way: we need to compiler to know the size of the vector of elements at compile time! Maybe we could read the vector as bytes instead and cast them later on?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm maybe. I think for now I'll stick to the current method to wrap up this PR and in the future we could refactor this if needed. At the same time, it's not like the number of types will not be manageable.. so "not scaling" isn't necessarily an issue right now.

@antimora
Copy link
Collaborator

Before merging, please file uncompleted refactor or fixes.

@laggui
Copy link
Member Author

laggui commented May 16, 2024

To close this PR I'll handle the other tensor element types, but I've opened a new issue regarding the other improvements suggested in previous discussions.

@antimora
Copy link
Collaborator

If we do #1773 alone, then we can deprecate serialization. So you don't need to do other accumulating item types.

@laggui
Copy link
Member Author

laggui commented May 21, 2024

If we do #1773 alone, then we can deprecate serialization. So you don't need to do other accumulating item types.

Sure, we can limit the current PR to Vec<u16> (for f16 and bf16) and Vec<f32> (for pretty much all other parameter weights).

The linked issue should capture all element types when we tackle it.

@nathanielsimard nathanielsimard merged commit 550086a into main May 22, 2024
15 checks passed
@nathanielsimard nathanielsimard deleted the fix/record/nested-value-types branch May 22, 2024 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants