diff --git a/qualtran/bloqs/basic_gates/toffoli.py b/qualtran/bloqs/basic_gates/toffoli.py index fbbfdaff6..894dc4920 100644 --- a/qualtran/bloqs/basic_gates/toffoli.py +++ b/qualtran/bloqs/basic_gates/toffoli.py @@ -17,7 +17,7 @@ from attrs import frozen -from qualtran import Bloq, Register, Signature +from qualtran import Bloq, Register, Signature, Soquet from qualtran.bloqs.basic_gates import TGate from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.resource_counting import SympySymbolAllocator @@ -26,6 +26,7 @@ import cirq from qualtran.cirq_interop import CirqQuregT + from qualtran.drawing import WireSymbol from qualtran.resource_counting import BloqCountT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT @@ -74,3 +75,12 @@ def as_cirq_op( (trg,) = target return cirq.CCNOT(*ctrl[:, 0], trg), {'ctrl': ctrl, 'target': target} + + def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': + from qualtran.drawing import Circle, ModPlus + + if soq.reg.name == 'ctrl': + return Circle(filled=True) + elif soq.reg.name == 'target': + return ModPlus() + raise ValueError(f'Bad wire symbol soquet: {soq}') diff --git a/qualtran/bloqs/basic_gates/x_basis.py b/qualtran/bloqs/basic_gates/x_basis.py index 84e61fec7..66acbe036 100644 --- a/qualtran/bloqs/basic_gates/x_basis.py +++ b/qualtran/bloqs/basic_gates/x_basis.py @@ -19,13 +19,14 @@ import quimb.tensor as qtn from attrs import frozen -from qualtran import Bloq, Register, Side, Signature, SoquetT +from qualtran import Bloq, Register, Side, Signature, Soquet, SoquetT from qualtran.cirq_interop.t_complexity_protocol import TComplexity if TYPE_CHECKING: import cirq from qualtran.cirq_interop import CirqQuregT + from qualtran.drawing import WireSymbol from qualtran.simulation.classical_sim import ClassicalValT _PLUS = np.ones(2, dtype=np.complex128) / np.sqrt(2) @@ -194,3 +195,8 @@ def as_cirq_op( def t_complexity(self): return TComplexity(clifford=1) + + def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': + from qualtran.drawing import ModPlus + + return ModPlus() diff --git a/qualtran/drawing/_show_funcs.py b/qualtran/drawing/_show_funcs.py index 0983cf8f0..e6dc50a8f 100644 --- a/qualtran/drawing/_show_funcs.py +++ b/qualtran/drawing/_show_funcs.py @@ -21,6 +21,7 @@ from .bloq_counts_graph import format_counts_sigma, GraphvizCounts from .graphviz import PrettyGraphDrawer +from .musical_score import draw_musical_score, get_musical_score_data if TYPE_CHECKING: import networkx as nx @@ -29,9 +30,20 @@ from qualtran import Bloq -def show_bloq(bloq: 'Bloq'): - """Display a graph representation of the bloq in IPython.""" - IPython.display.display(PrettyGraphDrawer(bloq).get_svg()) +def show_bloq(bloq: 'Bloq', type: str = 'graph'): # pylint: disable=redefined-builtin + """Display a visual representation of the bloq in IPython. + + Args: + bloq: The bloq to show + type: Either 'graph' or 'musical_score'. By default, display a directed acyclic + graph of the bloq connectivity. Otherwise, draw a musical score diagram. + """ + if type.lower() == 'graph': + IPython.display.display(PrettyGraphDrawer(bloq).get_svg()) + elif type.lower() == 'musical_score': + draw_musical_score(get_musical_score_data(bloq)) + else: + raise ValueError(f"Unknown `show_bloq` type: {type}.") def show_bloqs(bloqs: Sequence['Bloq'], labels: Sequence[str] = None): diff --git a/qualtran/drawing/musical_score.py b/qualtran/drawing/musical_score.py index 1d3448c9b..0b1f04b0f 100644 --- a/qualtran/drawing/musical_score.py +++ b/qualtran/drawing/musical_score.py @@ -634,8 +634,38 @@ def get_musical_score_data(bloq: Bloq, manager: Optional[LineManager] = None) -> return msd -def draw_musical_score(msd: MusicalScoreData): - fig, ax = plt.subplots(figsize=(max(5.0, 0.2 + 0.6 * msd.max_x), 5)) +def draw_musical_score( + msd: MusicalScoreData, + unit_to_inches: float = 0.8, + max_width: float = 8.0, + max_height: float = 8.0, +): + # First, set up data coordinate limits and figure size. + # X coordinates go from -1 to max_x + # with 1 unit of padding it goes from -2 to max_x+1 + xlim = (-2, msd.max_x + 1) + x_extent = msd.max_x + 3.0 + # Y coordinates of non-labels goes from 0 to -max_y; + # with the bloq label above it goes from 0.5 to -max_y + # with 0.5 units of padding it goes from 1 to -(max_y+0.5) + ylim = (-msd.max_y - 0.5, 1) + y_extent = msd.max_y + 1.5 + + # The width and height are proportional. + width = unit_to_inches * x_extent + height = unit_to_inches * y_extent + + # But we cap width and height (but keep it proportional). + if width > height and width > max_width: + scale = max_width / width + width *= scale + height *= scale + elif height > max_height: + scale = max_height / height + height *= scale + width *= scale + + fig, ax = plt.subplots(figsize=(width, height)) for hline in msd.hlines: ax.hlines(-hline.y, hline.seq_x_start, hline.seq_x_end, color='k', zorder=-1) @@ -648,8 +678,8 @@ def draw_musical_score(msd: MusicalScoreData): symb = soq.symb symb.draw(ax, soq.rpos.seq_x, soq.rpos.y) - ax.set_xlim((-2, msd.max_x + 1)) - ax.set_ylim((-msd.max_y - 0.5, 1)) + ax.set_xlim(xlim) + ax.set_ylim(ylim) ax.axis('off') fig.tight_layout() return fig, ax