You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to achieve the same performance in mobilenetv2_12_int8.onnx, using pytorch to read imagenet dataset, onnxruntime to read model and torchmetrics to calculate accuracy. However, the only model which I have a significant accuracy drop is mobilenetv2_12_int8.onnx, reaching 64.346 % (vs 68.30 reporting on table https://github.com/onnx/models/tree/main/vision/classification/mobilenet#model).
Reproduction instructions
System Information
OS Platform and Distribution (Linux Ubuntu 20.04.4 LTS):
ONNX version (1.13.1):
Backend/Runtime version (ONNX Runtime 1.14.1, PyTorch 2.0.0):
Provide a code snippet to reproduce your errors.
importosimportioimporttarfilefromPILimportImagefromtqdmimporttqdmimporttorchfromtorchvisionimporttransformsasTimporttorchmetricsimportonnxruntime_TORCH_DEVICE="cuda:0"iftorch.cuda.is_available() else"cpu"classImagenetValDataset(torch.utils.data.Dataset):
def__init__(self, img_dir, transform=None):
images_path=os.path.join(img_dir, 'ILSVRC2012_img_val')
try:
self._tf=images_path+'.tar'withtarfile.open(self._tf) astf:
self._img_names=tf.getnames()
exceptExceptionase:
raiseValueError(f"{img_dir} have not 'ILSVRC2012_img_val.tar' ""file or it is corrupted.") fromeself._img_names=sorted(self._img_names)
# Read labelsself._labels= []
withopen(os.path.join(img_dir, 'imagenet_2012_validation_synset_labels.txt')) asf:
whilelabel:=f.readline():
self._labels.append(label.strip())
self._label_names=sorted(set(self._labels))
assertlen(self._img_names) ==len(self._labels), "Incomplete labels!"self.transform=transformdef_get_image(self, name):
image=self._tf.extractfile(name)
image=io.BytesIO(image.read())
image=Image.open(image).convert('RGB')
returnimagedef__len__(self):
returnlen(self._img_names)
def__getitem__(self, index):
# Read tar file here to proper parallelization (just one time)ifisinstance(self._tf, str):
self._tf=tarfile.open(self._tf)
# Read image from tar fileimage=self._get_image(self._img_names[index])
# Apply transformationifself.transformisnotNone:
image=self.transform(image)
# Return image with his labellabel=self._labels[index]
returnimage, self._label_names.index(label)
classOnnxInferencePipeline:
def__init__(self, onnx_path):
self._ort_session=onnxruntime.InferenceSession(onnx_path)
@propertydefinputs(self):
returnself._ort_session.get_inputs()[0]
@staticmethoddefto_numpy(tensor):
returntensor.detach().cpu().numpy() iftensor.requires_gradelsetensor.cpu().numpy()
def__call__(self, inputs: torch.Tensor):
# Generate ort inputsort_inputs= {self.inputs.name: self.to_numpy(inputs)}
# Run inputs in graphort_outputs=self._ort_session.run(None, ort_inputs)
returntorch.from_numpy(ort_outputs[0]).to(inputs.device)
defget_imagenet_dataset(data_path, batch_size=128, image_size=224, num_workers=0):
transform=T.Compose([T.Resize(int(image_size*1.1429)),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
imagenet_data=ImagenetValDataset(data_path, transform=transform)
returntorch.utils.data.DataLoader(imagenet_data,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
defevaluate_model(model, dataset):
print("Starting evaluation...")
num_classes=len(dataset.dataset._label_names)
accuracy=torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
forimages, gt_labelsin (barprog:=tqdm(dataset)):
images=images.to(_TORCH_DEVICE)
pred_labels=model(images).argmax(-1)
acc_step=accuracy(pred_labels.cpu(), gt_labels)
barprog.set_postfix({'acc': acc_step.item()})
print(f"[INFO] Accuracy: {accuracy.compute()}")
if__name__=="__main__":
model_path="mobilenetv2-12-int8.onnx"imagenet_path="/imagenet/dataset"model=OnnxInferencePipeline(model_path)
# Read datasetval_dataset=get_imagenet_dataset(imagenet_path, num_workers=8)
# Processevaluate_model(model, val_dataset)
Notes
ImagenetValDataset needs the list of ordered-labels to work. If you need it, I could provide it.
The text was updated successfully, but these errors were encountered:
Bug Report
Which model does this pertain to?
mobilenetv2-12-int8.onnx
Describe the bug
I am trying to achieve the same performance in
mobilenetv2_12_int8.onnx
, using pytorch to read imagenet dataset,onnxruntime
to read model andtorchmetrics
to calculate accuracy. However, the only model which I have a significant accuracy drop ismobilenetv2_12_int8.onnx
, reaching 64.346 % (vs 68.30 reporting on table https://github.com/onnx/models/tree/main/vision/classification/mobilenet#model).Reproduction instructions
System Information
OS Platform and Distribution (Linux Ubuntu 20.04.4 LTS):
ONNX version (1.13.1):
Backend/Runtime version (ONNX Runtime 1.14.1, PyTorch 2.0.0):
Provide a code snippet to reproduce your errors.
Notes
ImagenetValDataset
needs the list of ordered-labels to work. If you need it, I could provide it.The text was updated successfully, but these errors were encountered: