You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi everyone, thanks for the great library and apologies in advance for this basic question.
I'm trying to find the true minimum of a small neural network, and I thought of using a solver from optimistix together with an equinox model. However, I haven't been able to make the two work together.
Here is a minimal snippet which fails:
import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
jax.config.update("jax_enable_x64", True)
X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))
@jax.vmap
def function(x):
return x[0] + x[1]**2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4]*x[5] + (x[6]*x[7])**3
y = function(X).reshape(-1, 1)
model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))
static, params = eqx.partition(model, eqx.is_inexact_array)
def loss_fn(params, static, X, y):
model = eqx.combine(params, static)
return jnp.sum((jax.vmap(model)(X) - y)**2)
solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.minimise(loss_fn, solver, params)
I'm getting TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>.
What am I doing wrong?
Thank you in advance.
The text was updated successfully, but these errors were encountered:
Hi everyone, thanks for the great library and apologies in advance for this basic question.
I'm trying to find the true minimum of a small neural network, and I thought of using a solver from
optimistix
together with anequinox
model. However, I haven't been able to make the two work together.Here is a minimal snippet which fails:
I'm getting
TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>
.What am I doing wrong?
Thank you in advance.
The text was updated successfully, but these errors were encountered: