Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deformable DETR #81

Open
fduxvtech opened this issue Nov 4, 2021 · 3 comments
Open

Deformable DETR #81

fduxvtech opened this issue Nov 4, 2021 · 3 comments

Comments

@fduxvtech
Copy link

Hello,

I am planning to create a custom model using the mmdetection framework, I am interested in using deformable attention, I also wrote in the main project;

Are you planning on implementing the the deformable attention plugin or are you eventually available to discuss my implementation? (might become a pull request for the project). I only need some clarification on the enqueue function and how TensorRT handle the batch size (2 different version of enqueue are available, one with explicit batch size and one with Tensor description)

@grimoire
Copy link
Owner

grimoire commented Nov 5, 2021

If your custom ops are derived from IPluginV2DynamicExt, you can use

int32_t enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc,
    const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;

In this method, batch size could be inputDesc[0].dims.d[0] (or whatever axis, as you define the computation)

Another option is IPluginV2Ext, which does not support dynamic shapes, but might be faster. The enqueue method is

int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
        cudaStream_t stream) noexcept override;

Batch size is the first parameter.

TensorRT OSS is a good startup to learn how to write a custom plugin.

@fduxvtech
Copy link
Author

is the batch size needed to be define for each input? let me explain my doubt:
the Kernel launch op:

template <typename scalar_t>
void DeformAttentionForwardCUDAKernelLauncher(  
    cudaStream_t stream
    const scalar_t *data_value
    const int64_t *data_spatial_shapes
    const int64_t *c
    const scalar_t *data_sampling_loc
    const scalar_t *data_attn_weight
    const int batch_size
    const int spatial_size
    const int num_heads
    const int channels
    const int num_levels
    const int num_query
    const int num_point
    scalar_t *data_col)

Where the class member of the plugin are

const int spatial_size
const int num_heads
const int channels
const int num_levels
const int num_query
const int num_point

That I pass in the constructor for serialization and deserialization (they are inferred from the inputs);

The input pointer will contain the following 5 inputs

inputs[0] //data_value
inputs[1] //data_spatial_shapes
inputs[2] //data_level_start_index
inputs[3] //data_sampling_loc,
inputs[4] //data_attn_weight

while

output[0] //data_col

Now according to the pytorch operation the dimensionality are as follows:

  • data_value has shape (batch_size, num_keys, num_heads, floor(embed_dims / num_heads) )
  • data_spatial_shapes has shape (num_levels, 2), last dimension 2 represent (h, w)
  • data_level_start_index has shape (num_levels, )
  • has shape (batch_size ,num_query, num_heads, num_levels, num_point, 2), last dimension 2 represent (x, y)
  • data_attn_weight has shape (batch_size ,num_query, num_heads, num_levels, num_point),

So I have all the info needed, my doubt is how TensorRT handle the batch size, I remember something along the line that the plugins needs the batch size to be specified for all inputs, so I would need to reshape both data_spatial_shapes and data_level_start_index in the pytorch call leading to rework a codebase that I barely know.
I hope you can clarify this doubt about TensorRT since you are experienced in writing plugins; also it should work with both 7.x and 8.x

@grimoire
Copy link
Owner

grimoire commented Nov 5, 2021

Ok
If you are using IPluginV2DynamicExt, Each PluginTensorDesc represents the shape and data type of the input/output tensor. The shape is exactly the same as what they are in PyTorch. inputDesc[i].dims.d[j] equal to input[i].shape[j] in PyTorch. This is the way I create most plugins.

And as for IPluginV2Ext, since the shape is static, you can save everything you need from configurePlugin like this and use it in enqueue. Note that the shape does not include batch. I rarely use it since the dynamic shape is an important feature for me. I do not sure what will happen if different inputs have a different first axis.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants