Trainer Guide

The Trainer represents the final stage of the intended ife-surrogate workflow. While the Kernel (such as RBF) defines the spatial relationship of the data and the Model (such as WidebandGP) manages the data and likelihood structures, the Trainer is responsible for the numerical execution of the learning process.

Decoupling the Trainer from the Model allows us to apply different optimization philosophies, from gradient descent to global heuristic searches, without changing the underlying GP definition.

The Three-Step Workflow

To train a surrogate, we suggest the following pipeline:

  1. Kernel Definition: Select and compose kernels (e.g., Matern12, RBF, or additive compositions via kernel operators like ProductKernel) to define the prior covariance.

  2. Model Assembly: Initialize a specific Model class with a Kernel and some training data. The model defines the objective function (the Marginal Log-Likelihood) that the trainer will minimize.

  3. Execution via Trainer: Pass the Model into a specialized Trainer class.

Depending on the complexity of the loss landscape, one can choose from various trainer implementations:

  • OptaxTrainer: Leverages the optax library for gradient-based optimization. Ideal for high-dimensional parameter spaces where Adam, SGD, or multi-stage learning rate schedules are required.

  • SwarmTrainer: A global optimization approach using Particle Swarm Optimization (PSO). This is particularly useful for non-convex likelihood surfaces where gradient descent might get stuck in local minima.

Each trainer exposes specific hyperparameters (such as swarm size, inertia weights, or optimizer strategies) allowing for fine-grained control over the convergence behavior.

Note

Ensure your data is appropriately scaled before passing it to the Trainer to improve numerical stability during the JAX optimization process.


OptaxTrainer

The OptaxTrainer class leans heavily on the optax library a comprehensive optimization library using jax. Their page on optimizers is particularly helpful here.

Usage:


SwarmTrainer

Technical Overview:

The foundational mathematical model for swarm training is Particle Swarm Optimization (PSO).

Let \(S\) be a swarm of \(N\) particles (candidate solutions or models). Each particle \(i\) has a position vector \(x_i \in \mathbb{R}^d\) representing the model weights or hyperparameters, and a velocity vector \(v_i\).

The objective is to minimize a loss function \(f(x)\).

At iteration \(t\), each particle tracks:

  • Current Position: \(x_i^{(t)}\)

  • Current Velocity: \(v_i^{(t)}\)

  • Personal Best: The best position found by this specific particle so far.

    \[p_{i,\text{best}} = \arg\min_{\tau = 1 \ldots t} f\!\left(x_i^{(\tau)}\right)\]
  • Global Best: The best position found by the entire swarm.

    \[g_{\text{best}} = \arg\min_{j = 1 \ldots N} f\!\left(p_{j,\text{best}}\right)\]

The particles update their trajectory based on inertia, cognitive influence (memory), and social influence (collective knowledge). The velocity update equation is:

\[v_i^{(t+1)} = \underbrace{w \cdot v_i^{(t)}}_{\text{Inertia}} + \underbrace{c_1 r_1 \cdot \left(p_{i,\text{best}} - x_i^{(t)}\right)}_{\text{Cognitive}} + \underbrace{c_2 r_2 \cdot \left(g_{\text{best}} - x_i^{(t)}\right)}_{\text{Social}}\]

where:

  • \(w\) is the inertia weight (controls exploration vs. exploitation)

  • \(c_1, c_2\) are acceleration coefficients

  • \(r_1, r_2 \sim U(0,1)\) are random stochastic factors

The position is then updated as:

\[x_i^{(t+1)} = x_i^{(t)} + v_i^{(t+1)}\]

Usage: