Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient NewtonCG Implementation #24

Open
quattro opened this issue Nov 15, 2023 · 3 comments
Open

Efficient NewtonCG Implementation #24

quattro opened this issue Nov 15, 2023 · 3 comments
Labels
question User queries

Comments

@quattro
Copy link

quattro commented Nov 15, 2023

Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!

I've been porting over some older code of mine to use optimistix, rather than hand-rolled inference procedures and could use some advice. Currently, I am performing some variational inference using a mix of closed-form updates for variational parameters, as well as gradient-based updates for some hyperparameters. It -roughly- works like,

while True:
  eval_f = jax.value_and_grad(_infer, has_aux=True)
  ((value, var_params), gradient) = eval_f(hyper_param, var_params, data)
  hyper_param = hyper_param + learning_rate * gradient
  if converged:
    break

I'd -like- to retool the above to not only report the current value, aux values (i.e. updated variational parameters), and gradient wrt hyper param, but return a -hvp- function that could be used in a Newton CG like step in Optimistix. I know of the new minimize function, but what isn't clear is how to set up the scenario to not only report gradients, but also return a hvp function internally without having to take two additional passes over the graph (i.e. once for value and grad, another two for hvp => forward + backward).

Is this doable? Apologies if this is somewhat nebulous--I'm happy to clarify.

@quattro quattro changed the title NewtonCG Efficient NewtonCG Implementation Nov 15, 2023
@patrick-kidger
Copy link
Owner

Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!

That's great to hear, thank you!

On HVPs: if I understand you correctly, this is a general JAX question, rather than specifically a question of how to integrate a solve into Optimistix? You're looking to get both the gradient and a HVP without having to treat them both separately (which would be 3 sweeps in total).
This approach will get you 2 sweeps, which I think is optimal:

def to_jvp(x):
    return jax.value_and_grad(fn)(x)

(f, dfdxi), (_, dfdxidxjvj) = jax.jvp(to_jvp, (x,), (v,))

@patrick-kidger patrick-kidger added the question User queries label Nov 17, 2023
@quattro
Copy link
Author

quattro commented Nov 17, 2023

Yes exactly! It isn't clear to me how to work this into a Newton-like CG solver, but I'll keep toying around. Thanks as always for your enthusiasm and detailed help. It is greatly appreciated!

@packquickly
Copy link
Collaborator

Newton CG is one of the algorithms which I think could be somewhat involved to implement. It is further from the existing solvers in Optimistix, so there's more custom work that needs to be done to get it running.

There's two steps I would take if I were implementing it (which I may in the future):

  1. Implement a custom top-level solver which passes the Hessian-vector product $h(v) = \nabla^2 f(x)$ via lx.FunctionLinearOperator in FunctionInfo.EvalGradHessian. Take a look at the abstract solvers (AbstractGradientDescent, AbstractGaussNewton, AbstractBFGS, etc.) for a bit on how to do this, and check out the docs for FunctionInfo. Implementing a top-level solver from scratch for the first time can be a little tricky, because it requires touching all of the abstractions used in Optimistix. However, the technical bit should be very simple here. If you find it difficult to figure out from looking at existing implementations just poke me!
  2. Implement an early-exit version of CG in Lineax. This is needed because for inexact Newton CG the CG iterates stop immediately if a direction of negative curvature is encountered, allowing non-psd matrices to be approximately solved with CG. This should be straightforward, copy the code from the implementation of CG in Lineax and add the curvature condition to the function not_converged (something like tree_dot(y, operator.mv(y)) > 0. This may need tweaking, and you'll also have to disable the positive semidefinite checks.)

This should be pretty much everything though. Just use optx.NewtonDescent with linear_solver=NewEarlyExitCG for the descent, and your favorite Optimistix Search for the search and it should all work! You could probably use some other descents as well, such as DoglegDescent or DampedNewtonDescent by passing the new linear solver in to these descents. I don't know how well these would work, as they're not standard algorithms (neat!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants