You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I think, there is a trouble with grouped convolution when i'm using distributed training.
I have tried to use a simple convolution with distributed training, that worked fine. I have tried a grouped convolution on a single device, worked fine.
Log i met:
Traceback (most recent call last):
File "/WavLMJax/distributed_grouped_conv_failure.py", line 86, in
train_step(logical_initialized_state, jax.device_put(batch, mesh_sharding(PartitionSpec('data', None, None))), model)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: during context [hlo verifier]: Expected instruction to have shape equal to f32[128,1,384], actual shape is f32[128,48,384]:
%multiply.33 = f32[128,48,384]{2,1,0} multiply(f32[128,1,384]{2,1,0} %dynamic-slice.6, f32[128,1,384]{2,1,0} %dynamic-slice.6), metadata={op_name="jit(train_step)/jit(main)/mul" source_file="/usr/local/lib/python3.10/dist-packages/optax/_src/transform.py" source_line=98}
Failed after pipeline-start
I think, there is a trouble with grouped convolution when i'm using distributed training.
I have tried to use a simple convolution with distributed training, that worked fine. I have tried a grouped convolution on a single device, worked fine.
Log i met:
Traceback (most recent call last):
File "/WavLMJax/distributed_grouped_conv_failure.py", line 86, in
train_step(logical_initialized_state, jax.device_put(batch, mesh_sharding(PartitionSpec('data', None, None))), model)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: during context [hlo verifier]: Expected instruction to have shape equal to f32[128,1,384], actual shape is f32[128,48,384]:
%multiply.33 = f32[128,48,384]{2,1,0} multiply(f32[128,1,384]{2,1,0} %dynamic-slice.6, f32[128,1,384]{2,1,0} %dynamic-slice.6), metadata={op_name="jit(train_step)/jit(main)/mul" source_file="/usr/local/lib/python3.10/dist-packages/optax/_src/transform.py" source_line=98}
Failed after pipeline-start
Code:
https://colab.research.google.com/drive/117FrrCLar8TVcXncT8kUsZEykallgEqX?usp=sharing (distributed grouped conv)
https://colab.research.google.com/drive/1xmvMAfz4NzNmp7EAxF8EIysGYAsP_jCV?usp=sharing (single device grouped conv)
The text was updated successfully, but these errors were encountered: