Source code for jenn.post_processing._contours

# Copyright (C) 2018 Steven H. Berguin
# This work is licensed under the MIT License.
from __future__ import annotations  # needed if python is 3.9

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from collections.abc import Callable

    from matplotlib.figure import Figure, SubFigure

import matplotlib.pyplot as plt
import numpy as np


[docs]def plot_contours( # noqa: C901 func: Callable, x_min: np.ndarray, x_max: np.ndarray, x0: np.ndarray | None = None, x1_index: int = 0, x2_index: int = 1, y_index: int = 0, x_train: np.ndarray | None = None, x_test: np.ndarray | None = None, figsize: tuple[float, float] = (3.25, 3), fontsize: int = 9, alpha: float = 0.5, title: str = "", x1_label: str | None = None, x2_label: str | None = None, y_label: str | None = None, levels: int = 20, resolution: int = 100, show_colorbar: bool = False, ax: plt.Axes | None = None, ) -> Figure | SubFigure | None: """Plot contours of a scalar function of two variables, y = f(x1, x2). .. note:: This method takes in a function of signature form y=f(x) and maps it onto a function of signature form y=f(x1, x2) such that the contours can be plotted. :param func: the function to be evaluate, y = f(x) :param lb: lower bounds on x :param ub: upper bounds on x :param x1_index: index of x to use for factor #1 :param x2_index: index of x to use for factor #2 :param y_index: index of y to be plotted :param x_train: option to overlay training data if provided :param x_test: option to overlay test data if provided :param figsize: figure size :param fontsize: text size :param alpha: transparency of dots (between 0 and 1) :param title: title of figure :param x1_label: factor #1 label :param x2_label: factor #1 label :param y_label: response label :param levels: number of contour levels :param resolution: line resolution :param show_colorbar: show the colorbar :param ax: the matplotlib axes on which to plot the data :return: matplotlib figure instance """ if x0 is None: x0 = 0.5 * (x_min + x_max).reshape((-1, 1)) if x1_label is None: x1_label = f"x{x1_index}" if x2_label is None: x2_label = f"x{x2_index}" if y_label is None: y_label = f"y{y_index}" # Domain m = resolution x1 = np.linspace(x_min[x1_index], x_max[x1_index], m) x2 = np.linspace(x_min[x2_index], x_max[x2_index], m) x1, x2 = np.meshgrid(x1, x2) # Response y = np.zeros((m, m)) for i in range(m): for j in range(m): x = x0.copy() x[x1_index] = x1[i, j] x[x2_index] = x2[i, j] y[i, j] = func(x).ravel()[y_index] # Plot if ax: fig = ax.get_figure() else: fig = plt.figure(figsize=figsize) ax = plt.gca() cs = ax.contour(x1, x2, y, levels, cmap="RdGy", alpha=alpha) if show_colorbar: cbar = plt.colorbar(cs, shrink=1, location="right") cbar.set_label(y_label) # Label for the colorbar legend = [] if x_train is not None: ax.scatter(x_train[0], x_train[1], marker=".", c="k", alpha=1) legend.append("train") if x_test is not None: ax.scatter(x_test[0], x_test[1], marker="+", c="r", alpha=1) legend.append("test") if legend: ax.legend(legend, loc=1) ax.set_title(title, fontsize=fontsize) ax.set_xlabel(x1_label, fontsize=fontsize) ax.set_ylabel(x2_label, fontsize=fontsize) plt.close(fig) return fig