Skip to content

Commit 6446e29

Browse files
authored
fix: support async vector search from a collection (#949)
1 parent d4956f4 commit 6446e29

File tree

5 files changed

+206
-6
lines changed

5 files changed

+206
-6
lines changed

google/cloud/firestore_v1/async_collection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
async_aggregation,
2424
async_document,
2525
async_query,
26+
async_vector_query,
2627
transaction,
2728
)
2829
from google.cloud.firestore_v1.base_collection import (
@@ -81,6 +82,14 @@ def _aggregation_query(self) -> async_aggregation.AsyncAggregationQuery:
8182
"""
8283
return async_aggregation.AsyncAggregationQuery(self._query())
8384

85+
def _vector_query(self) -> async_vector_query.AsyncVectorQuery:
86+
"""AsyncVectorQuery factory.
87+
88+
Returns:
89+
:class:`~google.cloud.firestore_v1.async_vector_query.AsyncVectorQuery`
90+
"""
91+
return async_vector_query.AsyncVectorQuery(self._query())
92+
8493
async def _chunkify(self, chunk_size: int):
8594
async for page in self._query()._chunkify(chunk_size):
8695
yield page

tests/system/test_system.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,29 @@ def on_snapshot(docs, changes, read_time):
177177
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
178178
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
179179
def test_vector_search_collection(client, database):
180+
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
181+
collection_id = "vector_search"
182+
collection = client.collection(collection_id)
183+
184+
vector_query = collection.find_nearest(
185+
vector_field="embedding",
186+
query_vector=Vector([1.0, 2.0, 3.0]),
187+
distance_measure=DistanceMeasure.EUCLIDEAN,
188+
limit=1,
189+
)
190+
returned = vector_query.get()
191+
assert isinstance(returned, list)
192+
assert len(returned) == 1
193+
assert returned[0].to_dict() == {
194+
"embedding": Vector([1.0, 2.0, 3.0]),
195+
"color": "red",
196+
}
197+
198+
199+
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
200+
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
201+
def test_vector_search_collection_with_filter(client, database):
202+
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
180203
collection_id = "vector_search"
181204
collection = client.collection(collection_id)
182205

@@ -198,6 +221,29 @@ def test_vector_search_collection(client, database):
198221
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
199222
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
200223
def test_vector_search_collection_group(client, database):
224+
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
225+
collection_id = "vector_search"
226+
collection_group = client.collection_group(collection_id)
227+
228+
vector_query = collection_group.find_nearest(
229+
vector_field="embedding",
230+
query_vector=Vector([1.0, 2.0, 3.0]),
231+
distance_measure=DistanceMeasure.EUCLIDEAN,
232+
limit=1,
233+
)
234+
returned = vector_query.get()
235+
assert isinstance(returned, list)
236+
assert len(returned) == 1
237+
assert returned[0].to_dict() == {
238+
"embedding": Vector([1.0, 2.0, 3.0]),
239+
"color": "red",
240+
}
241+
242+
243+
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
244+
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
245+
def test_vector_search_collection_group_with_filter(client, database):
246+
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
201247
collection_id = "vector_search"
202248
collection_group = client.collection_group(collection_id)
203249

tests/system/test_system_async.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,28 @@ async def test_document_update_w_int_field(client, cleanup, database):
342342
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
343343
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
344344
async def test_vector_search_collection(client, database):
345+
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
346+
collection_id = "vector_search"
347+
collection = client.collection(collection_id)
348+
vector_query = collection.find_nearest(
349+
vector_field="embedding",
350+
query_vector=Vector([1.0, 2.0, 3.0]),
351+
limit=1,
352+
distance_measure=DistanceMeasure.EUCLIDEAN,
353+
)
354+
returned = await vector_query.get()
355+
assert isinstance(returned, list)
356+
assert len(returned) == 1
357+
assert returned[0].to_dict() == {
358+
"embedding": Vector([1.0, 2.0, 3.0]),
359+
"color": "red",
360+
}
361+
362+
363+
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
364+
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
365+
async def test_vector_search_collection_with_filter(client, database):
366+
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
345367
collection_id = "vector_search"
346368
collection = client.collection(collection_id)
347369
vector_query = collection.where("color", "==", "red").find_nearest(
@@ -362,6 +384,29 @@ async def test_vector_search_collection(client, database):
362384
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
363385
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
364386
async def test_vector_search_collection_group(client, database):
387+
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
388+
collection_id = "vector_search"
389+
collection_group = client.collection_group(collection_id)
390+
391+
vector_query = collection_group.find_nearest(
392+
vector_field="embedding",
393+
query_vector=Vector([1.0, 2.0, 3.0]),
394+
distance_measure=DistanceMeasure.EUCLIDEAN,
395+
limit=1,
396+
)
397+
returned = await vector_query.get()
398+
assert isinstance(returned, list)
399+
assert len(returned) == 1
400+
assert returned[0].to_dict() == {
401+
"embedding": Vector([1.0, 2.0, 3.0]),
402+
"color": "red",
403+
}
404+
405+
406+
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
407+
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
408+
async def test_vector_search_collection_group_with_filter(client, database):
409+
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
365410
collection_id = "vector_search"
366411
collection_group = client.collection_group(collection_id)
367412

tests/system/util/bootstrap_vector_index.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""A script to bootstrap vector data and vector index for system tests."""
15+
from google.api_core.client_options import ClientOptions
1516
from google.cloud.client import ClientWithProject # type: ignore
1617

1718
from google.cloud.firestore import Client
@@ -60,6 +61,21 @@ def _init_admin_api(self):
6061
return firestore_admin_client.FirestoreAdminClient(transport=self._transport)
6162

6263
def create_vector_index(self, parent):
64+
self._firestore_admin_api.create_index(
65+
parent=parent,
66+
index=Index(
67+
query_scope=Index.QueryScope.COLLECTION,
68+
fields=[
69+
Index.IndexField(
70+
field_path="embedding",
71+
vector_config=Index.IndexField.VectorConfig(
72+
dimension=3, flat=Index.IndexField.VectorConfig.FlatIndex()
73+
),
74+
),
75+
],
76+
),
77+
)
78+
6379
self._firestore_admin_api.create_index(
6480
parent=parent,
6581
index=Index(
@@ -79,6 +95,21 @@ def create_vector_index(self, parent):
7995
),
8096
)
8197

98+
self._firestore_admin_api.create_index(
99+
parent=parent,
100+
index=Index(
101+
query_scope=Index.QueryScope.COLLECTION_GROUP,
102+
fields=[
103+
Index.IndexField(
104+
field_path="embedding",
105+
vector_config=Index.IndexField.VectorConfig(
106+
dimension=3, flat=Index.IndexField.VectorConfig.FlatIndex()
107+
),
108+
),
109+
],
110+
),
111+
)
112+
82113
self._firestore_admin_api.create_index(
83114
parent=parent,
84115
index=Index(
@@ -103,13 +134,16 @@ def create_vector_documents(client, collection_id):
103134
document1 = client.document(collection_id, "doc1")
104135
document2 = client.document(collection_id, "doc2")
105136
document3 = client.document(collection_id, "doc3")
106-
document1.create({"embedding": Vector([1.0, 2.0, 3.0]), "color": "red"})
107-
document2.create({"embedding": Vector([2.0, 2.0, 3.0]), "color": "red"})
108-
document3.create({"embedding": Vector([3.0, 4.0, 5.0]), "color": "yellow"})
137+
document1.set({"embedding": Vector([1.0, 2.0, 3.0]), "color": "red"})
138+
document2.set({"embedding": Vector([2.0, 2.0, 3.0]), "color": "red"})
139+
document3.set({"embedding": Vector([3.0, 4.0, 5.0]), "color": "yellow"})
109140

110141

111142
def main():
112-
client = Client(project=PROJECT_ID, database=DATABASE_ID)
143+
client_options = ClientOptions(api_endpoint=TARGET_HOSTNAME)
144+
client = Client(
145+
project=PROJECT_ID, database=DATABASE_ID, client_options=client_options
146+
)
113147
create_vector_documents(client=client, collection_id=COLLECTION_ID)
114148
admin_client = FirestoreAdminClient(project=PROJECT_ID)
115149
admin_client.create_vector_index(

tests/unit/v1/test_async_vector_query.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,72 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit):
4545
return expected_pb
4646

4747

48+
@pytest.mark.parametrize(
49+
"distance_measure, expected_distance",
50+
[
51+
(
52+
DistanceMeasure.EUCLIDEAN,
53+
StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
54+
),
55+
(DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE),
56+
(
57+
DistanceMeasure.DOT_PRODUCT,
58+
StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT,
59+
),
60+
],
61+
)
62+
@pytest.mark.asyncio
63+
async def test_async_vector_query(distance_measure, expected_distance):
64+
# Create a minimal fake GAPIC.
65+
firestore_api = AsyncMock(spec=["run_query"])
66+
client = make_async_client()
67+
client._firestore_api_internal = firestore_api
68+
69+
# Make a **real** collection reference as parent.
70+
parent = client.collection("dee")
71+
parent_path, expected_prefix = parent._parent_info()
72+
73+
data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
74+
response_pb1 = _make_query_response(
75+
name="{}/test_doc".format(expected_prefix), data=data
76+
)
77+
78+
kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)
79+
80+
# Execute the vector query and check the response.
81+
firestore_api.run_query.return_value = AsyncIter([response_pb1])
82+
83+
vector_async_query = parent.find_nearest(
84+
vector_field="embedding",
85+
query_vector=Vector([1.0, 2.0, 3.0]),
86+
distance_measure=distance_measure,
87+
limit=5,
88+
)
89+
90+
returned = await vector_async_query.get(transaction=_transaction(client), **kwargs)
91+
assert isinstance(returned, list)
92+
assert len(returned) == 1
93+
assert returned[0].to_dict() == data
94+
95+
expected_pb = _expected_pb(
96+
parent=parent,
97+
vector_field="embedding",
98+
vector=Vector([1.0, 2.0, 3.0]),
99+
distance_type=expected_distance,
100+
limit=5,
101+
)
102+
103+
firestore_api.run_query.assert_called_once_with(
104+
request={
105+
"parent": parent_path,
106+
"structured_query": expected_pb,
107+
"transaction": _TXN_ID,
108+
},
109+
metadata=client._rpc_metadata,
110+
**kwargs,
111+
)
112+
113+
48114
@pytest.mark.parametrize(
49115
"distance_measure, expected_distance",
50116
[
@@ -84,14 +150,14 @@ async def test_async_vector_query_with_filter(distance_measure, expected_distanc
84150
# Execute the vector query and check the response.
85151
firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2])
86152

87-
vector_async__query = query.where("snooze", "==", 10).find_nearest(
153+
vector_async_query = query.where("snooze", "==", 10).find_nearest(
88154
vector_field="embedding",
89155
query_vector=Vector([1.0, 2.0, 3.0]),
90156
distance_measure=distance_measure,
91157
limit=5,
92158
)
93159

94-
returned = await vector_async__query.get(transaction=_transaction(client), **kwargs)
160+
returned = await vector_async_query.get(transaction=_transaction(client), **kwargs)
95161
assert isinstance(returned, list)
96162
assert len(returned) == 2
97163
assert returned[0].to_dict() == data

0 commit comments

Comments
 (0)