"""
Plotting utilities based on `matplotib <http://matplotlib.org/>`_.
Some default values are configurable via the dictionary ``arim.plot.conf``.
.. py:data:: conf
Dictionary of default values. For some functions, if an argument is not populated,
its values will be populated from this dictionary. Example::
# save the figure (independently on conf['savefig])
plot_oyz(data, grid, savefig=True, filename='foo')
# do not save the figure independently on conf['savefig])
plot_oyz(data, grid, savefig=False, filename='foo')
# save the figure depending only if conf['savefig'] is True
plot_oyz(data, grid, filename='foo')
.. py:data:: micro_formatter
.. py:data:: milli_formatter
.. py:data:: mega_formatter
Format the labels of an axis in a given unit prefix. Usage::
import matplotlib.pyplot as plt
ax = plt.plot(distance_vector, data)
ax.xaxis.set_major_formatter(arim.plot.milli_formatter)
"""
import logging
from warnings import warn
import matplotlib.pyplot as plt
import numpy as np
import scipy.signal
from matplotlib import ticker
from mpl_toolkits import axes_grid1
from . import geometry as g
from . import ut
from .config import Config
from .exceptions import ArimWarning
__all__ = [
"micro_formatter",
"mega_formatter",
"milli_formatter",
"plot_bscan",
"plot_bscan_pulse_echo",
"plot_oxz",
"plot_oxz_many",
"plot_tfm",
"plot_directivity_finite_width_2d",
"draw_rays_on_click",
"RayPlotter",
"conf",
"common_dynamic_db_scale",
]
logger = logging.getLogger(__name__)
micro_formatter = ticker.FuncFormatter(lambda x, pos: f"{x * 1e6:.1f}")
micro_formatter.__doc__ = "Format an axis to micro (µ).\nExample: ``ax.xaxis.set_major_formatter(micro_formatter)``"
milli_formatter = ticker.FuncFormatter(lambda x, pos: f"{x * 1e3:.1f}")
milli_formatter.__doc__ = "Format an axis to milli (m).\nExample: ``ax.xaxis.set_major_formatter(milli_formatter)``"
mega_formatter = ticker.FuncFormatter(lambda x, pos: f"{x * 1e-6:.1f}")
mega_formatter.__doc__ = "Format an axis to mega (M).\nExample: ``ax.xaxis.set_major_formatter(mega_formatter)``"
conf = Config(
[
("savefig", False), # save the figure?
("plot_oxz.figsize", None),
("plot_oxz_many.figsize", None),
]
)
[docs]
def plot_bscan(
frame,
timetraces_idx,
use_dB=True,
ax=None,
title="B-scan",
clim=None,
interpolation="none",
draw_cbar=True,
cmap=None,
savefig=None,
filename="bscan",
):
"""Plot Bscan (timetraces vs time)
Parameters
----------
frame : Frame
timetraces_idx : slice or tuple or ndarray
timetraces to use. Any valid numpy array is accepted.
use_dB : bool, optional
ax : matplotlib axis, optional
Where to draw. Default: create a new figure and axis.
title : str, optional
Title of the image (default: "Bscan")
clim : tuple, optional
Color limits of the image.
interpolation : str, optional
Image interpolation type (default: "none")
draw_cbar : bool, optional
cmap : str, optional
savefig : bool, optional
Default: use ``conf["savefig"]``
filename : str, optional
Default: "bscan"
Returns
-------
ax : matplotlib axis
im : matplotlib image
Examples
--------
>>> arim.plot.plot_bscan(frame, frame.tx == 0)
"""
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
if savefig is None:
savefig = conf["savefig"]
timetraces = frame.timetraces[timetraces_idx]
numtimetraces = timetraces.shape[0]
if use_dB:
timetraces = ut.decibel(timetraces)
if clim is None:
clim = [-40.0, 0.0]
im = ax.imshow(
timetraces,
extent=[frame.time.start, frame.time.end, 0, numtimetraces - 1],
interpolation=interpolation,
cmap=cmap,
origin="lower",
)
ax.set_xlabel("Time (µs)")
ax.set_ylabel("TX/RX index")
ax.xaxis.set_major_formatter(micro_formatter)
ax.xaxis.set_minor_formatter(micro_formatter)
# Use element index instead of timetrace index (may be different)
tx = frame.tx[timetraces_idx]
rx = frame.rx[timetraces_idx]
def _y_formatter(i, pos):
i = int(i)
try:
return f"({tx[i]}, {rx[i]})"
except IndexError:
return ""
y_formatter = ticker.FuncFormatter(_y_formatter)
ax.yaxis.set_major_formatter(y_formatter)
ax.yaxis.set_minor_formatter(y_formatter)
if draw_cbar:
fig.colorbar(im, ax=ax)
if clim is not None:
im.set_clim(clim)
if title is not None:
ax.set_title(title)
ax.axis("tight")
if savefig:
ax.figure.savefig(filename)
return ax, im
[docs]
def plot_bscan_pulse_echo(
frame,
use_dB=True,
ax=None,
title="B-scan (pulse-echo)",
clim=None,
interpolation="none",
draw_cbar=True,
cmap=None,
savefig=None,
filename="bscan",
):
"""
Plot a B-scan. Use the pulse-echo timetraces.
Parameters
----------
frame
use_dB
ax
title
clim
interpolation
draw_cbar
cmap
Returns
-------
axis, image
See Also
--------
:func:`plot_bscan`
"""
pulse_echo = frame.tx == frame.rx
elements = frame.tx[pulse_echo]
ax, im = plot_bscan(
frame,
pulse_echo,
use_dB=use_dB,
ax=ax,
title=title,
clim=clim,
interpolation=interpolation,
draw_cbar=draw_cbar,
cmap=cmap,
savefig=False, # save later
filename=filename,
)
ax.set_ylabel("Element")
# Use element index instead of timetrace index (may be different)
def _y_formatter(i, pos):
i = int(i)
if i >= len(elements):
return ""
else:
return str(elements[i])
y_formatter = ticker.FuncFormatter(_y_formatter)
ax.yaxis.set_major_formatter(y_formatter)
ax.yaxis.set_minor_formatter(y_formatter)
if savefig:
ax.figure.savefig(filename)
return ax, im
[docs]
def plot_psd(
frame,
idx="all",
to_show="filtered",
welch_params=None,
ax=None,
title="Power spectrum estimation",
show_legend=True,
savefig=None,
filename="psd",
):
"""
Plot the estimated power spectrum of a timetrace using Welch's method.
Parameters
----------
frame : Frame
idx : int or slice or list
Index or indices of the timetrace to use. If multiple indices are given,
the arithmetical mean of all PSDs is plotted. Default: use all
to_show
welch_params : dict
Arguments to pass to ``scipy.signal.welch``.
ax : matplotlib.axes.Axes or None
title
show_legend
savefig
filename
Returns
-------
ax : matplotlib.axes.Axes
lines : dict
"""
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
if welch_params is None:
welch_params = {}
if savefig is None:
savefig = conf["savefig"]
if isinstance(idx, str) and idx == "all":
idx = slice(None)
fs = 1 / frame.time.step
to_show = to_show.lower()
if to_show == "both":
show_raw = True
show_filtered = True
elif to_show == "raw":
show_raw = True
show_filtered = False
elif to_show == "filtered":
show_raw = False
show_filtered = True
else:
raise ValueError("Valid values for 'to_show' are: filtered, raw, both")
lines = {}
if show_raw:
x = frame.timetraces_raw[idx].real
freq, pxx = scipy.signal.welch(x, fs, **welch_params)
if pxx.ndim == 2:
pxx = np.mean(pxx, axis=0)
line = ax.plot(freq, pxx, label="raw".format())
lines["raw"] = line
if show_filtered:
x = frame.timetraces[idx].real
freq, pxx = scipy.signal.welch(x, fs, **welch_params)
if pxx.ndim == 2:
pxx = np.mean(pxx, axis=0)
line = ax.plot(freq, pxx, label="filtered".format())
lines["filtered"] = line
ax.set_xlabel("frequency (MHz)")
ax.set_ylabel("power spectrum estimation")
ax.xaxis.set_major_formatter(mega_formatter)
ax.xaxis.set_minor_formatter(mega_formatter)
if title is not None:
ax.set_title(title)
if show_legend:
ax.legend(loc="best")
if savefig:
fig.savefig(filename)
return ax, lines
[docs]
def plot_oxz(
data,
grid,
ax=None,
title=None,
clim=None,
interpolation="none",
draw_cbar=True,
cmap=None,
figsize=None,
savefig=None,
patches=None,
filename=None,
scale="linear",
ref_db=None,
):
"""
Plot data in the plane Oxz.
Parameters
----------
data : ndarray
Shape: 2D matrix ``(grid.numx, grid.numz)`` or 3D matrix ``(grid.numx, 1, grid.numz)``
or 1D matrix ``(grid.numx * grid.numz)``
grid : Grid
ax : matplotlib.Axis or None
Axis where to plot.
title : str or None
clim : List[Float] or None
interpolation : str or None
draw_cbar : boolean
cmap
figsize : List[Float] or None
Default: ``conf['plot_oxz.figsize']``
savefig : boolean or None
If True, save the figure. Default: ``conf['savefig']``
patches : List[matplotlib.patches.Patch] or None
Patches to draw
filename : str or None
If True
scale : str or None
'linear' or 'db'. Default: 'linear'
ref_db : float or None
Value for 0 dB. Used only for scale=db.
Returns
-------
axis
image
Examples
--------
::
grid = arim.geometry.Grid(-5e-3, 5e-3, 0, 0, 0, 15e-3, .1e-3)
k = 2 * np.pi / 10e-3
data = (np.cos(grid.x * 2 * k) * np.sin(grid.z * k))
ax, im = aplt.plot_oxz(data, grid)
"""
if figsize is None:
figsize = conf["plot_oxz.figsize"]
else:
if ax is not None:
warn(
"figsize is ignored because an axis is provided",
ArimWarning,
stacklevel=2,
)
if savefig is None:
savefig = conf["savefig"]
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
if patches is None:
patches = []
valid_shapes = [
(grid.numx, 1, grid.numz),
(grid.numx, grid.numz),
(grid.numx * grid.numz,),
]
if data.shape in valid_shapes:
data = data.reshape((grid.numx, grid.numz))
else:
msg = "invalid data shape (got {}, expected {} or {} or {})".format(
data.shape, *valid_shapes
)
raise ValueError(msg)
data = np.rot90(data)
scale = scale.lower()
if scale == "linear":
if ref_db is not None:
warn("ref_db is ignored for linear plot", ArimWarning, stacklevel=2)
elif scale == "db":
data = ut.decibel(data, ref_db)
else:
raise ValueError(f"invalid scale: {scale}")
image = ax.imshow(
data,
interpolation=interpolation,
origin="lower",
extent=(grid.xmin, grid.xmax, grid.zmax, grid.zmin),
cmap=cmap,
)
if ax.get_xlabel() == "":
# avoid overwriting labels
ax.set_xlabel("x (mm)")
if ax.get_ylabel() == "":
ax.set_ylabel("z (mm)")
ax.xaxis.set_major_formatter(milli_formatter)
ax.xaxis.set_minor_formatter(milli_formatter)
ax.yaxis.set_major_formatter(milli_formatter)
ax.yaxis.set_minor_formatter(milli_formatter)
if draw_cbar:
fig.colorbar(image, ax=ax)
if clim is not None:
image.set_clim(clim)
if title is not None:
ax.set_title(title)
for p in patches:
ax.add_patch(p)
ax.set_aspect(aspect="equal", adjustable="box")
ax.axis([grid.xmin, grid.xmax, grid.zmax, grid.zmin])
if savefig:
if filename is None:
raise ValueError("filename must be provided when savefig is true")
fig.savefig(filename)
return ax, image
[docs]
def plot_oxz_many(
data_list,
grid,
nrows,
ncols,
title_list=None,
suptitle=None,
draw_colorbar=True,
figsize=None,
savefig=None,
clim=None,
filename=None,
y_title=1.0,
y_suptitle=1.0,
axes_pad=0.1,
**plot_oxz_kwargs,
):
"""
Plot many Oxz plots on the same figure.
Parameters
----------
data_list : List[ndarray]
Data are plotted from top left to bottom right, row per row.
grid : Grid
nrows : int
ncols : int
title_list : List[str] or None
suptitle : str or None
draw_colorbar : boolean
Default: True
figsize : List[Float] or None
Default: ``conf['plot_oxz_many.figsize']``
savefig: boolean
Default: ``conf['savefig']``
clim :
Color limit. Common for all plots.
filename
y_title : float
Adjust y location of the titles.
y_suptitle : float
Adjust y location of the titles.
axes_pad : float
Pad between images in inches
plot_oxz_kwargs
Returns
-------
axes_grid : axes_grid1.ImageGrid
im_list
"""
if savefig is None:
savefig = conf["savefig"]
if figsize is None:
figsize = conf["plot_oxz_many.figsize"]
if title_list is None:
title_list = [None] * len(data_list)
# must use a common clim (otherwise the figure does not make sense)
if clim is None:
clim = (
min(np.nanmin(x) for x in data_list),
max(np.nanmax(x) for x in data_list),
)
if draw_colorbar:
cbar_mode = "single"
else:
cbar_mode = None
fig = plt.figure(figsize=figsize)
axes_grid = axes_grid1.ImageGrid(
fig,
111,
nrows_ncols=(nrows, ncols),
axes_pad=axes_pad,
share_all=True,
cbar_mode=cbar_mode,
)
images = []
for data, title, ax in zip(data_list, title_list, axes_grid):
# the current function handles saving fig, drawing the cbar and displaying the title
# so we prevent plot_oxz to do it.
ax, im = plot_oxz(
data,
grid,
ax=ax,
clim=clim,
draw_cbar=False,
savefig=False,
**plot_oxz_kwargs,
title=None,
)
images.append(im)
if title is not None:
ax.set_title(title, y=y_title)
if suptitle is not None:
fig.suptitle(suptitle, y=y_suptitle, size="x-large")
if draw_colorbar:
cax = axes_grid.cbar_axes[0]
fig.colorbar(im, cax=cax)
if savefig:
if filename is None:
raise ValueError("filename must be provided when savefig is true")
fig.savefig(filename)
return axes_grid, images
[docs]
def plot_tfm(tfm, y=0.0, func_res=None, interpolation="bilinear", **plot_oxz_kwargs):
"""
Plot a TFM in plane Oxz.
Parameters
----------
tfm : BaseTFM
y : float
interpolation : str
Cf matplotlib.pyplot.imshow
func_res : function
Function to apply on tfm.res before plotting it. Example: ``lambda x: np.abs(x)``
plot_oxz_kwargs : dict
Returns
-------
ax
image
See Also
--------
:func:`plot_oxz`
"""
grid = tfm.grid
iy = np.argmin(np.abs(grid.y - y))
if tfm.res is None:
raise ValueError("No result in this TFM object.")
if func_res is None:
func_res = lambda x: x
data = func_res(tfm.res[:, iy, :])
return plot_oxz(data, grid=grid, interpolation=interpolation, **plot_oxz_kwargs)
[docs]
def plot_directivity_finite_width_2d(element_width, wavelength, ax=None, **kwargs):
"""
Parameters
----------
element_width
wavelength
ax : matplotlib.axes._subplots.AxesSubplot
kwargs
Returns
-------
"""
if ax is None:
_, ax = plt.subplots()
else:
_ = ax.figure
title = kwargs.get(
"title", "Directivity of an element (uniform sources along a straight line)"
)
ratio = element_width / wavelength
theta = np.linspace(-np.pi / 2, np.pi / 2, 100)
directivity = ut.directivity_finite_width_2d(theta, element_width, wavelength)
ax.plot(
np.rad2deg(theta),
directivity,
label=rf"$a/\lambda = {ratio:.2f}$",
**kwargs,
)
ax.set_xlabel(r"Polar angle $\theta$ (deg)")
ax.set_ylabel("directivity (1)")
ax.set_title(title)
ax.set_xlim([-90, 90])
ax.set_ylim([0, 1.2])
ax.xaxis.set_major_locator(ticker.MultipleLocator(30.0))
ax.xaxis.set_minor_locator(ticker.MultipleLocator(15.0))
ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))
ax.legend()
return ax
[docs]
class RayPlotter:
def __init__(
self, grid, ray, element_index, linestyle="m--", tolerance_distance=1e-3
):
self.grid = grid
self.ray = ray
self.element_index = element_index
self.linestyle = linestyle
self._lines = []
self.debug = False
self.y = 0
self.tolerance_distance = tolerance_distance
def __call__(self, event):
logger.debug(
"button=%d, x=%d, y=%d, xdata=%f, ydata=%f"
% (event.button, event.x, event.y, event.xdata, event.ydata)
)
ax = event.canvas.figure.axes[0]
if event.button == 1:
self.draw_ray(ax, event.xdata, event.ydata)
elif event.button == 3:
self.clear_rays(ax)
if self.debug:
print("show_ray_on_clic() finish with no error")
[docs]
def draw_ray(self, ax, x, z):
gridpoints = self.grid.to_1d_points()
wanted_point = (x, self.y, z)
point_index = gridpoints.closest_point(*wanted_point)
obtained_point = gridpoints[point_index]
distance = g.norm2(*(obtained_point - wanted_point))
if distance > self.tolerance_distance:
logger.warning(
"The closest grid point is far from what you want (dist: {:.2f} mm)".format(
distance * 1000
)
)
legs = self.ray.get_coordinates_one(self.element_index, point_index)
line = ax.plot(legs.x, legs.z, self.linestyle)
self._lines.extend(line)
logger.debug("Draw a ray")
ax.figure.canvas.draw_idle()
[docs]
def clear_rays(self, ax):
"""Clear all rays"""
lines_to_clear = [line for line in ax.lines if line in self._lines]
for line in lines_to_clear:
ax.lines.remove(line)
self._lines.remove(line)
logger.debug(f"Clear {len(lines_to_clear)} ray(s) on figure")
ax.figure.canvas.draw_idle()
[docs]
def connect(self, ax):
"""Connect to matplotlib event backend"""
ax.figure.canvas.mpl_connect("button_press_event", self)
[docs]
def draw_rays_on_click(grid, ray, element_index, ax=None, linestyle="m--"):
"""
Dynamic plotting of rays on a plot.
Left-click: draw a ray between the probe element and the mouse point.
Right-click: clear all rays in the plot.
Parameters
----------
grid : Grid
ray : Rays
element_index : int
ax : Axis
Matplotlib axis on which to plot. If None: current axis.
linestyle : str
A valid matplotlib linestyle. Default: 'm--'
Returns
-------
ray_plotter : RayPlotter
"""
if ax is None:
ax = plt.gca()
ray_plotter = RayPlotter(
grid=grid, ray=ray, element_index=element_index, linestyle=linestyle
)
ray_plotter.connect(ax)
return ray_plotter
[docs]
def plot_interfaces(
oriented_points_list,
ax=None,
show_probe=True,
show_last=True,
show_orientations=False,
n_arrows=10,
title="Interfaces",
savefig=None,
filename="interfaces",
markers=None,
show_legend=True,
quiver_kwargs=None,
):
"""
Plot interfaces on the Oxz plane.
Assume the first interface is for the probe and the last is for the grid.
Parameters
----------
oriented_points_list : list[OrientedPoints]
ax : matplotlib.axis.Axis
show_probe : boolean
Default True
show_last : boolean
Default: True. Useful for hiding the grid.
show_orientations : boolean
Plot arrows for the orientations. Default: False
n_arrows : int
Number of arrows per interface to plot.
title : str or None
Title to display. None for no title.
savefig : boolean
If True, the plot will be saved. Default: ``conf['savefig']``.
filename : str
Filename of the plot, used if savefig is True. Default: 'interfaces'
markers : List[str]
Matplotlib markers for each interfaces. Default: '.' for probe, ',k' for the grid,
'.' for the rest.
show_legend : boolean
Default True
quiver_kwargs : dict
Arguments for displaying the arrows (cf. matplotlib function 'quiver')
Returns
-------
ax : matplotlib.axis.Axis
"""
if savefig is None:
savefig = conf["savefig"]
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
if quiver_kwargs is None:
quiver_kwargs = dict(width=0.0003)
numinterfaces = len(oriented_points_list)
if markers is None:
markers = ["."] + ["."] * (numinterfaces - 2) + [",k"]
for i, (interface, marker) in enumerate(zip(oriented_points_list, markers)):
if i == 0 and not show_probe:
continue
if i == numinterfaces - 1 and not show_last:
continue
(line,) = ax.plot(
interface.points.x, interface.points.z, marker, label=interface.points.name
)
if show_orientations:
# arrow every k points
k = len(interface.points) // n_arrows
if k == 0:
k = 1
# import pytest; pytest.set_trace()
ax.quiver(
interface.points.x[::k],
interface.points.z[::k],
interface.orientations.x[::k, 2],
interface.orientations.z[::k, 2],
color=line.get_color(),
units="xy",
angles="xy",
**quiver_kwargs,
)
# set labels only if there is none in the axis yet
if ax.get_xlabel() == "":
ax.set_xlabel("x (mm)")
if ax.get_ylabel() == "":
ax.set_ylabel("z (mm)")
ax.xaxis.set_major_formatter(milli_formatter)
ax.yaxis.set_major_formatter(milli_formatter)
ax.xaxis.set_minor_formatter(milli_formatter)
ax.yaxis.set_minor_formatter(milli_formatter)
if title is not None:
ax.set_title(title)
ylim = ax.get_ylim()
if ylim[0] < ylim[1]:
ax.invert_yaxis()
if show_legend:
ax.legend(loc="best")
ax.axis("equal")
if savefig:
fig.savefig(filename)
return ax
[docs]
def common_dynamic_db_scale(data_list, area=None, db_range=40.0, ref_db=None):
"""
Scale such as:
- 0 dB corresponds to the maximum value in the area for all data arrays,
- the clim for each data array are bound by the maximum value in the area.
Parameters
----------
data_list
db_range : float
Yields
------
ref_db
(clim_min, clim_max)
Examples
--------
>>> area = grid.points_in_rectbox(xmin=10, xmax=20)
>>> common_db_scale_iter = common_dynamic_db_scale(data_list, area)
>>> for data in data_list:
... ref_db, clim = next(common_db_scale_iter)
... plot_oxz(data, grid, scale='db', ref_db=ref_db, clim=clim)
"""
data_max_list = []
if area is None:
area = slice(None)
for data in data_list:
data_max_list.append(np.nanmax(np.abs(data[area])))
if ref_db is None:
ref_db = max(data_max_list)
data_max_db_list = ut.decibel(data_max_list, ref_db)
for data_max_db in data_max_db_list:
yield ref_db, (data_max_db - db_range, data_max_db)