Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with input shape #213

Open
whocares0101 opened this issue Apr 18, 2024 · 0 comments
Open

Problem with input shape #213

whocares0101 opened this issue Apr 18, 2024 · 0 comments

Comments

@whocares0101
Copy link

I'm trying to run wd-convnext-tagger-v3 from SmilingWolf using pytorch with cuda.
Is this a bug or am I doing it wrong?

# https://huggingface.co/SmilingWolf/wd-convnext-tagger-v3/blob/main/config.json
# input size is [3, 448, 448]

image_transform = transforms.Compose([
ForceRGB(),
ResizeKeepAspectRatio(448),
RGB2BGR(),
SquarePad(255),
transforms.ToTensor(),
])

# https://huggingface.co/SmilingWolf/wd-convnext-tagger-v3
fn          = 'wd14_tagger_model_onnx_v3/model.onnx'
onnx_model  = onnx.load(fn)
torch_model = onnx2torch.convert(onnx_model)
torch_model.cuda()
torch_model.eval()

image = Image.open('hor.jpg')
data  = image_transform(image)
print(data.size()) # torch.Size([3, 448, 448])

result = torch_model(data.unsqueeze(0).cuda())
# RuntimeError: Given groups=1, weight of size [128, 3, 4, 4], expected input[1, 448, 3, 448] to have 3 channels, but got 448 channels instead

Python 3.10.11 with the following packages

absl-py                      2.1.0
astunparse                   1.6.3
cachetools                   5.3.3
certifi                      2024.2.2
charset-normalizer           3.3.2
colorama                     0.4.6
filelock                     3.13.4
flatbuffers                  24.3.25
fsspec                       2024.3.1
gast                         0.4.0
google-auth                  2.29.0
google-auth-oauthlib         0.4.6
google-pasta                 0.2.0
grpcio                       1.62.1
h5py                         3.10.0
huggingface-hub              0.22.2
idna                         3.6
Jinja2                       3.1.3
keras                        2.10.0
Keras-Preprocessing          1.1.2
libclang                     18.1.1
Markdown                     3.6
MarkupSafe                   2.1.5
mpmath                       1.3.0
networkx                     3.3
numpy                        1.26.4
oauthlib                     3.2.2
onnx                         1.16.0
onnx2torch                   1.5.14
opencv-python                4.9.0.80
opt-einsum                   3.3.0
packaging                    24.0
pillow                       10.3.0
pip                          23.0.1
protobuf                     5.26.1
pyasn1                       0.6.0
pyasn1_modules               0.4.0
PyYAML                       6.0.1
requests                     2.31.0
requests-oauthlib            2.0.0
rsa                          4.9
safetensors                  0.4.3
setuptools                   65.5.0
six                          1.16.0
sympy                        1.12
tensorboard                  2.10.1
tensorboard-data-server      0.6.1
tensorboard-plugin-wit       1.8.1
tensorflow                   2.10.1
tensorflow-estimator         2.10.0
tensorflow-io-gcs-filesystem 0.31.0
termcolor                    2.4.0
torch                        2.2.2+cu118
torchaudio                   2.2.2+cu118
torchvision                  0.17.2+cu118
tqdm                         4.66.2
typing_extensions            4.11.0
urllib3                      2.2.1
Werkzeug                     3.0.2
wheel                        0.43.0
wrapt                        1.16.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant