Source code for arviz_plots.plots.dgof_dist_plot

"""dgof_dist plot code."""
from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal

import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.dgof_plot import plot_dgof
from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import (
    get_visual_kwargs,
    process_group_variables_coords,
    set_grid_layout,
)


[docs] def plot_dgof_dist( dt, *, var_names=None, filter_vars=None, group="posterior", coords=None, sample_dims=None, kind=None, envelope_prob=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal["dist", "ecdf_lines", "credible_interval", "title", "xlabel", "ylabel"], Sequence[str], ] = None, visuals: Mapping[ Literal[ "dist", "ecdf_lines", "credible_interval", "title", "xlabel", "ylabel", ], Mapping[str, Any] | bool, ] = None, stats: Mapping[Literal["dist", "ecdf_pit"], Mapping[str, Any] | xr.Dataset] = None, **pc_kwargs, ): """Plot 1D marginal distributions and a Δ-ECDF-PIT diagnostic. The marginal distributions are plotted using the specified `kind` (kde, histogram, or quantile dot plot). Additionally, a Δ-ECDF-PIT diagnostic is plotted to assess the goodness-of-fit of the estimated distributions to the underlying data [1]_. If the estimated distributions are accurate, the PIT values should be uniformly distributed on [0, 1], resulting in a Δ-ECDF close to zero. Simultaneous confidence bands are computed using simulation method described in [2]_. Parameters ---------- dt : DataTree Input data var_names : str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, "like", "regex"}, optional, default=None If None (default), interpret var_names as the real variables names. If "like", interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. group : str, default "posterior" Group to be plotted. coords : dict, optional Coordinates to be used to index data variables. sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` kind : {"kde", "hist", "dot"}, optional How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]`` envelope_prob : float, optional Indicates the probability that should be contained within the envelope. Defaults to ``rcParams["stats.envelope_prob"]``. plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional labeller : labeller, optional aes_by_visuals : mapping, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. visuals : mapping of {str : mapping or bool}, optional Valid keys are: * dist -> depending on the value of `kind` passed to: * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` * "hist" -> passed to :func: `~arviz_plots.visuals.step_hist` * "dot" -> passed to :func:`~arviz_plots.visuals.scatter_xy` * credible_interval -> passed to :func:`~arviz_plots.visuals.fill_between_y` * ecdf_lines -> passed to :func:`~arviz_plots.visuals.line_xy` * title -> passed to :func:`~arviz_plots.visuals.title` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` stats : mapping, optional Valid keys are: * dist -> passed to kde, ecdf and qds for both dist plot and dgof plot * ecdf_pit -> passed to :func:`~arviz_stats.ecdf_utils.ecdf_pit`. Default is ``{"n_simulations": 1000}``. **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.grid` Returns ------- PlotCollection Examples -------- Default plot with quantile dot marginals and Δ-ECDF-PIT diagnostic: .. plot:: :context: close-figs >>> from arviz_plots import plot_dgof_dist, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data("centered_eight") >>> plot_dgof_dist(dt, var_names=["mu" , "tau"], kind="dot"); .. minigallery:: plot_dgof_dist References ---------- .. [1] Säilynoja et al. *Recommendations for visual predictive checks in Bayesian workflow*. (2025) arXiv preprint https://arxiv.org/abs/2503.01509 .. [2] Säilynoja et al. *Graphical test for discrete uniformity and its applications in goodness-of-fit evaluation and multiple sample comparison*. Statistics and Computing 32(32). (2022) https://doi.org/10.1007/s11222-022-10090-6 """ if envelope_prob is None: envelope_prob = rcParams["stats.envelope_prob"] if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] if kind is None: kind = rcParams["plot.density_kind"] if stats is None: stats = {} if visuals is None: visuals = {} distribution = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend if kind not in ("kde", "hist", "dot"): raise ValueError("kind must be either 'kde', 'hist', 'dot'") plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs.setdefault("cols", ["column"]) aux_dim_list = [dim for dim in distribution.dims if dim not in sample_dims] pc_kwargs.setdefault("rows", ["__variable__"] + aux_dim_list) pc_kwargs = set_grid_layout(pc_kwargs, plot_bknd, distribution, num_cols=2) plot_collection = PlotCollection.grid( distribution.expand_dims(column=2).assign_coords(column=["dist", "gof"]), backend=backend, **pc_kwargs, ) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() if labeller is None: labeller = BaseLabeller() # dens visuals_dist = { key: False for key in ("credible_interval", "point_estimate", "point_estimate_text") } visuals_dist["dist"] = get_visual_kwargs(visuals, "dist") visuals_dist["title"] = get_visual_kwargs(visuals, "title") plot_collection.coords = {"column": "dist"} plot_dist( dt, var_names=var_names, filter_vars=filter_vars, group=group, coords=coords, sample_dims=sample_dims, kind=kind, plot_collection=plot_collection, labeller=labeller, aes_by_visuals={key: value for key, value in aes_by_visuals.items() if key == "dist"}, visuals=visuals_dist, stats={"dist": stats.get("dist", {})}, ) plot_collection.coords = None # dgof visuals_dgof = get_visual_kwargs(visuals, "dgof") visuals_dgof["title"] = get_visual_kwargs(visuals, "title") visuals_dgof["ecdf_lines"] = get_visual_kwargs(visuals, "ecdf_lines") visuals_dgof["credible_interval"] = get_visual_kwargs(visuals, "credible_interval") visuals_dgof["xlabel"] = get_visual_kwargs(visuals, "xlabel") visuals_dgof["ylabel"] = get_visual_kwargs(visuals, "ylabel") stats_dgof = {"ecdf_pit": stats.get("ecdf_pit", {})} stats_dgof["dist"] = stats.get("dist", {}) plot_collection.coords = {"column": "gof"} plot_dgof( dt, envelope_prob=envelope_prob, kind=kind, var_names=var_names, filter_vars=filter_vars, group=group, coords=coords, sample_dims=sample_dims, plot_collection=plot_collection, labeller=labeller, aes_by_visuals={key: value for key, value in aes_by_visuals.items() if key == "dgof"}, visuals=visuals_dgof, stats=stats_dgof, ) return plot_collection