"""welltestpy subpackage providing plotting routines."""
# pylint: disable=C0103
import copy
import functools as ft
import warnings
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import gaussian_kde
def _get_fig_ax(
fig=None,
ax=None,
ax_name="rectilinear",
sub_args=None,
sub_kwargs=None,
**fig_kwargs,
): # pragma: no cover
# ax_case: 0->None (create one) or given, 1->False, 2->True
ax_case = 1 + int(ax) if isinstance(ax, bool) else 0
sub_args = (111,) if sub_args is None else sub_args
sub_kwargs = {} if sub_kwargs is None else sub_kwargs
sub_kwargs["projection"] = ax_name
if ax_case == 0:
if fig is None:
fig = plt.figure(**fig_kwargs) if ax is None else ax.get_figure()
if ax is None:
ax = fig.add_subplot(*sub_args, **sub_kwargs)
assert ax.name == ax_name
assert ax.get_figure() is fig
return fig, ax
# if ax=False we only want a figure
if ax_case == 1:
return plt.figure(**fig_kwargs) if fig is None else fig
# if ax=True we want the current axis of the given figure
assert fig is not None
return fig, fig.gca()
def _sort_lgd(ax, **kwargs):
"""Show legend and sort it by names."""
handles, labels = ax.get_legend_handles_labels()
# sort both labels and handles by labels
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
ax.legend(handles, labels, **kwargs)
def dashes(i=1, max_n=12, width=1):
"""
Dashes for matplotlib.
Parameters
----------
i : int, optional
Number of dots. The default is 1.
max_n : int, optional
Maximal Number of dots. The default is 12.
width : float, optional
Linewidth. The default is 1.
Returns
-------
list
dashes list for matplotlib.
"""
return i * [width, width] + [max_n * 2 * width - 2 * i * width, width]
[docs]def fadeline(ax, x, y, label=None, color=None, steps=20, **kwargs):
"""Fading line for matplotlib.
This is a workaround to produce a fading line.
Parameters
----------
ax : axis
Axis to plot on.
x : :class:`list`
start and end value of x components of the line
y : :class:`list`
start and end value of y components of the line
label : :class:`str`, optional
label for the legend.
Default: ``None``
color : MPL color, optional
color of the line
Default: ``None``
steps : :class:`int`, optional
steps of fading
Default: ``20``
**kwargs
keyword arguments that are forwarded to `plt.plot`
"""
xarr = np.linspace(x[0], x[1], steps + 1)
yarr = np.linspace(y[0], y[1], steps + 1)
kwargs.pop("label", None)
kwargs.pop("alpha", None)
kwargs["color"] = color
kwargs["solid_capstyle"] = "butt"
for i in range(steps):
kwargs["label"] = label if i == 0 else None
kwargs["alpha"] = (steps - i) * (1.0 / steps) * 0.9 + 0.1
ax.plot([xarr[i], xarr[i + 1]], [yarr[i], yarr[i + 1]], **kwargs)
[docs]def campaign_plot(campaign, select_test=None, fig=None, style="WTP", **kwargs):
"""
Plot an overview of the tests within the campaign.
Parameters
----------
campaign : :class:`Campaign`
The campaign to be plotted.
select_test : dict, optional
The selected tests to be added to the plot. The default is None.
fig : Figure, optional
Matplotlib figure to plot on. The default is None.
style : str, optional
Plot style. The default is "WTP".
**kwargs : TYPE
Keyword arguments forwarded to the tests plotting routines.
Returns
-------
fig : Figure
The created matplotlib figure.
"""
if select_test is None:
tests = list(campaign.tests.keys())
else:
tests = select_test
tests.sort()
nroftests = len(tests)
style = copy.deepcopy(plt.rcParams) if style is None else style
keep_fs = False
if style == "WTP":
style = "ggplot"
font_size = plt.rcParams.get("font.size", 10.0)
keep_fs = True
with plt.style.context(style):
if keep_fs:
plt.rcParams.update({"font.size": font_size})
fig = _get_fig_ax(fig, ax=False, dpi=75, figsize=[8, 3 * nroftests])
for n, t in enumerate(tests):
ax = fig.add_subplot(nroftests, 1, n + 1)
# call the plotting routine of the test
campaign.tests[t].plot(wells=campaign.wells, ax=ax, **kwargs)
fig.tight_layout()
fig.show()
return fig
[docs]def campaign_well_plot(
campaign, plot_tests=True, plot_well_names=True, fig=None, style="WTP"
):
"""
Plot of the well constellation within the campaign.
Parameters
----------
campaign : :class:`Campaign`
The campaign to be plotted.
plot_tests : bool, optional
DESCRIPTION. The default is True.
plot_well_names : TYPE, optional
DESCRIPTION. The default is True.
fig : Figure, optional
Matplotlib figure to plot on. The default is None.
style : str, optional
Plot style. The default is "WTP".
Returns
-------
ax : Axes
The created matplotlib axes.
"""
well_const0 = []
names = []
for w in campaign.wells:
well_const0.append(
[campaign.wells[w].pos[0], campaign.wells[w].pos[1]]
)
names.append(w)
well_const = [well_const0]
fig = plot_well_pos(
well_const,
names,
plot_well_names=plot_well_names,
fig=fig,
style=style,
)
style = copy.deepcopy(plt.rcParams) if style is None else style
keep_fs = False
if style == "WTP":
style = "ggplot"
font_size = plt.rcParams.get("font.size", 10.0)
keep_fs = True
with plt.style.context(style):
if keep_fs:
plt.rcParams.update({"font.size": font_size})
clrs = plt.rcParams["axes.prop_cycle"].by_key()["color"]
clr_n = len(clrs)
fig, ax = _get_fig_ax(fig, ax=True)
if plot_tests:
testlist = list(campaign.tests.keys())
testlist.sort()
for i, t in enumerate(testlist):
p_well = campaign.tests[t].pumpingwell
for j, obs in enumerate(campaign.tests[t].observations):
x0 = campaign.wells[p_well].pos[0]
y0 = campaign.wells[p_well].pos[1]
x1 = campaign.wells[obs].pos[0]
y1 = campaign.wells[obs].pos[1]
label = "{}".format(t) if j == 0 else None
fadeline(
ax=ax,
x=[x0, x1],
y=[y0, y1],
label=label,
color=clrs[(i + 2) % clr_n],
linewidth=3,
zorder=10,
)
# get equal axis (for realism)
ax.axis("equal")
ax.legend(title="Test at", loc="upper left", bbox_to_anchor=(1, 1))
fig.tight_layout()
fig.show()
return ax
def plot_pump_test(
pump_test, wells, exclude=None, fig=None, ax=None, style="WTP", **kwargs
):
"""Plot a pumping test.
Parameters
----------
pump_test: :class:`PumpingTest`
Pumping test class that should be plotted.
wells : :class:`dict`
Dictionary containing the well classes sorted by name.
exclude: :class:`list`, optional
List of wells that should be excluded from the plot.
Default: ``None``
fig : Figure, optional
Matplotlib figure to plot on. The default is None.
ax : :class:`Axes`
Matplotlib axes to plot on. The default is None.
style : str, optional
Plot style. The default is "WTP".
Returns
-------
ax : Axes
The created matplotlib axes.
Notes
-----
This will be used by the Campaign class.
"""
style = copy.deepcopy(plt.rcParams) if style is None else style
keep_fs = False
if style == "WTP":
style = "ggplot"
font_size = plt.rcParams.get("font.size", 10.0)
keep_fs = True
with plt.style.context(style):
if keep_fs:
plt.rcParams.update({"font.size": font_size})
clrs = plt.rcParams["axes.prop_cycle"].by_key()["color"]
clr_n = len(clrs)
fig, ax = _get_fig_ax(fig, ax)
exclude = set() if exclude is None else set(exclude)
well_set = set(wells)
test_wells = set(pump_test.observationwells)
plot_wells = list((well_set & test_wells) - exclude)
# sort by radius
plot_wells.sort(key=lambda x: wells[x] - wells[pump_test.pumpingwell])
state = pump_test.state(wells=plot_wells)
steady_guide_x = []
steady_guide_y = []
# label for absolute values
abslab = "abs. " if ("abs_val" in kwargs and kwargs["abs_val"]) else ""
if state == "mixed":
ax1 = ax
ax2 = ax1.twiny()
elif state == "transient":
ax1 = ax
ax2 = None
elif state == "steady":
ax1 = None
ax2 = ax
else:
raise ValueError("plot_pump_test: unknown state of pumping test.")
for i, k in enumerate(plot_wells):
if k != pump_test.pumpingwell:
dist = wells[k] - wells[pump_test.pumpingwell]
else:
dist = wells[pump_test.pumpingwell].radius
if pump_test.observations[k].state == "transient":
if abslab:
displace = np.abs(pump_test.observations[k].value[0])
else:
displace = pump_test.observations[k].value[0]
ax1.plot(
pump_test.observations[k].value[1],
displace,
linewidth=2,
color=clrs[i % clr_n],
label=(
pump_test.observations[k].name
+ " r={:1.2f}".format(dist)
),
)
ax1.set_xlabel(pump_test.observations[k].labels[0])
ax1.set_ylabel(
abslab + "{}".format(pump_test.observations[k].labels[1])
)
else:
if abslab:
displace = np.abs(pump_test.observations[k].value)
else:
displace = pump_test.observations[k].value
steady_guide_x.append(dist)
steady_guide_y.append(displace)
label = pump_test.observations[k].name + " r={:1.2f}".format(
dist
)
color = "C{}".format(i % 10)
ax2.scatter(dist, displace, color=color, label=label)
ax2.set_xlabel("r in {}".format(wells[k].coordinates.units))
ax2.set_ylabel(
abslab + "{}".format(pump_test.observations[k].labels)
)
if state != "transient":
steady_guide_x = np.array(steady_guide_x, dtype=float)
steady_guide_y = np.array(steady_guide_y, dtype=float)
arg = np.argsort(steady_guide_x)
steady_guide_x = steady_guide_x[arg]
steady_guide_y = steady_guide_y[arg]
ax2.plot(steady_guide_x, steady_guide_y, color="k", alpha=0.1)
if "title" not in kwargs or not kwargs["title"] is False:
ax.set_title(repr(pump_test))
if "xscale" in kwargs:
ax.set_xscale(kwargs["xscale"])
if "yscale" in kwargs:
ax.set_yscale(kwargs["yscale"])
ax.legend(
title="Pumping test '{}'".format(pump_test.name),
loc="upper left",
bbox_to_anchor=(1, 1),
)
if state == "mixed": # add a second legend
ax2.legend(loc="upper right", fancybox=True, framealpha=0.75)
return ax
####
[docs]def plot_well_pos(
well_const,
names=None,
title="",
filename=None,
plot_well_names=True,
ticks_set="auto",
fig=None,
style="WTP",
):
"""
Plot all well constellations and label the points with the names.
Parameters
----------
well_const : list
List of well constellations.
names : list of str, optional
Names for the wells. The default is None.
title : str, optional
Plot title. The default is "".
filename : str, optional
Filename if the result should be saved. The default is None.
plot_well_names : bool, optional
Whether to plot the well-names. The default is True.
ticks_set : int or str, optional
Tick spacing in the plot. The default is "auto".
fig : Figure, optional
Matplotlib figure to plot on. The default is None.
style : str, optional
Plot style. The default is "WTP".
Returns
-------
fig : Figure
The created matplotlib figure.
"""
# calculate Column- and Row-count for quadratic shape of the plot
# total number of plots
total_n = len(well_const)
# columns near the square-root but tendentially wider than tall
col_n = int(np.ceil(np.sqrt(total_n)))
# enough rows to catch all plots
row_n = int(np.ceil(total_n / col_n))
# Position numbers as array
pos_tuple = np.arange(total_n) + 1
# generate names for points if undefined
if names is None:
names = []
for i in range(len(well_const[0])):
names.append("p" + str(i))
# generate common borders for all plots
xmax = -np.inf
xmin = np.inf
ymax = -np.inf
ymin = np.inf
for i in well_const:
for j in i:
xmax = max(j[0], xmax)
xmin = min(j[0], xmin)
ymax = max(j[1], ymax)
ymin = min(j[1], ymin)
# add some space around the points in the plot
space = 0.1 * max(abs(xmax - xmin), abs(ymax - ymin))
xspace = yspace = space
if ticks_set == "auto":
# bit hacky auto-ticking to be more pleasant for the eyes
tick_list = [1, 2, 5, 10]
tk_space = space * 10 / 7 # assume about 7 ticks
scaling = np.log10(tk_space)
if np.log10(0.4) < scaling < 1:
# if space is less 10, choose nearest value in tick_list (by log)
ticks_set = min(tick_list, key=lambda x: abs(np.log(x / tk_space)))
else:
# k * 10 ** n as ticks (0.1, 0.2, ..., 10, 20, ..., 100, 200, ...)
space_pot = 10 ** int(np.floor(scaling))
ticks_set = space_pot * int(np.around(tk_space / space_pot))
style = copy.deepcopy(plt.rcParams) if style is None else style
keep_fs = False
if style == "WTP":
style = "ggplot"
font_size = plt.rcParams.get("font.size", 10.0)
keep_fs = True
with plt.style.context(style):
if keep_fs:
plt.rcParams.update({"font.size": font_size})
fig = _get_fig_ax(
fig, ax=False, dpi=100, figsize=[9 * col_n, 5 * row_n]
)
for i, wells in enumerate(well_const):
ax = fig.add_subplot(row_n, col_n, pos_tuple[i])
ax.set_xlim([xmin - xspace, xmax + xspace])
ax.set_ylim([ymin - yspace, ymax + yspace])
ax.set_aspect("equal")
for j, name in enumerate(names):
ax.scatter(wells[j][0], wells[j][1], color="k", zorder=100)
if plot_well_names:
ax.annotate(
" " + name, (wells[j][0], wells[j][1]), zorder=100
)
ax.xaxis.set_major_locator(ticker.MultipleLocator(ticks_set))
ax.yaxis.set_major_locator(ticker.MultipleLocator(ticks_set))
ax.set_xlabel("x $[m]$")
ax.set_ylabel("y $[m]$")
if total_n > 1:
ax.set_title("Result {}".format(i))
if title:
fig.suptitle(title)
fig.tight_layout(rect=[0, 0, 1, 0.95])
if filename is not None:
fig.savefig(filename, format="pdf")
return fig
######
# Estimation plotting
[docs]def plotfit_transient(
setup,
data,
para,
rad,
time,
radnames,
extra,
plotname=None,
fig=None,
ax=None,
style="WTP",
):
"""Plot of transient estimation fitting."""
style = copy.deepcopy(plt.rcParams) if style is None else style
keep_fs = False
if style == "WTP":
style1 = "ggplot"
style2 = "default"
font_size = plt.rcParams.get("font.size", 10.0)
# font type fix
pdf_ft = plt.rcParams.get("pdf.fonttype", 42)
ps_ft = plt.rcParams.get("ps.fonttype", 42)
keep_fs = True
else:
style1 = style2 = style
with plt.style.context(style1):
clrs = plt.rcParams["axes.prop_cycle"].by_key()["color"]
clr_n = len(clrs)
with plt.style.context(style2):
if keep_fs:
# font type fix (reset in default)
plt.rcParams.update({"pdf.fonttype": pdf_ft, "ps.fonttype": ps_ft})
plt.rcParams.update({"font.size": font_size})
fig, ax = _get_fig_ax(fig, ax, ax_name=Axes3D.name, figsize=(7.5, 7))
val_fix = setup.val_fix
for kwarg in ["time", "rad"]:
val_fix.pop(extra[kwarg], None)
para_ordered = np.empty(len(setup.para_names))
for i, name in enumerate(setup.para_names):
para_ordered[i] = para[name]
para_kw = setup.get_sim_kwargs(para_ordered)
val_fix.update(para_kw)
plot_f = ft.partial(setup.func, **val_fix)
radarr = np.linspace(rad.min(), rad.max(), 100)
timarr = np.linspace(time.min(), time.max(), 100)
t_gen = np.ones_like(radarr)
r_gen = np.ones_like(time)
r_gen1 = np.ones_like(timarr)
xydir = np.zeros_like(time)
test_name = list(np.unique(radnames[:, 0]))
test_name.sort()
__, rad_un_idx = np.unique(rad, return_index=True)
for ri, re in enumerate(rad):
r1 = re * r_gen
r11 = re * r_gen1
h = plot_f(**{extra["time"]: time, extra["rad"]: re}).reshape(-1)
h1 = data[:, ri]
h2 = plot_f(**{extra["time"]: timarr, extra["rad"]: re}).reshape(
-1
)
color = clrs[(test_name.index(radnames[ri, 0]) + 2) % clr_n]
alpha = 0.3 * (1 - (re - min(rad)) / (max(rad) - min(rad))) + 0.3
zord = 100 * (len(rad) - ri)
if radnames[ri, 0] == radnames[ri, 1]:
label = radnames[ri, 0]
label_eff = "type curve"
eff_zord = zord + 100 # first line should be on top
else:
label = None
label_eff = None
eff_zord = 1
if ri in rad_un_idx:
ax.plot(
r11,
timarr,
h2,
zorder=eff_zord,
color="k",
alpha=alpha,
label=label_eff,
)
ax.quiver(
r1,
time,
h,
xydir,
xydir,
h1 - h,
alpha=0.6,
arrow_length_ratio=0.0,
color=color,
zorder=zord + 30,
)
ax.scatter(
r1,
time,
h1,
depthshade=False,
zorder=zord + 60,
color=color,
label=label,
)
for te in time:
t11 = te * t_gen
h = plot_f(**{extra["time"]: te, extra["rad"]: radarr}).reshape(-1)
ax.plot(radarr, t11, h, color="k", alpha=0.1, linestyle="--")
ax.view_init(elev=30, azim=130)
ax.set_xlabel(r"$r$ in $\left[\mathrm{m}\right]$", labelpad=20)
ax.set_ylabel(r"$t$ in $\left[\mathrm{s}\right]$", labelpad=20)
ax.set_zlabel(r"$\tilde{h}$ in $\left[\mathrm{m}\right]$", labelpad=10)
_sort_lgd(
ax,
loc="lower center",
markerscale=2,
bbox_to_anchor=(0.5, -0.1),
ncol=5,
columnspacing=1.0,
handletextpad=0.5,
handlelength=1.0,
)
fig.tight_layout()
fig.subplots_adjust(top=1, left=0, right=0.9)
if plotname is not None:
fig.savefig(plotname, format="pdf")
return ax
[docs]def plotfit_steady(
setup,
data,
para,
rad,
radnames,
extra,
plotname=None,
ax_ins=True,
fig=None,
ax=None,
style="WTP",
):
"""Plot of steady estimation fitting."""
val_fix = setup.val_fix
val_fix.pop(extra["rad"], None)
para_ordered = np.empty(len(setup.para_names))
for i, name in enumerate(setup.para_names):
para_ordered[i] = para[name]
para_kw = setup.get_sim_kwargs(para_ordered)
val_fix.update(para_kw)
plot_f = ft.partial(setup.func, **val_fix)
radarr = np.linspace(rad.min(), rad.max(), 100)
test_name = list(np.unique(radnames[:, 0]))
test_name.sort()
style = copy.deepcopy(plt.rcParams) if style is None else style
keep_fs = False
if style == "WTP":
style = "ggplot"
font_size = plt.rcParams.get("font.size", 10.0)
keep_fs = True
with plt.style.context(style):
if keep_fs:
plt.rcParams.update({"font.size": font_size})
clrs = plt.rcParams["axes.prop_cycle"].by_key()["color"]
clr_n = len(clrs)
fig, ax = _get_fig_ax(fig, ax, figsize=(9, 6))
if ax_ins:
axins = ax.inset_axes([0.4, 0.07, 0.57, 0.5])
axins.plot(
radarr,
plot_f(**{extra["rad"]: radarr}),
alpha=0.6,
color="k",
zorder=200,
)
axins.set_xscale("log")
axins.set_facecolor("w")
axins.text(
0.975,
0.025,
"log-radius plot",
ha="right",
va="bottom",
bbox=dict(boxstyle="round", ec="k", fc="w"),
transform=axins.transAxes,
)
for ri, re in enumerate(rad):
h = np.asarray(plot_f(**{extra["rad"]: re})).reshape(-1)
h1 = np.asarray(data[ri]).reshape(-1)
color = clrs[(test_name.index(radnames[ri, 0]) + 2) % clr_n]
if radnames[ri, 0] == radnames[ri, 1]:
label = "test at '{}'".format(radnames[ri, 0])
else:
label = None
ax.plot([re, re], [h, h1], alpha=0.6, color=color, zorder=100)
ax.scatter(re, data[ri], color=color, label=label, zorder=300)
if ax_ins:
axins.plot(
[re, re], [h, h1], alpha=0.6, color=color, zorder=100
)
axins.scatter(re, data[ri], color=color, zorder=300)
ax.plot(
radarr,
plot_f(**{extra["rad"]: radarr}),
alpha=0.6,
color="k",
zorder=200,
label="fitted type curve",
)
ax.set_xlabel(r"$r$ in $\left[\mathrm{m}\right]$")
ax.set_ylabel(r"$\tilde{h}$ in $\left[\mathrm{m}\right]$")
_sort_lgd(ax, loc="upper left", bbox_to_anchor=(1, 1), markerscale=2)
fig.tight_layout()
if plotname is not None:
fig.savefig(plotname, format="pdf")
return ax
[docs]def plotparainteract(result, paranames, plotname=None, fig=None, style="WTP"):
"""Plot of parameter interaction."""
style = copy.deepcopy(plt.rcParams) if style is None else style
keep_fs = False
if style == "WTP":
style = "default"
font_size = plt.rcParams.get("font.size", 10.0)
# font type fix
pdf_ft = plt.rcParams.get("pdf.fonttype", 42)
ps_ft = plt.rcParams.get("ps.fonttype", 42)
keep_fs = True
with plt.style.context(style):
if keep_fs:
# font type fix (resetted in default)
plt.rcParams.update({"pdf.fonttype": pdf_ft, "ps.fonttype": ps_ft})
plt.rcParams.update({"font.size": font_size})
fields = [par for par in result.dtype.names if par.startswith("par")]
para = [result[:][name] for name in fields]
fig = _scatter_matrix(para, paranames, fig)
fig.tight_layout()
fig.subplots_adjust(hspace=0, wspace=0, bottom=0.1)
if plotname is not None:
fig.savefig(plotname, format="pdf")
return fig
[docs]def plotparatrace(
result,
parameternames=None,
parameterlabels=None,
xticks=None,
stdvalues=None,
plotname=None,
fig=None,
style="WTP",
):
"""Plot of parameter trace."""
rep = len(result)
rows = len(parameternames)
style = copy.deepcopy(plt.rcParams) if style is None else style
keep_fs = False
if style == "WTP":
style = "ggplot"
font_size = plt.rcParams.get("font.size", 10.0)
keep_fs = True
with plt.style.context(style):
if keep_fs:
plt.rcParams.update({"font.size": font_size})
clrs = plt.rcParams["axes.prop_cycle"].by_key()["color"]
fig = _get_fig_ax(fig, ax=False, figsize=(15, 3 * rows))
axes = []
for j in range(rows):
ax = fig.add_subplot(rows, 1, 1 + j)
axes.append(ax)
data = result["par" + parameternames[j]]
ax.plot(data, "-", color=clrs[0])
if stdvalues is not None:
ax.plot(
[stdvalues[parameternames[j]]] * rep,
"--",
label="best value: {:04.2f}".format(
stdvalues[parameternames[j]]
),
color="k",
alpha=0.7,
)
ax.legend()
if xticks is None:
xticks = np.linspace(0, 1, 11) * len(data)
ax.set_xlim(0, rep)
ax.margins(y=0.2)
ax.xaxis.set_ticks(xticks)
ax.set_ylabel(parameterlabels[j], fontsize="large")
ax.set_xlabel("Iterations", fontsize="large")
fig.align_ylabels(axes)
fig.tight_layout()
if plotname is not None:
fig.savefig(plotname, format="pdf", bbox_inches="tight")
return fig
[docs]def plotsensitivity(
paralabels, sensitivities, plotname=None, fig=None, ax=None, style="WTP"
):
"""Plot of sensitivity results."""
style = copy.deepcopy(plt.rcParams) if style is None else style
keep_fs = False
if style == "WTP":
style = "ggplot"
font_size = plt.rcParams.get("font.size", 10.0)
keep_fs = True
with plt.style.context(style):
if keep_fs:
plt.rcParams.update({"font.size": font_size})
fig, ax = _get_fig_ax(fig, ax)
w_props = {"linewidth": 1, "edgecolor": "w", "width": 0.5}
wedges, __ = ax.pie(
sensitivities["ST"], wedgeprops=w_props, startangle=90
)
lgd = ax.legend(
wedges,
paralabels,
title="Parameters",
loc="center left",
bbox_to_anchor=(1, 0, 0.5, 1),
)
ax.axis("equal")
fig.suptitle("FAST total sensitivity shares", fontsize="large")
fig.tight_layout()
if plotname is not None:
fig.savefig(
plotname,
format="pdf",
bbox_extra_artists=(lgd,),
bbox_inches="tight",
)
return ax
def _scatter_matrix(data, label, fig=None):
data = np.array(data, ndmin=2, dtype=float)
n = len(data)
axes = np.empty(n**2, dtype=object)
for i in range(n**2):
fig, axes[i] = _get_fig_ax(fig, figsize=(8, 8), sub_args=(n, n, i + 1))
axes = axes.reshape(n, n)
boundaries_list = []
for dat in data:
rmin, rmax = np.min(dat), np.max(dat)
rdelta = (rmax - rmin) * 0.025
boundaries_list.append((rmin - rdelta, rmax + rdelta))
for i, a in enumerate(data):
for j, b in enumerate(data):
ax = axes[i, j]
if i == j:
ind = np.linspace(a.min(), a.max(), 1000)
ax.plot(ind, gaussian_kde(a).evaluate(ind))
else:
ax.scatter(b, a, marker=".", alpha=0.2, edgecolors="none")
ax.set_ylim(boundaries_list[i])
ax.set_xlim(boundaries_list[j])
ax.set_xlabel(label[j])
ax.set_ylabel(label[i])
if j != 0:
ax.yaxis.set_visible(False)
if i != n - 1:
ax.xaxis.set_visible(False)
# reset labels of first kde plot to match scatter plots
if n > 1:
lim1 = boundaries_list[0]
locs = axes[0, 1].yaxis.get_majorticklocs()
locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
lim0 = axes[0, 0].get_ylim()
adj = adj * (lim0[1] - lim0[0]) + lim0[0]
axes[0, 0].yaxis.set_ticks(adj)
locs = locs.astype(int) if np.all(locs == locs.astype(int)) else locs
axes[0, 0].yaxis.set_ticklabels(locs)
fig.align_ylabels(axes[:, 0])
fig.align_xlabels(axes[-1, :])
for ax in axes[-1, :]:
plt.setp(ax.get_xticklabels(), rotation=90)
return fig