Source code for mmgp.backends.gpjax

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#
from mmgp.regressor import RegressorBase

import numpy as np
from typing import Self

import gpjax as gpx
import optax as ox
import jax.numpy as jnp
import jax


def _parse_gpjax_options(options: dict) -> dict:
    """Function that checks that all the required keys for the GPJax model exist,
    and that their values are admissible.

    If the following mandatory keys are missing, they will be set to their default values:

        - kernel: "Matern52"
        - optim: "adam",
        - num_restarts: 2
        - max_iters: 1000
        - anisotropic: True
        - nugget_est: True
        - batch_size: -1
        - step_size: 1e-2
        - seed: 0

    Args:
        options (dict): A dictionary of options specific to the GPJax model.

    Returns:
        hparams (dict): A cleaned dictionary of options for the GPJax model.
    """

    hparams = {
        "kernel": "Matern52",
        "optim": "adam",
        "num_restarts": 2,
        "max_iters": 1_000,
        "anisotropic": True,
        "nugget_est": True,
        "batch_size": -1,
        "step_size": 1e-2,
        "seed": 0,
    }
    for key in hparams.keys():
        if key in options:
            hparams[key] = options[key]
    return hparams


[docs] class Regressor(RegressorBase): """ gpjax regressor """ def __init__(self, options: dict): """ Args: algo (str): The regression algorithm to use. options (dict): A dictionary of options specific to the chosen algorithm. Allowed fields are "kernel", "optim", "num_restarts", "max_iters" and "anisotropic". """ super(Regressor, self).__init__()
[docs] self.algo = "GPJax"
[docs] self.options = _parse_gpjax_options(options)
[docs] def fit(self, X: np.ndarray, y: np.ndarray) -> Self: """Train the regression model on the provided data. Args: X (np.ndarray): The input features for training. y (np.ndarray): The target values for training. """ self.input_dim = X.shape[1] self.output_dim = y.shape[1] kernel_name = self.options["kernel"] if hasattr(gpx.kernels, kernel_name): kernel_cls = getattr(gpx.kernels, kernel_name) else: raise ImportError(f"Kernel {kernel_name} is not available in gpjax.kernels") optim_name = self.options["optim"] if hasattr(ox, optim_name): optim_cls = getattr(ox, optim_name) else: raise ImportError(f"Optimizer {optim_name} is not available in optax") def _gpjax_fit(dataset: gpx.Dataset, anisotropic: bool, nugget_est: bool, step_size: float, max_iters: int, batch_size: int, rng_key: jax.random.KeyArray): if anisotropic: lengthscale = jnp.ones(dataset.in_dim) else: lengthscale = jnp.array([1.0]) if nugget_est: kernel = kernel_cls(lengthscale=lengthscale) + gpx.kernels.White() else: kernel = kernel_cls(lengthscale=lengthscale) meanf = gpx.mean_functions.Constant(jnp.ones(dataset.out_dim)) prior = gpx.Prior(mean_function=meanf, kernel=kernel) likelihood = gpx.Gaussian(num_datapoints=dataset.n) posterior = prior * likelihood negative_mll = gpx.objectives.ConjugateMLL(negative=True) negative_mll(posterior, train_data=dataset) negative_mll = jax.jit(negative_mll) opt_posterior, history = gpx.fit( model=posterior, objective=negative_mll, train_data=dataset, optim=optim_cls(learning_rate=step_size), num_iters=max_iters, batch_size=batch_size, safe=True, key=rng_key, verbose=False, ) return opt_posterior, history dataset = gpx.Dataset(X=X, y=y) rng_key = jax.random.PRNGKey(self.options["seed"]) f_ub = jnp.inf for _ in range(self.options["num_restarts"]): _, rng_key = jax.random.split(rng_key) kmodel, history = _gpjax_fit(dataset, self.options["anisotropic"], self.options["nugget_est"], self.options["step_size"], self.options["max_iters"], self.options["batch_size"], rng_key) if history[-1] < f_ub: f_ub = history[-1] self.kmodel = kmodel self.gpjax_dataset = dataset self.gpjax_rng_key = rng_key return self
[docs] def predict(self, X: np.ndarray) -> np.ndarray: """Make predictions using the trained regression model. Args: X (np.ndarray): The input features for making predictions. Returns: np.ndarray: Predicted target values. """ latent_dist = self.kmodel.predict(X_test, train_data=self.gpjax_dataset) self.gpjax_predictive_dist = self.kmodel.likelihood(latent_dist) return self.gpjax_predictive_dist.mean()
[docs] def predict_Monte_Carlo_draws( self, X: np.ndarray, size: int = 100) -> np.ndarray: """Generate Monte Carlo draws from the trained regression model. Args: X (np.ndarray): The input features for generating draws. size (int, optional): The number of Monte Carlo draws to generate. Defaults to 100. Returns: np.ndarray: Monte Carlo draws from the posterior of the regression model. """ # import jax # _, rng_key = jax.random.split(self.gpjax_rng_key) # self.gpjax_rng_key = rng_key # return self.gpjax_predictive_dist.sample(sample_shape=(size,), key=rng_key) raise NotImplementedError