-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GPTQ Marlin 2:4 sparse structured support #4790
Add GPTQ Marlin 2:4 sparse structured support #4790
Conversation
Benchmark results on A100 for Yi-34B Chat model that has marlin_24 serialized weights (where the actual weight values are not real yet). This is just to show preliminary results to get a feeling of how it compares vs original Marlin, GPTQ and fp16. Original PDF: |
vllm/config.py
Outdated
@@ -160,6 +160,9 @@ def _verify_quantization(self) -> None: | |||
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin" | |||
or quant_cfg.get("is_marlin_format", False)) | |||
|
|||
is_format_marlin_24 = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think about how to clean this up and not have this marlin specific code in vllm/config.py
.
One way to do it that doesn't require more registries: Have an optional class variable checkpoint_format
in the gptq compatible QuantizationConfigs and then in this code, iterate through QUANTIZATION_METHODS
and see if one of them has the associated checkpoint format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion. I have changed the code to encapsulate the marlin specific checkpoint checks into marlin config classes. Tell me if it looks good now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking about something a little more radical like replacing this whole code block with
for name, method in QUANTIZATION_METHODS.items():
if method.supports_checkpoint(quant_cfg):
self.quantization = name
and you would have a default implementation of supports_checkpoint for QuantizationConfig that returns False
, and Marlin would implement the method, print the appropriate warnings and return True
if the quantization should be overridden.
That way you can remove all occurrences of marlin from the config.py file, and this mechanism can also be used by other quantization schemes :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I can try it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcmoritz I have redid the config.py part that you proposed to change. It looks cleaner now :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks good to me (I didn't review the kernel code in detail though). Do you know how much it adds to the binary size? We need to be careful to not increase that too much due to pypi limitations.
I think we have a test for this in the CI |
Nice, thanks for making these changes, this looks a bunch cleaner now! Optional suggestion that would be even cleaner: Rename
The # Detect which checkpoint is it
for name, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(quant_cfg):
if quantization_override:
self.quantization = quantization_override
break This would enable you to shift the following logic into # Allow override of gptq_marlin to gptq (if set explicitly)
if self.quantization == "gptq" and quant_method == "gptq_marlin":
logger.warning(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference")
quant_method = "gptq"
# Choose gptq_marlin if marlin is specified
if self.quantization == "marlin" and quant_method == "gptq_marlin":
self.quantization = quant_method
# Choose marlin if gptq is specified
if self.quantization == "gptq" and quant_method == "marlin":
self.quantization = quant_method |
@pcmoritz This is good idea. Changed the API to return str or None and moved the gptq specific override logic to the override funcs. |
"fp8": Fp8Config, | ||
# The order of gptq methods is important for config.py iteration over | ||
# supports_checkpoint(..) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This is called override_quantization_method
now :)
Wonderful! Small nit and then it looks good to go if the tests pass :) |
Cool, fixed the nit and some other little things. |
Thanks for the suggestions! |
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
The 2:4 sparse without quantization is currently not supported in vLLM yet, right? |
@yzlnew That is correct, currently GPTQ quantization is required |
This PR adds a new GPTQ Marlin 2:4 sparse structured GPU kernel and a support to run 2:4 sparse models in vllm. Currently supported configs are:
The new 2:4 sparse marlin GPU kernel is based on the great work of @LopezCastroRoberto and @dalistarh from @IST-Das. More information will be provided in their upcoming publication.
TODO: