Source code for flamingpy.utils.viz

# Copyright 2022 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions to draw various graphs and generate plots using Matplotlib.

Plots are configured via the ``plot_params`` dictionary. These parameters
are associated with Matplolib's rc settings and are modified within the
plotting functions using the ``rc_context`` context manager. This approach
avoids having to modify the global Matplotlib ``rc_params``.

To modify the plot parameters use, for example,

.. code-block:: python

    from flamingpy.utils.viz import plot_params as fp_plot_params
    fp_plot_params["font.size"] = 20
"""

# pylint: disable=too-many-statements,singleton-comparison, too-many-lines

import itertools as it
import math

import numpy as np
import networkx as nx
import plotly.graph_objects as go
import matplotlib as mpl
from matplotlib.patches import Patch
from matplotlib.ticker import Formatter
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

from flamingpy.codes import Stabilizer
from flamingpy.cv import gkp

plot_params = {
    "font.size": 10,
    "font.family": "serif",
    "axes.labelsize": 11,
    "axes.titlesize": 13,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "grid.color": "lightgray",
    "lines.markersize": 5,
    "lines.linewidth": 4,
    "figure.figsize": (8, 6),
}


[docs]@mpl.rc_context(plot_params) def plot_integer_part(xs, ns, alpha, show=True): """Plot the integer part of real numbers mod alpha.""" fig = plt.figure() ax = plt.gca() xmin, xmax = alpha * (xs[0] // alpha), alpha * (xs[-1] // alpha) + alpha newxticks = np.linspace(xmin, xmax, int((xmax - xmin) // alpha) + 1) ax.xaxis.set_major_formatter(PiFormatter()) plt.plot(xs, ns, ".") plt.title("Integer Part") plt.xlabel("$x$") plt.xticks(newxticks) plt.ylabel(r"$\mathrm{int}(x)$") if show: plt.show() return fig, ax
[docs]@mpl.rc_context(plot_params) def plot_fractional_part(xs, fs, alpha, show=True): """Plot the fractional part of real numbers mod alpha.""" fig = plt.figure() ax = plt.gca() xmin, xmax = alpha * (xs[0] // alpha), alpha * (xs[-1] // alpha) + alpha newxticks = np.linspace(xmin, xmax, int((xmax - xmin) // alpha) + 1) newyticks = np.linspace(-alpha / 2, alpha / 2, num=7) ax.xaxis.set_major_formatter(PiFormatter()) ax.yaxis.set_major_formatter(PiFormatter()) plt.plot(xs, fs, ".") plt.title("Fractional Part") plt.xticks(newxticks) plt.xlabel("$x$") plt.yticks(newyticks) plt.ylabel(r"$\mathrm{frac}(x)$") if show: plt.show() return fig, ax
[docs]@mpl.rc_context(plot_params) def plot_GKP_bins(outcomes, bit_values, alpha, show=True): """Plot binned real numbers mod alpha.""" fig = plt.figure() ax = plt.gca() xmin, xmax = alpha * (outcomes[0] // alpha), alpha * (outcomes[-1] // alpha) + alpha newxticks = np.linspace(xmin, xmax, int((xmax - xmin) // alpha) + 1) ax.xaxis.set_major_formatter(PiFormatter()) plt.plot(outcomes, bit_values, ".") plt.title("Binned values") plt.xticks(newxticks) plt.xlabel("Outcomes") plt.yticks([0, 1], [0, 1]) plt.ylabel("Bit values") if show: plt.show() return fig, ax
[docs]@mpl.rc_context(plot_params) def plot_Z_err_cond(hom_val, error, alpha, use_hom_val, show=True): """Plot conditional phase probabilities for GKP states.""" fig = plt.figure() ax = plt.gca() _, frac = gkp.GKP_binner(hom_val, return_fraction=True) val = hom_val if use_hom_val else frac # bounds for the plot if use_hom_val: xmin, xmax = alpha * np.array([hom_val[0] // alpha, hom_val[-1] // alpha + 1]) else: xmin, xmax = -alpha / 2, alpha / 2 print(xmin, xmax, min(val), max(val)) newxticks = np.linspace(xmin, xmax, int((xmax - xmin) // alpha) + 1) ax.xaxis.set_major_formatter(PiFormatter()) plt.plot(val, error, ".") plt.xticks(newxticks) plt.xlabel("Homodyne value") plt.ylabel("Error") plt.title( "Conditional phase probabilities: " + ("Full homodyne value" if use_hom_val else "Central peak") ) if show: plt.show() return fig, ax
[docs]def draw_EGraph( egraph, backend="matplotlib", **kwargs, ): """Draw an EGraph using either matplotlib or plotly as backend. Args: egraph (EGraph): The EGraph to draw. backend (str): The backend to use, either "matplotlib" or "plotly". Keyword args: color_nodes (bool or string or dict): Options are: True: color the nodes based on the 'color' attribute attached to the node. If unavailable, color nodes black. string: color all nodes with the color specified by the string tuple[str, dict]: color nodes based on attribute and defined colour string by providing a tuple with [attribute, color_dictionary], for example: ``["state", {"GKP": "b", "p": "r"}]`` will look at the "state" attribute of the node, and colour according to the dictionary. color_edges (bool or string or dict): True: color the edges based on the 'color' attribute attached to the node. If unavailable, color nodes grey. string: color all edges with the color specified by the stirng tuple: color edges based on attribute and defined colour string by providing a tuple with [attribute, color_dictionary], for example: if the edge attribute "weight" can be +1 or -1, the tuple should be of the form: ``("weight", {-1: minus_color, +1: plus_color})`` label (NoneType, string or iterable): plot values next to each node associated with the node attribute label. For example, to plot bit values, set label to "bit_val". If set to 'index', it will plot the integer indices of the nodes. If the attribute for some or all of the nodes, a message will print indicating for how many nodes the attribute has not been set. title (bool): if True, display the title, depending on the label. For default labels, the titles are converted from attribute name to plane English and capitalized. legend (bool): if True and color_nodes argument is a tuple(str, dict), display the a color legend with node attributes. show_axes (bool): if False, turn off the axes. See draw_EGraph_matplotlib or draw_EGraph_plotly for backend specific keyword arguments. """ if backend == "matplotlib": return draw_EGraph_matplotlib(egraph, **kwargs) if backend == "plotly": return draw_EGraph_plotly(egraph, **kwargs) raise ValueError(f"Unknown backend: {backend}")
def _get_title(title=None, label="index"): """Get the title for the EGraph plot, depending on the label if title is a boolean. Args: title (string, boolean or NoneType): variable to determine the returned title. If ``title`` is a string, it will simply return the string. Else, if ``title is None``, it will return None. If the ``title`` is a boolean set to ``True``, it will return a title based on ``label``. In all other cases, the function will return None (i.e. there will be no title on the figure). label (string, list or tuple): Only relevant if ``title == True``. In that case, there are three options - if the label is set to p_phase, p_phase_cond, hom_val_p, hom_val_q, bit_val, weight or index, the title will be the label converted to a plane English word. - if the label is another string, the title will simply be that string. - if the label is a list or tuple of strings, the title will be the list or tuple unpacked separated by a comma. """ # Return nothing is not title if not title: return None # Unpack list or tuple... if isinstance(label, (tuple, list)): if len(label) > 1: return ", ".join(label) # or convert to a single string label = label[0] # Return title directly... if isinstance(title, str): return title # ... or base it on label value if isinstance(label, str): if isinstance(title, bool): title_dict = { "p_phase": "Phase error probabilities", "p_phase_cond": "Conditional phase error probabilities", "hom_val_p": "p-homodyne outcomes", "hom_val_q": "q-homodyne outcomes", "bit_val": "Bit values", "weight": "Weights", "index": "Indices", } return title_dict.get(label, label) return None
[docs]@mpl.rc_context(plot_params) def draw_EGraph_matplotlib( egraph, color_nodes=False, color_edges=False, label=None, title=None, legend=False, show_axes=True, **kwargs, ): """Draw the graph state represented by the EGraph. Args: See draw_EGraph for general keyword arguments, see keyword arguments below for matplotlib-specific arguments. Keyword args: dimensions (tuple): Dimensions of the region that should be plotted. Should be of the form: ``([xmin, xmax], [ymin, ymax], [zmin, zmax])`` If None, sets the dimensions to the smallest rectangular space containing all the nodes. Returns: tuple: Matplotib Figure and Axes. """ dimensions = kwargs.get("dimensions", None) if dimensions is None: mins = map(min, zip(*egraph.nodes)) maxs = map(max, zip(*egraph.nodes)) mins = map(lambda x: int(np.floor(x)), mins) maxs = map(lambda x: int(np.ceil(x)), maxs) dimensions = zip(mins, maxs) xlim, ylim, zlim = [list(lim) for lim in dimensions] fig = plt.figure() ax = fig.add_subplot(111, projection="3d") title = _get_title(title, label) if title: ax.set_title(title) ax.title.set_size(plot_params.get("axes.titlesize")) # plot graph ax = _plot_EGraph_nodes(ax, egraph, color_nodes, label, title, legend) ax = _plot_EGraph_edges(ax, egraph, color_edges) # plot generalities plt.xticks(range(xlim[0], xlim[1] + 1)) plt.yticks(range(zlim[0], zlim[1] + 1)) ax.set_zticks(range(ylim[0], ylim[1] + 1)) for lim in [xlim, ylim, zlim]: if lim[0] == lim[1]: lim[0] -= 1 lim[1] += 1 plt.xlim(xlim) plt.ylim(zlim) ax.set_zlim(ylim) ax.set_xlabel("x", labelpad=15) ax.set_ylabel("z", labelpad=15) ax.set_zlabel("y", labelpad=15) if not show_axes: ax.axis("off") plt.draw() return fig, ax
[docs]def draw_EGraph_plotly( egraph, color_nodes=False, color_edges=False, label="all", title=None, legend=False, show_axes=True, **kwargs, ): """Draw the graph state represented by the EGraph with plotly. NOTE: Plotly gives you a wide range of options for how and where to display your figures. Refer to the `Plotly documentation <https://plotly.com/python/renderers/>`_ for more information. Args: See draw_EGraph for general keyword arguments, see keyword arguments below for matplotlib-specific arguments. Keyword args: showbackground (bool): if True, shows the background of the graph. Default is False. showspikes (bool): if True, shows coordinate-lines when hovering over a node. Default is False. showgrid (bool): if True, shows the grid of the graph. Default is True. width (int): width of the graph. Default is 750. height (int): height of the graph. Default is 750. Returns: figure: Plotly Express figure. """ # Layout and axis axis = dict( showbackground=kwargs.get("showbackground", False), showline=True, showspikes=kwargs.get("showspikes", False), zeroline=False, showgrid=kwargs.get("showgrid", True), showticklabels=show_axes, tickmode="linear", tick0=0, dtick=1, ) layout = go.Layout( width=kwargs.get("width", 750), height=kwargs.get("height", 750), showlegend=legend, hovermode="closest", scene=dict( xaxis={**dict(title="x" * show_axes), **axis}, yaxis={**dict(title="y" * show_axes), **axis}, zaxis={**dict(title="z" * show_axes), **axis}, ), title=_get_title(title, label), ) # Figure object fig = go.Figure(layout=layout) # Nodes nodes = np.array(egraph.nodes) x_nodes, y_nodes, z_nodes = nodes[:, 0], nodes[:, 1], nodes[:, 2] nodeColors = [_get_node_color(egraph, node, color_nodes) for node in egraph.nodes] nodeInfo = [_get_node_info(egraph, node, information=label) for node in egraph.nodes] fig.add_traces( go.Scatter3d( name="nodes", x=x_nodes, y=y_nodes, z=z_nodes, mode="markers", marker=dict( symbol="circle", size=5, color=nodeColors, line=dict(color="black", width=0.5), ), hovertext=nodeInfo, hoverinfo="text", ) ) # Edges x_edges, y_edges, z_edges = [], [], [] for edge in egraph.edges: x0, y0, z0 = edge[0] x1, y1, z1 = edge[1] x_edges.extend([x0, x1, None]) y_edges.extend([y0, y1, None]) z_edges.extend([z0, z1, None]) edgeColors = [_get_edge_color(egraph, edge, color_edges) for edge in egraph.edges] fig.add_traces( go.Scatter3d( name="edges", x=x_edges, y=y_edges, z=z_edges, line=dict(color="black", width=1), mode="lines", marker=dict(color=edgeColors), hoverinfo="none", ) ) return fig
def _plot_EGraph_edges(ax, egraph, color_edges): """Draw the edges of the graph state represented by the EGraph. Args: ax (matplotlib.axes.Axes): the axes to draw the lines in color_edges (bool or string or dict): True: color the edges based on the 'color' attribute attached to the node. If unavailable, color nodes grey. string: color all edges with the color specified by the stirng dict: color edges based on attribute and defined colour string by providing a tuple with [attribute, color_dictionary], for example: if the edge attribute "weight" can be +1 or -1, the tuple should be of the form: ``("weight", {-1: minus_color, +1: plus_color})``. Returns: A Matplotib Axes object. """ # Plotting edges. for edge in egraph.edges: color = _get_edge_color(egraph, edge, color_edges) x1, z1, y1 = edge[0] x2, z2, y2 = edge[1] ax.plot([x1, x2], [y1, y2], [z1, z2], color=color, linewidth=0.5) return ax def _plot_EGraph_nodes(ax, egraph, color_nodes, label, name, legend): """Draw the nodes of the graph state represented by the EGraph. Args: ax (matplotlib.axes.Axes): the axes to draw the points in color_nodes (bool or string or dict): Options are: True: color the nodes based on the 'color' attribute attached to the node. If unavailable, color nodes black. string: color all nodes with the color specified by the string tuple[str, dict]: color nodes based on attribute and defined colour string by providing a tuple with (attribute, color_dictionary), for example: ``("state", {"GKP": "b", "p": "r"})`` will look at the "state" attribute of the node, and colour according to the dictionary. label (NoneType or string): plot values next to each node associated with the node attribute label. For example, to plot bit values, set label to "bit_val". If set to 'index', it will plot the integer indices of the nodes. name (bool): attribute name to display as title. legend (bool): if True and color_nodes argument is a tuple(str, dict), display the a color legend with node attributes. Returns: A Matplotib Axes object. """ if label: n_uncomputed = 0 if label == "index": indices = egraph.index_generator() # Plotting points. y and z are swapped in the loops so that # z goes into the page; however, the axes labels are correct. for node in egraph.nodes: x, z, y = node color = _get_node_color(egraph, node, color_nodes) ax.scatter(x, y, z, c=color) if label: value = egraph.nodes[node].get(label) if label != "index" else indices[node] if value is not None: # Raise negative sign above node. sign = "^{-}" if value < 0 else " " if not isinstance(value, int): value = r"${}{:.2g}$".format(sign, np.abs(value)) ax.text( x + 0.05, y, z, value, color="MediumBlue", zorder=2, ) else: n_uncomputed += 1 if label and n_uncomputed > 0: message = "{} at {} node(s) have not yet been computed." print(message.format(name.lower(), n_uncomputed)) # Plotting nodes legend if isinstance(color_nodes, tuple) and legend: # these two lines are just a handy way to create a legend for # the node colors and attributes by generating handles of empty lines # with the label and color of the node property handles = [ mlines.Line2D([], [], marker="o", linewidth=0, color=color, label=node_property) for node_property, color in color_nodes[1].items() ] ax.legend(handles=handles) return ax def _get_node_color(egraph, node, color_nodes): """Color nodes based on ``color_nodes`` arg: - if `color_nodes` is a string use the string as color, - using the attribute and color dict if `color_nodes` is a tuple(str,dict), - or based on color attribute (when available) if `color_nodes` is bool and True; - black otherwise. """ default_color = "black" if isinstance(color_nodes, str): color = color_nodes elif isinstance(color_nodes, tuple): node_attribute, color_dict = color_nodes if not (isinstance(node_attribute, str) and isinstance(color_dict, dict)): raise ValueError( "Inappropiate value for `color_nodes` argument:" "Check that it complies with the type `tuple(str, dict)`," "where the string corresponds to a valid node attribute," "the dictionary keys to valid attribute values and" "dictionary values to valid matplotlib color strings." ) node_property = egraph.nodes[node].get(node_attribute) color = color_dict.get(node_property, default_color) elif color_nodes == True: color = egraph.nodes[node].get("color", default_color) else: color = default_color return color def _get_node_info(egraph, node, information="coordinates"): """Information to be displayed when hovering over a node based on ``information`` Arguments: egraph (EGraph): the EGraph with the node of interest. node (tuple): the node to get the information from. information (str, iterable or NoneType): the information to be displayed: - if ``information`` is a string, the value of the node attribute ``information``, - if ``information`` is an iterable, a list of the values of the node attributes in ``information``, - if ``information`` is None, return None (nothing will be displayed). - if ``information`` contains ``"index"``, include the index from ``egraph.to_indices[node]``. - if ``information`` is "all" it will display the coordinates, index and all the information avaible in the node. """ # information dictionary info_dict = egraph.nodes[node].copy() # list all available information if information == "all": info_list = list(info_dict.keys()) info_list.sort() information = ["index"] + info_list # add index if desired if "index" in information: info_dict["index"] = egraph.to_indices[node] # returning relevant information if information == "coordinates": return str(node) if information is None: return None if isinstance(information, str): node_property = info_dict.get(information) return f"{information}: {node_property}" if isinstance(information, (tuple, list)): node_info = str(node) for key in information: node_property = info_dict.get(key, None) if node_property is not None: node_info += "<br />" + f"{key}: {node_property}" return node_info raise ValueError( "Inappropiate value for `information` argument:" "Check that it complies with the type `str`," "`tuple` or `list`, or has value `None`." ) def _get_edge_color(egraph, edge, color_edges): """Return the color of an EGraph edge. Args: color_edges (bool or string or dict): True: color the edges based on the 'color' attribute attached to the node. If unavailable, color nodes grey. string: color all edges with the color specified by the stirng dict: color edges based on attribute and defined colour string by providing a tuple with [attribute, color_dictionary], for example: if the edge attribute "weight" can be +1 or -1, the tuple should be of the form: ``("weight", {-1: minus_color, +1: plus_color})``. Returns: A color (string) """ # Color edges based on `color_edges` choices (see docstring) if isinstance(color_edges, str): color = color_edges elif isinstance(color_edges, tuple): edge_attribute, color_dict = color_edges if not (isinstance(edge_attribute, str) and isinstance(color_dict, dict)): raise ValueError( "Inappropiate value for `color_edges` argument:" "Check that it complies with the type `tuple(str, dict)`," "where the string corresponds to a valid edge attribute," "the dictionary keys to valid attribute values and" "dictionary values to valid matplotlib color strings." ) edge_property = egraph.edges[edge].get(edge_attribute) color = color_dict.get(edge_property) elif color_edges == True: color = egraph.edges[edge].get("color") or "grey" else: color = "grey" return color
[docs]@mpl.rc_context(plot_params) def plot_mat_heat_map(mat, show=True, title=None): """Plot the heat map of a matrix.""" fig = plt.figure() ax = plt.gca() if not isinstance(mat, np.ndarray): mat = mat.toarray() plt.matshow(mat, 0) if title: plt.title(title) cbar = plt.colorbar() cbar.set_label( "value", rotation=270, fontsize=plot_params.get("axes.labelsize", 10) * 1.2, labelpad=20 ) if show: plt.show() axs = [ax, cbar.ax] return fig, axs
[docs]@mpl.rc_context(plot_params) def draw_dec_graph(graph, label_edges=False, node_labels=None, title=""): """Draw a stabilizer or matching graph with a color legend. This requires that the graph is implemented with the NetworkX backend. Args: graph (NxStabilizerGraph or NxMatchingGraph): the graph to draw. label_edges (bool, optional): if `True`, label the edges of the graph with their weight. Defaults to False. node_labels (dict of node to label, optional): if provided, the nodes will be identified with the given labels. Else, there will be no label for the nodes. title (string, optional): add the given title to the plot. """ if not isinstance(graph.graph, nx.Graph): raise ValueError("The graph must be implemented with the networkx backend.") # Remove 'low' and 'high' nodes, which are not important for visualization. if graph.__class__.__name__ == "NxStabilizerGraph": graph.graph.remove_nodes_from({"low", "high"}) graph = graph.graph layout = nx.circular_layout(graph) fig, ax = plt.subplots() if title != "": ax.set_title(title) # Color edges based on weight, and draw a colobar. # divider = make_axes_locatable(ax) ax.axis("off") cmap, norm = draw_curved_edges(graph, layout, ax) nx.draw_networkx_nodes(graph, pos=layout, node_color="#202020", ax=ax) # Draw node labels if node_labels is not None: node_labels = {node: label for node, label in node_labels.items() if node in graph} nx.draw_networkx_labels( graph, pos=layout, labels=node_labels, font_color="white", font_weight=100, ax=ax ) if label_edges: weight_dict = {edge: "{:.2f}".format(graph.edges[edge]["weight"]) for edge in graph.edges} nx.draw_networkx_edge_labels( graph, layout, edge_labels=weight_dict, font_size=plot_params.get("font.size", 7) * 0.75, clip_on=False, alpha=0.7, verticalalignment="center_baseline", bbox={"alpha": 0}, ) cax, kw = mpl.colorbar.make_axes(ax, location="right", fraction=0.15) cbar = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm, **kw) cbar.ax.tick_params(labelsize=plot_params.get("axes.labelsize", 10), rotation=270) cbar.set_label( "weight", rotation=270, fontsize=plot_params.get("axes.labelsize", 10), labelpad=20 ) axs = [ax, cax] return fig, axs
[docs]def draw_curved_edges(graph, layout, ax, rad=0.5): """Draw curved edges using matplotlib directly instead of networkx. This requires that the graph is implemented with the NetworkX backend. Args: graph (NxGraph): The graph to draw. layout (dict): A dictionary of positions keyed by node. ax (matplotlib.pyplot.Axis): The axis on which arrows should be drawn. rad (float, optional): Curvature of the arrows in radians. """ edges = graph.edges edge_weights = [edges[edge]["weight"] for edge in edges] cmap = mpl.cm.get_cmap("Spectral") norm = mpl.colors.Normalize(vmin=np.min(edge_weights), vmax=np.max(edge_weights)) for edge in graph.edges(): source, target = edge arrowprops = dict( arrowstyle="-", color=cmap(norm(edges[edge]["weight"])), connectionstyle=f"arc3,rad={rad}", linestyle="-", linewidth=plot_params.get("lines.linewidth", 1) / 2, alpha=0.8, ) ax.annotate("", xy=layout[source], xytext=layout[target], arrowprops=arrowprops) return cmap, norm
[docs]@mpl.rc_context(plot_params) def syndrome_plot(code, ec, index_dict=None, drawing_opts=None): """Draw the syndrome plot for a code. A comprehensive graphing tool for drawing the error syndrome of code. Labelling options are specified with the help of drawing_opts, and can include: 'show_nodes' -> the underlying graph displayed 'label_nodes' -> node labels, as per draw_EGraph 'label_stabilizers' -> indices of the stabilizers 'label_boundary'-> indices of the boundary points 'legend' -> legends for the nodes and stabilizers Stabilizers are shown as transparent voxels, green for even parity and red for odd. For now, stabilizers on periodic boundaries are not drawn in a special way, stabilizers on dual boundaries are unshifted from the primal stabilizer location, and incomplete stabilizers are still represented as complete cubes. Args: code (SurfaceCode): the qubit QEC code ec (string): the error complex ('primal' or 'dual') index_dict (dict): the stabiizer-to-index mapping drawing_opts (dict): a dictionary of drawing options, with all possibilities described above. Returns: tuple: figure and axes """ cubes = getattr(code, ec + "_stabilizers") # Default drawing options. draw_dict = { "show_nodes": False, "color_nodes": ("state", {"p": None, "GKP": None}), "color_edges": "k", "label": None, "legend": True, "show_title": True, "show_axes": True, "label_stabilizers": True, "label_boundary": False, } # Combine default dictionary with supplied dictionary, duplicates # favor supplied dictionary. if drawing_opts is None: drawing_opts = {} drawing_opts = {**draw_dict, **drawing_opts} # Shape and font properties from the original graph. shape = np.array(code.dims) # If show_nodes is True, get the axes object and legend from # draw_EGraph (this also plots the graph in the console). if drawing_opts["show_nodes"]: egraph_args = [ "color_nodes", "color_edges", "label", "legend", "show_axes", ] egraph_opts = {k: drawing_opts[k] for k in egraph_args} fig, ax = draw_EGraph_matplotlib(code.graph, **egraph_opts) leg = ax.get_legend() # If show_nodes is False, create a new figure with size # determined by the dimensions of the lattice. else: fig = plt.figure() ax = fig.add_subplot(projection="3d") plt.xticks(range(0, 2 * shape[0] - 1)) plt.yticks(range(0, 2 * shape[1] - 1)) ax.set_zticks(range(0, 2 * shape[2] - 1)) ax.set_xlabel("x", labelpad=20) ax.set_ylabel("z", labelpad=20) ax.set_zlabel("y", labelpad=20) leg = None # Illustrate stabilizers with voxels colored green for even # parity and red for odd pariy. positions, colors, sizes = [], [], [] for cube in cubes: # Obtain smallest, largest, and middle coordinates for each # cube. xmin, xmax = np.array(cube.xlims()) ymin, ymax = np.array(cube.ylims()) zmin, zmax = np.array(cube.zlims()) xmid, ymid, zmid = np.array(cube.midpoint()) # Fill in the color arrays depending on parity. if cube.parity: color = "#FF000015" else: color = "#00FF0015" # gap defines the distance between adjacent cubes gap = 0.15 min_arr = np.array([xmin, ymin, zmin]) max_arr = np.array([xmax, ymax, zmax]) positions.append(min_arr + gap) sizes.append(np.abs(max_arr - min_arr) - 2 * gap) colors.append(color) if drawing_opts["label_stabilizers"] and index_dict: if cube in index_dict: ax.text(xmid, ymid, zmid, index_dict[cube]) # draw cubes pc = _plot_cubes_at(positions, colors=colors, sizes=sizes) ax.add_collection3d(pc) # setting plot limits to give some room to the boxes ymin = 0 if ec == "primal" else -2 ax.set_xlim(0, 2 * shape[0]) ax.set_ylim(ymin, 2 * shape[1]) ax.set_zlim(0, 2 * shape[2]) if drawing_opts["label_boundary"] and index_dict: bound_points = getattr(code, ec + "_bound_points") for point in bound_points: ax.scatter(*point, s=5, c="k") ax.text(*point, index_dict[point]) # Define a legend for red/green cubes. legend_elements = [ Patch(facecolor="#00FF0050", label="even parity"), Patch(facecolor="#FF000050", label="odd parity"), ] if drawing_opts["legend"]: ax.legend(handles=legend_elements, loc="upper left") # Since draw_EGraph legend has been overwritten, re-add # it to the plot. if leg: ax.add_artist(leg) if drawing_opts["show_title"]: ax.set_title(ec.capitalize() + " syndrome") return fig, ax
def _plot_cubes_at(positions, sizes=None, colors=None, **kwargs): """Plot cubes with their origin located at ``positions``. Note cubes are located by displacing them from the origin. In that sense, the location is defined by the corner matching the origin of the coordinate system before displacement. Args: positions (Iterable): An interable of dimension ``(N,3)`` containing the position of the corner of the cube. sizes (Iterable): An interable of dimension ``(N,3)`` containing the size of the cube in the coordinate directions. colors (Iterable): An iterable of size ``N`` containing the colors of the cube. This can be any of the option allowed by matplolib. Keyword args: **kwargs: all other parameters are forwarded to ```Poly3DColletion`` <https://matplotlib.org/stable/api/_as_gen/mpl_toolkits.mplot3d.art3d.Poly3DCollection.html>`_. Returs: Poly3DCollection: A collection of 3D polygons defining the cubes. """ g = [_cuboid_data(p, size=s) for p, s in zip(positions, sizes)] return Poly3DCollection(np.concatenate(g), facecolors=np.repeat(colors, 6, axis=0), **kwargs) def _cuboid_data(origin, size=(1, 1, 1)): """Return an array with the corners of a cube of size 1.""" X = np.array( [ [[0, 1, 0], [0, 0, 0], [1, 0, 0], [1, 1, 0]], [[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]], [[1, 0, 1], [1, 0, 0], [1, 1, 0], [1, 1, 1]], [[0, 0, 1], [0, 0, 0], [0, 1, 0], [0, 1, 1]], [[0, 1, 0], [0, 1, 1], [1, 1, 1], [1, 1, 0]], [[0, 1, 1], [0, 0, 1], [1, 0, 1], [1, 1, 1]], ] ).astype(float) # scale the sides of the cube for i in range(3): X[:, :, i] *= size[i] # displace the cube origin X += np.array(origin) return X
[docs]@mpl.rc_context(plot_params) def add_recovery_drawing(ax, **kwargs): """Plot the recovery.""" if kwargs.get("show_title"): ax.set_title("Syndrome and recovery") recovery_set = kwargs.get("recovery_set") if recovery_set: for point in recovery_set: ax.plot(*point, "o", c="k", markersize=6) matching = kwargs.get("matching") G_match = kwargs.get("matching_graph") if matching: virtual_points = G_match.virtual_points for pair in matching: if pair not in it.product(virtual_points, virtual_points): xlist, ylist, zlist = [], [], [] path = G_match.edge_path(pair) for node in path: if isinstance(node, Stabilizer): x, y, z = node.midpoint() else: x, y, z = node plt.plot(x, y, z, marker="2", markersize=15, c="k") xlist += [x] ylist += [y] zlist += [z] ax.plot( xlist, ylist, zlist, "o-", c=np.random.rand(3), linewidth=plot_params.get("lines.linewidth", None) * 0.9, ) return ax
[docs]def draw_decoding(code, ec, dec_objects=None, drawing_opts=None): """Draw the stabilizer and matching graphs, and the plot the syndrome.""" if drawing_opts is None: drawing_opts = {} # Drawing the stabilizer graph G_stabilizer = getattr(code, ec + "_stab_graph") G_match = dec_objects.get("matching_graph") # An integer label for each node in the stabilizer and matching # graphs. This is useful to identify the nodes in the plots. if drawing_opts.get("label_stabilizers") or drawing_opts.get("label_boundary"): # Node labels for the stabilizer graph (avoid "low" / "high" nodes) node_labels = {node: index for index, node in enumerate(list(G_stabilizer.nodes())[2:])} # Update node labels to work with the matching graph---needs to be done # because virtual boundary nodes are of the form ((x, y, z), i). if G_match: for virtual_node in set(G_match.graph.nodes()) - set(G_stabilizer.nodes()): index = node_labels[virtual_node[0]] node_labels[virtual_node] = index else: node_labels = None label_edges = bool(drawing_opts.get("label_edges")) show_title = bool(drawing_opts.get("show_title")) # title = drawing_opts.get_title() fig1, ax1 = code.draw_stabilizer_graph( ec, title=ec.capitalize() + " stabilizer graph" if show_title else "", label_edges=label_edges, node_labels=node_labels, ) # Drawing the matching graph fig2, ax2 = (None, None) if G_match: if len(G_match.graph): fig2, ax2 = G_match.draw( title=ec.capitalize() + " matching graph" if show_title else "", label_edges=label_edges, node_labels=node_labels, ) else: print("\nMatching graph empty!\n") # Drawing the syndrome fig3, ax3 = syndrome_plot(code, ec, drawing_opts=drawing_opts, index_dict=node_labels) if drawing_opts.get("show_recovery"): ax3 = add_recovery_drawing(ax3, show_title=drawing_opts.get("show_title"), **dec_objects) return (fig1, ax1), (fig2, ax2), (fig3, ax3)
[docs]def to_pi_string(x, tex: bool = True, d=2): """Convert x, a multiple of sqrt(pi)/2, to a pretty string. If x is not a multiple of sqrt(pi)/2, return the unmodified string of x with `d` integers after the decimal. If tex is True, add LaTeX $ signs. """ remainder = math.remainder(x, np.sqrt(np.pi) / 2) if np.isclose(remainder, 0): integer = round(x / (np.sqrt(np.pi) / 2)) pref = int(integer * ((1 - integer % 2) / 2 + integer % 2)) x_str = (not bool(round(x))) * "0" + bool(round(x)) * ( bool(tex) * "$" + (not bool(1 + pref)) * "-" + bool(1 - abs(pref)) * str(pref) + r"\sqrt{\pi}" + (integer % 2) * "/2" + bool(tex) * "$" ) return x_str return f"{x:.{d}f}"
[docs]class PiFormatter(Formatter): """Formatter for axis-ticks containing multiples of sqrt(pi)/2.""" def __init__(self, tex: bool = True, d: int = 2): """Initialize the formatter. Args: tex: Whether to use LaTeX formatting (i.e. adding $ around the string). d: Number of decimals to use. """ self.tex = tex self.d = d def __call__(self, x, pos=None): return to_pi_string(x, tex=self.tex, d=self.d)

Contents

Home

Background

Using FlamingPy

Development

Getting Help

Python API