Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using optimistix with an equinox model #56

Open
frostedoyster opened this issue May 7, 2024 · 2 comments
Open

Using optimistix with an equinox model #56

frostedoyster opened this issue May 7, 2024 · 2 comments
Labels
question User queries

Comments

@frostedoyster
Copy link

Hi everyone, thanks for the great library and apologies in advance for this basic question.
I'm trying to find the true minimum of a small neural network, and I thought of using a solver from optimistix together with an equinox model. However, I haven't been able to make the two work together.

Here is a minimal snippet which fails:

import jax 
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx

jax.config.update("jax_enable_x64", True)


X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))

@jax.vmap
def function(x):
    return x[0] + x[1]**2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4]*x[5] + (x[6]*x[7])**3

y = function(X).reshape(-1, 1)

model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))

static, params = eqx.partition(model, eqx.is_inexact_array)

def loss_fn(params, static, X, y):
    model = eqx.combine(params, static)
    return jnp.sum((jax.vmap(model)(X) - y)**2)

solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.minimise(loss_fn, solver, params)

I'm getting TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>.

What am I doing wrong?
Thank you in advance.

@patrick-kidger
Copy link
Owner

You have static and params the wrong way around.

(Caveat: I've not tried running the code, this was just what jumped out at me.)

@patrick-kidger patrick-kidger added the question User queries label May 7, 2024
@frostedoyster
Copy link
Author

Indeed, that was the issue! Thanks a lot for the reply and for the library!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants