New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Provide a method to chunk the weights before importing using PyTorchFileRecorder #1632
Comments
As a reference, the plan is not to support everything such that it can be a 1-to-1 map with PyTorch. The But for basic supported layers such as LSTM it's a bit different. Might be something we need to add. |
If you could provide a minimal code, we will look into this. PyTorch screen shot appears to store weights as an array, which is supported by PyTorchFileRecorder. |
Sure. Rust code: use burn::{
backend::{ndarray::NdArrayDevice, NdArray},
module::Module,
nn::{Linear, LinearConfig, Lstm, LstmConfig},
record::{FullPrecisionSettings, Recorder},
tensor::{backend::Backend, Tensor},
};
use burn_import::pytorch::PyTorchFileRecorder;
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
lstm: Lstm<B>,
linear: Linear<B>,
}
impl<B: Backend> Model<B> {
pub fn new(device: &B::Device) -> Self {
Self {
lstm: LstmConfig::new(10, 20, true).init(device),
linear: LinearConfig::new(20, 30).init(device),
}
}
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let (_, out) = self.lstm.forward(input, None);
self.linear.forward(out)
}
}
fn main() {
type Backend = NdArray;
let device = NdArrayDevice::Cpu;
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("./example.pt".into(), &device)
.expect("Should decode state successfully");
let model: Model<Backend> = Model::new(&device).load_record(record);
} Python code: import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(10, 20)
self.linear = nn.Linear(20, 30)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out, _ = self.lstm(input)
return self.linear(out)
if __name__ == "__main__":
model = Model()
input = torch.randn(2, 8, 10)
res = model(input)
print(res.shape)
print([name for name, _ in model.named_parameters()])
torch.save(model.state_dict(), "./example.pt") |
@antimora @laggui Therefore, importing LSTM's weights from PyTorch should do some remap like this lstm.weight_ih_l0 => lstm.input_gate.input_transform.weight => lstm.forget_gate.input_transform.weight => lstm.cell_gate.input_transform.weight => lstm.output_gate.input_transform.weight lstm.bias_ih_l0 => lstm.input_gate.input_transform.bias => lstm.forget_gate.input_transform.bias => lstm.cell_gate.input_transform.bias => lstm.output_gate.input_transform.bias lstm.weight_hh_l0 => lstm.input_gate.hidden_transform.weight => lstm.forget_gate.hidden_transform.weight => lstm.cell_gate.hidden_transform.weight => lstm.output_gate.hidden_transform.weight lstm.bias_hh_l0 => lstm.input_gate.hidden_transform.bias => lstm.forget_gate.hidden_transform.bias => lstm.cell_gate.hidden_transform.bias => lstm.output_gate.hidden_transform.bias |
Here is what I would recommend you to unblock you quickly. It will be a while since we handle this use case. Create a corresponding PyTorch module for lstm type matching the source attributes, so you can load the source PyTorch record without issues. I recommend the following module structure to match the source: #[derive(Module, Debug)]
pub struct LstmIntermediate<B: Backend> {
pub weight_ih_l0: Param<Tensor<B, 2>>,
pub weight_hh_l0: Param<Tensor<B, 2>>,
pub bias_ih_l0: Param<Tensor<B, 1>>,
pub bias_hh_l0: Param<Tensor<B, 1>>,
}
impl <B: Backend> LstmIntermediate<B> {
fn into_lstm(self) -> Lstm<B> {
...
}
} Once you have it loaded, you can extract the data and chunk them up. You can your logic in After you made the transformation, just save in Burn's record format so you can load without transformation next time. |
Description
I encountered difficulties when importing the weights of an LSTM from PyTorch. This is because PyTorch often tends to concatenate different weights together, resulting in significant differences between the structure of LSTM weights in PyTorch and those in Burn. Therefore, I think it is necessary for
PyTorchFileRecorder
to provide a method to chunk the weights before importing them into Burn.In Burn the weights of lstm are saved through eight Linears.
The text was updated successfully, but these errors were encountered: