Skip to content

Commit

Permalink
Improve musical score drawing (#621)
Browse files Browse the repository at this point in the history
* Improve musical score drawing

* format/lint
  • Loading branch information
mpharrigan authored Feb 2, 2024
1 parent 10671a5 commit 79c324d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 9 deletions.
12 changes: 11 additions & 1 deletion qualtran/bloqs/basic_gates/toffoli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from attrs import frozen

from qualtran import Bloq, Register, Signature
from qualtran import Bloq, Register, Signature, Soquet
from qualtran.bloqs.basic_gates import TGate
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.resource_counting import SympySymbolAllocator
Expand All @@ -26,6 +26,7 @@
import cirq

from qualtran.cirq_interop import CirqQuregT
from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT

Expand Down Expand Up @@ -74,3 +75,12 @@ def as_cirq_op(

(trg,) = target
return cirq.CCNOT(*ctrl[:, 0], trg), {'ctrl': ctrl, 'target': target}

def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
from qualtran.drawing import Circle, ModPlus

if soq.reg.name == 'ctrl':
return Circle(filled=True)
elif soq.reg.name == 'target':
return ModPlus()
raise ValueError(f'Bad wire symbol soquet: {soq}')
8 changes: 7 additions & 1 deletion qualtran/bloqs/basic_gates/x_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
import quimb.tensor as qtn
from attrs import frozen

from qualtran import Bloq, Register, Side, Signature, SoquetT
from qualtran import Bloq, Register, Side, Signature, Soquet, SoquetT
from qualtran.cirq_interop.t_complexity_protocol import TComplexity

if TYPE_CHECKING:
import cirq

from qualtran.cirq_interop import CirqQuregT
from qualtran.drawing import WireSymbol
from qualtran.simulation.classical_sim import ClassicalValT

_PLUS = np.ones(2, dtype=np.complex128) / np.sqrt(2)
Expand Down Expand Up @@ -194,3 +195,8 @@ def as_cirq_op(

def t_complexity(self):
return TComplexity(clifford=1)

def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
from qualtran.drawing import ModPlus

return ModPlus()
18 changes: 15 additions & 3 deletions qualtran/drawing/_show_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .bloq_counts_graph import format_counts_sigma, GraphvizCounts
from .graphviz import PrettyGraphDrawer
from .musical_score import draw_musical_score, get_musical_score_data

if TYPE_CHECKING:
import networkx as nx
Expand All @@ -29,9 +30,20 @@
from qualtran import Bloq


def show_bloq(bloq: 'Bloq'):
"""Display a graph representation of the bloq in IPython."""
IPython.display.display(PrettyGraphDrawer(bloq).get_svg())
def show_bloq(bloq: 'Bloq', type: str = 'graph'): # pylint: disable=redefined-builtin
"""Display a visual representation of the bloq in IPython.
Args:
bloq: The bloq to show
type: Either 'graph' or 'musical_score'. By default, display a directed acyclic
graph of the bloq connectivity. Otherwise, draw a musical score diagram.
"""
if type.lower() == 'graph':
IPython.display.display(PrettyGraphDrawer(bloq).get_svg())
elif type.lower() == 'musical_score':
draw_musical_score(get_musical_score_data(bloq))
else:
raise ValueError(f"Unknown `show_bloq` type: {type}.")


def show_bloqs(bloqs: Sequence['Bloq'], labels: Sequence[str] = None):
Expand Down
38 changes: 34 additions & 4 deletions qualtran/drawing/musical_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,38 @@ def get_musical_score_data(bloq: Bloq, manager: Optional[LineManager] = None) ->
return msd


def draw_musical_score(msd: MusicalScoreData):
fig, ax = plt.subplots(figsize=(max(5.0, 0.2 + 0.6 * msd.max_x), 5))
def draw_musical_score(
msd: MusicalScoreData,
unit_to_inches: float = 0.8,
max_width: float = 8.0,
max_height: float = 8.0,
):
# First, set up data coordinate limits and figure size.
# X coordinates go from -1 to max_x
# with 1 unit of padding it goes from -2 to max_x+1
xlim = (-2, msd.max_x + 1)
x_extent = msd.max_x + 3.0
# Y coordinates of non-labels goes from 0 to -max_y;
# with the bloq label above it goes from 0.5 to -max_y
# with 0.5 units of padding it goes from 1 to -(max_y+0.5)
ylim = (-msd.max_y - 0.5, 1)
y_extent = msd.max_y + 1.5

# The width and height are proportional.
width = unit_to_inches * x_extent
height = unit_to_inches * y_extent

# But we cap width and height (but keep it proportional).
if width > height and width > max_width:
scale = max_width / width
width *= scale
height *= scale
elif height > max_height:
scale = max_height / height
height *= scale
width *= scale

fig, ax = plt.subplots(figsize=(width, height))

for hline in msd.hlines:
ax.hlines(-hline.y, hline.seq_x_start, hline.seq_x_end, color='k', zorder=-1)
Expand All @@ -648,8 +678,8 @@ def draw_musical_score(msd: MusicalScoreData):
symb = soq.symb
symb.draw(ax, soq.rpos.seq_x, soq.rpos.y)

ax.set_xlim((-2, msd.max_x + 1))
ax.set_ylim((-msd.max_y - 0.5, 1))
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.axis('off')
fig.tight_layout()
return fig, ax
Expand Down

0 comments on commit 79c324d

Please sign in to comment.