New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix record nested value de/serialization #1751
Conversation
Was running the checks locally and /edit: Previously the That means the characters |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1751 +/- ##
=======================================
Coverage 86.61% 86.61%
=======================================
Files 700 700
Lines 83427 83509 +82
=======================================
+ Hits 72257 72329 +72
- Misses 11170 11180 +10 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this. This is the correct approach, but to be complete, we need the following as well (start with 2 and 3):
-
Handle other Tensor elements. See https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/pytorch/reader.rs#L91-L124 for all types.
-
Instead of serializing
CandleTensor
, we should just convert the underlyingParamSerde
andDataSerialize
structures to NestedValue directly. The data (vector) pointer will be copied to NestedValue::F32s or NestedValue::F16s (a new datatype that preserves encoding). You will need to modify theserialize_data
function at https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/pytorch/reader.rs#L127-L145. -
You will need to do something similar to the
serialize
function in https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/pytorch/adapter.rs#L68-L78. The logic overlaps with convertingParamSerde::new(param_id, DataSerialize::new(data, shape))
to NestedValue.
Once you fix points 2 and 3, we can bypass the serialization of tensors. Later, we can remove serialization artifacts and rely on generating NestedValue directly.
The main reason we want to do points 2 and 3 is that the code still converts each tensor element to NestedValue::F32, which is later collected into a Vec again, although there is no extra memory overhead except for one element at a time. Additionally, we can eliminate code complexity. We need to review the code and see if there are places where the serializer is used, as I believe I only intended to use it for Param objects.
NestedValue::U16s(_) => visitor.visit_seq(VecSeqAccess::<A, u16>::new( | ||
value, | ||
self.default_for_missing_fields, | ||
)), | ||
NestedValue::F32s(_) => visitor.visit_seq(VecSeqAccess::<A, f32>::new( | ||
value, | ||
self.default_for_missing_fields, | ||
)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we add more element types such as bf16
, f16
, f64
, u64
, u32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we're gonna have to (at some point anyway). I started with the two types for my use cases (BF16 for Llama and F32 for all other models like the ResNet family for testing purposes). They're probably the most common too.
Wanted to get a review of the implementation before going further. As mentioned by @antimora currently this doesn't scale very well to add a concrete implementation for each type, but it's the easiest solution I came up with for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see another way: we need to compiler to know the size of the vector of elements at compile time! Maybe we could read the vector as bytes instead and cast them later on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm maybe. I think for now I'll stick to the current method to wrap up this PR and in the future we could refactor this if needed. At the same time, it's not like the number of types will not be manageable.. so "not scaling" isn't necessarily an issue right now.
Before merging, please file uncompleted refactor or fixes. |
To close this PR I'll handle the other tensor element types, but I've opened a new issue regarding the other improvements suggested in previous discussions. |
If we do #1773 alone, then we can deprecate serialization. So you don't need to do other accumulating item types. |
Sure, we can limit the current PR to The linked issue should capture all element types when we tackle it. |
While working on the Llama-3 implementation I stumbled upon a memory issue when importing pytorch weights with
PyTorchFileRecorder
.When I profiled the memory usage for ResNet-152 (checkpoint is 252MB on disk), I saw a huge peak memory usage for what is supposed to be a relatively small model. Up to ~5GB as pointed out by the
heaptrack
trace below.Before
After
Checklist
run-checks all
script has been executed.Changes
Added
U16s
andF32s
variants forNestedValue
so weights can be parsed as a vector of primitive types instead ofVec<NestedValue>
. For example, a vec off32
s is now represented asVec[v, v, v, ...]
instead ofVec[NestedValue::F32(v), NestedValue::F32(v), ...]
. TheNestedValue
enum has a size of 56 bytes so it can grow very rapidly (just imagine for a very large number of parameters like in LLama 8B 🤯 ).Serializer
based on the input element typeVecSeqAccess
'siter
generic and add concrete implementations for vec ofNestedValue
,u16
andf32
Testing
All unit tests pass, including half precision record tests in
burn-import/pytorch-tests
.