Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] fix rope error when load models with different dtypes #4835

Merged
merged 4 commits into from
May 17, 2024

Conversation

jinzhen-lin
Copy link
Contributor

Currently, if we load models with different dtypes in the same process, we would get an error like

File ~/.miniconda3/lib/python3.8/site-packages/vllm/_custom_ops.py:89, in rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
     81 def rotary_embedding(
     82     positions: torch.Tensor,
     83     query: torch.Tensor,
   (...)
     87     is_neox: bool,
     88 ) -> None:
---> 89     vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
     90                               is_neox)

RuntimeError: expected scalar type BFloat16 but found Half

To reproduce:

import torch
from vllm import LLM

model_fp16 = LLM("Qwen/Qwen1.5-0.5B", dtype=torch.half, gpu_memory_utilization=0.4)
model_bf16 = LLM("Qwen/Qwen1.5-0.5B", dtype=torch.bfloat16, gpu_memory_utilization=0.4)

The bug is caused by the rope cache, different dtypes share the same rope module. This PR add dtype to cache key to fix this bug.

@@ -474,7 +474,7 @@ def get_rope(
else:
rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args)
rope_scaling_args, torch.get_default_dtype())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we pass the dtype as an argument instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@@ -463,7 +468,10 @@ def get_rope(
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: is it difficult to always require to pass the dtype instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice that linear module in vllm set param_dtype as an optional argument, so I think it may be better to keep the same.

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
from vllm import LLM

model_fp16 = LLM("Qwen/Qwen1.5-0.5B", dtype=torch.half, gpu_memory_utilization=0.4)
model_bf16 = LLM("Qwen/Qwen1.5-0.5B", dtype=torch.bfloat16, gpu_memory_utilization=0.4)

Can you add this as a regression test? And then it lgtm

@jinzhen-lin
Copy link
Contributor Author

import torch
from vllm import LLM

model_fp16 = LLM("Qwen/Qwen1.5-0.5B", dtype=torch.half, gpu_memory_utilization=0.4)
model_bf16 = LLM("Qwen/Qwen1.5-0.5B", dtype=torch.bfloat16, gpu_memory_utilization=0.4)

Can you add this as a regression test? And then it lgtm

I add a rope module cache test instead of model test, is that ok?

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah test lgtm!

@rkooo567 rkooo567 merged commit 33e0823 into vllm-project:main May 17, 2024
55 checks passed
tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants