Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PicklingError: Can't pickle <function Embedding.forward at XXXXXXX> it's not the same object as torch.nn.modules.sparse.Embedding.forward #2749

Open
1 of 4 tasks
arpit2665 opened this issue May 7, 2024 · 6 comments

Comments

@arpit2665
Copy link

arpit2665 commented May 7, 2024

System Info

I am trying to share the LLMs during inference time between multiple forked processes using torch's ForkingPickler class. I can achieve this with the model loaded with FP16(Without any quantization) and device_map = {"": 0} but can't share the model when device_map = 'auto' with the other forked processes. Could you please help with this issue?

Below is the list of installed libraries
python==3.11.7
torch==2.0.1
transformers==4.37.1
bitsandbytes==0.42.0
accelerate==0.23.0
cuda version = 11.0

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import AutoModelForCausualLLM
from torch.multiprocessing.reductions import ForkingPickler

base_model_name = 'mistral-7b-instruct'

#Loading LLM with FP16(Without any quantization) and device_map = {"":0}
base_model_wo_quant = AutoModelForCausalLLM.from_pretrained(f'{base_model_name}', torch.dtype=torch.float16, device_map={"":0}, use_safetensors=True)

#Able to share with the forked processes
_ = base_model_wo_quant.share_memory()
ForkingPickler.dumps(base_model_wo_quant)

#Loading LLM with FP16(Without any quantization) and device_map = 'auto'
base_model_wo_quant_auto = AutoModelForCausalLLM.from_pretrained(f'{base_model_name}', torch.dtype=torch.float16, device_map='auto', use_safetensors=True)

_ = base_model_wo_quant_auto.share_memory()
#Failed to share with the forked processes getting below error
ForkingPickler.dumps(base_model_wo_quant_auto)

#PicklingError: Can't pickle it's not the same object as torch.nn.modules.sparse.Embedding.forward

Expected behavior

Expecting the model to be shared with the forked processes even when device_map = 'auto'

@muellerzr
Copy link
Collaborator

cc @SunMarc

@SunMarc
Copy link
Member

SunMarc commented May 7, 2024

Hi @arpit2665, thanks for reporting. This is indeed the case since we modify slightly the forward when we use device_map="auto". Could you try to remove the hooks before pickling ? to do that you can use remove_hook_from_module(model, recurse=True) that you get by doing from accelerate.hooks import remove_hook_from_module. Could you also explain what you are trying to acheive with ForkingPickler ?
I closed a PR recently, maybe you can try if the solution proposed works ! #2613

@arpit2665
Copy link
Author

arpit2665 commented May 7, 2024

Hi @SunMarc - Thanks for your reply. I tried above suggested solutions(both of them individually) and the error message is now changed to

PicklingError: Can't pickle function Module.to at XXXXXXX it's not the same object as torch.nn.modules.module.Module.to

Could you also explain what you are trying to acheive with ForkingPickler ?
I am trying to pickle the model and share it with the forked processes using ForkingPickler. It's very similar to the problem statement mentioned in the #2613

Please suggest if there's any solution or workaround to this problem.

@arpit2665
Copy link
Author

Hi, @SunMarc - Is there any solution or workaround to this problem?

@SunMarc
Copy link
Member

SunMarc commented May 13, 2024

Hi @SunMarc - Thanks for your reply. I tried above suggested solutions(both of them individually) and the error message is now changed to

PicklingError: Can't pickle function Module.to at XXXXXXX it's not the same object as torch.nn.modules.module.Module.to

Hi @arpit2665 , this is most likely due to these lines. Could your try to see with commenting them solves the issue ? If so, if you can provide a quick benchmark of how ForkingPickler speeds up the generation, I will if I can come up with a better solution !

Copy link

github-actions bot commented Jun 6, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants