qml.labs.dla.run_opt

run_opt(cost, theta, n_epochs=500, optimizer=None, verbose=False, interrupt_tol=None)[source]

Boilerplate jax optimization

Parameters:
  • cost (callable) – Cost function with scalar valued real output

  • theta (Iterable) – Initial values for argument of cost

  • n_epochs (int) – Number of optimization iterations

  • optimizer (optax.GradientTransformation) – optax optimizer. Default is optax.adam(learning_rate=0.1).

  • verbose (bool) – Whether progress is output during optimization

  • interrupt_tol (float) – If not None, interrupt the optimization if the norm of the gradient is smaller than interrupt_tol.

Example

from pennylane.labs.dla import run_opt
import jax
import jax.numpy as jnp
import optax
jax.config.update("jax_enable_x64", True)

def cost(x):
    return x**2

x0 = jnp.array(0.4)

thetas, energy, gradients = run_opt(cost, x0)

When no optimizer is passed, we use optax.adam(learning_rate=0.1). We can also use other optimizers, like optax.lbfgs.

>>> optimizer = optax.lbfgs(learning_rate=0.1, memory_size=1000)
>>> thetas, energy, gradients = run_opt(cost, x0, optimizer=optimizer)