Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Divergence of accuracy in mobilenetv2_12_int8.onnx #609

Open
Johansmm opened this issue May 11, 2023 · 0 comments
Open

Divergence of accuracy in mobilenetv2_12_int8.onnx #609

Johansmm opened this issue May 11, 2023 · 0 comments
Labels

Comments

@Johansmm
Copy link

Johansmm commented May 11, 2023

Bug Report

Which model does this pertain to?

mobilenetv2-12-int8.onnx

Describe the bug

I am trying to achieve the same performance in mobilenetv2_12_int8.onnx, using pytorch to read imagenet dataset, onnxruntime to read model and torchmetrics to calculate accuracy. However, the only model which I have a significant accuracy drop is mobilenetv2_12_int8.onnx, reaching 64.346 % (vs 68.30 reporting on table https://github.com/onnx/models/tree/main/vision/classification/mobilenet#model).

Reproduction instructions

System Information

OS Platform and Distribution (Linux Ubuntu 20.04.4 LTS):
ONNX version (1.13.1):
Backend/Runtime version (ONNX Runtime 1.14.1, PyTorch 2.0.0):

Provide a code snippet to reproduce your errors.

import os
import io
import tarfile
from PIL import Image
from tqdm import tqdm

import torch
from torchvision import transforms as T
import torchmetrics

import onnxruntime

_TORCH_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"


class ImagenetValDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, transform=None):
        images_path = os.path.join(img_dir, 'ILSVRC2012_img_val')
        try:
            self._tf = images_path + '.tar'
            with tarfile.open(self._tf) as tf:
                self._img_names = tf.getnames()
        except Exception as e:
            raise ValueError(f"{img_dir} have not 'ILSVRC2012_img_val.tar' "
                                "file or it is corrupted.") from e
        self._img_names = sorted(self._img_names)

        # Read labels
        self._labels = []
        with open(os.path.join(img_dir, 'imagenet_2012_validation_synset_labels.txt')) as f:
            while label := f.readline():
                self._labels.append(label.strip())
        self._label_names = sorted(set(self._labels))
        assert len(self._img_names) == len(self._labels), "Incomplete labels!"

        self.transform = transform

    def _get_image(self, name):
        image = self._tf.extractfile(name)
        image = io.BytesIO(image.read())
        image = Image.open(image).convert('RGB')
        return image

    def __len__(self):
        return len(self._img_names)

    def __getitem__(self, index):
        # Read tar file here to proper parallelization (just one time)
        if isinstance(self._tf, str):
            self._tf = tarfile.open(self._tf)

        # Read image from tar file
        image = self._get_image(self._img_names[index])

        # Apply transformation
        if self.transform is not None:
            image = self.transform(image)

        # Return image with his label
        label = self._labels[index]
        return image, self._label_names.index(label)

class OnnxInferencePipeline:
    def __init__(self, onnx_path):
        self._ort_session = onnxruntime.InferenceSession(onnx_path)

    @property
    def inputs(self):
        return self._ort_session.get_inputs()[0]

    @staticmethod
    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    def __call__(self, inputs: torch.Tensor):
        # Generate ort inputs
        ort_inputs = {self.inputs.name: self.to_numpy(inputs)}

        # Run inputs in graph
        ort_outputs = self._ort_session.run(None, ort_inputs)
        return torch.from_numpy(ort_outputs[0]).to(inputs.device)


def get_imagenet_dataset(data_path, batch_size=128, image_size=224, num_workers=0):
    transform = T.Compose([T.Resize(int(image_size * 1.1429)),
                           T.CenterCrop(image_size),
                           T.ToTensor(),
                           T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    imagenet_data = ImagenetValDataset(data_path, transform=transform)
    return torch.utils.data.DataLoader(imagenet_data,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=num_workers)


def evaluate_model(model, dataset):
    print("Starting evaluation...")
    num_classes = len(dataset.dataset._label_names)
    accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

    for images, gt_labels in (barprog := tqdm(dataset)):
        images = images.to(_TORCH_DEVICE)
        pred_labels = model(images).argmax(-1)
        acc_step = accuracy(pred_labels.cpu(), gt_labels)
        barprog.set_postfix({'acc': acc_step.item()})
    print(f"[INFO] Accuracy: {accuracy.compute()}")


if __name__ == "__main__":
    model_path = "mobilenetv2-12-int8.onnx"
    imagenet_path = "/imagenet/dataset"
    model = OnnxInferencePipeline(model_path)

    # Read dataset
    val_dataset = get_imagenet_dataset(imagenet_path, num_workers=8)

    # Process
    evaluate_model(model, val_dataset)

Notes

ImagenetValDataset needs the list of ordered-labels to work. If you need it, I could provide it.

@Johansmm Johansmm added the bug label May 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant