Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to pointwize and sampled GW variants #470

Open
patrick-nicodemus opened this issue May 3, 2023 · 1 comment
Open

Improvements to pointwize and sampled GW variants #470

patrick-nicodemus opened this issue May 3, 2023 · 1 comment

Comments

@patrick-nicodemus
Copy link

Currently several algorithms to compute or estimate Gromov-Wasserstein distance are provided, so the user has lots of freedom to experiment with algorithms which are appropriate to their particular distribution size, accuracy requirements, loss function, etc.

However, the pointwise_gromov_wasserstein and sampled_gromov_wasserstein functions are substantially slower than gromov_wasserstein for analogous cases. Our lab is working with distributions of size N=100 and on a 20 core machine, the gromov_wasserstein function takes about 17-20 milliseconds. For pointwise_gromov_wasserstein with 5 iterations, log=False, max_iter=5, it takes between 40 and 80 milliseconds.

Granted, the original paper on Sampled Gromov Wasserstein points out that its advantage is strongest for distributions with N >> 100, and strongest when we are not talking about the square loss. However I do not think this explains the performance difference. I suspect a large share of the performance difference is due to the slowness of the user-supplied loss function being interpreted in a list comprehension each stage of the loop.

I propose that the interface for pointwise_gromov_wasserstein, sampled_gromov_wasserstein and GW_distance_estimation expose a way that users can select from a fixed list of loss operations, including square loss and absolute value loss, and internally these will be implemented in a vectorized way using a performant backend.

@rflamary
Copy link
Collaborator

rflamary commented May 3, 2023

Hello @patrick-nicodemus this makes sens you could give loss either as string for pre computed loss or a function for more geenral ones. feel free to propose a PR and try to respect the API for GW.

@rflamary rflamary changed the title Improvements to GW variants Improvements to pointwize and sampled GW variants May 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants