Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility of least_squares and custom_vjp #50

Open
ahwillia opened this issue Mar 17, 2024 · 2 comments
Open

Incompatibility of least_squares and custom_vjp #50

ahwillia opened this issue Mar 17, 2024 · 2 comments
Labels
feature New feature

Comments

@ahwillia
Copy link

I'm running into some trouble applying optimistix.least_squares(fn, LevenbergMarquardt(...), x0) to certain problems. From the error message below, my understanding of the root cause is that forward-mode autodiff cannot be used on jax.custom_vjp. In my case I am using diffrax to solve an ODE within fn(...), which I think might be causing the problem.

Is my basic understanding correct? Are there specific constraints / assumptions that fn(...) must follow for optimistix.least_squares to work (e.g. cannot use jax.custom_vjp)? Is there any way around this?

The error I get is:

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

The full code to reproduce the error is below. By the way I get the same problem when trying to use jaxopt.LevenbergMarquardt on this problem.

# === imports === #
from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import diffrax
from diffrax import ODETerm, Dopri5, SaveAt
from tqdm import trange
import optimistix
from optimistix import LevenbergMarquardt

# === functions defining flow field and residuals === #
def geodesic_vector_field(P):
    jacP = jax.jacobian(P)
    def vector_field(t, state, args):
        x, v = state
        Pdx = jacP(x)
        q1 = 0.5 * jnp.einsum("jki,j,k->i",Pdx, v, v)
        q2 = jnp.einsum("ilp,l,p->i", Pdx, v, v)
        dxdt = v
        dvdt = jnp.linalg.solve(P(x), q1 + q2)
        return (dxdt, dvdt)
    return vector_field

def exponential_map(x0, v0, term, solver):
    return diffrax.diffeqsolve(
        term, solver, t0=0, t1=1, dt0=0.1, y0=(x0, v0),
        saveat=SaveAt(t0=False, t1=True)
    ).ys[0].ravel()

def shooting_method_resids(x0, x1, term, solver):
    return jax.jit(
        lambda v0, args: (x1 - exponential_map(x0, v0, term, solver)).ravel()
    )

# === try solving the boundary value problem === #
term = ODETerm(geodesic_vector_field(lambda x: jnp.eye(2)))
solver = Dopri5()

optimistix.least_squares(
    shooting_method_resids(jnp.zeros(2), jnp.ones(2), term, solver),
    LevenbergMarquardt(1e-3, 1e-3),
    -1 * jnp.ones(2)
)
patrick-kidger added a commit that referenced this issue Mar 18, 2024
In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a `jax.custom_vjp`, see #50
@patrick-kidger
Copy link
Owner

Yup, you're completely correct in your diagnosis: Diffrax has a jax.custom_vjp for the autodifferentiation through diffeqsolve, and this doesn't support forward-mode autodiff, which is what is used by optx.LevenbergMarquardt to compute its Jacobians.

We have essentially two possible fixes: offer a way for Diffrax to use forward-mode autodifferentiation, or offer a way for Optimistix to use reverse-mode.

For now I've just added the latter. in #51. Try using Optimistix from that branch and see if it solves your problem! You'll need to pass optx.least_squares(..., options=dict(jac="bwd")).

(I'd like to add better forward-mode support for Diffrax, but the best way of doing this is really dependent on JAX just adding directly support for jvp-of-custom_vjp, which I have a draft of here but still seems to be buggy, so I haven't gotten around to finishing it.)

@patrick-kidger patrick-kidger added the feature New feature label Mar 18, 2024
@ahwillia
Copy link
Author

Amazing, works as intended (at least for the simple example I've tried)!

patrick-kidger added a commit that referenced this issue Mar 30, 2024
In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a `jax.custom_vjp`, see #50
patrick-kidger added a commit that referenced this issue May 1, 2024
In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a `jax.custom_vjp`, see #50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants