# A little helper module for plotting of broadbean objects
from typing import cast
import matplotlib.axes
import matplotlib.pyplot as plt
import numpy as np
from broadbean import BluePrint, Element, Sequence
from broadbean.sequence import SequenceConsistencyError
# The object we can/want to plot
BBObject = Sequence | BluePrint | Element
[docs]
def getSIScalingAndPrefix(minmax: tuple[float, float]) -> tuple[float, str]:
"""
Return the scaling exponent and unit prefix. E.g. (-2e-3, 1e-6) will
return (1e3, 'm')
Args:
minmax: The (min, max) value of the signal
Returns:
A tuple of the scaling (inverse of the prefix) and the prefix
string.
"""
v_max: float = max(map(abs, minmax))
if v_max == 0:
v_max = 1
exponent = np.log10(v_max)
prefix = ""
scaling: float = 1
if exponent < 0:
prefix = "m"
scaling = 1e3
if exponent < -3:
prefix = "micro "
scaling = 1e6
if exponent < -6:
prefix = "n"
scaling = 1e9
return (scaling, prefix)
def _plot_object_validator(obj_to_plot: BBObject) -> None:
"""
Validate the object
"""
if isinstance(obj_to_plot, Sequence):
proceed = obj_to_plot.checkConsistency(verbose=True)
if not proceed:
raise SequenceConsistencyError
elif isinstance(obj_to_plot, Element):
obj_to_plot.validateDurations()
elif isinstance(obj_to_plot, BluePrint):
assert obj_to_plot.SR is not None
def _plot_object_forger(obj_to_plot: BBObject, **forger_kwargs) -> dict[int, dict]:
"""
Make a forged sequence out of any object.
Returns a forged sequence.
"""
if isinstance(obj_to_plot, BluePrint):
elem = Element()
elem.addBluePrint(1, obj_to_plot)
seq = Sequence()
seq.addElement(1, elem)
seq.setSR(obj_to_plot.SR)
elif isinstance(obj_to_plot, Element):
seq = Sequence()
seq.addElement(1, obj_to_plot)
seq.setSR(obj_to_plot._meta["SR"])
elif isinstance(obj_to_plot, Sequence):
seq = obj_to_plot
forged_seq = seq.forge(includetime=True, **forger_kwargs)
return forged_seq
def _plot_summariser(seq: dict[int, dict]) -> dict[int, dict[str, np.ndarray]]:
"""
Return a plotting summary of a subsequence.
Args:
seq: The 'content' value of a forged sequence where a
subsequence resides
Returns:
A dict that looks like a forged element, but all waveforms
are just two points, np.array([min, max])
"""
output = {}
# we assume correctness, all postions specify the same channels
chans = seq[1]["data"].keys()
minmax = dict(zip(chans, [(0, 0)] * len(chans)))
for element in seq.values():
arr_dict = element["data"]
for chan in chans:
wfm = arr_dict[chan]["wfm"]
if wfm.min() < minmax[chan][0]:
minmax[chan] = (wfm.min(), minmax[chan][1])
if wfm.max() > minmax[chan][1]:
minmax[chan] = (minmax[chan][0], wfm.max())
output[chan] = {
"wfm": np.array(minmax[chan]),
"m1": np.zeros(2),
"m2": np.zeros(2),
"time": np.linspace(0, 1, 2),
}
return output
# the Grand Unified Plotter
[docs]
def plotter(obj_to_plot: BBObject, **forger_kwargs) -> None:
"""
The one plot function to be called. Turns whatever it gets
into a sequence, forges it, and plots that.
"""
# TODO: Take axes as input
# strategy:
# * Validate
# * Forge
# * Plot
_plot_object_validator(obj_to_plot)
seq = _plot_object_forger(obj_to_plot, **forger_kwargs)
# Get the dimensions.
chans = seq[1]["content"][1]["data"].keys()
seqlen = len(seq.keys())
def update_minmax(chanminmax, wfmdata, chanind):
(thismin, thismax) = (wfmdata.min(), wfmdata.max())
if thismin < chanminmax[chanind][0]:
chanminmax[chanind] = [thismin, chanminmax[chanind][1]]
if thismax > chanminmax[chanind][1]:
chanminmax[chanind] = [chanminmax[chanind][0], thismax]
return chanminmax
# Then figure out the figure scalings
minf: float = -np.inf
inf: float = np.inf
chanminmax: list[tuple[float, float]] = [(inf, minf)] * len(chans)
for chanind, chan in enumerate(chans):
for pos in range(1, seqlen + 1):
if seq[pos]["type"] == "element":
wfmdata = seq[pos]["content"][1]["data"][chan]["wfm"]
chanminmax = update_minmax(chanminmax, wfmdata, chanind)
elif seq[pos]["type"] == "subsequence":
for pos2 in seq[pos]["content"].keys():
elem = seq[pos]["content"][pos2]["data"]
wfmdata = elem[chan]["wfm"]
chanminmax = update_minmax(chanminmax, wfmdata, chanind)
fig, axs = plt.subplots(len(chans), seqlen, squeeze=False)
# ...and do the plotting
for chanind, chan in enumerate(chans):
# figure out the channel voltage scaling
# The entire channel shares a y-axis
minmax: tuple[float, float] = chanminmax[chanind]
(voltagescaling, voltageprefix) = getSIScalingAndPrefix(minmax)
voltageunit = voltageprefix + "V"
for pos in range(seqlen):
ax = cast(matplotlib.axes.Axes, axs[chanind, pos])
# reduce the tickmark density (must be called before scaling)
ax.locator_params(tight=True, nbins=4, prune="lower")
if seq[pos + 1]["type"] == "element":
content = seq[pos + 1]["content"][1]["data"][chan]
wfm = content["wfm"]
m1 = content.get("m1", np.zeros_like(wfm))
m2 = content.get("m2", np.zeros_like(wfm))
time = content["time"]
newdurs = content.get("newdurations", [])
else:
arr_dict = _plot_summariser(seq[pos + 1]["content"])
wfm = arr_dict[chan]["wfm"]
newdurs = []
ax.annotate(
"SUBSEQ",
xy=(0.5, 0.5),
xycoords="axes fraction",
horizontalalignment="center",
)
time = np.linspace(0, 1, 2) # needed for timeexponent
# Figure out the axes' scaling
timeexponent = np.log10(time.max())
timeunit = "s"
timescaling: float = 1.0
if timeexponent < 0:
timeunit = "ms"
timescaling = 1e3
if timeexponent < -3:
timeunit = "micro s"
timescaling = 1e6
if timeexponent < -6:
timeunit = "ns"
timescaling = 1e9
if seq[pos + 1]["type"] == "element":
ax.plot(
timescaling * time,
voltagescaling * wfm,
lw=3,
color=(0.6, 0.4, 0.3),
alpha=0.4,
)
ymax = voltagescaling * chanminmax[chanind][1]
ymin = voltagescaling * chanminmax[chanind][0]
yrange = ymax - ymin
ax.set_ylim((ymin - 0.05 * yrange, ymax + 0.2 * yrange))
if seq[pos + 1]["type"] == "element":
# TODO: make this work for more than two markers
# marker1 (red, on top)
y_m1 = ymax + 0.15 * yrange
marker_on = np.ones_like(m1)
marker_on[m1 == 0] = np.nan
marker_off = np.ones_like(m1)
ax.plot(
timescaling * time,
y_m1 * marker_off,
color=(0.6, 0.1, 0.1),
alpha=0.2,
lw=2,
)
ax.plot(
timescaling * time,
y_m1 * marker_on,
color=(0.6, 0.1, 0.1),
alpha=0.6,
lw=2,
)
# marker 2 (blue, below the red)
y_m2 = ymax + 0.10 * yrange
marker_on = np.ones_like(m2)
marker_on[m2 == 0] = np.nan
marker_off = np.ones_like(m2)
ax.plot(
timescaling * time,
y_m2 * marker_off,
color=(0.1, 0.1, 0.6),
alpha=0.2,
lw=2,
)
ax.plot(
timescaling * time,
y_m2 * marker_on,
color=(0.1, 0.1, 0.6),
alpha=0.6,
lw=2,
)
# If subsequence, plot lines indicating min and max value
if seq[pos + 1]["type"] == "subsequence":
# min:
ax.plot(
time,
np.ones_like(time) * wfm[0],
color=(0.12, 0.12, 0.12),
alpha=0.2,
lw=2,
)
# max:
ax.plot(
time,
np.ones_like(time) * wfm[1],
color=(0.12, 0.12, 0.12),
alpha=0.2,
lw=2,
)
ax.set_xticks([])
# time step lines
for dur in np.cumsum(newdurs):
ax.plot(
[timescaling * dur, timescaling * dur],
[ax.get_ylim()[0], ax.get_ylim()[1]],
color=(0.312, 0.2, 0.33),
alpha=0.3,
)
# labels
if pos == 0:
ax.set_ylabel(f"({voltageunit})")
if pos == seqlen - 1 and not (isinstance(obj_to_plot, BluePrint)):
newax = ax.twinx()
newax.set_yticks([])
if isinstance(chan, int):
new_ylabel = f"Ch. {chan}"
elif isinstance(chan, str):
new_ylabel = chan
newax.set_ylabel(new_ylabel)
if seq[pos + 1]["type"] == "subsequence":
ax.set_xlabel("Time N/A")
else:
ax.set_xlabel(f"({timeunit})")
# remove excess space from the plot
if not chanind + 1 == len(chans):
ax.set_xticks([])
if not pos == 0:
ax.set_yticks([])
fig.subplots_adjust(hspace=0, wspace=0)
# display sequencer information
if chanind == 0 and isinstance(obj_to_plot, Sequence):
seq_info = seq[pos + 1]["sequencing"]
titlestring = ""
if seq_info["twait"] == 1: # trigger wait
titlestring += "T "
if seq_info["nrep"] > 1: # nreps
titlestring += "\u21bb{} ".format(seq_info["nrep"])
if seq_info["nrep"] == 0:
titlestring += "\u221e "
if seq_info["jump_input"] != 0:
if seq_info["jump_input"] == -1:
titlestring += "E\u2192 "
else:
titlestring += "E{} ".format(seq_info["jump_input"])
if seq_info["goto"] > 0:
titlestring += "\u21b1{}".format(seq_info["goto"])
ax.set_title(titlestring)