Source code for jenn.post_processing._histogram

# Copyright (C) 2018 Steven H. Berguin
# This work is licensed under the MIT License.
from __future__ import annotations  # needed if python is 3.9

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from matplotlib.figure import Figure, SubFigure
    from numpy.typing import NDArray

import matplotlib.pyplot as plt


[docs]def plot_histogram( y_pred: NDArray, y_true: NDArray, figsize: tuple[float, float] = (3.25, 3), fontsize: int = 9, legend_fontsize: int = 7, legend_label: str = "data", alpha: float = 0.75, percent: bool = False, ax: plt.Axes | None = None, ) -> Figure | SubFigure | None: """Plot prediction error distribution. .. note:: This method uses ravel(). A NumPy array with shape (n_y, m) becomes (n_y * m,). :param y_pred: predicted values for each dataset, list of arrays of shape (m,) :param y_true: true values for each dataset, list of arrays of shape (m,) :param figsize: figure size :param fontsize: text size to use for axis labels :param fontsize: text size to use for legend labels :param alpha: transparency of dots (between 0 and 1) :param percent: show residuals as percentages :param ax: the matplotlib axes on which to plot the data :return: matplotlib Figure instance """ if ax: fig = ax.get_figure() else: fig, ax = plt.subplots(figsize=figsize, layout="tight") # Sanity check inputs if len(y_true) != len(y_pred): raise ValueError("y_true and y_pred must have same length") # Compute residuals y_pred = y_pred.ravel() y_true = y_true.ravel() residuals = (y_pred - y_true) / y_true * 100 if percent else y_pred - y_true # Compute statistics avg = residuals.mean() std = residuals.std() # Make histogram ax.hist( residuals.ravel(), alpha=alpha, label=legend_label, color="gray", density=True, range=[avg - 6 * std, avg + 6 * std], bins=30, ) # Add statistics avg = residuals.mean() std = residuals.std() ax.axvline(x=avg, color="r", linestyle="-", linewidth=1, label=f"avg = {avg:.3f}") ax.axvline( x=avg + std, color="r", linestyle=":", linewidth=1, label=f"std = {std:.3f}" ) ax.axvline(x=avg - std, color="r", linestyle=":", linewidth=1) # Finish annotating axes ax.set_xlabel("Residuals (%)" if percent else "Residuals", fontsize=fontsize) ax.set_ylabel("Probability", fontsize=fontsize) ax.grid(True) ax.legend(fontsize=legend_fontsize) plt.close(fig) return fig