Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] add bfloat16 support for gptq marlin kernel #4788

Merged

Conversation

jinzhen-lin
Copy link
Contributor

Some models would overflow when using fp16 inference (e.g. Deepseek-V2), so we should add bfloat16 support for quantization kernel. This PR add bfloat16 support for gptq marlin kernel.

Unlike gptq kernel in #4781 , gptq marlin kernel doesn't use atomicAdd, so the performance of bfloat16 is close to float16.

Related issue: #2149

Main changes:

  • add bfloat16 input/output support for cuda kernels
  • dequant qweight to bfloat16 in proper ways.

@robertgshaw2-neuralmagic
Copy link
Collaborator

@alexm-nm can you review this?

Copy link
Contributor

@alexm-neuralmagic alexm-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing the work of adding bfloat16 to marlin. Left some comments.

@@ -9,6 +9,10 @@
#include <cuda_runtime.h>
#include <iostream>

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it is necessary to check here that SM >= 8.0? Shouldn't the "include <cuda_bf16.h> work regardless?

@@ -38,6 +42,7 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// No support for async
#else


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: formatting

C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
}
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a problematic way to add bfloat16 support to marlin, since we should be able to compile the marlin module for both float16 and bfloat16 at the same time. Could you restructure the code to use a template parameter instead to the Marlin<...> kernel and use the template parameter for all of the functions required to have a templated type. If you don't have time, then I can take over and fix it for you. Tell me what works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I would restructure it soon.

@jinzhen-lin
Copy link
Contributor Author

@alexm-nm I have restructured code. Can you review it again.

Copy link
Contributor

@alexm-neuralmagic alexm-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinzhen-lin this looks much better with the template param! I left some minor comments. Could you also add a test to test_gptq_marlin.py with some models that run with dtype.bfloat16 (so we have correctness verified on every change going forward). Again, thanks for the help!

size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, gptq_marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You had an #ifdef to check for CUDA_ARCH >= 8 above whether you access nv_bfloat16. I suppose it generates a compilation error if you don't have the ifdef. I think you should have an ifdef here as well to disable the bfloat16 case so the code compiles for SM < 8.

};

template <>
class ScalarType<half> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much better! Thanks for doing this.

fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);

fp32_intermediates[0] -= 8388736.f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On what code this dequant_8bit is based? Maybe you can document the reference you used.

@alexm-neuralmagic
Copy link
Contributor

alexm-neuralmagic commented May 14, 2024

@bnellnm could you do a quick pass on the template changes.

__device__ inline FragB dequant_4bit(int q) {
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
throw std::runtime_error("unsupported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the standard is but I think most checks in the code use TORCH_CHECK rather than throw.

: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be safer to make the else clause a static_assert so if a new type were added, this function would not silently compile with an empty body, i.e.

} else {
    static_assert(std::is_same<scalar_t, half>::value);
    asm volatile(...);
}


template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
throw std::runtime_error("unsupported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TORCH_CHECK?

num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, gptq_marlin::max_par);
} else {
throw std::runtime_error("gpt_marlin_gemm only supports bfloat16 and float16");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TORCH_CHECK here too?

@bnellnm
Copy link
Contributor

bnellnm commented May 14, 2024

@bnellnm could you do a quick pass on the template changes.

The template changes look good. I had a few minor comments. Mostly the use of TORCH_CHECK over throw (which I think is more "standard").

@alexm-neuralmagic
Copy link
Contributor

@jinzhen-lin I think your code is in good state to land after addressing last comments.

@jinzhen-lin
Copy link
Contributor Author

@alexm-nm @bnellnm All previous comments have been fixed.

As for test in test_gptq_marlin.py:

@alexm-neuralmagic
Copy link
Contributor

alexm-neuralmagic commented May 16, 2024

@jinzhen-lin thanks for adding the tests and fixing all comments. @robertgshaw2-neuralmagic looks good to me to proceed forward.

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 99caa49 into vllm-project:main May 16, 2024
55 checks passed
@robertgshaw2-neuralmagic
Copy link
Collaborator

Thanks all!

tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants