Source code for jenn.utils.plot

"""Plotting.
============

This module provides optional but helpful utilities to 
assess goodness of fit and visualize trends. 

.. code-block:: python

    #################
    # Example Usage #
    #################

    import jenn 

    # Assuming the following are available: 
    x_train, y_train, dydx_train = _ # user provided
    x_test, y_test, dydx_test = _    # user provided
    nn = _                           # previously trained NeuralNet

    # Show goodness of fit of the partials 
    i = 0  # index of the response to plot
    jenn.utils.plot.goodness_of_fit(
        y_true=dydx_test[i], 
        y_pred=nn.predict_partials(x_test)[i], 
        title="Partial Derivative: dy/dx (NN)"
    )

    # Example: visualize local trends
    jenn.utils.plot.sensitivity_profiles(
        f=[nn.predict], 
        x_min=x_train.min(), 
        x_max=x_train.max(), 
        x_true=x_train, 
        y_true=y_train, 
        resolution=100, 
        legend=['nn'], 
        xlabels=['x'], 
        ylabels=['y'],
    )
"""  # noqa: W291

from collections.abc import Callable
from typing import Dict, List, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np

from .metrics import r_square

LINE_STYLES = {
    "solid": "solid",  # Same as (0, ()) or '-'
    "dotted": "dotted",  # Same as (0, (1, 1)) or ':'
    "dashdot": "dashdot",  # Same as '-.'
    "dashed": "dashed",  # Same as '--'
    "loosely dotted": (0, (1, 10)),
    # "dotted": (0, (1, 1)),
    "densely dotted": (0, (1, 1)),
    "long dash with offset": (5, (10, 3)),
    "loosely dashed": (0, (5, 10)),
    # "dashed": (0, (5, 5)),
    "densely dashed": (0, (5, 1)),
    "loosely dashdotted": (0, (3, 10, 1, 10)),
    "dashdotted": (0, (3, 5, 1, 5)),
    "densely dashdotted": (0, (3, 1, 1, 1)),
    "dashdotdotted": (0, (3, 5, 1, 5, 1, 5)),
    "loosely dashdotdotted": (0, (3, 10, 1, 10, 1, 10)),
    "densely dashdotdotted": (0, (3, 1, 1, 1, 1, 1)),
}


[docs] def actual_by_predicted( y_pred: np.ndarray, y_true: np.ndarray, ax: Union[plt.Axes, None] = None, # noqa: ANN401 figsize: Tuple[float, float] = (3.25, 3), title: str = "", fontsize: int = 9, alpha: float = 1.0, ) -> plt.Figure: # noqa: ANN401 """Create actual by predicted plot for a single response. :param y_pred: predicted values, array of shape (m,) :param y_true: true values, array of shape (m,) :param ax: the matplotlib axes on which to plot the data :param figsize: figure size :param title: title of figure :param fontsize: text size :param alpha: transparency of dots (between 0 and 1) :return: matplotlib figure instance """ if y_pred.ndim > 1: if y_pred.ndim == 2 and 1 in y_pred.shape: pass # just needs to be unraveled else: raise ValueError( f"Expected one dimensional array, " f"but y_pred has {y_pred.ndim} dimensions." ) if y_true.ndim > 1 and 1 not in y_true.shape: if y_true.ndim == 2 and 1 in y_true.shape: pass # just needs to be unraveled else: raise ValueError( f"Expected one dimensional array, " f"but y_true has {y_true.ndim} dimensions." ) fig = plt.figure(figsize=figsize, layout="tight") if not ax: spec = fig.add_gridspec(ncols=1, nrows=1) ax = fig.add_subplot(spec[0, 0]) actual = y_true.ravel() predicted = y_pred.ravel() ax.scatter(actual, predicted, color="k", alpha=alpha) line = [actual.min(), actual.max()] ax.plot(line, line, color="r", linestyle=":") ax.set_xlabel("Predicted", fontsize=fontsize) ax.set_ylabel("Actual", fontsize=fontsize) ax.set_title(title, fontsize=fontsize) ax.grid(True) ax.legend(["predictions", "perfect fit line"], fontsize=fontsize) plt.close(fig) return fig
[docs] def contours( func: Callable, lb: Tuple[float, float], ub: Tuple[float, float], x_train: Union[np.ndarray, None] = None, x_test: Union[np.ndarray, None] = None, figsize: Tuple[float, float] = (3.25, 3), fontsize: int = 9, alpha: float = 0.5, title: str = "", xlabel: str = "", ylabel: str = "", levels: int = 20, resolution: int = 100, ax: Union[plt.Axes, None] = None, # noqa: ANN401 ) -> Union[None, plt.Figure]: # noqa: ANN401 """Plot contours of a scalar function of two variables. :param figsize: figure size :param fontsize: text size :param alpha: transparency of dots (between 0 and 1) :param title: title of figure :param xlabel: factor #1 label :param ylabel: factor #2 label :param levels: number of contour levels :param resolution: line resolution :param ax: the matplotlib axes on which to plot the data :return: matplotlib figure instance """ # Domain m = resolution x1 = np.linspace(lb[0], ub[0], m) x2 = np.linspace(lb[1], ub[1], m) x1, x2 = np.meshgrid(x1, x2) # Response y = np.zeros((m, m)) for i in range(m): for j in range(m): y[i, j] = func(np.array([[x1[i, j]], [x2[i, j]]])).ravel()[0] # Plot if ax: fig = ax.get_figure() else: fig = plt.figure(figsize=figsize) ax = plt.gca() ax.contour(x1, x2, y, levels, cmap="RdGy", alpha=alpha) legend = [] if x_train is not None: ax.scatter(x_train[0], x_train[1], marker=".", c="k", alpha=1) legend.append("train") if x_test is not None: ax.scatter(x_test[0], x_test[1], marker="+", c="r", alpha=1) legend.append("test") if legend: ax.legend(legend, loc=1) ax.set_title(title, fontsize=fontsize) ax.set_xlabel(xlabel, fontsize=fontsize) ax.set_ylabel(ylabel, fontsize=fontsize) plt.close(fig) return fig
[docs] def convergence( histories: List[Dict[str, Dict[str, List[float]]]], figsize: Tuple[float, float] = (3.25, 3), fontsize: int = 9, alpha: float = 1.0, title: str = "", legend: Union[List[str], None] = None, ) -> Union[plt.Figure, None]: # noqa: ANN401 """Plot training history. :param histories: training history for each model :param figsize: subfigure size of each subplot :param fontsize: text size :param alpha: transparency of dots (between 0 and 1) :param title: title of figure :param legend: label for each model :return: matplotlib figure instance """ if not histories: return None fig = plt.figure(figsize=figsize, layout="tight") fig.suptitle(title) linestyles = iter(LINE_STYLES.values()) for history in histories: linestyle = next(linestyles) epochs = list(history.keys()) if len(epochs) > 1: avg_costs = [] for epoch in epochs: batches = history[epoch].keys() avg_batch_costs = [] for batch in batches: avg_batch_cost = np.mean(history[epoch][batch]) avg_batch_costs.append(avg_batch_cost) avg_costs.append(sum(avg_batch_costs) / len(batches)) plt.plot( range(len(epochs)), np.array(avg_costs), alpha=alpha, color="k", linewidth=2, linestyle=linestyle, ) plt.xlabel("epoch", fontsize=fontsize) plt.ylabel("avg cost", fontsize=fontsize) elif len(history["epoch_0"]) > 1: avg_cost = [] batches = history["epoch_0"].keys() for batch in batches: avg_cost.append(np.mean(history["epoch_0"][batch])) plt.plot( range(len(batches)), avg_cost, alpha=alpha, color="k", linewidth=2, linestyle=linestyle, ) plt.xlabel("batch", fontsize=fontsize) plt.ylabel("avg cost", fontsize=fontsize) else: cost = history["epoch_0"]["batch_0"] plt.plot( range(len(cost)), cost, alpha=alpha, color="k", linewidth=2, linestyle=linestyle, ) plt.xlabel("iteration", fontsize=fontsize) plt.ylabel("cost", fontsize=fontsize) ax = plt.gca() if legend: ax.legend(legend) ax.set_yscale("log") plt.close(fig) return fig
[docs] def residuals_by_predicted( y_pred: np.ndarray, y_true: np.ndarray, percent_residuals: bool = False, ax: Union[plt.Axes, None] = None, # noqa: ANN401 figsize: Tuple[float, float] = (3.25, 3), title: str = "", fontsize: int = 9, alpha: float = 1.0, ) -> plt.Figure: # noqa: ANN401 """Create residual by predicted plot for a single response. :param y_pred: predicted values, array of shape (m,) :param y_true: true values, array of shape (m,) :param percent_residuals: shows residuals as percentages if True :param ax: the matplotlib axes on which to plot the data :param figsize: figure size :param title: title of figure :param fontsize: text size :param alpha: transparency of dots (between 0 and 1) :return: matplotlib figure instance """ if y_pred.ndim > 1: if y_pred.ndim == 2 and 1 in y_pred.shape: pass # just needs to be unraveled else: raise ValueError( f"Expected one dimensional array, " f"but y_pred has {y_pred.ndim} dimensions." ) if y_true.ndim > 1 and 1 not in y_true.shape: if y_true.ndim == 2 and 1 in y_true.shape: pass # just needs to be unraveled else: raise ValueError( f"Expected one dimensional array, " f"but y_true has {y_true.ndim} dimensions." ) fig = plt.figure(figsize=figsize, layout="tight") if not ax: spec = fig.add_gridspec(ncols=1, nrows=1) ax = fig.add_subplot(spec[0, 0]) y_true = y_true.ravel() y_pred = y_pred.ravel() if percent_residuals: residuals = 100 * ((y_pred - y_true) / (y_true + 1e-12)).ravel() else: residuals = y_pred - y_true avg_error = residuals.mean() std_error = residuals.std() ax.axhline(y=avg_error, color="k", linestyle="-", linewidth=2) ax.axhline(y=avg_error + std_error, color="k", linestyle=":", linewidth=2) ax.axhline(y=avg_error - std_error, color="k", linestyle=":", linewidth=2) ax.scatter(y_pred, residuals, color="k", alpha=alpha) ax.axhline(y=0, color="r", linestyle=":") ax.set_title(title, fontsize=fontsize) if percent_residuals: ax.set_ylabel("Residuals (%)", fontsize=fontsize) else: ax.set_ylabel("Residuals", fontsize=fontsize) ax.set_xlabel("Predicted", fontsize=fontsize) ax.grid(True) ax.legend([f"avg = {avg_error:.3f}", f"std = {std_error:.3f}"], fontsize=fontsize) plt.close(fig) return fig
[docs] def goodness_of_fit( y_true: np.ndarray, y_pred: np.ndarray, percent_residuals: bool = False, figsize: Tuple[float, float] = (6.5, 3), fontsize: int = 9, alpha: float = 1.0, title: str = "", ) -> plt.Figure: # noqa: ANN401 """Create 'residual by predicted' and 'actual by predicted' plots. :param y_true: true values, array of shape (m,) :param y_pred: predicted values, array of shape (m,) :param percent_residuals: shows residuals as percentages if True :param figsize: figure size :param fontsize: text size :param alpha: transparency of dots (between 0 and 1) :param title: title of figure :return: matplotlib figure instance """ if title is None: title = "" r2 = r_square(y_pred, y_true).squeeze() fig = plt.figure(figsize=figsize, layout="tight") fig.suptitle(title + f" (R-Squared = {r2:.3f})") spec = fig.add_gridspec(ncols=2, nrows=1) ax0 = fig.add_subplot(spec[0, 0]) actual_by_predicted( ax=ax0, y_pred=y_pred, y_true=y_true, fontsize=fontsize, alpha=alpha, ) ax1 = fig.add_subplot(spec[0, 1]) residuals_by_predicted( ax=ax1, y_pred=y_pred, y_true=y_true, percent_residuals=percent_residuals, fontsize=fontsize, alpha=alpha, ) plt.close(fig) return fig
[docs] def sensitivity_profile( ax: plt.Axes, # noqa: ANN401 x0: np.ndarray, y0: np.ndarray, x_pred: np.ndarray, y_pred: Union[np.ndarray, List[np.ndarray]], x_true: Union[np.ndarray, None] = None, y_true: Union[np.ndarray, None] = None, alpha: float = 1.0, xlabel: str = "x", ylabel: str = "y", legend: Union[List[str], None] = None, figsize: Tuple[float, float] = (6.5, 3), fontsize: int = 9, show_cursor: bool = True, ) -> plt.Figure: # noqa: ANN401 """Plot sensitivity profile for a single input, single output. :param ax: the matplotlib axes on which to plot the data :param x0: point at which the profile is centered, array of shape (1,) :param y0: model evaluated as x0, list of arrays of shape (1,) :param x_pred: input values for prediction, an array of shape (m,) :param y_pred: predicted output values for each model, list of arrays of shape (m,) :param x_true: inputs value of actual data, array of shape (m, n_x) :param y_true: output values of actual data. An array of shape (m,) :param alpha: transparency of dots (between 0 and 1) :param xlabel: label of x-axis :param ylabel: label of y-axis :param legend: legend name of each model :param figsize: figure size :param fontsize: text size :param show_cursor: show x0 as a red dot (or not) :return: matplotlib figure instance """ fig = plt.figure(figsize=figsize, layout="tight") if not ax: spec = fig.add_gridspec(ncols=1, nrows=1) ax = fig.add_subplot(spec[0, 0]) if not isinstance(y_pred, list): y_pred = [y_pred] if legend is None: legend = [] x0 = x0.ravel() y0 = y0.ravel() x_pred = x_pred.ravel() linestyles = iter(LINE_STYLES.values()) for array in y_pred: linestyle = next(linestyles) ax.plot(x_pred, array.ravel(), color="k", linestyle=linestyle, linewidth=2) if x_true is not None and y_true is not None: x_true = x_true.ravel() y_true = y_true.ravel() ax.scatter(x_true, y_true, color="k", alpha=alpha) legend.append("data") ax.legend(legend, fontsize=fontsize) if show_cursor: for n in range(y0.size): ax.scatter(x0, y0[n], color="r") ax.set_xlabel(xlabel, fontsize=fontsize) ax.set_ylabel(ylabel, fontsize=fontsize) ax.grid(True) plt.close(fig) return fig
[docs] def sensitivity_profiles( f: Union[Callable, List[Callable]], x_min: np.ndarray, x_max: np.ndarray, x0: Union[np.ndarray, None] = None, x_true: Union[np.ndarray, None] = None, y_true: Union[np.ndarray, None] = None, figsize: Tuple[float, float] = (3.25, 3), fontsize: int = 9, alpha: float = 1.0, title: str = "", xlabels: Union[List[str], None] = None, ylabels: Union[List[str], None] = None, legend: Union[List[str], None] = None, resolution: int = 100, show_cursor: bool = True, ) -> plt.Figure: # noqa: ANN401 """Plot grid of all outputs vs. all inputs evaluated at x0. :param f: callable function(s) for evaluating y_pred = f_pred(x) :param x0: point at which to evaluate profiles, array of shape (n_x, 1) :param x_true: inputs at which y_true is evaluated, array of shape (n_x, m) :param y_true: true values, array of shape (n_y, m) :param figsize: figure size :param fontsize: text size :param alpha: transparency of dots (between 0 and 1) :param title: title of figure :param xlabels: x-axis labels :param ylabels: y-axis labels resolution: line resolution :param legend: legend labels for each model :param show_cursor: show x0 as a red dot (or not) """ funcs = f if not isinstance(funcs, list): funcs = [funcs] x_min = x_min.ravel() x_max = x_max.ravel() if x0 is None: x0 = 0.5 * (x_min + x_max).reshape((-1, 1)) y0 = np.concatenate([func(x0) for func in funcs], axis=1) n_x = x0.shape[0] n_y = y0.shape[0] x_indices = range(n_x) y_indices = range(n_y) xlabels = xlabels if xlabels else [f"x_{i}" for i in x_indices] ylabels = ylabels if ylabels else [f"y_{i}" for i in y_indices] width, height = figsize fig = plt.figure(figsize=(n_x * width, height), layout="tight") fig.suptitle(title) spec = fig.add_gridspec(ncols=n_x, nrows=n_y) for i in x_indices: x_pred = np.tile(x0, (1, resolution)) x_pred[i] = np.linspace(x_min[i], x_max[i], resolution) y_preds = [] for func in funcs: y_pred = func(x_pred) y_preds.append(y_pred) for j in y_indices: sensitivity_profile( ax=fig.add_subplot(spec[j, i]), x0=x0[i], y0=y0[j], x_pred=x_pred[i], y_pred=[y_pred[j] for y_pred in y_preds], x_true=x_true[i] if x_true is not None else None, y_true=y_true[i] if y_true is not None else None, fontsize=fontsize, alpha=alpha, xlabel=xlabels[i], ylabel=ylabels[j], legend=legend, show_cursor=show_cursor, ) plt.close(fig) return fig