Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a method to chunk the weights before importing using PyTorchFileRecorder #1632

Open
wcshds opened this issue Apr 15, 2024 · 5 comments
Labels
enhancement Enhance existing features

Comments

@wcshds
Copy link
Contributor

wcshds commented Apr 15, 2024

Description

I encountered difficulties when importing the weights of an LSTM from PyTorch. This is because PyTorch often tends to concatenate different weights together, resulting in significant differences between the structure of LSTM weights in PyTorch and those in Burn. Therefore, I think it is necessary for PyTorchFileRecorder to provide a method to chunk the weights before importing them into Burn.

In Burn the weights of lstm are saved through eight Linears.

@laggui
Copy link
Member

laggui commented Apr 15, 2024

As a reference, the plan is not to support everything such that it can be a 1-to-1 map with PyTorch. The PyTorchFileRecorder is there to help make the transition easier to import your pre-trained weights. There could be a couple different ways to work around this, such as modifying the pytorch model when saved.

But for basic supported layers such as LSTM it's a bit different. Might be something we need to add.

@antimora
Copy link
Collaborator

If you could provide a minimal code, we will look into this. PyTorch screen shot appears to store weights as an array, which is supported by PyTorchFileRecorder.

@wcshds
Copy link
Contributor Author

wcshds commented Apr 15, 2024

If you could provide a minimal code, we will look into this. PyTorch screen shot appears to store weights as an array, which is supported by PyTorchFileRecorder.

Sure.

Rust code:

use burn::{
    backend::{ndarray::NdArrayDevice, NdArray},
    module::Module,
    nn::{Linear, LinearConfig, Lstm, LstmConfig},
    record::{FullPrecisionSettings, Recorder},
    tensor::{backend::Backend, Tensor},
};
use burn_import::pytorch::PyTorchFileRecorder;

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    lstm: Lstm<B>,
    linear: Linear<B>,
}

impl<B: Backend> Model<B> {
    pub fn new(device: &B::Device) -> Self {
        Self {
            lstm: LstmConfig::new(10, 20, true).init(device),
            linear: LinearConfig::new(20, 30).init(device),
        }
    }

    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
        let (_, out) = self.lstm.forward(input, None);
        self.linear.forward(out)
    }
}

fn main() {
    type Backend = NdArray;
    let device = NdArrayDevice::Cpu;

    let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        .load("./example.pt".into(), &device)
        .expect("Should decode state successfully");

    let model: Model<Backend> = Model::new(&device).load_record(record);
}

Python code:

import torch
from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.lstm = nn.LSTM(10, 20)
        self.linear = nn.Linear(20, 30)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out, _ = self.lstm(input)
        return self.linear(out)
    
if __name__ == "__main__":
    model = Model()
    input = torch.randn(2, 8, 10)
    res = model(input)
    print(res.shape)
    print([name for name, _ in model.named_parameters()])

    torch.save(model.state_dict(), "./example.pt")

image

@wcshds
Copy link
Contributor Author

wcshds commented Apr 15, 2024

@antimora @laggui Therefore, importing LSTM's weights from PyTorch should do some remap like this

lstm.weight_ih_l0 => lstm.input_gate.input_transform.weight
                  => lstm.forget_gate.input_transform.weight
                  => lstm.cell_gate.input_transform.weight
                  => lstm.output_gate.input_transform.weight

lstm.bias_ih_l0   => lstm.input_gate.input_transform.bias
                  => lstm.forget_gate.input_transform.bias
                  => lstm.cell_gate.input_transform.bias
                  => lstm.output_gate.input_transform.bias

lstm.weight_hh_l0 => lstm.input_gate.hidden_transform.weight
                  => lstm.forget_gate.hidden_transform.weight
                  => lstm.cell_gate.hidden_transform.weight
                  => lstm.output_gate.hidden_transform.weight

lstm.bias_hh_l0   => lstm.input_gate.hidden_transform.bias
                  => lstm.forget_gate.hidden_transform.bias
                  => lstm.cell_gate.hidden_transform.bias
                  => lstm.output_gate.hidden_transform.bias

@antimora
Copy link
Collaborator

Here is what I would recommend you to unblock you quickly. It will be a while since we handle this use case.

Create a corresponding PyTorch module for lstm type matching the source attributes, so you can load the source PyTorch record without issues. I recommend the following module structure to match the source:

#[derive(Module, Debug)]
pub struct LstmIntermediate<B: Backend>  {
  pub weight_ih_l0: Param<Tensor<B, 2>>,
  pub weight_hh_l0: Param<Tensor<B, 2>>,
  pub bias_ih_l0: Param<Tensor<B, 1>>,
  pub bias_hh_l0: Param<Tensor<B, 1>>,
}
 
impl <B: Backend> LstmIntermediate<B> {
   fn into_lstm(self) -> Lstm<B> {
      ...
   }
}

Once you have it loaded, you can extract the data and chunk them up. You can your logic in into_lstm.

After you made the transformation, just save in Burn's record format so you can load without transformation next time.

@antimora antimora added the enhancement Enhance existing features label Apr 15, 2024
@antimora antimora changed the title provide a method to chunk the weights before importing them into Burn Provide a method to chunk the weights before importing them into Burn Apr 15, 2024
@antimora antimora changed the title Provide a method to chunk the weights before importing them into Burn Provide a method to chunk the weights before importing using PyTorchFileRecorder Apr 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhance existing features
Projects
None yet
Development

No branches or pull requests

3 participants