# Copyright (C) 2018 Steven H. Berguin
# This work is licensed under the MIT License.
from __future__ import annotations # needed if python is 3.9
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from matplotlib.figure import Figure, SubFigure
from numpy.typing import NDArray
import matplotlib.pyplot as plt
from jenn.post_processing.metrics import rsquare
[docs]def plot_actual_by_predicted(
y_pred: NDArray,
y_true: NDArray,
figsize: tuple[float, float] = (3.25, 3),
fontsize: int = 9,
legend_fontsize: int = 7,
legend_label: str = "data",
alpha: float = 0.5,
ax: plt.Axes | None = None,
) -> Figure | SubFigure | None:
r"""Plot predicted vs. actual value.
.. note::
This method uses ravel(). A NumPy array with shape :math:`(n_y, m)` will become :math:`(n_y m,)`.
This is useful to merge all responses in one plot. Use indexing to handle responses separately,
e.g. :code:`jenn.plot_actual_by_predicted(y_pred=model.predict(x=x_test[2]), y_true=y_test[2])`.
:param y_pred: predicted values for each dataset, list of arrays of shape (m,)
:param y_true: true values for each dataset, list of arrays of shape (m,)
:param figsize: figure size
:param fontsize: text size to use for axis labels
:param fontsize: text size to use for legend labels
:param alpha: transparency of dots (between 0 and 1)
:param ax: the matplotlib axes on which to plot the data
:return: matplotlib Figure instance
"""
if not legend_fontsize:
legend_fontsize = fontsize
if ax:
fig = ax.get_figure()
else:
fig, ax = plt.subplots(figsize=figsize, layout="tight")
# Sanity check inputs
if y_true.shape != y_pred.shape:
raise ValueError("y_true and y_pred must have same length")
# Loop over datasets to overlay them in one plot (e.g. train, test)
y_pred = y_pred.ravel()
y_true = y_true.ravel()
r2 = rsquare(y_pred, y_true).squeeze()
label = f"{legend_label} (" + r"$R^2$" + f"={r2:.2f})"
ax.scatter(
y_true, y_pred, alpha=alpha, color="gray", label=label, edgecolors="black"
)
# Add a perfect fit line to show deviations
ymin = min(y_pred.min(), y_true.min())
ymax = max(y_pred.max(), y_true.max())
line = [ymin, ymax]
ax.plot(line, line, color="r", linestyle=":", linewidth=1, label="perfect fit line")
# Finish annotating axes
ax.set_xlabel("Actual", fontsize=fontsize)
ax.set_ylabel("Predicted", fontsize=fontsize)
ax.grid(True)
ax.legend(fontsize=legend_fontsize)
ax.set_aspect("equal")
plt.close(fig)
return fig