qml.labs.dla.run_opt¶
- run_opt(cost, theta, n_epochs=500, optimizer=None, verbose=False, interrupt_tol=None)[source]¶
Boilerplate jax optimization
- Parameters:
cost (callable) – Cost function with scalar valued real output
theta (Iterable) – Initial values for argument of
cost
n_epochs (int) – Number of optimization iterations
optimizer (optax.GradientTransformation) –
optax
optimizer. Default isoptax.adam(learning_rate=0.1)
.verbose (bool) – Whether progress is output during optimization
interrupt_tol (float) – If not None, interrupt the optimization if the norm of the gradient is smaller than
interrupt_tol
.
Example
from pennylane.labs.dla import run_opt import jax import jax.numpy as jnp import optax jax.config.update("jax_enable_x64", True) def cost(x): return x**2 x0 = jnp.array(0.4) thetas, energy, gradients = run_opt(cost, x0)
When no
optimizer
is passed, we useoptax.adam(learning_rate=0.1)
. We can also use other optimizers, likeoptax.lbfgs
.>>> optimizer = optax.lbfgs(learning_rate=0.1, memory_size=1000) >>> thetas, energy, gradients = run_opt(cost, x0, optimizer=optimizer)
code/api/api/pennylane.labs.dla.run_opt
Download Python script
Download Notebook
View on GitHub