Source code for best.plot

"""Make plots for displaying results of BEST test.

This module produces plots similar to those in

Kruschke, J. (2012) Bayesian estimation supersedes the t
    test. Journal of Experimental Psychology: General.
"""
from typing import Optional, Union

try:
    import matplotlib.pyplot as plt
except ImportError:
    import matplotlib as mpl
    mpl.rcParams['backend'] = 'TkAgg'
    import matplotlib.pyplot as plt

import numpy as np
import matplotlib.lines as mpllines
from matplotlib.transforms import blended_transform_factory
import scipy.stats as st

from .model import BestResults, BestResultsOne, BestResultsTwo

# Only this ratio of the samples are displayed,
#  to prevent the long tails distorting the plot.
DISPLAYED_MASS = 0.995

DEFAULT_HDI_MASS = 0.95

PRETTY_BLUE = '#89d1ea'


[docs]def plot_posterior(best_results: BestResults, var_name: str, ax: Optional[plt.Axes] = None, bins: Union[int, list, np.ndarray] = 30, stat: str = 'mode', title: Optional[str] = None, label: Optional[str] = None, ref_val: Optional[float] = None, **kwargs) -> plt.Axes: """Plot a histogram of posterior samples of a variable Parameters ---------- best_results : BestResults The result of the analysis. var_name : string The name of the variable to be plotted. Available variable names are described in the :ref:`sec-variables` section. ax : Matplotlib Axes, optional If not None, the Matplotlib Axes instance to be used. Default: create a new axes. bins : int or list or NumPy array The number or edges of the bins used for the histogram of the data. If an integer, the number of bins to use. If a sequence, then the edges of the bins, including left edge of the first bin and right edge of the last bin. Default: 30 bins. stat : {'mean', 'mode'} Whether to print the mean or the mode of the variable on the plot. Default: 'mode'. title : string, optional Title of the plot. Default: don’t print a title. label : string, optional Label of the *x* axis. Default: don’t print a label. ref_val : float, optional If not None, print a vertical line at this reference value (typically zero). Default: None (don’t print a reference value) **kwargs : dict All other keyword arguments are passed to `plt.hist`. Returns ------- Matplotlib Axes The Axes object containing the plot. Using this return value, the plot can be customized afterwards – for details, see the documentation of the `Matplotlib Axes API <https://matplotlib.org/api/axes_api.html>`_. Examples -------- To plot a histogram of the samples of the mean of the first group in avocado green: >>> import matplotlib as plt >>> ax = best.plot_posterior(best_out, ... 'Group 1 mean', ... color='avocado') >>> plt.show() """ samples = best_results.trace[var_name] samples_min, samples_max = best_results.hdi(var_name, DISPLAYED_MASS) samples = samples[(samples_min <= samples) * (samples <= samples_max)] if ax is None: _, ax = plt.subplots() trans = blended_transform_factory(ax.transData, ax.transAxes) hist_kwargs = {'bins': bins} hist_kwargs.update(kwargs) ax.hist(samples, rwidth=0.8, facecolor=PRETTY_BLUE, edgecolor='none', **hist_kwargs) if stat: if stat == 'mode': stat_val = best_results.posterior_mode(var_name) elif stat == 'mean': stat_val = np.mean(samples) else: raise ValueError('stat parameter must be either "mean" or "mode" ' 'or None.') ax.text(stat_val, 0.99, '%s = %.3g' % (stat, stat_val), transform=trans, horizontalalignment='center', verticalalignment='top', ) if ref_val is not None: ax.axvline(ref_val, linestyle=':') # plot HDI hdi_min, hdi_max = best_results.hdi(var_name, DEFAULT_HDI_MASS) hdi_line, = ax.plot([hdi_min, hdi_max], [0, 0], lw=5.0, color='k') hdi_line.set_clip_on(False) ax.text(hdi_min, 0.04, '%.3g' % hdi_min, transform=trans, horizontalalignment='center', verticalalignment='bottom', ) ax.text(hdi_max, 0.04, '%.3g' % hdi_max, transform=trans, horizontalalignment='center', verticalalignment='bottom', ) ax.text((hdi_min + hdi_max) / 2, 0.14, '95% HDI', transform=trans, horizontalalignment='center', verticalalignment='bottom', ) # make it pretty ax.spines['bottom'].set_position(('outward', 2)) for loc in ['left', 'top', 'right']: ax.spines[loc].set_color('none') # don't draw ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks([]) # don't draw for line in ax.get_xticklines(): line.set_marker(mpllines.TICKDOWN) if label: ax.set_xlabel(label) if title is not None: ax.set_title(title) return ax
def plot_normality_posterior(best_results, ax, bins, title): # TODO merge it into plot_posterior, with a log_x: bool = False parameter # Then we could also center the "95% HPD" text on the log scale. var_name = 'Normality' norm_bins = np.logspace(np.log10(best_results.model.nu_min), np.log10(best_results.hdi(var_name, DISPLAYED_MASS)[-1]), num=bins + 1) plot_posterior(best_results, var_name, ax=ax, bins=norm_bins, title=title, label=r'$\nu$') ax.set_xlim(2.4, norm_bins[-1] * 1.05) ax.semilogx() # don't use scientific notation for tick labels tick_fmt = plt.LogFormatter() ax.xaxis.set_major_formatter(tick_fmt) ax.xaxis.set_minor_formatter(tick_fmt)
[docs]def plot_data_and_prediction(best_results: BestResults, group_id: int = 1, ax: plt.Axes = None, bins: Union[int, list, np.ndarray] = 30, title: Optional[str] = None, hist_kwargs: dict = {}, prediction_kwargs: dict = {}) -> plt.Axes: """Plot samples of predictive distributions and a histogram of the data. This plot can be used as a *posterior predictive check*, to examine how well the model predictions fit the observed data. Parameters ---------- best_results The result of the analysis. group_id : {1, 2} Which group to plot (1 or 2). ax : Matplotlib Axes, optional If not None, the Matplotlib Axes instance to be used. Default: create a new plot. title : string, optional. Title of the plot. Default: no plot title. bins : int or list or NumPy array. The number or edges of the bins used for the histogram of the data. If an integer, the number of bins to use. If a sequence, then the edges of the bins, including left edge of the first bin and right edge of the last bin. Default: 30 bins. hist_kwargs : dict The keyword arguments to be passed to `plt.hist` for the group data. prediction_kwargs : dict The keyword arguments to be passed to `plt.plot` for the posterior predictive curves. Returns ------- Matplotlib Axes The Axes object containing the plot. Using this return value, the plot can be customized afterwards – for details, see the documentation of the `Matplotlib Axes API <https://matplotlib.org/api/axes_api.html>`_. Examples -------- To print the data of the second group, add a hatch to the histogram, and set the limits of the *x* axis to 85 and 115: >>> import matplotlib as plt >>> ax = best.plot_data_and_prediction( ... best_out, ... 2, ... hist_kwargs={'hatch':'...'} ... ) >>> ax.set_xlim(85, 115) >>> plt.show() Notes ----- You can move the histogram in front of the predictive curves by passing ``hist_kwargs={'zorder': 10}`` as an argument, or completely behind the curves with ``hist_kwargs={'zorder': 0}``. If the plot is large enough, it is suggested to put a legend on it, by calling ``ax.legend()`` afterwards. """ if ax is None: _, ax = plt.subplots() group_data = best_results.observed_data(group_id) trace = best_results.trace if isinstance(best_results, BestResultsTwo): means = trace['Group %d mean' % group_id] sigmas = trace['Group %d sigma' % group_id] nus = trace['Normality'] elif isinstance(best_results, BestResultsOne): means = trace['Mean'] sigmas = trace['Sigma'] nus = trace['Normality'] else: raise ValueError('Unknown type of best_results argument') n_curves = 50 n_samps = len(means) idxs = np.random.choice(np.arange(n_samps), n_curves, replace=False) try: xmin = bins[0] xmax = bins[-1] except TypeError: xmin = np.min(group_data) xmax = np.max(group_data) dx = xmax - xmin xmin -= dx * 0.05 xmax += dx * 0.05 x = np.linspace(xmin, xmax, 1000) kwargs = dict(color=PRETTY_BLUE, zorder=1, alpha=0.3) kwargs.update(prediction_kwargs) for i in idxs: v = st.t.pdf(x, nus[i], means[i], sigmas[i]) line, = ax.plot(x, v, **kwargs) line.set_label('Prediction') kwargs = dict(edgecolor='w', facecolor='xkcd:salmon', density=True, bins=bins, label='Observation') kwargs.update(hist_kwargs) ax.hist(group_data, **kwargs) # draw a translucent histogram in front of the curves if 'zorder' not in hist_kwargs and 'alpha' not in hist_kwargs: kwargs.update(dict(zorder=3, label=None, alpha=0.3)) ax.hist(group_data, **kwargs) ax.text(0.95, 0.95, r'$\mathrm{N}=%d$' % len(group_data), transform=ax.transAxes, horizontalalignment='right', verticalalignment='top' ) for loc in ['top', 'right']: ax.spines[loc].set_color('none') # don't draw ax.spines['left'].set_color('gray') ax.set_xlabel('Observation') ax.set_xlim(xmin, xmax) ax.set_ylabel('Probability') ax.set_yticks([]) ax.set_ylim(0) if title: ax.set_title(title) return ax
[docs]def plot_all_two(best_results: BestResultsTwo, bins: int = 30, group1_name: str = 'Group 1', group2_name: str = 'Group 2') -> plt.Figure: """Plot posteriors of every parameter and observation of a two-group analysis. Parameters ---------- best_results : BestResultsTwo The result of the analysis. bins : int The number of bins to be used for the histograms. Default: 30. group1_name : string Name of the first group, to be used in the titles. Default: "Group 1". group2_name : string Name of the second group, to be used in the titles. Default: "Group 2". Returns ------- plt.Figure The created figure. (The separate plots can be accessed via ``fig.axes``, where ``fig`` is the return value of this function.) Notes ----- :ref:`This section <sec-mean-mode>` explains when is the mean or the mode printed. """ assert type(bins) is int, 'bins argument must be an integer.' trace = best_results.trace posterior_mean1 = trace['Group 1 mean'] posterior_mean2 = trace['Group 2 mean'] posterior_means = np.concatenate((posterior_mean1, posterior_mean2)) _, bin_edges_means = np.histogram(posterior_means, bins=bins) posterior_std1 = trace['Group 1 SD'] posterior_std2 = trace['Group 2 SD'] std1_min, std1_max = best_results.hdi('Group 1 SD', DISPLAYED_MASS) std2_min, std2_max = best_results.hdi('Group 2 SD', DISPLAYED_MASS) std_min = min(std1_min, std2_min) std_max = max(std1_max, std2_max) stds = np.concatenate((posterior_std1, posterior_std2)) stds = stds[(std_min <= stds) * (stds <= std_max)] _, bin_edges_stds = np.histogram(stds, bins=bins) fig, axes = plt.subplots(5, 2, figsize=(8.2, 11)) axes[0, 0].get_shared_x_axes().join(axes[0, 0], axes[1, 0]) plot_posterior(best_results, 'Group 1 mean', ax=axes[0, 0], bins=bin_edges_means, stat='mean', title='%s mean' % group1_name, label=r'$\mu_1$') plot_posterior(best_results, 'Group 2 mean', ax=axes[1, 0], bins=bin_edges_means, stat='mean', title='%s mean' % group2_name, label=r'$\mu_2$') axes[2, 0].get_shared_x_axes().join(axes[2, 0], axes[3, 0]) plot_posterior(best_results, 'Group 1 SD', ax=axes[2, 0], bins=bin_edges_stds, title='%s std. dev.' % group1_name, label=r'$\mathrm{sd}_1$') plot_posterior(best_results, 'Group 2 SD', ax=axes[3, 0], bins=bin_edges_stds, title='%s std. dev.' % group2_name, label=r'$\mathrm{sd}_2$') plot_normality_posterior(best_results, axes[4, 0], bins, 'Normality') plot_posterior(best_results, 'Difference of means', ax=axes[2, 1], bins=bins, title='Difference of means', stat='mean', ref_val=0, label=r'$\mu_1 - \mu_2$') plot_posterior(best_results, 'Difference of SDs', ax=axes[3, 1], bins=bins, title='Difference of std. dev.s', ref_val=0, label=r'$\mathrm{sd}_1 - \mathrm{sd}_2$') plot_posterior(best_results, 'Effect size', ax=axes[4, 1], bins=bins, title='Effect size', ref_val=0, label=r'$(\mu_1 - \mu_2) / \sqrt{(\mathrm{sd}_1^2 + \mathrm{sd}_2^2)/2}$') group1_data = best_results.observed_data(1) group2_data = best_results.observed_data(2) obs_vals = np.concatenate((group1_data, group2_data)) bin_edges = np.linspace(np.min(obs_vals), np.max(obs_vals), bins + 1) axes[0, 1].get_shared_x_axes().join(axes[0, 1], axes[1, 1]) plot_data_and_prediction(best_results, 1, ax=axes[0, 1], bins=bin_edges, title='%s data with post. pred.' % group1_name) plot_data_and_prediction(best_results, 2, ax=axes[1, 1], bins=bin_edges, title='%s data with post. pred.' % group2_name) fig.tight_layout() return fig
[docs]def plot_all_one(best_results: BestResultsOne, bins: int = 30, group_name: Optional[str] = None) -> plt.Figure: """Plot posteriors of every parameter and observation of a two-group analysis. Parameters ---------- best_results : BestResultsOne The result of the analysis. bins : int The number of bins to be used for the histograms. Default: 30. group_name : string, optional If not None, group name to be used in the title, e.g. if ``group_name`` is ``"eTRF day 5"`` then the plot for the mean is titled “eTRF day 5 mean”. If None, then group name is omitted from the titles, resulting in e.g. “Mean”. Default: None. Returns ------- plt.Figure The created figure. (The separate plots can be accessed via ``fig.axes``, where ``fig`` is the return value of this function.) Notes ----- :ref:`This section <sec-mean-mode>` explains when is the mean or the mode printed. """ assert type(bins) is int, 'bins argument must be an integer.' def maybe_caps(title): if group_name: return group_name + ' ' + title else: return title.capitalize() fig, axes = plt.subplots(3, 2, figsize=(8.2, 6.6)) plot_posterior(best_results, 'Mean', ax=axes[0, 0], bins=bins, stat='mean', title=maybe_caps('mean'), label=r'$\mu$') plot_posterior(best_results, 'SD', ax=axes[0, 1], bins=bins, title=maybe_caps('std. dev.'), label=r'$\sigma$') plot_normality_posterior(best_results, axes[1, 0], bins, maybe_caps('normality')) ref_val = best_results.model.ref_val if ref_val == 0: label = r'$\mu / \sigma$' else: label = r'$(\mu - %.1f) / \sigma$' % ref_val plot_posterior(best_results, 'Effect size', ax=axes[1, 1], bins=bins, title=maybe_caps('effect size'), ref_val=ref_val, label=label) plot_data_and_prediction(best_results, 1, ax=axes[2, 0], bins=bins, title=maybe_caps('data with post. pred.')) fig.delaxes(axes[2, 1]) fig.tight_layout() return fig
[docs]def plot_all(best_results: BestResults, *args, **kwargs) -> plt.Figure: """Plot posteriors of every parameter and observation of an analysis. Depending on the type of best_results, this call is equivalent to calling :func:`plot_all_one` or :func:`plot_all_two`. """ if isinstance(best_results, BestResultsOne): return plot_all_one(best_results, *args, **kwargs) elif isinstance(best_results, BestResultsTwo): return plot_all_two(best_results, *args, **kwargs) else: raise ValueError('best_results argument is of unknown type')