Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] input_type bfloat16 lead to wrong predict results #16698

Open
jikechao opened this issue Mar 11, 2024 · 0 comments
Open

[Bug] input_type bfloat16 lead to wrong predict results #16698

jikechao opened this issue Mar 11, 2024 · 0 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@jikechao
Copy link
Contributor

Actual behavior

image

What actually happened

Environment

Any environment details, such as: Operating System, TVM version, etc

Steps to reproduce

import tvm
import tvm.relay as relay
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, models


input_shape = (1, 12, 4, 6)
input_data = np.random.random(input_shape)
x = layers.Input(shape=input_shape[1:], dtype='bfloat16')

layer = keras.layers.Dropout(rate=0, seed=[1, 1])
#layer.set_weights(layer.get_weights())

y = layer(x)
model = models.Model(x, y)
#res_keras = model.predict(input_data)
res_keras = model(input_data)
shape_dict = {'input_1': input_shape}
mod, params = relay.frontend.from_keras(model, shape_dict)

with tvm.transform.PassContext(opt_level=3):
    model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()

res_tvm = model(tvm.nd.array(input_data.astype('bfloat16'))).numpy()

print('keras infer result:', res_keras)
print('tvm infer result:', res_tvm)
np.testing.assert_allclose(res_keras, res_tvm, atol=1e-3, rtol=1e-3)

Notice that "float32" can receive a correct inference result.

@jikechao jikechao added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Mar 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

1 participant