Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Err` value: IrError(OutputNodeNotFound("/linear/MatMul_output_0")) on linear model #191

Open
maxwellflitton opened this issue Oct 3, 2023 · 8 comments

Comments

@maxwellflitton
Copy link

Describe the bug
I've trained a simple linear model in pytorch. I then export it to ONNX. Calling from the ONNX library it works fine. However, when trying to call from wonnx I get the error Err value: IrError(OutputNodeNotFound("/linear/MatMul_output_0")). Looking at the model in neuron everything seems to make sense and in my settings I define the output at the name of 5 as this is the output, I don't know why wonnx is erroring here when onnx works fine.

To Reproduce
Train a simple linear model with the following code:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

squarefoot = np.array([1000, 1200, 1500, 1800, 2000, 2200, 2500, 2800, 3000, 3200], dtype=np.float32)
num_floors = np.array([1, 1, 1.5, 1.5, 2, 2, 2.5, 2.5, 3, 3], dtype=np.float32)
house_price = np.array([200000, 230000, 280000, 320000, 350000, 380000, 420000, 470000, 500000, 520000], dtype=np.float32)

squarefoot_mean = squarefoot.mean()
squarefoot_std = squarefoot.std()
num_floors_mean = num_floors.mean()
num_floors_std = num_floors.std()
house_price_mean = house_price.mean()
house_price_std = house_price.std()

# Normalize the data (optional, but recommended for better convergence)
squarefoot = (squarefoot - squarefoot.mean()) / squarefoot.std()
num_floors = (num_floors - num_floors.mean()) / num_floors.std()
house_price = (house_price - house_price.mean()) / house_price.std()

# Convert numpy arrays to PyTorch tensors
squarefoot_tensor = torch.from_numpy(squarefoot)
num_floors_tensor = torch.from_numpy(num_floors)
house_price_tensor = torch.from_numpy(house_price)


# Define the linear regression model
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(2, 1)  # 2 input features, 1 output

    def forward(self, x):
        return self.linear(x)

# Initialize the model
model = LinearRegressionModel()

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
#
# # Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    # Forward pass
    y_pred = model(X)

    # Compute the loss
    loss = criterion(y_pred.squeeze(), house_price_tensor)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print the progress
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

test_squarefoot = torch.tensor([2800, 3200], dtype=torch.float32)
test_num_floors = torch.tensor([2.5, 3], dtype=torch.float32)
test_inputs = torch.stack([test_squarefoot, test_num_floors], dim=1)
test_inputs = torch.tensor([2800, 3], dtype=torch.float32)

# Test the model
with torch.no_grad():
    predicted_prices = model(test_inputs)
    predicted_prices = predicted_prices.squeeze().numpy()
    print("Predicted Prices:", predicted_prices)

I then perform an onnx export with the following code:

# export to ONNX and save file
torch.onnx.export(model, test_inputs, "./linear_test.onnx")

I then load the model in rust with the following code:

use std::collections::HashMap;
use ndarray::{ArrayD, CowArray};
use std::sync::Arc;
use wonnx::Session;
use wonnx::utils::{InputTensor, OutputTensor, tensor};
use wonnx::SessionConfig;

use std::fs::File;
use std::io::{Read, Result};

pub async fn load_model() {
    let mut file = File::open("./linear_test.onnx").unwrap();

    let mut buffer = Vec::new();

    file.read_to_end(&mut buffer).unwrap();
    let config = SessionConfig::new().with_outputs(Some(vec!["5".to_string()]));
    let session = Session::from_bytes_with_config(&buffer, &config).await.unwrap();
    let mut inputs = HashMap::new();
    inputs.insert("onnx::MatMul_0".to_string(), InputTensor::F32(vec![1000.0, 2.0].into()));
    let outputs = session.run(&inputs).await.unwrap();
    println!("file: {:?}", outputs);
}

and I get the error with the following line:

let session = Session::from_bytes_with_config(&buffer, &config).await.unwrap();

Expected behavior
Merely to run a simple inference

Screenshots
When inspecting the onnx file all the weights seem to match up or am I missing something here?

Screenshot 2023-10-03 at 15 14 16 Screenshot 2023-10-03 at 15 14 29 Screenshot 2023-10-03 at 15 14 40 Screenshot 2023-10-03 at 15 15 06

Desktop

  • OS: MacOs (Ventura V13.4)
  • Chip: Apple M2 Max
  • RAM: 96GB
  • Hard Drive: 3.62TB available of 4TB
  • model: 16-inch 2023
@maxwellflitton
Copy link
Author

is there any update on this? Can anyone help? I've tried it again, and still getting the same error

@pixelspark
Copy link
Collaborator

@maxwellflitton the error indicates an output from an earlier node cannot be found in the ONNX file. Can you share your ONNX file?

@notdanilo
Copy link

I have a similar issue. Here is the onnx file elections.zip (based on https://huggingface.github.io/candle/training/simplified.html)

In my case, it's IrError(OutputNodeNotFound("/ln1/Gemm_output_0")).

@notdanilo
Copy link

I did some debugging. The error is coming from here:

wonnx/wonnx/src/ir.rs

Lines 24 to 26 in 7880ed8

if !value_shapes.contains_key(output_name.as_str()) {
return Err(IrError::OutputNodeNotFound(output_name.to_string()));
}

The shapes are acquired here:

wonnx/wonnx/src/ir.rs

Lines 198 to 208 in 7880ed8

let mut value_shapes: HashMap<&'model str, Shape> = HashMap::new();
for vi in model.get_graph().get_value_info() {
value_shapes.insert(vi.get_name(), vi.get_shape()?);
}
for vi in model.get_graph().get_output() {
let output_name = vi.get_name();
if !output_name.is_empty() {
value_shapes.insert(output_name, vi.get_shape()?);
}
}

But I am assuming our onnx files doesn't have definitions for them because they need to be inferred. Is that assumption correct? I don't know the onnx standard.

@notdanilo
Copy link

The onnx file needs to be pre-processed to infer shapes and save them back in the file.
@maxwellflitton https://github.com/webonnx/wonnx?tab=readme-ov-file#shape-inference

@Dainerx
Copy link

Dainerx commented Mar 12, 2024

@notdanilo I am facing the same issue getting SessionError(IrError(OutputNodeNotFound("onnx::Reshape_1988"))) with a model exported using pytorch 1.12

When I try to infer shapes and save them back to a file

nnx prepare my_model.onnx my_model_prepared.onnx --set batch_size=1 --set sequence_length=255 -i
I get:
Error: Could not infer shapes: unsupported: Split

Any ideas, why this is happening?

@notdanilo
Copy link

@Dainerx

Split shape inference isn't supported in nnx.
https://github.com/webonnx/wonnx?tab=readme-ov-file#supported-operators-ref-onnx-ir

Try this:

pip install onnx-simplifier
python -m onnxsim <input.onnx> <output.onnx>

@astnmsn
Copy link

astnmsn commented Mar 27, 2024

Running into a similar issue with the Equal and Not operations after running the model through onnxsim. Looks like it is unsupported by nnx.

Any pointers on where to get started if I wanted to send a PR? I know where this exists in the wonnx codebase, was just hoping there might be a reference that guided that implementation for the other ops

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants