Skip to content

Commit 38d0c10

Browse files
committed
Changes to .flatten
- Default predicate to flatten everything - If you can't flatten a bloq, don't complain; just don't do anything - Documentation improvements
1 parent 52622f6 commit 38d0c10

File tree

3 files changed

+76
-41
lines changed

3 files changed

+76
-41
lines changed

qualtran/_infra/composite_bloq.ipynb

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@
256256
"cbloq2 = cbloq.copy()\n",
257257
"\n",
258258
"# They're the same!\n",
259-
"display(show_bloq(cbloq))\n",
260-
"display(show_bloq(cbloq2))"
259+
"show_bloq(cbloq)\n",
260+
"show_bloq(cbloq2)"
261261
]
262262
},
263263
{
@@ -364,7 +364,7 @@
364364
"# right-dangling soquets.\n",
365365
"fsoqs = bb.map_soqs(cbloq.final_soqs(), soq_map)\n",
366366
"copy = bb.finalize(**fsoqs)\n",
367-
"copy"
367+
"print(copy)"
368368
]
369369
},
370370
{
@@ -475,6 +475,8 @@
475475
" stuff = bb.add(TestParallelCombo(), reg=stuff)\n",
476476
" return {'stuff': stuff}\n",
477477
"\n",
478+
"# Note! We're using `.as_composite_bloq()` to wrap the Bloq\n",
479+
"# into a compute graph with one node.\n",
478480
"three_p = ThreeParallelBloqs().as_composite_bloq()\n",
479481
"show_bloq(three_p)"
480482
]
@@ -486,8 +488,9 @@
486488
"metadata": {},
487489
"outputs": [],
488490
"source": [
489-
"# Do one flattening iteration\n",
490-
"flat_three_p = three_p.flatten_once(lambda binst: True)\n",
491+
"# Do one flattening operation; here equivalent to `.decompose_bloq()`\n",
492+
"# on the original bloq.\n",
493+
"flat_three_p = three_p.flatten_once()\n",
491494
"show_bloq(flat_three_p)"
492495
]
493496
},
@@ -496,7 +499,7 @@
496499
"id": "6f322361",
497500
"metadata": {},
498501
"source": [
499-
"If we just decomposed all of the subbloqs, our diagram would look very similar except the subbloqs are now `CompositeBloq` container classes instead of the original `TestParallelBloq`."
502+
"Now we have a `CompositeBloq` with three subbloqs. What if we wanted to continue decomposing? A naive approach would be to simply call `.decompose_bloq()` on each subbloqs. We'll do this in the next cell and see what happens."
500503
]
501504
},
502505
{
@@ -523,6 +526,14 @@
523526
"show_bloq(decompose_children)"
524527
]
525528
},
529+
{
530+
"cell_type": "markdown",
531+
"id": "35b8a7d3-d54e-4b72-a50a-4fb047bf797c",
532+
"metadata": {},
533+
"source": [
534+
"This is still a `CompositeBloq` with three subbloqs. The only difference is now the subbloqs are each compute graphs of their own. This likely isn't what we want. Instead, we want to do the equivalent of `flatMap` with the decompose operation: namely, decompose the subbloqs and remove a layer of nesting. This is what the `flatten_once` method achieves."
535+
]
536+
},
526537
{
527538
"cell_type": "code",
528539
"execution_count": null,
@@ -532,10 +543,18 @@
532543
"source": [
533544
"# Actually do a flattening operation on all the subbloqs\n",
534545
"show_bloq(\n",
535-
" flat_three_p.flatten_once(lambda binst: True)\n",
546+
" flat_three_p.flatten_once()\n",
536547
")"
537548
]
538549
},
550+
{
551+
"cell_type": "markdown",
552+
"id": "34e0e027-f327-47a4-9169-de233cdb8359",
553+
"metadata": {},
554+
"source": [
555+
"You can use the optional predicate to control which subbloqs get decomposed and flattened."
556+
]
557+
},
539558
{
540559
"cell_type": "code",
541560
"execution_count": null,
@@ -549,16 +568,25 @@
549568
")"
550569
]
551570
},
571+
{
572+
"cell_type": "markdown",
573+
"id": "f13ff34a-a9f3-43e6-85bd-4a1f90f4e3e1",
574+
"metadata": {},
575+
"source": [
576+
"The `.flatten` method will repeatedly call `flatten_once` until you can't flatten any more."
577+
]
578+
},
552579
{
553580
"cell_type": "code",
554581
"execution_count": null,
555582
"id": "c1e6cbb0",
556583
"metadata": {},
557584
"outputs": [],
558585
"source": [
559-
"# Flatten until you can't flatten any more\n",
586+
"# Note that in this example, we have gone back to the original `three_p` starting composite bloq.\n",
587+
"# This will perform two flattening operations.\n",
560588
"show_bloq(\n",
561-
" three_p.flatten(lambda binst: binst.bloq.supports_decompose_bloq())\n",
589+
" three_p.flatten()\n",
562590
")"
563591
]
564592
}
@@ -579,7 +607,7 @@
579607
"name": "python",
580608
"nbconvert_exporter": "python",
581609
"pygments_lexer": "ipython3",
582-
"version": "3.10.9"
610+
"version": "3.11.8"
583611
}
584612
},
585613
"nbformat": 4,

qualtran/_infra/composite_bloq.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import sympy
4141
from numpy.typing import NDArray
4242

43-
from .bloq import Bloq, DecomposeTypeError
43+
from .bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError
4444
from .data_types import check_dtypes_consistent, QAny, QBit, QDType
4545
from .quantum_graph import BloqInstance, Connection, DanglingT, LeftDangle, RightDangle, Soquet
4646
from .registers import Register, Side, Signature
@@ -316,7 +316,9 @@ def copy(self) -> 'CompositeBloq':
316316
fsoqs = _map_soqs(self.final_soqs(), soq_map)
317317
return bb.finalize(**fsoqs)
318318

319-
def flatten_once(self, pred: Callable[[BloqInstance], bool]) -> 'CompositeBloq':
319+
def flatten_once(
320+
self, pred: Callable[[BloqInstance], bool] = lambda binst: True
321+
) -> 'CompositeBloq':
320322
"""Decompose and flatten each subbloq that satisfies `pred`.
321323
322324
This will only flatten "once". That is, we will go through the bloq instances
@@ -326,16 +328,17 @@ def flatten_once(self, pred: Callable[[BloqInstance], bool]) -> 'CompositeBloq':
326328
Args:
327329
pred: A predicate that takes a bloq instance and returns True if it should
328330
be decomposed and flattened or False if it should remain undecomposed.
329-
All bloqs for which this callable returns True must support decomposition.
331+
If the bloq does not have a decomposition, it will remain undecomposed.
332+
By default, flatten everything.
330333
331334
Returns:
332335
A new composite bloq where subbloqs matching `pred` have been decomposed and
333336
flattened.
334337
335338
Raises:
336-
NotImplementedError: If `pred` returns True but the underlying bloq does not
337-
support `decompose_bloq()`.
338-
DidNotFlattenAnythingError: If none of the bloq instances satisfied `pred`.
339+
DidNotFlattenAnythingError: If the operation did not actually flatten anything.
340+
This could be because none of the bloq instances satisfied `pred` or none of
341+
the bloqs have decompositions.
339342
340343
"""
341344
bb, _ = BloqBuilder.from_signature(self.signature)
@@ -348,13 +351,16 @@ def flatten_once(self, pred: Callable[[BloqInstance], bool]) -> 'CompositeBloq':
348351
bb._i = max(binst.i for binst in self.bloq_instances) + 1
349352

350353
soq_map: List[Tuple[SoquetT, SoquetT]] = []
354+
new_out_soqs: Tuple[SoquetT, ...] = ()
351355
did_work = False
352356
for binst, in_soqs, old_out_soqs in self.iter_bloqsoqs():
353357
in_soqs = _map_soqs(in_soqs, soq_map) # update `in_soqs` from old to new.
354-
355358
if pred(binst):
356-
new_out_soqs = bb.add_from(binst.bloq.decompose_bloq(), **in_soqs)
357-
did_work = True
359+
try:
360+
new_out_soqs = bb.add_from(binst.bloq.decompose_bloq(), **in_soqs)
361+
did_work = True
362+
except (DecomposeTypeError, DecomposeNotImplementedError):
363+
pass
358364
else:
359365
# Since we took care to not re-use existing `binst.i` values for flattened
360366
# bloqs, it is safe to call `bb._add_binst` with the old `binst` (and in
@@ -371,18 +377,8 @@ def flatten_once(self, pred: Callable[[BloqInstance], bool]) -> 'CompositeBloq':
371377
fsoqs = _map_soqs(self.final_soqs(), soq_map)
372378
return bb.finalize(**fsoqs)
373379

374-
def adjoint(self) -> 'CompositeBloq':
375-
"""Get a composite bloq which is the adjoint of this composite bloq.
376-
377-
The adjoint of a composite bloq is another composite bloq where the order of
378-
operations is reversed and each subbloq is replaced with its adjoint.
379-
"""
380-
from .adjoint import _adjoint_cbloq
381-
382-
return _adjoint_cbloq(self)
383-
384380
def flatten(
385-
self, pred: Callable[[BloqInstance], bool], max_depth: int = 1_000
381+
self, pred: Callable[[BloqInstance], bool] = lambda binst: True, max_depth: int = 1_000
386382
) -> 'CompositeBloq':
387383
"""Recursively decompose and flatten subbloqs until none satisfy `pred`.
388384
@@ -392,16 +388,13 @@ def flatten(
392388
Args:
393389
pred: A predicate that takes a bloq instance and returns True if it should
394390
be decomposed and flattened or False if it should remain undecomposed.
395-
All bloqs for which this callable returns True must support decomposition.
391+
If the bloq does not have a decomposition, it will remain undecomposed.
392+
By default, flatten as much as possible.
396393
max_depth: To avoid infinite recursion, give up after this many recursive steps.
397394
398395
Returns:
399396
A new composite bloq where all recursive subbloqs matching `pred` have been
400397
decomposed and flattened.
401-
402-
Raises:
403-
NotImplementedError: If `pred` returns True but the underlying bloq does not
404-
support `decompose_bloq()`.
405398
"""
406399
cbloq = self
407400
for _ in range(max_depth):
@@ -414,6 +407,16 @@ def flatten(
414407

415408
return cbloq
416409

410+
def adjoint(self) -> 'CompositeBloq':
411+
"""Get a composite bloq which is the adjoint of this composite bloq.
412+
413+
The adjoint of a composite bloq is another composite bloq where the order of
414+
operations is reversed and each subbloq is replaced with its adjoint.
415+
"""
416+
from .adjoint import _adjoint_cbloq
417+
418+
return _adjoint_cbloq(self)
419+
417420
@staticmethod
418421
def _debug_binst(g: nx.DiGraph, binst: BloqInstance) -> List[str]:
419422
"""Helper method used in `debug_text`"""
@@ -454,6 +457,9 @@ def debug_text(self) -> str:
454457
delimited_gens = ('\n' + '-' * 20 + '\n').join(gen_texts)
455458
return delimited_gens
456459

460+
def __str__(self):
461+
return f'CompositeBloq([{len(self.bloq_instances)} subbloqs...])'
462+
457463

458464
def _create_binst_graph(
459465
cxns: Iterable[Connection], nodes: Iterable[BloqInstance] = ()

qualtran/_infra/composite_bloq_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
BloqInstance,
3131
CompositeBloq,
3232
Connection,
33-
DecomposeTypeError,
3433
LeftDangle,
3534
Register,
3635
RightDangle,
@@ -504,13 +503,15 @@ def test_flatten():
504503
cbloq2 = cbloq.flatten_once(lambda binst: True)
505504
assert len(cbloq2.bloq_instances) == 5 * 2
506505

507-
with pytest.raises(DecomposeTypeError):
508-
# Will keep trying to flatten non-decomposable things
509-
cbloq.flatten(lambda x: True)
510-
511-
cbloq3 = cbloq.flatten(lambda binst: binst.bloq.supports_decompose_bloq())
506+
cbloq3 = cbloq.flatten(lambda binst: True)
512507
assert len(cbloq3.bloq_instances) == 5 * 2
513508

509+
cbloq4 = cbloq.flatten(lambda binst: binst.bloq.supports_decompose_bloq())
510+
assert len(cbloq4.bloq_instances) == 5 * 2
511+
512+
cbloq5 = cbloq.flatten()
513+
assert len(cbloq5.bloq_instances) == 5 * 2
514+
514515

515516
def test_type_error():
516517
bb = BloqBuilder()

0 commit comments

Comments
 (0)