-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT / Bitsandbytes: Add dequantize
API for bitsandbytes quantized models
#30806
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this new method in quantizer ! This will make fine-tuning with quantized model way easier ! I left a few minor comments.
if cls_name == "Params4bit": | ||
return bnb.functional.dequantize_4bit(weight.data, weight.quant_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user might want to know in which precision the model was dequantized since they don't have the possibility to control that. I think it could be great to give that information since there is no default value (as opposed to from_pretrained
which loads the model in fp32).
Two ways to get that:
- just check the dtype of the weights at the end ( potentially the easiest way )
- check what happens in
dequantize_4bit
. In the method, you see that they get the output dtype with weight.quant_state.dtype.
We can potentially add a torch_dtype
attribute in the future if it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! The output dtype should be correctly inferred here: https://github.com/TimDettmers/bitsandbytes/blob/b891f80ba514833f41f0e9226983b02a9fb5c44b/bitsandbytes/functional.py#L1349 through the compute_dtype so it should be accurate - I added a warning_once
staement to inform users on the dequantized dtype: 1a4a906
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! +1 on all of @SunMarc's comments.
if cls_name == "Params4bit": | ||
return bnb.functional.dequantize_4bit(weight.data, weight.quant_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Returns the converted model and a boolean that indicates if the conversion has been successfull or not. | ||
""" | ||
import bitsandbytes as bnb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already imported at the top of the module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch ! Should be fixed now
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this feature and iterating!
) | ||
# Remove the last key for recursion | ||
current_key_name.pop(-1) | ||
return model, has_been_replaced |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One general comment, if instead you could have a private method _dequantize_and_replace
, which handles the recursion, you don't need to return has_been_replaced
here. When someone calls dequantize_and_replace
, I don't think has_been_replaced
is ever used and could be confusing e.g.:
# This is just dequantize_and_replace from before
def _dequantize_and_replace(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
...
return model, has_been_replaced
def dequantize_and_replace(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
model, has_been_replaced = _dequantize_and_replace(...)
return model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense ! Will do !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 8b904f7 !
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
) | ||
|
||
if not has_been_replaced: | ||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :)
Yeah this is great, thanks |
Great thanks @RonanKMcGovern ! let us know how it goes |
What does this PR do?
Fixes #30177
This PR adds a new feature
dequantize
in order to de-quantize models for interesting usecases such as the one described in #30177The API is very simple:
Users just need to make sure they have enough GPU RAM in order to store the unquantized model, otherwise they might face unexpected behaviour
Added the support for 4-bit / 8-bit models and nice tests + docs to educate users on how to use this new API.
cc @amyeroberts @SunMarc