Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to load mistralai/Mixtral-8x7B-Instruct-v0.1 using mps #2778

Open
2 of 4 tasks
chimezie opened this issue May 14, 2024 · 5 comments
Open
2 of 4 tasks

Unable to load mistralai/Mixtral-8x7B-Instruct-v0.1 using mps #2778

chimezie opened this issue May 14, 2024 · 5 comments

Comments

@chimezie
Copy link

chimezie commented May 14, 2024

System Info

- `Accelerate` version: 0.30.1
- Platform: macOS-14.2.1-arm64-arm-64bit
- `accelerate` bash location: /path/to/venv/mmlu-eval/bin/accelerate
- Python version: 3.11.6
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.3.0 (False)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 128.00 GB
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoModelForCausalLM
import torch
AutoModelForCausalLM.from_pretrained('mistralai/Mixtral-8x7B-Instruct-v0.1',
                                     revision='main',
                                     torch_dtype=torch.float16,
                                     trust_remote_code=False,
                                     device_map= {'': 'mps'})

Which results in:

File /path/to/python3.11/site-packages/transformers/models/auto/auto_factory.py:563, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    561 elif type(config) in cls._model_mapping.keys():
    562     model_class = _get_model_class(config, cls._model_mapping)
--> 563     return model_class.from_pretrained(
    564         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    565     )
    566 raise ValueError(
    567     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    568     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    569 )

File /path/to/python3.11/site-packages/transformers/modeling_utils.py:3531, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3522     if dtype_orig is not None:
   3523         torch.set_default_dtype(dtype_orig)
   3524     (
   3525         model,
   3526         missing_keys,
   3527         unexpected_keys,
   3528         mismatched_keys,
   3529         offload_index,
   3530         error_msgs,
-> 3531     ) = cls._load_pretrained_model(
   3532         model,
   3533         state_dict,
   3534         loaded_state_dict_keys,  # XXX: rename?
   3535         resolved_archive_file,
   3536         pretrained_model_name_or_path,
   3537         ignore_mismatched_sizes=ignore_mismatched_sizes,
   3538         sharded_metadata=sharded_metadata,
   3539         _fast_init=_fast_init,
   3540         low_cpu_mem_usage=low_cpu_mem_usage,
   3541         device_map=device_map,
   3542         offload_folder=offload_folder,
   3543         offload_state_dict=offload_state_dict,
   3544         dtype=torch_dtype,
   3545         hf_quantizer=hf_quantizer,
   3546         keep_in_fp32_modules=keep_in_fp32_modules,
   3547     )
   3549 # make sure token embedding weights are still tied if needed
   3550 model.tie_weights()
File /path/to/python3.11/site-packages/transformers/modeling_utils.py:3958, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_modules)
   3954                 set_module_tensor_to_device(
   3955                     model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
   3956                 )
   3957     else:
-> 3958         new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
   3959             model_to_load,
   3960             state_dict,
   3961             loaded_keys,
   3962             start_prefix,
   3963             expected_keys,
   3964             device_map=device_map,
   3965             offload_folder=offload_folder,
   3966             offload_index=offload_index,
   3967             state_dict_folder=state_dict_folder,
   3968             state_dict_index=state_dict_index,
   3969             dtype=dtype,
   3970             hf_quantizer=hf_quantizer,
   3971             is_safetensors=is_safetensors,
   3972             keep_in_fp32_modules=keep_in_fp32_modules,
   3973             unexpected_keys=unexpected_keys,
   3974         )
   3975         error_msgs += new_error_msgs
   3976 else:

File /path/to/python3.11/site-packages/transformers/modeling_utils.py:812, in _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix, expected_keys, device_map, offload_folder, offload_index, state_dict_folder, state_dict_index, dtype, hf_quantizer, is_safetensors, keep_in_fp32_modules, unexpected_keys)
    801     state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
    802 elif (
    803     not is_quantized
    804     or (not hf_quantizer.requires_parameters_quantization)
   (...)
    810 ):
    811     # For backward compatibility with older versions of `accelerate` and for non-quantized params
--> 812     set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
    813 else:
    814     hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)

File /path/to/python3.11/site-packages/accelerate/utils/modeling.py:400, in set_module_tensor_to_device(module, tensor_name, device, value, dtype, fp16_statistics, tied_params_map)
    398             module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
    399 elif isinstance(value, torch.Tensor):
--> 400     new_value = value.to(device)
    401 else:
    402     new_value = torch.tensor(value, device=device)

RuntimeError: MPS backend out of memory (MPS allocated: 163.01 GB, other allocations: 384.00 KB, max allowed: 163.20 GB). Tried to allocate 250.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

Expected behavior

Should return a transformer model without error.

@chimezie
Copy link
Author

The system has > 100GB free at the time the code is run

@muellerzr
Copy link
Collaborator

cc @SunMarc

@SunMarc
Copy link
Member

SunMarc commented May 15, 2024

Hi @chimezie, does this happen only with Mixtral-8x7B or with all the models ? From the traceback, the memory was completely used: MPS backend out of memory (MPS allocated: 163.01 GB, other allocations: 384.00 KB, max allowed: 163.20 GB)

@chimezie
Copy link
Author

This seems to happen only with Mixtral-8x7B. I was able to load Llama 3 8, Qwen1.5-14B, and internistai/base-7b-v0.2 for example, without any issue

@SunMarc
Copy link
Member

SunMarc commented May 15, 2024

Mixtral-8x7B is a very big model with around 100GB but you should be able to load the model since you have over 160GB. At which checkpoint does the loading fail ? Near the end ? You can track the memory consumption using the activity monitor on your mac.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants