Skip to content

Commit

Permalink
Adding option to share value axis in plots
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Jul 22, 2024
1 parent ad756b7 commit dc27b83
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
32 changes: 28 additions & 4 deletions src/dxh.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from dolfinx.geometry import compute_collisions as compute_collisions_points


from matplotlib.colors import Normalize
from matplotlib.tri import Triangulation
from mpi4py import MPI

Expand Down Expand Up @@ -170,6 +171,7 @@ def plot_1d_functions(
points: Optional[NDArray[np.float64]] = None,
axis_size: tuple[float, float] = (5.0, 5.0),
arrangement: Literal["horizontal", "vertical", "stacked"] = "horizontal",
share_value_axis: bool = False,
) -> plt.Figure:
"""
Plot one or more finite element functions on 1D domains using Matplotlib.
Expand All @@ -188,6 +190,8 @@ def plot_1d_functions(
:py:const:`"stacked"` corresponding to respectively plotting functions on
separate axes in a single row, plotting functions on separate axes in a
single column or plotting functions all on a single axis.
share_value_axis: Whether to use a common vertical axis scale (representing
function value) across all subplots.
Returns:
Matplotlib figure object with plotted function(s).
Expand All @@ -206,7 +210,7 @@ def plot_1d_functions(
else:
msg = f"Value {arrangement} for arrangement invalid"
raise ValueError(msg)
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, sharey=share_value_axis)
axes = np.atleast_1d(axes)
for i, (label, function) in enumerate(label_and_functions):
ax = axes[0] if arrangement == "stacked" else axes[i]
Expand Down Expand Up @@ -234,6 +238,7 @@ def plot_2d_functions(
show_colorbar: bool = True,
triangulation_color: Union[str, tuple[float, float, float], None] = None,
arrangement: Literal["horizontal", "vertical"] = "horizontal",
share_value_axis: bool = False,
) -> plt.Figure:
"""
Plot one or more finite element functions on 2D domains using Matplotlib.
Expand Down Expand Up @@ -263,6 +268,8 @@ def plot_2d_functions(
heatmap.
arrangement: Whether to arrange multiple axes vertically in a single column
rather than default of horizontally in a single row.
share_value_axis: Whether to use a common vertical axis scale and/or colormap
normalization (representing function value) across all subplots.
Returns:
Matplotlib figure object with plotted function(s).
Expand All @@ -279,30 +286,47 @@ def plot_2d_functions(
else:
msg = f"Value of arrangement argument {arrangement} is invalid"
raise ValueError(msg)
subplot_kw = {"projection": "3d"} if plot_type == "surface" else {}
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, subplot_kw=subplot_kw)
for ax, (label, function) in zip(np.atleast_1d(axes), label_and_functions):
labels_triangulations_and_function_values = []
min_value, max_value = float("inf"), -float("inf")
for label, function in label_and_functions:
mesh = function.function_space.mesh
if mesh.topology.dim != 2:
msg = "Only two-dimensional spatial domains are supported"
raise ValueError(msg)
triangulation = get_matplotlib_triangulation_from_mesh(mesh)
function_values = evaluate_function_at_points(function, mesh.geometry.x)
min_value = min(min_value, function_values.min())
max_value = max(max_value, function_values.max())
labels_triangulations_and_function_values.append(
(label, triangulation, function_values),
)
subplot_kw = {"projection": "3d"} if plot_type == "surface" else {}
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, subplot_kw=subplot_kw)
normalize = Normalize(vmin=min_value, vmax=max_value) if share_value_axis else None
for ax, (label, triangulation, function_values) in zip(
np.atleast_1d(axes),
labels_triangulations_and_function_values,
strict=True,
):
if plot_type == "surface":
artist = ax.plot_trisurf(
triangulation,
function_values,
cmap=colormap,
norm=normalize,
shade=False,
edgecolor=triangulation_color,
linewidth=None if triangulation_color is None else 0.2,
)
if share_value_axis:
ax.set_zlim(min_value, max_value)
elif plot_type == "pcolor":
artist = ax.tripcolor(
triangulation,
function_values,
shading="gouraud",
cmap=colormap,
norm=normalize,
)
if triangulation_color is not None:
ax.triplot(triangulation, color=triangulation_color, linewidth=1.0)
Expand Down
20 changes: 17 additions & 3 deletions tests/test_dxh.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,14 @@ def _interpolate_functions(function_space, functions):
@pytest.mark.parametrize("degree", [1, 2])
@pytest.mark.parametrize("points", [None, np.linspace(0, 1, 3)])
@pytest.mark.parametrize("arrangement", ["vertical", "horizontal", "stacked"])
def test_plot_1d_functions(number_cells_per_axis, points, degree, arrangement):
@pytest.mark.parametrize("share_value_axis", [True, False])
def test_plot_1d_functions(
number_cells_per_axis,
points,
degree,
arrangement,
share_value_axis,
):
mesh = _create_unit_mesh(1, number_cells_per_axis)
function_space = dolfinx.fem.FunctionSpace(mesh, ("Lagrange", degree))
functions_dict = _interpolate_functions(
Expand All @@ -254,6 +261,7 @@ def test_plot_1d_functions(number_cells_per_axis, points, degree, arrangement):
functions_argument,
points=points,
arrangement=arrangement,
share_value_axis=share_value_axis,
)
assert isinstance(fig, plt.Figure)
number_functions = (
Expand Down Expand Up @@ -298,6 +306,7 @@ def test_plot_1d_functions_invalid_dimension():
[None, "white", "#fff", (1.0, 1.0, 1.0)],
)
@pytest.mark.parametrize("arrangement", ["horizontal", "vertical"])
@pytest.mark.parametrize("share_value_axis", [True, False])
def test_plot_2d_functions(
number_cells_per_axis,
degree,
Expand All @@ -306,6 +315,7 @@ def test_plot_2d_functions(
show_colorbar,
triangulation_color,
arrangement,
share_value_axis,
):
mesh = _create_unit_mesh(2, number_cells_per_axis)
function_space = dolfinx.fem.FunctionSpace(mesh, ("Lagrange", degree))
Expand All @@ -325,14 +335,18 @@ def test_plot_2d_functions(
colormap=colormap,
triangulation_color=triangulation_color,
arrangement=arrangement,
share_value_axis=share_value_axis,
)
assert isinstance(fig, plt.Figure)
number_functions = (
1
if isinstance(functions_argument, dolfinx.fem.Function)
else len(functions_argument)
) * (2 if show_colorbar else 1)
assert len(fig.get_axes()) == number_functions
)
expected_number_axes = (
2 * number_functions if show_colorbar else number_functions
)
assert len(fig.get_axes()) == expected_number_axes
plt.close(fig)


Expand Down

0 comments on commit dc27b83

Please sign in to comment.