Source code for flamingpy.codes.surface_code

# Copyright 2022 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for the measurement-based surface code and related functions."""

import itertools as it
from typing import Tuple, Union, List

import numpy as np

from flamingpy.codes import Stabilizer
from flamingpy.codes.graphs import EGraph
from flamingpy.codes.graphs import NxStabilizerGraph, RxStabilizerGraph


[docs]def alternating_polarity(edge): """Return +1 or -1 depending on the vertices that form edge edge. Help with the assignment of edge weights (i.e. polarity) for the RHG graph. This particular alternating pattern ensures that every vertex has two +1 and two -1 weight edges incident on it. The precise combination depends on the direction of the edge and the coordinates of the vertices. This pattern may be helpful to reduce, e.g. CV noise. Args: edge (list-type): a pair of tuples, denoting lattice vertices. Returns: int: +1 or -1. """ point1, point2 = np.array(edge[0]), np.array(edge[1]) direction = np.where(point2 - point1)[0][0] if direction == 0: pol = (-1) ** point1[1] elif direction == 1: pol = (-1) ** point1[2] elif direction == 2: pol = (-1) ** point1[0] else: print("Vertices must be separated by one unit on the integer lattice.") return 1 return pol
[docs]def dual_neighbours(p, displace=1): """All the dual neighbours of primal vertex p in the RHG lattice. A helper function for RHG_graph. Given a primal vertex p, returns the coordinates of all the dual neighbours. Assumes each neighbour is 1 unit away by default. Args: p (tuple): the coordinates of the primal vertex. displace (float): how much to displace the neighbour by. Useful to change when creating maronodes. Returns: List[Tuple]: the coordinates of the four neighbours. """ x, y, z = p[0], p[1], p[2] top = (x, y + displace, z) bottom = (x, y - displace, z) left = (x - displace, y, z) right = (x + displace, y, z) if z % 2: front = (x, y, z + displace) back = (x, y, z - displace) if x % 2: return [back, left, front, right] return [back, top, front, bottom] return [bottom, left, top, right]
[docs]def str_to_bound(bound_name): """Return a list of x-y-z boundaries corresponding to bound_name. The options are: 'open_primal': [primal, dual, primal] 'open_dual': [primal, dual, dual] 'primal': [primal, primal, primal] 'dual': [dual, dual, dual] 'periodic': [periodic, periodic, periodic] 'periodic_primal': [periodic, periodic, primal] 'periodic_dual': [periodic, periodic, dual] '{b}': [b, b, b], where b can be 'primal', 'dual', or 'periodic'. """ if not isinstance(bound_name, str): raise TypeError("Boundary type must be string.") boundary_mapping_dict = { "open_primal": ["primal", "dual", "primal"], "open_dual": ["primal", "dual", "dual"], "primal": ["primal", "primal", "primal"], "dual": ["dual", "dual", "dual"], "periodic": ["periodic", "periodic", "periodic"], "periodic_primal": ["periodic", "periodic", "primal"], "periodic_dual": ["periodic", "periodic", "dual"], } return np.array(boundary_mapping_dict.get(bound_name))
[docs]def RHG_graph( dims: Union[int, Tuple[int, int, int], List, np.ndarray], boundaries="primal", polarity=None, ): """Return an EGraph of a dims-dimensional RHG lattice. Generate a Raussendorf-Harrington-Goyal (RHG) lattice, which can be viewed as the measurement-based version or foliation of the surface code, with specified dimensions and boundary types. Args: dims (int or Sequence[int]): the dimensions of the lattice. If int, generates a cube corresponding to a code of distance dims. If a sequence (dx, dy, dz), assumes distances dx, dy, dz in x, y, z directions, respectively. For axes with open boundaries, the corresponding distance should be greater than 1. boundaries (str or list-type, optional): the boundary types in x, y, z. We use the identification primal = smooth and dual = rough, to align with surface code terminology. Available choices in the order x, y, z are: 'open_primal': [primal, dual, primal] 'open_dual': [primal, dual, dual] 'primal': [primal, primal, primal] 'dual': [dual, dual, dual] 'periodic': [periodic, periodic, periodic] 'periodic_primal': [periodic, periodic, primal] 'periodic_dual': [periodic, periodic, dual] '{b}': b, b, b, ['{b1}', '{b2}', '{b3}']: b1, b2, b3, where each b above can be 'primal', 'dual', or 'periodic'. By default, 'primal' is used (i.e. ['primal', 'primal', 'primal']). polarity (func): a function that specifies edge weights. It must be of the following form: polarity(edge) = weight. If not supplied, assumes all edges have weight 1. Returns: EGraph: the RHG lattice. """ # Checking input if not isinstance(dims, (int, tuple, list, np.ndarray)): raise TypeError("dims must be an integer or a sequence of three integers.") if isinstance(dims, (tuple, list, np.ndarray)) and np.size(dims) != 3: raise ValueError("dims must be an integer or a sequence of three integers.") if np.issubdtype(type(dims), np.integer): dims = (dims,) * 3 # Dealing with boundaries. if isinstance(boundaries, str): boundaries = str_to_bound(boundaries) # Locations of all primal vertices. max_dict = {"primal": 1, "dual": 0, "periodic": 0} range_max = dims - np.array([max_dict[typ] for typ in boundaries]) ranges = [range(range_max[i]) for i in (0, 1, 2)] inds = it.product(*ranges) # Primal vertices are combined into lists of six to be later usable # by the syndrome indentification in SurfaceCode. all_six_bodies = [ [ (2 * i, 2 * j + 1, 2 * k + 1), (2 * i + 1, 2 * j, 2 * k + 1), (2 * i + 1, 2 * j + 1, 2 * k), (2 * i + 2, 2 * j + 1, 2 * k + 1), (2 * i + 1, 2 * j + 2, 2 * k + 1), (2 * i + 1, 2 * j + 1, 2 * k + 2), ] for (i, j, k) in inds ] denested_six_bodies = {a for b in all_six_bodies for a in b} # Tuple indices corresponding to dual and periodic boundaries. dual_inds = set((boundaries == "dual").nonzero()[0]) periodic_inds = set((boundaries == "periodic").nonzero()[0]) G = EGraph() for vertex in denested_six_bodies: where_vertex_0, where_vertex_max = set(), set() # Ensure no vertices are included if they extend beyond # requested boundary conditions. Note this can also be achieved # by changing 'inds'. for i in range(3): if vertex[i] == 0: where_vertex_0 ^= {i} elif vertex[i] == 2 * dims[i]: where_vertex_max ^= {i} if not ( where_vertex_max & dual_inds or where_vertex_0 & dual_inds or where_vertex_max & periodic_inds ): for neighbor in dual_neighbours(vertex): # Ensure no neighbours are included if they extend beyond # requested boundary conditions. where_neighbor_0, where_neighbor_max = set(), set() for i in range(3): if neighbor[i] == 0: where_neighbor_0 ^= {i} elif neighbor[i] == 2 * dims[i]: where_neighbor_max ^= {i} if not ( where_neighbor_max & dual_inds or where_neighbor_0 & dual_inds or where_neighbor_max & periodic_inds ): edge = (vertex, neighbor) weight = polarity(edge) if polarity else 1 G.add_node(vertex, type="primal") G.add_node(neighbor, type="dual") G.add_edge(vertex, neighbor, weight=weight) # Additional edges for periodic boundaries. for ind in where_neighbor_0 & periodic_inds: max_coord = 2 * dims[ind] - 1 high_primal_vertex = list(neighbor) high_primal_vertex[ind] = max_coord neighbor_other_side = tuple(high_primal_vertex) edge = (neighbor, neighbor_other_side) weight = polarity(edge) if polarity else 1 G.add_node(neighbor_other_side, type="primal") G.add_edge(neighbor, neighbor_other_side, weight=weight, periodic=True) # Store coordinates of primal cubes for later use. G.graph["primal_cubes"] = all_six_bodies return G
[docs]class SurfaceCode: """A class for representing the surface code. Represent the surface code in its measurement-based description. By specifying the distance, error complex, choice of boundaries, and polarity, store the graph state corresponding to the code, the set of stabilizer elements and the stabilizer graph, as well as the syndrome and boundary vertices. Attributes: distance (int or Sequence[int]): the code distance of the lattice. If int, generates a cube corresponding to a code of distance dims. If a sequence (dx, dy, dz), assumes distances dx, dy, dz in x, y, z directions, respectively. dims (tup): a tuple of the spatial extent in x, y, z. ec (str): the error complex ('primal' or 'dual'). boundaries (str): the boundary conditions. The options are: * 'open': ['primal', 'dual', 'dual'] for 'primal' EC, ['primal', 'dual', 'primal'] for 'dual' EC. * 'periodic': 'periodic' in all three directions. * 'toric': ['periodic', 'periodic', 'dual'] for 'primal' EC, ['periodic', 'periodic', 'primal'] for 'dual' EC. For the "open" and "toric" boundary choice, we imagine qubits are encoded into the x-y plane and propagated in time. The z-axis is interpreted as the temporal axis, which is relevant in quantum memory simulations. Note that in the z-axis is considered as the temporal dimension in quantum memory simulations. polarity (func): a function that specifies edge weights. It must be of the following form: polarity(edge) = weight. If not supplied, assumes all edges have weight 1. backend (string): The backend to use for the stabilizer graph. Can be "rustworkx" (the default) or "networkx". The rustworkx backend should be used when speed is a concern. The networkx backend is provided for historical reasons. graph (EGraph): the EGraph corresponding to the code, representing the graph state. '{b}'_stab_graph (StabilizerGraph): the stabilizer graph combining stabilizers from error complex b ('primal' or 'dual'). The particular implementation depends on backend. '{b}'_stabilizers (List[Stabilizer]): the stabilizer elements of the code according to the error complex b ('primal'/'dual'). '{b}'_syndrome_coords (List[Tuple]): the coordinates of the syndrome vertices according to the error complex b ('primal'/'dual'/'all'). '{b}'_syndrome_inds (List[Int]): the integer indices of the syndrome vertices according to the error comple b ('primal'/'dual'/'all'). '{b}'_boundary_coords (list of tup): the coordinates of the boundary according to the error complex b ('primal'/'dual'). """ def __init__( self, distance: Union[int, Tuple[int, int, int]], ec="primal", boundaries="open", polarity=None, backend="rustworkx", ): self.distance = distance if np.issubdtype(type(distance), np.integer): self.dims = (distance,) * 3 elif np.size(distance) == 3: self.dims = tuple(distance) self.ec = [ec] # self.ec = ["primal", "dual"] if ec == "both" else [ec] if not (boundaries in ("open", "toric", "periodic") and ec in ("primal", "dual")): raise ValueError( f"The combination `ec={ec}` and `boundaries={boundaries} is not valid." "Allowed choices for boundaries are 'open', 'toric' or 'periodic'" "and `ec` should be either 'primal' or 'dual'." ) if boundaries == "open": self.bound_str = "open_dual" if ec in ("primal", "both") else "open_primal" elif boundaries == "toric": self.bound_str = "periodic_dual" if ec in ("primal", "both") else "periodic_primal" else: self.bound_str = boundaries self.boundaries = str_to_bound(self.bound_str) self.polarity = polarity self.graph = RHG_graph(self.dims, self.boundaries, polarity=polarity) self.graph.index_generator() # The following line defines the stabilizer, syndrome coordinate, # and syndrome index attributes. self.identify_stabilizers() # The following line defines the boundary points attribute. self.identify_boundary() # if ec == "both": # # For both error complexes, designate certain qubits as perfect # # so that the correction check proceeds as expected. In particular # # the qubits on the first and last temporal (z-direction) slice # # are made perfect. # perfect_qubits = self.graph.slice_coords("z", 1) + self.graph.slice_coords( # "z", 2 * self.dims[2] - 1 # ) # self.graph.graph["perfect_points"] = perfect_qubits # self.graph.graph["perfect_inds"] = [ # self.graph.to_indices[point] for point in perfect_qubits # ] for error_type in self.ec: if backend == "networkx": stabilizer_graph = NxStabilizerGraph(error_type, self) elif backend == "rustworkx": stabilizer_graph = RxStabilizerGraph(error_type, self) else: raise ValueError("Invalid backend; options are 'networkx' and 'rustworkx'.") setattr(self, error_type + "_stab_graph", stabilizer_graph)
[docs] def identify_stabilizers(self): """Set the stabilizer and syndrome coordinates of self. Generate a list of Stabilizer objects containing coordinates of all the stabilizer elements according to error complex ec. Furthermore, generate a list of all the relevant syndrome coordinates. In the end, the {ec}_syndrome_coords, {ec}_syndrome_inds, and {ec}_stabilizers attributes (where ec can be 'primal' or 'dual') as well as all_syndrome_inds and all_syndrome_coords are set. """ all_six_bodies = self.generate_stabilizer_coords() # set {ec}_syndrome_coords, {ec}_syndrome_inds, and {ec}_stabilizers self.set_ec_stabilizers_and_syndrome(all_six_bodies) # set `all_syndrome_inds` and `all_syndrome_coords` for att in ["_syndrome_inds", "_syndrome_coords"]: new_attr = sum((getattr(self, ec + att) for ec in self.ec), start=[]) setattr(self, "all" + att, new_attr)
[docs] def set_ec_stabilizers_and_syndrome(self, all_six_bodies): """Set syndrome and stabilizers attributes of self. Set the {ec}_syndrome_coords, {ec}_syndrome_inds, and {ec}_stabilizers attributes. This method generates a list of Stabilizer objects containing coordinates of all the stabilizer elements according to error complex ec and sets them in self. """ # Dimensions, boundary types, max and min ranges. rhg_lattice = self.graph for ec in self.ec: all_cubes = [] syndrome_coords = [] for stabe in all_six_bodies[ec]: actual_stabe = list(set(stabe).intersection(rhg_lattice)) # Dealing with stabilizers at periodic boundaries if len(actual_stabe) < 6: self.periodic_boundary_stabilizers(ec, stabe, actual_stabe) cube = Stabilizer(rhg_lattice.subgraph(actual_stabe)) cube.physical = stabe syndrome_coords += actual_stabe all_cubes.append(cube) setattr(self, ec + "_syndrome_coords", list(set(syndrome_coords))) setattr( self, ec + "_syndrome_inds", [self.graph.to_indices[point] for point in syndrome_coords], ) setattr(self, ec + "_stabilizers", all_cubes)
[docs] def periodic_boundary_stabilizers(self, ec, stabe, actual_stabe): """Dealing with stabilizers at periodic boundaries. Note this method is directly modifying the reference of ``actual_stabe``. """ dims = np.array(self.dims) periodic_inds = np.where(self.boundaries == "periodic")[0] for ind in periodic_inds: if ec == "dual": highest_point = list(stabe[3 + ind]) if highest_point[ind] == 1: highest_point[ind] = 2 * dims[ind] - 1 virtual_point = tuple(highest_point) actual_stabe += [virtual_point] else: lowest_point = list(stabe[ind]) if lowest_point[ind] == 2 * dims[ind] - 2: lowest_point[ind] = 0 virtual_point = tuple(lowest_point) actual_stabe += [virtual_point]
[docs] def generate_stabilizer_coords(self): """Generate primal and dual stabilizer coordinates. Returns dict: "primal" and "dual" stabilizers """ all_six_bodies = {} if "primal" in self.ec: all_six_bodies["primal"] = self.graph.graph["primal_cubes"] if "dual" in self.ec: all_six_bodies["dual"] = self.six_body_dual_stabilizers() return all_six_bodies
[docs] def six_body_dual_stabilizers(self): """Returns all potential six-body dual stabilizers of self.""" min_dict = {"primal": -1, "dual": 0, "periodic": -1} max_dict = {"primal": 1, "dual": 1, "periodic": 1} range_min = np.array([min_dict[typ] for typ in self.boundaries]) range_max = self.dims - np.array([max_dict[typ] for typ in self.boundaries]) ranges = [range(range_min[i], range_max[i]) for i in (0, 1, 2)] inds = it.product(*ranges) # All potential six-body stabilizers stabes = [ [ (2 * i + 1, 2 * j + 2, 2 * k + 2), (2 * i + 2, 2 * j + 1, 2 * k + 2), (2 * i + 2, 2 * j + 2, 2 * k + 1), (2 * i + 3, 2 * j + 2, 2 * k + 2), (2 * i + 2, 2 * j + 3, 2 * k + 2), (2 * i + 2, 2 * j + 2, 2 * k + 3), ] for (i, j, k) in inds ] return stabes
[docs] def identify_boundary(self): """Obtain coordinates of syndrome qubits on the boundary. The relevant boundaries are determined by the ec string. In the end, the attributes {b}_bound_points are set, where b can be 'primal' or 'dual'. """ for ec in self.ec: if "periodic" in self.bound_str: ec_bound_points = [] else: bound_ind = np.where(self.boundaries == ec)[0][0] ec_bound_points = self._get_ec_bounds(ec, bound_ind) setattr(self, ec + "_bound_points", ec_bound_points)
def _get_ec_bounds(self, ec, bound_ind): """Obtain coordinates of syndrome qubits on the relevant boundary. The boundary is determined by the ``ec`` string. ``bound_ind`` specifies the direction where the ``ec`` boundary is located. """ dims = self.dims plane_dict = {0: "x", 1: "y", 2: "z"} syndrome_coords = getattr(self, ec + "_syndrome_coords") low_index = 0 if ec == "primal" else 1 high_index = 2 * dims[bound_ind] - 2 if ec == "primal" else 2 * dims[bound_ind] - 1 low_bound_points = self.graph.slice_coords(plane_dict[bound_ind], low_index) final_low_set = set(low_bound_points).intersection(syndrome_coords) high_bound_points = self.graph.slice_coords(plane_dict[bound_ind], high_index) final_high_set = set(high_bound_points).intersection(syndrome_coords) return list(final_low_set) + list(final_high_set)
[docs] def draw(self, backend="matplotlib", **kwargs): """Draw the cluster state with matplotlib. See ``flamingpy.utils.viz.draw_EGraph`` for more details. Use the default colour options: black for primal nodes, grey for dual nodes; blue for weight +1 edges, red for weight -1 edges. """ edge_colors = "grey" if self.polarity is not None: if self.polarity.__name__ == "alternating_polarity": edge_colors = ("weight", {1: "b", -1: "r"}) default_opts = { "color_nodes": ("type", {"primal": "black", "dual": "grey"}), "color_edges": edge_colors, } updated_opts = {**default_opts, **kwargs} return self.graph.draw(backend=backend, **updated_opts)
[docs] def draw_stabilizer_graph(self, ec, **kwargs): """Draw the stabilizer graph with matplotlib. See flamingpy.utils.viz.draw_dec_graph for more details. """ graph = getattr(self, ec + "_stab_graph") graph.assign_weights(self) return graph.draw(**kwargs)

Contents

Home

Background

Using FlamingPy

Development

Getting Help

Python API