Skip to content

Commit 6e4b7ad

Browse files
authored
PlusEqualProduct version of GFMultiplication for GF($2^m$) (#1457)
* PlusEqualProduct version of multiplication for GF(2^m) * Regenerate notebooks and update docstring * Fix mypy
1 parent 801bee5 commit 6e4b7ad

File tree

3 files changed

+92
-23
lines changed

3 files changed

+92
-23
lines changed

qualtran/bloqs/gf_arithmetic/gf2_multiplication.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,13 @@
5656
"gates.\n",
5757
"\n",
5858
"#### Parameters\n",
59-
" - `bitsize`: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of qubits in each of the two input registers $a$ and $b$ that should be multiplied. \n",
59+
" - `bitsize`: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of qubits in each of the two input registers $a$ and $b$ that should be multiplied.\n",
60+
" - `plus_equal_prod`: If True, implements the `PlusEqualProduct` version that applies the map $|x\\rangle |y\\rangle |z\\rangle \\rightarrow |x\\rangle |y\\rangle |x + z\\rangle$. \n",
6061
"\n",
6162
"#### Registers\n",
6263
" - `x`: Input THRU register of size $m$ that stores elements from $GF(2^m)$.\n",
6364
" - `y`: Input THRU register of size $m$ that stores elements from $GF(2^m)$.\n",
64-
" - `result`: Output RIGHT register of size $m$ that stores the product $x * y$ in $GF(2^m)$. \n",
65+
" - `result`: Register of size $m$ that stores the product $x * y$ in $GF(2^m)$. If plus_equal_prod is True - result is a THRU register and stores $result + x * y$. If plus_equal_prod is False - result is a RIGHT register and stores $x * y$. \n",
6566
"\n",
6667
"#### References\n",
6768
" - [On the Design and Optimization of a Quantum Polynomial-Time Attack on Elliptic Curve Cryptography](https://arxiv.org/abs/0710.1093). \n",
@@ -99,7 +100,7 @@
99100
},
100101
"outputs": [],
101102
"source": [
102-
"gf16_multiplication = GF2Multiplication(4)"
103+
"gf16_multiplication = GF2Multiplication(4, plus_equal_prod=True)"
103104
]
104105
},
105106
{
@@ -114,7 +115,7 @@
114115
"import sympy\n",
115116
"\n",
116117
"m = sympy.Symbol('m')\n",
117-
"gf2_multiplication_symbolic = GF2Multiplication(m)"
118+
"gf2_multiplication_symbolic = GF2Multiplication(m, plus_equal_prod=False)"
118119
]
119120
},
120121
{

qualtran/bloqs/gf_arithmetic/gf2_multiplication.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,19 @@ class SynthesizeLRCircuit(Bloq):
5151
"""Synthesize linear reversible circuit using CNOT gates.
5252
5353
Args:
54-
matrix: An n x m matrix describing the linear transformation.
54+
matrix: An n x n matrix describing the linear transformation.
5555
5656
References:
5757
[Efficient Synthesis of Linear Reversible Circuits](https://arxiv.org/abs/quant-ph/0302002)
5858
"""
5959

6060
matrix: Union[Shaped, np.ndarray] = attrs.field(eq=_data_or_shape_to_tuple)
61+
is_adjoint: bool = False
6162

6263
def __attrs_post_init__(self):
6364
assert len(self.matrix.shape) == 2
6465
n, m = self.matrix.shape
65-
assert is_symbolic(n, m) or n >= m
66+
assert is_symbolic(n, m) or n == m
6667

6768
@cached_property
6869
def signature(self) -> 'Signature':
@@ -72,17 +73,23 @@ def signature(self) -> 'Signature':
7273
def on_classical_vals(self, *, q: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
7374
matrix = self.matrix
7475
assert isinstance(matrix, np.ndarray)
76+
if self.is_adjoint:
77+
matrix = np.linalg.inv(matrix)
78+
assert np.allclose(matrix, matrix.astype(int))
79+
matrix = matrix.astype(int)
7580
_, m = matrix.shape
7681
assert isinstance(q, np.ndarray)
77-
q_in = q[:m]
78-
return {'q': (matrix @ q_in) % 2}
82+
return {'q': (matrix @ q) % 2}
7983

8084
def build_call_graph(
8185
self, ssa: 'SympySymbolAllocator'
8286
) -> Union['BloqCountDictT', Set['BloqCountT']]:
8387
n = self.matrix.shape[0]
8488
return {CNOT(): ceil(n**2 / log2(n))}
8589

90+
def adjoint(self) -> 'SynthesizeLRCircuit':
91+
return attrs.evolve(self, is_adjoint=not self.is_adjoint)
92+
8693

8794
@attrs.frozen
8895
class GF2Multiplication(Bloq):
@@ -108,11 +115,15 @@ class GF2Multiplication(Bloq):
108115
Args:
109116
bitsize: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of
110117
qubits in each of the two input registers $a$ and $b$ that should be multiplied.
118+
plus_equal_prod: If True, implements the `PlusEqualProduct` version that applies the
119+
map $|x\rangle |y\rangle |z\rangle \rightarrow |x\rangle |y\rangle |x + z\rangle$.
111120
112121
Registers:
113122
x: Input THRU register of size $m$ that stores elements from $GF(2^m)$.
114123
y: Input THRU register of size $m$ that stores elements from $GF(2^m)$.
115-
result: Output RIGHT register of size $m$ that stores the product $x * y$ in $GF(2^m)$.
124+
result: Register of size $m$ that stores the product $x * y$ in $GF(2^m)$.
125+
If plus_equal_prod is True - result is a THRU register and stores $result + x * y$.
126+
If plus_equal_prod is False - result is a RIGHT register and stores $x * y$.
116127
117128
118129
References:
@@ -124,14 +135,16 @@ class GF2Multiplication(Bloq):
124135
"""
125136

126137
bitsize: SymbolicInt
138+
plus_equal_prod: bool = False
127139

128140
@cached_property
129141
def signature(self) -> 'Signature':
142+
result_side = Side.THRU if self.plus_equal_prod else Side.RIGHT
130143
return Signature(
131144
[
132145
Register('x', dtype=self.qgf),
133146
Register('y', dtype=self.qgf),
134-
Register('result', dtype=self.qgf, side=Side.RIGHT),
147+
Register('result', dtype=self.qgf, side=result_side),
135148
]
136149
)
137150

@@ -143,14 +156,15 @@ def qgf(self) -> QGF:
143156
def reduction_matrix_q(self) -> np.ndarray:
144157
m = int(self.bitsize)
145158
f = self.qgf.gf_type.irreducible_poly
146-
M = np.zeros((m - 1, m))
159+
M = np.zeros((m, m))
147160
alpha = [1] + [0] * m
148161
for i in range(m - 1):
149162
# x ** (m + i) % f
150163
coeffs = (Poly(alpha, GF(2)) % f).coeffs.tolist()[::-1]
151164
coeffs = coeffs + [0] * (m - len(coeffs))
152165
M[i] = coeffs
153166
alpha += [0]
167+
M[m - 1][m - 1] = 1
154168
return np.transpose(M)
155169

156170
@cached_property
@@ -162,14 +176,18 @@ def synthesize_reduction_matrix_q(self) -> SynthesizeLRCircuit:
162176
else SynthesizeLRCircuit(self.reduction_matrix_q)
163177
)
164178

165-
def build_composite_bloq(
166-
self, bb: 'BloqBuilder', *, x: 'Soquet', y: 'Soquet'
167-
) -> Dict[str, 'Soquet']:
179+
def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'Soquet') -> Dict[str, 'Soquet']:
168180
if is_symbolic(self.bitsize):
169181
raise DecomposeTypeError(f"Cannot decompose symbolic {self}")
170-
result = bb.allocate(dtype=self.qgf)
182+
x, y = soqs['x'], soqs['y']
183+
result = soqs['result'] if self.plus_equal_prod else bb.allocate(dtype=self.qgf)
171184
x, y, result = bb.split(x)[::-1], bb.split(y)[::-1], bb.split(result)[::-1]
172185
m = int(self.bitsize)
186+
187+
# Step-0: PlusEqualProduct special case.
188+
if self.plus_equal_prod:
189+
result = bb.add(self.synthesize_reduction_matrix_q.adjoint(), q=result)
190+
173191
# Step-1: Multiply Monomials.
174192
for i in range(m):
175193
for j in range(i + 1, m):
@@ -199,16 +217,21 @@ def build_call_graph(
199217
self, ssa: 'SympySymbolAllocator'
200218
) -> Union['BloqCountDictT', Set['BloqCountT']]:
201219
m = self.bitsize
202-
return {Toffoli(): m**2, self.synthesize_reduction_matrix_q: 1}
220+
plus_equal_prod = (
221+
{self.synthesize_reduction_matrix_q.adjoint(): 1} if self.plus_equal_prod else {}
222+
)
223+
return {Toffoli(): m**2, self.synthesize_reduction_matrix_q: 1} | plus_equal_prod
203224

204-
def on_classical_vals(self, *, x, y) -> Dict[str, 'ClassicalValT']:
205-
assert isinstance(x, self.qgf.gf_type) and isinstance(y, self.qgf.gf_type)
206-
return {'x': x, 'y': y, 'result': x * y}
225+
def on_classical_vals(self, **vals) -> Dict[str, 'ClassicalValT']:
226+
assert all(isinstance(val, self.qgf.gf_type) for val in vals.values())
227+
x, y = vals['x'], vals['y']
228+
result = vals['result'] if self.plus_equal_prod else self.qgf.gf_type(0)
229+
return {'x': x, 'y': y, 'result': result + x * y}
207230

208231

209232
@bloq_example
210233
def _gf16_multiplication() -> GF2Multiplication:
211-
gf16_multiplication = GF2Multiplication(4)
234+
gf16_multiplication = GF2Multiplication(4, plus_equal_prod=True)
212235
return gf16_multiplication
213236

214237

@@ -217,7 +240,7 @@ def _gf2_multiplication_symbolic() -> GF2Multiplication:
217240
import sympy
218241

219242
m = sympy.Symbol('m')
220-
gf2_multiplication_symbolic = GF2Multiplication(m)
243+
gf2_multiplication_symbolic = GF2Multiplication(m, plus_equal_prod=False)
221244
return gf2_multiplication_symbolic
222245

223246

qualtran/bloqs/gf_arithmetic/gf2_multiplication_test.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
1516
import pytest
1617
from galois import GF
1718

19+
from qualtran import QGF
1820
from qualtran.bloqs.gf_arithmetic.gf2_multiplication import (
1921
_gf2_multiplication_symbolic,
2022
_gf16_multiplication,
2123
GF2Multiplication,
24+
SynthesizeLRCircuit,
2225
)
2326
from qualtran.testing import assert_consistent_classical_action
2427

@@ -31,16 +34,58 @@ def test_gf2_multiplication_symbolic(bloq_autotester):
3134
bloq_autotester(_gf2_multiplication_symbolic)
3235

3336

37+
def test_synthesize_lr_circuit():
38+
m = 2
39+
matrix = GF2Multiplication(m).reduction_matrix_q
40+
bloq = SynthesizeLRCircuit(matrix)
41+
bloq_adj = bloq.adjoint()
42+
QGFM, GFM = QGF(2, m), GF(2**m)
43+
for i in GFM.elements:
44+
bloq_out = bloq.call_classically(q=np.array(QGFM.to_bits(i)))[0]
45+
bloq_adj_out = bloq_adj.call_classically(q=bloq_out)[0]
46+
assert isinstance(bloq_adj_out, np.ndarray)
47+
assert i == QGFM.from_bits([*bloq_adj_out])
48+
49+
50+
@pytest.mark.slow
51+
@pytest.mark.parametrize('m', [3, 4, 5])
52+
def test_synthesize_lr_circuit_slow(m):
53+
matrix = GF2Multiplication(m).reduction_matrix_q
54+
bloq = SynthesizeLRCircuit(matrix)
55+
bloq_adj = bloq.adjoint()
56+
QGFM, GFM = QGF(2, m), GF(2**m)
57+
for i in GFM.elements:
58+
bloq_out = bloq.call_classically(q=np.array(QGFM.to_bits(i)))[0]
59+
bloq_adj_out = bloq_adj.call_classically(q=bloq_out)[0]
60+
assert isinstance(bloq_adj_out, np.ndarray)
61+
assert i == QGFM.from_bits([*bloq_adj_out])
62+
63+
64+
def test_gf2_plus_equal_prod_classical_sim_quick():
65+
m = 2
66+
bloq = GF2Multiplication(m, plus_equal_prod=True)
67+
GFM = GF(2**m)
68+
assert_consistent_classical_action(bloq, x=GFM.elements, y=GFM.elements, result=GFM.elements)
69+
70+
71+
@pytest.mark.slow
72+
def test_gf2_plus_equal_prod_classical_sim():
73+
m = 3
74+
bloq = GF2Multiplication(m, plus_equal_prod=True)
75+
GFM = GF(2**m)
76+
assert_consistent_classical_action(bloq, x=GFM.elements, y=GFM.elements, result=GFM.elements)
77+
78+
3479
def test_gf2_multiplication_classical_sim_quick():
3580
m = 2
36-
bloq = GF2Multiplication(m)
81+
bloq = GF2Multiplication(m, plus_equal_prod=False)
3782
GFM = GF(2**m)
3883
assert_consistent_classical_action(bloq, x=GFM.elements, y=GFM.elements)
3984

4085

4186
@pytest.mark.slow
4287
@pytest.mark.parametrize('m', [3, 4, 5])
4388
def test_gf2_multiplication_classical_sim(m):
44-
bloq = GF2Multiplication(m)
89+
bloq = GF2Multiplication(m, plus_equal_prod=False)
4590
GFM = GF(2**m)
4691
assert_consistent_classical_action(bloq, x=GFM.elements, y=GFM.elements)

0 commit comments

Comments
 (0)