diff --git a/.coveragerc b/.coveragerc
index 3128ad99e..f12d4dc21 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -18,13 +18,11 @@
[run]
branch = True
omit =
- google/cloud/__init__.py
- google/__init__.py
google/cloud/bigtable_admin/__init__.py
google/cloud/bigtable_admin/gapic_version.py
[report]
-fail_under = 100
+fail_under = 99
show_missing = True
exclude_lines =
# Re-enable the standard pragma
@@ -33,11 +31,5 @@ exclude_lines =
def __repr__
# Ignore abstract methods
raise NotImplementedError
- # Ignore setuptools-less fallback
- except pkg_resources.DistributionNotFound:
omit =
- */gapic/*.py
- */proto/*.py
- */core/*.py
*/site-packages/*.py
- google/cloud/__init__.py
diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml
index eb4d9f794..2aefd0e91 100644
--- a/.github/.OwlBot.lock.yaml
+++ b/.github/.OwlBot.lock.yaml
@@ -1,4 +1,4 @@
-# Copyright 2023 Google LLC
+# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,5 +13,5 @@
# limitations under the License.
docker:
image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest
- digest: sha256:bacc3af03bff793a03add584537b36b5644342931ad989e3ba1171d3bd5399f5
-# created: 2023-11-23T18:17:28.105124211Z
+ digest: sha256:97b671488ad548ef783a452a9e1276ac10f144d5ae56d98cc4bf77ba504082b4
+# created: 2024-02-06T03:20:16.660474034Z
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 2f1fee904..8e8f088b7 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -5,8 +5,8 @@
# https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners#codeowners-syntax
# Note: This file is autogenerated. To make changes to the codeowner team, please update .repo-metadata.json.
-# @googleapis/yoshi-python @googleapis/api-bigtable are the default owners for changes in this repo
-* @googleapis/yoshi-python @googleapis/api-bigtable
+# @googleapis/yoshi-python @googleapis/api-bigtable @googleapis/api-bigtable-partners are the default owners for changes in this repo
+* @googleapis/yoshi-python @googleapis/api-bigtable @googleapis/api-bigtable-partners
-# @googleapis/python-samples-reviewers @googleapis/api-bigtable are the default owners for samples changes
-/samples/ @googleapis/python-samples-reviewers @googleapis/api-bigtable
+# @googleapis/python-samples-reviewers @googleapis/api-bigtable @googleapis/api-bigtable-partners are the default owners for samples changes
+/samples/ @googleapis/python-samples-reviewers @googleapis/api-bigtable @googleapis/api-bigtable-partners
diff --git a/.github/flakybot.yaml b/.github/flakybot.yaml
new file mode 100644
index 000000000..2159a1bca
--- /dev/null
+++ b/.github/flakybot.yaml
@@ -0,0 +1,15 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+issuePriority: p2
\ No newline at end of file
diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml
index a0d3362c9..a8cc5b33b 100644
--- a/.github/sync-repo-settings.yaml
+++ b/.github/sync-repo-settings.yaml
@@ -31,6 +31,24 @@ branchProtectionRules:
- 'Kokoro'
- 'Kokoro system-3.8'
- 'cla/google'
+- pattern: experimental_v3
+ # Can admins overwrite branch protection.
+ # Defaults to `true`
+ isAdminEnforced: false
+ # Number of approving reviews required to update matching branches.
+ # Defaults to `1`
+ requiredApprovingReviewCount: 1
+ # Are reviews from code owners required to update matching branches.
+ # Defaults to `false`
+ requiresCodeOwnerReviews: false
+ # Require up to date branches
+ requiresStrictStatusChecks: false
+ # List of required status check contexts that must pass for commits to be accepted to matching branches.
+ requiredStatusCheckContexts:
+ - 'Kokoro'
+ - 'Kokoro system-3.8'
+ - 'cla/google'
+ - 'Conformance / Async v3 Client / Python 3.8'
# List of explicit permissions to add (additive only)
permissionRules:
# Team slug to add to repository permissions
diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml
new file mode 100644
index 000000000..68545cbec
--- /dev/null
+++ b/.github/workflows/conformance.yaml
@@ -0,0 +1,56 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Github action job to test core java library features on
+# downstream client libraries before they are released.
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+name: Conformance
+jobs:
+ conformance:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ test-version: [ "v0.0.2" ]
+ py-version: [ 3.8 ]
+ client-type: [ "Async v3", "Legacy" ]
+ fail-fast: false
+ name: "${{ matrix.client-type }} Client / Python ${{ matrix.py-version }} / Test Tag ${{ matrix.test-version }}"
+ steps:
+ - uses: actions/checkout@v4
+ name: "Checkout python-bigtable"
+ - uses: actions/checkout@v4
+ name: "Checkout conformance tests"
+ with:
+ repository: googleapis/cloud-bigtable-clients-test
+ ref: ${{ matrix.test-version }}
+ path: cloud-bigtable-clients-test
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.py-version }}
+ - uses: actions/setup-go@v5
+ with:
+ go-version: '>=1.20.2'
+ - run: chmod +x .kokoro/conformance.sh
+ - run: pip install -e .
+ name: "Install python-bigtable from HEAD"
+ - run: go version
+ - run: .kokoro/conformance.sh
+ name: "Run tests"
+ env:
+ CLIENT_TYPE: ${{ matrix.client-type }}
+ PYTHONUNBUFFERED: 1
+
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 221806ced..698fbc5c9 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -8,9 +8,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install nox
@@ -24,9 +24,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install nox
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 16d5a9e90..4866193af 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -8,9 +8,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install nox
diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml
index c63242630..3915cddd3 100644
--- a/.github/workflows/mypy.yml
+++ b/.github/workflows/mypy.yml
@@ -8,9 +8,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install nox
diff --git a/.github/workflows/system_emulated.yml b/.github/workflows/system_emulated.yml
index f1aa7e87c..fa5ef15af 100644
--- a/.github/workflows/system_emulated.yml
+++ b/.github/workflows/system_emulated.yml
@@ -12,15 +12,15 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: Setup GCloud SDK
- uses: google-github-actions/setup-gcloud@v1.1.1
+ uses: google-github-actions/setup-gcloud@v2.1.0
- name: Install / run Nox
run: |
diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml
index a32027b49..87d08602f 100644
--- a/.github/workflows/unittest.yml
+++ b/.github/workflows/unittest.yml
@@ -11,9 +11,9 @@ jobs:
python: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
steps:
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install nox
@@ -26,9 +26,9 @@ jobs:
run: |
nox -s unit-${{ matrix.python }}
- name: Upload coverage results
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
- name: coverage-artifacts
+ name: coverage-artifact-${{ matrix.python }}
path: .coverage-${{ matrix.python }}
cover:
@@ -37,9 +37,9 @@ jobs:
- unit
steps:
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install coverage
@@ -47,11 +47,11 @@ jobs:
python -m pip install --upgrade setuptools pip wheel
python -m pip install coverage
- name: Download coverage results
- uses: actions/download-artifact@v3
+ uses: actions/download-artifact@v4
with:
- name: coverage-artifacts
path: .coverage-results/
- name: Report coverage results
run: |
- coverage combine .coverage-results/.coverage*
- coverage report --show-missing --fail-under=100
+ find .coverage-results -type f -name '*.zip' -exec unzip {} \;
+ coverage combine .coverage-results/**/.coverage*
+ coverage report --show-missing --fail-under=99
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 000000000..5fa9b1ed5
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,6 @@
+[submodule "python-api-core"]
+ path = python-api-core
+ url = git@github.com:googleapis/python-api-core.git
+[submodule "gapic-generator-fork"]
+ path = gapic-generator-fork
+ url = git@github.com:googleapis/gapic-generator-python.git
diff --git a/.kokoro/conformance.sh b/.kokoro/conformance.sh
new file mode 100644
index 000000000..1c0b3ee0d
--- /dev/null
+++ b/.kokoro/conformance.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -eo pipefail
+
+## cd to the parent directory, i.e. the root of the git repo
+cd $(dirname $0)/..
+
+PROXY_ARGS=""
+TEST_ARGS=""
+if [[ "${CLIENT_TYPE^^}" == "LEGACY" ]]; then
+ echo "Using legacy client"
+ PROXY_ARGS="--legacy-client"
+ # legacy client does not expose mutate_row. Disable those tests
+ TEST_ARGS="-skip TestMutateRow_"
+fi
+
+# Build and start the proxy in a separate process
+PROXY_PORT=9999
+pushd test_proxy
+nohup python test_proxy.py --port $PROXY_PORT $PROXY_ARGS &
+proxyPID=$!
+popd
+
+# Kill proxy on exit
+function cleanup() {
+ echo "Cleanup testbench";
+ kill $proxyPID
+}
+trap cleanup EXIT
+
+# Run the conformance test
+pushd cloud-bigtable-clients-test/tests
+eval "go test -v -proxy_addr=:$PROXY_PORT $TEST_ARGS"
+RETURN_CODE=$?
+popd
+
+echo "exiting with ${RETURN_CODE}"
+exit ${RETURN_CODE}
diff --git a/.kokoro/presubmit/conformance.cfg b/.kokoro/presubmit/conformance.cfg
new file mode 100644
index 000000000..4f44e8a78
--- /dev/null
+++ b/.kokoro/presubmit/conformance.cfg
@@ -0,0 +1,6 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "NOX_SESSION"
+ value: "conformance"
+}
diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt
index 8957e2110..8c11c9f3e 100644
--- a/.kokoro/requirements.txt
+++ b/.kokoro/requirements.txt
@@ -93,30 +93,39 @@ colorlog==6.7.0 \
# via
# gcp-docuploader
# nox
-cryptography==41.0.5 \
- --hash=sha256:0c327cac00f082013c7c9fb6c46b7cc9fa3c288ca702c74773968173bda421bf \
- --hash=sha256:0d2a6a598847c46e3e321a7aef8af1436f11c27f1254933746304ff014664d84 \
- --hash=sha256:227ec057cd32a41c6651701abc0328135e472ed450f47c2766f23267b792a88e \
- --hash=sha256:22892cc830d8b2c89ea60148227631bb96a7da0c1b722f2aac8824b1b7c0b6b8 \
- --hash=sha256:392cb88b597247177172e02da6b7a63deeff1937fa6fec3bbf902ebd75d97ec7 \
- --hash=sha256:3be3ca726e1572517d2bef99a818378bbcf7d7799d5372a46c79c29eb8d166c1 \
- --hash=sha256:573eb7128cbca75f9157dcde974781209463ce56b5804983e11a1c462f0f4e88 \
- --hash=sha256:580afc7b7216deeb87a098ef0674d6ee34ab55993140838b14c9b83312b37b86 \
- --hash=sha256:5a70187954ba7292c7876734183e810b728b4f3965fbe571421cb2434d279179 \
- --hash=sha256:73801ac9736741f220e20435f84ecec75ed70eda90f781a148f1bad546963d81 \
- --hash=sha256:7d208c21e47940369accfc9e85f0de7693d9a5d843c2509b3846b2db170dfd20 \
- --hash=sha256:8254962e6ba1f4d2090c44daf50a547cd5f0bf446dc658a8e5f8156cae0d8548 \
- --hash=sha256:88417bff20162f635f24f849ab182b092697922088b477a7abd6664ddd82291d \
- --hash=sha256:a48e74dad1fb349f3dc1d449ed88e0017d792997a7ad2ec9587ed17405667e6d \
- --hash=sha256:b948e09fe5fb18517d99994184854ebd50b57248736fd4c720ad540560174ec5 \
- --hash=sha256:c707f7afd813478e2019ae32a7c49cd932dd60ab2d2a93e796f68236b7e1fbf1 \
- --hash=sha256:d38e6031e113b7421db1de0c1b1f7739564a88f1684c6b89234fbf6c11b75147 \
- --hash=sha256:d3977f0e276f6f5bf245c403156673db103283266601405376f075c849a0b936 \
- --hash=sha256:da6a0ff8f1016ccc7477e6339e1d50ce5f59b88905585f77193ebd5068f1e797 \
- --hash=sha256:e270c04f4d9b5671ebcc792b3ba5d4488bf7c42c3c241a3748e2599776f29696 \
- --hash=sha256:e886098619d3815e0ad5790c973afeee2c0e6e04b4da90b88e6bd06e2a0b1b72 \
- --hash=sha256:ec3b055ff8f1dce8e6ef28f626e0972981475173d7973d63f271b29c8a2897da \
- --hash=sha256:fba1e91467c65fe64a82c689dc6cf58151158993b13eb7a7f3f4b7f395636723
+cryptography==42.0.0 \
+ --hash=sha256:0a68bfcf57a6887818307600c3c0ebc3f62fbb6ccad2240aa21887cda1f8df1b \
+ --hash=sha256:146e971e92a6dd042214b537a726c9750496128453146ab0ee8971a0299dc9bd \
+ --hash=sha256:14e4b909373bc5bf1095311fa0f7fcabf2d1a160ca13f1e9e467be1ac4cbdf94 \
+ --hash=sha256:206aaf42e031b93f86ad60f9f5d9da1b09164f25488238ac1dc488334eb5e221 \
+ --hash=sha256:3005166a39b70c8b94455fdbe78d87a444da31ff70de3331cdec2c568cf25b7e \
+ --hash=sha256:324721d93b998cb7367f1e6897370644751e5580ff9b370c0a50dc60a2003513 \
+ --hash=sha256:33588310b5c886dfb87dba5f013b8d27df7ffd31dc753775342a1e5ab139e59d \
+ --hash=sha256:35cf6ed4c38f054478a9df14f03c1169bb14bd98f0b1705751079b25e1cb58bc \
+ --hash=sha256:3ca482ea80626048975360c8e62be3ceb0f11803180b73163acd24bf014133a0 \
+ --hash=sha256:56ce0c106d5c3fec1038c3cca3d55ac320a5be1b44bf15116732d0bc716979a2 \
+ --hash=sha256:5a217bca51f3b91971400890905a9323ad805838ca3fa1e202a01844f485ee87 \
+ --hash=sha256:678cfa0d1e72ef41d48993a7be75a76b0725d29b820ff3cfd606a5b2b33fda01 \
+ --hash=sha256:69fd009a325cad6fbfd5b04c711a4da563c6c4854fc4c9544bff3088387c77c0 \
+ --hash=sha256:6cf9b76d6e93c62114bd19485e5cb003115c134cf9ce91f8ac924c44f8c8c3f4 \
+ --hash=sha256:74f18a4c8ca04134d2052a140322002fef535c99cdbc2a6afc18a8024d5c9d5b \
+ --hash=sha256:85f759ed59ffd1d0baad296e72780aa62ff8a71f94dc1ab340386a1207d0ea81 \
+ --hash=sha256:87086eae86a700307b544625e3ba11cc600c3c0ef8ab97b0fda0705d6db3d4e3 \
+ --hash=sha256:8814722cffcfd1fbd91edd9f3451b88a8f26a5fd41b28c1c9193949d1c689dc4 \
+ --hash=sha256:8fedec73d590fd30c4e3f0d0f4bc961aeca8390c72f3eaa1a0874d180e868ddf \
+ --hash=sha256:9515ea7f596c8092fdc9902627e51b23a75daa2c7815ed5aa8cf4f07469212ec \
+ --hash=sha256:988b738f56c665366b1e4bfd9045c3efae89ee366ca3839cd5af53eaa1401bce \
+ --hash=sha256:a2a8d873667e4fd2f34aedab02ba500b824692c6542e017075a2efc38f60a4c0 \
+ --hash=sha256:bd7cf7a8d9f34cc67220f1195884151426ce616fdc8285df9054bfa10135925f \
+ --hash=sha256:bdce70e562c69bb089523e75ef1d9625b7417c6297a76ac27b1b8b1eb51b7d0f \
+ --hash=sha256:be14b31eb3a293fc6e6aa2807c8a3224c71426f7c4e3639ccf1a2f3ffd6df8c3 \
+ --hash=sha256:be41b0c7366e5549265adf2145135dca107718fa44b6e418dc7499cfff6b4689 \
+ --hash=sha256:c310767268d88803b653fffe6d6f2f17bb9d49ffceb8d70aed50ad45ea49ab08 \
+ --hash=sha256:c58115384bdcfe9c7f644c72f10f6f42bed7cf59f7b52fe1bf7ae0a622b3a139 \
+ --hash=sha256:c640b0ef54138fde761ec99a6c7dc4ce05e80420262c20fa239e694ca371d434 \
+ --hash=sha256:ca20550bb590db16223eb9ccc5852335b48b8f597e2f6f0878bbfd9e7314eb17 \
+ --hash=sha256:d97aae66b7de41cdf5b12087b5509e4e9805ed6f562406dfcf60e8481a9a28f8 \
+ --hash=sha256:e9326ca78111e4c645f7e49cbce4ed2f3f85e17b61a563328c85a5208cf34440
# via
# gcp-releasetool
# secretstorage
@@ -263,9 +272,9 @@ jeepney==0.8.0 \
# via
# keyring
# secretstorage
-jinja2==3.1.2 \
- --hash=sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852 \
- --hash=sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61
+jinja2==3.1.3 \
+ --hash=sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa \
+ --hash=sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90
# via gcp-releasetool
keyring==24.2.0 \
--hash=sha256:4901caaf597bfd3bbd78c9a0c7c4c29fcd8310dab2cffefe749e916b6527acd6 \
diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index a5ab48803..b94f3df9f 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "2.22.0"
+ ".": "2.23.0"
}
\ No newline at end of file
diff --git a/.repo-metadata.json b/.repo-metadata.json
index 3c65ac669..9de4b5f92 100644
--- a/.repo-metadata.json
+++ b/.repo-metadata.json
@@ -75,6 +75,6 @@
}
],
"default_version": "v2",
- "codeowner_team": "@googleapis/api-bigtable",
+ "codeowner_team": "@googleapis/api-bigtable @googleapis/api-bigtable-partners",
"api_shortname": "bigtable"
}
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5f86fdd88..ea8a8525d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,19 @@
[1]: https://pypi.org/project/google-cloud-bigtable/#history
+## [2.23.0](https://github.com/googleapis/python-bigtable/compare/v2.22.0...v2.23.0) (2024-02-07)
+
+
+### Features
+
+* Add async data client preview ([7088e39](https://github.com/googleapis/python-bigtable/commit/7088e39c6bac10e5f830e8fa68e181412910ec5a))
+* Adding feature flags for routing cookie and retry info ([#905](https://github.com/googleapis/python-bigtable/issues/905)) ([1859e67](https://github.com/googleapis/python-bigtable/commit/1859e67961629663a8749eea849b5b005fcbc09f))
+
+
+### Bug Fixes
+
+* Fix `ValueError` in `test__validate_universe_domain` ([#929](https://github.com/googleapis/python-bigtable/issues/929)) ([aa76a5a](https://github.com/googleapis/python-bigtable/commit/aa76a5aaa349386d5972d96e1255389e30df8764))
+
## [2.22.0](https://github.com/googleapis/python-bigtable/compare/v2.21.0...v2.22.0) (2023-12-12)
diff --git a/README.rst b/README.rst
index 5f7d5809d..2bc151e95 100644
--- a/README.rst
+++ b/README.rst
@@ -20,6 +20,24 @@ Analytics, Maps, and Gmail.
.. _Client Library Documentation: https://googleapis.dev/python/bigtable/latest
.. _Product Documentation: https://cloud.google.com/bigtable/docs
+
+Preview Async Data Client
+-------------------------
+
+:code:`v2.23.0` includes a preview release of the new :code:`BigtableDataClientAsync` client, accessible at the import path
+:code:`google.cloud.bigtable.data`.
+
+The new client brings a simplified API and increased performance using asyncio, with a corresponding synchronous surface
+coming soon. The new client is focused on the data API (i.e. reading and writing Bigtable data), with admin operations
+remaining in the existing client.
+
+:code:`BigtableDataClientAsync` is currently in preview, and is not recommended for production use.
+
+Feedback and bug reports are welcome at cbt-python-client-v3-feedback@google.com,
+or through the Github `issue tracker`_.
+
+.. _issue tracker: https://github.com/googleapis/python-bigtable/issues
+
Quick Start
-----------
@@ -94,14 +112,3 @@ Next Steps
to see other available methods on the client.
- Read the `Product documentation`_ to learn
more about the product and see How-to Guides.
-
-``google-cloud-happybase``
---------------------------
-
-In addition to the core ``google-cloud-bigtable``, we provide a
-`google-cloud-happybase
-`__ library
-with the same interface as the popular `HappyBase
-`__ library. Unlike HappyBase,
-``google-cloud-happybase`` uses ``google-cloud-bigtable`` under the covers,
-rather than Apache HBase.
diff --git a/gapic-generator-fork b/gapic-generator-fork
new file mode 160000
index 000000000..b26cda7d1
--- /dev/null
+++ b/gapic-generator-fork
@@ -0,0 +1 @@
+Subproject commit b26cda7d163d6e0d45c9684f328ca32fb49b799a
diff --git a/google/cloud/bigtable/data/README.rst b/google/cloud/bigtable/data/README.rst
new file mode 100644
index 000000000..7a05cf913
--- /dev/null
+++ b/google/cloud/bigtable/data/README.rst
@@ -0,0 +1,11 @@
+Async Data Client Preview
+=========================
+
+This new client is currently in preview, and is not recommended for production use.
+
+Synchronous API surface and usage examples coming soon
+
+Feedback and bug reports are welcome at cbt-python-client-v3-feedback@google.com,
+or through the Github `issue tracker`_.
+
+.. _issue tracker: https://github.com/googleapis/python-bigtable/issues
diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py
new file mode 100644
index 000000000..5229f8021
--- /dev/null
+++ b/google/cloud/bigtable/data/__init__.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from google.cloud.bigtable import gapic_version as package_version
+
+from google.cloud.bigtable.data._async.client import BigtableDataClientAsync
+from google.cloud.bigtable.data._async.client import TableAsync
+
+from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync
+
+from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+from google.cloud.bigtable.data.read_rows_query import RowRange
+from google.cloud.bigtable.data.row import Row
+from google.cloud.bigtable.data.row import Cell
+
+from google.cloud.bigtable.data.mutations import Mutation
+from google.cloud.bigtable.data.mutations import RowMutationEntry
+from google.cloud.bigtable.data.mutations import SetCell
+from google.cloud.bigtable.data.mutations import DeleteRangeFromColumn
+from google.cloud.bigtable.data.mutations import DeleteAllFromFamily
+from google.cloud.bigtable.data.mutations import DeleteAllFromRow
+
+from google.cloud.bigtable.data.exceptions import InvalidChunk
+from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
+from google.cloud.bigtable.data.exceptions import FailedQueryShardError
+
+from google.cloud.bigtable.data.exceptions import RetryExceptionGroup
+from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
+from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup
+
+from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
+from google.cloud.bigtable.data._helpers import RowKeySamples
+from google.cloud.bigtable.data._helpers import ShardedQuery
+
+
+__version__: str = package_version.__version__
+
+__all__ = (
+ "BigtableDataClientAsync",
+ "TableAsync",
+ "RowKeySamples",
+ "ReadRowsQuery",
+ "RowRange",
+ "MutationsBatcherAsync",
+ "Mutation",
+ "RowMutationEntry",
+ "SetCell",
+ "DeleteRangeFromColumn",
+ "DeleteAllFromFamily",
+ "DeleteAllFromRow",
+ "Row",
+ "Cell",
+ "InvalidChunk",
+ "FailedMutationEntryError",
+ "FailedQueryShardError",
+ "RetryExceptionGroup",
+ "MutationsExceptionGroup",
+ "ShardedReadRowsExceptionGroup",
+ "ShardedQuery",
+ "TABLE_DEFAULT",
+)
diff --git a/google/cloud/bigtable/data/_async/__init__.py b/google/cloud/bigtable/data/_async/__init__.py
new file mode 100644
index 000000000..e13c9acb7
--- /dev/null
+++ b/google/cloud/bigtable/data/_async/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.cloud.bigtable.data._async.client import BigtableDataClientAsync
+from google.cloud.bigtable.data._async.client import TableAsync
+
+from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync
+
+
+__all__ = [
+ "BigtableDataClientAsync",
+ "TableAsync",
+ "MutationsBatcherAsync",
+]
diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py
new file mode 100644
index 000000000..7d1144553
--- /dev/null
+++ b/google/cloud/bigtable/data/_async/_mutate_rows.py
@@ -0,0 +1,226 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import annotations
+
+from typing import Sequence, TYPE_CHECKING
+from dataclasses import dataclass
+import functools
+
+from google.api_core import exceptions as core_exceptions
+from google.api_core import retry as retries
+import google.cloud.bigtable_v2.types.bigtable as types_pb
+import google.cloud.bigtable.data.exceptions as bt_exceptions
+from google.cloud.bigtable.data._helpers import _make_metadata
+from google.cloud.bigtable.data._helpers import _attempt_timeout_generator
+from google.cloud.bigtable.data._helpers import _retry_exception_factory
+
+# mutate_rows requests are limited to this number of mutations
+from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT
+
+if TYPE_CHECKING:
+ from google.cloud.bigtable_v2.services.bigtable.async_client import (
+ BigtableAsyncClient,
+ )
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+ from google.cloud.bigtable.data._async.client import TableAsync
+
+
+@dataclass
+class _EntryWithProto:
+ """
+ A dataclass to hold a RowMutationEntry and its corresponding proto representation.
+ """
+
+ entry: RowMutationEntry
+ proto: types_pb.MutateRowsRequest.Entry
+
+
+class _MutateRowsOperationAsync:
+ """
+ MutateRowsOperation manages the logic of sending a set of row mutations,
+ and retrying on failed entries. It manages this using the _run_attempt
+ function, which attempts to mutate all outstanding entries, and raises
+ _MutateRowsIncomplete if any retryable errors are encountered.
+
+ Errors are exposed as a MutationsExceptionGroup, which contains a list of
+ exceptions organized by the related failed mutation entries.
+ """
+
+ def __init__(
+ self,
+ gapic_client: "BigtableAsyncClient",
+ table: "TableAsync",
+ mutation_entries: list["RowMutationEntry"],
+ operation_timeout: float,
+ attempt_timeout: float | None,
+ retryable_exceptions: Sequence[type[Exception]] = (),
+ ):
+ """
+ Args:
+ - gapic_client: the client to use for the mutate_rows call
+ - table: the table associated with the request
+ - mutation_entries: a list of RowMutationEntry objects to send to the server
+ - operation_timeout: the timeout to use for the entire operation, in seconds.
+ - attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds.
+ If not specified, the request will run until operation_timeout is reached.
+ """
+ # check that mutations are within limits
+ total_mutations = sum(len(entry.mutations) for entry in mutation_entries)
+ if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT:
+ raise ValueError(
+ "mutate_rows requests can contain at most "
+ f"{_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across "
+ f"all entries. Found {total_mutations}."
+ )
+ # create partial function to pass to trigger rpc call
+ metadata = _make_metadata(table.table_name, table.app_profile_id)
+ self._gapic_fn = functools.partial(
+ gapic_client.mutate_rows,
+ table_name=table.table_name,
+ app_profile_id=table.app_profile_id,
+ metadata=metadata,
+ retry=None,
+ )
+ # create predicate for determining which errors are retryable
+ self.is_retryable = retries.if_exception_type(
+ # RPC level errors
+ *retryable_exceptions,
+ # Entry level errors
+ bt_exceptions._MutateRowsIncomplete,
+ )
+ sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60)
+ self._operation = retries.retry_target_async(
+ self._run_attempt,
+ self.is_retryable,
+ sleep_generator,
+ operation_timeout,
+ exception_factory=_retry_exception_factory,
+ )
+ # initialize state
+ self.timeout_generator = _attempt_timeout_generator(
+ attempt_timeout, operation_timeout
+ )
+ self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries]
+ self.remaining_indices = list(range(len(self.mutations)))
+ self.errors: dict[int, list[Exception]] = {}
+
+ async def start(self):
+ """
+ Start the operation, and run until completion
+
+ Raises:
+ - MutationsExceptionGroup: if any mutations failed
+ """
+ try:
+ # trigger mutate_rows
+ await self._operation
+ except Exception as exc:
+ # exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations
+ incomplete_indices = self.remaining_indices.copy()
+ for idx in incomplete_indices:
+ self._handle_entry_error(idx, exc)
+ finally:
+ # raise exception detailing incomplete mutations
+ all_errors: list[Exception] = []
+ for idx, exc_list in self.errors.items():
+ if len(exc_list) == 0:
+ raise core_exceptions.ClientError(
+ f"Mutation {idx} failed with no associated errors"
+ )
+ elif len(exc_list) == 1:
+ cause_exc = exc_list[0]
+ else:
+ cause_exc = bt_exceptions.RetryExceptionGroup(exc_list)
+ entry = self.mutations[idx].entry
+ all_errors.append(
+ bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc)
+ )
+ if all_errors:
+ raise bt_exceptions.MutationsExceptionGroup(
+ all_errors, len(self.mutations)
+ )
+
+ async def _run_attempt(self):
+ """
+ Run a single attempt of the mutate_rows rpc.
+
+ Raises:
+ - _MutateRowsIncomplete: if there are failed mutations eligible for
+ retry after the attempt is complete
+ - GoogleAPICallError: if the gapic rpc fails
+ """
+ request_entries = [self.mutations[idx].proto for idx in self.remaining_indices]
+ # track mutations in this request that have not been finalized yet
+ active_request_indices = {
+ req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices)
+ }
+ self.remaining_indices = []
+ if not request_entries:
+ # no more mutations. return early
+ return
+ # make gapic request
+ try:
+ result_generator = await self._gapic_fn(
+ timeout=next(self.timeout_generator),
+ entries=request_entries,
+ retry=None,
+ )
+ async for result_list in result_generator:
+ for result in result_list.entries:
+ # convert sub-request index to global index
+ orig_idx = active_request_indices[result.index]
+ entry_error = core_exceptions.from_grpc_status(
+ result.status.code,
+ result.status.message,
+ details=result.status.details,
+ )
+ if result.status.code != 0:
+ # mutation failed; update error list (and remaining_indices if retryable)
+ self._handle_entry_error(orig_idx, entry_error)
+ elif orig_idx in self.errors:
+ # mutation succeeded; remove from error list
+ del self.errors[orig_idx]
+ # remove processed entry from active list
+ del active_request_indices[result.index]
+ except Exception as exc:
+ # add this exception to list for each mutation that wasn't
+ # already handled, and update remaining_indices if mutation is retryable
+ for idx in active_request_indices.values():
+ self._handle_entry_error(idx, exc)
+ # bubble up exception to be handled by retry wrapper
+ raise
+ # check if attempt succeeded, or needs to be retried
+ if self.remaining_indices:
+ # unfinished work; raise exception to trigger retry
+ raise bt_exceptions._MutateRowsIncomplete
+
+ def _handle_entry_error(self, idx: int, exc: Exception):
+ """
+ Add an exception to the list of exceptions for a given mutation index,
+ and add the index to the list of remaining indices if the exception is
+ retryable.
+
+ Args:
+ - idx: the index of the mutation that failed
+ - exc: the exception to add to the list
+ """
+ entry = self.mutations[idx].entry
+ self.errors.setdefault(idx, []).append(exc)
+ if (
+ entry.is_idempotent()
+ and self.is_retryable(exc)
+ and idx not in self.remaining_indices
+ ):
+ self.remaining_indices.append(idx)
diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py
new file mode 100644
index 000000000..9e0fd78e1
--- /dev/null
+++ b/google/cloud/bigtable/data/_async/_read_rows.py
@@ -0,0 +1,343 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import annotations
+
+from typing import (
+ TYPE_CHECKING,
+ AsyncGenerator,
+ AsyncIterable,
+ Awaitable,
+ Sequence,
+)
+
+from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB
+from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB
+from google.cloud.bigtable_v2.types import RowSet as RowSetPB
+from google.cloud.bigtable_v2.types import RowRange as RowRangePB
+
+from google.cloud.bigtable.data.row import Row, Cell
+from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+from google.cloud.bigtable.data.exceptions import InvalidChunk
+from google.cloud.bigtable.data.exceptions import _RowSetComplete
+from google.cloud.bigtable.data._helpers import _attempt_timeout_generator
+from google.cloud.bigtable.data._helpers import _make_metadata
+from google.cloud.bigtable.data._helpers import _retry_exception_factory
+
+from google.api_core import retry as retries
+from google.api_core.retry import exponential_sleep_generator
+
+if TYPE_CHECKING:
+ from google.cloud.bigtable.data._async.client import TableAsync
+
+
+class _ResetRow(Exception):
+ def __init__(self, chunk):
+ self.chunk = chunk
+
+
+class _ReadRowsOperationAsync:
+ """
+ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream
+ into a stream of Row objects.
+
+ ReadRowsOperation.merge_row_response_stream takes in a stream of ReadRowsResponse
+ and turns them into a stream of Row objects using an internal
+ StateMachine.
+
+ ReadRowsOperation(request, client) handles row merging logic end-to-end, including
+ performing retries on stream errors.
+ """
+
+ __slots__ = (
+ "attempt_timeout_gen",
+ "operation_timeout",
+ "request",
+ "table",
+ "_predicate",
+ "_metadata",
+ "_last_yielded_row_key",
+ "_remaining_count",
+ )
+
+ def __init__(
+ self,
+ query: ReadRowsQuery,
+ table: "TableAsync",
+ operation_timeout: float,
+ attempt_timeout: float,
+ retryable_exceptions: Sequence[type[Exception]] = (),
+ ):
+ self.attempt_timeout_gen = _attempt_timeout_generator(
+ attempt_timeout, operation_timeout
+ )
+ self.operation_timeout = operation_timeout
+ if isinstance(query, dict):
+ self.request = ReadRowsRequestPB(
+ **query,
+ table_name=table.table_name,
+ app_profile_id=table.app_profile_id,
+ )
+ else:
+ self.request = query._to_pb(table)
+ self.table = table
+ self._predicate = retries.if_exception_type(*retryable_exceptions)
+ self._metadata = _make_metadata(
+ table.table_name,
+ table.app_profile_id,
+ )
+ self._last_yielded_row_key: bytes | None = None
+ self._remaining_count: int | None = self.request.rows_limit or None
+
+ def start_operation(self) -> AsyncGenerator[Row, None]:
+ """
+ Start the read_rows operation, retrying on retryable errors.
+ """
+ return retries.retry_target_stream_async(
+ self._read_rows_attempt,
+ self._predicate,
+ exponential_sleep_generator(0.01, 60, multiplier=2),
+ self.operation_timeout,
+ exception_factory=_retry_exception_factory,
+ )
+
+ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]:
+ """
+ Attempt a single read_rows rpc call.
+ This function is intended to be wrapped by retry logic,
+ which will call this function until it succeeds or
+ a non-retryable error is raised.
+ """
+ # revise request keys and ranges between attempts
+ if self._last_yielded_row_key is not None:
+ # if this is a retry, try to trim down the request to avoid ones we've already processed
+ try:
+ self.request.rows = self._revise_request_rowset(
+ row_set=self.request.rows,
+ last_seen_row_key=self._last_yielded_row_key,
+ )
+ except _RowSetComplete:
+ # if we've already seen all the rows, we're done
+ return self.merge_rows(None)
+ # revise the limit based on number of rows already yielded
+ if self._remaining_count is not None:
+ self.request.rows_limit = self._remaining_count
+ if self._remaining_count == 0:
+ return self.merge_rows(None)
+ # create and return a new row merger
+ gapic_stream = self.table.client._gapic_client.read_rows(
+ self.request,
+ timeout=next(self.attempt_timeout_gen),
+ metadata=self._metadata,
+ retry=None,
+ )
+ chunked_stream = self.chunk_stream(gapic_stream)
+ return self.merge_rows(chunked_stream)
+
+ async def chunk_stream(
+ self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]]
+ ) -> AsyncGenerator[ReadRowsResponsePB.CellChunk, None]:
+ """
+ process chunks out of raw read_rows stream
+ """
+ async for resp in await stream:
+ # extract proto from proto-plus wrapper
+ resp = resp._pb
+
+ # handle last_scanned_row_key packets, sent when server
+ # has scanned past the end of the row range
+ if resp.last_scanned_row_key:
+ if (
+ self._last_yielded_row_key is not None
+ and resp.last_scanned_row_key <= self._last_yielded_row_key
+ ):
+ raise InvalidChunk("last scanned out of order")
+ self._last_yielded_row_key = resp.last_scanned_row_key
+
+ current_key = None
+ # process each chunk in the response
+ for c in resp.chunks:
+ if current_key is None:
+ current_key = c.row_key
+ if current_key is None:
+ raise InvalidChunk("first chunk is missing a row key")
+ elif (
+ self._last_yielded_row_key
+ and current_key <= self._last_yielded_row_key
+ ):
+ raise InvalidChunk("row keys should be strictly increasing")
+
+ yield c
+
+ if c.reset_row:
+ current_key = None
+ elif c.commit_row:
+ # update row state after each commit
+ self._last_yielded_row_key = current_key
+ if self._remaining_count is not None:
+ self._remaining_count -= 1
+ if self._remaining_count < 0:
+ raise InvalidChunk("emit count exceeds row limit")
+ current_key = None
+
+ @staticmethod
+ async def merge_rows(
+ chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None
+ ):
+ """
+ Merge chunks into rows
+ """
+ if chunks is None:
+ return
+ it = chunks.__aiter__()
+ # For each row
+ while True:
+ try:
+ c = await it.__anext__()
+ except StopAsyncIteration:
+ # stream complete
+ return
+ row_key = c.row_key
+
+ if not row_key:
+ raise InvalidChunk("first row chunk is missing key")
+
+ cells = []
+
+ # shared per cell storage
+ family: str | None = None
+ qualifier: bytes | None = None
+
+ try:
+ # for each cell
+ while True:
+ if c.reset_row:
+ raise _ResetRow(c)
+ k = c.row_key
+ f = c.family_name.value
+ q = c.qualifier.value if c.HasField("qualifier") else None
+ if k and k != row_key:
+ raise InvalidChunk("unexpected new row key")
+ if f:
+ family = f
+ if q is not None:
+ qualifier = q
+ else:
+ raise InvalidChunk("new family without qualifier")
+ elif family is None:
+ raise InvalidChunk("missing family")
+ elif q is not None:
+ if family is None:
+ raise InvalidChunk("new qualifier without family")
+ qualifier = q
+ elif qualifier is None:
+ raise InvalidChunk("missing qualifier")
+
+ ts = c.timestamp_micros
+ labels = c.labels if c.labels else []
+ value = c.value
+
+ # merge split cells
+ if c.value_size > 0:
+ buffer = [value]
+ while c.value_size > 0:
+ # throws when premature end
+ c = await it.__anext__()
+
+ t = c.timestamp_micros
+ cl = c.labels
+ k = c.row_key
+ if (
+ c.HasField("family_name")
+ and c.family_name.value != family
+ ):
+ raise InvalidChunk("family changed mid cell")
+ if (
+ c.HasField("qualifier")
+ and c.qualifier.value != qualifier
+ ):
+ raise InvalidChunk("qualifier changed mid cell")
+ if t and t != ts:
+ raise InvalidChunk("timestamp changed mid cell")
+ if cl and cl != labels:
+ raise InvalidChunk("labels changed mid cell")
+ if k and k != row_key:
+ raise InvalidChunk("row key changed mid cell")
+
+ if c.reset_row:
+ raise _ResetRow(c)
+ buffer.append(c.value)
+ value = b"".join(buffer)
+ cells.append(
+ Cell(value, row_key, family, qualifier, ts, list(labels))
+ )
+ if c.commit_row:
+ yield Row(row_key, cells)
+ break
+ c = await it.__anext__()
+ except _ResetRow as e:
+ c = e.chunk
+ if (
+ c.row_key
+ or c.HasField("family_name")
+ or c.HasField("qualifier")
+ or c.timestamp_micros
+ or c.labels
+ or c.value
+ ):
+ raise InvalidChunk("reset row with data")
+ continue
+ except StopAsyncIteration:
+ raise InvalidChunk("premature end of stream")
+
+ @staticmethod
+ def _revise_request_rowset(
+ row_set: RowSetPB,
+ last_seen_row_key: bytes,
+ ) -> RowSetPB:
+ """
+ Revise the rows in the request to avoid ones we've already processed.
+
+ Args:
+ - row_set: the row set from the request
+ - last_seen_row_key: the last row key encountered
+ Raises:
+ - _RowSetComplete: if there are no rows left to process after the revision
+ """
+ # if user is doing a whole table scan, start a new one with the last seen key
+ if row_set is None or (not row_set.row_ranges and row_set.row_keys is not None):
+ last_seen = last_seen_row_key
+ return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)])
+ # remove seen keys from user-specific key list
+ adjusted_keys: list[bytes] = [
+ k for k in row_set.row_keys if k > last_seen_row_key
+ ]
+ # adjust ranges to ignore keys before last seen
+ adjusted_ranges: list[RowRangePB] = []
+ for row_range in row_set.row_ranges:
+ end_key = row_range.end_key_closed or row_range.end_key_open or None
+ if end_key is None or end_key > last_seen_row_key:
+ # end range is after last seen key
+ new_range = RowRangePB(row_range)
+ start_key = row_range.start_key_closed or row_range.start_key_open
+ if start_key is None or start_key <= last_seen_row_key:
+ # replace start key with last seen
+ new_range.start_key_open = last_seen_row_key
+ adjusted_ranges.append(new_range)
+ if len(adjusted_keys) == 0 and len(adjusted_ranges) == 0:
+ # if the query is empty after revision, raise an exception
+ # this will avoid an unwanted full table scan
+ raise _RowSetComplete()
+ return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges)
diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py
new file mode 100644
index 000000000..ed14c618d
--- /dev/null
+++ b/google/cloud/bigtable/data/_async/client.py
@@ -0,0 +1,1264 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import annotations
+
+from typing import (
+ cast,
+ Any,
+ AsyncIterable,
+ Optional,
+ Set,
+ Sequence,
+ TYPE_CHECKING,
+)
+
+import asyncio
+import grpc
+import time
+import warnings
+import sys
+import random
+import os
+
+from functools import partial
+
+from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta
+from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient
+from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO
+from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import (
+ PooledBigtableGrpcAsyncIOTransport,
+ PooledChannel,
+)
+from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest
+from google.cloud.client import ClientWithProject
+from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore
+from google.api_core import retry as retries
+from google.api_core.exceptions import DeadlineExceeded
+from google.api_core.exceptions import ServiceUnavailable
+from google.api_core.exceptions import Aborted
+from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync
+
+import google.auth.credentials
+import google.auth._default
+from google.api_core import client_options as client_options_lib
+from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT
+from google.cloud.bigtable.data.row import Row
+from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+from google.cloud.bigtable.data.exceptions import FailedQueryShardError
+from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup
+
+from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry
+from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync
+from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
+from google.cloud.bigtable.data._helpers import _WarmedInstanceKey
+from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT
+from google.cloud.bigtable.data._helpers import _make_metadata
+from google.cloud.bigtable.data._helpers import _retry_exception_factory
+from google.cloud.bigtable.data._helpers import _validate_timeouts
+from google.cloud.bigtable.data._helpers import _get_retryable_errors
+from google.cloud.bigtable.data._helpers import _get_timeouts
+from google.cloud.bigtable.data._helpers import _attempt_timeout_generator
+from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync
+from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE
+from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule
+from google.cloud.bigtable.data.row_filters import RowFilter
+from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter
+from google.cloud.bigtable.data.row_filters import RowFilterChain
+
+
+if TYPE_CHECKING:
+ from google.cloud.bigtable.data._helpers import RowKeySamples
+ from google.cloud.bigtable.data._helpers import ShardedQuery
+
+
+class BigtableDataClientAsync(ClientWithProject):
+ def __init__(
+ self,
+ *,
+ project: str | None = None,
+ pool_size: int = 3,
+ credentials: google.auth.credentials.Credentials | None = None,
+ client_options: dict[str, Any]
+ | "google.api_core.client_options.ClientOptions"
+ | None = None,
+ ):
+ """
+ Create a client instance for the Bigtable Data API
+
+ Client should be created within an async context (running event loop)
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ project: the project which the client acts on behalf of.
+ If not passed, falls back to the default inferred
+ from the environment.
+ pool_size: The number of grpc channels to maintain
+ in the internal channel pool.
+ credentials:
+ Thehe OAuth2 Credentials to use for this
+ client. If not passed (and if no ``_http`` object is
+ passed), falls back to the default inferred from the
+ environment.
+ client_options (Optional[Union[dict, google.api_core.client_options.ClientOptions]]):
+ Client options used to set user options
+ on the client. API Endpoint should be set through client_options.
+ Raises:
+ - RuntimeError if called outside of an async context (no running event loop)
+ - ValueError if pool_size is less than 1
+ """
+ # set up transport in registry
+ transport_str = f"pooled_grpc_asyncio_{pool_size}"
+ transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size)
+ BigtableClientMeta._transport_registry[transport_str] = transport
+ # set up client info headers for veneer library
+ client_info = DEFAULT_CLIENT_INFO
+ client_info.client_library_version = self._client_version()
+ # parse client options
+ if type(client_options) is dict:
+ client_options = client_options_lib.from_dict(client_options)
+ client_options = cast(
+ Optional[client_options_lib.ClientOptions], client_options
+ )
+ self._emulator_host = os.getenv(BIGTABLE_EMULATOR)
+ if self._emulator_host is not None:
+ # use insecure channel if emulator is set
+ if credentials is None:
+ credentials = google.auth.credentials.AnonymousCredentials()
+ if project is None:
+ project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT
+ # initialize client
+ ClientWithProject.__init__(
+ self,
+ credentials=credentials,
+ project=project,
+ client_options=client_options,
+ )
+ self._gapic_client = BigtableAsyncClient(
+ transport=transport_str,
+ credentials=credentials,
+ client_options=client_options,
+ client_info=client_info,
+ )
+ self.transport = cast(
+ PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport
+ )
+ # keep track of active instances to for warmup on channel refresh
+ self._active_instances: Set[_WarmedInstanceKey] = set()
+ # keep track of table objects associated with each instance
+ # only remove instance from _active_instances when all associated tables remove it
+ self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {}
+ self._channel_init_time = time.monotonic()
+ self._channel_refresh_tasks: list[asyncio.Task[None]] = []
+ if self._emulator_host is not None:
+ # connect to an emulator host
+ warnings.warn(
+ "Connecting to Bigtable emulator at {}".format(self._emulator_host),
+ RuntimeWarning,
+ stacklevel=2,
+ )
+ self.transport._grpc_channel = PooledChannel(
+ pool_size=pool_size,
+ host=self._emulator_host,
+ insecure=True,
+ )
+ # refresh cached stubs to use emulator pool
+ self.transport._stubs = {}
+ self.transport._prep_wrapped_messages(client_info)
+ else:
+ # attempt to start background channel refresh tasks
+ try:
+ self._start_background_channel_refresh()
+ except RuntimeError:
+ warnings.warn(
+ f"{self.__class__.__name__} should be started in an "
+ "asyncio event loop. Channel refresh will not be started",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+
+ @staticmethod
+ def _client_version() -> str:
+ """
+ Helper function to return the client version string for this client
+ """
+ return f"{google.cloud.bigtable.__version__}-data-async"
+
+ def _start_background_channel_refresh(self) -> None:
+ """
+ Starts a background task to ping and warm each channel in the pool
+ Raises:
+ - RuntimeError if not called in an asyncio event loop
+ """
+ if not self._channel_refresh_tasks and not self._emulator_host:
+ # raise RuntimeError if there is no event loop
+ asyncio.get_running_loop()
+ for channel_idx in range(self.transport.pool_size):
+ refresh_task = asyncio.create_task(self._manage_channel(channel_idx))
+ if sys.version_info >= (3, 8):
+ # task names supported in Python 3.8+
+ refresh_task.set_name(
+ f"{self.__class__.__name__} channel refresh {channel_idx}"
+ )
+ self._channel_refresh_tasks.append(refresh_task)
+
+ async def close(self, timeout: float = 2.0):
+ """
+ Cancel all background tasks
+ """
+ for task in self._channel_refresh_tasks:
+ task.cancel()
+ group = asyncio.gather(*self._channel_refresh_tasks, return_exceptions=True)
+ await asyncio.wait_for(group, timeout=timeout)
+ await self.transport.close()
+ self._channel_refresh_tasks = []
+
+ async def _ping_and_warm_instances(
+ self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None
+ ) -> list[BaseException | None]:
+ """
+ Prepares the backend for requests on a channel
+
+ Pings each Bigtable instance registered in `_active_instances` on the client
+
+ Args:
+ - channel: grpc channel to warm
+ - instance_key: if provided, only warm the instance associated with the key
+ Returns:
+ - sequence of results or exceptions from the ping requests
+ """
+ instance_list = (
+ [instance_key] if instance_key is not None else self._active_instances
+ )
+ ping_rpc = channel.unary_unary(
+ "/google.bigtable.v2.Bigtable/PingAndWarm",
+ request_serializer=PingAndWarmRequest.serialize,
+ )
+ # prepare list of coroutines to run
+ tasks = [
+ ping_rpc(
+ request={"name": instance_name, "app_profile_id": app_profile_id},
+ metadata=[
+ (
+ "x-goog-request-params",
+ f"name={instance_name}&app_profile_id={app_profile_id}",
+ )
+ ],
+ wait_for_ready=True,
+ )
+ for (instance_name, table_name, app_profile_id) in instance_list
+ ]
+ # execute coroutines in parallel
+ result_list = await asyncio.gather(*tasks, return_exceptions=True)
+ # return None in place of empty successful responses
+ return [r or None for r in result_list]
+
+ async def _manage_channel(
+ self,
+ channel_idx: int,
+ refresh_interval_min: float = 60 * 35,
+ refresh_interval_max: float = 60 * 45,
+ grace_period: float = 60 * 10,
+ ) -> None:
+ """
+ Background coroutine that periodically refreshes and warms a grpc channel
+
+ The backend will automatically close channels after 60 minutes, so
+ `refresh_interval` + `grace_period` should be < 60 minutes
+
+ Runs continuously until the client is closed
+
+ Args:
+ channel_idx: index of the channel in the transport's channel pool
+ refresh_interval_min: minimum interval before initiating refresh
+ process in seconds. Actual interval will be a random value
+ between `refresh_interval_min` and `refresh_interval_max`
+ refresh_interval_max: maximum interval before initiating refresh
+ process in seconds. Actual interval will be a random value
+ between `refresh_interval_min` and `refresh_interval_max`
+ grace_period: time to allow previous channel to serve existing
+ requests before closing, in seconds
+ """
+ first_refresh = self._channel_init_time + random.uniform(
+ refresh_interval_min, refresh_interval_max
+ )
+ next_sleep = max(first_refresh - time.monotonic(), 0)
+ if next_sleep > 0:
+ # warm the current channel immediately
+ channel = self.transport.channels[channel_idx]
+ await self._ping_and_warm_instances(channel)
+ # continuously refresh the channel every `refresh_interval` seconds
+ while True:
+ await asyncio.sleep(next_sleep)
+ # prepare new channel for use
+ new_channel = self.transport.grpc_channel._create_channel()
+ await self._ping_and_warm_instances(new_channel)
+ # cycle channel out of use, with long grace window before closure
+ start_timestamp = time.time()
+ await self.transport.replace_channel(
+ channel_idx, grace=grace_period, swap_sleep=10, new_channel=new_channel
+ )
+ # subtract the time spent waiting for the channel to be replaced
+ next_refresh = random.uniform(refresh_interval_min, refresh_interval_max)
+ next_sleep = next_refresh - (time.time() - start_timestamp)
+
+ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None:
+ """
+ Registers an instance with the client, and warms the channel pool
+ for the instance
+ The client will periodically refresh grpc channel pool used to make
+ requests, and new channels will be warmed for each registered instance
+ Channels will not be refreshed unless at least one instance is registered
+
+ Args:
+ - instance_id: id of the instance to register.
+ - owner: table that owns the instance. Owners will be tracked in
+ _instance_owners, and instances will only be unregistered when all
+ owners call _remove_instance_registration
+ """
+ instance_name = self._gapic_client.instance_path(self.project, instance_id)
+ instance_key = _WarmedInstanceKey(
+ instance_name, owner.table_name, owner.app_profile_id
+ )
+ self._instance_owners.setdefault(instance_key, set()).add(id(owner))
+ if instance_name not in self._active_instances:
+ self._active_instances.add(instance_key)
+ if self._channel_refresh_tasks:
+ # refresh tasks already running
+ # call ping and warm on all existing channels
+ for channel in self.transport.channels:
+ await self._ping_and_warm_instances(channel, instance_key)
+ else:
+ # refresh tasks aren't active. start them as background tasks
+ self._start_background_channel_refresh()
+
+ async def _remove_instance_registration(
+ self, instance_id: str, owner: TableAsync
+ ) -> bool:
+ """
+ Removes an instance from the client's registered instances, to prevent
+ warming new channels for the instance
+
+ If instance_id is not registered, or is still in use by other tables, returns False
+
+ Args:
+ - instance_id: id of the instance to remove
+ - owner: table that owns the instance. Owners will be tracked in
+ _instance_owners, and instances will only be unregistered when all
+ owners call _remove_instance_registration
+ Returns:
+ - True if instance was removed
+ """
+ instance_name = self._gapic_client.instance_path(self.project, instance_id)
+ instance_key = _WarmedInstanceKey(
+ instance_name, owner.table_name, owner.app_profile_id
+ )
+ owner_list = self._instance_owners.get(instance_key, set())
+ try:
+ owner_list.remove(id(owner))
+ if len(owner_list) == 0:
+ self._active_instances.remove(instance_key)
+ return True
+ except KeyError:
+ return False
+
+ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync:
+ """
+ Returns a table instance for making data API requests. All arguments are passed
+ directly to the TableAsync constructor.
+
+ Args:
+ instance_id: The Bigtable instance ID to associate with this client.
+ instance_id is combined with the client's project to fully
+ specify the instance
+ table_id: The ID of the table. table_id is combined with the
+ instance_id and the client's project to fully specify the table
+ app_profile_id: The app profile to associate with requests.
+ https://cloud.google.com/bigtable/docs/app-profiles
+ default_read_rows_operation_timeout: The default timeout for read rows
+ operations, in seconds. If not set, defaults to 600 seconds (10 minutes)
+ default_read_rows_attempt_timeout: The default timeout for individual
+ read rows rpc requests, in seconds. If not set, defaults to 20 seconds
+ default_mutate_rows_operation_timeout: The default timeout for mutate rows
+ operations, in seconds. If not set, defaults to 600 seconds (10 minutes)
+ default_mutate_rows_attempt_timeout: The default timeout for individual
+ mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds
+ default_operation_timeout: The default timeout for all other operations, in
+ seconds. If not set, defaults to 60 seconds
+ default_attempt_timeout: The default timeout for all other individual rpc
+ requests, in seconds. If not set, defaults to 20 seconds
+ default_read_rows_retryable_errors: a list of errors that will be retried
+ if encountered during read_rows and related operations.
+ Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted)
+ default_mutate_rows_retryable_errors: a list of errors that will be retried
+ if encountered during mutate_rows and related operations.
+ Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable)
+ default_retryable_errors: a list of errors that will be retried if
+ encountered during all other operations.
+ Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable)
+ """
+ return TableAsync(self, instance_id, table_id, *args, **kwargs)
+
+ async def __aenter__(self):
+ self._start_background_channel_refresh()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+ await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb)
+
+
+class TableAsync:
+ """
+ Main Data API surface
+
+ Table object maintains table_id, and app_profile_id context, and passes them with
+ each call
+ """
+
+ def __init__(
+ self,
+ client: BigtableDataClientAsync,
+ instance_id: str,
+ table_id: str,
+ app_profile_id: str | None = None,
+ *,
+ default_read_rows_operation_timeout: float = 600,
+ default_read_rows_attempt_timeout: float | None = 20,
+ default_mutate_rows_operation_timeout: float = 600,
+ default_mutate_rows_attempt_timeout: float | None = 60,
+ default_operation_timeout: float = 60,
+ default_attempt_timeout: float | None = 20,
+ default_read_rows_retryable_errors: Sequence[type[Exception]] = (
+ DeadlineExceeded,
+ ServiceUnavailable,
+ Aborted,
+ ),
+ default_mutate_rows_retryable_errors: Sequence[type[Exception]] = (
+ DeadlineExceeded,
+ ServiceUnavailable,
+ ),
+ default_retryable_errors: Sequence[type[Exception]] = (
+ DeadlineExceeded,
+ ServiceUnavailable,
+ ),
+ ):
+ """
+ Initialize a Table instance
+
+ Must be created within an async context (running event loop)
+
+ Args:
+ instance_id: The Bigtable instance ID to associate with this client.
+ instance_id is combined with the client's project to fully
+ specify the instance
+ table_id: The ID of the table. table_id is combined with the
+ instance_id and the client's project to fully specify the table
+ app_profile_id: The app profile to associate with requests.
+ https://cloud.google.com/bigtable/docs/app-profiles
+ default_read_rows_operation_timeout: The default timeout for read rows
+ operations, in seconds. If not set, defaults to 600 seconds (10 minutes)
+ default_read_rows_attempt_timeout: The default timeout for individual
+ read rows rpc requests, in seconds. If not set, defaults to 20 seconds
+ default_mutate_rows_operation_timeout: The default timeout for mutate rows
+ operations, in seconds. If not set, defaults to 600 seconds (10 minutes)
+ default_mutate_rows_attempt_timeout: The default timeout for individual
+ mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds
+ default_operation_timeout: The default timeout for all other operations, in
+ seconds. If not set, defaults to 60 seconds
+ default_attempt_timeout: The default timeout for all other individual rpc
+ requests, in seconds. If not set, defaults to 20 seconds
+ default_read_rows_retryable_errors: a list of errors that will be retried
+ if encountered during read_rows and related operations.
+ Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted)
+ default_mutate_rows_retryable_errors: a list of errors that will be retried
+ if encountered during mutate_rows and related operations.
+ Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable)
+ default_retryable_errors: a list of errors that will be retried if
+ encountered during all other operations.
+ Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable)
+ Raises:
+ - RuntimeError if called outside of an async context (no running event loop)
+ """
+ # NOTE: any changes to the signature of this method should also be reflected
+ # in client.get_table()
+ # validate timeouts
+ _validate_timeouts(
+ default_operation_timeout, default_attempt_timeout, allow_none=True
+ )
+ _validate_timeouts(
+ default_read_rows_operation_timeout,
+ default_read_rows_attempt_timeout,
+ allow_none=True,
+ )
+ _validate_timeouts(
+ default_mutate_rows_operation_timeout,
+ default_mutate_rows_attempt_timeout,
+ allow_none=True,
+ )
+
+ self.client = client
+ self.instance_id = instance_id
+ self.instance_name = self.client._gapic_client.instance_path(
+ self.client.project, instance_id
+ )
+ self.table_id = table_id
+ self.table_name = self.client._gapic_client.table_path(
+ self.client.project, instance_id, table_id
+ )
+ self.app_profile_id = app_profile_id
+
+ self.default_operation_timeout = default_operation_timeout
+ self.default_attempt_timeout = default_attempt_timeout
+ self.default_read_rows_operation_timeout = default_read_rows_operation_timeout
+ self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout
+ self.default_mutate_rows_operation_timeout = (
+ default_mutate_rows_operation_timeout
+ )
+ self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout
+
+ self.default_read_rows_retryable_errors = (
+ default_read_rows_retryable_errors or ()
+ )
+ self.default_mutate_rows_retryable_errors = (
+ default_mutate_rows_retryable_errors or ()
+ )
+ self.default_retryable_errors = default_retryable_errors or ()
+
+ # raises RuntimeError if called outside of an async context (no running event loop)
+ try:
+ self._register_instance_task = asyncio.create_task(
+ self.client._register_instance(instance_id, self)
+ )
+ except RuntimeError as e:
+ raise RuntimeError(
+ f"{self.__class__.__name__} must be created within an async event loop context."
+ ) from e
+
+ async def read_rows_stream(
+ self,
+ query: ReadRowsQuery,
+ *,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ ) -> AsyncIterable[Row]:
+ """
+ Read a set of rows from the table, based on the specified query.
+ Returns an iterator to asynchronously stream back row data.
+
+ Failed requests within operation_timeout will be retried based on the
+ retryable_errors list until operation_timeout is reached.
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - query: contains details about which rows to return
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will be retried within the budget.
+ Defaults to the Table's default_read_rows_operation_timeout
+ - attempt_timeout: the time budget for an individual network request, in seconds.
+ If it takes longer than this time to complete, the request will be cancelled with
+ a DeadlineExceeded exception, and a retry will be attempted.
+ Defaults to the Table's default_read_rows_attempt_timeout.
+ If None, defaults to operation_timeout.
+ - retryable_errors: a list of errors that will be retried if encountered.
+ Defaults to the Table's default_read_rows_retryable_errors
+ Returns:
+ - an asynchronous iterator that yields rows returned by the query
+ Raises:
+ - DeadlineExceeded: raised after operation timeout
+ will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions
+ from any retries that failed
+ - GoogleAPIError: raised if the request encounters an unrecoverable error
+ """
+ operation_timeout, attempt_timeout = _get_timeouts(
+ operation_timeout, attempt_timeout, self
+ )
+ retryable_excs = _get_retryable_errors(retryable_errors, self)
+
+ row_merger = _ReadRowsOperationAsync(
+ query,
+ self,
+ operation_timeout=operation_timeout,
+ attempt_timeout=attempt_timeout,
+ retryable_exceptions=retryable_excs,
+ )
+ return row_merger.start_operation()
+
+ async def read_rows(
+ self,
+ query: ReadRowsQuery,
+ *,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ ) -> list[Row]:
+ """
+ Read a set of rows from the table, based on the specified query.
+ Retruns results as a list of Row objects when the request is complete.
+ For streamed results, use read_rows_stream.
+
+ Failed requests within operation_timeout will be retried based on the
+ retryable_errors list until operation_timeout is reached.
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - query: contains details about which rows to return
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will be retried within the budget.
+ Defaults to the Table's default_read_rows_operation_timeout
+ - attempt_timeout: the time budget for an individual network request, in seconds.
+ If it takes longer than this time to complete, the request will be cancelled with
+ a DeadlineExceeded exception, and a retry will be attempted.
+ Defaults to the Table's default_read_rows_attempt_timeout.
+ If None, defaults to operation_timeout.
+ If None, defaults to the Table's default_read_rows_attempt_timeout,
+ or the operation_timeout if that is also None.
+ - retryable_errors: a list of errors that will be retried if encountered.
+ Defaults to the Table's default_read_rows_retryable_errors.
+ Returns:
+ - a list of Rows returned by the query
+ Raises:
+ - DeadlineExceeded: raised after operation timeout
+ will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions
+ from any retries that failed
+ - GoogleAPIError: raised if the request encounters an unrecoverable error
+ """
+ row_generator = await self.read_rows_stream(
+ query,
+ operation_timeout=operation_timeout,
+ attempt_timeout=attempt_timeout,
+ retryable_errors=retryable_errors,
+ )
+ return [row async for row in row_generator]
+
+ async def read_row(
+ self,
+ row_key: str | bytes,
+ *,
+ row_filter: RowFilter | None = None,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ ) -> Row | None:
+ """
+ Read a single row from the table, based on the specified key.
+
+ Failed requests within operation_timeout will be retried based on the
+ retryable_errors list until operation_timeout is reached.
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - query: contains details about which rows to return
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will be retried within the budget.
+ Defaults to the Table's default_read_rows_operation_timeout
+ - attempt_timeout: the time budget for an individual network request, in seconds.
+ If it takes longer than this time to complete, the request will be cancelled with
+ a DeadlineExceeded exception, and a retry will be attempted.
+ Defaults to the Table's default_read_rows_attempt_timeout.
+ If None, defaults to operation_timeout.
+ - retryable_errors: a list of errors that will be retried if encountered.
+ Defaults to the Table's default_read_rows_retryable_errors.
+ Returns:
+ - a Row object if the row exists, otherwise None
+ Raises:
+ - DeadlineExceeded: raised after operation timeout
+ will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions
+ from any retries that failed
+ - GoogleAPIError: raised if the request encounters an unrecoverable error
+ """
+ if row_key is None:
+ raise ValueError("row_key must be string or bytes")
+ query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1)
+ results = await self.read_rows(
+ query,
+ operation_timeout=operation_timeout,
+ attempt_timeout=attempt_timeout,
+ retryable_errors=retryable_errors,
+ )
+ if len(results) == 0:
+ return None
+ return results[0]
+
+ async def read_rows_sharded(
+ self,
+ sharded_query: ShardedQuery,
+ *,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ ) -> list[Row]:
+ """
+ Runs a sharded query in parallel, then return the results in a single list.
+ Results will be returned in the order of the input queries.
+
+ This function is intended to be run on the results on a query.shard() call:
+
+ ```
+ table_shard_keys = await table.sample_row_keys()
+ query = ReadRowsQuery(...)
+ shard_queries = query.shard(table_shard_keys)
+ results = await table.read_rows_sharded(shard_queries)
+ ```
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - sharded_query: a sharded query to execute
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will be retried within the budget.
+ Defaults to the Table's default_read_rows_operation_timeout
+ - attempt_timeout: the time budget for an individual network request, in seconds.
+ If it takes longer than this time to complete, the request will be cancelled with
+ a DeadlineExceeded exception, and a retry will be attempted.
+ Defaults to the Table's default_read_rows_attempt_timeout.
+ If None, defaults to operation_timeout.
+ - retryable_errors: a list of errors that will be retried if encountered.
+ Defaults to the Table's default_read_rows_retryable_errors.
+ Raises:
+ - ShardedReadRowsExceptionGroup: if any of the queries failed
+ - ValueError: if the query_list is empty
+ """
+ if not sharded_query:
+ raise ValueError("empty sharded_query")
+ # reduce operation_timeout between batches
+ operation_timeout, attempt_timeout = _get_timeouts(
+ operation_timeout, attempt_timeout, self
+ )
+ timeout_generator = _attempt_timeout_generator(
+ operation_timeout, operation_timeout
+ )
+ # submit shards in batches if the number of shards goes over _CONCURRENCY_LIMIT
+ batched_queries = [
+ sharded_query[i : i + _CONCURRENCY_LIMIT]
+ for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT)
+ ]
+ # run batches and collect results
+ results_list = []
+ error_dict = {}
+ shard_idx = 0
+ for batch in batched_queries:
+ batch_operation_timeout = next(timeout_generator)
+ routine_list = [
+ self.read_rows(
+ query,
+ operation_timeout=batch_operation_timeout,
+ attempt_timeout=min(attempt_timeout, batch_operation_timeout),
+ retryable_errors=retryable_errors,
+ )
+ for query in batch
+ ]
+ batch_result = await asyncio.gather(*routine_list, return_exceptions=True)
+ for result in batch_result:
+ if isinstance(result, Exception):
+ error_dict[shard_idx] = result
+ elif isinstance(result, BaseException):
+ # BaseException not expected; raise immediately
+ raise result
+ else:
+ results_list.extend(result)
+ shard_idx += 1
+ if error_dict:
+ # if any sub-request failed, raise an exception instead of returning results
+ raise ShardedReadRowsExceptionGroup(
+ [
+ FailedQueryShardError(idx, sharded_query[idx], e)
+ for idx, e in error_dict.items()
+ ],
+ results_list,
+ len(sharded_query),
+ )
+ return results_list
+
+ async def row_exists(
+ self,
+ row_key: str | bytes,
+ *,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS,
+ ) -> bool:
+ """
+ Return a boolean indicating whether the specified row exists in the table.
+ uses the filters: chain(limit cells per row = 1, strip value)
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - row_key: the key of the row to check
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will be retried within the budget.
+ Defaults to the Table's default_read_rows_operation_timeout
+ - attempt_timeout: the time budget for an individual network request, in seconds.
+ If it takes longer than this time to complete, the request will be cancelled with
+ a DeadlineExceeded exception, and a retry will be attempted.
+ Defaults to the Table's default_read_rows_attempt_timeout.
+ If None, defaults to operation_timeout.
+ - retryable_errors: a list of errors that will be retried if encountered.
+ Defaults to the Table's default_read_rows_retryable_errors.
+ Returns:
+ - a bool indicating whether the row exists
+ Raises:
+ - DeadlineExceeded: raised after operation timeout
+ will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions
+ from any retries that failed
+ - GoogleAPIError: raised if the request encounters an unrecoverable error
+ """
+ if row_key is None:
+ raise ValueError("row_key must be string or bytes")
+
+ strip_filter = StripValueTransformerFilter(flag=True)
+ limit_filter = CellsRowLimitFilter(1)
+ chain_filter = RowFilterChain(filters=[limit_filter, strip_filter])
+ query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter)
+ results = await self.read_rows(
+ query,
+ operation_timeout=operation_timeout,
+ attempt_timeout=attempt_timeout,
+ retryable_errors=retryable_errors,
+ )
+ return len(results) > 0
+
+ async def sample_row_keys(
+ self,
+ *,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT,
+ attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT,
+ retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT,
+ ) -> RowKeySamples:
+ """
+ Return a set of RowKeySamples that delimit contiguous sections of the table of
+ approximately equal size
+
+ RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that
+ can be parallelized across multiple backend nodes read_rows and read_rows_stream
+ requests will call sample_row_keys internally for this purpose when sharding is enabled
+
+ RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of
+ row_keys, along with offset positions in the table
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will be retried within the budget.i
+ Defaults to the Table's default_operation_timeout
+ - attempt_timeout: the time budget for an individual network request, in seconds.
+ If it takes longer than this time to complete, the request will be cancelled with
+ a DeadlineExceeded exception, and a retry will be attempted.
+ Defaults to the Table's default_attempt_timeout.
+ If None, defaults to operation_timeout.
+ - retryable_errors: a list of errors that will be retried if encountered.
+ Defaults to the Table's default_retryable_errors.
+ Returns:
+ - a set of RowKeySamples the delimit contiguous sections of the table
+ Raises:
+ - DeadlineExceeded: raised after operation timeout
+ will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions
+ from any retries that failed
+ - GoogleAPIError: raised if the request encounters an unrecoverable error
+ """
+ # prepare timeouts
+ operation_timeout, attempt_timeout = _get_timeouts(
+ operation_timeout, attempt_timeout, self
+ )
+ attempt_timeout_gen = _attempt_timeout_generator(
+ attempt_timeout, operation_timeout
+ )
+ # prepare retryable
+ retryable_excs = _get_retryable_errors(retryable_errors, self)
+ predicate = retries.if_exception_type(*retryable_excs)
+
+ sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60)
+
+ # prepare request
+ metadata = _make_metadata(self.table_name, self.app_profile_id)
+
+ async def execute_rpc():
+ results = await self.client._gapic_client.sample_row_keys(
+ table_name=self.table_name,
+ app_profile_id=self.app_profile_id,
+ timeout=next(attempt_timeout_gen),
+ metadata=metadata,
+ retry=None,
+ )
+ return [(s.row_key, s.offset_bytes) async for s in results]
+
+ return await retries.retry_target_async(
+ execute_rpc,
+ predicate,
+ sleep_generator,
+ operation_timeout,
+ exception_factory=_retry_exception_factory,
+ )
+
+ def mutations_batcher(
+ self,
+ *,
+ flush_interval: float | None = 5,
+ flush_limit_mutation_count: int | None = 1000,
+ flush_limit_bytes: int = 20 * _MB_SIZE,
+ flow_control_max_mutation_count: int = 100_000,
+ flow_control_max_bytes: int = 100 * _MB_SIZE,
+ batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
+ batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
+ batch_retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
+ ) -> MutationsBatcherAsync:
+ """
+ Returns a new mutations batcher instance.
+
+ Can be used to iteratively add mutations that are flushed as a group,
+ to avoid excess network calls
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - flush_interval: Automatically flush every flush_interval seconds. If None,
+ a table default will be used
+ - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count
+ mutations are added across all entries. If None, this limit is ignored.
+ - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added.
+ - flow_control_max_mutation_count: Maximum number of inflight mutations.
+ - flow_control_max_bytes: Maximum number of inflight bytes.
+ - batch_operation_timeout: timeout for each mutate_rows operation, in seconds.
+ Defaults to the Table's default_mutate_rows_operation_timeout
+ - batch_attempt_timeout: timeout for each individual request, in seconds.
+ Defaults to the Table's default_mutate_rows_attempt_timeout.
+ If None, defaults to batch_operation_timeout.
+ - batch_retryable_errors: a list of errors that will be retried if encountered.
+ Defaults to the Table's default_mutate_rows_retryable_errors.
+ Returns:
+ - a MutationsBatcherAsync context manager that can batch requests
+ """
+ return MutationsBatcherAsync(
+ self,
+ flush_interval=flush_interval,
+ flush_limit_mutation_count=flush_limit_mutation_count,
+ flush_limit_bytes=flush_limit_bytes,
+ flow_control_max_mutation_count=flow_control_max_mutation_count,
+ flow_control_max_bytes=flow_control_max_bytes,
+ batch_operation_timeout=batch_operation_timeout,
+ batch_attempt_timeout=batch_attempt_timeout,
+ batch_retryable_errors=batch_retryable_errors,
+ )
+
+ async def mutate_row(
+ self,
+ row_key: str | bytes,
+ mutations: list[Mutation] | Mutation,
+ *,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT,
+ attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT,
+ retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT,
+ ):
+ """
+ Mutates a row atomically.
+
+ Cells already present in the row are left unchanged unless explicitly changed
+ by ``mutation``.
+
+ Idempotent operations (i.e, all mutations have an explicit timestamp) will be
+ retried on server failure. Non-idempotent operations will not.
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - row_key: the row to apply mutations to
+ - mutations: the set of mutations to apply to the row
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will be retried within the budget.
+ Defaults to the Table's default_operation_timeout
+ - attempt_timeout: the time budget for an individual network request, in seconds.
+ If it takes longer than this time to complete, the request will be cancelled with
+ a DeadlineExceeded exception, and a retry will be attempted.
+ Defaults to the Table's default_attempt_timeout.
+ If None, defaults to operation_timeout.
+ - retryable_errors: a list of errors that will be retried if encountered.
+ Only idempotent mutations will be retried. Defaults to the Table's
+ default_retryable_errors.
+ Raises:
+ - DeadlineExceeded: raised after operation timeout
+ will be chained with a RetryExceptionGroup containing all
+ GoogleAPIError exceptions from any retries that failed
+ - GoogleAPIError: raised on non-idempotent operations that cannot be
+ safely retried.
+ - ValueError if invalid arguments are provided
+ """
+ operation_timeout, attempt_timeout = _get_timeouts(
+ operation_timeout, attempt_timeout, self
+ )
+
+ if not mutations:
+ raise ValueError("No mutations provided")
+ mutations_list = mutations if isinstance(mutations, list) else [mutations]
+
+ if all(mutation.is_idempotent() for mutation in mutations_list):
+ # mutations are all idempotent and safe to retry
+ predicate = retries.if_exception_type(
+ *_get_retryable_errors(retryable_errors, self)
+ )
+ else:
+ # mutations should not be retried
+ predicate = retries.if_exception_type()
+
+ sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60)
+
+ target = partial(
+ self.client._gapic_client.mutate_row,
+ row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
+ mutations=[mutation._to_pb() for mutation in mutations_list],
+ table_name=self.table_name,
+ app_profile_id=self.app_profile_id,
+ timeout=attempt_timeout,
+ metadata=_make_metadata(self.table_name, self.app_profile_id),
+ retry=None,
+ )
+ return await retries.retry_target_async(
+ target,
+ predicate,
+ sleep_generator,
+ operation_timeout,
+ exception_factory=_retry_exception_factory,
+ )
+
+ async def bulk_mutate_rows(
+ self,
+ mutation_entries: list[RowMutationEntry],
+ *,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
+ attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
+ retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
+ ):
+ """
+ Applies mutations for multiple rows in a single batched request.
+
+ Each individual RowMutationEntry is applied atomically, but separate entries
+ may be applied in arbitrary order (even for entries targetting the same row)
+ In total, the row_mutations can contain at most 100000 individual mutations
+ across all entries
+
+ Idempotent entries (i.e., entries with mutations with explicit timestamps)
+ will be retried on failure. Non-idempotent will not, and will reported in a
+ raised exception group
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - mutation_entries: the batches of mutations to apply
+ Each entry will be applied atomically, but entries will be applied
+ in arbitrary order
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will be retried within the budget.
+ Defaults to the Table's default_mutate_rows_operation_timeout
+ - attempt_timeout: the time budget for an individual network request, in seconds.
+ If it takes longer than this time to complete, the request will be cancelled with
+ a DeadlineExceeded exception, and a retry will be attempted.
+ Defaults to the Table's default_mutate_rows_attempt_timeout.
+ If None, defaults to operation_timeout.
+ - retryable_errors: a list of errors that will be retried if encountered.
+ Defaults to the Table's default_mutate_rows_retryable_errors
+ Raises:
+ - MutationsExceptionGroup if one or more mutations fails
+ Contains details about any failed entries in .exceptions
+ - ValueError if invalid arguments are provided
+ """
+ operation_timeout, attempt_timeout = _get_timeouts(
+ operation_timeout, attempt_timeout, self
+ )
+ retryable_excs = _get_retryable_errors(retryable_errors, self)
+
+ operation = _MutateRowsOperationAsync(
+ self.client._gapic_client,
+ self,
+ mutation_entries,
+ operation_timeout,
+ attempt_timeout,
+ retryable_exceptions=retryable_excs,
+ )
+ await operation.start()
+
+ async def check_and_mutate_row(
+ self,
+ row_key: str | bytes,
+ predicate: RowFilter | None,
+ *,
+ true_case_mutations: Mutation | list[Mutation] | None = None,
+ false_case_mutations: Mutation | list[Mutation] | None = None,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT,
+ ) -> bool:
+ """
+ Mutates a row atomically based on the output of a predicate filter
+
+ Non-idempotent operation: will not be retried
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - row_key: the key of the row to mutate
+ - predicate: the filter to be applied to the contents of the specified row.
+ Depending on whether or not any results are yielded,
+ either true_case_mutations or false_case_mutations will be executed.
+ If None, checks that the row contains any values at all.
+ - true_case_mutations:
+ Changes to be atomically applied to the specified row if
+ predicate yields at least one cell when
+ applied to row_key. Entries are applied in order,
+ meaning that earlier mutations can be masked by later
+ ones. Must contain at least one entry if
+ false_case_mutations is empty, and at most 100000.
+ - false_case_mutations:
+ Changes to be atomically applied to the specified row if
+ predicate_filter does not yield any cells when
+ applied to row_key. Entries are applied in order,
+ meaning that earlier mutations can be masked by later
+ ones. Must contain at least one entry if
+ `true_case_mutations is empty, and at most 100000.
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will not be retried. Defaults to the Table's default_operation_timeout
+ Returns:
+ - bool indicating whether the predicate was true or false
+ Raises:
+ - GoogleAPIError exceptions from grpc call
+ """
+ operation_timeout, _ = _get_timeouts(operation_timeout, None, self)
+ if true_case_mutations is not None and not isinstance(
+ true_case_mutations, list
+ ):
+ true_case_mutations = [true_case_mutations]
+ true_case_list = [m._to_pb() for m in true_case_mutations or []]
+ if false_case_mutations is not None and not isinstance(
+ false_case_mutations, list
+ ):
+ false_case_mutations = [false_case_mutations]
+ false_case_list = [m._to_pb() for m in false_case_mutations or []]
+ metadata = _make_metadata(self.table_name, self.app_profile_id)
+ result = await self.client._gapic_client.check_and_mutate_row(
+ true_mutations=true_case_list,
+ false_mutations=false_case_list,
+ predicate_filter=predicate._to_pb() if predicate is not None else None,
+ row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
+ table_name=self.table_name,
+ app_profile_id=self.app_profile_id,
+ metadata=metadata,
+ timeout=operation_timeout,
+ retry=None,
+ )
+ return result.predicate_matched
+
+ async def read_modify_write_row(
+ self,
+ row_key: str | bytes,
+ rules: ReadModifyWriteRule | list[ReadModifyWriteRule],
+ *,
+ operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT,
+ ) -> Row:
+ """
+ Reads and modifies a row atomically according to input ReadModifyWriteRules,
+ and returns the contents of all modified cells
+
+ The new value for the timestamp is the greater of the existing timestamp or
+ the current server time.
+
+ Non-idempotent operation: will not be retried
+
+ Warning: BigtableDataClientAsync is currently in preview, and is not
+ yet recommended for production use.
+
+ Args:
+ - row_key: the key of the row to apply read/modify/write rules to
+ - rules: A rule or set of rules to apply to the row.
+ Rules are applied in order, meaning that earlier rules will affect the
+ results of later ones.
+ - operation_timeout: the time budget for the entire operation, in seconds.
+ Failed requests will not be retried.
+ Defaults to the Table's default_operation_timeout.
+ Returns:
+ - Row: containing cell data that was modified as part of the
+ operation
+ Raises:
+ - GoogleAPIError exceptions from grpc call
+ - ValueError if invalid arguments are provided
+ """
+ operation_timeout, _ = _get_timeouts(operation_timeout, None, self)
+ if operation_timeout <= 0:
+ raise ValueError("operation_timeout must be greater than 0")
+ if rules is not None and not isinstance(rules, list):
+ rules = [rules]
+ if not rules:
+ raise ValueError("rules must contain at least one item")
+ metadata = _make_metadata(self.table_name, self.app_profile_id)
+ result = await self.client._gapic_client.read_modify_write_row(
+ rules=[rule._to_pb() for rule in rules],
+ row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
+ table_name=self.table_name,
+ app_profile_id=self.app_profile_id,
+ metadata=metadata,
+ timeout=operation_timeout,
+ retry=None,
+ )
+ # construct Row from result
+ return Row._from_pb(result.row)
+
+ async def close(self):
+ """
+ Called to close the Table instance and release any resources held by it.
+ """
+ self._register_instance_task.cancel()
+ await self.client._remove_instance_registration(self.instance_id, self)
+
+ async def __aenter__(self):
+ """
+ Implement async context manager protocol
+
+ Ensure registration task has time to run, so that
+ grpc channels will be warmed for the specified instance
+ """
+ await self._register_instance_task
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ """
+ Implement async context manager protocol
+
+ Unregister this instance with the client, so that
+ grpc channels will no longer be warmed
+ """
+ await self.close()
diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py
new file mode 100644
index 000000000..5d5dd535e
--- /dev/null
+++ b/google/cloud/bigtable/data/_async/mutations_batcher.py
@@ -0,0 +1,501 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import annotations
+
+from typing import Any, Sequence, TYPE_CHECKING
+import asyncio
+import atexit
+import warnings
+from collections import deque
+
+from google.cloud.bigtable.data.mutations import RowMutationEntry
+from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
+from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
+from google.cloud.bigtable.data._helpers import _get_retryable_errors
+from google.cloud.bigtable.data._helpers import _get_timeouts
+from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
+
+from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync
+from google.cloud.bigtable.data._async._mutate_rows import (
+ _MUTATE_ROWS_REQUEST_MUTATION_LIMIT,
+)
+from google.cloud.bigtable.data.mutations import Mutation
+
+if TYPE_CHECKING:
+ from google.cloud.bigtable.data._async.client import TableAsync
+
+# used to make more readable default values
+_MB_SIZE = 1024 * 1024
+
+
+class _FlowControlAsync:
+ """
+ Manages flow control for batched mutations. Mutations are registered against
+ the FlowControl object before being sent, which will block if size or count
+ limits have reached capacity. As mutations completed, they are removed from
+ the FlowControl object, which will notify any blocked requests that there
+ is additional capacity.
+
+ Flow limits are not hard limits. If a single mutation exceeds the configured
+ limits, it will be allowed as a single batch when the capacity is available.
+ """
+
+ def __init__(
+ self,
+ max_mutation_count: int,
+ max_mutation_bytes: int,
+ ):
+ """
+ Args:
+ - max_mutation_count: maximum number of mutations to send in a single rpc.
+ This corresponds to individual mutations in a single RowMutationEntry.
+ - max_mutation_bytes: maximum number of bytes to send in a single rpc.
+ """
+ self._max_mutation_count = max_mutation_count
+ self._max_mutation_bytes = max_mutation_bytes
+ if self._max_mutation_count < 1:
+ raise ValueError("max_mutation_count must be greater than 0")
+ if self._max_mutation_bytes < 1:
+ raise ValueError("max_mutation_bytes must be greater than 0")
+ self._capacity_condition = asyncio.Condition()
+ self._in_flight_mutation_count = 0
+ self._in_flight_mutation_bytes = 0
+
+ def _has_capacity(self, additional_count: int, additional_size: int) -> bool:
+ """
+ Checks if there is capacity to send a new entry with the given size and count
+
+ FlowControl limits are not hard limits. If a single mutation exceeds
+ the configured flow limits, it will be sent in a single batch when
+ previous batches have completed.
+
+ Args:
+ - additional_count: number of mutations in the pending entry
+ - additional_size: size of the pending entry
+ Returns:
+ - True if there is capacity to send the pending entry, False otherwise
+ """
+ # adjust limits to allow overly large mutations
+ acceptable_size = max(self._max_mutation_bytes, additional_size)
+ acceptable_count = max(self._max_mutation_count, additional_count)
+ # check if we have capacity for new mutation
+ new_size = self._in_flight_mutation_bytes + additional_size
+ new_count = self._in_flight_mutation_count + additional_count
+ return new_size <= acceptable_size and new_count <= acceptable_count
+
+ async def remove_from_flow(
+ self, mutations: RowMutationEntry | list[RowMutationEntry]
+ ) -> None:
+ """
+ Removes mutations from flow control. This method should be called once
+ for each mutation that was sent to add_to_flow, after the corresponding
+ operation is complete.
+
+ Args:
+ - mutations: mutation or list of mutations to remove from flow control
+ """
+ if not isinstance(mutations, list):
+ mutations = [mutations]
+ total_count = sum(len(entry.mutations) for entry in mutations)
+ total_size = sum(entry.size() for entry in mutations)
+ self._in_flight_mutation_count -= total_count
+ self._in_flight_mutation_bytes -= total_size
+ # notify any blocked requests that there is additional capacity
+ async with self._capacity_condition:
+ self._capacity_condition.notify_all()
+
+ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]):
+ """
+ Generator function that registers mutations with flow control. As mutations
+ are accepted into the flow control, they are yielded back to the caller,
+ to be sent in a batch. If the flow control is at capacity, the generator
+ will block until there is capacity available.
+
+ Args:
+ - mutations: list mutations to break up into batches
+ Yields:
+ - list of mutations that have reserved space in the flow control.
+ Each batch contains at least one mutation.
+ """
+ if not isinstance(mutations, list):
+ mutations = [mutations]
+ start_idx = 0
+ end_idx = 0
+ while end_idx < len(mutations):
+ start_idx = end_idx
+ batch_mutation_count = 0
+ # fill up batch until we hit capacity
+ async with self._capacity_condition:
+ while end_idx < len(mutations):
+ next_entry = mutations[end_idx]
+ next_size = next_entry.size()
+ next_count = len(next_entry.mutations)
+ if (
+ self._has_capacity(next_count, next_size)
+ # make sure not to exceed per-request mutation count limits
+ and (batch_mutation_count + next_count)
+ <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT
+ ):
+ # room for new mutation; add to batch
+ end_idx += 1
+ batch_mutation_count += next_count
+ self._in_flight_mutation_bytes += next_size
+ self._in_flight_mutation_count += next_count
+ elif start_idx != end_idx:
+ # we have at least one mutation in the batch, so send it
+ break
+ else:
+ # batch is empty. Block until we have capacity
+ await self._capacity_condition.wait_for(
+ lambda: self._has_capacity(next_count, next_size)
+ )
+ yield mutations[start_idx:end_idx]
+
+
+class MutationsBatcherAsync:
+ """
+ Allows users to send batches using context manager API:
+
+ Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining
+ to use as few network requests as required
+
+ Flushes:
+ - every flush_interval seconds
+ - after queue reaches flush_count in quantity
+ - after queue reaches flush_size_bytes in storage size
+ - when batcher is closed or destroyed
+
+ async with table.mutations_batcher() as batcher:
+ for i in range(10):
+ batcher.add(row, mut)
+ """
+
+ def __init__(
+ self,
+ table: "TableAsync",
+ *,
+ flush_interval: float | None = 5,
+ flush_limit_mutation_count: int | None = 1000,
+ flush_limit_bytes: int = 20 * _MB_SIZE,
+ flow_control_max_mutation_count: int = 100_000,
+ flow_control_max_bytes: int = 100 * _MB_SIZE,
+ batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
+ batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
+ batch_retryable_errors: Sequence[type[Exception]]
+ | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
+ ):
+ """
+ Args:
+ - table: Table to preform rpc calls
+ - flush_interval: Automatically flush every flush_interval seconds.
+ If None, no time-based flushing is performed.
+ - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count
+ mutations are added across all entries. If None, this limit is ignored.
+ - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added.
+ - flow_control_max_mutation_count: Maximum number of inflight mutations.
+ - flow_control_max_bytes: Maximum number of inflight bytes.
+ - batch_operation_timeout: timeout for each mutate_rows operation, in seconds.
+ If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout.
+ - batch_attempt_timeout: timeout for each individual request, in seconds.
+ If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout.
+ If None, defaults to batch_operation_timeout.
+ - batch_retryable_errors: a list of errors that will be retried if encountered.
+ Defaults to the Table's default_mutate_rows_retryable_errors.
+ """
+ self._operation_timeout, self._attempt_timeout = _get_timeouts(
+ batch_operation_timeout, batch_attempt_timeout, table
+ )
+ self._retryable_errors: list[type[Exception]] = _get_retryable_errors(
+ batch_retryable_errors, table
+ )
+
+ self.closed: bool = False
+ self._table = table
+ self._staged_entries: list[RowMutationEntry] = []
+ self._staged_count, self._staged_bytes = 0, 0
+ self._flow_control = _FlowControlAsync(
+ flow_control_max_mutation_count, flow_control_max_bytes
+ )
+ self._flush_limit_bytes = flush_limit_bytes
+ self._flush_limit_count = (
+ flush_limit_mutation_count
+ if flush_limit_mutation_count is not None
+ else float("inf")
+ )
+ self._flush_timer = self._start_flush_timer(flush_interval)
+ self._flush_jobs: set[asyncio.Future[None]] = set()
+ # MutationExceptionGroup reports number of successful entries along with failures
+ self._entries_processed_since_last_raise: int = 0
+ self._exceptions_since_last_raise: int = 0
+ # keep track of the first and last _exception_list_limit exceptions
+ self._exception_list_limit: int = 10
+ self._oldest_exceptions: list[Exception] = []
+ self._newest_exceptions: deque[Exception] = deque(
+ maxlen=self._exception_list_limit
+ )
+ # clean up on program exit
+ atexit.register(self._on_exit)
+
+ def _start_flush_timer(self, interval: float | None) -> asyncio.Future[None]:
+ """
+ Set up a background task to flush the batcher every interval seconds
+
+ If interval is None, an empty future is returned
+
+ Args:
+ - flush_interval: Automatically flush every flush_interval seconds.
+ If None, no time-based flushing is performed.
+ Returns:
+ - asyncio.Future that represents the background task
+ """
+ if interval is None or self.closed:
+ empty_future: asyncio.Future[None] = asyncio.Future()
+ empty_future.set_result(None)
+ return empty_future
+
+ async def timer_routine(self, interval: float):
+ """
+ Triggers new flush tasks every `interval` seconds
+ """
+ while not self.closed:
+ await asyncio.sleep(interval)
+ # add new flush task to list
+ if not self.closed and self._staged_entries:
+ self._schedule_flush()
+
+ timer_task = asyncio.create_task(timer_routine(self, interval))
+ return timer_task
+
+ async def append(self, mutation_entry: RowMutationEntry):
+ """
+ Add a new set of mutations to the internal queue
+
+ TODO: return a future to track completion of this entry
+
+ Args:
+ - mutation_entry: new entry to add to flush queue
+ Raises:
+ - RuntimeError if batcher is closed
+ - ValueError if an invalid mutation type is added
+ """
+ if self.closed:
+ raise RuntimeError("Cannot append to closed MutationsBatcher")
+ if isinstance(mutation_entry, Mutation): # type: ignore
+ raise ValueError(
+ f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher"
+ )
+ self._staged_entries.append(mutation_entry)
+ # start a new flush task if limits exceeded
+ self._staged_count += len(mutation_entry.mutations)
+ self._staged_bytes += mutation_entry.size()
+ if (
+ self._staged_count >= self._flush_limit_count
+ or self._staged_bytes >= self._flush_limit_bytes
+ ):
+ self._schedule_flush()
+ # yield to the event loop to allow flush to run
+ await asyncio.sleep(0)
+
+ def _schedule_flush(self) -> asyncio.Future[None] | None:
+ """Update the flush task to include the latest staged entries"""
+ if self._staged_entries:
+ entries, self._staged_entries = self._staged_entries, []
+ self._staged_count, self._staged_bytes = 0, 0
+ new_task = self._create_bg_task(self._flush_internal, entries)
+ new_task.add_done_callback(self._flush_jobs.remove)
+ self._flush_jobs.add(new_task)
+ return new_task
+ return None
+
+ async def _flush_internal(self, new_entries: list[RowMutationEntry]):
+ """
+ Flushes a set of mutations to the server, and updates internal state
+
+ Args:
+ - new_entries: list of RowMutationEntry objects to flush
+ """
+ # flush new entries
+ in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = []
+ async for batch in self._flow_control.add_to_flow(new_entries):
+ batch_task = self._create_bg_task(self._execute_mutate_rows, batch)
+ in_process_requests.append(batch_task)
+ # wait for all inflight requests to complete
+ found_exceptions = await self._wait_for_batch_results(*in_process_requests)
+ # update exception data to reflect any new errors
+ self._entries_processed_since_last_raise += len(new_entries)
+ self._add_exceptions(found_exceptions)
+
+ async def _execute_mutate_rows(
+ self, batch: list[RowMutationEntry]
+ ) -> list[FailedMutationEntryError]:
+ """
+ Helper to execute mutation operation on a batch
+
+ Args:
+ - batch: list of RowMutationEntry objects to send to server
+ - timeout: timeout in seconds. Used as operation_timeout and attempt_timeout.
+ If not given, will use table defaults
+ Returns:
+ - list of FailedMutationEntryError objects for mutations that failed.
+ FailedMutationEntryError objects will not contain index information
+ """
+ try:
+ operation = _MutateRowsOperationAsync(
+ self._table.client._gapic_client,
+ self._table,
+ batch,
+ operation_timeout=self._operation_timeout,
+ attempt_timeout=self._attempt_timeout,
+ retryable_exceptions=self._retryable_errors,
+ )
+ await operation.start()
+ except MutationsExceptionGroup as e:
+ # strip index information from exceptions, since it is not useful in a batch context
+ for subexc in e.exceptions:
+ subexc.index = None
+ return list(e.exceptions)
+ finally:
+ # mark batch as complete in flow control
+ await self._flow_control.remove_from_flow(batch)
+ return []
+
+ def _add_exceptions(self, excs: list[Exception]):
+ """
+ Add new list of exceptions to internal store. To avoid unbounded memory,
+ the batcher will store the first and last _exception_list_limit exceptions,
+ and discard any in between.
+ """
+ self._exceptions_since_last_raise += len(excs)
+ if excs and len(self._oldest_exceptions) < self._exception_list_limit:
+ # populate oldest_exceptions with found_exceptions
+ addition_count = self._exception_list_limit - len(self._oldest_exceptions)
+ self._oldest_exceptions.extend(excs[:addition_count])
+ excs = excs[addition_count:]
+ if excs:
+ # populate newest_exceptions with remaining found_exceptions
+ self._newest_exceptions.extend(excs[-self._exception_list_limit :])
+
+ def _raise_exceptions(self):
+ """
+ Raise any unreported exceptions from background flush operations
+
+ Raises:
+ - MutationsExceptionGroup with all unreported exceptions
+ """
+ if self._oldest_exceptions or self._newest_exceptions:
+ oldest, self._oldest_exceptions = self._oldest_exceptions, []
+ newest = list(self._newest_exceptions)
+ self._newest_exceptions.clear()
+ entry_count, self._entries_processed_since_last_raise = (
+ self._entries_processed_since_last_raise,
+ 0,
+ )
+ exc_count, self._exceptions_since_last_raise = (
+ self._exceptions_since_last_raise,
+ 0,
+ )
+ raise MutationsExceptionGroup.from_truncated_lists(
+ first_list=oldest,
+ last_list=newest,
+ total_excs=exc_count,
+ entry_count=entry_count,
+ )
+
+ async def __aenter__(self):
+ """For context manager API"""
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ """For context manager API"""
+ await self.close()
+
+ async def close(self):
+ """
+ Flush queue and clean up resources
+ """
+ self.closed = True
+ self._flush_timer.cancel()
+ self._schedule_flush()
+ if self._flush_jobs:
+ await asyncio.gather(*self._flush_jobs, return_exceptions=True)
+ try:
+ await self._flush_timer
+ except asyncio.CancelledError:
+ pass
+ atexit.unregister(self._on_exit)
+ # raise unreported exceptions
+ self._raise_exceptions()
+
+ def _on_exit(self):
+ """
+ Called when program is exited. Raises warning if unflushed mutations remain
+ """
+ if not self.closed and self._staged_entries:
+ warnings.warn(
+ f"MutationsBatcher for table {self._table.table_name} was not closed. "
+ f"{len(self._staged_entries)} Unflushed mutations will not be sent to the server."
+ )
+
+ @staticmethod
+ def _create_bg_task(func, *args, **kwargs) -> asyncio.Future[Any]:
+ """
+ Create a new background task, and return a future
+
+ This method wraps asyncio to make it easier to maintain subclasses
+ with different concurrency models.
+
+ Args:
+ - func: function to execute in background task
+ - *args: positional arguments to pass to func
+ - **kwargs: keyword arguments to pass to func
+ Returns:
+ - Future object representing the background task
+ """
+ return asyncio.create_task(func(*args, **kwargs))
+
+ @staticmethod
+ async def _wait_for_batch_results(
+ *tasks: asyncio.Future[list[FailedMutationEntryError]] | asyncio.Future[None],
+ ) -> list[Exception]:
+ """
+ Takes in a list of futures representing _execute_mutate_rows tasks,
+ waits for them to complete, and returns a list of errors encountered.
+
+ Args:
+ - *tasks: futures representing _execute_mutate_rows or _flush_internal tasks
+ Returns:
+ - list of Exceptions encountered by any of the tasks. Errors are expected
+ to be FailedMutationEntryError, representing a failed mutation operation.
+ If a task fails with a different exception, it will be included in the
+ output list. Successful tasks will not be represented in the output list.
+ """
+ if not tasks:
+ return []
+ all_results = await asyncio.gather(*tasks, return_exceptions=True)
+ found_errors = []
+ for result in all_results:
+ if isinstance(result, Exception):
+ # will receive direct Exception objects if request task fails
+ found_errors.append(result)
+ elif isinstance(result, BaseException):
+ # BaseException not expected from grpc calls. Raise immediately
+ raise result
+ elif result:
+ # completed requests will return a list of FailedMutationEntryError
+ for e in result:
+ # strip index information
+ e.index = None
+ found_errors.extend(result)
+ return found_errors
diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py
new file mode 100644
index 000000000..a0b13cbaf
--- /dev/null
+++ b/google/cloud/bigtable/data/_helpers.py
@@ -0,0 +1,220 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+Helper functions used in various places in the library.
+"""
+from __future__ import annotations
+
+from typing import Sequence, List, Tuple, TYPE_CHECKING
+import time
+import enum
+from collections import namedtuple
+from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+
+from google.api_core import exceptions as core_exceptions
+from google.api_core.retry import RetryFailureReason
+from google.cloud.bigtable.data.exceptions import RetryExceptionGroup
+
+if TYPE_CHECKING:
+ import grpc
+ from google.cloud.bigtable.data import TableAsync
+
+"""
+Helper functions used in various places in the library.
+"""
+
+# Type alias for the output of sample_keys
+RowKeySamples = List[Tuple[bytes, int]]
+
+# type alias for the output of query.shard()
+ShardedQuery = List[ReadRowsQuery]
+
+# used by read_rows_sharded to limit how many requests are attempted in parallel
+_CONCURRENCY_LIMIT = 10
+
+# used to register instance data with the client for channel warming
+_WarmedInstanceKey = namedtuple(
+ "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"]
+)
+
+
+# enum used on method calls when table defaults should be used
+class TABLE_DEFAULT(enum.Enum):
+ # default for mutate_row, sample_row_keys, check_and_mutate_row, and read_modify_write_row
+ DEFAULT = "DEFAULT"
+ # default for read_rows, read_rows_stream, read_rows_sharded, row_exists, and read_row
+ READ_ROWS = "READ_ROWS_DEFAULT"
+ # default for bulk_mutate_rows and mutations_batcher
+ MUTATE_ROWS = "MUTATE_ROWS_DEFAULT"
+
+
+def _make_metadata(
+ table_name: str, app_profile_id: str | None
+) -> list[tuple[str, str]]:
+ """
+ Create properly formatted gRPC metadata for requests.
+ """
+ params = []
+ params.append(f"table_name={table_name}")
+ if app_profile_id is not None:
+ params.append(f"app_profile_id={app_profile_id}")
+ params_str = "&".join(params)
+ return [("x-goog-request-params", params_str)]
+
+
+def _attempt_timeout_generator(
+ per_request_timeout: float | None, operation_timeout: float
+):
+ """
+ Generator that yields the timeout value for each attempt of a retry loop.
+
+ Will return per_request_timeout until the operation_timeout is approached,
+ at which point it will return the remaining time in the operation_timeout.
+
+ Args:
+ - per_request_timeout: The timeout value to use for each request, in seconds.
+ If None, the operation_timeout will be used for each request.
+ - operation_timeout: The timeout value to use for the entire operationm in seconds.
+ Yields:
+ - The timeout value to use for the next request, in seonds
+ """
+ per_request_timeout = (
+ per_request_timeout if per_request_timeout is not None else operation_timeout
+ )
+ deadline = operation_timeout + time.monotonic()
+ while True:
+ yield max(0, min(per_request_timeout, deadline - time.monotonic()))
+
+
+def _retry_exception_factory(
+ exc_list: list[Exception],
+ reason: RetryFailureReason,
+ timeout_val: float | None,
+) -> tuple[Exception, Exception | None]:
+ """
+ Build retry error based on exceptions encountered during operation
+
+ Args:
+ - exc_list: list of exceptions encountered during operation
+ - is_timeout: whether the operation failed due to timeout
+ - timeout_val: the operation timeout value in seconds, for constructing
+ the error message
+ Returns:
+ - tuple of the exception to raise, and a cause exception if applicable
+ """
+ if reason == RetryFailureReason.TIMEOUT:
+ timeout_val_str = f"of {timeout_val:0.1f}s " if timeout_val is not None else ""
+ # if failed due to timeout, raise deadline exceeded as primary exception
+ source_exc: Exception = core_exceptions.DeadlineExceeded(
+ f"operation_timeout{timeout_val_str} exceeded"
+ )
+ elif exc_list:
+ # otherwise, raise non-retryable error as primary exception
+ source_exc = exc_list.pop()
+ else:
+ source_exc = RuntimeError("failed with unspecified exception")
+ # use the retry exception group as the cause of the exception
+ cause_exc: Exception | None = RetryExceptionGroup(exc_list) if exc_list else None
+ source_exc.__cause__ = cause_exc
+ return source_exc, cause_exc
+
+
+def _get_timeouts(
+ operation: float | TABLE_DEFAULT,
+ attempt: float | None | TABLE_DEFAULT,
+ table: "TableAsync",
+) -> tuple[float, float]:
+ """
+ Convert passed in timeout values to floats, using table defaults if necessary.
+
+ attempt will use operation value if None, or if larger than operation.
+
+ Will call _validate_timeouts on the outputs, and raise ValueError if the
+ resulting timeouts are invalid.
+
+ Args:
+ - operation: The timeout value to use for the entire operation, in seconds.
+ - attempt: The timeout value to use for each attempt, in seconds.
+ - table: The table to use for default values.
+ Returns:
+ - A tuple of (operation_timeout, attempt_timeout)
+ """
+ # load table defaults if necessary
+ if operation == TABLE_DEFAULT.DEFAULT:
+ final_operation = table.default_operation_timeout
+ elif operation == TABLE_DEFAULT.READ_ROWS:
+ final_operation = table.default_read_rows_operation_timeout
+ elif operation == TABLE_DEFAULT.MUTATE_ROWS:
+ final_operation = table.default_mutate_rows_operation_timeout
+ else:
+ final_operation = operation
+ if attempt == TABLE_DEFAULT.DEFAULT:
+ attempt = table.default_attempt_timeout
+ elif attempt == TABLE_DEFAULT.READ_ROWS:
+ attempt = table.default_read_rows_attempt_timeout
+ elif attempt == TABLE_DEFAULT.MUTATE_ROWS:
+ attempt = table.default_mutate_rows_attempt_timeout
+
+ if attempt is None:
+ # no timeout specified, use operation timeout for both
+ final_attempt = final_operation
+ else:
+ # cap attempt timeout at operation timeout
+ final_attempt = min(attempt, final_operation) if final_operation else attempt
+
+ _validate_timeouts(final_operation, final_attempt, allow_none=False)
+ return final_operation, final_attempt
+
+
+def _validate_timeouts(
+ operation_timeout: float, attempt_timeout: float | None, allow_none: bool = False
+):
+ """
+ Helper function that will verify that timeout values are valid, and raise
+ an exception if they are not.
+
+ Args:
+ - operation_timeout: The timeout value to use for the entire operation, in seconds.
+ - attempt_timeout: The timeout value to use for each attempt, in seconds.
+ - allow_none: If True, attempt_timeout can be None. If False, None values will raise an exception.
+ Raises:
+ - ValueError if operation_timeout or attempt_timeout are invalid.
+ """
+ if operation_timeout is None:
+ raise ValueError("operation_timeout cannot be None")
+ if operation_timeout <= 0:
+ raise ValueError("operation_timeout must be greater than 0")
+ if not allow_none and attempt_timeout is None:
+ raise ValueError("attempt_timeout must not be None")
+ elif attempt_timeout is not None:
+ if attempt_timeout <= 0:
+ raise ValueError("attempt_timeout must be greater than 0")
+
+
+def _get_retryable_errors(
+ call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT,
+ table: "TableAsync",
+) -> list[type[Exception]]:
+ # load table defaults if necessary
+ if call_codes == TABLE_DEFAULT.DEFAULT:
+ call_codes = table.default_retryable_errors
+ elif call_codes == TABLE_DEFAULT.READ_ROWS:
+ call_codes = table.default_read_rows_retryable_errors
+ elif call_codes == TABLE_DEFAULT.MUTATE_ROWS:
+ call_codes = table.default_mutate_rows_retryable_errors
+
+ return [
+ e if isinstance(e, type) else type(core_exceptions.from_grpc_status(e, ""))
+ for e in call_codes
+ ]
diff --git a/google/cloud/bigtable/data/exceptions.py b/google/cloud/bigtable/data/exceptions.py
new file mode 100644
index 000000000..3c73ec4e9
--- /dev/null
+++ b/google/cloud/bigtable/data/exceptions.py
@@ -0,0 +1,307 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import annotations
+
+import sys
+
+from typing import Any, TYPE_CHECKING
+
+from google.api_core import exceptions as core_exceptions
+from google.cloud.bigtable.data.row import Row
+
+is_311_plus = sys.version_info >= (3, 11)
+
+if TYPE_CHECKING:
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+
+
+class InvalidChunk(core_exceptions.GoogleAPICallError):
+ """Exception raised to invalid chunk data from back-end."""
+
+
+class _RowSetComplete(Exception):
+ """
+ Internal exception for _ReadRowsOperation
+ Raised in revise_request_rowset when there are no rows left to process when starting a retry attempt
+ """
+
+ pass
+
+
+class _MutateRowsIncomplete(RuntimeError):
+ """
+ Exception raised when a mutate_rows call has unfinished work.
+ """
+
+ pass
+
+
+class _BigtableExceptionGroup(ExceptionGroup if is_311_plus else Exception): # type: ignore # noqa: F821
+ """
+ Represents one or more exceptions that occur during a bulk Bigtable operation
+
+ In Python 3.11+, this is an unmodified exception group. In < 3.10, it is a
+ custom exception with some exception group functionality backported, but does
+ Not implement the full API
+ """
+
+ def __init__(self, message, excs):
+ if is_311_plus:
+ super().__init__(message, excs)
+ else:
+ if len(excs) == 0:
+ raise ValueError("exceptions must be a non-empty sequence")
+ self.exceptions = tuple(excs)
+ # simulate an exception group in Python < 3.11 by adding exception info
+ # to the message
+ first_line = "--+---------------- 1 ----------------"
+ last_line = "+------------------------------------"
+ message_parts = [message + "\n" + first_line]
+ # print error info for each exception in the group
+ for idx, e in enumerate(excs[:15]):
+ # apply index header
+ if idx != 0:
+ message_parts.append(
+ f"+---------------- {str(idx+1).rjust(2)} ----------------"
+ )
+ cause = e.__cause__
+ # if this exception was had a cause, print the cause first
+ # used to display root causes of FailedMutationEntryError and FailedQueryShardError
+ # format matches the error output of Python 3.11+
+ if cause is not None:
+ message_parts.extend(
+ f"| {type(cause).__name__}: {cause}".splitlines()
+ )
+ message_parts.append("| ")
+ message_parts.append(
+ "| The above exception was the direct cause of the following exception:"
+ )
+ message_parts.append("| ")
+ # attach error message for this sub-exception
+ # if the subexception is also a _BigtableExceptionGroup,
+ # error messages will be nested
+ message_parts.extend(f"| {type(e).__name__}: {e}".splitlines())
+ # truncate the message if there are more than 15 exceptions
+ if len(excs) > 15:
+ message_parts.append("+---------------- ... ---------------")
+ message_parts.append(f"| and {len(excs) - 15} more")
+ if last_line not in message_parts[-1]:
+ # in the case of nested _BigtableExceptionGroups, the last line
+ # does not need to be added, since one was added by the final sub-exception
+ message_parts.append(last_line)
+ super().__init__("\n ".join(message_parts))
+
+ def __new__(cls, message, excs):
+ if is_311_plus:
+ return super().__new__(cls, message, excs)
+ else:
+ return super().__new__(cls)
+
+ def __str__(self):
+ if is_311_plus:
+ # don't return built-in sub-exception message
+ return self.args[0]
+ return super().__str__()
+
+ def __repr__(self):
+ """
+ repr representation should strip out sub-exception details
+ """
+ if is_311_plus:
+ return super().__repr__()
+ message = self.args[0].split("\n")[0]
+ return f"{self.__class__.__name__}({message!r}, {self.exceptions!r})"
+
+
+class MutationsExceptionGroup(_BigtableExceptionGroup):
+ """
+ Represents one or more exceptions that occur during a bulk mutation operation
+
+ Exceptions will typically be of type FailedMutationEntryError, but other exceptions may
+ be included if they are raised during the mutation operation
+ """
+
+ @staticmethod
+ def _format_message(
+ excs: list[Exception], total_entries: int, exc_count: int | None = None
+ ) -> str:
+ """
+ Format a message for the exception group
+
+ Args:
+ - excs: the exceptions in the group
+ - total_entries: the total number of entries attempted, successful or not
+ - exc_count: the number of exceptions associated with the request
+ if None, this will be len(excs)
+ """
+ exc_count = exc_count if exc_count is not None else len(excs)
+ entry_str = "entry" if exc_count == 1 else "entries"
+ return f"{exc_count} failed {entry_str} from {total_entries} attempted."
+
+ def __init__(
+ self, excs: list[Exception], total_entries: int, message: str | None = None
+ ):
+ """
+ Args:
+ - excs: the exceptions in the group
+ - total_entries: the total number of entries attempted, successful or not
+ - message: the message for the exception group. If None, a default message
+ will be generated
+ """
+ message = (
+ message
+ if message is not None
+ else self._format_message(excs, total_entries)
+ )
+ super().__init__(message, excs)
+ self.total_entries_attempted = total_entries
+
+ def __new__(
+ cls, excs: list[Exception], total_entries: int, message: str | None = None
+ ):
+ """
+ Args:
+ - excs: the exceptions in the group
+ - total_entries: the total number of entries attempted, successful or not
+ - message: the message for the exception group. If None, a default message
+ """
+ message = (
+ message if message is not None else cls._format_message(excs, total_entries)
+ )
+ instance = super().__new__(cls, message, excs)
+ instance.total_entries_attempted = total_entries
+ return instance
+
+ @classmethod
+ def from_truncated_lists(
+ cls,
+ first_list: list[Exception],
+ last_list: list[Exception],
+ total_excs: int,
+ entry_count: int,
+ ) -> MutationsExceptionGroup:
+ """
+ Create a MutationsExceptionGroup from two lists of exceptions, representing
+ a larger set that has been truncated. The MutationsExceptionGroup will
+ contain the union of the two lists as sub-exceptions, and the error message
+ describe the number of exceptions that were truncated.
+
+ Args:
+ - first_list: the set of oldest exceptions to add to the ExceptionGroup
+ - last_list: the set of newest exceptions to add to the ExceptionGroup
+ - total_excs: the total number of exceptions associated with the request
+ Should be len(first_list) + len(last_list) + number of dropped exceptions
+ in the middle
+ - entry_count: the total number of entries attempted, successful or not
+ """
+ first_count, last_count = len(first_list), len(last_list)
+ if first_count + last_count >= total_excs:
+ # no exceptions were dropped
+ return cls(first_list + last_list, entry_count)
+ excs = first_list + last_list
+ truncation_count = total_excs - (first_count + last_count)
+ base_message = cls._format_message(excs, entry_count, total_excs)
+ first_message = f"first {first_count}" if first_count else ""
+ last_message = f"last {last_count}" if last_count else ""
+ conjunction = " and " if first_message and last_message else ""
+ message = f"{base_message} ({first_message}{conjunction}{last_message} attached as sub-exceptions; {truncation_count} truncated)"
+ return cls(excs, entry_count, message)
+
+
+class FailedMutationEntryError(Exception):
+ """
+ Represents a single failed RowMutationEntry in a bulk_mutate_rows request.
+ A collection of FailedMutationEntryErrors will be raised in a MutationsExceptionGroup
+ """
+
+ def __init__(
+ self,
+ failed_idx: int | None,
+ failed_mutation_entry: "RowMutationEntry",
+ cause: Exception,
+ ):
+ idempotent_msg = (
+ "idempotent" if failed_mutation_entry.is_idempotent() else "non-idempotent"
+ )
+ index_msg = f" at index {failed_idx}" if failed_idx is not None else ""
+ message = f"Failed {idempotent_msg} mutation entry{index_msg}"
+ super().__init__(message)
+ self.__cause__ = cause
+ self.index = failed_idx
+ self.entry = failed_mutation_entry
+
+
+class RetryExceptionGroup(_BigtableExceptionGroup):
+ """Represents one or more exceptions that occur during a retryable operation"""
+
+ @staticmethod
+ def _format_message(excs: list[Exception]):
+ if len(excs) == 0:
+ return "No exceptions"
+ plural = "s" if len(excs) > 1 else ""
+ return f"{len(excs)} failed attempt{plural}"
+
+ def __init__(self, excs: list[Exception]):
+ super().__init__(self._format_message(excs), excs)
+
+ def __new__(cls, excs: list[Exception]):
+ return super().__new__(cls, cls._format_message(excs), excs)
+
+
+class ShardedReadRowsExceptionGroup(_BigtableExceptionGroup):
+ """
+ Represents one or more exceptions that occur during a sharded read rows operation
+ """
+
+ @staticmethod
+ def _format_message(excs: list[FailedQueryShardError], total_queries: int):
+ query_str = "query" if total_queries == 1 else "queries"
+ plural_str = "" if len(excs) == 1 else "s"
+ return f"{len(excs)} sub-exception{plural_str} (from {total_queries} {query_str} attempted)"
+
+ def __init__(
+ self,
+ excs: list[FailedQueryShardError],
+ succeeded: list[Row],
+ total_queries: int,
+ ):
+ super().__init__(self._format_message(excs, total_queries), excs)
+ self.successful_rows = succeeded
+
+ def __new__(
+ cls, excs: list[FailedQueryShardError], succeeded: list[Row], total_queries: int
+ ):
+ instance = super().__new__(cls, cls._format_message(excs, total_queries), excs)
+ instance.successful_rows = succeeded
+ return instance
+
+
+class FailedQueryShardError(Exception):
+ """
+ Represents an individual failed query in a sharded read rows operation
+ """
+
+ def __init__(
+ self,
+ failed_index: int,
+ failed_query: "ReadRowsQuery" | dict[str, Any],
+ cause: Exception,
+ ):
+ message = f"Failed query at index {failed_index}"
+ super().__init__(message)
+ self.__cause__ = cause
+ self.index = failed_index
+ self.query = failed_query
diff --git a/google/cloud/bigtable/data/mutations.py b/google/cloud/bigtable/data/mutations.py
new file mode 100644
index 000000000..b5729d25e
--- /dev/null
+++ b/google/cloud/bigtable/data/mutations.py
@@ -0,0 +1,256 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import annotations
+from typing import Any
+import time
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+from sys import getsizeof
+
+import google.cloud.bigtable_v2.types.bigtable as types_pb
+import google.cloud.bigtable_v2.types.data as data_pb
+
+from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE
+
+
+# special value for SetCell mutation timestamps. If set, server will assign a timestamp
+_SERVER_SIDE_TIMESTAMP = -1
+
+# mutation entries above this should be rejected
+_MUTATE_ROWS_REQUEST_MUTATION_LIMIT = 100_000
+
+
+class Mutation(ABC):
+ """Model class for mutations"""
+
+ @abstractmethod
+ def _to_dict(self) -> dict[str, Any]:
+ raise NotImplementedError
+
+ def _to_pb(self) -> data_pb.Mutation:
+ """
+ Convert the mutation to protobuf
+ """
+ return data_pb.Mutation(**self._to_dict())
+
+ def is_idempotent(self) -> bool:
+ """
+ Check if the mutation is idempotent
+ If false, the mutation will not be retried
+ """
+ return True
+
+ def __str__(self) -> str:
+ return str(self._to_dict())
+
+ def size(self) -> int:
+ """
+ Get the size of the mutation in bytes
+ """
+ return getsizeof(self._to_dict())
+
+ @classmethod
+ def _from_dict(cls, input_dict: dict[str, Any]) -> Mutation:
+ instance: Mutation | None = None
+ try:
+ if "set_cell" in input_dict:
+ details = input_dict["set_cell"]
+ instance = SetCell(
+ details["family_name"],
+ details["column_qualifier"],
+ details["value"],
+ details["timestamp_micros"],
+ )
+ elif "delete_from_column" in input_dict:
+ details = input_dict["delete_from_column"]
+ time_range = details.get("time_range", {})
+ start = time_range.get("start_timestamp_micros", None)
+ end = time_range.get("end_timestamp_micros", None)
+ instance = DeleteRangeFromColumn(
+ details["family_name"], details["column_qualifier"], start, end
+ )
+ elif "delete_from_family" in input_dict:
+ details = input_dict["delete_from_family"]
+ instance = DeleteAllFromFamily(details["family_name"])
+ elif "delete_from_row" in input_dict:
+ instance = DeleteAllFromRow()
+ except KeyError as e:
+ raise ValueError("Invalid mutation dictionary") from e
+ if instance is None:
+ raise ValueError("No valid mutation found")
+ if not issubclass(instance.__class__, cls):
+ raise ValueError("Mutation type mismatch")
+ return instance
+
+
+class SetCell(Mutation):
+ def __init__(
+ self,
+ family: str,
+ qualifier: bytes | str,
+ new_value: bytes | str | int,
+ timestamp_micros: int | None = None,
+ ):
+ """
+ Mutation to set the value of a cell
+
+ Args:
+ - family: The name of the column family to which the new cell belongs.
+ - qualifier: The column qualifier of the new cell.
+ - new_value: The value of the new cell. str or int input will be converted to bytes
+ - timestamp_micros: The timestamp of the new cell. If None, the current timestamp will be used.
+ Timestamps will be sent with milisecond-percision. Extra precision will be truncated.
+ If -1, the server will assign a timestamp. Note that SetCell mutations with server-side
+ timestamps are non-idempotent operations and will not be retried.
+ """
+ qualifier = qualifier.encode() if isinstance(qualifier, str) else qualifier
+ if not isinstance(qualifier, bytes):
+ raise TypeError("qualifier must be bytes or str")
+ if isinstance(new_value, str):
+ new_value = new_value.encode()
+ elif isinstance(new_value, int):
+ if abs(new_value) > _MAX_INCREMENT_VALUE:
+ raise ValueError(
+ "int values must be between -2**63 and 2**63 (64-bit signed int)"
+ )
+ new_value = new_value.to_bytes(8, "big", signed=True)
+ if not isinstance(new_value, bytes):
+ raise TypeError("new_value must be bytes, str, or int")
+ if timestamp_micros is None:
+ # use current timestamp, with milisecond precision
+ timestamp_micros = time.time_ns() // 1000
+ timestamp_micros = timestamp_micros - (timestamp_micros % 1000)
+ if timestamp_micros < _SERVER_SIDE_TIMESTAMP:
+ raise ValueError(
+ f"timestamp_micros must be positive (or {_SERVER_SIDE_TIMESTAMP} for server-side timestamp)"
+ )
+ self.family = family
+ self.qualifier = qualifier
+ self.new_value = new_value
+ self.timestamp_micros = timestamp_micros
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Convert the mutation to a dictionary representation"""
+ return {
+ "set_cell": {
+ "family_name": self.family,
+ "column_qualifier": self.qualifier,
+ "timestamp_micros": self.timestamp_micros,
+ "value": self.new_value,
+ }
+ }
+
+ def is_idempotent(self) -> bool:
+ """Check if the mutation is idempotent"""
+ return self.timestamp_micros != _SERVER_SIDE_TIMESTAMP
+
+
+@dataclass
+class DeleteRangeFromColumn(Mutation):
+ family: str
+ qualifier: bytes
+ # None represents 0
+ start_timestamp_micros: int | None = None
+ # None represents infinity
+ end_timestamp_micros: int | None = None
+
+ def __post_init__(self):
+ if (
+ self.start_timestamp_micros is not None
+ and self.end_timestamp_micros is not None
+ and self.start_timestamp_micros > self.end_timestamp_micros
+ ):
+ raise ValueError("start_timestamp_micros must be <= end_timestamp_micros")
+
+ def _to_dict(self) -> dict[str, Any]:
+ timestamp_range = {}
+ if self.start_timestamp_micros is not None:
+ timestamp_range["start_timestamp_micros"] = self.start_timestamp_micros
+ if self.end_timestamp_micros is not None:
+ timestamp_range["end_timestamp_micros"] = self.end_timestamp_micros
+ return {
+ "delete_from_column": {
+ "family_name": self.family,
+ "column_qualifier": self.qualifier,
+ "time_range": timestamp_range,
+ }
+ }
+
+
+@dataclass
+class DeleteAllFromFamily(Mutation):
+ family_to_delete: str
+
+ def _to_dict(self) -> dict[str, Any]:
+ return {
+ "delete_from_family": {
+ "family_name": self.family_to_delete,
+ }
+ }
+
+
+@dataclass
+class DeleteAllFromRow(Mutation):
+ def _to_dict(self) -> dict[str, Any]:
+ return {
+ "delete_from_row": {},
+ }
+
+
+class RowMutationEntry:
+ def __init__(self, row_key: bytes | str, mutations: Mutation | list[Mutation]):
+ if isinstance(row_key, str):
+ row_key = row_key.encode("utf-8")
+ if isinstance(mutations, Mutation):
+ mutations = [mutations]
+ if len(mutations) == 0:
+ raise ValueError("mutations must not be empty")
+ elif len(mutations) > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT:
+ raise ValueError(
+ f"entries must have <= {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations"
+ )
+ self.row_key = row_key
+ self.mutations = tuple(mutations)
+
+ def _to_dict(self) -> dict[str, Any]:
+ return {
+ "row_key": self.row_key,
+ "mutations": [mutation._to_dict() for mutation in self.mutations],
+ }
+
+ def _to_pb(self) -> types_pb.MutateRowsRequest.Entry:
+ return types_pb.MutateRowsRequest.Entry(
+ row_key=self.row_key,
+ mutations=[mutation._to_pb() for mutation in self.mutations],
+ )
+
+ def is_idempotent(self) -> bool:
+ """Check if the mutation is idempotent"""
+ return all(mutation.is_idempotent() for mutation in self.mutations)
+
+ def size(self) -> int:
+ """
+ Get the size of the mutation in bytes
+ """
+ return getsizeof(self._to_dict())
+
+ @classmethod
+ def _from_dict(cls, input_dict: dict[str, Any]) -> RowMutationEntry:
+ return RowMutationEntry(
+ row_key=input_dict["row_key"],
+ mutations=[
+ Mutation._from_dict(mutation) for mutation in input_dict["mutations"]
+ ],
+ )
diff --git a/google/cloud/bigtable/data/read_modify_write_rules.py b/google/cloud/bigtable/data/read_modify_write_rules.py
new file mode 100644
index 000000000..f43dbe79f
--- /dev/null
+++ b/google/cloud/bigtable/data/read_modify_write_rules.py
@@ -0,0 +1,77 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import annotations
+
+import abc
+
+import google.cloud.bigtable_v2.types.data as data_pb
+
+# value must fit in 64-bit signed integer
+_MAX_INCREMENT_VALUE = (1 << 63) - 1
+
+
+class ReadModifyWriteRule(abc.ABC):
+ def __init__(self, family: str, qualifier: bytes | str):
+ qualifier = (
+ qualifier if isinstance(qualifier, bytes) else qualifier.encode("utf-8")
+ )
+ self.family = family
+ self.qualifier = qualifier
+
+ @abc.abstractmethod
+ def _to_dict(self) -> dict[str, str | bytes | int]:
+ raise NotImplementedError
+
+ def _to_pb(self) -> data_pb.ReadModifyWriteRule:
+ return data_pb.ReadModifyWriteRule(**self._to_dict())
+
+
+class IncrementRule(ReadModifyWriteRule):
+ def __init__(self, family: str, qualifier: bytes | str, increment_amount: int = 1):
+ if not isinstance(increment_amount, int):
+ raise TypeError("increment_amount must be an integer")
+ if abs(increment_amount) > _MAX_INCREMENT_VALUE:
+ raise ValueError(
+ "increment_amount must be between -2**63 and 2**63 (64-bit signed int)"
+ )
+ super().__init__(family, qualifier)
+ self.increment_amount = increment_amount
+
+ def _to_dict(self) -> dict[str, str | bytes | int]:
+ return {
+ "family_name": self.family,
+ "column_qualifier": self.qualifier,
+ "increment_amount": self.increment_amount,
+ }
+
+
+class AppendValueRule(ReadModifyWriteRule):
+ def __init__(self, family: str, qualifier: bytes | str, append_value: bytes | str):
+ append_value = (
+ append_value.encode("utf-8")
+ if isinstance(append_value, str)
+ else append_value
+ )
+ if not isinstance(append_value, bytes):
+ raise TypeError("append_value must be bytes or str")
+ super().__init__(family, qualifier)
+ self.append_value = append_value
+
+ def _to_dict(self) -> dict[str, str | bytes | int]:
+ return {
+ "family_name": self.family,
+ "column_qualifier": self.qualifier,
+ "append_value": self.append_value,
+ }
diff --git a/google/cloud/bigtable/data/read_rows_query.py b/google/cloud/bigtable/data/read_rows_query.py
new file mode 100644
index 000000000..362f54c3e
--- /dev/null
+++ b/google/cloud/bigtable/data/read_rows_query.py
@@ -0,0 +1,476 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import annotations
+from typing import TYPE_CHECKING, Any
+from bisect import bisect_left
+from bisect import bisect_right
+from collections import defaultdict
+from google.cloud.bigtable.data.row_filters import RowFilter
+
+from google.cloud.bigtable_v2.types import RowRange as RowRangePB
+from google.cloud.bigtable_v2.types import RowSet as RowSetPB
+from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB
+
+if TYPE_CHECKING:
+ from google.cloud.bigtable.data import RowKeySamples
+ from google.cloud.bigtable.data import ShardedQuery
+
+
+class RowRange:
+ """
+ Represents a range of keys in a ReadRowsQuery
+ """
+
+ __slots__ = ("_pb",)
+
+ def __init__(
+ self,
+ start_key: str | bytes | None = None,
+ end_key: str | bytes | None = None,
+ start_is_inclusive: bool | None = None,
+ end_is_inclusive: bool | None = None,
+ ):
+ """
+ Args:
+ - start_key: The start key of the range. If empty, the range is unbounded on the left.
+ - end_key: The end key of the range. If empty, the range is unbounded on the right.
+ - start_is_inclusive: Whether the start key is inclusive. If None, the start key is
+ inclusive.
+ - end_is_inclusive: Whether the end key is inclusive. If None, the end key is not inclusive.
+ Raises:
+ - ValueError: if start_key is greater than end_key, or start_is_inclusive,
+ or end_is_inclusive is set when the corresponding key is None,
+ or start_key or end_key is not a string or bytes.
+ """
+ # convert empty key inputs to None for consistency
+ start_key = None if not start_key else start_key
+ end_key = None if not end_key else end_key
+ # check for invalid combinations of arguments
+ if start_is_inclusive is None:
+ start_is_inclusive = True
+
+ if end_is_inclusive is None:
+ end_is_inclusive = False
+ # ensure that start_key and end_key are bytes
+ if isinstance(start_key, str):
+ start_key = start_key.encode()
+ elif start_key is not None and not isinstance(start_key, bytes):
+ raise ValueError("start_key must be a string or bytes")
+ if isinstance(end_key, str):
+ end_key = end_key.encode()
+ elif end_key is not None and not isinstance(end_key, bytes):
+ raise ValueError("end_key must be a string or bytes")
+ # ensure that start_key is less than or equal to end_key
+ if start_key is not None and end_key is not None and start_key > end_key:
+ raise ValueError("start_key must be less than or equal to end_key")
+
+ init_dict = {}
+ if start_key is not None:
+ if start_is_inclusive:
+ init_dict["start_key_closed"] = start_key
+ else:
+ init_dict["start_key_open"] = start_key
+ if end_key is not None:
+ if end_is_inclusive:
+ init_dict["end_key_closed"] = end_key
+ else:
+ init_dict["end_key_open"] = end_key
+ self._pb = RowRangePB(**init_dict)
+
+ @property
+ def start_key(self) -> bytes | None:
+ """
+ Returns the start key of the range. If None, the range is unbounded on the left.
+ """
+ return self._pb.start_key_closed or self._pb.start_key_open or None
+
+ @property
+ def end_key(self) -> bytes | None:
+ """
+ Returns the end key of the range. If None, the range is unbounded on the right.
+ """
+ return self._pb.end_key_closed or self._pb.end_key_open or None
+
+ @property
+ def start_is_inclusive(self) -> bool:
+ """
+ Returns whether the range is inclusive of the start key.
+ Returns True if the range is unbounded on the left.
+ """
+ return not bool(self._pb.start_key_open)
+
+ @property
+ def end_is_inclusive(self) -> bool:
+ """
+ Returns whether the range is inclusive of the end key.
+ Returns True if the range is unbounded on the right.
+ """
+ return not bool(self._pb.end_key_open)
+
+ def _to_pb(self) -> RowRangePB:
+ """Converts this object to a protobuf"""
+ return self._pb
+
+ @classmethod
+ def _from_pb(cls, data: RowRangePB) -> RowRange:
+ """Creates a RowRange from a protobuf"""
+ instance = cls()
+ instance._pb = data
+ return instance
+
+ @classmethod
+ def _from_dict(cls, data: dict[str, bytes | str]) -> RowRange:
+ """Creates a RowRange from a protobuf"""
+ formatted_data = {
+ k: v.encode() if isinstance(v, str) else v for k, v in data.items()
+ }
+ instance = cls()
+ instance._pb = RowRangePB(**formatted_data)
+ return instance
+
+ def __bool__(self) -> bool:
+ """
+ Empty RowRanges (representing a full table scan) are falsy, because
+ they can be substituted with None. Non-empty RowRanges are truthy.
+ """
+ return bool(
+ self._pb.start_key_closed
+ or self._pb.start_key_open
+ or self._pb.end_key_closed
+ or self._pb.end_key_open
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, RowRange):
+ return NotImplemented
+ return self._pb == other._pb
+
+ def __str__(self) -> str:
+ """
+ Represent range as a string, e.g. "[b'a', b'z)"
+ Unbounded start or end keys are represented as "-inf" or "+inf"
+ """
+ left = "[" if self.start_is_inclusive else "("
+ right = "]" if self.end_is_inclusive else ")"
+ start = repr(self.start_key) if self.start_key is not None else "-inf"
+ end = repr(self.end_key) if self.end_key is not None else "+inf"
+ return f"{left}{start}, {end}{right}"
+
+ def __repr__(self) -> str:
+ args_list = []
+ args_list.append(f"start_key={self.start_key!r}")
+ args_list.append(f"end_key={self.end_key!r}")
+ if self.start_is_inclusive is False:
+ # only show start_is_inclusive if it is different from the default
+ args_list.append(f"start_is_inclusive={self.start_is_inclusive}")
+ if self.end_is_inclusive is True and self.end_key is not None:
+ # only show end_is_inclusive if it is different from the default
+ args_list.append(f"end_is_inclusive={self.end_is_inclusive}")
+ return f"RowRange({', '.join(args_list)})"
+
+
+class ReadRowsQuery:
+ """
+ Class to encapsulate details of a read row request
+ """
+
+ slots = ("_limit", "_filter", "_row_set")
+
+ def __init__(
+ self,
+ row_keys: list[str | bytes] | str | bytes | None = None,
+ row_ranges: list[RowRange] | RowRange | None = None,
+ limit: int | None = None,
+ row_filter: RowFilter | None = None,
+ ):
+ """
+ Create a new ReadRowsQuery
+
+ Args:
+ - row_keys: row keys to include in the query
+ a query can contain multiple keys, but ranges should be preferred
+ - row_ranges: ranges of rows to include in the query
+ - limit: the maximum number of rows to return. None or 0 means no limit
+ default: None (no limit)
+ - row_filter: a RowFilter to apply to the query
+ """
+ if row_keys is None:
+ row_keys = []
+ if row_ranges is None:
+ row_ranges = []
+ if not isinstance(row_ranges, list):
+ row_ranges = [row_ranges]
+ if not isinstance(row_keys, list):
+ row_keys = [row_keys]
+ row_keys = [key.encode() if isinstance(key, str) else key for key in row_keys]
+ self._row_set = RowSetPB(
+ row_keys=row_keys, row_ranges=[r._pb for r in row_ranges]
+ )
+ self.limit = limit or None
+ self.filter = row_filter
+
+ @property
+ def row_keys(self) -> list[bytes]:
+ return list(self._row_set.row_keys)
+
+ @property
+ def row_ranges(self) -> list[RowRange]:
+ return [RowRange._from_pb(r) for r in self._row_set.row_ranges]
+
+ @property
+ def limit(self) -> int | None:
+ return self._limit or None
+
+ @limit.setter
+ def limit(self, new_limit: int | None):
+ """
+ Set the maximum number of rows to return by this query.
+
+ None or 0 means no limit
+
+ Args:
+ - new_limit: the new limit to apply to this query
+ Returns:
+ - a reference to this query for chaining
+ Raises:
+ - ValueError if new_limit is < 0
+ """
+ if new_limit is not None and new_limit < 0:
+ raise ValueError("limit must be >= 0")
+ self._limit = new_limit
+
+ @property
+ def filter(self) -> RowFilter | None:
+ return self._filter
+
+ @filter.setter
+ def filter(self, row_filter: RowFilter | None):
+ """
+ Set a RowFilter to apply to this query
+
+ Args:
+ - row_filter: a RowFilter to apply to this query
+ Returns:
+ - a reference to this query for chaining
+ """
+ self._filter = row_filter
+
+ def add_key(self, row_key: str | bytes):
+ """
+ Add a row key to this query
+
+ A query can contain multiple keys, but ranges should be preferred
+
+ Args:
+ - row_key: a key to add to this query
+ Returns:
+ - a reference to this query for chaining
+ Raises:
+ - ValueError if an input is not a string or bytes
+ """
+ if isinstance(row_key, str):
+ row_key = row_key.encode()
+ elif not isinstance(row_key, bytes):
+ raise ValueError("row_key must be string or bytes")
+ if row_key not in self._row_set.row_keys:
+ self._row_set.row_keys.append(row_key)
+
+ def add_range(
+ self,
+ row_range: RowRange,
+ ):
+ """
+ Add a range of row keys to this query.
+
+ Args:
+ - row_range: a range of row keys to add to this query
+ """
+ if row_range not in self.row_ranges:
+ self._row_set.row_ranges.append(row_range._pb)
+
+ def shard(self, shard_keys: RowKeySamples) -> ShardedQuery:
+ """
+ Split this query into multiple queries that can be evenly distributed
+ across nodes and run in parallel
+
+ Returns:
+ - a ShardedQuery that can be used in sharded_read_rows calls
+ Raises:
+ - AttributeError if the query contains a limit
+ """
+ if self.limit is not None:
+ raise AttributeError("Cannot shard query with a limit")
+ if len(self.row_keys) == 0 and len(self.row_ranges) == 0:
+ # empty query represents full scan
+ # ensure that we have at least one key or range
+ full_scan_query = ReadRowsQuery(
+ row_ranges=RowRange(), row_filter=self.filter
+ )
+ return full_scan_query.shard(shard_keys)
+
+ sharded_queries: dict[int, ReadRowsQuery] = defaultdict(
+ lambda: ReadRowsQuery(row_filter=self.filter)
+ )
+ # the split_points divde our key space into segments
+ # each split_point defines last key that belongs to a segment
+ # our goal is to break up the query into subqueries that each operate in a single segment
+ split_points = [sample[0] for sample in shard_keys if sample[0]]
+
+ # handle row_keys
+ # use binary search to find the segment that each key belongs to
+ for this_key in list(self.row_keys):
+ # bisect_left: in case of exact match, pick left side (keys are inclusive ends)
+ segment_index = bisect_left(split_points, this_key)
+ sharded_queries[segment_index].add_key(this_key)
+
+ # handle row_ranges
+ for this_range in self.row_ranges:
+ # defer to _shard_range helper
+ for segment_index, added_range in self._shard_range(
+ this_range, split_points
+ ):
+ sharded_queries[segment_index].add_range(added_range)
+ # return list of queries ordered by segment index
+ # pull populated segments out of sharded_queries dict
+ keys = sorted(list(sharded_queries.keys()))
+ # return list of queries
+ return [sharded_queries[k] for k in keys]
+
+ @staticmethod
+ def _shard_range(
+ orig_range: RowRange, split_points: list[bytes]
+ ) -> list[tuple[int, RowRange]]:
+ """
+ Helper function for sharding row_range into subranges that fit into
+ segments of the key-space, determined by split_points
+
+ Args:
+ - orig_range: a row range to split
+ - split_points: a list of row keys that define the boundaries of segments.
+ each point represents the inclusive end of a segment
+ Returns:
+ - a list of tuples, containing a segment index and a new sub-range.
+ """
+ # 1. find the index of the segment the start key belongs to
+ if orig_range.start_key is None:
+ # if range is open on the left, include first segment
+ start_segment = 0
+ else:
+ # use binary search to find the segment the start key belongs to
+ # bisect method determines how we break ties when the start key matches a split point
+ # if inclusive, bisect_left to the left segment, otherwise bisect_right
+ bisect = bisect_left if orig_range.start_is_inclusive else bisect_right
+ start_segment = bisect(split_points, orig_range.start_key)
+
+ # 2. find the index of the segment the end key belongs to
+ if orig_range.end_key is None:
+ # if range is open on the right, include final segment
+ end_segment = len(split_points)
+ else:
+ # use binary search to find the segment the end key belongs to.
+ end_segment = bisect_left(
+ split_points, orig_range.end_key, lo=start_segment
+ )
+ # note: end_segment will always bisect_left, because split points represent inclusive ends
+ # whether the end_key is includes the split point or not, the result is the same segment
+ # 3. create new range definitions for each segment this_range spans
+ if start_segment == end_segment:
+ # this_range is contained in a single segment.
+ # Add this_range to that segment's query only
+ return [(start_segment, orig_range)]
+ else:
+ results: list[tuple[int, RowRange]] = []
+ # this_range spans multiple segments. Create a new range for each segment's query
+ # 3a. add new range for first segment this_range spans
+ # first range spans from start_key to the split_point representing the last key in the segment
+ last_key_in_first_segment = split_points[start_segment]
+ start_range = RowRange(
+ start_key=orig_range.start_key,
+ start_is_inclusive=orig_range.start_is_inclusive,
+ end_key=last_key_in_first_segment,
+ end_is_inclusive=True,
+ )
+ results.append((start_segment, start_range))
+ # 3b. add new range for last segment this_range spans
+ # we start the final range using the end key from of the previous segment, with is_inclusive=False
+ previous_segment = end_segment - 1
+ last_key_before_segment = split_points[previous_segment]
+ end_range = RowRange(
+ start_key=last_key_before_segment,
+ start_is_inclusive=False,
+ end_key=orig_range.end_key,
+ end_is_inclusive=orig_range.end_is_inclusive,
+ )
+ results.append((end_segment, end_range))
+ # 3c. add new spanning range to all segments other than the first and last
+ for this_segment in range(start_segment + 1, end_segment):
+ prev_segment = this_segment - 1
+ prev_end_key = split_points[prev_segment]
+ this_end_key = split_points[prev_segment + 1]
+ new_range = RowRange(
+ start_key=prev_end_key,
+ start_is_inclusive=False,
+ end_key=this_end_key,
+ end_is_inclusive=True,
+ )
+ results.append((this_segment, new_range))
+ return results
+
+ def _to_pb(self, table) -> ReadRowsRequestPB:
+ """
+ Convert this query into a dictionary that can be used to construct a
+ ReadRowsRequest protobuf
+ """
+ return ReadRowsRequestPB(
+ table_name=table.table_name,
+ app_profile_id=table.app_profile_id,
+ filter=self.filter._to_pb() if self.filter else None,
+ rows_limit=self.limit or 0,
+ rows=self._row_set,
+ )
+
+ def __eq__(self, other):
+ """
+ RowRanges are equal if they have the same row keys, row ranges,
+ filter and limit, or if they both represent a full scan with the
+ same filter and limit
+ """
+ if not isinstance(other, ReadRowsQuery):
+ return False
+ # empty queries are equal
+ if len(self.row_keys) == 0 and len(other.row_keys) == 0:
+ this_range_empty = len(self.row_ranges) == 0 or all(
+ [bool(r) is False for r in self.row_ranges]
+ )
+ other_range_empty = len(other.row_ranges) == 0 or all(
+ [bool(r) is False for r in other.row_ranges]
+ )
+ if this_range_empty and other_range_empty:
+ return self.filter == other.filter and self.limit == other.limit
+ # otherwise, sets should have same sizes
+ if len(self.row_keys) != len(other.row_keys):
+ return False
+ if len(self.row_ranges) != len(other.row_ranges):
+ return False
+ ranges_match = all([row in other.row_ranges for row in self.row_ranges])
+ return (
+ self.row_keys == other.row_keys
+ and ranges_match
+ and self.filter == other.filter
+ and self.limit == other.limit
+ )
+
+ def __repr__(self):
+ return f"ReadRowsQuery(row_keys={list(self.row_keys)}, row_ranges={list(self.row_ranges)}, row_filter={self.filter}, limit={self.limit})"
diff --git a/google/cloud/bigtable/data/row.py b/google/cloud/bigtable/data/row.py
new file mode 100644
index 000000000..ecf9cea66
--- /dev/null
+++ b/google/cloud/bigtable/data/row.py
@@ -0,0 +1,450 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import annotations
+
+from collections import OrderedDict
+from typing import Generator, overload, Any
+from functools import total_ordering
+
+from google.cloud.bigtable_v2.types import Row as RowPB
+
+# Type aliases used internally for readability.
+_family_type = str
+_qualifier_type = bytes
+
+
+class Row:
+ """
+ Model class for row data returned from server
+
+ Does not represent all data contained in the row, only data returned by a
+ query.
+ Expected to be read-only to users, and written by backend
+
+ Can be indexed:
+ cells = row["family", "qualifier"]
+ """
+
+ __slots__ = ("row_key", "cells", "_index_data")
+
+ def __init__(
+ self,
+ key: bytes,
+ cells: list[Cell],
+ ):
+ """
+ Initializes a Row object
+
+ Row objects are not intended to be created by users.
+ They are returned by the Bigtable backend.
+ """
+ self.row_key = key
+ self.cells: list[Cell] = cells
+ # index is lazily created when needed
+ self._index_data: OrderedDict[
+ _family_type, OrderedDict[_qualifier_type, list[Cell]]
+ ] | None = None
+
+ @property
+ def _index(
+ self,
+ ) -> OrderedDict[_family_type, OrderedDict[_qualifier_type, list[Cell]]]:
+ """
+ Returns an index of cells associated with each family and qualifier.
+
+ The index is lazily created when needed
+ """
+ if self._index_data is None:
+ self._index_data = OrderedDict()
+ for cell in self.cells:
+ self._index_data.setdefault(cell.family, OrderedDict()).setdefault(
+ cell.qualifier, []
+ ).append(cell)
+ return self._index_data
+
+ @classmethod
+ def _from_pb(cls, row_pb: RowPB) -> Row:
+ """
+ Creates a row from a protobuf representation
+
+ Row objects are not intended to be created by users.
+ They are returned by the Bigtable backend.
+ """
+ row_key: bytes = row_pb.key
+ cell_list: list[Cell] = []
+ for family in row_pb.families:
+ for column in family.columns:
+ for cell in column.cells:
+ new_cell = Cell(
+ value=cell.value,
+ row_key=row_key,
+ family=family.name,
+ qualifier=column.qualifier,
+ timestamp_micros=cell.timestamp_micros,
+ labels=list(cell.labels) if cell.labels else None,
+ )
+ cell_list.append(new_cell)
+ return cls(row_key, cells=cell_list)
+
+ def get_cells(
+ self, family: str | None = None, qualifier: str | bytes | None = None
+ ) -> list[Cell]:
+ """
+ Returns cells sorted in Bigtable native order:
+ - Family lexicographically ascending
+ - Qualifier ascending
+ - Timestamp in reverse chronological order
+
+ If family or qualifier not passed, will include all
+
+ Can also be accessed through indexing:
+ cells = row["family", "qualifier"]
+ cells = row["family"]
+ """
+ if family is None:
+ if qualifier is not None:
+ # get_cells(None, "qualifier") is not allowed
+ raise ValueError("Qualifier passed without family")
+ else:
+ # return all cells on get_cells()
+ return self.cells
+ if qualifier is None:
+ # return all cells in family on get_cells(family)
+ return list(self._get_all_from_family(family))
+ if isinstance(qualifier, str):
+ qualifier = qualifier.encode("utf-8")
+ # return cells in family and qualifier on get_cells(family, qualifier)
+ if family not in self._index:
+ raise ValueError(f"Family '{family}' not found in row '{self.row_key!r}'")
+ if qualifier not in self._index[family]:
+ raise ValueError(
+ f"Qualifier '{qualifier!r}' not found in family '{family}' in row '{self.row_key!r}'"
+ )
+ return self._index[family][qualifier]
+
+ def _get_all_from_family(self, family: str) -> Generator[Cell, None, None]:
+ """
+ Returns all cells in the row for the family_id
+ """
+ if family not in self._index:
+ raise ValueError(f"Family '{family}' not found in row '{self.row_key!r}'")
+ for qualifier in self._index[family]:
+ yield from self._index[family][qualifier]
+
+ def __str__(self) -> str:
+ """
+ Human-readable string representation
+
+ {
+ (family='fam', qualifier=b'col'): [b'value', (+1 more),],
+ (family='fam', qualifier=b'col2'): [b'other'],
+ }
+ """
+ output = ["{"]
+ for family, qualifier in self._get_column_components():
+ cell_list = self[family, qualifier]
+ line = [f" (family={family!r}, qualifier={qualifier!r}): "]
+ if len(cell_list) == 0:
+ line.append("[],")
+ elif len(cell_list) == 1:
+ line.append(f"[{cell_list[0]}],")
+ else:
+ line.append(f"[{cell_list[0]}, (+{len(cell_list)-1} more)],")
+ output.append("".join(line))
+ output.append("}")
+ return "\n".join(output)
+
+ def __repr__(self):
+ cell_str_buffer = ["{"]
+ for family, qualifier in self._get_column_components():
+ cell_list = self[family, qualifier]
+ repr_list = [cell._to_dict() for cell in cell_list]
+ cell_str_buffer.append(f" ('{family}', {qualifier!r}): {repr_list},")
+ cell_str_buffer.append("}")
+ cell_str = "\n".join(cell_str_buffer)
+ output = f"Row(key={self.row_key!r}, cells={cell_str})"
+ return output
+
+ def _to_dict(self) -> dict[str, Any]:
+ """
+ Returns a dictionary representation of the cell in the Bigtable Row
+ proto format
+
+ https://cloud.google.com/bigtable/docs/reference/data/rpc/google.bigtable.v2#row
+ """
+ family_list = []
+ for family_name, qualifier_dict in self._index.items():
+ qualifier_list = []
+ for qualifier_name, cell_list in qualifier_dict.items():
+ cell_dicts = [cell._to_dict() for cell in cell_list]
+ qualifier_list.append(
+ {"qualifier": qualifier_name, "cells": cell_dicts}
+ )
+ family_list.append({"name": family_name, "columns": qualifier_list})
+ return {"key": self.row_key, "families": family_list}
+
+ # Sequence and Mapping methods
+ def __iter__(self):
+ """
+ Allow iterating over all cells in the row
+ """
+ return iter(self.cells)
+
+ def __contains__(self, item):
+ """
+ Implements `in` operator
+
+ Works for both cells in the internal list, and `family` or
+ `(family, qualifier)` pairs associated with the cells
+ """
+ if isinstance(item, _family_type):
+ return item in self._index
+ elif (
+ isinstance(item, tuple)
+ and isinstance(item[0], _family_type)
+ and isinstance(item[1], (bytes, str))
+ ):
+ q = item[1] if isinstance(item[1], bytes) else item[1].encode("utf-8")
+ return item[0] in self._index and q in self._index[item[0]]
+ # check if Cell is in Row
+ return item in self.cells
+
+ @overload
+ def __getitem__(
+ self,
+ index: str | tuple[str, bytes | str],
+ ) -> list[Cell]:
+ # overload signature for type checking
+ pass
+
+ @overload
+ def __getitem__(self, index: int) -> Cell:
+ # overload signature for type checking
+ pass
+
+ @overload
+ def __getitem__(self, index: slice) -> list[Cell]:
+ # overload signature for type checking
+ pass
+
+ def __getitem__(self, index):
+ """
+ Implements [] indexing
+
+ Supports indexing by family, (family, qualifier) pair,
+ numerical index, and index slicing
+ """
+ if isinstance(index, _family_type):
+ return self.get_cells(family=index)
+ elif (
+ isinstance(index, tuple)
+ and isinstance(index[0], _family_type)
+ and isinstance(index[1], (bytes, str))
+ ):
+ return self.get_cells(family=index[0], qualifier=index[1])
+ elif isinstance(index, int) or isinstance(index, slice):
+ # index is int or slice
+ return self.cells[index]
+ else:
+ raise TypeError(
+ "Index must be family_id, (family_id, qualifier), int, or slice"
+ )
+
+ def __len__(self):
+ """
+ Implements `len()` operator
+ """
+ return len(self.cells)
+
+ def _get_column_components(self) -> list[tuple[str, bytes]]:
+ """
+ Returns a list of (family, qualifier) pairs associated with the cells
+
+ Pairs can be used for indexing
+ """
+ return [(f, q) for f in self._index for q in self._index[f]]
+
+ def __eq__(self, other):
+ """
+ Implements `==` operator
+ """
+ # for performance reasons, check row metadata
+ # before checking individual cells
+ if not isinstance(other, Row):
+ return False
+ if self.row_key != other.row_key:
+ return False
+ if len(self.cells) != len(other.cells):
+ return False
+ components = self._get_column_components()
+ other_components = other._get_column_components()
+ if len(components) != len(other_components):
+ return False
+ if components != other_components:
+ return False
+ for family, qualifier in components:
+ if len(self[family, qualifier]) != len(other[family, qualifier]):
+ return False
+ # compare individual cell lists
+ if self.cells != other.cells:
+ return False
+ return True
+
+ def __ne__(self, other) -> bool:
+ """
+ Implements `!=` operator
+ """
+ return not self == other
+
+
+@total_ordering
+class Cell:
+ """
+ Model class for cell data
+
+ Does not represent all data contained in the cell, only data returned by a
+ query.
+ Expected to be read-only to users, and written by backend
+ """
+
+ __slots__ = (
+ "value",
+ "row_key",
+ "family",
+ "qualifier",
+ "timestamp_micros",
+ "labels",
+ )
+
+ def __init__(
+ self,
+ value: bytes,
+ row_key: bytes,
+ family: str,
+ qualifier: bytes | str,
+ timestamp_micros: int,
+ labels: list[str] | None = None,
+ ):
+ """
+ Cell constructor
+
+ Cell objects are not intended to be constructed by users.
+ They are returned by the Bigtable backend.
+ """
+ self.value = value
+ self.row_key = row_key
+ self.family = family
+ if isinstance(qualifier, str):
+ qualifier = qualifier.encode()
+ self.qualifier = qualifier
+ self.timestamp_micros = timestamp_micros
+ self.labels = labels if labels is not None else []
+
+ def __int__(self) -> int:
+ """
+ Allows casting cell to int
+ Interprets value as a 64-bit big-endian signed integer, as expected by
+ ReadModifyWrite increment rule
+ """
+ return int.from_bytes(self.value, byteorder="big", signed=True)
+
+ def _to_dict(self) -> dict[str, Any]:
+ """
+ Returns a dictionary representation of the cell in the Bigtable Cell
+ proto format
+
+ https://cloud.google.com/bigtable/docs/reference/data/rpc/google.bigtable.v2#cell
+ """
+ cell_dict: dict[str, Any] = {
+ "value": self.value,
+ }
+ cell_dict["timestamp_micros"] = self.timestamp_micros
+ if self.labels:
+ cell_dict["labels"] = self.labels
+ return cell_dict
+
+ def __str__(self) -> str:
+ """
+ Allows casting cell to str
+ Prints encoded byte string, same as printing value directly.
+ """
+ return str(self.value)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the cell
+ """
+ return f"Cell(value={self.value!r}, row_key={self.row_key!r}, family='{self.family}', qualifier={self.qualifier!r}, timestamp_micros={self.timestamp_micros}, labels={self.labels})"
+
+ """For Bigtable native ordering"""
+
+ def __lt__(self, other) -> bool:
+ """
+ Implements `<` operator
+ """
+ if not isinstance(other, Cell):
+ return NotImplemented
+ this_ordering = (
+ self.family,
+ self.qualifier,
+ -self.timestamp_micros,
+ self.value,
+ self.labels,
+ )
+ other_ordering = (
+ other.family,
+ other.qualifier,
+ -other.timestamp_micros,
+ other.value,
+ other.labels,
+ )
+ return this_ordering < other_ordering
+
+ def __eq__(self, other) -> bool:
+ """
+ Implements `==` operator
+ """
+ if not isinstance(other, Cell):
+ return NotImplemented
+ return (
+ self.row_key == other.row_key
+ and self.family == other.family
+ and self.qualifier == other.qualifier
+ and self.value == other.value
+ and self.timestamp_micros == other.timestamp_micros
+ and len(self.labels) == len(other.labels)
+ and all([label in other.labels for label in self.labels])
+ )
+
+ def __ne__(self, other) -> bool:
+ """
+ Implements `!=` operator
+ """
+ return not self == other
+
+ def __hash__(self):
+ """
+ Implements `hash()` function to fingerprint cell
+ """
+ return hash(
+ (
+ self.row_key,
+ self.family,
+ self.qualifier,
+ self.value,
+ self.timestamp_micros,
+ tuple(self.labels),
+ )
+ )
diff --git a/google/cloud/bigtable/data/row_filters.py b/google/cloud/bigtable/data/row_filters.py
new file mode 100644
index 000000000..9f09133d5
--- /dev/null
+++ b/google/cloud/bigtable/data/row_filters.py
@@ -0,0 +1,968 @@
+# Copyright 2016 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Filters for Google Cloud Bigtable Row classes."""
+from __future__ import annotations
+
+import struct
+
+from typing import Any, Sequence, TYPE_CHECKING, overload
+from abc import ABC, abstractmethod
+
+from google.cloud._helpers import _microseconds_from_datetime # type: ignore
+from google.cloud._helpers import _to_bytes # type: ignore
+from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+if TYPE_CHECKING:
+ # import dependencies when type checking
+ from datetime import datetime
+
+_PACK_I64 = struct.Struct(">q").pack
+
+
+class RowFilter(ABC):
+ """Basic filter to apply to cells in a row.
+
+ These values can be combined via :class:`RowFilterChain`,
+ :class:`RowFilterUnion` and :class:`ConditionalRowFilter`.
+
+ .. note::
+
+ This class is a do-nothing base class for all row filters.
+ """
+
+ def _to_pb(self) -> data_v2_pb2.RowFilter:
+ """Converts the row filter to a protobuf.
+
+ Returns: The converted current object.
+ """
+ return data_v2_pb2.RowFilter(**self._to_dict())
+
+ @abstractmethod
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ pass
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}()"
+
+
+class _BoolFilter(RowFilter, ABC):
+ """Row filter that uses a boolean flag.
+
+ :type flag: bool
+ :param flag: An indicator if a setting is turned on or off.
+ """
+
+ def __init__(self, flag: bool):
+ self.flag = flag
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return other.flag == self.flag
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(flag={self.flag})"
+
+
+class SinkFilter(_BoolFilter):
+ """Advanced row filter to skip parent filters.
+
+ :type flag: bool
+ :param flag: ADVANCED USE ONLY. Hook for introspection into the row filter.
+ Outputs all cells directly to the output of the read rather
+ than to any parent filter. Cannot be used within the
+ ``predicate_filter``, ``true_filter``, or ``false_filter``
+ of a :class:`ConditionalRowFilter`.
+ """
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"sink": self.flag}
+
+
+class PassAllFilter(_BoolFilter):
+ """Row filter equivalent to not filtering at all.
+
+ :type flag: bool
+ :param flag: Matches all cells, regardless of input. Functionally
+ equivalent to leaving ``filter`` unset, but included for
+ completeness.
+ """
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"pass_all_filter": self.flag}
+
+
+class BlockAllFilter(_BoolFilter):
+ """Row filter that doesn't match any cells.
+
+ :type flag: bool
+ :param flag: Does not match any cells, regardless of input. Useful for
+ temporarily disabling just part of a filter.
+ """
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"block_all_filter": self.flag}
+
+
+class _RegexFilter(RowFilter, ABC):
+ """Row filter that uses a regular expression.
+
+ The ``regex`` must be valid RE2 patterns. See Google's
+ `RE2 reference`_ for the accepted syntax.
+
+ .. _RE2 reference: https://github.com/google/re2/wiki/Syntax
+
+ :type regex: bytes or str
+ :param regex:
+ A regular expression (RE2) for some row filter. String values
+ will be encoded as ASCII.
+ """
+
+ def __init__(self, regex: str | bytes):
+ self.regex: bytes = _to_bytes(regex)
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return other.regex == self.regex
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(regex={self.regex!r})"
+
+
+class RowKeyRegexFilter(_RegexFilter):
+ """Row filter for a row key regular expression.
+
+ The ``regex`` must be valid RE2 patterns. See Google's
+ `RE2 reference`_ for the accepted syntax.
+
+ .. _RE2 reference: https://github.com/google/re2/wiki/Syntax
+
+ .. note::
+
+ Special care need be used with the expression used. Since
+ each of these properties can contain arbitrary bytes, the ``\\C``
+ escape sequence must be used if a true wildcard is desired. The ``.``
+ character will not match the new line character ``\\n``, which may be
+ present in a binary value.
+
+ :type regex: bytes
+ :param regex: A regular expression (RE2) to match cells from rows with row
+ keys that satisfy this regex. For a
+ ``CheckAndMutateRowRequest``, this filter is unnecessary
+ since the row key is already specified.
+ """
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"row_key_regex_filter": self.regex}
+
+
+class RowSampleFilter(RowFilter):
+ """Matches all cells from a row with probability p.
+
+ :type sample: float
+ :param sample: The probability of matching a cell (must be in the
+ interval ``(0, 1)`` The end points are excluded).
+ """
+
+ def __init__(self, sample: float):
+ self.sample: float = sample
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return other.sample == self.sample
+
+ def __ne__(self, other):
+ return not self == other
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"row_sample_filter": self.sample}
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(sample={self.sample})"
+
+
+class FamilyNameRegexFilter(_RegexFilter):
+ """Row filter for a family name regular expression.
+
+ The ``regex`` must be valid RE2 patterns. See Google's
+ `RE2 reference`_ for the accepted syntax.
+
+ .. _RE2 reference: https://github.com/google/re2/wiki/Syntax
+
+ :type regex: str
+ :param regex: A regular expression (RE2) to match cells from columns in a
+ given column family. For technical reasons, the regex must
+ not contain the ``':'`` character, even if it is not being
+ used as a literal.
+ """
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"family_name_regex_filter": self.regex}
+
+
+class ColumnQualifierRegexFilter(_RegexFilter):
+ """Row filter for a column qualifier regular expression.
+
+ The ``regex`` must be valid RE2 patterns. See Google's
+ `RE2 reference`_ for the accepted syntax.
+
+ .. _RE2 reference: https://github.com/google/re2/wiki/Syntax
+
+ .. note::
+
+ Special care need be used with the expression used. Since
+ each of these properties can contain arbitrary bytes, the ``\\C``
+ escape sequence must be used if a true wildcard is desired. The ``.``
+ character will not match the new line character ``\\n``, which may be
+ present in a binary value.
+
+ :type regex: bytes
+ :param regex: A regular expression (RE2) to match cells from column that
+ match this regex (irrespective of column family).
+ """
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"column_qualifier_regex_filter": self.regex}
+
+
+class TimestampRange(object):
+ """Range of time with inclusive lower and exclusive upper bounds.
+
+ :type start: :class:`datetime.datetime`
+ :param start: (Optional) The (inclusive) lower bound of the timestamp
+ range. If omitted, defaults to Unix epoch.
+
+ :type end: :class:`datetime.datetime`
+ :param end: (Optional) The (exclusive) upper bound of the timestamp
+ range. If omitted, no upper bound is used.
+ """
+
+ def __init__(self, start: "datetime" | None = None, end: "datetime" | None = None):
+ self.start: "datetime" | None = start
+ self.end: "datetime" | None = end
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return other.start == self.start and other.end == self.end
+
+ def __ne__(self, other):
+ return not self == other
+
+ def _to_pb(self) -> data_v2_pb2.TimestampRange:
+ """Converts the :class:`TimestampRange` to a protobuf.
+
+ Returns: The converted current object.
+ """
+ return data_v2_pb2.TimestampRange(**self._to_dict())
+
+ def _to_dict(self) -> dict[str, int]:
+ """Converts the timestamp range to a dict representation."""
+ timestamp_range_kwargs = {}
+ if self.start is not None:
+ start_time = _microseconds_from_datetime(self.start) // 1000 * 1000
+ timestamp_range_kwargs["start_timestamp_micros"] = start_time
+ if self.end is not None:
+ end_time = _microseconds_from_datetime(self.end)
+ if end_time % 1000 != 0:
+ # if not a whole milisecond value, round up
+ end_time = end_time // 1000 * 1000 + 1000
+ timestamp_range_kwargs["end_timestamp_micros"] = end_time
+ return timestamp_range_kwargs
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(start={self.start}, end={self.end})"
+
+
+class TimestampRangeFilter(RowFilter):
+ """Row filter that limits cells to a range of time.
+
+ :type range_: :class:`TimestampRange`
+ :param range_: Range of time that cells should match against.
+ """
+
+ def __init__(self, start: "datetime" | None = None, end: "datetime" | None = None):
+ self.range_: TimestampRange = TimestampRange(start, end)
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return other.range_ == self.range_
+
+ def __ne__(self, other):
+ return not self == other
+
+ def _to_pb(self) -> data_v2_pb2.RowFilter:
+ """Converts the row filter to a protobuf.
+
+ First converts the ``range_`` on the current object to a protobuf and
+ then uses it in the ``timestamp_range_filter`` field.
+
+ Returns: The converted current object.
+ """
+ return data_v2_pb2.RowFilter(timestamp_range_filter=self.range_._to_pb())
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"timestamp_range_filter": self.range_._to_dict()}
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(start={self.range_.start!r}, end={self.range_.end!r})"
+
+
+class ColumnRangeFilter(RowFilter):
+ """A row filter to restrict to a range of columns.
+
+ Both the start and end column can be included or excluded in the range.
+ By default, we include them both, but this can be changed with optional
+ flags.
+
+ :type family_id: str
+ :param family_id: The column family that contains the columns. Must
+ be of the form ``[_a-zA-Z0-9][-_.a-zA-Z0-9]*``.
+
+ :type start_qualifier: bytes
+ :param start_qualifier: The start of the range of columns. If no value is
+ used, the backend applies no upper bound to the
+ values.
+
+ :type end_qualifier: bytes
+ :param end_qualifier: The end of the range of columns. If no value is used,
+ the backend applies no upper bound to the values.
+
+ :type inclusive_start: bool
+ :param inclusive_start: Boolean indicating if the start column should be
+ included in the range (or excluded). Defaults
+ to :data:`True` if ``start_qualifier`` is passed and
+ no ``inclusive_start`` was given.
+
+ :type inclusive_end: bool
+ :param inclusive_end: Boolean indicating if the end column should be
+ included in the range (or excluded). Defaults
+ to :data:`True` if ``end_qualifier`` is passed and
+ no ``inclusive_end`` was given.
+
+ :raises: :class:`ValueError ` if ``inclusive_start``
+ is set but no ``start_qualifier`` is given or if ``inclusive_end``
+ is set but no ``end_qualifier`` is given
+ """
+
+ def __init__(
+ self,
+ family_id: str,
+ start_qualifier: bytes | None = None,
+ end_qualifier: bytes | None = None,
+ inclusive_start: bool | None = None,
+ inclusive_end: bool | None = None,
+ ):
+ if inclusive_start is None:
+ inclusive_start = True
+ elif start_qualifier is None:
+ raise ValueError(
+ "inclusive_start was specified but no start_qualifier was given."
+ )
+ if inclusive_end is None:
+ inclusive_end = True
+ elif end_qualifier is None:
+ raise ValueError(
+ "inclusive_end was specified but no end_qualifier was given."
+ )
+
+ self.family_id = family_id
+
+ self.start_qualifier = start_qualifier
+ self.inclusive_start = inclusive_start
+
+ self.end_qualifier = end_qualifier
+ self.inclusive_end = inclusive_end
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return (
+ other.family_id == self.family_id
+ and other.start_qualifier == self.start_qualifier
+ and other.end_qualifier == self.end_qualifier
+ and other.inclusive_start == self.inclusive_start
+ and other.inclusive_end == self.inclusive_end
+ )
+
+ def __ne__(self, other):
+ return not self == other
+
+ def _to_pb(self) -> data_v2_pb2.RowFilter:
+ """Converts the row filter to a protobuf.
+
+ First converts to a :class:`.data_v2_pb2.ColumnRange` and then uses it
+ in the ``column_range_filter`` field.
+
+ Returns: The converted current object.
+ """
+ column_range = data_v2_pb2.ColumnRange(**self._range_to_dict())
+ return data_v2_pb2.RowFilter(column_range_filter=column_range)
+
+ def _range_to_dict(self) -> dict[str, str | bytes]:
+ """Converts the column range range to a dict representation."""
+ column_range_kwargs: dict[str, str | bytes] = {}
+ column_range_kwargs["family_name"] = self.family_id
+ if self.start_qualifier is not None:
+ if self.inclusive_start:
+ key = "start_qualifier_closed"
+ else:
+ key = "start_qualifier_open"
+ column_range_kwargs[key] = _to_bytes(self.start_qualifier)
+ if self.end_qualifier is not None:
+ if self.inclusive_end:
+ key = "end_qualifier_closed"
+ else:
+ key = "end_qualifier_open"
+ column_range_kwargs[key] = _to_bytes(self.end_qualifier)
+ return column_range_kwargs
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"column_range_filter": self._range_to_dict()}
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(family_id='{self.family_id}', start_qualifier={self.start_qualifier!r}, end_qualifier={self.end_qualifier!r}, inclusive_start={self.inclusive_start}, inclusive_end={self.inclusive_end})"
+
+
+class ValueRegexFilter(_RegexFilter):
+ """Row filter for a value regular expression.
+
+ The ``regex`` must be valid RE2 patterns. See Google's
+ `RE2 reference`_ for the accepted syntax.
+
+ .. _RE2 reference: https://github.com/google/re2/wiki/Syntax
+
+ .. note::
+
+ Special care need be used with the expression used. Since
+ each of these properties can contain arbitrary bytes, the ``\\C``
+ escape sequence must be used if a true wildcard is desired. The ``.``
+ character will not match the new line character ``\\n``, which may be
+ present in a binary value.
+
+ :type regex: bytes or str
+ :param regex: A regular expression (RE2) to match cells with values that
+ match this regex. String values will be encoded as ASCII.
+ """
+
+ def _to_dict(self) -> dict[str, bytes]:
+ """Converts the row filter to a dict representation."""
+ return {"value_regex_filter": self.regex}
+
+
+class LiteralValueFilter(ValueRegexFilter):
+ """Row filter for an exact value.
+
+
+ :type value: bytes or str or int
+ :param value:
+ a literal string, integer, or the equivalent bytes.
+ Integer values will be packed into signed 8-bytes.
+ """
+
+ def __init__(self, value: bytes | str | int):
+ if isinstance(value, int):
+ value = _PACK_I64(value)
+ elif isinstance(value, str):
+ value = value.encode("utf-8")
+ value = self._write_literal_regex(value)
+ super(LiteralValueFilter, self).__init__(value)
+
+ @staticmethod
+ def _write_literal_regex(input_bytes: bytes) -> bytes:
+ """
+ Escape re2 special characters from literal bytes.
+
+ Extracted from: re2 QuoteMeta:
+ https://github.com/google/re2/blob/70f66454c255080a54a8da806c52d1f618707f8a/re2/re2.cc#L456
+ """
+ result = bytearray()
+ for byte in input_bytes:
+ # If this is the part of a UTF8 or Latin1 character, we need \
+ # to copy this byte without escaping. Experimentally this is \
+ # what works correctly with the regexp library. \
+ utf8_latin1_check = (byte & 128) == 0
+ if (
+ (byte < ord("a") or byte > ord("z"))
+ and (byte < ord("A") or byte > ord("Z"))
+ and (byte < ord("0") or byte > ord("9"))
+ and byte != ord("_")
+ and utf8_latin1_check
+ ):
+ if byte == 0:
+ # Special handling for null chars.
+ # Note that this special handling is not strictly required for RE2,
+ # but this quoting is required for other regexp libraries such as
+ # PCRE.
+ # Can't use "\\0" since the next character might be a digit.
+ result.extend([ord("\\"), ord("x"), ord("0"), ord("0")])
+ continue
+ result.append(ord(b"\\"))
+ result.append(byte)
+ return bytes(result)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(value={self.regex!r})"
+
+
+class ValueRangeFilter(RowFilter):
+ """A range of values to restrict to in a row filter.
+
+ Will only match cells that have values in this range.
+
+ Both the start and end value can be included or excluded in the range.
+ By default, we include them both, but this can be changed with optional
+ flags.
+
+ :type start_value: bytes
+ :param start_value: The start of the range of values. If no value is used,
+ the backend applies no lower bound to the values.
+
+ :type end_value: bytes
+ :param end_value: The end of the range of values. If no value is used,
+ the backend applies no upper bound to the values.
+
+ :type inclusive_start: bool
+ :param inclusive_start: Boolean indicating if the start value should be
+ included in the range (or excluded). Defaults
+ to :data:`True` if ``start_value`` is passed and
+ no ``inclusive_start`` was given.
+
+ :type inclusive_end: bool
+ :param inclusive_end: Boolean indicating if the end value should be
+ included in the range (or excluded). Defaults
+ to :data:`True` if ``end_value`` is passed and
+ no ``inclusive_end`` was given.
+
+ :raises: :class:`ValueError ` if ``inclusive_start``
+ is set but no ``start_value`` is given or if ``inclusive_end``
+ is set but no ``end_value`` is given
+ """
+
+ def __init__(
+ self,
+ start_value: bytes | int | None = None,
+ end_value: bytes | int | None = None,
+ inclusive_start: bool | None = None,
+ inclusive_end: bool | None = None,
+ ):
+ if inclusive_start is None:
+ inclusive_start = True
+ elif start_value is None:
+ raise ValueError(
+ "inclusive_start was specified but no start_value was given."
+ )
+ if inclusive_end is None:
+ inclusive_end = True
+ elif end_value is None:
+ raise ValueError(
+ "inclusive_end was specified but no end_qualifier was given."
+ )
+ if isinstance(start_value, int):
+ start_value = _PACK_I64(start_value)
+ self.start_value = start_value
+ self.inclusive_start = inclusive_start
+
+ if isinstance(end_value, int):
+ end_value = _PACK_I64(end_value)
+ self.end_value = end_value
+ self.inclusive_end = inclusive_end
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return (
+ other.start_value == self.start_value
+ and other.end_value == self.end_value
+ and other.inclusive_start == self.inclusive_start
+ and other.inclusive_end == self.inclusive_end
+ )
+
+ def __ne__(self, other):
+ return not self == other
+
+ def _to_pb(self) -> data_v2_pb2.RowFilter:
+ """Converts the row filter to a protobuf.
+
+ First converts to a :class:`.data_v2_pb2.ValueRange` and then uses
+ it to create a row filter protobuf.
+
+ Returns: The converted current object.
+ """
+ value_range = data_v2_pb2.ValueRange(**self._range_to_dict())
+ return data_v2_pb2.RowFilter(value_range_filter=value_range)
+
+ def _range_to_dict(self) -> dict[str, bytes]:
+ """Converts the value range range to a dict representation."""
+ value_range_kwargs = {}
+ if self.start_value is not None:
+ if self.inclusive_start:
+ key = "start_value_closed"
+ else:
+ key = "start_value_open"
+ value_range_kwargs[key] = _to_bytes(self.start_value)
+ if self.end_value is not None:
+ if self.inclusive_end:
+ key = "end_value_closed"
+ else:
+ key = "end_value_open"
+ value_range_kwargs[key] = _to_bytes(self.end_value)
+ return value_range_kwargs
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"value_range_filter": self._range_to_dict()}
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(start_value={self.start_value!r}, end_value={self.end_value!r}, inclusive_start={self.inclusive_start}, inclusive_end={self.inclusive_end})"
+
+
+class _CellCountFilter(RowFilter, ABC):
+ """Row filter that uses an integer count of cells.
+
+ The cell count is used as an offset or a limit for the number
+ of results returned.
+
+ :type num_cells: int
+ :param num_cells: An integer count / offset / limit.
+ """
+
+ def __init__(self, num_cells: int):
+ self.num_cells = num_cells
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return other.num_cells == self.num_cells
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(num_cells={self.num_cells})"
+
+
+class CellsRowOffsetFilter(_CellCountFilter):
+ """Row filter to skip cells in a row.
+
+ :type num_cells: int
+ :param num_cells: Skips the first N cells of the row.
+ """
+
+ def _to_dict(self) -> dict[str, int]:
+ """Converts the row filter to a dict representation."""
+ return {"cells_per_row_offset_filter": self.num_cells}
+
+
+class CellsRowLimitFilter(_CellCountFilter):
+ """Row filter to limit cells in a row.
+
+ :type num_cells: int
+ :param num_cells: Matches only the first N cells of the row.
+ """
+
+ def _to_dict(self) -> dict[str, int]:
+ """Converts the row filter to a dict representation."""
+ return {"cells_per_row_limit_filter": self.num_cells}
+
+
+class CellsColumnLimitFilter(_CellCountFilter):
+ """Row filter to limit cells in a column.
+
+ :type num_cells: int
+ :param num_cells: Matches only the most recent N cells within each column.
+ This filters a (family name, column) pair, based on
+ timestamps of each cell.
+ """
+
+ def _to_dict(self) -> dict[str, int]:
+ """Converts the row filter to a dict representation."""
+ return {"cells_per_column_limit_filter": self.num_cells}
+
+
+class StripValueTransformerFilter(_BoolFilter):
+ """Row filter that transforms cells into empty string (0 bytes).
+
+ :type flag: bool
+ :param flag: If :data:`True`, replaces each cell's value with the empty
+ string. As the name indicates, this is more useful as a
+ transformer than a generic query / filter.
+ """
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"strip_value_transformer": self.flag}
+
+
+class ApplyLabelFilter(RowFilter):
+ """Filter to apply labels to cells.
+
+ Intended to be used as an intermediate filter on a pre-existing filtered
+ result set. This way if two sets are combined, the label can tell where
+ the cell(s) originated.This allows the client to determine which results
+ were produced from which part of the filter.
+
+ .. note::
+
+ Due to a technical limitation of the backend, it is not currently
+ possible to apply multiple labels to a cell.
+
+ :type label: str
+ :param label: Label to apply to cells in the output row. Values must be
+ at most 15 characters long, and match the pattern
+ ``[a-z0-9\\-]+``.
+ """
+
+ def __init__(self, label: str):
+ self.label = label
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return other.label == self.label
+
+ def __ne__(self, other):
+ return not self == other
+
+ def _to_dict(self) -> dict[str, str]:
+ """Converts the row filter to a dict representation."""
+ return {"apply_label_transformer": self.label}
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(label={self.label})"
+
+
+class _FilterCombination(RowFilter, Sequence[RowFilter], ABC):
+ """Chain of row filters.
+
+ Sends rows through several filters in sequence. The filters are "chained"
+ together to process a row. After the first filter is applied, the second
+ is applied to the filtered output and so on for subsequent filters.
+
+ :type filters: list
+ :param filters: List of :class:`RowFilter`
+ """
+
+ def __init__(self, filters: list[RowFilter] | None = None):
+ if filters is None:
+ filters = []
+ self.filters: list[RowFilter] = filters
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return other.filters == self.filters
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __len__(self) -> int:
+ return len(self.filters)
+
+ @overload
+ def __getitem__(self, index: int) -> RowFilter:
+ # overload signature for type checking
+ pass
+
+ @overload
+ def __getitem__(self, index: slice) -> list[RowFilter]:
+ # overload signature for type checking
+ pass
+
+ def __getitem__(self, index):
+ return self.filters[index]
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(filters={self.filters})"
+
+ def __str__(self) -> str:
+ """
+ Returns a string representation of the filter chain.
+
+ Adds line breaks between each sub-filter for readability.
+ """
+ output = [f"{self.__class__.__name__}(["]
+ for filter_ in self.filters:
+ filter_lines = f"{filter_},".splitlines()
+ output.extend([f" {line}" for line in filter_lines])
+ output.append("])")
+ return "\n".join(output)
+
+
+class RowFilterChain(_FilterCombination):
+ """Chain of row filters.
+
+ Sends rows through several filters in sequence. The filters are "chained"
+ together to process a row. After the first filter is applied, the second
+ is applied to the filtered output and so on for subsequent filters.
+
+ :type filters: list
+ :param filters: List of :class:`RowFilter`
+ """
+
+ def _to_pb(self) -> data_v2_pb2.RowFilter:
+ """Converts the row filter to a protobuf.
+
+ Returns: The converted current object.
+ """
+ chain = data_v2_pb2.RowFilter.Chain(
+ filters=[row_filter._to_pb() for row_filter in self.filters]
+ )
+ return data_v2_pb2.RowFilter(chain=chain)
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"chain": {"filters": [f._to_dict() for f in self.filters]}}
+
+
+class RowFilterUnion(_FilterCombination):
+ """Union of row filters.
+
+ Sends rows through several filters simultaneously, then
+ merges / interleaves all the filtered results together.
+
+ If multiple cells are produced with the same column and timestamp,
+ they will all appear in the output row in an unspecified mutual order.
+
+ :type filters: list
+ :param filters: List of :class:`RowFilter`
+ """
+
+ def _to_pb(self) -> data_v2_pb2.RowFilter:
+ """Converts the row filter to a protobuf.
+
+ Returns: The converted current object.
+ """
+ interleave = data_v2_pb2.RowFilter.Interleave(
+ filters=[row_filter._to_pb() for row_filter in self.filters]
+ )
+ return data_v2_pb2.RowFilter(interleave=interleave)
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"interleave": {"filters": [f._to_dict() for f in self.filters]}}
+
+
+class ConditionalRowFilter(RowFilter):
+ """Conditional row filter which exhibits ternary behavior.
+
+ Executes one of two filters based on another filter. If the ``predicate_filter``
+ returns any cells in the row, then ``true_filter`` is executed. If not,
+ then ``false_filter`` is executed.
+
+ .. note::
+
+ The ``predicate_filter`` does not execute atomically with the true and false
+ filters, which may lead to inconsistent or unexpected results.
+
+ Additionally, executing a :class:`ConditionalRowFilter` has poor
+ performance on the server, especially when ``false_filter`` is set.
+
+ :type predicate_filter: :class:`RowFilter`
+ :param predicate_filter: The filter to condition on before executing the
+ true/false filters.
+
+ :type true_filter: :class:`RowFilter`
+ :param true_filter: (Optional) The filter to execute if there are any cells
+ matching ``predicate_filter``. If not provided, no results
+ will be returned in the true case.
+
+ :type false_filter: :class:`RowFilter`
+ :param false_filter: (Optional) The filter to execute if there are no cells
+ matching ``predicate_filter``. If not provided, no results
+ will be returned in the false case.
+ """
+
+ def __init__(
+ self,
+ predicate_filter: RowFilter,
+ true_filter: RowFilter | None = None,
+ false_filter: RowFilter | None = None,
+ ):
+ self.predicate_filter = predicate_filter
+ self.true_filter = true_filter
+ self.false_filter = false_filter
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return (
+ other.predicate_filter == self.predicate_filter
+ and other.true_filter == self.true_filter
+ and other.false_filter == self.false_filter
+ )
+
+ def __ne__(self, other):
+ return not self == other
+
+ def _to_pb(self) -> data_v2_pb2.RowFilter:
+ """Converts the row filter to a protobuf.
+
+ Returns: The converted current object.
+ """
+ condition_kwargs = {"predicate_filter": self.predicate_filter._to_pb()}
+ if self.true_filter is not None:
+ condition_kwargs["true_filter"] = self.true_filter._to_pb()
+ if self.false_filter is not None:
+ condition_kwargs["false_filter"] = self.false_filter._to_pb()
+ condition = data_v2_pb2.RowFilter.Condition(**condition_kwargs)
+ return data_v2_pb2.RowFilter(condition=condition)
+
+ def _condition_to_dict(self) -> dict[str, Any]:
+ """Converts the condition to a dict representation."""
+ condition_kwargs = {"predicate_filter": self.predicate_filter._to_dict()}
+ if self.true_filter is not None:
+ condition_kwargs["true_filter"] = self.true_filter._to_dict()
+ if self.false_filter is not None:
+ condition_kwargs["false_filter"] = self.false_filter._to_dict()
+ return condition_kwargs
+
+ def _to_dict(self) -> dict[str, Any]:
+ """Converts the row filter to a dict representation."""
+ return {"condition": self._condition_to_dict()}
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(predicate_filter={self.predicate_filter!r}, true_filter={self.true_filter!r}, false_filter={self.false_filter!r})"
+
+ def __str__(self) -> str:
+ output = [f"{self.__class__.__name__}("]
+ for filter_type in ("predicate_filter", "true_filter", "false_filter"):
+ filter_ = getattr(self, filter_type)
+ if filter_ is None:
+ continue
+ # add the new filter set, adding indentations for readability
+ filter_lines = f"{filter_type}={filter_},".splitlines()
+ output.extend(f" {line}" for line in filter_lines)
+ output.append(")")
+ return "\n".join(output)
diff --git a/google/cloud/bigtable/gapic_version.py b/google/cloud/bigtable/gapic_version.py
index 03d6d0200..f01e1d3a5 100644
--- a/google/cloud/bigtable/gapic_version.py
+++ b/google/cloud/bigtable/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "2.22.0" # {x-release-please-version}
+__version__ = "2.23.0" # {x-release-please-version}
diff --git a/google/cloud/bigtable/py.typed b/google/cloud/bigtable/py.typed
deleted file mode 100644
index 7bd4705d4..000000000
--- a/google/cloud/bigtable/py.typed
+++ /dev/null
@@ -1,2 +0,0 @@
-# Marker file for PEP 561.
-# The google-cloud-bigtable package uses inline types.
diff --git a/google/cloud/bigtable_admin/gapic_version.py b/google/cloud/bigtable_admin/gapic_version.py
index 03d6d0200..f01e1d3a5 100644
--- a/google/cloud/bigtable_admin/gapic_version.py
+++ b/google/cloud/bigtable_admin/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "2.22.0" # {x-release-please-version}
+__version__ = "2.23.0" # {x-release-please-version}
diff --git a/google/cloud/bigtable_admin_v2/gapic_version.py b/google/cloud/bigtable_admin_v2/gapic_version.py
index 03d6d0200..f01e1d3a5 100644
--- a/google/cloud/bigtable_admin_v2/gapic_version.py
+++ b/google/cloud/bigtable_admin_v2/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "2.22.0" # {x-release-please-version}
+__version__ = "2.23.0" # {x-release-please-version}
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/async_client.py b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/async_client.py
index e4c4639af..ab14ddaed 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/async_client.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/async_client.py
@@ -38,9 +38,9 @@
from google.oauth2 import service_account # type: ignore
try:
- OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault]
+ OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None]
except AttributeError: # pragma: NO COVER
- OptionalRetry = Union[retries.AsyncRetry, object] # type: ignore
+ OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore
from google.api_core import operation # type: ignore
from google.api_core import operation_async # type: ignore
@@ -67,8 +67,12 @@ class BigtableInstanceAdminAsyncClient:
_client: BigtableInstanceAdminClient
+ # Copy defaults from the synchronous client for use here.
+ # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead.
DEFAULT_ENDPOINT = BigtableInstanceAdminClient.DEFAULT_ENDPOINT
DEFAULT_MTLS_ENDPOINT = BigtableInstanceAdminClient.DEFAULT_MTLS_ENDPOINT
+ _DEFAULT_ENDPOINT_TEMPLATE = BigtableInstanceAdminClient._DEFAULT_ENDPOINT_TEMPLATE
+ _DEFAULT_UNIVERSE = BigtableInstanceAdminClient._DEFAULT_UNIVERSE
app_profile_path = staticmethod(BigtableInstanceAdminClient.app_profile_path)
parse_app_profile_path = staticmethod(
@@ -193,6 +197,25 @@ def transport(self) -> BigtableInstanceAdminTransport:
"""
return self._client.transport
+ @property
+ def api_endpoint(self):
+ """Return the API endpoint used by the client instance.
+
+ Returns:
+ str: The API endpoint used by the client instance.
+ """
+ return self._client._api_endpoint
+
+ @property
+ def universe_domain(self) -> str:
+ """Return the universe domain used by the client instance.
+
+ Returns:
+ str: The universe domain used
+ by the client instance.
+ """
+ return self._client._universe_domain
+
get_transport_class = functools.partial(
type(BigtableInstanceAdminClient).get_transport_class,
type(BigtableInstanceAdminClient),
@@ -206,7 +229,7 @@ def __init__(
client_options: Optional[ClientOptions] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
- """Instantiates the bigtable instance admin client.
+ """Instantiates the bigtable instance admin async client.
Args:
credentials (Optional[google.auth.credentials.Credentials]): The
@@ -217,23 +240,38 @@ def __init__(
transport (Union[str, ~.BigtableInstanceAdminTransport]): The
transport to use. If set to None, a transport is chosen
automatically.
- client_options (ClientOptions): Custom options for the client. It
- won't take effect if a ``transport`` instance is provided.
- (1) The ``api_endpoint`` property can be used to override the
- default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
- environment variable can also be used to override the endpoint:
+ client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]):
+ Custom options for the client.
+
+ 1. The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client when ``transport`` is
+ not explicitly provided. Only if this property is not set and
+ ``transport`` was not explicitly provided, the endpoint is
+ determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment
+ variable, which have one of the following values:
"always" (always use the default mTLS endpoint), "never" (always
- use the default regular endpoint) and "auto" (auto switch to the
- default mTLS endpoint if client certificate is present, this is
- the default value). However, the ``api_endpoint`` property takes
- precedence if provided.
- (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ use the default regular endpoint) and "auto" (auto-switch to the
+ default mTLS endpoint if client certificate is present; this is
+ the default value).
+
+ 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
is "true", then the ``client_cert_source`` property can be used
- to provide client certificate for mutual TLS transport. If
+ to provide a client certificate for mTLS transport. If
not provided, the default SSL client certificate will be used if
present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
set, no client certificate will be used.
+ 3. The ``universe_domain`` property can be used to override the
+ default "googleapis.com" universe. Note that ``api_endpoint``
+ property still takes precedence; and ``universe_domain`` is
+ currently not supported for mTLS.
+
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+
Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
@@ -360,6 +398,9 @@ async def create_instance(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -460,6 +501,9 @@ async def get_instance(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -549,6 +593,9 @@ async def list_instances(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -623,6 +670,9 @@ async def update_instance(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -731,6 +781,9 @@ async def partial_update_instance(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -812,6 +865,9 @@ async def delete_instance(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
await rpc(
request,
@@ -920,6 +976,9 @@ async def create_cluster(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1019,6 +1078,9 @@ async def get_cluster(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1110,6 +1172,9 @@ async def list_clusters(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1184,6 +1249,9 @@ async def update_cluster(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1299,6 +1367,9 @@ async def partial_update_cluster(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1380,6 +1451,9 @@ async def delete_cluster(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
await rpc(
request,
@@ -1479,6 +1553,9 @@ async def create_app_profile(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1569,6 +1646,9 @@ async def get_app_profile(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1664,6 +1744,9 @@ async def list_app_profiles(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1776,6 +1859,9 @@ async def update_app_profile(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1857,6 +1943,9 @@ async def delete_app_profile(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
await rpc(
request,
@@ -1973,6 +2062,9 @@ async def get_iam_policy(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2081,6 +2173,9 @@ async def set_iam_policy(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2180,6 +2275,9 @@ async def test_iam_permissions(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2273,6 +2371,9 @@ async def list_hot_tablets(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/client.py b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/client.py
index 52c61ea4f..4c2c2998e 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/client.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/client.py
@@ -28,6 +28,7 @@
Union,
cast,
)
+import warnings
from google.cloud.bigtable_admin_v2 import gapic_version as package_version
@@ -42,9 +43,9 @@
from google.oauth2 import service_account # type: ignore
try:
- OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None]
except AttributeError: # pragma: NO COVER
- OptionalRetry = Union[retries.Retry, object] # type: ignore
+ OptionalRetry = Union[retries.Retry, object, None] # type: ignore
from google.api_core import operation # type: ignore
from google.api_core import operation_async # type: ignore
@@ -137,11 +138,15 @@ def _get_default_mtls_endpoint(api_endpoint):
return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")
+ # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead.
DEFAULT_ENDPOINT = "bigtableadmin.googleapis.com"
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)
+ _DEFAULT_ENDPOINT_TEMPLATE = "bigtableadmin.{UNIVERSE_DOMAIN}"
+ _DEFAULT_UNIVERSE = "googleapis.com"
+
@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
@@ -403,7 +408,7 @@ def parse_common_location_path(path: str) -> Dict[str, str]:
def get_mtls_endpoint_and_cert_source(
cls, client_options: Optional[client_options_lib.ClientOptions] = None
):
- """Return the API endpoint and client cert source for mutual TLS.
+ """Deprecated. Return the API endpoint and client cert source for mutual TLS.
The client cert source is determined in the following order:
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
@@ -433,6 +438,11 @@ def get_mtls_endpoint_and_cert_source(
Raises:
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
"""
+
+ warnings.warn(
+ "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.",
+ DeprecationWarning,
+ )
if client_options is None:
client_options = client_options_lib.ClientOptions()
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
@@ -466,6 +476,180 @@ def get_mtls_endpoint_and_cert_source(
return api_endpoint, client_cert_source
+ @staticmethod
+ def _read_environment_variables():
+ """Returns the environment variables used by the client.
+
+ Returns:
+ Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE,
+ GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables.
+
+ Raises:
+ ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not
+ any of ["true", "false"].
+ google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT
+ is not any of ["auto", "never", "always"].
+ """
+ use_client_cert = os.getenv(
+ "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"
+ ).lower()
+ use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower()
+ universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN")
+ if use_client_cert not in ("true", "false"):
+ raise ValueError(
+ "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+ if use_mtls_endpoint not in ("auto", "never", "always"):
+ raise MutualTLSChannelError(
+ "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+ return use_client_cert == "true", use_mtls_endpoint, universe_domain_env
+
+ @staticmethod
+ def _get_client_cert_source(provided_cert_source, use_cert_flag):
+ """Return the client cert source to be used by the client.
+
+ Args:
+ provided_cert_source (bytes): The client certificate source provided.
+ use_cert_flag (bool): A flag indicating whether to use the client certificate.
+
+ Returns:
+ bytes or None: The client cert source to be used by the client.
+ """
+ client_cert_source = None
+ if use_cert_flag:
+ if provided_cert_source:
+ client_cert_source = provided_cert_source
+ elif mtls.has_default_client_cert_source():
+ client_cert_source = mtls.default_client_cert_source()
+ return client_cert_source
+
+ @staticmethod
+ def _get_api_endpoint(
+ api_override, client_cert_source, universe_domain, use_mtls_endpoint
+ ):
+ """Return the API endpoint used by the client.
+
+ Args:
+ api_override (str): The API endpoint override. If specified, this is always
+ the return value of this function and the other arguments are not used.
+ client_cert_source (bytes): The client certificate source used by the client.
+ universe_domain (str): The universe domain used by the client.
+ use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters.
+ Possible values are "always", "auto", or "never".
+
+ Returns:
+ str: The API endpoint to be used by the client.
+ """
+ if api_override is not None:
+ api_endpoint = api_override
+ elif use_mtls_endpoint == "always" or (
+ use_mtls_endpoint == "auto" and client_cert_source
+ ):
+ _default_universe = BigtableInstanceAdminClient._DEFAULT_UNIVERSE
+ if universe_domain != _default_universe:
+ raise MutualTLSChannelError(
+ f"mTLS is not supported in any universe other than {_default_universe}."
+ )
+ api_endpoint = BigtableInstanceAdminClient.DEFAULT_MTLS_ENDPOINT
+ else:
+ api_endpoint = (
+ BigtableInstanceAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=universe_domain
+ )
+ )
+ return api_endpoint
+
+ @staticmethod
+ def _get_universe_domain(
+ client_universe_domain: Optional[str], universe_domain_env: Optional[str]
+ ) -> str:
+ """Return the universe domain used by the client.
+
+ Args:
+ client_universe_domain (Optional[str]): The universe domain configured via the client options.
+ universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable.
+
+ Returns:
+ str: The universe domain to be used by the client.
+
+ Raises:
+ ValueError: If the universe domain is an empty string.
+ """
+ universe_domain = BigtableInstanceAdminClient._DEFAULT_UNIVERSE
+ if client_universe_domain is not None:
+ universe_domain = client_universe_domain
+ elif universe_domain_env is not None:
+ universe_domain = universe_domain_env
+ if len(universe_domain.strip()) == 0:
+ raise ValueError("Universe Domain cannot be an empty string.")
+ return universe_domain
+
+ @staticmethod
+ def _compare_universes(
+ client_universe: str, credentials: ga_credentials.Credentials
+ ) -> bool:
+ """Returns True iff the universe domains used by the client and credentials match.
+
+ Args:
+ client_universe (str): The universe domain configured via the client options.
+ credentials (ga_credentials.Credentials): The credentials being used in the client.
+
+ Returns:
+ bool: True iff client_universe matches the universe in credentials.
+
+ Raises:
+ ValueError: when client_universe does not match the universe in credentials.
+ """
+
+ default_universe = BigtableInstanceAdminClient._DEFAULT_UNIVERSE
+ credentials_universe = getattr(credentials, "universe_domain", default_universe)
+
+ if client_universe != credentials_universe:
+ raise ValueError(
+ "The configured universe domain "
+ f"({client_universe}) does not match the universe domain "
+ f"found in the credentials ({credentials_universe}). "
+ "If you haven't configured the universe domain explicitly, "
+ f"`{default_universe}` is the default."
+ )
+ return True
+
+ def _validate_universe_domain(self):
+ """Validates client's and credentials' universe domains are consistent.
+
+ Returns:
+ bool: True iff the configured universe domain is valid.
+
+ Raises:
+ ValueError: If the configured universe domain is not valid.
+ """
+ self._is_universe_domain_valid = (
+ self._is_universe_domain_valid
+ or BigtableInstanceAdminClient._compare_universes(
+ self.universe_domain, self.transport._credentials
+ )
+ )
+ return self._is_universe_domain_valid
+
+ @property
+ def api_endpoint(self):
+ """Return the API endpoint used by the client instance.
+
+ Returns:
+ str: The API endpoint used by the client instance.
+ """
+ return self._api_endpoint
+
+ @property
+ def universe_domain(self) -> str:
+ """Return the universe domain used by the client instance.
+
+ Returns:
+ str: The universe domain used by the client instance.
+ """
+ return self._universe_domain
+
def __init__(
self,
*,
@@ -485,22 +669,32 @@ def __init__(
transport (Union[str, BigtableInstanceAdminTransport]): The
transport to use. If set to None, a transport is chosen
automatically.
- client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the
- client. It won't take effect if a ``transport`` instance is provided.
- (1) The ``api_endpoint`` property can be used to override the
- default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
- environment variable can also be used to override the endpoint:
+ client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]):
+ Custom options for the client.
+
+ 1. The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client when ``transport`` is
+ not explicitly provided. Only if this property is not set and
+ ``transport`` was not explicitly provided, the endpoint is
+ determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment
+ variable, which have one of the following values:
"always" (always use the default mTLS endpoint), "never" (always
- use the default regular endpoint) and "auto" (auto switch to the
- default mTLS endpoint if client certificate is present, this is
- the default value). However, the ``api_endpoint`` property takes
- precedence if provided.
- (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ use the default regular endpoint) and "auto" (auto-switch to the
+ default mTLS endpoint if client certificate is present; this is
+ the default value).
+
+ 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
is "true", then the ``client_cert_source`` property can be used
- to provide client certificate for mutual TLS transport. If
+ to provide a client certificate for mTLS transport. If
not provided, the default SSL client certificate will be used if
present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
set, no client certificate will be used.
+
+ 3. The ``universe_domain`` property can be used to override the
+ default "googleapis.com" universe. Note that the ``api_endpoint``
+ property still takes precedence; and ``universe_domain`` is
+ currently not supported for mTLS.
+
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
@@ -511,17 +705,34 @@ def __init__(
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
"""
- if isinstance(client_options, dict):
- client_options = client_options_lib.from_dict(client_options)
- if client_options is None:
- client_options = client_options_lib.ClientOptions()
- client_options = cast(client_options_lib.ClientOptions, client_options)
+ self._client_options = client_options
+ if isinstance(self._client_options, dict):
+ self._client_options = client_options_lib.from_dict(self._client_options)
+ if self._client_options is None:
+ self._client_options = client_options_lib.ClientOptions()
+ self._client_options = cast(
+ client_options_lib.ClientOptions, self._client_options
+ )
+
+ universe_domain_opt = getattr(self._client_options, "universe_domain", None)
- api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(
- client_options
+ (
+ self._use_client_cert,
+ self._use_mtls_endpoint,
+ self._universe_domain_env,
+ ) = BigtableInstanceAdminClient._read_environment_variables()
+ self._client_cert_source = BigtableInstanceAdminClient._get_client_cert_source(
+ self._client_options.client_cert_source, self._use_client_cert
)
+ self._universe_domain = BigtableInstanceAdminClient._get_universe_domain(
+ universe_domain_opt, self._universe_domain_env
+ )
+ self._api_endpoint = None # updated below, depending on `transport`
+
+ # Initialize the universe domain validation.
+ self._is_universe_domain_valid = False
- api_key_value = getattr(client_options, "api_key", None)
+ api_key_value = getattr(self._client_options, "api_key", None)
if api_key_value and credentials:
raise ValueError(
"client_options.api_key and credentials are mutually exclusive"
@@ -530,20 +741,33 @@ def __init__(
# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
- if isinstance(transport, BigtableInstanceAdminTransport):
+ transport_provided = isinstance(transport, BigtableInstanceAdminTransport)
+ if transport_provided:
# transport is a BigtableInstanceAdminTransport instance.
- if credentials or client_options.credentials_file or api_key_value:
+ if credentials or self._client_options.credentials_file or api_key_value:
raise ValueError(
"When providing a transport instance, "
"provide its credentials directly."
)
- if client_options.scopes:
+ if self._client_options.scopes:
raise ValueError(
"When providing a transport instance, provide its scopes "
"directly."
)
- self._transport = transport
- else:
+ self._transport = cast(BigtableInstanceAdminTransport, transport)
+ self._api_endpoint = self._transport.host
+
+ self._api_endpoint = (
+ self._api_endpoint
+ or BigtableInstanceAdminClient._get_api_endpoint(
+ self._client_options.api_endpoint,
+ self._client_cert_source,
+ self._universe_domain,
+ self._use_mtls_endpoint,
+ )
+ )
+
+ if not transport_provided:
import google.auth._default # type: ignore
if api_key_value and hasattr(
@@ -553,17 +777,17 @@ def __init__(
api_key_value
)
- Transport = type(self).get_transport_class(transport)
+ Transport = type(self).get_transport_class(cast(str, transport))
self._transport = Transport(
credentials=credentials,
- credentials_file=client_options.credentials_file,
- host=api_endpoint,
- scopes=client_options.scopes,
- client_cert_source_for_mtls=client_cert_source_func,
- quota_project_id=client_options.quota_project_id,
+ credentials_file=self._client_options.credentials_file,
+ host=self._api_endpoint,
+ scopes=self._client_options.scopes,
+ client_cert_source_for_mtls=self._client_cert_source,
+ quota_project_id=self._client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
- api_audience=client_options.api_audience,
+ api_audience=self._client_options.api_audience,
)
def create_instance(
@@ -680,6 +904,9 @@ def create_instance(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -770,6 +997,9 @@ def get_instance(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -849,6 +1079,9 @@ def list_instances(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -914,6 +1147,9 @@ def update_instance(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1014,6 +1250,9 @@ def partial_update_instance(
),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1095,6 +1334,9 @@ def delete_instance(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
rpc(
request,
@@ -1203,6 +1445,9 @@ def create_cluster(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1292,6 +1537,9 @@ def get_cluster(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1373,6 +1621,9 @@ def list_clusters(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1438,6 +1689,9 @@ def update_cluster(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1553,6 +1807,9 @@ def partial_update_cluster(
),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1634,6 +1891,9 @@ def delete_cluster(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
rpc(
request,
@@ -1733,6 +1993,9 @@ def create_app_profile(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1813,6 +2076,9 @@ def get_app_profile(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1898,6 +2164,9 @@ def list_app_profiles(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2000,6 +2269,9 @@ def update_app_profile(
),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2081,6 +2353,9 @@ def delete_app_profile(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
rpc(
request,
@@ -2184,6 +2459,9 @@ def get_iam_policy(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2289,6 +2567,9 @@ def set_iam_policy(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2376,6 +2657,9 @@ def test_iam_permissions(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2459,6 +2743,9 @@ def list_hot_tablets(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/base.py b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/base.py
index d92d25453..aeb07556c 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/base.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/base.py
@@ -71,7 +71,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtableadmin.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
@@ -134,6 +134,10 @@ def __init__(
host += ":443"
self._host = host
+ @property
+ def host(self):
+ return self._host
+
def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/grpc.py b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/grpc.py
index eca37957d..c47db6ba5 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/grpc.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/grpc.py
@@ -73,7 +73,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtableadmin.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/grpc_asyncio.py b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/grpc_asyncio.py
index 145aa427d..cbd77b381 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/grpc_asyncio.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/grpc_asyncio.py
@@ -118,7 +118,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtableadmin.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/rest.py b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/rest.py
index 9d5502b7e..61f425953 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/rest.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_instance_admin/transports/rest.py
@@ -35,9 +35,9 @@
import warnings
try:
- OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None]
except AttributeError: # pragma: NO COVER
- OptionalRetry = Union[retries.Retry, object] # type: ignore
+ OptionalRetry = Union[retries.Retry, object, None] # type: ignore
from google.cloud.bigtable_admin_v2.types import bigtable_instance_admin
@@ -734,7 +734,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtableadmin.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/async_client.py b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/async_client.py
index 5a4435bde..124b3ef09 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/async_client.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/async_client.py
@@ -38,9 +38,9 @@
from google.oauth2 import service_account # type: ignore
try:
- OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault]
+ OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None]
except AttributeError: # pragma: NO COVER
- OptionalRetry = Union[retries.AsyncRetry, object] # type: ignore
+ OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore
from google.api_core import operation # type: ignore
from google.api_core import operation_async # type: ignore
@@ -67,8 +67,12 @@ class BigtableTableAdminAsyncClient:
_client: BigtableTableAdminClient
+ # Copy defaults from the synchronous client for use here.
+ # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead.
DEFAULT_ENDPOINT = BigtableTableAdminClient.DEFAULT_ENDPOINT
DEFAULT_MTLS_ENDPOINT = BigtableTableAdminClient.DEFAULT_MTLS_ENDPOINT
+ _DEFAULT_ENDPOINT_TEMPLATE = BigtableTableAdminClient._DEFAULT_ENDPOINT_TEMPLATE
+ _DEFAULT_UNIVERSE = BigtableTableAdminClient._DEFAULT_UNIVERSE
backup_path = staticmethod(BigtableTableAdminClient.backup_path)
parse_backup_path = staticmethod(BigtableTableAdminClient.parse_backup_path)
@@ -189,6 +193,25 @@ def transport(self) -> BigtableTableAdminTransport:
"""
return self._client.transport
+ @property
+ def api_endpoint(self):
+ """Return the API endpoint used by the client instance.
+
+ Returns:
+ str: The API endpoint used by the client instance.
+ """
+ return self._client._api_endpoint
+
+ @property
+ def universe_domain(self) -> str:
+ """Return the universe domain used by the client instance.
+
+ Returns:
+ str: The universe domain used
+ by the client instance.
+ """
+ return self._client._universe_domain
+
get_transport_class = functools.partial(
type(BigtableTableAdminClient).get_transport_class,
type(BigtableTableAdminClient),
@@ -202,7 +225,7 @@ def __init__(
client_options: Optional[ClientOptions] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
- """Instantiates the bigtable table admin client.
+ """Instantiates the bigtable table admin async client.
Args:
credentials (Optional[google.auth.credentials.Credentials]): The
@@ -213,23 +236,38 @@ def __init__(
transport (Union[str, ~.BigtableTableAdminTransport]): The
transport to use. If set to None, a transport is chosen
automatically.
- client_options (ClientOptions): Custom options for the client. It
- won't take effect if a ``transport`` instance is provided.
- (1) The ``api_endpoint`` property can be used to override the
- default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
- environment variable can also be used to override the endpoint:
+ client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]):
+ Custom options for the client.
+
+ 1. The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client when ``transport`` is
+ not explicitly provided. Only if this property is not set and
+ ``transport`` was not explicitly provided, the endpoint is
+ determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment
+ variable, which have one of the following values:
"always" (always use the default mTLS endpoint), "never" (always
- use the default regular endpoint) and "auto" (auto switch to the
- default mTLS endpoint if client certificate is present, this is
- the default value). However, the ``api_endpoint`` property takes
- precedence if provided.
- (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ use the default regular endpoint) and "auto" (auto-switch to the
+ default mTLS endpoint if client certificate is present; this is
+ the default value).
+
+ 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
is "true", then the ``client_cert_source`` property can be used
- to provide client certificate for mutual TLS transport. If
+ to provide a client certificate for mTLS transport. If
not provided, the default SSL client certificate will be used if
present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
set, no client certificate will be used.
+ 3. The ``universe_domain`` property can be used to override the
+ default "googleapis.com" universe. Note that ``api_endpoint``
+ property still takes precedence; and ``universe_domain`` is
+ currently not supported for mTLS.
+
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+
Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
@@ -331,6 +369,9 @@ async def create_table(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -452,6 +493,9 @@ async def create_table_from_snapshot(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -550,6 +594,9 @@ async def list_tables(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -648,6 +695,9 @@ async def get_table(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -751,6 +801,9 @@ async def update_table(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -831,6 +884,9 @@ async def delete_table(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
await rpc(
request,
@@ -911,6 +967,9 @@ async def undelete_table(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1023,6 +1082,9 @@ async def modify_column_families(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1074,6 +1136,9 @@ async def drop_row_range(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
await rpc(
request,
@@ -1164,6 +1229,9 @@ async def generate_consistency_token(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1266,6 +1334,9 @@ async def check_consistency(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1403,6 +1474,9 @@ async def snapshot_table(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1523,6 +1597,9 @@ async def get_snapshot(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1639,6 +1716,9 @@ async def list_snapshots(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1735,6 +1815,9 @@ async def delete_snapshot(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
await rpc(
request,
@@ -1844,6 +1927,9 @@ async def create_backup(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1937,6 +2023,9 @@ async def get_backup(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2032,6 +2121,9 @@ async def update_backup(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2103,6 +2195,9 @@ async def delete_backup(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
await rpc(
request,
@@ -2194,6 +2289,9 @@ async def list_backups(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2267,6 +2365,9 @@ async def restore_table(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2402,6 +2503,9 @@ async def copy_backup(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2529,6 +2633,9 @@ async def get_iam_policy(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2637,6 +2744,9 @@ async def set_iam_policy(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -2736,6 +2846,9 @@ async def test_iam_permissions(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/client.py b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/client.py
index d0c04ed11..09a67e696 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/client.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/client.py
@@ -28,6 +28,7 @@
Union,
cast,
)
+import warnings
from google.cloud.bigtable_admin_v2 import gapic_version as package_version
@@ -42,9 +43,9 @@
from google.oauth2 import service_account # type: ignore
try:
- OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None]
except AttributeError: # pragma: NO COVER
- OptionalRetry = Union[retries.Retry, object] # type: ignore
+ OptionalRetry = Union[retries.Retry, object, None] # type: ignore
from google.api_core import operation # type: ignore
from google.api_core import operation_async # type: ignore
@@ -137,11 +138,15 @@ def _get_default_mtls_endpoint(api_endpoint):
return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")
+ # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead.
DEFAULT_ENDPOINT = "bigtableadmin.googleapis.com"
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)
+ _DEFAULT_ENDPOINT_TEMPLATE = "bigtableadmin.{UNIVERSE_DOMAIN}"
+ _DEFAULT_UNIVERSE = "googleapis.com"
+
@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
@@ -405,7 +410,7 @@ def parse_common_location_path(path: str) -> Dict[str, str]:
def get_mtls_endpoint_and_cert_source(
cls, client_options: Optional[client_options_lib.ClientOptions] = None
):
- """Return the API endpoint and client cert source for mutual TLS.
+ """Deprecated. Return the API endpoint and client cert source for mutual TLS.
The client cert source is determined in the following order:
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
@@ -435,6 +440,11 @@ def get_mtls_endpoint_and_cert_source(
Raises:
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
"""
+
+ warnings.warn(
+ "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.",
+ DeprecationWarning,
+ )
if client_options is None:
client_options = client_options_lib.ClientOptions()
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
@@ -468,6 +478,178 @@ def get_mtls_endpoint_and_cert_source(
return api_endpoint, client_cert_source
+ @staticmethod
+ def _read_environment_variables():
+ """Returns the environment variables used by the client.
+
+ Returns:
+ Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE,
+ GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables.
+
+ Raises:
+ ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not
+ any of ["true", "false"].
+ google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT
+ is not any of ["auto", "never", "always"].
+ """
+ use_client_cert = os.getenv(
+ "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"
+ ).lower()
+ use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower()
+ universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN")
+ if use_client_cert not in ("true", "false"):
+ raise ValueError(
+ "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+ if use_mtls_endpoint not in ("auto", "never", "always"):
+ raise MutualTLSChannelError(
+ "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+ return use_client_cert == "true", use_mtls_endpoint, universe_domain_env
+
+ @staticmethod
+ def _get_client_cert_source(provided_cert_source, use_cert_flag):
+ """Return the client cert source to be used by the client.
+
+ Args:
+ provided_cert_source (bytes): The client certificate source provided.
+ use_cert_flag (bool): A flag indicating whether to use the client certificate.
+
+ Returns:
+ bytes or None: The client cert source to be used by the client.
+ """
+ client_cert_source = None
+ if use_cert_flag:
+ if provided_cert_source:
+ client_cert_source = provided_cert_source
+ elif mtls.has_default_client_cert_source():
+ client_cert_source = mtls.default_client_cert_source()
+ return client_cert_source
+
+ @staticmethod
+ def _get_api_endpoint(
+ api_override, client_cert_source, universe_domain, use_mtls_endpoint
+ ):
+ """Return the API endpoint used by the client.
+
+ Args:
+ api_override (str): The API endpoint override. If specified, this is always
+ the return value of this function and the other arguments are not used.
+ client_cert_source (bytes): The client certificate source used by the client.
+ universe_domain (str): The universe domain used by the client.
+ use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters.
+ Possible values are "always", "auto", or "never".
+
+ Returns:
+ str: The API endpoint to be used by the client.
+ """
+ if api_override is not None:
+ api_endpoint = api_override
+ elif use_mtls_endpoint == "always" or (
+ use_mtls_endpoint == "auto" and client_cert_source
+ ):
+ _default_universe = BigtableTableAdminClient._DEFAULT_UNIVERSE
+ if universe_domain != _default_universe:
+ raise MutualTLSChannelError(
+ f"mTLS is not supported in any universe other than {_default_universe}."
+ )
+ api_endpoint = BigtableTableAdminClient.DEFAULT_MTLS_ENDPOINT
+ else:
+ api_endpoint = BigtableTableAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=universe_domain
+ )
+ return api_endpoint
+
+ @staticmethod
+ def _get_universe_domain(
+ client_universe_domain: Optional[str], universe_domain_env: Optional[str]
+ ) -> str:
+ """Return the universe domain used by the client.
+
+ Args:
+ client_universe_domain (Optional[str]): The universe domain configured via the client options.
+ universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable.
+
+ Returns:
+ str: The universe domain to be used by the client.
+
+ Raises:
+ ValueError: If the universe domain is an empty string.
+ """
+ universe_domain = BigtableTableAdminClient._DEFAULT_UNIVERSE
+ if client_universe_domain is not None:
+ universe_domain = client_universe_domain
+ elif universe_domain_env is not None:
+ universe_domain = universe_domain_env
+ if len(universe_domain.strip()) == 0:
+ raise ValueError("Universe Domain cannot be an empty string.")
+ return universe_domain
+
+ @staticmethod
+ def _compare_universes(
+ client_universe: str, credentials: ga_credentials.Credentials
+ ) -> bool:
+ """Returns True iff the universe domains used by the client and credentials match.
+
+ Args:
+ client_universe (str): The universe domain configured via the client options.
+ credentials (ga_credentials.Credentials): The credentials being used in the client.
+
+ Returns:
+ bool: True iff client_universe matches the universe in credentials.
+
+ Raises:
+ ValueError: when client_universe does not match the universe in credentials.
+ """
+
+ default_universe = BigtableTableAdminClient._DEFAULT_UNIVERSE
+ credentials_universe = getattr(credentials, "universe_domain", default_universe)
+
+ if client_universe != credentials_universe:
+ raise ValueError(
+ "The configured universe domain "
+ f"({client_universe}) does not match the universe domain "
+ f"found in the credentials ({credentials_universe}). "
+ "If you haven't configured the universe domain explicitly, "
+ f"`{default_universe}` is the default."
+ )
+ return True
+
+ def _validate_universe_domain(self):
+ """Validates client's and credentials' universe domains are consistent.
+
+ Returns:
+ bool: True iff the configured universe domain is valid.
+
+ Raises:
+ ValueError: If the configured universe domain is not valid.
+ """
+ self._is_universe_domain_valid = (
+ self._is_universe_domain_valid
+ or BigtableTableAdminClient._compare_universes(
+ self.universe_domain, self.transport._credentials
+ )
+ )
+ return self._is_universe_domain_valid
+
+ @property
+ def api_endpoint(self):
+ """Return the API endpoint used by the client instance.
+
+ Returns:
+ str: The API endpoint used by the client instance.
+ """
+ return self._api_endpoint
+
+ @property
+ def universe_domain(self) -> str:
+ """Return the universe domain used by the client instance.
+
+ Returns:
+ str: The universe domain used by the client instance.
+ """
+ return self._universe_domain
+
def __init__(
self,
*,
@@ -487,22 +669,32 @@ def __init__(
transport (Union[str, BigtableTableAdminTransport]): The
transport to use. If set to None, a transport is chosen
automatically.
- client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the
- client. It won't take effect if a ``transport`` instance is provided.
- (1) The ``api_endpoint`` property can be used to override the
- default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
- environment variable can also be used to override the endpoint:
+ client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]):
+ Custom options for the client.
+
+ 1. The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client when ``transport`` is
+ not explicitly provided. Only if this property is not set and
+ ``transport`` was not explicitly provided, the endpoint is
+ determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment
+ variable, which have one of the following values:
"always" (always use the default mTLS endpoint), "never" (always
- use the default regular endpoint) and "auto" (auto switch to the
- default mTLS endpoint if client certificate is present, this is
- the default value). However, the ``api_endpoint`` property takes
- precedence if provided.
- (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ use the default regular endpoint) and "auto" (auto-switch to the
+ default mTLS endpoint if client certificate is present; this is
+ the default value).
+
+ 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
is "true", then the ``client_cert_source`` property can be used
- to provide client certificate for mutual TLS transport. If
+ to provide a client certificate for mTLS transport. If
not provided, the default SSL client certificate will be used if
present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
set, no client certificate will be used.
+
+ 3. The ``universe_domain`` property can be used to override the
+ default "googleapis.com" universe. Note that the ``api_endpoint``
+ property still takes precedence; and ``universe_domain`` is
+ currently not supported for mTLS.
+
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
@@ -513,17 +705,34 @@ def __init__(
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
"""
- if isinstance(client_options, dict):
- client_options = client_options_lib.from_dict(client_options)
- if client_options is None:
- client_options = client_options_lib.ClientOptions()
- client_options = cast(client_options_lib.ClientOptions, client_options)
+ self._client_options = client_options
+ if isinstance(self._client_options, dict):
+ self._client_options = client_options_lib.from_dict(self._client_options)
+ if self._client_options is None:
+ self._client_options = client_options_lib.ClientOptions()
+ self._client_options = cast(
+ client_options_lib.ClientOptions, self._client_options
+ )
+
+ universe_domain_opt = getattr(self._client_options, "universe_domain", None)
- api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(
- client_options
+ (
+ self._use_client_cert,
+ self._use_mtls_endpoint,
+ self._universe_domain_env,
+ ) = BigtableTableAdminClient._read_environment_variables()
+ self._client_cert_source = BigtableTableAdminClient._get_client_cert_source(
+ self._client_options.client_cert_source, self._use_client_cert
+ )
+ self._universe_domain = BigtableTableAdminClient._get_universe_domain(
+ universe_domain_opt, self._universe_domain_env
)
+ self._api_endpoint = None # updated below, depending on `transport`
- api_key_value = getattr(client_options, "api_key", None)
+ # Initialize the universe domain validation.
+ self._is_universe_domain_valid = False
+
+ api_key_value = getattr(self._client_options, "api_key", None)
if api_key_value and credentials:
raise ValueError(
"client_options.api_key and credentials are mutually exclusive"
@@ -532,20 +741,33 @@ def __init__(
# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
- if isinstance(transport, BigtableTableAdminTransport):
+ transport_provided = isinstance(transport, BigtableTableAdminTransport)
+ if transport_provided:
# transport is a BigtableTableAdminTransport instance.
- if credentials or client_options.credentials_file or api_key_value:
+ if credentials or self._client_options.credentials_file or api_key_value:
raise ValueError(
"When providing a transport instance, "
"provide its credentials directly."
)
- if client_options.scopes:
+ if self._client_options.scopes:
raise ValueError(
"When providing a transport instance, provide its scopes "
"directly."
)
- self._transport = transport
- else:
+ self._transport = cast(BigtableTableAdminTransport, transport)
+ self._api_endpoint = self._transport.host
+
+ self._api_endpoint = (
+ self._api_endpoint
+ or BigtableTableAdminClient._get_api_endpoint(
+ self._client_options.api_endpoint,
+ self._client_cert_source,
+ self._universe_domain,
+ self._use_mtls_endpoint,
+ )
+ )
+
+ if not transport_provided:
import google.auth._default # type: ignore
if api_key_value and hasattr(
@@ -555,17 +777,17 @@ def __init__(
api_key_value
)
- Transport = type(self).get_transport_class(transport)
+ Transport = type(self).get_transport_class(cast(str, transport))
self._transport = Transport(
credentials=credentials,
- credentials_file=client_options.credentials_file,
- host=api_endpoint,
- scopes=client_options.scopes,
- client_cert_source_for_mtls=client_cert_source_func,
- quota_project_id=client_options.quota_project_id,
+ credentials_file=self._client_options.credentials_file,
+ host=self._api_endpoint,
+ scopes=self._client_options.scopes,
+ client_cert_source_for_mtls=self._client_cert_source,
+ quota_project_id=self._client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
- api_audience=client_options.api_audience,
+ api_audience=self._client_options.api_audience,
)
def create_table(
@@ -658,6 +880,9 @@ def create_table(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -781,6 +1006,9 @@ def create_table_from_snapshot(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -869,6 +1097,9 @@ def list_tables(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -957,6 +1188,9 @@ def get_table(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1060,6 +1294,9 @@ def update_table(
),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1140,6 +1377,9 @@ def delete_table(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
rpc(
request,
@@ -1220,6 +1460,9 @@ def undelete_table(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1332,6 +1575,9 @@ def modify_column_families(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1384,6 +1630,9 @@ def drop_row_range(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
rpc(
request,
@@ -1468,6 +1717,9 @@ def generate_consistency_token(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1560,6 +1812,9 @@ def check_consistency(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1697,6 +1952,9 @@ def snapshot_table(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1807,6 +2065,9 @@ def get_snapshot(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1913,6 +2174,9 @@ def list_snapshots(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2009,6 +2273,9 @@ def delete_snapshot(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
rpc(
request,
@@ -2118,6 +2385,9 @@ def create_backup(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2201,6 +2471,9 @@ def get_backup(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2296,6 +2569,9 @@ def update_backup(
),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2367,6 +2643,9 @@ def delete_backup(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
rpc(
request,
@@ -2448,6 +2727,9 @@ def list_backups(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2522,6 +2804,9 @@ def restore_table(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2657,6 +2942,9 @@ def copy_backup(
gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2771,6 +3059,9 @@ def get_iam_policy(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2876,6 +3167,9 @@ def set_iam_policy(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -2963,6 +3257,9 @@ def test_iam_permissions(
gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/base.py b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/base.py
index c3cf01a96..e0313a946 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/base.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/base.py
@@ -71,7 +71,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtableadmin.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
@@ -134,6 +134,10 @@ def __init__(
host += ":443"
self._host = host
+ @property
+ def host(self):
+ return self._host
+
def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/grpc.py b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/grpc.py
index d765869cd..b0c33eca9 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/grpc.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/grpc.py
@@ -75,7 +75,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtableadmin.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/grpc_asyncio.py b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/grpc_asyncio.py
index b60a7351c..3ae66f84f 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/grpc_asyncio.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/grpc_asyncio.py
@@ -120,7 +120,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtableadmin.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
diff --git a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/rest.py b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/rest.py
index 41b893eb7..ad171d8f3 100644
--- a/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/rest.py
+++ b/google/cloud/bigtable_admin_v2/services/bigtable_table_admin/transports/rest.py
@@ -35,9 +35,9 @@
import warnings
try:
- OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None]
except AttributeError: # pragma: NO COVER
- OptionalRetry = Union[retries.Retry, object] # type: ignore
+ OptionalRetry = Union[retries.Retry, object, None] # type: ignore
from google.cloud.bigtable_admin_v2.types import bigtable_table_admin
@@ -831,7 +831,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtableadmin.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
diff --git a/google/cloud/bigtable_admin_v2/types/bigtable_table_admin.py b/google/cloud/bigtable_admin_v2/types/bigtable_table_admin.py
index 6a3b31a1e..c21ac4d5a 100644
--- a/google/cloud/bigtable_admin_v2/types/bigtable_table_admin.py
+++ b/google/cloud/bigtable_admin_v2/types/bigtable_table_admin.py
@@ -597,6 +597,9 @@ class ModifyColumnFamiliesRequest(proto.Message):
earlier modifications can be masked by later
ones (in the case of repeated updates to the
same family, for example).
+ ignore_warnings (bool):
+ Optional. If true, ignore safety checks when
+ modifying the column families.
"""
class Modification(proto.Message):
@@ -662,6 +665,10 @@ class Modification(proto.Message):
number=2,
message=Modification,
)
+ ignore_warnings: bool = proto.Field(
+ proto.BOOL,
+ number=3,
+ )
class GenerateConsistencyTokenRequest(proto.Message):
diff --git a/google/cloud/bigtable_v2/gapic_version.py b/google/cloud/bigtable_v2/gapic_version.py
index 03d6d0200..f01e1d3a5 100644
--- a/google/cloud/bigtable_v2/gapic_version.py
+++ b/google/cloud/bigtable_v2/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "2.22.0" # {x-release-please-version}
+__version__ = "2.23.0" # {x-release-please-version}
diff --git a/google/cloud/bigtable_v2/services/bigtable/async_client.py b/google/cloud/bigtable_v2/services/bigtable/async_client.py
index 33686a4a8..0421e19bc 100644
--- a/google/cloud/bigtable_v2/services/bigtable/async_client.py
+++ b/google/cloud/bigtable_v2/services/bigtable/async_client.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import functools
from collections import OrderedDict
import functools
import re
@@ -40,9 +41,9 @@
from google.oauth2 import service_account # type: ignore
try:
- OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault]
+ OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None]
except AttributeError: # pragma: NO COVER
- OptionalRetry = Union[retries.AsyncRetry, object] # type: ignore
+ OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore
from google.cloud.bigtable_v2.types import bigtable
from google.cloud.bigtable_v2.types import data
@@ -59,8 +60,12 @@ class BigtableAsyncClient:
_client: BigtableClient
+ # Copy defaults from the synchronous client for use here.
+ # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead.
DEFAULT_ENDPOINT = BigtableClient.DEFAULT_ENDPOINT
DEFAULT_MTLS_ENDPOINT = BigtableClient.DEFAULT_MTLS_ENDPOINT
+ _DEFAULT_ENDPOINT_TEMPLATE = BigtableClient._DEFAULT_ENDPOINT_TEMPLATE
+ _DEFAULT_UNIVERSE = BigtableClient._DEFAULT_UNIVERSE
instance_path = staticmethod(BigtableClient.instance_path)
parse_instance_path = staticmethod(BigtableClient.parse_instance_path)
@@ -161,6 +166,25 @@ def transport(self) -> BigtableTransport:
"""
return self._client.transport
+ @property
+ def api_endpoint(self):
+ """Return the API endpoint used by the client instance.
+
+ Returns:
+ str: The API endpoint used by the client instance.
+ """
+ return self._client._api_endpoint
+
+ @property
+ def universe_domain(self) -> str:
+ """Return the universe domain used by the client instance.
+
+ Returns:
+ str: The universe domain used
+ by the client instance.
+ """
+ return self._client._universe_domain
+
get_transport_class = functools.partial(
type(BigtableClient).get_transport_class, type(BigtableClient)
)
@@ -173,7 +197,7 @@ def __init__(
client_options: Optional[ClientOptions] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
- """Instantiates the bigtable client.
+ """Instantiates the bigtable async client.
Args:
credentials (Optional[google.auth.credentials.Credentials]): The
@@ -184,23 +208,38 @@ def __init__(
transport (Union[str, ~.BigtableTransport]): The
transport to use. If set to None, a transport is chosen
automatically.
- client_options (ClientOptions): Custom options for the client. It
- won't take effect if a ``transport`` instance is provided.
- (1) The ``api_endpoint`` property can be used to override the
- default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
- environment variable can also be used to override the endpoint:
+ client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]):
+ Custom options for the client.
+
+ 1. The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client when ``transport`` is
+ not explicitly provided. Only if this property is not set and
+ ``transport`` was not explicitly provided, the endpoint is
+ determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment
+ variable, which have one of the following values:
"always" (always use the default mTLS endpoint), "never" (always
- use the default regular endpoint) and "auto" (auto switch to the
- default mTLS endpoint if client certificate is present, this is
- the default value). However, the ``api_endpoint`` property takes
- precedence if provided.
- (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ use the default regular endpoint) and "auto" (auto-switch to the
+ default mTLS endpoint if client certificate is present; this is
+ the default value).
+
+ 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
is "true", then the ``client_cert_source`` property can be used
- to provide client certificate for mutual TLS transport. If
+ to provide a client certificate for mTLS transport. If
not provided, the default SSL client certificate will be used if
present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
set, no client certificate will be used.
+ 3. The ``universe_domain`` property can be used to override the
+ default "googleapis.com" universe. Note that ``api_endpoint``
+ property still takes precedence; and ``universe_domain`` is
+ currently not supported for mTLS.
+
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+
Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
@@ -272,7 +311,8 @@ def read_rows(
"the individual field arguments should be set."
)
- request = bigtable.ReadRowsRequest(request)
+ if not isinstance(request, bigtable.ReadRowsRequest):
+ request = bigtable.ReadRowsRequest(request)
# If we have keyword arguments corresponding to fields on the
# request, apply these.
@@ -283,12 +323,9 @@ def read_rows(
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
- rpc = gapic_v1.method_async.wrap_method(
- self._client._transport.read_rows,
- default_timeout=43200.0,
- client_info=DEFAULT_CLIENT_INFO,
- )
-
+ rpc = self._client._transport._wrapped_methods[
+ self._client._transport.read_rows
+ ]
# Certain fields should be provided within the metadata header;
# add these here.
metadata = tuple(metadata) + (
@@ -297,6 +334,9 @@ def read_rows(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -367,7 +407,8 @@ def sample_row_keys(
"the individual field arguments should be set."
)
- request = bigtable.SampleRowKeysRequest(request)
+ if not isinstance(request, bigtable.SampleRowKeysRequest):
+ request = bigtable.SampleRowKeysRequest(request)
# If we have keyword arguments corresponding to fields on the
# request, apply these.
@@ -378,12 +419,9 @@ def sample_row_keys(
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
- rpc = gapic_v1.method_async.wrap_method(
- self._client._transport.sample_row_keys,
- default_timeout=60.0,
- client_info=DEFAULT_CLIENT_INFO,
- )
-
+ rpc = self._client._transport._wrapped_methods[
+ self._client._transport.sample_row_keys
+ ]
# Certain fields should be provided within the metadata header;
# add these here.
metadata = tuple(metadata) + (
@@ -392,6 +430,9 @@ def sample_row_keys(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -479,7 +520,8 @@ async def mutate_row(
"the individual field arguments should be set."
)
- request = bigtable.MutateRowRequest(request)
+ if not isinstance(request, bigtable.MutateRowRequest):
+ request = bigtable.MutateRowRequest(request)
# If we have keyword arguments corresponding to fields on the
# request, apply these.
@@ -494,21 +536,9 @@ async def mutate_row(
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
- rpc = gapic_v1.method_async.wrap_method(
- self._client._transport.mutate_row,
- default_retry=retries.AsyncRetry(
- initial=0.01,
- maximum=60.0,
- multiplier=2,
- predicate=retries.if_exception_type(
- core_exceptions.DeadlineExceeded,
- core_exceptions.ServiceUnavailable,
- ),
- deadline=60.0,
- ),
- default_timeout=60.0,
- client_info=DEFAULT_CLIENT_INFO,
- )
+ rpc = self._client._transport._wrapped_methods[
+ self._client._transport.mutate_row
+ ]
# Certain fields should be provided within the metadata header;
# add these here.
@@ -518,6 +548,9 @@ async def mutate_row(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -601,7 +634,8 @@ def mutate_rows(
"the individual field arguments should be set."
)
- request = bigtable.MutateRowsRequest(request)
+ if not isinstance(request, bigtable.MutateRowsRequest):
+ request = bigtable.MutateRowsRequest(request)
# If we have keyword arguments corresponding to fields on the
# request, apply these.
@@ -614,11 +648,9 @@ def mutate_rows(
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
- rpc = gapic_v1.method_async.wrap_method(
- self._client._transport.mutate_rows,
- default_timeout=600.0,
- client_info=DEFAULT_CLIENT_INFO,
- )
+ rpc = self._client._transport._wrapped_methods[
+ self._client._transport.mutate_rows
+ ]
# Certain fields should be provided within the metadata header;
# add these here.
@@ -628,6 +660,9 @@ def mutate_rows(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -749,7 +784,8 @@ async def check_and_mutate_row(
"the individual field arguments should be set."
)
- request = bigtable.CheckAndMutateRowRequest(request)
+ if not isinstance(request, bigtable.CheckAndMutateRowRequest):
+ request = bigtable.CheckAndMutateRowRequest(request)
# If we have keyword arguments corresponding to fields on the
# request, apply these.
@@ -768,11 +804,9 @@ async def check_and_mutate_row(
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
- rpc = gapic_v1.method_async.wrap_method(
- self._client._transport.check_and_mutate_row,
- default_timeout=20.0,
- client_info=DEFAULT_CLIENT_INFO,
- )
+ rpc = self._client._transport._wrapped_methods[
+ self._client._transport.check_and_mutate_row
+ ]
# Certain fields should be provided within the metadata header;
# add these here.
@@ -782,6 +816,9 @@ async def check_and_mutate_row(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -851,7 +888,8 @@ async def ping_and_warm(
"the individual field arguments should be set."
)
- request = bigtable.PingAndWarmRequest(request)
+ if not isinstance(request, bigtable.PingAndWarmRequest):
+ request = bigtable.PingAndWarmRequest(request)
# If we have keyword arguments corresponding to fields on the
# request, apply these.
@@ -862,11 +900,9 @@ async def ping_and_warm(
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
- rpc = gapic_v1.method_async.wrap_method(
- self._client._transport.ping_and_warm,
- default_timeout=None,
- client_info=DEFAULT_CLIENT_INFO,
- )
+ rpc = self._client._transport._wrapped_methods[
+ self._client._transport.ping_and_warm
+ ]
# Certain fields should be provided within the metadata header;
# add these here.
@@ -874,6 +910,9 @@ async def ping_and_warm(
gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -968,7 +1007,8 @@ async def read_modify_write_row(
"the individual field arguments should be set."
)
- request = bigtable.ReadModifyWriteRowRequest(request)
+ if not isinstance(request, bigtable.ReadModifyWriteRowRequest):
+ request = bigtable.ReadModifyWriteRowRequest(request)
# If we have keyword arguments corresponding to fields on the
# request, apply these.
@@ -983,11 +1023,9 @@ async def read_modify_write_row(
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
- rpc = gapic_v1.method_async.wrap_method(
- self._client._transport.read_modify_write_row,
- default_timeout=20.0,
- client_info=DEFAULT_CLIENT_INFO,
- )
+ rpc = self._client._transport._wrapped_methods[
+ self._client._transport.read_modify_write_row
+ ]
# Certain fields should be provided within the metadata header;
# add these here.
@@ -997,6 +1035,9 @@ async def read_modify_write_row(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = await rpc(
request,
@@ -1076,7 +1117,10 @@ def generate_initial_change_stream_partitions(
"the individual field arguments should be set."
)
- request = bigtable.GenerateInitialChangeStreamPartitionsRequest(request)
+ if not isinstance(
+ request, bigtable.GenerateInitialChangeStreamPartitionsRequest
+ ):
+ request = bigtable.GenerateInitialChangeStreamPartitionsRequest(request)
# If we have keyword arguments corresponding to fields on the
# request, apply these.
@@ -1101,6 +1145,9 @@ def generate_initial_change_stream_partitions(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1174,7 +1221,8 @@ def read_change_stream(
"the individual field arguments should be set."
)
- request = bigtable.ReadChangeStreamRequest(request)
+ if not isinstance(request, bigtable.ReadChangeStreamRequest):
+ request = bigtable.ReadChangeStreamRequest(request)
# If we have keyword arguments corresponding to fields on the
# request, apply these.
@@ -1199,6 +1247,9 @@ def read_change_stream(
),
)
+ # Validate the universe domain.
+ self._client._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
diff --git a/google/cloud/bigtable_v2/services/bigtable/client.py b/google/cloud/bigtable_v2/services/bigtable/client.py
index db393faa7..f53f25e90 100644
--- a/google/cloud/bigtable_v2/services/bigtable/client.py
+++ b/google/cloud/bigtable_v2/services/bigtable/client.py
@@ -29,6 +29,7 @@
Union,
cast,
)
+import warnings
from google.cloud.bigtable_v2 import gapic_version as package_version
@@ -43,9 +44,9 @@
from google.oauth2 import service_account # type: ignore
try:
- OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None]
except AttributeError: # pragma: NO COVER
- OptionalRetry = Union[retries.Retry, object] # type: ignore
+ OptionalRetry = Union[retries.Retry, object, None] # type: ignore
from google.cloud.bigtable_v2.types import bigtable
from google.cloud.bigtable_v2.types import data
@@ -53,6 +54,7 @@
from .transports.base import BigtableTransport, DEFAULT_CLIENT_INFO
from .transports.grpc import BigtableGrpcTransport
from .transports.grpc_asyncio import BigtableGrpcAsyncIOTransport
+from .transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport
from .transports.rest import BigtableRestTransport
@@ -67,6 +69,7 @@ class BigtableClientMeta(type):
_transport_registry = OrderedDict() # type: Dict[str, Type[BigtableTransport]]
_transport_registry["grpc"] = BigtableGrpcTransport
_transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport
+ _transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport
_transport_registry["rest"] = BigtableRestTransport
def get_transport_class(
@@ -126,11 +129,15 @@ def _get_default_mtls_endpoint(api_endpoint):
return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")
+ # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead.
DEFAULT_ENDPOINT = "bigtable.googleapis.com"
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)
+ _DEFAULT_ENDPOINT_TEMPLATE = "bigtable.{UNIVERSE_DOMAIN}"
+ _DEFAULT_UNIVERSE = "googleapis.com"
+
@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
@@ -298,7 +305,7 @@ def parse_common_location_path(path: str) -> Dict[str, str]:
def get_mtls_endpoint_and_cert_source(
cls, client_options: Optional[client_options_lib.ClientOptions] = None
):
- """Return the API endpoint and client cert source for mutual TLS.
+ """Deprecated. Return the API endpoint and client cert source for mutual TLS.
The client cert source is determined in the following order:
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
@@ -328,6 +335,11 @@ def get_mtls_endpoint_and_cert_source(
Raises:
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
"""
+
+ warnings.warn(
+ "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.",
+ DeprecationWarning,
+ )
if client_options is None:
client_options = client_options_lib.ClientOptions()
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
@@ -361,6 +373,178 @@ def get_mtls_endpoint_and_cert_source(
return api_endpoint, client_cert_source
+ @staticmethod
+ def _read_environment_variables():
+ """Returns the environment variables used by the client.
+
+ Returns:
+ Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE,
+ GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables.
+
+ Raises:
+ ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not
+ any of ["true", "false"].
+ google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT
+ is not any of ["auto", "never", "always"].
+ """
+ use_client_cert = os.getenv(
+ "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"
+ ).lower()
+ use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower()
+ universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN")
+ if use_client_cert not in ("true", "false"):
+ raise ValueError(
+ "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+ if use_mtls_endpoint not in ("auto", "never", "always"):
+ raise MutualTLSChannelError(
+ "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+ return use_client_cert == "true", use_mtls_endpoint, universe_domain_env
+
+ @staticmethod
+ def _get_client_cert_source(provided_cert_source, use_cert_flag):
+ """Return the client cert source to be used by the client.
+
+ Args:
+ provided_cert_source (bytes): The client certificate source provided.
+ use_cert_flag (bool): A flag indicating whether to use the client certificate.
+
+ Returns:
+ bytes or None: The client cert source to be used by the client.
+ """
+ client_cert_source = None
+ if use_cert_flag:
+ if provided_cert_source:
+ client_cert_source = provided_cert_source
+ elif mtls.has_default_client_cert_source():
+ client_cert_source = mtls.default_client_cert_source()
+ return client_cert_source
+
+ @staticmethod
+ def _get_api_endpoint(
+ api_override, client_cert_source, universe_domain, use_mtls_endpoint
+ ):
+ """Return the API endpoint used by the client.
+
+ Args:
+ api_override (str): The API endpoint override. If specified, this is always
+ the return value of this function and the other arguments are not used.
+ client_cert_source (bytes): The client certificate source used by the client.
+ universe_domain (str): The universe domain used by the client.
+ use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters.
+ Possible values are "always", "auto", or "never".
+
+ Returns:
+ str: The API endpoint to be used by the client.
+ """
+ if api_override is not None:
+ api_endpoint = api_override
+ elif use_mtls_endpoint == "always" or (
+ use_mtls_endpoint == "auto" and client_cert_source
+ ):
+ _default_universe = BigtableClient._DEFAULT_UNIVERSE
+ if universe_domain != _default_universe:
+ raise MutualTLSChannelError(
+ f"mTLS is not supported in any universe other than {_default_universe}."
+ )
+ api_endpoint = BigtableClient.DEFAULT_MTLS_ENDPOINT
+ else:
+ api_endpoint = BigtableClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=universe_domain
+ )
+ return api_endpoint
+
+ @staticmethod
+ def _get_universe_domain(
+ client_universe_domain: Optional[str], universe_domain_env: Optional[str]
+ ) -> str:
+ """Return the universe domain used by the client.
+
+ Args:
+ client_universe_domain (Optional[str]): The universe domain configured via the client options.
+ universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable.
+
+ Returns:
+ str: The universe domain to be used by the client.
+
+ Raises:
+ ValueError: If the universe domain is an empty string.
+ """
+ universe_domain = BigtableClient._DEFAULT_UNIVERSE
+ if client_universe_domain is not None:
+ universe_domain = client_universe_domain
+ elif universe_domain_env is not None:
+ universe_domain = universe_domain_env
+ if len(universe_domain.strip()) == 0:
+ raise ValueError("Universe Domain cannot be an empty string.")
+ return universe_domain
+
+ @staticmethod
+ def _compare_universes(
+ client_universe: str, credentials: ga_credentials.Credentials
+ ) -> bool:
+ """Returns True iff the universe domains used by the client and credentials match.
+
+ Args:
+ client_universe (str): The universe domain configured via the client options.
+ credentials (ga_credentials.Credentials): The credentials being used in the client.
+
+ Returns:
+ bool: True iff client_universe matches the universe in credentials.
+
+ Raises:
+ ValueError: when client_universe does not match the universe in credentials.
+ """
+
+ default_universe = BigtableClient._DEFAULT_UNIVERSE
+ credentials_universe = getattr(credentials, "universe_domain", default_universe)
+
+ if client_universe != credentials_universe:
+ raise ValueError(
+ "The configured universe domain "
+ f"({client_universe}) does not match the universe domain "
+ f"found in the credentials ({credentials_universe}). "
+ "If you haven't configured the universe domain explicitly, "
+ f"`{default_universe}` is the default."
+ )
+ return True
+
+ def _validate_universe_domain(self):
+ """Validates client's and credentials' universe domains are consistent.
+
+ Returns:
+ bool: True iff the configured universe domain is valid.
+
+ Raises:
+ ValueError: If the configured universe domain is not valid.
+ """
+ self._is_universe_domain_valid = (
+ self._is_universe_domain_valid
+ or BigtableClient._compare_universes(
+ self.universe_domain, self.transport._credentials
+ )
+ )
+ return self._is_universe_domain_valid
+
+ @property
+ def api_endpoint(self):
+ """Return the API endpoint used by the client instance.
+
+ Returns:
+ str: The API endpoint used by the client instance.
+ """
+ return self._api_endpoint
+
+ @property
+ def universe_domain(self) -> str:
+ """Return the universe domain used by the client instance.
+
+ Returns:
+ str: The universe domain used by the client instance.
+ """
+ return self._universe_domain
+
def __init__(
self,
*,
@@ -380,22 +564,32 @@ def __init__(
transport (Union[str, BigtableTransport]): The
transport to use. If set to None, a transport is chosen
automatically.
- client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the
- client. It won't take effect if a ``transport`` instance is provided.
- (1) The ``api_endpoint`` property can be used to override the
- default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
- environment variable can also be used to override the endpoint:
+ client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]):
+ Custom options for the client.
+
+ 1. The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client when ``transport`` is
+ not explicitly provided. Only if this property is not set and
+ ``transport`` was not explicitly provided, the endpoint is
+ determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment
+ variable, which have one of the following values:
"always" (always use the default mTLS endpoint), "never" (always
- use the default regular endpoint) and "auto" (auto switch to the
- default mTLS endpoint if client certificate is present, this is
- the default value). However, the ``api_endpoint`` property takes
- precedence if provided.
- (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ use the default regular endpoint) and "auto" (auto-switch to the
+ default mTLS endpoint if client certificate is present; this is
+ the default value).
+
+ 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
is "true", then the ``client_cert_source`` property can be used
- to provide client certificate for mutual TLS transport. If
+ to provide a client certificate for mTLS transport. If
not provided, the default SSL client certificate will be used if
present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
set, no client certificate will be used.
+
+ 3. The ``universe_domain`` property can be used to override the
+ default "googleapis.com" universe. Note that the ``api_endpoint``
+ property still takes precedence; and ``universe_domain`` is
+ currently not supported for mTLS.
+
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
@@ -406,17 +600,34 @@ def __init__(
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
"""
- if isinstance(client_options, dict):
- client_options = client_options_lib.from_dict(client_options)
- if client_options is None:
- client_options = client_options_lib.ClientOptions()
- client_options = cast(client_options_lib.ClientOptions, client_options)
+ self._client_options = client_options
+ if isinstance(self._client_options, dict):
+ self._client_options = client_options_lib.from_dict(self._client_options)
+ if self._client_options is None:
+ self._client_options = client_options_lib.ClientOptions()
+ self._client_options = cast(
+ client_options_lib.ClientOptions, self._client_options
+ )
+
+ universe_domain_opt = getattr(self._client_options, "universe_domain", None)
- api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(
- client_options
+ (
+ self._use_client_cert,
+ self._use_mtls_endpoint,
+ self._universe_domain_env,
+ ) = BigtableClient._read_environment_variables()
+ self._client_cert_source = BigtableClient._get_client_cert_source(
+ self._client_options.client_cert_source, self._use_client_cert
+ )
+ self._universe_domain = BigtableClient._get_universe_domain(
+ universe_domain_opt, self._universe_domain_env
)
+ self._api_endpoint = None # updated below, depending on `transport`
- api_key_value = getattr(client_options, "api_key", None)
+ # Initialize the universe domain validation.
+ self._is_universe_domain_valid = False
+
+ api_key_value = getattr(self._client_options, "api_key", None)
if api_key_value and credentials:
raise ValueError(
"client_options.api_key and credentials are mutually exclusive"
@@ -425,20 +636,30 @@ def __init__(
# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
- if isinstance(transport, BigtableTransport):
+ transport_provided = isinstance(transport, BigtableTransport)
+ if transport_provided:
# transport is a BigtableTransport instance.
- if credentials or client_options.credentials_file or api_key_value:
+ if credentials or self._client_options.credentials_file or api_key_value:
raise ValueError(
"When providing a transport instance, "
"provide its credentials directly."
)
- if client_options.scopes:
+ if self._client_options.scopes:
raise ValueError(
"When providing a transport instance, provide its scopes "
"directly."
)
- self._transport = transport
- else:
+ self._transport = cast(BigtableTransport, transport)
+ self._api_endpoint = self._transport.host
+
+ self._api_endpoint = self._api_endpoint or BigtableClient._get_api_endpoint(
+ self._client_options.api_endpoint,
+ self._client_cert_source,
+ self._universe_domain,
+ self._use_mtls_endpoint,
+ )
+
+ if not transport_provided:
import google.auth._default # type: ignore
if api_key_value and hasattr(
@@ -448,17 +669,17 @@ def __init__(
api_key_value
)
- Transport = type(self).get_transport_class(transport)
+ Transport = type(self).get_transport_class(cast(str, transport))
self._transport = Transport(
credentials=credentials,
- credentials_file=client_options.credentials_file,
- host=api_endpoint,
- scopes=client_options.scopes,
- client_cert_source_for_mtls=client_cert_source_func,
- quota_project_id=client_options.quota_project_id,
+ credentials_file=self._client_options.credentials_file,
+ host=self._api_endpoint,
+ scopes=self._client_options.scopes,
+ client_cert_source_for_mtls=self._client_cert_source,
+ quota_project_id=self._client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
- api_audience=client_options.api_audience,
+ api_audience=self._client_options.api_audience,
)
def read_rows(
@@ -555,6 +776,9 @@ def read_rows(
gapic_v1.routing_header.to_grpc_metadata(header_params),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -659,6 +883,9 @@ def sample_row_keys(
gapic_v1.routing_header.to_grpc_metadata(header_params),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -784,6 +1011,9 @@ def mutate_row(
gapic_v1.routing_header.to_grpc_metadata(header_params),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -903,6 +1133,9 @@ def mutate_rows(
gapic_v1.routing_header.to_grpc_metadata(header_params),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1066,6 +1299,9 @@ def check_and_mutate_row(
gapic_v1.routing_header.to_grpc_metadata(header_params),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1167,6 +1403,9 @@ def ping_and_warm(
gapic_v1.routing_header.to_grpc_metadata(header_params),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1299,6 +1538,9 @@ def read_modify_write_row(
gapic_v1.routing_header.to_grpc_metadata(header_params),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1405,6 +1647,9 @@ def generate_initial_change_stream_partitions(
),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
@@ -1503,6 +1748,9 @@ def read_change_stream(
),
)
+ # Validate the universe domain.
+ self._validate_universe_domain()
+
# Send the request.
response = rpc(
request,
diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py b/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py
index c09443bc2..6a9eb0e58 100644
--- a/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py
+++ b/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py
@@ -19,6 +19,7 @@
from .base import BigtableTransport
from .grpc import BigtableGrpcTransport
from .grpc_asyncio import BigtableGrpcAsyncIOTransport
+from .pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport
from .rest import BigtableRestTransport
from .rest import BigtableRestInterceptor
@@ -27,12 +28,14 @@
_transport_registry = OrderedDict() # type: Dict[str, Type[BigtableTransport]]
_transport_registry["grpc"] = BigtableGrpcTransport
_transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport
+_transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport
_transport_registry["rest"] = BigtableRestTransport
__all__ = (
"BigtableTransport",
"BigtableGrpcTransport",
"BigtableGrpcAsyncIOTransport",
+ "PooledBigtableGrpcAsyncIOTransport",
"BigtableRestTransport",
"BigtableRestInterceptor",
)
diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/base.py b/google/cloud/bigtable_v2/services/bigtable/transports/base.py
index b580bbca7..7d1475eb9 100644
--- a/google/cloud/bigtable_v2/services/bigtable/transports/base.py
+++ b/google/cloud/bigtable_v2/services/bigtable/transports/base.py
@@ -64,7 +64,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtable.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
@@ -127,6 +127,10 @@ def __init__(
host += ":443"
self._host = host
+ @property
+ def host(self):
+ return self._host
+
def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/grpc.py b/google/cloud/bigtable_v2/services/bigtable/transports/grpc.py
index 8ba04e761..bec9c85f1 100644
--- a/google/cloud/bigtable_v2/services/bigtable/transports/grpc.py
+++ b/google/cloud/bigtable_v2/services/bigtable/transports/grpc.py
@@ -65,7 +65,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtable.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py
index 2c0cbdad6..7765ecce8 100644
--- a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py
+++ b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py
@@ -18,6 +18,8 @@
from google.api_core import gapic_v1
from google.api_core import grpc_helpers_async
+from google.api_core import exceptions as core_exceptions
+from google.api_core import retry as retries
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
@@ -110,7 +112,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtable.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
@@ -512,6 +514,66 @@ def read_change_stream(
)
return self._stubs["read_change_stream"]
+ def _prep_wrapped_messages(self, client_info):
+ # Precompute the wrapped methods.
+ self._wrapped_methods = {
+ self.read_rows: gapic_v1.method_async.wrap_method(
+ self.read_rows,
+ default_timeout=43200.0,
+ client_info=client_info,
+ ),
+ self.sample_row_keys: gapic_v1.method_async.wrap_method(
+ self.sample_row_keys,
+ default_timeout=60.0,
+ client_info=client_info,
+ ),
+ self.mutate_row: gapic_v1.method_async.wrap_method(
+ self.mutate_row,
+ default_retry=retries.Retry(
+ initial=0.01,
+ maximum=60.0,
+ multiplier=2,
+ predicate=retries.if_exception_type(
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ ),
+ deadline=60.0,
+ ),
+ default_timeout=60.0,
+ client_info=client_info,
+ ),
+ self.mutate_rows: gapic_v1.method_async.wrap_method(
+ self.mutate_rows,
+ default_timeout=600.0,
+ client_info=client_info,
+ ),
+ self.check_and_mutate_row: gapic_v1.method_async.wrap_method(
+ self.check_and_mutate_row,
+ default_timeout=20.0,
+ client_info=client_info,
+ ),
+ self.ping_and_warm: gapic_v1.method_async.wrap_method(
+ self.ping_and_warm,
+ default_timeout=None,
+ client_info=client_info,
+ ),
+ self.read_modify_write_row: gapic_v1.method_async.wrap_method(
+ self.read_modify_write_row,
+ default_timeout=20.0,
+ client_info=client_info,
+ ),
+ self.generate_initial_change_stream_partitions: gapic_v1.method_async.wrap_method(
+ self.generate_initial_change_stream_partitions,
+ default_timeout=60.0,
+ client_info=client_info,
+ ),
+ self.read_change_stream: gapic_v1.method_async.wrap_method(
+ self.read_change_stream,
+ default_timeout=43200.0,
+ client_info=client_info,
+ ),
+ }
+
def close(self):
return self.grpc_channel.close()
diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py
new file mode 100644
index 000000000..372e5796d
--- /dev/null
+++ b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py
@@ -0,0 +1,426 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import asyncio
+import warnings
+from functools import partialmethod
+from functools import partial
+from typing import (
+ Awaitable,
+ Callable,
+ Dict,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ List,
+ Type,
+)
+
+from google.api_core import gapic_v1
+from google.api_core import grpc_helpers_async
+from google.auth import credentials as ga_credentials # type: ignore
+from google.auth.transport.grpc import SslCredentials # type: ignore
+
+import grpc # type: ignore
+from grpc.experimental import aio # type: ignore
+
+from google.cloud.bigtable_v2.types import bigtable
+from .base import BigtableTransport, DEFAULT_CLIENT_INFO
+from .grpc_asyncio import BigtableGrpcAsyncIOTransport
+
+
+class PooledMultiCallable:
+ def __init__(self, channel_pool: "PooledChannel", *args, **kwargs):
+ self._init_args = args
+ self._init_kwargs = kwargs
+ self.next_channel_fn = channel_pool.next_channel
+
+
+class PooledUnaryUnaryMultiCallable(PooledMultiCallable, aio.UnaryUnaryMultiCallable):
+ def __call__(self, *args, **kwargs) -> aio.UnaryUnaryCall:
+ return self.next_channel_fn().unary_unary(
+ *self._init_args, **self._init_kwargs
+ )(*args, **kwargs)
+
+
+class PooledUnaryStreamMultiCallable(PooledMultiCallable, aio.UnaryStreamMultiCallable):
+ def __call__(self, *args, **kwargs) -> aio.UnaryStreamCall:
+ return self.next_channel_fn().unary_stream(
+ *self._init_args, **self._init_kwargs
+ )(*args, **kwargs)
+
+
+class PooledStreamUnaryMultiCallable(PooledMultiCallable, aio.StreamUnaryMultiCallable):
+ def __call__(self, *args, **kwargs) -> aio.StreamUnaryCall:
+ return self.next_channel_fn().stream_unary(
+ *self._init_args, **self._init_kwargs
+ )(*args, **kwargs)
+
+
+class PooledStreamStreamMultiCallable(
+ PooledMultiCallable, aio.StreamStreamMultiCallable
+):
+ def __call__(self, *args, **kwargs) -> aio.StreamStreamCall:
+ return self.next_channel_fn().stream_stream(
+ *self._init_args, **self._init_kwargs
+ )(*args, **kwargs)
+
+
+class PooledChannel(aio.Channel):
+ def __init__(
+ self,
+ pool_size: int = 3,
+ host: str = "bigtable.googleapis.com",
+ credentials: Optional[ga_credentials.Credentials] = None,
+ credentials_file: Optional[str] = None,
+ quota_project_id: Optional[str] = None,
+ default_scopes: Optional[Sequence[str]] = None,
+ scopes: Optional[Sequence[str]] = None,
+ default_host: Optional[str] = None,
+ insecure: bool = False,
+ **kwargs,
+ ):
+ self._pool: List[aio.Channel] = []
+ self._next_idx = 0
+ if insecure:
+ self._create_channel = partial(aio.insecure_channel, host)
+ else:
+ self._create_channel = partial(
+ grpc_helpers_async.create_channel,
+ target=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ quota_project_id=quota_project_id,
+ default_scopes=default_scopes,
+ scopes=scopes,
+ default_host=default_host,
+ **kwargs,
+ )
+ for i in range(pool_size):
+ self._pool.append(self._create_channel())
+
+ def next_channel(self) -> aio.Channel:
+ channel = self._pool[self._next_idx]
+ self._next_idx = (self._next_idx + 1) % len(self._pool)
+ return channel
+
+ def unary_unary(self, *args, **kwargs) -> grpc.aio.UnaryUnaryMultiCallable:
+ return PooledUnaryUnaryMultiCallable(self, *args, **kwargs)
+
+ def unary_stream(self, *args, **kwargs) -> grpc.aio.UnaryStreamMultiCallable:
+ return PooledUnaryStreamMultiCallable(self, *args, **kwargs)
+
+ def stream_unary(self, *args, **kwargs) -> grpc.aio.StreamUnaryMultiCallable:
+ return PooledStreamUnaryMultiCallable(self, *args, **kwargs)
+
+ def stream_stream(self, *args, **kwargs) -> grpc.aio.StreamStreamMultiCallable:
+ return PooledStreamStreamMultiCallable(self, *args, **kwargs)
+
+ async def close(self, grace=None):
+ close_fns = [channel.close(grace=grace) for channel in self._pool]
+ return await asyncio.gather(*close_fns)
+
+ async def channel_ready(self):
+ ready_fns = [channel.channel_ready() for channel in self._pool]
+ return asyncio.gather(*ready_fns)
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+
+ def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity:
+ raise NotImplementedError()
+
+ async def wait_for_state_change(self, last_observed_state):
+ raise NotImplementedError()
+
+ async def replace_channel(
+ self, channel_idx, grace=None, swap_sleep=1, new_channel=None
+ ) -> aio.Channel:
+ """
+ Replaces a channel in the pool with a fresh one.
+
+ The `new_channel` will start processing new requests immidiately,
+ but the old channel will continue serving existing clients for `grace` seconds
+
+ Args:
+ channel_idx(int): the channel index in the pool to replace
+ grace(Optional[float]): The time to wait until all active RPCs are
+ finished. If a grace period is not specified (by passing None for
+ grace), all existing RPCs are cancelled immediately.
+ swap_sleep(Optional[float]): The number of seconds to sleep in between
+ replacing channels and closing the old one
+ new_channel(grpc.aio.Channel): a new channel to insert into the pool
+ at `channel_idx`. If `None`, a new channel will be created.
+ """
+ if channel_idx >= len(self._pool) or channel_idx < 0:
+ raise ValueError(
+ f"invalid channel_idx {channel_idx} for pool size {len(self._pool)}"
+ )
+ if new_channel is None:
+ new_channel = self._create_channel()
+ old_channel = self._pool[channel_idx]
+ self._pool[channel_idx] = new_channel
+ await asyncio.sleep(swap_sleep)
+ await old_channel.close(grace=grace)
+ return new_channel
+
+
+class PooledBigtableGrpcAsyncIOTransport(BigtableGrpcAsyncIOTransport):
+ """Pooled gRPC AsyncIO backend transport for Bigtable.
+
+ Service for reading from and writing to existing Bigtable
+ tables.
+
+ This class defines the same methods as the primary client, so the
+ primary client can load the underlying transport implementation
+ and call it.
+
+ It sends protocol buffers over the wire using gRPC (which is built on
+ top of HTTP/2); the ``grpcio`` package must be installed.
+
+ This class allows channel pooling, so multiple channels can be used concurrently
+ when making requests. Channels are rotated in a round-robin fashion.
+ """
+
+ @classmethod
+ def with_fixed_size(cls, pool_size) -> Type["PooledBigtableGrpcAsyncIOTransport"]:
+ """
+ Creates a new class with a fixed channel pool size.
+
+ A fixed channel pool makes compatibility with other transports easier,
+ as the initializer signature is the same.
+ """
+
+ class PooledTransportFixed(cls):
+ __init__ = partialmethod(cls.__init__, pool_size=pool_size)
+
+ PooledTransportFixed.__name__ = f"{cls.__name__}_{pool_size}"
+ PooledTransportFixed.__qualname__ = PooledTransportFixed.__name__
+ return PooledTransportFixed
+
+ @classmethod
+ def create_channel(
+ cls,
+ pool_size: int = 3,
+ host: str = "bigtable.googleapis.com",
+ credentials: Optional[ga_credentials.Credentials] = None,
+ credentials_file: Optional[str] = None,
+ scopes: Optional[Sequence[str]] = None,
+ quota_project_id: Optional[str] = None,
+ **kwargs,
+ ) -> aio.Channel:
+ """Create and return a PooledChannel object, representing a pool of gRPC AsyncIO channels
+ Args:
+ pool_size (int): The number of channels in the pool.
+ host (Optional[str]): The host for the channel to use.
+ credentials (Optional[~.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify this application to the service. If
+ none are specified, the client will attempt to ascertain
+ the credentials from the environment.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is ignored if ``channel`` is provided.
+ scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
+ service. These are only used when credentials are not specified and
+ are passed to :func:`google.auth.default`.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ kwargs (Optional[dict]): Keyword arguments, which are passed to the
+ channel creation.
+ Returns:
+ PooledChannel: a channel pool object
+ """
+
+ return PooledChannel(
+ pool_size,
+ host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ quota_project_id=quota_project_id,
+ default_scopes=cls.AUTH_SCOPES,
+ scopes=scopes,
+ default_host=cls.DEFAULT_HOST,
+ **kwargs,
+ )
+
+ def __init__(
+ self,
+ *,
+ pool_size: int = 3,
+ host: str = "bigtable.googleapis.com",
+ credentials: Optional[ga_credentials.Credentials] = None,
+ credentials_file: Optional[str] = None,
+ scopes: Optional[Sequence[str]] = None,
+ api_mtls_endpoint: Optional[str] = None,
+ client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None,
+ ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None,
+ client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None,
+ quota_project_id: Optional[str] = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
+ api_audience: Optional[str] = None,
+ ) -> None:
+ """Instantiate the transport.
+
+ Args:
+ pool_size (int): the number of grpc channels to maintain in a pool
+ host (Optional[str]):
+ The hostname to connect to.
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ This argument is ignored if ``channel`` is provided.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is ignored if ``channel`` is provided.
+ scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
+ service. These are only used when credentials are not specified and
+ are passed to :func:`google.auth.default`.
+ api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
+ If provided, it overrides the ``host`` argument and tries to create
+ a mutual TLS channel with client SSL credentials from
+ ``client_cert_source`` or application default SSL credentials.
+ client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ Deprecated. A callback to provide client SSL certificate bytes and
+ private key bytes, both in PEM format. It is ignored if
+ ``api_mtls_endpoint`` is None.
+ ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
+ for the grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure a mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
+
+ Raises:
+ google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
+ creation failed for any reason.
+ google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
+ and ``credentials_file`` are passed.
+ ValueError: if ``pool_size`` <= 0
+ """
+ if pool_size <= 0:
+ raise ValueError(f"invalid pool_size: {pool_size}")
+ self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
+
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
+
+ else:
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+
+ # The base transport sets the host, credentials and scopes
+ BigtableTransport.__init__(
+ self,
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ always_use_jwt_access=always_use_jwt_access,
+ api_audience=api_audience,
+ )
+ self._quota_project_id = quota_project_id
+ self._grpc_channel = type(self).create_channel(
+ pool_size,
+ self._host,
+ # use the credentials which are saved
+ credentials=self._credentials,
+ # Set ``credentials_file`` to ``None`` here as
+ # the credentials that we saved earlier should be used.
+ credentials_file=None,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
+ quota_project_id=self._quota_project_id,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
+
+ @property
+ def pool_size(self) -> int:
+ """The number of grpc channels in the pool."""
+ return len(self._grpc_channel._pool)
+
+ @property
+ def channels(self) -> List[grpc.Channel]:
+ """Acccess the internal list of grpc channels."""
+ return self._grpc_channel._pool
+
+ async def replace_channel(
+ self, channel_idx, grace=None, swap_sleep=1, new_channel=None
+ ) -> aio.Channel:
+ """
+ Replaces a channel in the pool with a fresh one.
+
+ The `new_channel` will start processing new requests immidiately,
+ but the old channel will continue serving existing clients for `grace` seconds
+
+ Args:
+ channel_idx(int): the channel index in the pool to replace
+ grace(Optional[float]): The time to wait until all active RPCs are
+ finished. If a grace period is not specified (by passing None for
+ grace), all existing RPCs are cancelled immediately.
+ swap_sleep(Optional[float]): The number of seconds to sleep in between
+ replacing channels and closing the old one
+ new_channel(grpc.aio.Channel): a new channel to insert into the pool
+ at `channel_idx`. If `None`, a new channel will be created.
+ """
+ return await self._grpc_channel.replace_channel(
+ channel_idx, grace, swap_sleep, new_channel
+ )
+
+
+__all__ = ("PooledBigtableGrpcAsyncIOTransport",)
diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/rest.py b/google/cloud/bigtable_v2/services/bigtable/transports/rest.py
index 31d230f94..17b47cb1c 100644
--- a/google/cloud/bigtable_v2/services/bigtable/transports/rest.py
+++ b/google/cloud/bigtable_v2/services/bigtable/transports/rest.py
@@ -34,9 +34,9 @@
import warnings
try:
- OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None]
except AttributeError: # pragma: NO COVER
- OptionalRetry = Union[retries.Retry, object] # type: ignore
+ OptionalRetry = Union[retries.Retry, object, None] # type: ignore
from google.cloud.bigtable_v2.types import bigtable
@@ -386,7 +386,7 @@ def __init__(
Args:
host (Optional[str]):
- The hostname to connect to.
+ The hostname to connect to (default: 'bigtable.googleapis.com').
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
diff --git a/google/cloud/bigtable_v2/types/feature_flags.py b/google/cloud/bigtable_v2/types/feature_flags.py
index 92ac5023d..45e673f75 100644
--- a/google/cloud/bigtable_v2/types/feature_flags.py
+++ b/google/cloud/bigtable_v2/types/feature_flags.py
@@ -59,6 +59,14 @@ class FeatureFlags(proto.Message):
Notify the server that the client supports the
last_scanned_row field in ReadRowsResponse for long-running
scans.
+ routing_cookie (bool):
+ Notify the server that the client supports
+ using encoded routing cookie strings to retry
+ requests with.
+ retry_info (bool):
+ Notify the server that the client supports
+ using retry info back off durations to retry
+ requests with.
"""
reverse_scans: bool = proto.Field(
@@ -77,6 +85,14 @@ class FeatureFlags(proto.Message):
proto.BOOL,
number=4,
)
+ routing_cookie: bool = proto.Field(
+ proto.BOOL,
+ number=6,
+ )
+ retry_info: bool = proto.Field(
+ proto.BOOL,
+ number=7,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/noxfile.py b/noxfile.py
index a6fb7d6f3..daf730a9a 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -54,7 +54,9 @@
"pytest",
"google-cloud-testutils",
]
-SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = []
+SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [
+ "pytest-asyncio",
+]
SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = []
SYSTEM_TEST_DEPENDENCIES: List[str] = []
SYSTEM_TEST_EXTRAS: List[str] = []
@@ -134,8 +136,18 @@ def mypy(session):
"mypy", "types-setuptools", "types-protobuf", "types-mock", "types-requests"
)
session.install("google-cloud-testutils")
- # TODO: also verify types on tests, all of google package
- session.run("mypy", "-p", "google", "-p", "tests")
+ session.run(
+ "mypy",
+ "-p",
+ "google.cloud.bigtable.data",
+ "--check-untyped-defs",
+ "--warn-unreachable",
+ "--disallow-any-generics",
+ "--exclude",
+ "tests/system/v2_client",
+ "--exclude",
+ "tests/unit/v2_client",
+ )
@nox.session(python=DEFAULT_PYTHON_VERSION)
@@ -260,6 +272,24 @@ def system_emulated(session):
os.killpg(os.getpgid(p.pid), signal.SIGKILL)
+@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)
+def conformance(session):
+ TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git"
+ CLONE_REPO_DIR = "cloud-bigtable-clients-test"
+ # install dependencies
+ constraints_path = str(
+ CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
+ )
+ install_unittest_dependencies(session, "-c", constraints_path)
+ with session.chdir("test_proxy"):
+ # download the conformance test suite
+ clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR)
+ if not os.path.exists(clone_dir):
+ print("downloading copy of test repo")
+ session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR, external=True)
+ session.run("bash", "-e", "run_tests.sh", external=True)
+
+
@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)
def system(session):
"""Run the system test suite."""
@@ -311,7 +341,7 @@ def cover(session):
test runs (not system test runs), and then erases coverage data.
"""
session.install("coverage", "pytest-cov")
- session.run("coverage", "report", "--show-missing", "--fail-under=100")
+ session.run("coverage", "report", "--show-missing", "--fail-under=99")
session.run("coverage", "erase")
@@ -322,7 +352,16 @@ def docs(session):
session.install("-e", ".")
session.install(
- "sphinx==4.0.1",
+ # We need to pin to specific versions of the `sphinxcontrib-*` packages
+ # which still support sphinx 4.x.
+ # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344
+ # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345.
+ "sphinxcontrib-applehelp==1.0.4",
+ "sphinxcontrib-devhelp==1.0.2",
+ "sphinxcontrib-htmlhelp==2.0.1",
+ "sphinxcontrib-qthelp==1.0.3",
+ "sphinxcontrib-serializinghtml==1.1.5",
+ "sphinx==4.5.0",
"alabaster",
"recommonmark",
)
@@ -348,6 +387,15 @@ def docfx(session):
session.install("-e", ".")
session.install(
+ # We need to pin to specific versions of the `sphinxcontrib-*` packages
+ # which still support sphinx 4.x.
+ # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344
+ # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345.
+ "sphinxcontrib-applehelp==1.0.4",
+ "sphinxcontrib-devhelp==1.0.2",
+ "sphinxcontrib-htmlhelp==2.0.1",
+ "sphinxcontrib-qthelp==1.0.3",
+ "sphinxcontrib-serializinghtml==1.1.5",
"gcp-sphinx-docfx-yaml",
"alabaster",
"recommonmark",
diff --git a/owlbot.py b/owlbot.py
index 4b06aea77..3fb079396 100644
--- a/owlbot.py
+++ b/owlbot.py
@@ -89,7 +89,10 @@ def get_staging_dirs(
samples=True, # set to True only if there are samples
split_system_tests=True,
microgenerator=True,
- cov_level=100,
+ cov_level=99,
+ system_test_external_dependencies=[
+ "pytest-asyncio",
+ ],
)
s.move(templated_files, excludes=[".coveragerc", "README.rst", ".github/release-please.yml"])
@@ -142,7 +145,35 @@ def system_emulated(session):
escape="()"
)
-# add system_emulated nox session
+conformance_session = """
+@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)
+def conformance(session):
+ TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git"
+ CLONE_REPO_DIR = "cloud-bigtable-clients-test"
+ # install dependencies
+ constraints_path = str(
+ CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
+ )
+ install_unittest_dependencies(session, "-c", constraints_path)
+ with session.chdir("test_proxy"):
+ # download the conformance test suite
+ clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR)
+ if not os.path.exists(clone_dir):
+ print("downloading copy of test repo")
+ session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR, external=True)
+ session.run("bash", "-e", "run_tests.sh", external=True)
+
+"""
+
+place_before(
+ "noxfile.py",
+ "@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)\n"
+ "def system(session):",
+ conformance_session,
+ escape="()"
+)
+
+# add system_emulated and mypy and conformance to nox session
s.replace("noxfile.py",
"""nox.options.sessions = \[
"unit",
@@ -168,8 +199,18 @@ def mypy(session):
session.install("-e", ".")
session.install("mypy", "types-setuptools", "types-protobuf", "types-mock", "types-requests")
session.install("google-cloud-testutils")
- # TODO: also verify types on tests, all of google package
- session.run("mypy", "-p", "google", "-p", "tests")
+ session.run(
+ "mypy",
+ "-p",
+ "google.cloud.bigtable.data",
+ "--check-untyped-defs",
+ "--warn-unreachable",
+ "--disallow-any-generics",
+ "--exclude",
+ "tests/system/v2_client",
+ "--exclude",
+ "tests/unit/v2_client",
+ )
@nox.session(python=DEFAULT_PYTHON_VERSION)
diff --git a/python-api-core b/python-api-core
new file mode 160000
index 000000000..17ff5f1d8
--- /dev/null
+++ b/python-api-core
@@ -0,0 +1 @@
+Subproject commit 17ff5f1d83a9a6f50a0226fb0e794634bd584f17
diff --git a/samples/beam/requirements-test.txt b/samples/beam/requirements-test.txt
index 70613be0c..8075a1ec5 100644
--- a/samples/beam/requirements-test.txt
+++ b/samples/beam/requirements-test.txt
@@ -1 +1 @@
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/beam/requirements.txt b/samples/beam/requirements.txt
index 9b95d0b52..70b1371ae 100644
--- a/samples/beam/requirements.txt
+++ b/samples/beam/requirements.txt
@@ -1,3 +1,3 @@
-apache-beam==2.46.0
-google-cloud-bigtable==2.17.0
-google-cloud-core==2.3.3
+apache-beam==2.53.0
+google-cloud-bigtable==2.22.0
+google-cloud-core==2.4.1
diff --git a/samples/hello/requirements-test.txt b/samples/hello/requirements-test.txt
index 70613be0c..8075a1ec5 100644
--- a/samples/hello/requirements-test.txt
+++ b/samples/hello/requirements-test.txt
@@ -1 +1 @@
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/hello/requirements.txt b/samples/hello/requirements.txt
index a76d144e6..68419fbcb 100644
--- a/samples/hello/requirements.txt
+++ b/samples/hello/requirements.txt
@@ -1,2 +1,2 @@
-google-cloud-bigtable==2.20.0
-google-cloud-core==2.3.3
+google-cloud-bigtable==2.22.0
+google-cloud-core==2.4.1
diff --git a/samples/hello_happybase/requirements-test.txt b/samples/hello_happybase/requirements-test.txt
index 70613be0c..8075a1ec5 100644
--- a/samples/hello_happybase/requirements-test.txt
+++ b/samples/hello_happybase/requirements-test.txt
@@ -1 +1 @@
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/instanceadmin/requirements-test.txt b/samples/instanceadmin/requirements-test.txt
index 70613be0c..8075a1ec5 100644
--- a/samples/instanceadmin/requirements-test.txt
+++ b/samples/instanceadmin/requirements-test.txt
@@ -1 +1 @@
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/instanceadmin/requirements.txt b/samples/instanceadmin/requirements.txt
index bba9ed8cf..a01a0943c 100644
--- a/samples/instanceadmin/requirements.txt
+++ b/samples/instanceadmin/requirements.txt
@@ -1,2 +1,2 @@
-google-cloud-bigtable==2.20.0
+google-cloud-bigtable==2.22.0
backoff==2.2.1
diff --git a/samples/metricscaler/requirements-test.txt b/samples/metricscaler/requirements-test.txt
index d8ae088dd..8b8270b6c 100644
--- a/samples/metricscaler/requirements-test.txt
+++ b/samples/metricscaler/requirements-test.txt
@@ -1,3 +1,3 @@
-pytest==7.4.0
+pytest==8.0.0
mock==5.1.0
google-cloud-testutils
diff --git a/samples/metricscaler/requirements.txt b/samples/metricscaler/requirements.txt
index c0fce2294..be3b2b222 100644
--- a/samples/metricscaler/requirements.txt
+++ b/samples/metricscaler/requirements.txt
@@ -1,2 +1,2 @@
-google-cloud-bigtable==2.20.0
-google-cloud-monitoring==2.15.1
+google-cloud-bigtable==2.22.0
+google-cloud-monitoring==2.19.0
diff --git a/samples/quickstart/requirements-test.txt b/samples/quickstart/requirements-test.txt
index 70613be0c..8075a1ec5 100644
--- a/samples/quickstart/requirements-test.txt
+++ b/samples/quickstart/requirements-test.txt
@@ -1 +1 @@
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/quickstart/requirements.txt b/samples/quickstart/requirements.txt
index 83e37754e..6dc985893 100644
--- a/samples/quickstart/requirements.txt
+++ b/samples/quickstart/requirements.txt
@@ -1 +1 @@
-google-cloud-bigtable==2.20.0
+google-cloud-bigtable==2.22.0
diff --git a/samples/quickstart_happybase/requirements-test.txt b/samples/quickstart_happybase/requirements-test.txt
index 70613be0c..8075a1ec5 100644
--- a/samples/quickstart_happybase/requirements-test.txt
+++ b/samples/quickstart_happybase/requirements-test.txt
@@ -1 +1 @@
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/snippets/deletes/requirements-test.txt b/samples/snippets/deletes/requirements-test.txt
index 70613be0c..8075a1ec5 100644
--- a/samples/snippets/deletes/requirements-test.txt
+++ b/samples/snippets/deletes/requirements-test.txt
@@ -1 +1 @@
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/snippets/deletes/requirements.txt b/samples/snippets/deletes/requirements.txt
index 85b4e786f..ae10593d2 100644
--- a/samples/snippets/deletes/requirements.txt
+++ b/samples/snippets/deletes/requirements.txt
@@ -1,2 +1,2 @@
-google-cloud-bigtable==2.20.0
+google-cloud-bigtable==2.22.0
snapshottest==0.6.0
\ No newline at end of file
diff --git a/samples/snippets/filters/requirements-test.txt b/samples/snippets/filters/requirements-test.txt
index 70613be0c..8075a1ec5 100644
--- a/samples/snippets/filters/requirements-test.txt
+++ b/samples/snippets/filters/requirements-test.txt
@@ -1 +1 @@
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/snippets/filters/requirements.txt b/samples/snippets/filters/requirements.txt
index 85b4e786f..ae10593d2 100644
--- a/samples/snippets/filters/requirements.txt
+++ b/samples/snippets/filters/requirements.txt
@@ -1,2 +1,2 @@
-google-cloud-bigtable==2.20.0
+google-cloud-bigtable==2.22.0
snapshottest==0.6.0
\ No newline at end of file
diff --git a/samples/snippets/reads/requirements-test.txt b/samples/snippets/reads/requirements-test.txt
index 70613be0c..8075a1ec5 100644
--- a/samples/snippets/reads/requirements-test.txt
+++ b/samples/snippets/reads/requirements-test.txt
@@ -1 +1 @@
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/snippets/reads/requirements.txt b/samples/snippets/reads/requirements.txt
index 85b4e786f..ae10593d2 100644
--- a/samples/snippets/reads/requirements.txt
+++ b/samples/snippets/reads/requirements.txt
@@ -1,2 +1,2 @@
-google-cloud-bigtable==2.20.0
+google-cloud-bigtable==2.22.0
snapshottest==0.6.0
\ No newline at end of file
diff --git a/samples/snippets/writes/requirements-test.txt b/samples/snippets/writes/requirements-test.txt
index cbd0a47de..aaa563abc 100644
--- a/samples/snippets/writes/requirements-test.txt
+++ b/samples/snippets/writes/requirements-test.txt
@@ -1,2 +1,2 @@
backoff==2.2.1
-pytest==7.4.0
+pytest==8.0.0
diff --git a/samples/snippets/writes/requirements.txt b/samples/snippets/writes/requirements.txt
index 90fa5577c..07b0a191d 100644
--- a/samples/snippets/writes/requirements.txt
+++ b/samples/snippets/writes/requirements.txt
@@ -1 +1 @@
-google-cloud-bigtable==2.20.0
\ No newline at end of file
+google-cloud-bigtable==2.22.0
\ No newline at end of file
diff --git a/samples/tableadmin/requirements-test.txt b/samples/tableadmin/requirements-test.txt
index b4ead9993..b4d30f505 100644
--- a/samples/tableadmin/requirements-test.txt
+++ b/samples/tableadmin/requirements-test.txt
@@ -1,2 +1,2 @@
-pytest==7.4.0
-google-cloud-testutils==1.3.3
+pytest==8.0.0
+google-cloud-testutils==1.4.0
diff --git a/samples/tableadmin/requirements.txt b/samples/tableadmin/requirements.txt
index 83e37754e..6dc985893 100644
--- a/samples/tableadmin/requirements.txt
+++ b/samples/tableadmin/requirements.txt
@@ -1 +1 @@
-google-cloud-bigtable==2.20.0
+google-cloud-bigtable==2.22.0
diff --git a/scripts/fixup_bigtable_admin_v2_keywords.py b/scripts/fixup_bigtable_admin_v2_keywords.py
index 6882feaf6..8c3efea10 100644
--- a/scripts/fixup_bigtable_admin_v2_keywords.py
+++ b/scripts/fixup_bigtable_admin_v2_keywords.py
@@ -69,7 +69,7 @@ class bigtable_adminCallTransformer(cst.CSTTransformer):
'list_instances': ('parent', 'page_token', ),
'list_snapshots': ('parent', 'page_size', 'page_token', ),
'list_tables': ('parent', 'view', 'page_size', 'page_token', ),
- 'modify_column_families': ('name', 'modifications', ),
+ 'modify_column_families': ('name', 'modifications', 'ignore_warnings', ),
'partial_update_cluster': ('cluster', 'update_mask', ),
'partial_update_instance': ('instance', 'update_mask', ),
'restore_table': ('parent', 'table_id', 'backup', ),
diff --git a/setup.py b/setup.py
index e9bce0960..8b698a35b 100644
--- a/setup.py
+++ b/setup.py
@@ -37,7 +37,7 @@
# 'Development Status :: 5 - Production/Stable'
release_status = "Development Status :: 5 - Production/Stable"
dependencies = [
- "google-api-core[grpc] >= 1.34.0, <3.0.0dev,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,!=2.10.*",
+ "google-api-core[grpc] >= 2.16.0, <3.0.0dev",
"google-cloud-core >= 1.4.4, <3.0.0dev",
"grpc-google-iam-v1 >= 0.12.4, <1.0.0dev",
"proto-plus >= 1.22.0, <2.0.0dev",
diff --git a/test_proxy/README.md b/test_proxy/README.md
new file mode 100644
index 000000000..08741fd5d
--- /dev/null
+++ b/test_proxy/README.md
@@ -0,0 +1,60 @@
+# CBT Python Test Proxy
+
+The CBT test proxy is intended for running conformance tests for Cloud Bigtable Python Client.
+
+## Option 1: Run Tests with Nox
+
+You can run the conformance tests in a single line by calling `nox -s conformance` from the repo root
+
+
+```
+cd python-bigtable/test_proxy
+nox -s conformance
+```
+
+## Option 2: Run processes manually
+
+### Start test proxy
+
+You can use `test_proxy.py` to launch a new test proxy process directly
+
+```
+cd python-bigtable/test_proxy
+python test_proxy.py
+```
+
+The port can be set by passing in an extra positional argument
+
+```
+cd python-bigtable/test_proxy
+python test_proxy.py --port 8080
+```
+
+You can run the test proxy against the previous `v2` client by running it with the `--legacy-client` flag:
+
+```
+python test_proxy.py --legacy-client
+```
+
+### Run the test cases
+
+Prerequisites:
+- If you have not already done so, [install golang](https://go.dev/doc/install).
+- Before running tests, [launch an instance of the test proxy](#start-test-proxy)
+in a separate shell session, and make note of the port
+
+
+Clone and navigate to the go test library:
+
+```
+git clone https://github.com/googleapis/cloud-bigtable-clients-test.git
+cd cloud-bigtable-clients-test/tests
+```
+
+
+Launch the tests
+
+```
+go test -v -proxy_addr=:50055
+```
+
diff --git a/test_proxy/handlers/client_handler_data.py b/test_proxy/handlers/client_handler_data.py
new file mode 100644
index 000000000..43ff5d634
--- /dev/null
+++ b/test_proxy/handlers/client_handler_data.py
@@ -0,0 +1,214 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module contains the client handler process for proxy_server.py.
+"""
+import os
+
+from google.cloud.environment_vars import BIGTABLE_EMULATOR
+from google.cloud.bigtable.data import BigtableDataClientAsync
+
+
+def error_safe(func):
+ """
+ Catch and pass errors back to the grpc_server_process
+ Also check if client is closed before processing requests
+ """
+ async def wrapper(self, *args, **kwargs):
+ try:
+ if self.closed:
+ raise RuntimeError("client is closed")
+ return await func(self, *args, **kwargs)
+ except (Exception, NotImplementedError) as e:
+ # exceptions should be raised in grpc_server_process
+ return encode_exception(e)
+
+ return wrapper
+
+
+def encode_exception(exc):
+ """
+ Encode an exception or chain of exceptions to pass back to grpc_handler
+ """
+ from google.api_core.exceptions import GoogleAPICallError
+ error_msg = f"{type(exc).__name__}: {exc}"
+ result = {"error": error_msg}
+ if exc.__cause__:
+ result["cause"] = encode_exception(exc.__cause__)
+ if hasattr(exc, "exceptions"):
+ result["subexceptions"] = [encode_exception(e) for e in exc.exceptions]
+ if hasattr(exc, "index"):
+ result["index"] = exc.index
+ if isinstance(exc, GoogleAPICallError):
+ if exc.grpc_status_code is not None:
+ result["code"] = exc.grpc_status_code.value[0]
+ elif exc.code is not None:
+ result["code"] = int(exc.code)
+ else:
+ result["code"] = -1
+ elif result.get("cause", {}).get("code", None):
+ # look for code code in cause
+ result["code"] = result["cause"]["code"]
+ elif result.get("subexceptions", None):
+ # look for code in subexceptions
+ for subexc in result["subexceptions"]:
+ if subexc.get("code", None):
+ result["code"] = subexc["code"]
+ return result
+
+
+class TestProxyClientHandler:
+ """
+ Implements the same methods as the grpc server, but handles the client
+ library side of the request.
+
+ Requests received in TestProxyGrpcServer are converted to a dictionary,
+ and supplied to the TestProxyClientHandler methods as kwargs.
+ The client response is then returned back to the TestProxyGrpcServer
+ """
+
+ def __init__(
+ self,
+ data_target=None,
+ project_id=None,
+ instance_id=None,
+ app_profile_id=None,
+ per_operation_timeout=None,
+ **kwargs,
+ ):
+ self.closed = False
+ # use emulator
+ os.environ[BIGTABLE_EMULATOR] = data_target
+ self.client = BigtableDataClientAsync(project=project_id)
+ self.instance_id = instance_id
+ self.app_profile_id = app_profile_id
+ self.per_operation_timeout = per_operation_timeout
+
+ def close(self):
+ # TODO: call self.client.close()
+ self.closed = True
+
+ @error_safe
+ async def ReadRows(self, request, **kwargs):
+ table_id = request.pop("table_name").split("/")[-1]
+ app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
+ table = self.client.get_table(self.instance_id, table_id, app_profile_id)
+ kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
+ result_list = await table.read_rows(request, **kwargs)
+ # pack results back into protobuf-parsable format
+ serialized_response = [row._to_dict() for row in result_list]
+ return serialized_response
+
+ @error_safe
+ async def ReadRow(self, row_key, **kwargs):
+ table_id = kwargs.pop("table_name").split("/")[-1]
+ app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None)
+ table = self.client.get_table(self.instance_id, table_id, app_profile_id)
+ kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
+ result_row = await table.read_row(row_key, **kwargs)
+ # pack results back into protobuf-parsable format
+ if result_row:
+ return result_row._to_dict()
+ else:
+ return "None"
+
+ @error_safe
+ async def MutateRow(self, request, **kwargs):
+ from google.cloud.bigtable.data.mutations import Mutation
+ table_id = request["table_name"].split("/")[-1]
+ app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
+ table = self.client.get_table(self.instance_id, table_id, app_profile_id)
+ kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
+ row_key = request["row_key"]
+ mutations = [Mutation._from_dict(d) for d in request["mutations"]]
+ await table.mutate_row(row_key, mutations, **kwargs)
+ return "OK"
+
+ @error_safe
+ async def BulkMutateRows(self, request, **kwargs):
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+ table_id = request["table_name"].split("/")[-1]
+ app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
+ table = self.client.get_table(self.instance_id, table_id, app_profile_id)
+ kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
+ entry_list = [RowMutationEntry._from_dict(entry) for entry in request["entries"]]
+ await table.bulk_mutate_rows(entry_list, **kwargs)
+ return "OK"
+
+ @error_safe
+ async def CheckAndMutateRow(self, request, **kwargs):
+ from google.cloud.bigtable.data.mutations import Mutation, SetCell
+ table_id = request["table_name"].split("/")[-1]
+ app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
+ table = self.client.get_table(self.instance_id, table_id, app_profile_id)
+ kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
+ row_key = request["row_key"]
+ # add default values for incomplete dicts, so they can still be parsed to objects
+ true_mutations = []
+ for mut_dict in request.get("true_mutations", []):
+ try:
+ true_mutations.append(Mutation._from_dict(mut_dict))
+ except ValueError:
+ # invalid mutation type. Conformance test may be sending generic empty request
+ mutation = SetCell("", "", "", 0)
+ true_mutations.append(mutation)
+ false_mutations = []
+ for mut_dict in request.get("false_mutations", []):
+ try:
+ false_mutations.append(Mutation._from_dict(mut_dict))
+ except ValueError:
+ # invalid mutation type. Conformance test may be sending generic empty request
+ false_mutations.append(SetCell("", "", "", 0))
+ predicate_filter = request.get("predicate_filter", None)
+ result = await table.check_and_mutate_row(
+ row_key,
+ predicate_filter,
+ true_case_mutations=true_mutations,
+ false_case_mutations=false_mutations,
+ **kwargs,
+ )
+ return result
+
+ @error_safe
+ async def ReadModifyWriteRow(self, request, **kwargs):
+ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule
+ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule
+ table_id = request["table_name"].split("/")[-1]
+ app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
+ table = self.client.get_table(self.instance_id, table_id, app_profile_id)
+ kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
+ row_key = request["row_key"]
+ rules = []
+ for rule_dict in request.get("rules", []):
+ qualifier = rule_dict["column_qualifier"]
+ if "append_value" in rule_dict:
+ new_rule = AppendValueRule(rule_dict["family_name"], qualifier, rule_dict["append_value"])
+ else:
+ new_rule = IncrementRule(rule_dict["family_name"], qualifier, rule_dict["increment_amount"])
+ rules.append(new_rule)
+ result = await table.read_modify_write_row(row_key, rules, **kwargs)
+ # pack results back into protobuf-parsable format
+ if result:
+ return result._to_dict()
+ else:
+ return "None"
+
+ @error_safe
+ async def SampleRowKeys(self, request, **kwargs):
+ table_id = request["table_name"].split("/")[-1]
+ app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
+ table = self.client.get_table(self.instance_id, table_id, app_profile_id)
+ kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
+ result = await table.sample_row_keys(**kwargs)
+ return result
diff --git a/test_proxy/handlers/client_handler_legacy.py b/test_proxy/handlers/client_handler_legacy.py
new file mode 100644
index 000000000..400f618b5
--- /dev/null
+++ b/test_proxy/handlers/client_handler_legacy.py
@@ -0,0 +1,235 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module contains the client handler process for proxy_server.py.
+"""
+import os
+
+from google.cloud.environment_vars import BIGTABLE_EMULATOR
+from google.cloud.bigtable.client import Client
+
+import client_handler_data as client_handler
+
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+
+class LegacyTestProxyClientHandler(client_handler.TestProxyClientHandler):
+
+ def __init__(
+ self,
+ data_target=None,
+ project_id=None,
+ instance_id=None,
+ app_profile_id=None,
+ per_operation_timeout=None,
+ **kwargs,
+ ):
+ self.closed = False
+ # use emulator
+ os.environ[BIGTABLE_EMULATOR] = data_target
+ self.client = Client(project=project_id)
+ self.instance_id = instance_id
+ self.app_profile_id = app_profile_id
+ self.per_operation_timeout = per_operation_timeout
+
+ def close(self):
+ self.closed = True
+
+ @client_handler.error_safe
+ async def ReadRows(self, request, **kwargs):
+ table_id = request["table_name"].split("/")[-1]
+ # app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
+ instance = self.client.instance(self.instance_id)
+ table = instance.table(table_id)
+
+ limit = request.get("rows_limit", None)
+ start_key = request.get("rows", {}).get("row_keys", [None])[0]
+ end_key = request.get("rows", {}).get("row_keys", [None])[-1]
+ end_inclusive = request.get("rows", {}).get("row_ranges", [{}])[-1].get("end_key_closed", True)
+
+ row_list = []
+ for row in table.read_rows(start_key=start_key, end_key=end_key, limit=limit, end_inclusive=end_inclusive):
+ # parse results into proto formatted dict
+ dict_val = {"row_key": row.row_key}
+ for family, family_cells in row.cells.items():
+ family_dict = {"name": family}
+ for qualifier, qualifier_cells in family_cells.items():
+ column_dict = {"qualifier": qualifier}
+ for cell in qualifier_cells:
+ cell_dict = {
+ "value": cell.value,
+ "timestamp_micros": cell.timestamp.timestamp() * 1000000,
+ "labels": cell.labels,
+ }
+ column_dict.setdefault("cells", []).append(cell_dict)
+ family_dict.setdefault("columns", []).append(column_dict)
+ dict_val.setdefault("families", []).append(family_dict)
+ row_list.append(dict_val)
+ return row_list
+
+ @client_handler.error_safe
+ async def ReadRow(self, row_key, **kwargs):
+ table_id = kwargs["table_name"].split("/")[-1]
+ instance = self.client.instance(self.instance_id)
+ table = instance.table(table_id)
+
+ row = table.read_row(row_key)
+ # parse results into proto formatted dict
+ dict_val = {"row_key": row.row_key}
+ for family, family_cells in row.cells.items():
+ family_dict = {"name": family}
+ for qualifier, qualifier_cells in family_cells.items():
+ column_dict = {"qualifier": qualifier}
+ for cell in qualifier_cells:
+ cell_dict = {
+ "value": cell.value,
+ "timestamp_micros": cell.timestamp.timestamp() * 1000000,
+ "labels": cell.labels,
+ }
+ column_dict.setdefault("cells", []).append(cell_dict)
+ family_dict.setdefault("columns", []).append(column_dict)
+ dict_val.setdefault("families", []).append(family_dict)
+ return dict_val
+
+ @client_handler.error_safe
+ async def MutateRow(self, request, **kwargs):
+ from datetime import datetime
+ from google.cloud.bigtable.row import DirectRow
+ table_id = request["table_name"].split("/")[-1]
+ instance = self.client.instance(self.instance_id)
+ table = instance.table(table_id)
+ row_key = request["row_key"]
+ new_row = DirectRow(row_key, table)
+ for m_dict in request.get("mutations", []):
+ details = m_dict.get("set_cell") or m_dict.get("delete_from_column") or m_dict.get("delete_from_family") or m_dict.get("delete_from_row")
+ timestamp = datetime.fromtimestamp(details.get("timestamp_micros")) if details.get("timestamp_micros") else None
+ if m_dict.get("set_cell"):
+ new_row.set_cell(details["family_name"], details["column_qualifier"], details["value"], timestamp=timestamp)
+ elif m_dict.get("delete_from_column"):
+ new_row.delete_cell(details["family_name"], details["column_qualifier"], timestamp=timestamp)
+ elif m_dict.get("delete_from_family"):
+ new_row.delete_cells(details["family_name"], timestamp=timestamp)
+ elif m_dict.get("delete_from_row"):
+ new_row.delete()
+ table.mutate_rows([new_row])
+ return "OK"
+
+ @client_handler.error_safe
+ async def BulkMutateRows(self, request, **kwargs):
+ from google.cloud.bigtable.row import DirectRow
+ from datetime import datetime
+ table_id = request["table_name"].split("/")[-1]
+ instance = self.client.instance(self.instance_id)
+ table = instance.table(table_id)
+ rows = []
+ for entry in request.get("entries", []):
+ row_key = entry["row_key"]
+ new_row = DirectRow(row_key, table)
+ for m_dict in entry.get("mutations"):
+ details = m_dict.get("set_cell") or m_dict.get("delete_from_column") or m_dict.get("delete_from_family") or m_dict.get("delete_from_row")
+ timestamp = datetime.fromtimestamp(details.get("timestamp_micros")) if details.get("timestamp_micros") else None
+ if m_dict.get("set_cell"):
+ new_row.set_cell(details["family_name"], details["column_qualifier"], details["value"], timestamp=timestamp)
+ elif m_dict.get("delete_from_column"):
+ new_row.delete_cell(details["family_name"], details["column_qualifier"], timestamp=timestamp)
+ elif m_dict.get("delete_from_family"):
+ new_row.delete_cells(details["family_name"], timestamp=timestamp)
+ elif m_dict.get("delete_from_row"):
+ new_row.delete()
+ rows.append(new_row)
+ table.mutate_rows(rows)
+ return "OK"
+
+ @client_handler.error_safe
+ async def CheckAndMutateRow(self, request, **kwargs):
+ from google.cloud.bigtable.row import ConditionalRow
+ from google.cloud.bigtable.row_filters import PassAllFilter
+ table_id = request["table_name"].split("/")[-1]
+ instance = self.client.instance(self.instance_id)
+ table = instance.table(table_id)
+
+ predicate_filter = request.get("predicate_filter", PassAllFilter(True))
+ new_row = ConditionalRow(request["row_key"], table, predicate_filter)
+
+ combined_mutations = [{"state": True, **m} for m in request.get("true_mutations", [])]
+ combined_mutations.extend([{"state": False, **m} for m in request.get("false_mutations", [])])
+ for mut_dict in combined_mutations:
+ if "set_cell" in mut_dict:
+ details = mut_dict["set_cell"]
+ new_row.set_cell(
+ details.get("family_name", ""),
+ details.get("column_qualifier", ""),
+ details.get("value", ""),
+ timestamp=details.get("timestamp_micros", None),
+ state=mut_dict["state"],
+ )
+ elif "delete_from_column" in mut_dict:
+ details = mut_dict["delete_from_column"]
+ new_row.delete_cell(
+ details.get("family_name", ""),
+ details.get("column_qualifier", ""),
+ timestamp=details.get("timestamp_micros", None),
+ state=mut_dict["state"],
+ )
+ elif "delete_from_family" in mut_dict:
+ details = mut_dict["delete_from_family"]
+ new_row.delete_cells(
+ details.get("family_name", ""),
+ timestamp=details.get("timestamp_micros", None),
+ state=mut_dict["state"],
+ )
+ elif "delete_from_row" in mut_dict:
+ new_row.delete(state=mut_dict["state"])
+ else:
+ raise RuntimeError(f"Unknown mutation type: {mut_dict}")
+ return new_row.commit()
+
+ @client_handler.error_safe
+ async def ReadModifyWriteRow(self, request, **kwargs):
+ from google.cloud.bigtable.row import AppendRow
+ from google.cloud._helpers import _microseconds_from_datetime
+ table_id = request["table_name"].split("/")[-1]
+ instance = self.client.instance(self.instance_id)
+ table = instance.table(table_id)
+ row_key = request["row_key"]
+ new_row = AppendRow(row_key, table)
+ for rule_dict in request.get("rules", []):
+ qualifier = rule_dict["column_qualifier"]
+ family = rule_dict["family_name"]
+ if "append_value" in rule_dict:
+ new_row.append_cell_value(family, qualifier, rule_dict["append_value"])
+ else:
+ new_row.increment_cell_value(family, qualifier, rule_dict["increment_amount"])
+ raw_result = new_row.commit()
+ result_families = []
+ for family, column_dict in raw_result.items():
+ result_columns = []
+ for column, cell_list in column_dict.items():
+ result_cells = []
+ for cell_tuple in cell_list:
+ cell_dict = {"value": cell_tuple[0], "timestamp_micros": _microseconds_from_datetime(cell_tuple[1])}
+ result_cells.append(cell_dict)
+ result_columns.append({"qualifier": column, "cells": result_cells})
+ result_families.append({"name": family, "columns": result_columns})
+ return {"key": row_key, "families": result_families}
+
+ @client_handler.error_safe
+ async def SampleRowKeys(self, request, **kwargs):
+ table_id = request["table_name"].split("/")[-1]
+ instance = self.client.instance(self.instance_id)
+ table = instance.table(table_id)
+ response = list(table.sample_row_keys())
+ tuple_response = [(s.row_key, s.offset_bytes) for s in response]
+ return tuple_response
diff --git a/test_proxy/handlers/grpc_handler.py b/test_proxy/handlers/grpc_handler.py
new file mode 100644
index 000000000..2c70778dd
--- /dev/null
+++ b/test_proxy/handlers/grpc_handler.py
@@ -0,0 +1,148 @@
+
+import time
+
+import test_proxy_pb2
+import test_proxy_pb2_grpc
+import data_pb2
+import bigtable_pb2
+from google.rpc.status_pb2 import Status
+from google.protobuf import json_format
+
+
+class TestProxyGrpcServer(test_proxy_pb2_grpc.CloudBigtableV2TestProxyServicer):
+ """
+ Implements a grpc server that proxies conformance test requests to the client library
+
+ Due to issues with using protoc-compiled protos and client-library
+ proto-plus objects in the same process, this server defers requests to
+ matching methods in a TestProxyClientHandler instance in a separate
+ process.
+ This happens invisbly in the decorator @delegate_to_client_handler, with the
+ results attached to each request as a client_response kwarg
+ """
+
+ def __init__(self, request_q, queue_pool):
+ self.open_queues = list(range(len(queue_pool)))
+ self.queue_pool = queue_pool
+ self.request_q = request_q
+
+ def delegate_to_client_handler(func, timeout_seconds=300):
+ """
+ Decorator that transparently passes a request to the client
+ handler process, and then attaches the resonse to the wrapped call
+ """
+
+ def wrapper(self, request, context, **kwargs):
+ deadline = time.time() + timeout_seconds
+ json_dict = json_format.MessageToDict(request)
+ out_idx = self.open_queues.pop()
+ json_dict["proxy_request"] = func.__name__
+ json_dict["response_queue_idx"] = out_idx
+ out_q = self.queue_pool[out_idx]
+ self.request_q.put(json_dict)
+ # wait for response
+ while time.time() < deadline:
+ if not out_q.empty():
+ response = out_q.get()
+ self.open_queues.append(out_idx)
+ if isinstance(response, Exception):
+ raise response
+ else:
+ return func(
+ self,
+ request,
+ context,
+ client_response=response,
+ **kwargs,
+ )
+ time.sleep(1e-4)
+
+ return wrapper
+
+
+ @delegate_to_client_handler
+ def CreateClient(self, request, context, client_response=None):
+ return test_proxy_pb2.CreateClientResponse()
+
+ @delegate_to_client_handler
+ def CloseClient(self, request, context, client_response=None):
+ return test_proxy_pb2.CloseClientResponse()
+
+ @delegate_to_client_handler
+ def RemoveClient(self, request, context, client_response=None):
+ return test_proxy_pb2.RemoveClientResponse()
+
+ @delegate_to_client_handler
+ def ReadRows(self, request, context, client_response=None):
+ status = Status()
+ rows = []
+ if isinstance(client_response, dict) and "error" in client_response:
+ status = Status(code=5, message=client_response["error"])
+ else:
+ rows = [data_pb2.Row(**d) for d in client_response]
+ result = test_proxy_pb2.RowsResult(row=rows, status=status)
+ return result
+
+ @delegate_to_client_handler
+ def ReadRow(self, request, context, client_response=None):
+ status = Status()
+ row = None
+ if isinstance(client_response, dict) and "error" in client_response:
+ status=Status(code=client_response.get("code", 5), message=client_response.get("error"))
+ elif client_response != "None":
+ row = data_pb2.Row(**client_response)
+ result = test_proxy_pb2.RowResult(row=row, status=status)
+ return result
+
+ @delegate_to_client_handler
+ def MutateRow(self, request, context, client_response=None):
+ status = Status()
+ if isinstance(client_response, dict) and "error" in client_response:
+ status = Status(code=client_response.get("code", 5), message=client_response["error"])
+ return test_proxy_pb2.MutateRowResult(status=status)
+
+ @delegate_to_client_handler
+ def BulkMutateRows(self, request, context, client_response=None):
+ status = Status()
+ entries = []
+ if isinstance(client_response, dict) and "error" in client_response:
+ entries = [bigtable_pb2.MutateRowsResponse.Entry(index=exc_dict.get("index",1), status=Status(code=exc_dict.get("code", 5))) for exc_dict in client_response.get("subexceptions", [])]
+ if not entries:
+ # only return failure on the overall request if there are failed entries
+ status = Status(code=client_response.get("code", 5), message=client_response["error"])
+ # TODO: protos were updated. entry is now entries: https://github.com/googleapis/cndb-client-testing-protos/commit/e6205a2bba04acc10d12421a1402870b4a525fb3
+ response = test_proxy_pb2.MutateRowsResult(status=status, entry=entries)
+ return response
+
+ @delegate_to_client_handler
+ def CheckAndMutateRow(self, request, context, client_response=None):
+ if isinstance(client_response, dict) and "error" in client_response:
+ status = Status(code=client_response.get("code", 5), message=client_response["error"])
+ response = test_proxy_pb2.CheckAndMutateRowResult(status=status)
+ else:
+ result = bigtable_pb2.CheckAndMutateRowResponse(predicate_matched=client_response)
+ response = test_proxy_pb2.CheckAndMutateRowResult(result=result, status=Status())
+ return response
+
+ @delegate_to_client_handler
+ def ReadModifyWriteRow(self, request, context, client_response=None):
+ status = Status()
+ row = None
+ if isinstance(client_response, dict) and "error" in client_response:
+ status = Status(code=client_response.get("code", 5), message=client_response.get("error"))
+ elif client_response != "None":
+ row = data_pb2.Row(**client_response)
+ result = test_proxy_pb2.RowResult(row=row, status=status)
+ return result
+
+ @delegate_to_client_handler
+ def SampleRowKeys(self, request, context, client_response=None):
+ status = Status()
+ sample_list = []
+ if isinstance(client_response, dict) and "error" in client_response:
+ status = Status(code=client_response.get("code", 5), message=client_response.get("error"))
+ else:
+ for sample in client_response:
+ sample_list.append(bigtable_pb2.SampleRowKeysResponse(offset_bytes=sample[1], row_key=sample[0]))
+ # TODO: protos were updated. sample is now samples: https://github.com/googleapis/cndb-client-testing-protos/commit/e6205a2bba04acc10d12421a1402870b4a525fb3
+ return test_proxy_pb2.SampleRowKeysResult(status=status, sample=sample_list)
diff --git a/test_proxy/noxfile.py b/test_proxy/noxfile.py
new file mode 100644
index 000000000..bebf247b7
--- /dev/null
+++ b/test_proxy/noxfile.py
@@ -0,0 +1,80 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import os
+import pathlib
+import re
+from colorlog.escape_codes import parse_colors
+
+import nox
+
+
+DEFAULT_PYTHON_VERSION = "3.10"
+
+PROXY_SERVER_PORT=os.environ.get("PROXY_SERVER_PORT", "50055")
+PROXY_CLIENT_VERSION=os.environ.get("PROXY_CLIENT_VERSION", None)
+
+CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()
+REPO_ROOT_DIRECTORY = CURRENT_DIRECTORY.parent
+
+nox.options.sessions = ["run_proxy", "conformance_tests"]
+
+TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git"
+CLONE_REPO_DIR = "cloud-bigtable-clients-test"
+
+# Error if a python version is missing
+nox.options.error_on_missing_interpreters = True
+
+
+def default(session):
+ """
+ if nox is run directly, run the test_proxy session
+ """
+ test_proxy(session)
+
+
+@nox.session(python=DEFAULT_PYTHON_VERSION)
+def conformance_tests(session):
+ """
+ download and run the conformance test suite against the test proxy
+ """
+ import subprocess
+ import time
+ # download the conformance test suite
+ clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR)
+ if not os.path.exists(clone_dir):
+ print("downloading copy of test repo")
+ session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR)
+ # start tests
+ with session.chdir(f"{clone_dir}/tests"):
+ session.run("go", "test", "-v", f"-proxy_addr=:{PROXY_SERVER_PORT}")
+
+@nox.session(python=DEFAULT_PYTHON_VERSION)
+def test_proxy(session):
+ """Start up the test proxy"""
+ # Install all dependencies, then install this package into the
+ # virtualenv's dist-packages.
+ # session.install(
+ # "grpcio",
+ # )
+ if PROXY_CLIENT_VERSION is not None:
+ # install released version of the library
+ session.install(f"python-bigtable=={PROXY_CLIENT_VERSION}")
+ else:
+ # install the library from the source
+ session.install("-e", str(REPO_ROOT_DIRECTORY))
+ session.install("-e", str(REPO_ROOT_DIRECTORY / "python-api-core"))
+
+ session.run("python", "test_proxy.py", "--port", PROXY_SERVER_PORT, *session.posargs,)
diff --git a/test_proxy/protos/bigtable_pb2.py b/test_proxy/protos/bigtable_pb2.py
new file mode 100644
index 000000000..936a4ed55
--- /dev/null
+++ b/test_proxy/protos/bigtable_pb2.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: google/bigtable/v2/bigtable.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2
+from google.api import client_pb2 as google_dot_api_dot_client__pb2
+from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2
+from google.api import resource_pb2 as google_dot_api_dot_resource__pb2
+from google.api import routing_pb2 as google_dot_api_dot_routing__pb2
+import data_pb2 as google_dot_bigtable_dot_v2_dot_data__pb2
+import request_stats_pb2 as google_dot_bigtable_dot_v2_dot_request__stats__pb2
+from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2
+from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
+from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2
+from google.rpc import status_pb2 as google_dot_rpc_dot_status__pb2
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n!google/bigtable/v2/bigtable.proto\x12\x12google.bigtable.v2\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x18google/api/routing.proto\x1a\x1dgoogle/bigtable/v2/data.proto\x1a&google/bigtable/v2/request_stats.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/wrappers.proto\x1a\x17google/rpc/status.proto\"\x90\x03\n\x0fReadRowsRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x05 \x01(\t\x12(\n\x04rows\x18\x02 \x01(\x0b\x32\x1a.google.bigtable.v2.RowSet\x12-\n\x06\x66ilter\x18\x03 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x12\x12\n\nrows_limit\x18\x04 \x01(\x03\x12P\n\x12request_stats_view\x18\x06 \x01(\x0e\x32\x34.google.bigtable.v2.ReadRowsRequest.RequestStatsView\"f\n\x10RequestStatsView\x12\"\n\x1eREQUEST_STATS_VIEW_UNSPECIFIED\x10\x00\x12\x16\n\x12REQUEST_STATS_NONE\x10\x01\x12\x16\n\x12REQUEST_STATS_FULL\x10\x02\"\xb1\x03\n\x10ReadRowsResponse\x12>\n\x06\x63hunks\x18\x01 \x03(\x0b\x32..google.bigtable.v2.ReadRowsResponse.CellChunk\x12\x1c\n\x14last_scanned_row_key\x18\x02 \x01(\x0c\x12\x37\n\rrequest_stats\x18\x03 \x01(\x0b\x32 .google.bigtable.v2.RequestStats\x1a\x85\x02\n\tCellChunk\x12\x0f\n\x07row_key\x18\x01 \x01(\x0c\x12\x31\n\x0b\x66\x61mily_name\x18\x02 \x01(\x0b\x32\x1c.google.protobuf.StringValue\x12.\n\tqualifier\x18\x03 \x01(\x0b\x32\x1b.google.protobuf.BytesValue\x12\x18\n\x10timestamp_micros\x18\x04 \x01(\x03\x12\x0e\n\x06labels\x18\x05 \x03(\t\x12\r\n\x05value\x18\x06 \x01(\x0c\x12\x12\n\nvalue_size\x18\x07 \x01(\x05\x12\x13\n\treset_row\x18\x08 \x01(\x08H\x00\x12\x14\n\ncommit_row\x18\t \x01(\x08H\x00\x42\x0c\n\nrow_status\"n\n\x14SampleRowKeysRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x02 \x01(\t\">\n\x15SampleRowKeysResponse\x12\x0f\n\x07row_key\x18\x01 \x01(\x0c\x12\x14\n\x0coffset_bytes\x18\x02 \x01(\x03\"\xb6\x01\n\x10MutateRowRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x04 \x01(\t\x12\x14\n\x07row_key\x18\x02 \x01(\x0c\x42\x03\xe0\x41\x02\x12\x34\n\tmutations\x18\x03 \x03(\x0b\x32\x1c.google.bigtable.v2.MutationB\x03\xe0\x41\x02\"\x13\n\x11MutateRowResponse\"\xfe\x01\n\x11MutateRowsRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x03 \x01(\t\x12\x41\n\x07\x65ntries\x18\x02 \x03(\x0b\x32+.google.bigtable.v2.MutateRowsRequest.EntryB\x03\xe0\x41\x02\x1aN\n\x05\x45ntry\x12\x0f\n\x07row_key\x18\x01 \x01(\x0c\x12\x34\n\tmutations\x18\x02 \x03(\x0b\x32\x1c.google.bigtable.v2.MutationB\x03\xe0\x41\x02\"\x8f\x01\n\x12MutateRowsResponse\x12=\n\x07\x65ntries\x18\x01 \x03(\x0b\x32,.google.bigtable.v2.MutateRowsResponse.Entry\x1a:\n\x05\x45ntry\x12\r\n\x05index\x18\x01 \x01(\x03\x12\"\n\x06status\x18\x02 \x01(\x0b\x32\x12.google.rpc.Status\"\xae\x02\n\x18\x43heckAndMutateRowRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x07 \x01(\t\x12\x14\n\x07row_key\x18\x02 \x01(\x0c\x42\x03\xe0\x41\x02\x12\x37\n\x10predicate_filter\x18\x06 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x12\x34\n\x0etrue_mutations\x18\x04 \x03(\x0b\x32\x1c.google.bigtable.v2.Mutation\x12\x35\n\x0f\x66\x61lse_mutations\x18\x05 \x03(\x0b\x32\x1c.google.bigtable.v2.Mutation\"6\n\x19\x43heckAndMutateRowResponse\x12\x19\n\x11predicate_matched\x18\x01 \x01(\x08\"i\n\x12PingAndWarmRequest\x12;\n\x04name\x18\x01 \x01(\tB-\xe0\x41\x02\xfa\x41\'\n%bigtableadmin.googleapis.com/Instance\x12\x16\n\x0e\x61pp_profile_id\x18\x02 \x01(\t\"\x15\n\x13PingAndWarmResponse\"\xc6\x01\n\x19ReadModifyWriteRowRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x04 \x01(\t\x12\x14\n\x07row_key\x18\x02 \x01(\x0c\x42\x03\xe0\x41\x02\x12;\n\x05rules\x18\x03 \x03(\x0b\x32\'.google.bigtable.v2.ReadModifyWriteRuleB\x03\xe0\x41\x02\"B\n\x1aReadModifyWriteRowResponse\x12$\n\x03row\x18\x01 \x01(\x0b\x32\x17.google.bigtable.v2.Row\"\x86\x01\n,GenerateInitialChangeStreamPartitionsRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x02 \x01(\t\"g\n-GenerateInitialChangeStreamPartitionsResponse\x12\x36\n\tpartition\x18\x01 \x01(\x0b\x32#.google.bigtable.v2.StreamPartition\"\x9b\x03\n\x17ReadChangeStreamRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x02 \x01(\t\x12\x36\n\tpartition\x18\x03 \x01(\x0b\x32#.google.bigtable.v2.StreamPartition\x12\x30\n\nstart_time\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x12K\n\x13\x63ontinuation_tokens\x18\x06 \x01(\x0b\x32,.google.bigtable.v2.StreamContinuationTokensH\x00\x12,\n\x08\x65nd_time\x18\x05 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x35\n\x12heartbeat_duration\x18\x07 \x01(\x0b\x32\x19.google.protobuf.DurationB\x0c\n\nstart_from\"\xeb\t\n\x18ReadChangeStreamResponse\x12N\n\x0b\x64\x61ta_change\x18\x01 \x01(\x0b\x32\x37.google.bigtable.v2.ReadChangeStreamResponse.DataChangeH\x00\x12K\n\theartbeat\x18\x02 \x01(\x0b\x32\x36.google.bigtable.v2.ReadChangeStreamResponse.HeartbeatH\x00\x12P\n\x0c\x63lose_stream\x18\x03 \x01(\x0b\x32\x38.google.bigtable.v2.ReadChangeStreamResponse.CloseStreamH\x00\x1a\xf4\x01\n\rMutationChunk\x12X\n\nchunk_info\x18\x01 \x01(\x0b\x32\x44.google.bigtable.v2.ReadChangeStreamResponse.MutationChunk.ChunkInfo\x12.\n\x08mutation\x18\x02 \x01(\x0b\x32\x1c.google.bigtable.v2.Mutation\x1aY\n\tChunkInfo\x12\x1a\n\x12\x63hunked_value_size\x18\x01 \x01(\x05\x12\x1c\n\x14\x63hunked_value_offset\x18\x02 \x01(\x05\x12\x12\n\nlast_chunk\x18\x03 \x01(\x08\x1a\xc6\x03\n\nDataChange\x12J\n\x04type\x18\x01 \x01(\x0e\x32<.google.bigtable.v2.ReadChangeStreamResponse.DataChange.Type\x12\x19\n\x11source_cluster_id\x18\x02 \x01(\t\x12\x0f\n\x07row_key\x18\x03 \x01(\x0c\x12\x34\n\x10\x63ommit_timestamp\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x12\n\ntiebreaker\x18\x05 \x01(\x05\x12J\n\x06\x63hunks\x18\x06 \x03(\x0b\x32:.google.bigtable.v2.ReadChangeStreamResponse.MutationChunk\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\r\n\x05token\x18\t \x01(\t\x12;\n\x17\x65stimated_low_watermark\x18\n \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"P\n\x04Type\x12\x14\n\x10TYPE_UNSPECIFIED\x10\x00\x12\x08\n\x04USER\x10\x01\x12\x16\n\x12GARBAGE_COLLECTION\x10\x02\x12\x10\n\x0c\x43ONTINUATION\x10\x03\x1a\x91\x01\n\tHeartbeat\x12G\n\x12\x63ontinuation_token\x18\x01 \x01(\x0b\x32+.google.bigtable.v2.StreamContinuationToken\x12;\n\x17\x65stimated_low_watermark\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x1a{\n\x0b\x43loseStream\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12H\n\x13\x63ontinuation_tokens\x18\x02 \x03(\x0b\x32+.google.bigtable.v2.StreamContinuationTokenB\x0f\n\rstream_record2\xd7\x18\n\x08\x42igtable\x12\x9b\x02\n\x08ReadRows\x12#.google.bigtable.v2.ReadRowsRequest\x1a$.google.bigtable.v2.ReadRowsResponse\"\xc1\x01\x82\xd3\xe4\x93\x02>\"9/v2/{table_name=projects/*/instances/*/tables/*}:readRows:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\ntable_name\xda\x41\x19table_name,app_profile_id0\x01\x12\xac\x02\n\rSampleRowKeys\x12(.google.bigtable.v2.SampleRowKeysRequest\x1a).google.bigtable.v2.SampleRowKeysResponse\"\xc3\x01\x82\xd3\xe4\x93\x02@\x12>/v2/{table_name=projects/*/instances/*/tables/*}:sampleRowKeys\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\ntable_name\xda\x41\x19table_name,app_profile_id0\x01\x12\xc1\x02\n\tMutateRow\x12$.google.bigtable.v2.MutateRowRequest\x1a%.google.bigtable.v2.MutateRowResponse\"\xe6\x01\x82\xd3\xe4\x93\x02?\":/v2/{table_name=projects/*/instances/*/tables/*}:mutateRow:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x1ctable_name,row_key,mutations\xda\x41+table_name,row_key,mutations,app_profile_id\x12\xb3\x02\n\nMutateRows\x12%.google.bigtable.v2.MutateRowsRequest\x1a&.google.bigtable.v2.MutateRowsResponse\"\xd3\x01\x82\xd3\xe4\x93\x02@\";/v2/{table_name=projects/*/instances/*/tables/*}:mutateRows:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x12table_name,entries\xda\x41!table_name,entries,app_profile_id0\x01\x12\xad\x03\n\x11\x43heckAndMutateRow\x12,.google.bigtable.v2.CheckAndMutateRowRequest\x1a-.google.bigtable.v2.CheckAndMutateRowResponse\"\xba\x02\x82\xd3\xe4\x93\x02G\"B/v2/{table_name=projects/*/instances/*/tables/*}:checkAndMutateRow:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x42table_name,row_key,predicate_filter,true_mutations,false_mutations\xda\x41Qtable_name,row_key,predicate_filter,true_mutations,false_mutations,app_profile_id\x12\xee\x01\n\x0bPingAndWarm\x12&.google.bigtable.v2.PingAndWarmRequest\x1a\'.google.bigtable.v2.PingAndWarmResponse\"\x8d\x01\x82\xd3\xe4\x93\x02+\"&/v2/{name=projects/*/instances/*}:ping:\x01*\x8a\xd3\xe4\x93\x02\x39\x12%\n\x04name\x12\x1d{name=projects/*/instances/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x04name\xda\x41\x13name,app_profile_id\x12\xdd\x02\n\x12ReadModifyWriteRow\x12-.google.bigtable.v2.ReadModifyWriteRowRequest\x1a..google.bigtable.v2.ReadModifyWriteRowResponse\"\xe7\x01\x82\xd3\xe4\x93\x02H\"C/v2/{table_name=projects/*/instances/*/tables/*}:readModifyWriteRow:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x18table_name,row_key,rules\xda\x41\'table_name,row_key,rules,app_profile_id\x12\xbb\x02\n%GenerateInitialChangeStreamPartitions\x12@.google.bigtable.v2.GenerateInitialChangeStreamPartitionsRequest\x1a\x41.google.bigtable.v2.GenerateInitialChangeStreamPartitionsResponse\"\x8a\x01\x82\xd3\xe4\x93\x02[\"V/v2/{table_name=projects/*/instances/*/tables/*}:generateInitialChangeStreamPartitions:\x01*\xda\x41\ntable_name\xda\x41\x19table_name,app_profile_id0\x01\x12\xe6\x01\n\x10ReadChangeStream\x12+.google.bigtable.v2.ReadChangeStreamRequest\x1a,.google.bigtable.v2.ReadChangeStreamResponse\"u\x82\xd3\xe4\x93\x02\x46\"A/v2/{table_name=projects/*/instances/*/tables/*}:readChangeStream:\x01*\xda\x41\ntable_name\xda\x41\x19table_name,app_profile_id0\x01\x1a\xdb\x02\xca\x41\x17\x62igtable.googleapis.com\xd2\x41\xbd\x02https://www.googleapis.com/auth/bigtable.data,https://www.googleapis.com/auth/bigtable.data.readonly,https://www.googleapis.com/auth/cloud-bigtable.data,https://www.googleapis.com/auth/cloud-bigtable.data.readonly,https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/cloud-platform.read-onlyB\xeb\x02\n\x16\x63om.google.bigtable.v2B\rBigtableProtoP\x01Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\xaa\x02\x18Google.Cloud.Bigtable.V2\xca\x02\x18Google\\Cloud\\Bigtable\\V2\xea\x02\x1bGoogle::Cloud::Bigtable::V2\xea\x41P\n%bigtableadmin.googleapis.com/Instance\x12\'projects/{project}/instances/{instance}\xea\x41\\\n\"bigtableadmin.googleapis.com/Table\x12\x36projects/{project}/instances/{instance}/tables/{table}b\x06proto3')
+
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.bigtable.v2.bigtable_pb2', globals())
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = b'\n\026com.google.bigtable.v2B\rBigtableProtoP\001Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\252\002\030Google.Cloud.Bigtable.V2\312\002\030Google\\Cloud\\Bigtable\\V2\352\002\033Google::Cloud::Bigtable::V2\352AP\n%bigtableadmin.googleapis.com/Instance\022\'projects/{project}/instances/{instance}\352A\\\n\"bigtableadmin.googleapis.com/Table\0226projects/{project}/instances/{instance}/tables/{table}'
+ _READROWSREQUEST.fields_by_name['table_name']._options = None
+ _READROWSREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table'
+ _SAMPLEROWKEYSREQUEST.fields_by_name['table_name']._options = None
+ _SAMPLEROWKEYSREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table'
+ _MUTATEROWREQUEST.fields_by_name['table_name']._options = None
+ _MUTATEROWREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table'
+ _MUTATEROWREQUEST.fields_by_name['row_key']._options = None
+ _MUTATEROWREQUEST.fields_by_name['row_key']._serialized_options = b'\340A\002'
+ _MUTATEROWREQUEST.fields_by_name['mutations']._options = None
+ _MUTATEROWREQUEST.fields_by_name['mutations']._serialized_options = b'\340A\002'
+ _MUTATEROWSREQUEST_ENTRY.fields_by_name['mutations']._options = None
+ _MUTATEROWSREQUEST_ENTRY.fields_by_name['mutations']._serialized_options = b'\340A\002'
+ _MUTATEROWSREQUEST.fields_by_name['table_name']._options = None
+ _MUTATEROWSREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table'
+ _MUTATEROWSREQUEST.fields_by_name['entries']._options = None
+ _MUTATEROWSREQUEST.fields_by_name['entries']._serialized_options = b'\340A\002'
+ _CHECKANDMUTATEROWREQUEST.fields_by_name['table_name']._options = None
+ _CHECKANDMUTATEROWREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table'
+ _CHECKANDMUTATEROWREQUEST.fields_by_name['row_key']._options = None
+ _CHECKANDMUTATEROWREQUEST.fields_by_name['row_key']._serialized_options = b'\340A\002'
+ _PINGANDWARMREQUEST.fields_by_name['name']._options = None
+ _PINGANDWARMREQUEST.fields_by_name['name']._serialized_options = b'\340A\002\372A\'\n%bigtableadmin.googleapis.com/Instance'
+ _READMODIFYWRITEROWREQUEST.fields_by_name['table_name']._options = None
+ _READMODIFYWRITEROWREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table'
+ _READMODIFYWRITEROWREQUEST.fields_by_name['row_key']._options = None
+ _READMODIFYWRITEROWREQUEST.fields_by_name['row_key']._serialized_options = b'\340A\002'
+ _READMODIFYWRITEROWREQUEST.fields_by_name['rules']._options = None
+ _READMODIFYWRITEROWREQUEST.fields_by_name['rules']._serialized_options = b'\340A\002'
+ _GENERATEINITIALCHANGESTREAMPARTITIONSREQUEST.fields_by_name['table_name']._options = None
+ _GENERATEINITIALCHANGESTREAMPARTITIONSREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table'
+ _READCHANGESTREAMREQUEST.fields_by_name['table_name']._options = None
+ _READCHANGESTREAMREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table'
+ _BIGTABLE._options = None
+ _BIGTABLE._serialized_options = b'\312A\027bigtable.googleapis.com\322A\275\002https://www.googleapis.com/auth/bigtable.data,https://www.googleapis.com/auth/bigtable.data.readonly,https://www.googleapis.com/auth/cloud-bigtable.data,https://www.googleapis.com/auth/cloud-bigtable.data.readonly,https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/cloud-platform.read-only'
+ _BIGTABLE.methods_by_name['ReadRows']._options = None
+ _BIGTABLE.methods_by_name['ReadRows']._serialized_options = b'\202\323\344\223\002>\"9/v2/{table_name=projects/*/instances/*/tables/*}:readRows:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\ntable_name\332A\031table_name,app_profile_id'
+ _BIGTABLE.methods_by_name['SampleRowKeys']._options = None
+ _BIGTABLE.methods_by_name['SampleRowKeys']._serialized_options = b'\202\323\344\223\002@\022>/v2/{table_name=projects/*/instances/*/tables/*}:sampleRowKeys\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\ntable_name\332A\031table_name,app_profile_id'
+ _BIGTABLE.methods_by_name['MutateRow']._options = None
+ _BIGTABLE.methods_by_name['MutateRow']._serialized_options = b'\202\323\344\223\002?\":/v2/{table_name=projects/*/instances/*/tables/*}:mutateRow:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\034table_name,row_key,mutations\332A+table_name,row_key,mutations,app_profile_id'
+ _BIGTABLE.methods_by_name['MutateRows']._options = None
+ _BIGTABLE.methods_by_name['MutateRows']._serialized_options = b'\202\323\344\223\002@\";/v2/{table_name=projects/*/instances/*/tables/*}:mutateRows:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\022table_name,entries\332A!table_name,entries,app_profile_id'
+ _BIGTABLE.methods_by_name['CheckAndMutateRow']._options = None
+ _BIGTABLE.methods_by_name['CheckAndMutateRow']._serialized_options = b'\202\323\344\223\002G\"B/v2/{table_name=projects/*/instances/*/tables/*}:checkAndMutateRow:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332ABtable_name,row_key,predicate_filter,true_mutations,false_mutations\332AQtable_name,row_key,predicate_filter,true_mutations,false_mutations,app_profile_id'
+ _BIGTABLE.methods_by_name['PingAndWarm']._options = None
+ _BIGTABLE.methods_by_name['PingAndWarm']._serialized_options = b'\202\323\344\223\002+\"&/v2/{name=projects/*/instances/*}:ping:\001*\212\323\344\223\0029\022%\n\004name\022\035{name=projects/*/instances/*}\022\020\n\016app_profile_id\332A\004name\332A\023name,app_profile_id'
+ _BIGTABLE.methods_by_name['ReadModifyWriteRow']._options = None
+ _BIGTABLE.methods_by_name['ReadModifyWriteRow']._serialized_options = b'\202\323\344\223\002H\"C/v2/{table_name=projects/*/instances/*/tables/*}:readModifyWriteRow:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\030table_name,row_key,rules\332A\'table_name,row_key,rules,app_profile_id'
+ _BIGTABLE.methods_by_name['GenerateInitialChangeStreamPartitions']._options = None
+ _BIGTABLE.methods_by_name['GenerateInitialChangeStreamPartitions']._serialized_options = b'\202\323\344\223\002[\"V/v2/{table_name=projects/*/instances/*/tables/*}:generateInitialChangeStreamPartitions:\001*\332A\ntable_name\332A\031table_name,app_profile_id'
+ _BIGTABLE.methods_by_name['ReadChangeStream']._options = None
+ _BIGTABLE.methods_by_name['ReadChangeStream']._serialized_options = b'\202\323\344\223\002F\"A/v2/{table_name=projects/*/instances/*/tables/*}:readChangeStream:\001*\332A\ntable_name\332A\031table_name,app_profile_id'
+ _READROWSREQUEST._serialized_start=392
+ _READROWSREQUEST._serialized_end=792
+ _READROWSREQUEST_REQUESTSTATSVIEW._serialized_start=690
+ _READROWSREQUEST_REQUESTSTATSVIEW._serialized_end=792
+ _READROWSRESPONSE._serialized_start=795
+ _READROWSRESPONSE._serialized_end=1228
+ _READROWSRESPONSE_CELLCHUNK._serialized_start=967
+ _READROWSRESPONSE_CELLCHUNK._serialized_end=1228
+ _SAMPLEROWKEYSREQUEST._serialized_start=1230
+ _SAMPLEROWKEYSREQUEST._serialized_end=1340
+ _SAMPLEROWKEYSRESPONSE._serialized_start=1342
+ _SAMPLEROWKEYSRESPONSE._serialized_end=1404
+ _MUTATEROWREQUEST._serialized_start=1407
+ _MUTATEROWREQUEST._serialized_end=1589
+ _MUTATEROWRESPONSE._serialized_start=1591
+ _MUTATEROWRESPONSE._serialized_end=1610
+ _MUTATEROWSREQUEST._serialized_start=1613
+ _MUTATEROWSREQUEST._serialized_end=1867
+ _MUTATEROWSREQUEST_ENTRY._serialized_start=1789
+ _MUTATEROWSREQUEST_ENTRY._serialized_end=1867
+ _MUTATEROWSRESPONSE._serialized_start=1870
+ _MUTATEROWSRESPONSE._serialized_end=2013
+ _MUTATEROWSRESPONSE_ENTRY._serialized_start=1955
+ _MUTATEROWSRESPONSE_ENTRY._serialized_end=2013
+ _CHECKANDMUTATEROWREQUEST._serialized_start=2016
+ _CHECKANDMUTATEROWREQUEST._serialized_end=2318
+ _CHECKANDMUTATEROWRESPONSE._serialized_start=2320
+ _CHECKANDMUTATEROWRESPONSE._serialized_end=2374
+ _PINGANDWARMREQUEST._serialized_start=2376
+ _PINGANDWARMREQUEST._serialized_end=2481
+ _PINGANDWARMRESPONSE._serialized_start=2483
+ _PINGANDWARMRESPONSE._serialized_end=2504
+ _READMODIFYWRITEROWREQUEST._serialized_start=2507
+ _READMODIFYWRITEROWREQUEST._serialized_end=2705
+ _READMODIFYWRITEROWRESPONSE._serialized_start=2707
+ _READMODIFYWRITEROWRESPONSE._serialized_end=2773
+ _GENERATEINITIALCHANGESTREAMPARTITIONSREQUEST._serialized_start=2776
+ _GENERATEINITIALCHANGESTREAMPARTITIONSREQUEST._serialized_end=2910
+ _GENERATEINITIALCHANGESTREAMPARTITIONSRESPONSE._serialized_start=2912
+ _GENERATEINITIALCHANGESTREAMPARTITIONSRESPONSE._serialized_end=3015
+ _READCHANGESTREAMREQUEST._serialized_start=3018
+ _READCHANGESTREAMREQUEST._serialized_end=3429
+ _READCHANGESTREAMRESPONSE._serialized_start=3432
+ _READCHANGESTREAMRESPONSE._serialized_end=4691
+ _READCHANGESTREAMRESPONSE_MUTATIONCHUNK._serialized_start=3700
+ _READCHANGESTREAMRESPONSE_MUTATIONCHUNK._serialized_end=3944
+ _READCHANGESTREAMRESPONSE_MUTATIONCHUNK_CHUNKINFO._serialized_start=3855
+ _READCHANGESTREAMRESPONSE_MUTATIONCHUNK_CHUNKINFO._serialized_end=3944
+ _READCHANGESTREAMRESPONSE_DATACHANGE._serialized_start=3947
+ _READCHANGESTREAMRESPONSE_DATACHANGE._serialized_end=4401
+ _READCHANGESTREAMRESPONSE_DATACHANGE_TYPE._serialized_start=4321
+ _READCHANGESTREAMRESPONSE_DATACHANGE_TYPE._serialized_end=4401
+ _READCHANGESTREAMRESPONSE_HEARTBEAT._serialized_start=4404
+ _READCHANGESTREAMRESPONSE_HEARTBEAT._serialized_end=4549
+ _READCHANGESTREAMRESPONSE_CLOSESTREAM._serialized_start=4551
+ _READCHANGESTREAMRESPONSE_CLOSESTREAM._serialized_end=4674
+ _BIGTABLE._serialized_start=4694
+ _BIGTABLE._serialized_end=7853
+# @@protoc_insertion_point(module_scope)
diff --git a/test_proxy/protos/bigtable_pb2_grpc.py b/test_proxy/protos/bigtable_pb2_grpc.py
new file mode 100644
index 000000000..9ce87d869
--- /dev/null
+++ b/test_proxy/protos/bigtable_pb2_grpc.py
@@ -0,0 +1,363 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+import bigtable_pb2 as google_dot_bigtable_dot_v2_dot_bigtable__pb2
+
+
+class BigtableStub(object):
+ """Service for reading from and writing to existing Bigtable tables.
+ """
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.ReadRows = channel.unary_stream(
+ '/google.bigtable.v2.Bigtable/ReadRows',
+ request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsRequest.SerializeToString,
+ response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsResponse.FromString,
+ )
+ self.SampleRowKeys = channel.unary_stream(
+ '/google.bigtable.v2.Bigtable/SampleRowKeys',
+ request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysRequest.SerializeToString,
+ response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysResponse.FromString,
+ )
+ self.MutateRow = channel.unary_unary(
+ '/google.bigtable.v2.Bigtable/MutateRow',
+ request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowRequest.SerializeToString,
+ response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowResponse.FromString,
+ )
+ self.MutateRows = channel.unary_stream(
+ '/google.bigtable.v2.Bigtable/MutateRows',
+ request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsRequest.SerializeToString,
+ response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsResponse.FromString,
+ )
+ self.CheckAndMutateRow = channel.unary_unary(
+ '/google.bigtable.v2.Bigtable/CheckAndMutateRow',
+ request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowRequest.SerializeToString,
+ response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowResponse.FromString,
+ )
+ self.PingAndWarm = channel.unary_unary(
+ '/google.bigtable.v2.Bigtable/PingAndWarm',
+ request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmRequest.SerializeToString,
+ response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmResponse.FromString,
+ )
+ self.ReadModifyWriteRow = channel.unary_unary(
+ '/google.bigtable.v2.Bigtable/ReadModifyWriteRow',
+ request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowRequest.SerializeToString,
+ response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowResponse.FromString,
+ )
+ self.GenerateInitialChangeStreamPartitions = channel.unary_stream(
+ '/google.bigtable.v2.Bigtable/GenerateInitialChangeStreamPartitions',
+ request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsRequest.SerializeToString,
+ response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsResponse.FromString,
+ )
+ self.ReadChangeStream = channel.unary_stream(
+ '/google.bigtable.v2.Bigtable/ReadChangeStream',
+ request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamRequest.SerializeToString,
+ response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamResponse.FromString,
+ )
+
+
+class BigtableServicer(object):
+ """Service for reading from and writing to existing Bigtable tables.
+ """
+
+ def ReadRows(self, request, context):
+ """Streams back the contents of all requested rows in key order, optionally
+ applying the same Reader filter to each. Depending on their size,
+ rows and cells may be broken up across multiple responses, but
+ atomicity of each row will still be preserved. See the
+ ReadRowsResponse documentation for details.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def SampleRowKeys(self, request, context):
+ """Returns a sample of row keys in the table. The returned row keys will
+ delimit contiguous sections of the table of approximately equal size,
+ which can be used to break up the data for distributed tasks like
+ mapreduces.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def MutateRow(self, request, context):
+ """Mutates a row atomically. Cells already present in the row are left
+ unchanged unless explicitly changed by `mutation`.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def MutateRows(self, request, context):
+ """Mutates multiple rows in a batch. Each individual row is mutated
+ atomically as in MutateRow, but the entire batch is not executed
+ atomically.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def CheckAndMutateRow(self, request, context):
+ """Mutates a row atomically based on the output of a predicate Reader filter.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def PingAndWarm(self, request, context):
+ """Warm up associated instance metadata for this connection.
+ This call is not required but may be useful for connection keep-alive.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def ReadModifyWriteRow(self, request, context):
+ """Modifies a row atomically on the server. The method reads the latest
+ existing timestamp and value from the specified columns and writes a new
+ entry based on pre-defined read/modify/write rules. The new value for the
+ timestamp is the greater of the existing timestamp or the current server
+ time. The method returns the new contents of all modified cells.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def GenerateInitialChangeStreamPartitions(self, request, context):
+ """NOTE: This API is intended to be used by Apache Beam BigtableIO.
+ Returns the current list of partitions that make up the table's
+ change stream. The union of partitions will cover the entire keyspace.
+ Partitions can be read with `ReadChangeStream`.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def ReadChangeStream(self, request, context):
+ """NOTE: This API is intended to be used by Apache Beam BigtableIO.
+ Reads changes from a table's change stream. Changes will
+ reflect both user-initiated mutations and mutations that are caused by
+ garbage collection.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_BigtableServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'ReadRows': grpc.unary_stream_rpc_method_handler(
+ servicer.ReadRows,
+ request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsRequest.FromString,
+ response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsResponse.SerializeToString,
+ ),
+ 'SampleRowKeys': grpc.unary_stream_rpc_method_handler(
+ servicer.SampleRowKeys,
+ request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysRequest.FromString,
+ response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysResponse.SerializeToString,
+ ),
+ 'MutateRow': grpc.unary_unary_rpc_method_handler(
+ servicer.MutateRow,
+ request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowRequest.FromString,
+ response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowResponse.SerializeToString,
+ ),
+ 'MutateRows': grpc.unary_stream_rpc_method_handler(
+ servicer.MutateRows,
+ request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsRequest.FromString,
+ response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsResponse.SerializeToString,
+ ),
+ 'CheckAndMutateRow': grpc.unary_unary_rpc_method_handler(
+ servicer.CheckAndMutateRow,
+ request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowRequest.FromString,
+ response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowResponse.SerializeToString,
+ ),
+ 'PingAndWarm': grpc.unary_unary_rpc_method_handler(
+ servicer.PingAndWarm,
+ request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmRequest.FromString,
+ response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmResponse.SerializeToString,
+ ),
+ 'ReadModifyWriteRow': grpc.unary_unary_rpc_method_handler(
+ servicer.ReadModifyWriteRow,
+ request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowRequest.FromString,
+ response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowResponse.SerializeToString,
+ ),
+ 'GenerateInitialChangeStreamPartitions': grpc.unary_stream_rpc_method_handler(
+ servicer.GenerateInitialChangeStreamPartitions,
+ request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsRequest.FromString,
+ response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsResponse.SerializeToString,
+ ),
+ 'ReadChangeStream': grpc.unary_stream_rpc_method_handler(
+ servicer.ReadChangeStream,
+ request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamRequest.FromString,
+ response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamResponse.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'google.bigtable.v2.Bigtable', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler,))
+
+
+ # This class is part of an EXPERIMENTAL API.
+class Bigtable(object):
+ """Service for reading from and writing to existing Bigtable tables.
+ """
+
+ @staticmethod
+ def ReadRows(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/ReadRows',
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsRequest.SerializeToString,
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def SampleRowKeys(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/SampleRowKeys',
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysRequest.SerializeToString,
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def MutateRow(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.v2.Bigtable/MutateRow',
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowRequest.SerializeToString,
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def MutateRows(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/MutateRows',
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsRequest.SerializeToString,
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def CheckAndMutateRow(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.v2.Bigtable/CheckAndMutateRow',
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowRequest.SerializeToString,
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def PingAndWarm(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.v2.Bigtable/PingAndWarm',
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmRequest.SerializeToString,
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def ReadModifyWriteRow(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.v2.Bigtable/ReadModifyWriteRow',
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowRequest.SerializeToString,
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def GenerateInitialChangeStreamPartitions(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/GenerateInitialChangeStreamPartitions',
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsRequest.SerializeToString,
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def ReadChangeStream(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/ReadChangeStream',
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamRequest.SerializeToString,
+ google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
diff --git a/test_proxy/protos/data_pb2.py b/test_proxy/protos/data_pb2.py
new file mode 100644
index 000000000..fff212034
--- /dev/null
+++ b/test_proxy/protos/data_pb2.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: google/bigtable/v2/data.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dgoogle/bigtable/v2/data.proto\x12\x12google.bigtable.v2\"@\n\x03Row\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12,\n\x08\x66\x61milies\x18\x02 \x03(\x0b\x32\x1a.google.bigtable.v2.Family\"C\n\x06\x46\x61mily\x12\x0c\n\x04name\x18\x01 \x01(\t\x12+\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x1a.google.bigtable.v2.Column\"D\n\x06\x43olumn\x12\x11\n\tqualifier\x18\x01 \x01(\x0c\x12\'\n\x05\x63\x65lls\x18\x02 \x03(\x0b\x32\x18.google.bigtable.v2.Cell\"?\n\x04\x43\x65ll\x12\x18\n\x10timestamp_micros\x18\x01 \x01(\x03\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x0e\n\x06labels\x18\x03 \x03(\t\"\x8a\x01\n\x08RowRange\x12\x1a\n\x10start_key_closed\x18\x01 \x01(\x0cH\x00\x12\x18\n\x0estart_key_open\x18\x02 \x01(\x0cH\x00\x12\x16\n\x0c\x65nd_key_open\x18\x03 \x01(\x0cH\x01\x12\x18\n\x0e\x65nd_key_closed\x18\x04 \x01(\x0cH\x01\x42\x0b\n\tstart_keyB\t\n\x07\x65nd_key\"L\n\x06RowSet\x12\x10\n\x08row_keys\x18\x01 \x03(\x0c\x12\x30\n\nrow_ranges\x18\x02 \x03(\x0b\x32\x1c.google.bigtable.v2.RowRange\"\xc6\x01\n\x0b\x43olumnRange\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x12 \n\x16start_qualifier_closed\x18\x02 \x01(\x0cH\x00\x12\x1e\n\x14start_qualifier_open\x18\x03 \x01(\x0cH\x00\x12\x1e\n\x14\x65nd_qualifier_closed\x18\x04 \x01(\x0cH\x01\x12\x1c\n\x12\x65nd_qualifier_open\x18\x05 \x01(\x0cH\x01\x42\x11\n\x0fstart_qualifierB\x0f\n\rend_qualifier\"N\n\x0eTimestampRange\x12\x1e\n\x16start_timestamp_micros\x18\x01 \x01(\x03\x12\x1c\n\x14\x65nd_timestamp_micros\x18\x02 \x01(\x03\"\x98\x01\n\nValueRange\x12\x1c\n\x12start_value_closed\x18\x01 \x01(\x0cH\x00\x12\x1a\n\x10start_value_open\x18\x02 \x01(\x0cH\x00\x12\x1a\n\x10\x65nd_value_closed\x18\x03 \x01(\x0cH\x01\x12\x18\n\x0e\x65nd_value_open\x18\x04 \x01(\x0cH\x01\x42\r\n\x0bstart_valueB\x0b\n\tend_value\"\xdf\x08\n\tRowFilter\x12\x34\n\x05\x63hain\x18\x01 \x01(\x0b\x32#.google.bigtable.v2.RowFilter.ChainH\x00\x12>\n\ninterleave\x18\x02 \x01(\x0b\x32(.google.bigtable.v2.RowFilter.InterleaveH\x00\x12<\n\tcondition\x18\x03 \x01(\x0b\x32\'.google.bigtable.v2.RowFilter.ConditionH\x00\x12\x0e\n\x04sink\x18\x10 \x01(\x08H\x00\x12\x19\n\x0fpass_all_filter\x18\x11 \x01(\x08H\x00\x12\x1a\n\x10\x62lock_all_filter\x18\x12 \x01(\x08H\x00\x12\x1e\n\x14row_key_regex_filter\x18\x04 \x01(\x0cH\x00\x12\x1b\n\x11row_sample_filter\x18\x0e \x01(\x01H\x00\x12\"\n\x18\x66\x61mily_name_regex_filter\x18\x05 \x01(\tH\x00\x12\'\n\x1d\x63olumn_qualifier_regex_filter\x18\x06 \x01(\x0cH\x00\x12>\n\x13\x63olumn_range_filter\x18\x07 \x01(\x0b\x32\x1f.google.bigtable.v2.ColumnRangeH\x00\x12\x44\n\x16timestamp_range_filter\x18\x08 \x01(\x0b\x32\".google.bigtable.v2.TimestampRangeH\x00\x12\x1c\n\x12value_regex_filter\x18\t \x01(\x0cH\x00\x12<\n\x12value_range_filter\x18\x0f \x01(\x0b\x32\x1e.google.bigtable.v2.ValueRangeH\x00\x12%\n\x1b\x63\x65lls_per_row_offset_filter\x18\n \x01(\x05H\x00\x12$\n\x1a\x63\x65lls_per_row_limit_filter\x18\x0b \x01(\x05H\x00\x12\'\n\x1d\x63\x65lls_per_column_limit_filter\x18\x0c \x01(\x05H\x00\x12!\n\x17strip_value_transformer\x18\r \x01(\x08H\x00\x12!\n\x17\x61pply_label_transformer\x18\x13 \x01(\tH\x00\x1a\x37\n\x05\x43hain\x12.\n\x07\x66ilters\x18\x01 \x03(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x1a<\n\nInterleave\x12.\n\x07\x66ilters\x18\x01 \x03(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x1a\xad\x01\n\tCondition\x12\x37\n\x10predicate_filter\x18\x01 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x12\x32\n\x0btrue_filter\x18\x02 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x12\x33\n\x0c\x66\x61lse_filter\x18\x03 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilterB\x08\n\x06\x66ilter\"\xc9\x04\n\x08Mutation\x12\x38\n\x08set_cell\x18\x01 \x01(\x0b\x32$.google.bigtable.v2.Mutation.SetCellH\x00\x12K\n\x12\x64\x65lete_from_column\x18\x02 \x01(\x0b\x32-.google.bigtable.v2.Mutation.DeleteFromColumnH\x00\x12K\n\x12\x64\x65lete_from_family\x18\x03 \x01(\x0b\x32-.google.bigtable.v2.Mutation.DeleteFromFamilyH\x00\x12\x45\n\x0f\x64\x65lete_from_row\x18\x04 \x01(\x0b\x32*.google.bigtable.v2.Mutation.DeleteFromRowH\x00\x1a\x61\n\x07SetCell\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x12\x18\n\x10\x63olumn_qualifier\x18\x02 \x01(\x0c\x12\x18\n\x10timestamp_micros\x18\x03 \x01(\x03\x12\r\n\x05value\x18\x04 \x01(\x0c\x1ay\n\x10\x44\x65leteFromColumn\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x12\x18\n\x10\x63olumn_qualifier\x18\x02 \x01(\x0c\x12\x36\n\ntime_range\x18\x03 \x01(\x0b\x32\".google.bigtable.v2.TimestampRange\x1a\'\n\x10\x44\x65leteFromFamily\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x1a\x0f\n\rDeleteFromRowB\n\n\x08mutation\"\x80\x01\n\x13ReadModifyWriteRule\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x12\x18\n\x10\x63olumn_qualifier\x18\x02 \x01(\x0c\x12\x16\n\x0c\x61ppend_value\x18\x03 \x01(\x0cH\x00\x12\x1a\n\x10increment_amount\x18\x04 \x01(\x03H\x00\x42\x06\n\x04rule\"B\n\x0fStreamPartition\x12/\n\trow_range\x18\x01 \x01(\x0b\x32\x1c.google.bigtable.v2.RowRange\"W\n\x18StreamContinuationTokens\x12;\n\x06tokens\x18\x01 \x03(\x0b\x32+.google.bigtable.v2.StreamContinuationToken\"`\n\x17StreamContinuationToken\x12\x36\n\tpartition\x18\x01 \x01(\x0b\x32#.google.bigtable.v2.StreamPartition\x12\r\n\x05token\x18\x02 \x01(\tB\xb5\x01\n\x16\x63om.google.bigtable.v2B\tDataProtoP\x01Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\xaa\x02\x18Google.Cloud.Bigtable.V2\xca\x02\x18Google\\Cloud\\Bigtable\\V2\xea\x02\x1bGoogle::Cloud::Bigtable::V2b\x06proto3')
+
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.bigtable.v2.data_pb2', globals())
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = b'\n\026com.google.bigtable.v2B\tDataProtoP\001Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\252\002\030Google.Cloud.Bigtable.V2\312\002\030Google\\Cloud\\Bigtable\\V2\352\002\033Google::Cloud::Bigtable::V2'
+ _ROW._serialized_start=53
+ _ROW._serialized_end=117
+ _FAMILY._serialized_start=119
+ _FAMILY._serialized_end=186
+ _COLUMN._serialized_start=188
+ _COLUMN._serialized_end=256
+ _CELL._serialized_start=258
+ _CELL._serialized_end=321
+ _ROWRANGE._serialized_start=324
+ _ROWRANGE._serialized_end=462
+ _ROWSET._serialized_start=464
+ _ROWSET._serialized_end=540
+ _COLUMNRANGE._serialized_start=543
+ _COLUMNRANGE._serialized_end=741
+ _TIMESTAMPRANGE._serialized_start=743
+ _TIMESTAMPRANGE._serialized_end=821
+ _VALUERANGE._serialized_start=824
+ _VALUERANGE._serialized_end=976
+ _ROWFILTER._serialized_start=979
+ _ROWFILTER._serialized_end=2098
+ _ROWFILTER_CHAIN._serialized_start=1795
+ _ROWFILTER_CHAIN._serialized_end=1850
+ _ROWFILTER_INTERLEAVE._serialized_start=1852
+ _ROWFILTER_INTERLEAVE._serialized_end=1912
+ _ROWFILTER_CONDITION._serialized_start=1915
+ _ROWFILTER_CONDITION._serialized_end=2088
+ _MUTATION._serialized_start=2101
+ _MUTATION._serialized_end=2686
+ _MUTATION_SETCELL._serialized_start=2396
+ _MUTATION_SETCELL._serialized_end=2493
+ _MUTATION_DELETEFROMCOLUMN._serialized_start=2495
+ _MUTATION_DELETEFROMCOLUMN._serialized_end=2616
+ _MUTATION_DELETEFROMFAMILY._serialized_start=2618
+ _MUTATION_DELETEFROMFAMILY._serialized_end=2657
+ _MUTATION_DELETEFROMROW._serialized_start=2659
+ _MUTATION_DELETEFROMROW._serialized_end=2674
+ _READMODIFYWRITERULE._serialized_start=2689
+ _READMODIFYWRITERULE._serialized_end=2817
+ _STREAMPARTITION._serialized_start=2819
+ _STREAMPARTITION._serialized_end=2885
+ _STREAMCONTINUATIONTOKENS._serialized_start=2887
+ _STREAMCONTINUATIONTOKENS._serialized_end=2974
+ _STREAMCONTINUATIONTOKEN._serialized_start=2976
+ _STREAMCONTINUATIONTOKEN._serialized_end=3072
+# @@protoc_insertion_point(module_scope)
diff --git a/test_proxy/protos/data_pb2_grpc.py b/test_proxy/protos/data_pb2_grpc.py
new file mode 100644
index 000000000..2daafffeb
--- /dev/null
+++ b/test_proxy/protos/data_pb2_grpc.py
@@ -0,0 +1,4 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
diff --git a/test_proxy/protos/request_stats_pb2.py b/test_proxy/protos/request_stats_pb2.py
new file mode 100644
index 000000000..95fcc6e0f
--- /dev/null
+++ b/test_proxy/protos/request_stats_pb2.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: google/bigtable/v2/request_stats.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&google/bigtable/v2/request_stats.proto\x12\x12google.bigtable.v2\x1a\x1egoogle/protobuf/duration.proto\"\x82\x01\n\x12ReadIterationStats\x12\x17\n\x0frows_seen_count\x18\x01 \x01(\x03\x12\x1b\n\x13rows_returned_count\x18\x02 \x01(\x03\x12\x18\n\x10\x63\x65lls_seen_count\x18\x03 \x01(\x03\x12\x1c\n\x14\x63\x65lls_returned_count\x18\x04 \x01(\x03\"Q\n\x13RequestLatencyStats\x12:\n\x17\x66rontend_server_latency\x18\x01 \x01(\x0b\x32\x19.google.protobuf.Duration\"\xa1\x01\n\x11\x46ullReadStatsView\x12\x44\n\x14read_iteration_stats\x18\x01 \x01(\x0b\x32&.google.bigtable.v2.ReadIterationStats\x12\x46\n\x15request_latency_stats\x18\x02 \x01(\x0b\x32\'.google.bigtable.v2.RequestLatencyStats\"c\n\x0cRequestStats\x12\x45\n\x14\x66ull_read_stats_view\x18\x01 \x01(\x0b\x32%.google.bigtable.v2.FullReadStatsViewH\x00\x42\x0c\n\nstats_viewB\xbd\x01\n\x16\x63om.google.bigtable.v2B\x11RequestStatsProtoP\x01Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\xaa\x02\x18Google.Cloud.Bigtable.V2\xca\x02\x18Google\\Cloud\\Bigtable\\V2\xea\x02\x1bGoogle::Cloud::Bigtable::V2b\x06proto3')
+
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.bigtable.v2.request_stats_pb2', globals())
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = b'\n\026com.google.bigtable.v2B\021RequestStatsProtoP\001Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\252\002\030Google.Cloud.Bigtable.V2\312\002\030Google\\Cloud\\Bigtable\\V2\352\002\033Google::Cloud::Bigtable::V2'
+ _READITERATIONSTATS._serialized_start=95
+ _READITERATIONSTATS._serialized_end=225
+ _REQUESTLATENCYSTATS._serialized_start=227
+ _REQUESTLATENCYSTATS._serialized_end=308
+ _FULLREADSTATSVIEW._serialized_start=311
+ _FULLREADSTATSVIEW._serialized_end=472
+ _REQUESTSTATS._serialized_start=474
+ _REQUESTSTATS._serialized_end=573
+# @@protoc_insertion_point(module_scope)
diff --git a/test_proxy/protos/request_stats_pb2_grpc.py b/test_proxy/protos/request_stats_pb2_grpc.py
new file mode 100644
index 000000000..2daafffeb
--- /dev/null
+++ b/test_proxy/protos/request_stats_pb2_grpc.py
@@ -0,0 +1,4 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
diff --git a/test_proxy/protos/test_proxy_pb2.py b/test_proxy/protos/test_proxy_pb2.py
new file mode 100644
index 000000000..8c7817b14
--- /dev/null
+++ b/test_proxy/protos/test_proxy_pb2.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: test_proxy.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from google.api import client_pb2 as google_dot_api_dot_client__pb2
+import bigtable_pb2 as google_dot_bigtable_dot_v2_dot_bigtable__pb2
+import data_pb2 as google_dot_bigtable_dot_v2_dot_data__pb2
+from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2
+from google.rpc import status_pb2 as google_dot_rpc_dot_status__pb2
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10test_proxy.proto\x12\x19google.bigtable.testproxy\x1a\x17google/api/client.proto\x1a!google/bigtable/v2/bigtable.proto\x1a\x1dgoogle/bigtable/v2/data.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x17google/rpc/status.proto\"\xb8\x01\n\x13\x43reateClientRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x61ta_target\x18\x02 \x01(\t\x12\x12\n\nproject_id\x18\x03 \x01(\t\x12\x13\n\x0binstance_id\x18\x04 \x01(\t\x12\x16\n\x0e\x61pp_profile_id\x18\x05 \x01(\t\x12\x38\n\x15per_operation_timeout\x18\x06 \x01(\x0b\x32\x19.google.protobuf.Duration\"\x16\n\x14\x43reateClientResponse\"\'\n\x12\x43loseClientRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\"\x15\n\x13\x43loseClientResponse\"(\n\x13RemoveClientRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\"\x16\n\x14RemoveClientResponse\"w\n\x0eReadRowRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x12\n\ntable_name\x18\x04 \x01(\t\x12\x0f\n\x07row_key\x18\x02 \x01(\t\x12-\n\x06\x66ilter\x18\x03 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\"U\n\tRowResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12$\n\x03row\x18\x02 \x01(\x0b\x32\x17.google.bigtable.v2.Row\"u\n\x0fReadRowsRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x34\n\x07request\x18\x02 \x01(\x0b\x32#.google.bigtable.v2.ReadRowsRequest\x12\x19\n\x11\x63\x61ncel_after_rows\x18\x03 \x01(\x05\"V\n\nRowsResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12$\n\x03row\x18\x02 \x03(\x0b\x32\x17.google.bigtable.v2.Row\"\\\n\x10MutateRowRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x35\n\x07request\x18\x02 \x01(\x0b\x32$.google.bigtable.v2.MutateRowRequest\"5\n\x0fMutateRowResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\"^\n\x11MutateRowsRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x36\n\x07request\x18\x02 \x01(\x0b\x32%.google.bigtable.v2.MutateRowsRequest\"s\n\x10MutateRowsResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12;\n\x05\x65ntry\x18\x02 \x03(\x0b\x32,.google.bigtable.v2.MutateRowsResponse.Entry\"l\n\x18\x43heckAndMutateRowRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12=\n\x07request\x18\x02 \x01(\x0b\x32,.google.bigtable.v2.CheckAndMutateRowRequest\"|\n\x17\x43heckAndMutateRowResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12=\n\x06result\x18\x02 \x01(\x0b\x32-.google.bigtable.v2.CheckAndMutateRowResponse\"d\n\x14SampleRowKeysRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x39\n\x07request\x18\x02 \x01(\x0b\x32(.google.bigtable.v2.SampleRowKeysRequest\"t\n\x13SampleRowKeysResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12\x39\n\x06sample\x18\x02 \x03(\x0b\x32).google.bigtable.v2.SampleRowKeysResponse\"n\n\x19ReadModifyWriteRowRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12>\n\x07request\x18\x02 \x01(\x0b\x32-.google.bigtable.v2.ReadModifyWriteRowRequest2\xa4\t\n\x18\x43loudBigtableV2TestProxy\x12q\n\x0c\x43reateClient\x12..google.bigtable.testproxy.CreateClientRequest\x1a/.google.bigtable.testproxy.CreateClientResponse\"\x00\x12n\n\x0b\x43loseClient\x12-.google.bigtable.testproxy.CloseClientRequest\x1a..google.bigtable.testproxy.CloseClientResponse\"\x00\x12q\n\x0cRemoveClient\x12..google.bigtable.testproxy.RemoveClientRequest\x1a/.google.bigtable.testproxy.RemoveClientResponse\"\x00\x12\\\n\x07ReadRow\x12).google.bigtable.testproxy.ReadRowRequest\x1a$.google.bigtable.testproxy.RowResult\"\x00\x12_\n\x08ReadRows\x12*.google.bigtable.testproxy.ReadRowsRequest\x1a%.google.bigtable.testproxy.RowsResult\"\x00\x12\x66\n\tMutateRow\x12+.google.bigtable.testproxy.MutateRowRequest\x1a*.google.bigtable.testproxy.MutateRowResult\"\x00\x12m\n\x0e\x42ulkMutateRows\x12,.google.bigtable.testproxy.MutateRowsRequest\x1a+.google.bigtable.testproxy.MutateRowsResult\"\x00\x12~\n\x11\x43heckAndMutateRow\x12\x33.google.bigtable.testproxy.CheckAndMutateRowRequest\x1a\x32.google.bigtable.testproxy.CheckAndMutateRowResult\"\x00\x12r\n\rSampleRowKeys\x12/.google.bigtable.testproxy.SampleRowKeysRequest\x1a..google.bigtable.testproxy.SampleRowKeysResult\"\x00\x12r\n\x12ReadModifyWriteRow\x12\x34.google.bigtable.testproxy.ReadModifyWriteRowRequest\x1a$.google.bigtable.testproxy.RowResult\"\x00\x1a\x34\xca\x41\x31\x62igtable-test-proxy-not-accessible.googleapis.comB6\n#com.google.cloud.bigtable.testproxyP\x01Z\r./testproxypbb\x06proto3')
+
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'test_proxy_pb2', globals())
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = b'\n#com.google.cloud.bigtable.testproxyP\001Z\r./testproxypb'
+ _CLOUDBIGTABLEV2TESTPROXY._options = None
+ _CLOUDBIGTABLEV2TESTPROXY._serialized_options = b'\312A1bigtable-test-proxy-not-accessible.googleapis.com'
+ _CREATECLIENTREQUEST._serialized_start=196
+ _CREATECLIENTREQUEST._serialized_end=380
+ _CREATECLIENTRESPONSE._serialized_start=382
+ _CREATECLIENTRESPONSE._serialized_end=404
+ _CLOSECLIENTREQUEST._serialized_start=406
+ _CLOSECLIENTREQUEST._serialized_end=445
+ _CLOSECLIENTRESPONSE._serialized_start=447
+ _CLOSECLIENTRESPONSE._serialized_end=468
+ _REMOVECLIENTREQUEST._serialized_start=470
+ _REMOVECLIENTREQUEST._serialized_end=510
+ _REMOVECLIENTRESPONSE._serialized_start=512
+ _REMOVECLIENTRESPONSE._serialized_end=534
+ _READROWREQUEST._serialized_start=536
+ _READROWREQUEST._serialized_end=655
+ _ROWRESULT._serialized_start=657
+ _ROWRESULT._serialized_end=742
+ _READROWSREQUEST._serialized_start=744
+ _READROWSREQUEST._serialized_end=861
+ _ROWSRESULT._serialized_start=863
+ _ROWSRESULT._serialized_end=949
+ _MUTATEROWREQUEST._serialized_start=951
+ _MUTATEROWREQUEST._serialized_end=1043
+ _MUTATEROWRESULT._serialized_start=1045
+ _MUTATEROWRESULT._serialized_end=1098
+ _MUTATEROWSREQUEST._serialized_start=1100
+ _MUTATEROWSREQUEST._serialized_end=1194
+ _MUTATEROWSRESULT._serialized_start=1196
+ _MUTATEROWSRESULT._serialized_end=1311
+ _CHECKANDMUTATEROWREQUEST._serialized_start=1313
+ _CHECKANDMUTATEROWREQUEST._serialized_end=1421
+ _CHECKANDMUTATEROWRESULT._serialized_start=1423
+ _CHECKANDMUTATEROWRESULT._serialized_end=1547
+ _SAMPLEROWKEYSREQUEST._serialized_start=1549
+ _SAMPLEROWKEYSREQUEST._serialized_end=1649
+ _SAMPLEROWKEYSRESULT._serialized_start=1651
+ _SAMPLEROWKEYSRESULT._serialized_end=1767
+ _READMODIFYWRITEROWREQUEST._serialized_start=1769
+ _READMODIFYWRITEROWREQUEST._serialized_end=1879
+ _CLOUDBIGTABLEV2TESTPROXY._serialized_start=1882
+ _CLOUDBIGTABLEV2TESTPROXY._serialized_end=3070
+# @@protoc_insertion_point(module_scope)
diff --git a/test_proxy/protos/test_proxy_pb2_grpc.py b/test_proxy/protos/test_proxy_pb2_grpc.py
new file mode 100644
index 000000000..60214a584
--- /dev/null
+++ b/test_proxy/protos/test_proxy_pb2_grpc.py
@@ -0,0 +1,433 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+import test_proxy_pb2 as test__proxy__pb2
+
+
+class CloudBigtableV2TestProxyStub(object):
+ """Note that all RPCs are unary, even when the equivalent client binding call
+ may be streaming. This is an intentional simplification.
+
+ Most methods have sync (default) and async variants. For async variants,
+ the proxy is expected to perform the async operation, then wait for results
+ before delivering them back to the driver client.
+
+ Operations that may have interesting concurrency characteristics are
+ represented explicitly in the API (see ReadRowsRequest.cancel_after_rows).
+ We include such operations only when they can be meaningfully performed
+ through client bindings.
+
+ Users should generally avoid setting deadlines for requests to the Proxy
+ because operations are not cancelable. If the deadline is set anyway, please
+ understand that the underlying operation will continue to be executed even
+ after the deadline expires.
+ """
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.CreateClient = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CreateClient',
+ request_serializer=test__proxy__pb2.CreateClientRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.CreateClientResponse.FromString,
+ )
+ self.CloseClient = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CloseClient',
+ request_serializer=test__proxy__pb2.CloseClientRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.CloseClientResponse.FromString,
+ )
+ self.RemoveClient = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/RemoveClient',
+ request_serializer=test__proxy__pb2.RemoveClientRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.RemoveClientResponse.FromString,
+ )
+ self.ReadRow = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadRow',
+ request_serializer=test__proxy__pb2.ReadRowRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.RowResult.FromString,
+ )
+ self.ReadRows = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadRows',
+ request_serializer=test__proxy__pb2.ReadRowsRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.RowsResult.FromString,
+ )
+ self.MutateRow = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/MutateRow',
+ request_serializer=test__proxy__pb2.MutateRowRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.MutateRowResult.FromString,
+ )
+ self.BulkMutateRows = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/BulkMutateRows',
+ request_serializer=test__proxy__pb2.MutateRowsRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.MutateRowsResult.FromString,
+ )
+ self.CheckAndMutateRow = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CheckAndMutateRow',
+ request_serializer=test__proxy__pb2.CheckAndMutateRowRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.CheckAndMutateRowResult.FromString,
+ )
+ self.SampleRowKeys = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/SampleRowKeys',
+ request_serializer=test__proxy__pb2.SampleRowKeysRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.SampleRowKeysResult.FromString,
+ )
+ self.ReadModifyWriteRow = channel.unary_unary(
+ '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadModifyWriteRow',
+ request_serializer=test__proxy__pb2.ReadModifyWriteRowRequest.SerializeToString,
+ response_deserializer=test__proxy__pb2.RowResult.FromString,
+ )
+
+
+class CloudBigtableV2TestProxyServicer(object):
+ """Note that all RPCs are unary, even when the equivalent client binding call
+ may be streaming. This is an intentional simplification.
+
+ Most methods have sync (default) and async variants. For async variants,
+ the proxy is expected to perform the async operation, then wait for results
+ before delivering them back to the driver client.
+
+ Operations that may have interesting concurrency characteristics are
+ represented explicitly in the API (see ReadRowsRequest.cancel_after_rows).
+ We include such operations only when they can be meaningfully performed
+ through client bindings.
+
+ Users should generally avoid setting deadlines for requests to the Proxy
+ because operations are not cancelable. If the deadline is set anyway, please
+ understand that the underlying operation will continue to be executed even
+ after the deadline expires.
+ """
+
+ def CreateClient(self, request, context):
+ """Client management:
+
+ Creates a client in the proxy.
+ Each client has its own dedicated channel(s), and can be used concurrently
+ and independently with other clients.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def CloseClient(self, request, context):
+ """Closes a client in the proxy, making it not accept new requests.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def RemoveClient(self, request, context):
+ """Removes a client in the proxy, making it inaccessible. Client closing
+ should be done by CloseClient() separately.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def ReadRow(self, request, context):
+ """Bigtable operations: for each operation, you should use the synchronous or
+ asynchronous variant of the client method based on the `use_async_method`
+ setting of the client instance. For starters, you can choose to implement
+ one variant, and return UNIMPLEMENTED status for the other.
+
+ Reads a row with the client instance.
+ The result row may not be present in the response.
+ Callers should check for it (e.g. calling has_row() in C++).
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def ReadRows(self, request, context):
+ """Reads rows with the client instance.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def MutateRow(self, request, context):
+ """Writes a row with the client instance.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def BulkMutateRows(self, request, context):
+ """Writes multiple rows with the client instance.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def CheckAndMutateRow(self, request, context):
+ """Performs a check-and-mutate-row operation with the client instance.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def SampleRowKeys(self, request, context):
+ """Obtains a row key sampling with the client instance.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def ReadModifyWriteRow(self, request, context):
+ """Performs a read-modify-write operation with the client.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_CloudBigtableV2TestProxyServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'CreateClient': grpc.unary_unary_rpc_method_handler(
+ servicer.CreateClient,
+ request_deserializer=test__proxy__pb2.CreateClientRequest.FromString,
+ response_serializer=test__proxy__pb2.CreateClientResponse.SerializeToString,
+ ),
+ 'CloseClient': grpc.unary_unary_rpc_method_handler(
+ servicer.CloseClient,
+ request_deserializer=test__proxy__pb2.CloseClientRequest.FromString,
+ response_serializer=test__proxy__pb2.CloseClientResponse.SerializeToString,
+ ),
+ 'RemoveClient': grpc.unary_unary_rpc_method_handler(
+ servicer.RemoveClient,
+ request_deserializer=test__proxy__pb2.RemoveClientRequest.FromString,
+ response_serializer=test__proxy__pb2.RemoveClientResponse.SerializeToString,
+ ),
+ 'ReadRow': grpc.unary_unary_rpc_method_handler(
+ servicer.ReadRow,
+ request_deserializer=test__proxy__pb2.ReadRowRequest.FromString,
+ response_serializer=test__proxy__pb2.RowResult.SerializeToString,
+ ),
+ 'ReadRows': grpc.unary_unary_rpc_method_handler(
+ servicer.ReadRows,
+ request_deserializer=test__proxy__pb2.ReadRowsRequest.FromString,
+ response_serializer=test__proxy__pb2.RowsResult.SerializeToString,
+ ),
+ 'MutateRow': grpc.unary_unary_rpc_method_handler(
+ servicer.MutateRow,
+ request_deserializer=test__proxy__pb2.MutateRowRequest.FromString,
+ response_serializer=test__proxy__pb2.MutateRowResult.SerializeToString,
+ ),
+ 'BulkMutateRows': grpc.unary_unary_rpc_method_handler(
+ servicer.BulkMutateRows,
+ request_deserializer=test__proxy__pb2.MutateRowsRequest.FromString,
+ response_serializer=test__proxy__pb2.MutateRowsResult.SerializeToString,
+ ),
+ 'CheckAndMutateRow': grpc.unary_unary_rpc_method_handler(
+ servicer.CheckAndMutateRow,
+ request_deserializer=test__proxy__pb2.CheckAndMutateRowRequest.FromString,
+ response_serializer=test__proxy__pb2.CheckAndMutateRowResult.SerializeToString,
+ ),
+ 'SampleRowKeys': grpc.unary_unary_rpc_method_handler(
+ servicer.SampleRowKeys,
+ request_deserializer=test__proxy__pb2.SampleRowKeysRequest.FromString,
+ response_serializer=test__proxy__pb2.SampleRowKeysResult.SerializeToString,
+ ),
+ 'ReadModifyWriteRow': grpc.unary_unary_rpc_method_handler(
+ servicer.ReadModifyWriteRow,
+ request_deserializer=test__proxy__pb2.ReadModifyWriteRowRequest.FromString,
+ response_serializer=test__proxy__pb2.RowResult.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'google.bigtable.testproxy.CloudBigtableV2TestProxy', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler,))
+
+
+ # This class is part of an EXPERIMENTAL API.
+class CloudBigtableV2TestProxy(object):
+ """Note that all RPCs are unary, even when the equivalent client binding call
+ may be streaming. This is an intentional simplification.
+
+ Most methods have sync (default) and async variants. For async variants,
+ the proxy is expected to perform the async operation, then wait for results
+ before delivering them back to the driver client.
+
+ Operations that may have interesting concurrency characteristics are
+ represented explicitly in the API (see ReadRowsRequest.cancel_after_rows).
+ We include such operations only when they can be meaningfully performed
+ through client bindings.
+
+ Users should generally avoid setting deadlines for requests to the Proxy
+ because operations are not cancelable. If the deadline is set anyway, please
+ understand that the underlying operation will continue to be executed even
+ after the deadline expires.
+ """
+
+ @staticmethod
+ def CreateClient(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CreateClient',
+ test__proxy__pb2.CreateClientRequest.SerializeToString,
+ test__proxy__pb2.CreateClientResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def CloseClient(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CloseClient',
+ test__proxy__pb2.CloseClientRequest.SerializeToString,
+ test__proxy__pb2.CloseClientResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def RemoveClient(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/RemoveClient',
+ test__proxy__pb2.RemoveClientRequest.SerializeToString,
+ test__proxy__pb2.RemoveClientResponse.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def ReadRow(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadRow',
+ test__proxy__pb2.ReadRowRequest.SerializeToString,
+ test__proxy__pb2.RowResult.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def ReadRows(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadRows',
+ test__proxy__pb2.ReadRowsRequest.SerializeToString,
+ test__proxy__pb2.RowsResult.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def MutateRow(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/MutateRow',
+ test__proxy__pb2.MutateRowRequest.SerializeToString,
+ test__proxy__pb2.MutateRowResult.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def BulkMutateRows(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/BulkMutateRows',
+ test__proxy__pb2.MutateRowsRequest.SerializeToString,
+ test__proxy__pb2.MutateRowsResult.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def CheckAndMutateRow(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CheckAndMutateRow',
+ test__proxy__pb2.CheckAndMutateRowRequest.SerializeToString,
+ test__proxy__pb2.CheckAndMutateRowResult.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def SampleRowKeys(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/SampleRowKeys',
+ test__proxy__pb2.SampleRowKeysRequest.SerializeToString,
+ test__proxy__pb2.SampleRowKeysResult.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def ReadModifyWriteRow(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadModifyWriteRow',
+ test__proxy__pb2.ReadModifyWriteRowRequest.SerializeToString,
+ test__proxy__pb2.RowResult.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh
new file mode 100755
index 000000000..15b146b03
--- /dev/null
+++ b/test_proxy/run_tests.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# attempt download golang if not found
+if [[ ! -x "$(command -v go)" ]]; then
+ echo "Downloading golang..."
+ wget https://go.dev/dl/go1.20.2.linux-amd64.tar.gz
+ tar -xzf go1.20.2.linux-amd64.tar.gz
+ export GOROOT=$(pwd)/go
+ export PATH=$GOROOT/bin:$PATH
+ export GOPATH=$HOME/go
+ go version
+fi
+
+# ensure the working dir is the script's folder
+SCRIPT_DIR=$(realpath $(dirname "$0"))
+cd $SCRIPT_DIR
+
+export PROXY_SERVER_PORT=50055
+
+# download test suite
+if [ ! -d "cloud-bigtable-clients-test" ]; then
+ git clone https://github.com/googleapis/cloud-bigtable-clients-test.git
+fi
+
+# start proxy
+python test_proxy.py --port $PROXY_SERVER_PORT &
+PROXY_PID=$!
+function finish {
+ kill $PROXY_PID
+}
+trap finish EXIT
+
+# run tests
+pushd cloud-bigtable-clients-test/tests
+go test -v -proxy_addr=:$PROXY_SERVER_PORT
diff --git a/test_proxy/test_proxy.py b/test_proxy/test_proxy.py
new file mode 100644
index 000000000..a0cf2f1f0
--- /dev/null
+++ b/test_proxy/test_proxy.py
@@ -0,0 +1,193 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+The Python implementation of the `cloud-bigtable-clients-test` proxy server.
+
+https://github.com/googleapis/cloud-bigtable-clients-test
+
+This server is intended to be used to test the correctness of Bigtable
+clients across languages.
+
+Contributor Note: the proxy implementation is split across TestProxyClientHandler
+and TestProxyGrpcServer. This is due to the fact that generated protos and proto-plus
+objects cannot be used in the same process, so we had to make use of the
+multiprocessing module to allow them to work together.
+"""
+
+import multiprocessing
+import argparse
+import sys
+import os
+sys.path.append("handlers")
+
+
+def grpc_server_process(request_q, queue_pool, port=50055):
+ """
+ Defines a process that hosts a grpc server
+ proxies requests to a client_handler_process
+ """
+ sys.path.append("protos")
+ from concurrent import futures
+
+ import grpc
+ import test_proxy_pb2_grpc
+ import grpc_handler
+
+ # Start gRPC server
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ test_proxy_pb2_grpc.add_CloudBigtableV2TestProxyServicer_to_server(
+ grpc_handler.TestProxyGrpcServer(request_q, queue_pool), server
+ )
+ server.add_insecure_port("[::]:" + port)
+ server.start()
+ print("grpc_server_process started, listening on " + port)
+ server.wait_for_termination()
+
+
+async def client_handler_process_async(request_q, queue_pool, use_legacy_client=False):
+ """
+ Defines a process that recives Bigtable requests from a grpc_server_process,
+ and runs the request using a client library instance
+ """
+ import base64
+ import re
+ import asyncio
+ import warnings
+ import client_handler_data
+ import client_handler_legacy
+ warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*Bigtable emulator.*")
+
+ def camel_to_snake(str):
+ return re.sub(r"(?= 1.14.0, < 2.0.0dev",
# Then this file should have foo==1.14.0
-google-api-core==1.34.0
-google-cloud-core==1.4.4
+google-api-core==2.16.0
+google-cloud-core==2.0.0
grpc-google-iam-v1==0.12.4
proto-plus==1.22.0
libcst==0.2.5
diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt
index e69de29bb..ee858c3ec 100644
--- a/testing/constraints-3.8.txt
+++ b/testing/constraints-3.8.txt
@@ -0,0 +1,14 @@
+# This constraints file is used to check that lower bounds
+# are correct in setup.py
+# List *all* library dependencies and extras in this file.
+# Pin the version to the lower bound.
+#
+# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev",
+# Then this file should have foo==1.14.0
+google-api-core==2.16.0
+google-cloud-core==2.0.0
+grpc-google-iam-v1==0.12.4
+proto-plus==1.22.0
+libcst==0.2.5
+protobuf==3.19.5
+pytest-asyncio==0.21.1
diff --git a/tests/system/__init__.py b/tests/system/__init__.py
index 4de65971c..89a37dc92 100644
--- a/tests/system/__init__.py
+++ b/tests/system/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/system/conftest.py b/tests/system/conftest.py
index 910c20970..b8862ea4b 100644
--- a/tests/system/conftest.py
+++ b/tests/system/conftest.py
@@ -1,4 +1,4 @@
-# Copyright 2011 Google LLC
+# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,199 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+"""
+Import pytest fixtures for setting up table for data client system tests
+"""
+import sys
import os
-import pytest
-from test_utils.system import unique_resource_id
-
-from google.cloud.bigtable.client import Client
-from google.cloud.environment_vars import BIGTABLE_EMULATOR
-
-from . import _helpers
-
-
-@pytest.fixture(scope="session")
-def in_emulator():
- return os.getenv(BIGTABLE_EMULATOR) is not None
-
-
-@pytest.fixture(scope="session")
-def kms_key_name():
- return os.getenv("KMS_KEY_NAME")
-
-
-@pytest.fixture(scope="session")
-def with_kms_key_name(kms_key_name):
- if kms_key_name is None:
- pytest.skip("Test requires KMS_KEY_NAME environment variable")
- return kms_key_name
-
-
-@pytest.fixture(scope="session")
-def skip_on_emulator(in_emulator):
- if in_emulator:
- pytest.skip("Emulator does not support this feature")
-
-
-@pytest.fixture(scope="session")
-def unique_suffix():
- return unique_resource_id("-")
-
-
-@pytest.fixture(scope="session")
-def location_id():
- return "us-central1-c"
-
-
-@pytest.fixture(scope="session")
-def serve_nodes():
- return 1
-
-
-@pytest.fixture(scope="session")
-def label_key():
- return "python-system"
-
-
-@pytest.fixture(scope="session")
-def instance_labels(label_key):
- return {label_key: _helpers.label_stamp()}
-
-
-@pytest.fixture(scope="session")
-def admin_client():
- return Client(admin=True)
-
-
-@pytest.fixture(scope="session")
-def service_account(admin_client):
- from google.oauth2.service_account import Credentials
-
- if not isinstance(admin_client._credentials, Credentials):
- pytest.skip("These tests require a service account credential")
- return admin_client._credentials
-
-
-@pytest.fixture(scope="session")
-def admin_instance_id(unique_suffix):
- return f"g-c-p{unique_suffix}"
-
-
-@pytest.fixture(scope="session")
-def admin_cluster_id(admin_instance_id):
- return f"{admin_instance_id}-cluster"
-
-
-@pytest.fixture(scope="session")
-def admin_instance(admin_client, admin_instance_id, instance_labels):
- return admin_client.instance(admin_instance_id, labels=instance_labels)
-
-
-@pytest.fixture(scope="session")
-def admin_cluster(admin_instance, admin_cluster_id, location_id, serve_nodes):
- return admin_instance.cluster(
- admin_cluster_id,
- location_id=location_id,
- serve_nodes=serve_nodes,
- )
-
-
-@pytest.fixture(scope="session")
-def admin_cluster_with_autoscaling(
- admin_instance,
- admin_cluster_id,
- location_id,
- min_serve_nodes,
- max_serve_nodes,
- cpu_utilization_percent,
-):
- return admin_instance.cluster(
- admin_cluster_id,
- location_id=location_id,
- min_serve_nodes=min_serve_nodes,
- max_serve_nodes=max_serve_nodes,
- cpu_utilization_percent=cpu_utilization_percent,
- )
-
-
-@pytest.fixture(scope="session")
-def admin_instance_populated(admin_instance, admin_cluster, in_emulator):
- # Emulator does not support instance admin operations (create / delete).
- # See: https://cloud.google.com/bigtable/docs/emulator
- if not in_emulator:
- operation = admin_instance.create(clusters=[admin_cluster])
- operation.result(timeout=240)
-
- yield admin_instance
-
- if not in_emulator:
- _helpers.retry_429(admin_instance.delete)()
-
-
-@pytest.fixture(scope="session")
-def data_client():
- return Client(admin=False)
-
-
-@pytest.fixture(scope="session")
-def data_instance_id(unique_suffix):
- return f"g-c-p-d{unique_suffix}"
-
-
-@pytest.fixture(scope="session")
-def data_cluster_id(data_instance_id):
- return f"{data_instance_id}-cluster"
-
-
-@pytest.fixture(scope="session")
-def data_instance_populated(
- admin_client,
- data_instance_id,
- instance_labels,
- data_cluster_id,
- location_id,
- serve_nodes,
- in_emulator,
-):
- instance = admin_client.instance(data_instance_id, labels=instance_labels)
- # Emulator does not support instance admin operations (create / delete).
- # See: https://cloud.google.com/bigtable/docs/emulator
- if not in_emulator:
- cluster = instance.cluster(
- data_cluster_id,
- location_id=location_id,
- serve_nodes=serve_nodes,
- )
- operation = instance.create(clusters=[cluster])
- operation.result(timeout=240)
-
- yield instance
-
- if not in_emulator:
- _helpers.retry_429(instance.delete)()
-
-
-@pytest.fixture(scope="function")
-def instances_to_delete():
- instances_to_delete = []
-
- yield instances_to_delete
-
- for instance in instances_to_delete:
- _helpers.retry_429(instance.delete)()
-
-
-@pytest.fixture(scope="session")
-def min_serve_nodes(in_emulator):
- return 1
-
-
-@pytest.fixture(scope="session")
-def max_serve_nodes(in_emulator):
- return 8
-
+script_path = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(script_path)
-@pytest.fixture(scope="session")
-def cpu_utilization_percent(in_emulator):
- return 10
+pytest_plugins = [
+ "data.setup_fixtures",
+]
diff --git a/tests/system/data/__init__.py b/tests/system/data/__init__.py
new file mode 100644
index 000000000..89a37dc92
--- /dev/null
+++ b/tests/system/data/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py
new file mode 100644
index 000000000..77086b7f3
--- /dev/null
+++ b/tests/system/data/setup_fixtures.py
@@ -0,0 +1,171 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Contains a set of pytest fixtures for setting up and populating a
+Bigtable database for testing purposes.
+"""
+
+import pytest
+import pytest_asyncio
+import os
+import asyncio
+import uuid
+
+
+@pytest.fixture(scope="session")
+def event_loop():
+ loop = asyncio.get_event_loop()
+ yield loop
+ loop.stop()
+ loop.close()
+
+
+@pytest.fixture(scope="session")
+def admin_client():
+ """
+ Client for interacting with Table and Instance admin APIs
+ """
+ from google.cloud.bigtable.client import Client
+
+ client = Client(admin=True)
+ yield client
+
+
+@pytest.fixture(scope="session")
+def instance_id(admin_client, project_id, cluster_config):
+ """
+ Returns BIGTABLE_TEST_INSTANCE if set, otherwise creates a new temporary instance for the test session
+ """
+ from google.cloud.bigtable_admin_v2 import types
+ from google.api_core import exceptions
+ from google.cloud.environment_vars import BIGTABLE_EMULATOR
+
+ # use user-specified instance if available
+ user_specified_instance = os.getenv("BIGTABLE_TEST_INSTANCE")
+ if user_specified_instance:
+ print("Using user-specified instance: {}".format(user_specified_instance))
+ yield user_specified_instance
+ return
+
+ # create a new temporary test instance
+ instance_id = f"python-bigtable-tests-{uuid.uuid4().hex[:6]}"
+ if os.getenv(BIGTABLE_EMULATOR):
+ # don't create instance if in emulator mode
+ yield instance_id
+ else:
+ try:
+ operation = admin_client.instance_admin_client.create_instance(
+ parent=f"projects/{project_id}",
+ instance_id=instance_id,
+ instance=types.Instance(
+ display_name="Test Instance",
+ # labels={"python-system-test": "true"},
+ ),
+ clusters=cluster_config,
+ )
+ operation.result(timeout=240)
+ except exceptions.AlreadyExists:
+ pass
+ yield instance_id
+ admin_client.instance_admin_client.delete_instance(
+ name=f"projects/{project_id}/instances/{instance_id}"
+ )
+
+
+@pytest.fixture(scope="session")
+def column_split_config():
+ """
+ specify initial splits to create when creating a new test table
+ """
+ return [(num * 1000).to_bytes(8, "big") for num in range(1, 10)]
+
+
+@pytest.fixture(scope="session")
+def table_id(
+ admin_client,
+ project_id,
+ instance_id,
+ column_family_config,
+ init_table_id,
+ column_split_config,
+):
+ """
+ Returns BIGTABLE_TEST_TABLE if set, otherwise creates a new temporary table for the test session
+
+ Args:
+ - admin_client: Client for interacting with the Table Admin API. Supplied by the admin_client fixture.
+ - project_id: The project ID of the GCP project to test against. Supplied by the project_id fixture.
+ - instance_id: The ID of the Bigtable instance to test against. Supplied by the instance_id fixture.
+ - init_column_families: A list of column families to initialize the table with, if pre-initialized table is not given with BIGTABLE_TEST_TABLE.
+ Supplied by the init_column_families fixture.
+ - init_table_id: The table ID to give to the test table, if pre-initialized table is not given with BIGTABLE_TEST_TABLE.
+ Supplied by the init_table_id fixture.
+ - column_split_config: A list of row keys to use as initial splits when creating the test table.
+ """
+ from google.api_core import exceptions
+ from google.api_core import retry
+
+ # use user-specified instance if available
+ user_specified_table = os.getenv("BIGTABLE_TEST_TABLE")
+ if user_specified_table:
+ print("Using user-specified table: {}".format(user_specified_table))
+ yield user_specified_table
+ return
+
+ retry = retry.Retry(
+ predicate=retry.if_exception_type(exceptions.FailedPrecondition)
+ )
+ try:
+ parent_path = f"projects/{project_id}/instances/{instance_id}"
+ print(f"Creating table: {parent_path}/tables/{init_table_id}")
+ admin_client.table_admin_client.create_table(
+ request={
+ "parent": parent_path,
+ "table_id": init_table_id,
+ "table": {"column_families": column_family_config},
+ "initial_splits": [{"key": key} for key in column_split_config],
+ },
+ retry=retry,
+ )
+ except exceptions.AlreadyExists:
+ pass
+ yield init_table_id
+ print(f"Deleting table: {parent_path}/tables/{init_table_id}")
+ try:
+ admin_client.table_admin_client.delete_table(
+ name=f"{parent_path}/tables/{init_table_id}"
+ )
+ except exceptions.NotFound:
+ print(f"Table {init_table_id} not found, skipping deletion")
+
+
+@pytest_asyncio.fixture(scope="session")
+async def client():
+ from google.cloud.bigtable.data import BigtableDataClientAsync
+
+ project = os.getenv("GOOGLE_CLOUD_PROJECT") or None
+ async with BigtableDataClientAsync(project=project, pool_size=4) as client:
+ yield client
+
+
+@pytest.fixture(scope="session")
+def project_id(client):
+ """Returns the project ID from the client."""
+ yield client.project
+
+
+@pytest_asyncio.fixture(scope="session")
+async def table(client, table_id, instance_id):
+ async with client.get_table(instance_id, table_id) as table:
+ yield table
diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py
new file mode 100644
index 000000000..aeb08fc1a
--- /dev/null
+++ b/tests/system/data/test_system.py
@@ -0,0 +1,943 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+import pytest_asyncio
+import asyncio
+import uuid
+import os
+from google.api_core import retry
+from google.api_core.exceptions import ClientError
+
+from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE
+from google.cloud.environment_vars import BIGTABLE_EMULATOR
+
+TEST_FAMILY = "test-family"
+TEST_FAMILY_2 = "test-family-2"
+
+
+@pytest.fixture(scope="session")
+def column_family_config():
+ """
+ specify column families to create when creating a new test table
+ """
+ from google.cloud.bigtable_admin_v2 import types
+
+ return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()}
+
+
+@pytest.fixture(scope="session")
+def init_table_id():
+ """
+ The table_id to use when creating a new test table
+ """
+ return f"test-table-{uuid.uuid4().hex}"
+
+
+@pytest.fixture(scope="session")
+def cluster_config(project_id):
+ """
+ Configuration for the clusters to use when creating a new instance
+ """
+ from google.cloud.bigtable_admin_v2 import types
+
+ cluster = {
+ "test-cluster": types.Cluster(
+ location=f"projects/{project_id}/locations/us-central1-b",
+ serve_nodes=1,
+ )
+ }
+ return cluster
+
+
+class TempRowBuilder:
+ """
+ Used to add rows to a table for testing purposes.
+ """
+
+ def __init__(self, table):
+ self.rows = []
+ self.table = table
+
+ async def add_row(
+ self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value"
+ ):
+ if isinstance(value, str):
+ value = value.encode("utf-8")
+ elif isinstance(value, int):
+ value = value.to_bytes(8, byteorder="big", signed=True)
+ request = {
+ "table_name": self.table.table_name,
+ "row_key": row_key,
+ "mutations": [
+ {
+ "set_cell": {
+ "family_name": family,
+ "column_qualifier": qualifier,
+ "value": value,
+ }
+ }
+ ],
+ }
+ await self.table.client._gapic_client.mutate_row(request)
+ self.rows.append(row_key)
+
+ async def delete_rows(self):
+ if self.rows:
+ request = {
+ "table_name": self.table.table_name,
+ "entries": [
+ {"row_key": row, "mutations": [{"delete_from_row": {}}]}
+ for row in self.rows
+ ],
+ }
+ await self.table.client._gapic_client.mutate_rows(request)
+
+
+@pytest.mark.usefixtures("table")
+async def _retrieve_cell_value(table, row_key):
+ """
+ Helper to read an individual row
+ """
+ from google.cloud.bigtable.data import ReadRowsQuery
+
+ row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key))
+ assert len(row_list) == 1
+ row = row_list[0]
+ cell = row.cells[0]
+ return cell.value
+
+
+async def _create_row_and_mutation(
+ table, temp_rows, *, start_value=b"start", new_value=b"new_value"
+):
+ """
+ Helper to create a new row, and a sample set_cell mutation to change its value
+ """
+ from google.cloud.bigtable.data.mutations import SetCell
+
+ row_key = uuid.uuid4().hex.encode()
+ family = TEST_FAMILY
+ qualifier = b"test-qualifier"
+ await temp_rows.add_row(
+ row_key, family=family, qualifier=qualifier, value=start_value
+ )
+ # ensure cell is initialized
+ assert (await _retrieve_cell_value(table, row_key)) == start_value
+
+ mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value)
+ return row_key, mutation
+
+
+@pytest.mark.usefixtures("table")
+@pytest_asyncio.fixture(scope="function")
+async def temp_rows(table):
+ builder = TempRowBuilder(table)
+ yield builder
+ await builder.delete_rows()
+
+
+@pytest.mark.usefixtures("table")
+@pytest.mark.usefixtures("client")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=10)
+@pytest.mark.asyncio
+async def test_ping_and_warm_gapic(client, table):
+ """
+ Simple ping rpc test
+ This test ensures channels are able to authenticate with backend
+ """
+ request = {"name": table.instance_name}
+ await client._gapic_client.ping_and_warm(request)
+
+
+@pytest.mark.usefixtures("table")
+@pytest.mark.usefixtures("client")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_ping_and_warm(client, table):
+ """
+ Test ping and warm from handwritten client
+ """
+ try:
+ channel = client.transport._grpc_channel.pool[0]
+ except Exception:
+ # for sync client
+ channel = client.transport._grpc_channel
+ results = await client._ping_and_warm_instances(channel)
+ assert len(results) == 1
+ assert results[0] is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+async def test_mutation_set_cell(table, temp_rows):
+ """
+ Ensure cells can be set properly
+ """
+ row_key = b"bulk_mutate"
+ new_value = uuid.uuid4().hex.encode()
+ row_key, mutation = await _create_row_and_mutation(
+ table, temp_rows, new_value=new_value
+ )
+ await table.mutate_row(row_key, mutation)
+
+ # ensure cell is updated
+ assert (await _retrieve_cell_value(table, row_key)) == new_value
+
+
+@pytest.mark.skipif(
+ bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits"
+)
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_sample_row_keys(client, table, temp_rows, column_split_config):
+ """
+ Sample keys should return a single sample in small test tables
+ """
+ await temp_rows.add_row(b"row_key_1")
+ await temp_rows.add_row(b"row_key_2")
+
+ results = await table.sample_row_keys()
+ assert len(results) == len(column_split_config) + 1
+ # first keys should match the split config
+ for idx in range(len(column_split_config)):
+ assert results[idx][0] == column_split_config[idx]
+ assert isinstance(results[idx][1], int)
+ # last sample should be empty key
+ assert results[-1][0] == b""
+ assert isinstance(results[-1][1], int)
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@pytest.mark.asyncio
+async def test_bulk_mutations_set_cell(client, table, temp_rows):
+ """
+ Ensure cells can be set properly
+ """
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+
+ new_value = uuid.uuid4().hex.encode()
+ row_key, mutation = await _create_row_and_mutation(
+ table, temp_rows, new_value=new_value
+ )
+ bulk_mutation = RowMutationEntry(row_key, [mutation])
+
+ await table.bulk_mutate_rows([bulk_mutation])
+
+ # ensure cell is updated
+ assert (await _retrieve_cell_value(table, row_key)) == new_value
+
+
+@pytest.mark.asyncio
+async def test_bulk_mutations_raise_exception(client, table):
+ """
+ If an invalid mutation is passed, an exception should be raised
+ """
+ from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell
+ from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
+ from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
+
+ row_key = uuid.uuid4().hex.encode()
+ mutation = SetCell(family="nonexistent", qualifier=b"test-qualifier", new_value=b"")
+ bulk_mutation = RowMutationEntry(row_key, [mutation])
+
+ with pytest.raises(MutationsExceptionGroup) as exc:
+ await table.bulk_mutate_rows([bulk_mutation])
+ assert len(exc.value.exceptions) == 1
+ entry_error = exc.value.exceptions[0]
+ assert isinstance(entry_error, FailedMutationEntryError)
+ assert entry_error.index == 0
+ assert entry_error.entry == bulk_mutation
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_mutations_batcher_context_manager(client, table, temp_rows):
+ """
+ test batcher with context manager. Should flush on exit
+ """
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+
+ new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)]
+ row_key, mutation = await _create_row_and_mutation(
+ table, temp_rows, new_value=new_value
+ )
+ row_key2, mutation2 = await _create_row_and_mutation(
+ table, temp_rows, new_value=new_value2
+ )
+ bulk_mutation = RowMutationEntry(row_key, [mutation])
+ bulk_mutation2 = RowMutationEntry(row_key2, [mutation2])
+
+ async with table.mutations_batcher() as batcher:
+ await batcher.append(bulk_mutation)
+ await batcher.append(bulk_mutation2)
+ # ensure cell is updated
+ assert (await _retrieve_cell_value(table, row_key)) == new_value
+ assert len(batcher._staged_entries) == 0
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_mutations_batcher_timer_flush(client, table, temp_rows):
+ """
+ batch should occur after flush_interval seconds
+ """
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+
+ new_value = uuid.uuid4().hex.encode()
+ row_key, mutation = await _create_row_and_mutation(
+ table, temp_rows, new_value=new_value
+ )
+ bulk_mutation = RowMutationEntry(row_key, [mutation])
+ flush_interval = 0.1
+ async with table.mutations_batcher(flush_interval=flush_interval) as batcher:
+ await batcher.append(bulk_mutation)
+ await asyncio.sleep(0)
+ assert len(batcher._staged_entries) == 1
+ await asyncio.sleep(flush_interval + 0.1)
+ assert len(batcher._staged_entries) == 0
+ # ensure cell is updated
+ assert (await _retrieve_cell_value(table, row_key)) == new_value
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_mutations_batcher_count_flush(client, table, temp_rows):
+ """
+ batch should flush after flush_limit_mutation_count mutations
+ """
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+
+ new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)]
+ row_key, mutation = await _create_row_and_mutation(
+ table, temp_rows, new_value=new_value
+ )
+ bulk_mutation = RowMutationEntry(row_key, [mutation])
+ row_key2, mutation2 = await _create_row_and_mutation(
+ table, temp_rows, new_value=new_value2
+ )
+ bulk_mutation2 = RowMutationEntry(row_key2, [mutation2])
+
+ async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher:
+ await batcher.append(bulk_mutation)
+ assert len(batcher._flush_jobs) == 0
+ # should be noop; flush not scheduled
+ assert len(batcher._staged_entries) == 1
+ await batcher.append(bulk_mutation2)
+ # task should now be scheduled
+ assert len(batcher._flush_jobs) == 1
+ await asyncio.gather(*batcher._flush_jobs)
+ assert len(batcher._staged_entries) == 0
+ assert len(batcher._flush_jobs) == 0
+ # ensure cells were updated
+ assert (await _retrieve_cell_value(table, row_key)) == new_value
+ assert (await _retrieve_cell_value(table, row_key2)) == new_value2
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_mutations_batcher_bytes_flush(client, table, temp_rows):
+ """
+ batch should flush after flush_limit_bytes bytes
+ """
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+
+ new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)]
+ row_key, mutation = await _create_row_and_mutation(
+ table, temp_rows, new_value=new_value
+ )
+ bulk_mutation = RowMutationEntry(row_key, [mutation])
+ row_key2, mutation2 = await _create_row_and_mutation(
+ table, temp_rows, new_value=new_value2
+ )
+ bulk_mutation2 = RowMutationEntry(row_key2, [mutation2])
+
+ flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1
+
+ async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher:
+ await batcher.append(bulk_mutation)
+ assert len(batcher._flush_jobs) == 0
+ assert len(batcher._staged_entries) == 1
+ await batcher.append(bulk_mutation2)
+ # task should now be scheduled
+ assert len(batcher._flush_jobs) == 1
+ assert len(batcher._staged_entries) == 0
+ # let flush complete
+ await asyncio.gather(*batcher._flush_jobs)
+ # ensure cells were updated
+ assert (await _retrieve_cell_value(table, row_key)) == new_value
+ assert (await _retrieve_cell_value(table, row_key2)) == new_value2
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@pytest.mark.asyncio
+async def test_mutations_batcher_no_flush(client, table, temp_rows):
+ """
+ test with no flush requirements met
+ """
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+
+ new_value = uuid.uuid4().hex.encode()
+ start_value = b"unchanged"
+ row_key, mutation = await _create_row_and_mutation(
+ table, temp_rows, start_value=start_value, new_value=new_value
+ )
+ bulk_mutation = RowMutationEntry(row_key, [mutation])
+ row_key2, mutation2 = await _create_row_and_mutation(
+ table, temp_rows, start_value=start_value, new_value=new_value
+ )
+ bulk_mutation2 = RowMutationEntry(row_key2, [mutation2])
+
+ size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1
+ async with table.mutations_batcher(
+ flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1
+ ) as batcher:
+ await batcher.append(bulk_mutation)
+ assert len(batcher._staged_entries) == 1
+ await batcher.append(bulk_mutation2)
+ # flush not scheduled
+ assert len(batcher._flush_jobs) == 0
+ await asyncio.sleep(0.01)
+ assert len(batcher._staged_entries) == 2
+ assert len(batcher._flush_jobs) == 0
+ # ensure cells were not updated
+ assert (await _retrieve_cell_value(table, row_key)) == start_value
+ assert (await _retrieve_cell_value(table, row_key2)) == start_value
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@pytest.mark.parametrize(
+ "start,increment,expected",
+ [
+ (0, 0, 0),
+ (0, 1, 1),
+ (0, -1, -1),
+ (1, 0, 1),
+ (0, -100, -100),
+ (0, 3000, 3000),
+ (10, 4, 14),
+ (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0),
+ (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE),
+ (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE),
+ ],
+)
+@pytest.mark.asyncio
+async def test_read_modify_write_row_increment(
+ client, table, temp_rows, start, increment, expected
+):
+ """
+ test read_modify_write_row
+ """
+ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule
+
+ row_key = b"test-row-key"
+ family = TEST_FAMILY
+ qualifier = b"test-qualifier"
+ await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier)
+
+ rule = IncrementRule(family, qualifier, increment)
+ result = await table.read_modify_write_row(row_key, rule)
+ assert result.row_key == row_key
+ assert len(result) == 1
+ assert result[0].family == family
+ assert result[0].qualifier == qualifier
+ assert int(result[0]) == expected
+ # ensure that reading from server gives same value
+ assert (await _retrieve_cell_value(table, row_key)) == result[0].value
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@pytest.mark.parametrize(
+ "start,append,expected",
+ [
+ (b"", b"", b""),
+ ("", "", b""),
+ (b"abc", b"123", b"abc123"),
+ (b"abc", "123", b"abc123"),
+ ("", b"1", b"1"),
+ (b"abc", "", b"abc"),
+ (b"hello", b"world", b"helloworld"),
+ ],
+)
+@pytest.mark.asyncio
+async def test_read_modify_write_row_append(
+ client, table, temp_rows, start, append, expected
+):
+ """
+ test read_modify_write_row
+ """
+ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule
+
+ row_key = b"test-row-key"
+ family = TEST_FAMILY
+ qualifier = b"test-qualifier"
+ await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier)
+
+ rule = AppendValueRule(family, qualifier, append)
+ result = await table.read_modify_write_row(row_key, rule)
+ assert result.row_key == row_key
+ assert len(result) == 1
+ assert result[0].family == family
+ assert result[0].qualifier == qualifier
+ assert result[0].value == expected
+ # ensure that reading from server gives same value
+ assert (await _retrieve_cell_value(table, row_key)) == result[0].value
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@pytest.mark.asyncio
+async def test_read_modify_write_row_chained(client, table, temp_rows):
+ """
+ test read_modify_write_row with multiple rules
+ """
+ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule
+ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule
+
+ row_key = b"test-row-key"
+ family = TEST_FAMILY
+ qualifier = b"test-qualifier"
+ start_amount = 1
+ increment_amount = 10
+ await temp_rows.add_row(
+ row_key, value=start_amount, family=family, qualifier=qualifier
+ )
+ rule = [
+ IncrementRule(family, qualifier, increment_amount),
+ AppendValueRule(family, qualifier, "hello"),
+ AppendValueRule(family, qualifier, "world"),
+ AppendValueRule(family, qualifier, "!"),
+ ]
+ result = await table.read_modify_write_row(row_key, rule)
+ assert result.row_key == row_key
+ assert result[0].family == family
+ assert result[0].qualifier == qualifier
+ # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values
+ assert (
+ result[0].value
+ == (start_amount + increment_amount).to_bytes(8, "big", signed=True)
+ + b"helloworld!"
+ )
+ # ensure that reading from server gives same value
+ assert (await _retrieve_cell_value(table, row_key)) == result[0].value
+
+
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@pytest.mark.parametrize(
+ "start_val,predicate_range,expected_result",
+ [
+ (1, (0, 2), True),
+ (-1, (0, 2), False),
+ ],
+)
+@pytest.mark.asyncio
+async def test_check_and_mutate(
+ client, table, temp_rows, start_val, predicate_range, expected_result
+):
+ """
+ test that check_and_mutate_row works applies the right mutations, and returns the right result
+ """
+ from google.cloud.bigtable.data.mutations import SetCell
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ row_key = b"test-row-key"
+ family = TEST_FAMILY
+ qualifier = b"test-qualifier"
+
+ await temp_rows.add_row(
+ row_key, value=start_val, family=family, qualifier=qualifier
+ )
+
+ false_mutation_value = b"false-mutation-value"
+ false_mutation = SetCell(
+ family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value
+ )
+ true_mutation_value = b"true-mutation-value"
+ true_mutation = SetCell(
+ family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value
+ )
+ predicate = ValueRangeFilter(predicate_range[0], predicate_range[1])
+ result = await table.check_and_mutate_row(
+ row_key,
+ predicate,
+ true_case_mutations=true_mutation,
+ false_case_mutations=false_mutation,
+ )
+ assert result == expected_result
+ # ensure cell is updated
+ expected_value = true_mutation_value if expected_result else false_mutation_value
+ assert (await _retrieve_cell_value(table, row_key)) == expected_value
+
+
+@pytest.mark.skipif(
+ bool(os.environ.get(BIGTABLE_EMULATOR)),
+ reason="emulator doesn't raise InvalidArgument",
+)
+@pytest.mark.usefixtures("client")
+@pytest.mark.usefixtures("table")
+@pytest.mark.asyncio
+async def test_check_and_mutate_empty_request(client, table):
+ """
+ check_and_mutate with no true or fale mutations should raise an error
+ """
+ from google.api_core import exceptions
+
+ with pytest.raises(exceptions.InvalidArgument) as e:
+ await table.check_and_mutate_row(
+ b"row_key", None, true_case_mutations=None, false_case_mutations=None
+ )
+ assert "No mutations provided" in str(e.value)
+
+
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_read_rows_stream(table, temp_rows):
+ """
+ Ensure that the read_rows_stream method works
+ """
+ await temp_rows.add_row(b"row_key_1")
+ await temp_rows.add_row(b"row_key_2")
+
+ # full table scan
+ generator = await table.read_rows_stream({})
+ first_row = await generator.__anext__()
+ second_row = await generator.__anext__()
+ assert first_row.row_key == b"row_key_1"
+ assert second_row.row_key == b"row_key_2"
+ with pytest.raises(StopAsyncIteration):
+ await generator.__anext__()
+
+
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_read_rows(table, temp_rows):
+ """
+ Ensure that the read_rows method works
+ """
+ await temp_rows.add_row(b"row_key_1")
+ await temp_rows.add_row(b"row_key_2")
+ # full table scan
+ row_list = await table.read_rows({})
+ assert len(row_list) == 2
+ assert row_list[0].row_key == b"row_key_1"
+ assert row_list[1].row_key == b"row_key_2"
+
+
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_read_rows_sharded_simple(table, temp_rows):
+ """
+ Test read rows sharded with two queries
+ """
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+
+ await temp_rows.add_row(b"a")
+ await temp_rows.add_row(b"b")
+ await temp_rows.add_row(b"c")
+ await temp_rows.add_row(b"d")
+ query1 = ReadRowsQuery(row_keys=[b"a", b"c"])
+ query2 = ReadRowsQuery(row_keys=[b"b", b"d"])
+ row_list = await table.read_rows_sharded([query1, query2])
+ assert len(row_list) == 4
+ assert row_list[0].row_key == b"a"
+ assert row_list[1].row_key == b"c"
+ assert row_list[2].row_key == b"b"
+ assert row_list[3].row_key == b"d"
+
+
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_read_rows_sharded_from_sample(table, temp_rows):
+ """
+ Test end-to-end sharding
+ """
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+ from google.cloud.bigtable.data.read_rows_query import RowRange
+
+ await temp_rows.add_row(b"a")
+ await temp_rows.add_row(b"b")
+ await temp_rows.add_row(b"c")
+ await temp_rows.add_row(b"d")
+
+ table_shard_keys = await table.sample_row_keys()
+ query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")])
+ shard_queries = query.shard(table_shard_keys)
+ row_list = await table.read_rows_sharded(shard_queries)
+ assert len(row_list) == 3
+ assert row_list[0].row_key == b"b"
+ assert row_list[1].row_key == b"c"
+ assert row_list[2].row_key == b"d"
+
+
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_read_rows_sharded_filters_limits(table, temp_rows):
+ """
+ Test read rows sharded with filters and limits
+ """
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+
+ await temp_rows.add_row(b"a")
+ await temp_rows.add_row(b"b")
+ await temp_rows.add_row(b"c")
+ await temp_rows.add_row(b"d")
+
+ label_filter1 = ApplyLabelFilter("first")
+ label_filter2 = ApplyLabelFilter("second")
+ query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1)
+ query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2)
+ row_list = await table.read_rows_sharded([query1, query2])
+ assert len(row_list) == 3
+ assert row_list[0].row_key == b"a"
+ assert row_list[1].row_key == b"b"
+ assert row_list[2].row_key == b"d"
+ assert row_list[0][0].labels == ["first"]
+ assert row_list[1][0].labels == ["second"]
+ assert row_list[2][0].labels == ["second"]
+
+
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_read_rows_range_query(table, temp_rows):
+ """
+ Ensure that the read_rows method works
+ """
+ from google.cloud.bigtable.data import ReadRowsQuery
+ from google.cloud.bigtable.data import RowRange
+
+ await temp_rows.add_row(b"a")
+ await temp_rows.add_row(b"b")
+ await temp_rows.add_row(b"c")
+ await temp_rows.add_row(b"d")
+ # full table scan
+ query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d"))
+ row_list = await table.read_rows(query)
+ assert len(row_list) == 2
+ assert row_list[0].row_key == b"b"
+ assert row_list[1].row_key == b"c"
+
+
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_read_rows_single_key_query(table, temp_rows):
+ """
+ Ensure that the read_rows method works with specified query
+ """
+ from google.cloud.bigtable.data import ReadRowsQuery
+
+ await temp_rows.add_row(b"a")
+ await temp_rows.add_row(b"b")
+ await temp_rows.add_row(b"c")
+ await temp_rows.add_row(b"d")
+ # retrieve specific keys
+ query = ReadRowsQuery(row_keys=[b"a", b"c"])
+ row_list = await table.read_rows(query)
+ assert len(row_list) == 2
+ assert row_list[0].row_key == b"a"
+ assert row_list[1].row_key == b"c"
+
+
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.asyncio
+async def test_read_rows_with_filter(table, temp_rows):
+ """
+ ensure filters are applied
+ """
+ from google.cloud.bigtable.data import ReadRowsQuery
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+
+ await temp_rows.add_row(b"a")
+ await temp_rows.add_row(b"b")
+ await temp_rows.add_row(b"c")
+ await temp_rows.add_row(b"d")
+ # retrieve keys with filter
+ expected_label = "test-label"
+ row_filter = ApplyLabelFilter(expected_label)
+ query = ReadRowsQuery(row_filter=row_filter)
+ row_list = await table.read_rows(query)
+ assert len(row_list) == 4
+ for row in row_list:
+ assert row[0].labels == [expected_label]
+
+
+@pytest.mark.usefixtures("table")
+@pytest.mark.asyncio
+async def test_read_rows_stream_close(table, temp_rows):
+ """
+ Ensure that the read_rows_stream can be closed
+ """
+ from google.cloud.bigtable.data import ReadRowsQuery
+
+ await temp_rows.add_row(b"row_key_1")
+ await temp_rows.add_row(b"row_key_2")
+ # full table scan
+ query = ReadRowsQuery()
+ generator = await table.read_rows_stream(query)
+ # grab first row
+ first_row = await generator.__anext__()
+ assert first_row.row_key == b"row_key_1"
+ # close stream early
+ await generator.aclose()
+ with pytest.raises(StopAsyncIteration):
+ await generator.__anext__()
+
+
+@pytest.mark.usefixtures("table")
+@pytest.mark.asyncio
+async def test_read_row(table, temp_rows):
+ """
+ Test read_row (single row helper)
+ """
+ from google.cloud.bigtable.data import Row
+
+ await temp_rows.add_row(b"row_key_1", value=b"value")
+ row = await table.read_row(b"row_key_1")
+ assert isinstance(row, Row)
+ assert row.row_key == b"row_key_1"
+ assert row.cells[0].value == b"value"
+
+
+@pytest.mark.skipif(
+ bool(os.environ.get(BIGTABLE_EMULATOR)),
+ reason="emulator doesn't raise InvalidArgument",
+)
+@pytest.mark.usefixtures("table")
+@pytest.mark.asyncio
+async def test_read_row_missing(table):
+ """
+ Test read_row when row does not exist
+ """
+ from google.api_core import exceptions
+
+ row_key = "row_key_not_exist"
+ result = await table.read_row(row_key)
+ assert result is None
+ with pytest.raises(exceptions.InvalidArgument) as e:
+ await table.read_row("")
+ assert "Row keys must be non-empty" in str(e)
+
+
+@pytest.mark.usefixtures("table")
+@pytest.mark.asyncio
+async def test_read_row_w_filter(table, temp_rows):
+ """
+ Test read_row (single row helper)
+ """
+ from google.cloud.bigtable.data import Row
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+
+ await temp_rows.add_row(b"row_key_1", value=b"value")
+ expected_label = "test-label"
+ label_filter = ApplyLabelFilter(expected_label)
+ row = await table.read_row(b"row_key_1", row_filter=label_filter)
+ assert isinstance(row, Row)
+ assert row.row_key == b"row_key_1"
+ assert row.cells[0].value == b"value"
+ assert row.cells[0].labels == [expected_label]
+
+
+@pytest.mark.skipif(
+ bool(os.environ.get(BIGTABLE_EMULATOR)),
+ reason="emulator doesn't raise InvalidArgument",
+)
+@pytest.mark.usefixtures("table")
+@pytest.mark.asyncio
+async def test_row_exists(table, temp_rows):
+ from google.api_core import exceptions
+
+ """Test row_exists with rows that exist and don't exist"""
+ assert await table.row_exists(b"row_key_1") is False
+ await temp_rows.add_row(b"row_key_1")
+ assert await table.row_exists(b"row_key_1") is True
+ assert await table.row_exists("row_key_1") is True
+ assert await table.row_exists(b"row_key_2") is False
+ assert await table.row_exists("row_key_2") is False
+ assert await table.row_exists("3") is False
+ await temp_rows.add_row(b"3")
+ assert await table.row_exists(b"3") is True
+ with pytest.raises(exceptions.InvalidArgument) as e:
+ await table.row_exists("")
+ assert "Row keys must be non-empty" in str(e)
+
+
+@pytest.mark.usefixtures("table")
+@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
+@pytest.mark.parametrize(
+ "cell_value,filter_input,expect_match",
+ [
+ (b"abc", b"abc", True),
+ (b"abc", "abc", True),
+ (b".", ".", True),
+ (".*", ".*", True),
+ (".*", b".*", True),
+ ("a", ".*", False),
+ (b".*", b".*", True),
+ (r"\a", r"\a", True),
+ (b"\xe2\x98\x83", "☃", True),
+ ("☃", "☃", True),
+ (r"\C☃", r"\C☃", True),
+ (1, 1, True),
+ (2, 1, False),
+ (68, 68, True),
+ ("D", 68, False),
+ (68, "D", False),
+ (-1, -1, True),
+ (2852126720, 2852126720, True),
+ (-1431655766, -1431655766, True),
+ (-1431655766, -1, False),
+ ],
+)
+@pytest.mark.asyncio
+async def test_literal_value_filter(
+ table, temp_rows, cell_value, filter_input, expect_match
+):
+ """
+ Literal value filter does complex escaping on re2 strings.
+ Make sure inputs are properly interpreted by the server
+ """
+ from google.cloud.bigtable.data.row_filters import LiteralValueFilter
+ from google.cloud.bigtable.data import ReadRowsQuery
+
+ f = LiteralValueFilter(filter_input)
+ await temp_rows.add_row(b"row_key_1", value=cell_value)
+ query = ReadRowsQuery(row_filter=f)
+ row_list = await table.read_rows(query)
+ assert len(row_list) == bool(
+ expect_match
+ ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter"
diff --git a/tests/system/v2_client/__init__.py b/tests/system/v2_client/__init__.py
new file mode 100644
index 000000000..4de65971c
--- /dev/null
+++ b/tests/system/v2_client/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/tests/system/_helpers.py b/tests/system/v2_client/_helpers.py
similarity index 100%
rename from tests/system/_helpers.py
rename to tests/system/v2_client/_helpers.py
diff --git a/tests/system/v2_client/conftest.py b/tests/system/v2_client/conftest.py
new file mode 100644
index 000000000..f39fcba88
--- /dev/null
+++ b/tests/system/v2_client/conftest.py
@@ -0,0 +1,209 @@
+# Copyright 2011 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import pytest
+from test_utils.system import unique_resource_id
+
+from google.cloud.bigtable.client import Client
+from google.cloud.environment_vars import BIGTABLE_EMULATOR
+
+from . import _helpers
+
+
+@pytest.fixture(scope="session")
+def in_emulator():
+ return os.getenv(BIGTABLE_EMULATOR) is not None
+
+
+@pytest.fixture(scope="session")
+def kms_key_name():
+ return os.getenv("KMS_KEY_NAME")
+
+
+@pytest.fixture(scope="session")
+def with_kms_key_name(kms_key_name):
+ if kms_key_name is None:
+ pytest.skip("Test requires KMS_KEY_NAME environment variable")
+ return kms_key_name
+
+
+@pytest.fixture(scope="session")
+def skip_on_emulator(in_emulator):
+ if in_emulator:
+ pytest.skip("Emulator does not support this feature")
+
+
+@pytest.fixture(scope="session")
+def unique_suffix():
+ return unique_resource_id("-")
+
+
+@pytest.fixture(scope="session")
+def location_id():
+ return "us-central1-c"
+
+
+@pytest.fixture(scope="session")
+def serve_nodes():
+ return 3
+
+
+@pytest.fixture(scope="session")
+def label_key():
+ return "python-system"
+
+
+@pytest.fixture(scope="session")
+def instance_labels(label_key):
+ return {label_key: _helpers.label_stamp()}
+
+
+@pytest.fixture(scope="session")
+def admin_client():
+ return Client(admin=True)
+
+
+@pytest.fixture(scope="session")
+def service_account(admin_client):
+ from google.oauth2.service_account import Credentials
+
+ if not isinstance(admin_client._credentials, Credentials):
+ pytest.skip("These tests require a service account credential")
+ return admin_client._credentials
+
+
+@pytest.fixture(scope="session")
+def admin_instance_id(unique_suffix):
+ return f"g-c-p{unique_suffix}"
+
+
+@pytest.fixture(scope="session")
+def admin_cluster_id(admin_instance_id):
+ return f"{admin_instance_id}-cluster"
+
+
+@pytest.fixture(scope="session")
+def admin_instance(admin_client, admin_instance_id, instance_labels):
+ return admin_client.instance(admin_instance_id, labels=instance_labels)
+
+
+@pytest.fixture(scope="session")
+def admin_cluster(admin_instance, admin_cluster_id, location_id, serve_nodes):
+ return admin_instance.cluster(
+ admin_cluster_id,
+ location_id=location_id,
+ serve_nodes=serve_nodes,
+ )
+
+
+@pytest.fixture(scope="session")
+def admin_cluster_with_autoscaling(
+ admin_instance,
+ admin_cluster_id,
+ location_id,
+ min_serve_nodes,
+ max_serve_nodes,
+ cpu_utilization_percent,
+):
+ return admin_instance.cluster(
+ admin_cluster_id,
+ location_id=location_id,
+ min_serve_nodes=min_serve_nodes,
+ max_serve_nodes=max_serve_nodes,
+ cpu_utilization_percent=cpu_utilization_percent,
+ )
+
+
+@pytest.fixture(scope="session")
+def admin_instance_populated(admin_instance, admin_cluster, in_emulator):
+ # Emulator does not support instance admin operations (create / delete).
+ # See: https://cloud.google.com/bigtable/docs/emulator
+ if not in_emulator:
+ operation = admin_instance.create(clusters=[admin_cluster])
+ operation.result(timeout=240)
+
+ yield admin_instance
+
+ if not in_emulator:
+ _helpers.retry_429(admin_instance.delete)()
+
+
+@pytest.fixture(scope="session")
+def data_client():
+ return Client(admin=False)
+
+
+@pytest.fixture(scope="session")
+def data_instance_id(unique_suffix):
+ return f"g-c-p-d{unique_suffix}"
+
+
+@pytest.fixture(scope="session")
+def data_cluster_id(data_instance_id):
+ return f"{data_instance_id}-cluster"
+
+
+@pytest.fixture(scope="session")
+def data_instance_populated(
+ admin_client,
+ data_instance_id,
+ instance_labels,
+ data_cluster_id,
+ location_id,
+ serve_nodes,
+ in_emulator,
+):
+ instance = admin_client.instance(data_instance_id, labels=instance_labels)
+ # Emulator does not support instance admin operations (create / delete).
+ # See: https://cloud.google.com/bigtable/docs/emulator
+ if not in_emulator:
+ cluster = instance.cluster(
+ data_cluster_id,
+ location_id=location_id,
+ serve_nodes=serve_nodes,
+ )
+ operation = instance.create(clusters=[cluster])
+ operation.result(timeout=240)
+
+ yield instance
+
+ if not in_emulator:
+ _helpers.retry_429(instance.delete)()
+
+
+@pytest.fixture(scope="function")
+def instances_to_delete():
+ instances_to_delete = []
+
+ yield instances_to_delete
+
+ for instance in instances_to_delete:
+ _helpers.retry_429(instance.delete)()
+
+
+@pytest.fixture(scope="session")
+def min_serve_nodes(in_emulator):
+ return 1
+
+
+@pytest.fixture(scope="session")
+def max_serve_nodes(in_emulator):
+ return 8
+
+
+@pytest.fixture(scope="session")
+def cpu_utilization_percent(in_emulator):
+ return 10
diff --git a/tests/system/test_data_api.py b/tests/system/v2_client/test_data_api.py
similarity index 100%
rename from tests/system/test_data_api.py
rename to tests/system/v2_client/test_data_api.py
diff --git a/tests/system/test_instance_admin.py b/tests/system/v2_client/test_instance_admin.py
similarity index 100%
rename from tests/system/test_instance_admin.py
rename to tests/system/v2_client/test_instance_admin.py
diff --git a/tests/system/test_table_admin.py b/tests/system/v2_client/test_table_admin.py
similarity index 100%
rename from tests/system/test_table_admin.py
rename to tests/system/v2_client/test_table_admin.py
diff --git a/tests/unit/data/__init__.py b/tests/unit/data/__init__.py
new file mode 100644
index 000000000..89a37dc92
--- /dev/null
+++ b/tests/unit/data/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py
new file mode 100644
index 000000000..e03028c45
--- /dev/null
+++ b/tests/unit/data/_async/test__mutate_rows.py
@@ -0,0 +1,378 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from google.cloud.bigtable_v2.types import MutateRowsResponse
+from google.rpc import status_pb2
+import google.api_core.exceptions as core_exceptions
+
+# try/except added for compatibility with python < 3.8
+try:
+ from unittest import mock
+ from unittest.mock import AsyncMock # type: ignore
+except ImportError: # pragma: NO COVER
+ import mock # type: ignore
+ from mock import AsyncMock # type: ignore
+
+
+def _make_mutation(count=1, size=1):
+ mutation = mock.Mock()
+ mutation.size.return_value = size
+ mutation.mutations = [mock.Mock()] * count
+ return mutation
+
+
+class TestMutateRowsOperation:
+ def _target_class(self):
+ from google.cloud.bigtable.data._async._mutate_rows import (
+ _MutateRowsOperationAsync,
+ )
+
+ return _MutateRowsOperationAsync
+
+ def _make_one(self, *args, **kwargs):
+ if not args:
+ kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock())
+ kwargs["table"] = kwargs.pop("table", AsyncMock())
+ kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5)
+ kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1)
+ kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ())
+ kwargs["mutation_entries"] = kwargs.pop("mutation_entries", [])
+ return self._target_class()(*args, **kwargs)
+
+ async def _mock_stream(self, mutation_list, error_dict):
+ for idx, entry in enumerate(mutation_list):
+ code = error_dict.get(idx, 0)
+ yield MutateRowsResponse(
+ entries=[
+ MutateRowsResponse.Entry(
+ index=idx, status=status_pb2.Status(code=code)
+ )
+ ]
+ )
+
+ def _make_mock_gapic(self, mutation_list, error_dict=None):
+ mock_fn = AsyncMock()
+ if error_dict is None:
+ error_dict = {}
+ mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream(
+ mutation_list, error_dict
+ )
+ return mock_fn
+
+ def test_ctor(self):
+ """
+ test that constructor sets all the attributes correctly
+ """
+ from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto
+ from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete
+ from google.api_core.exceptions import DeadlineExceeded
+ from google.api_core.exceptions import Aborted
+
+ client = mock.Mock()
+ table = mock.Mock()
+ entries = [_make_mutation(), _make_mutation()]
+ operation_timeout = 0.05
+ attempt_timeout = 0.01
+ retryable_exceptions = ()
+ instance = self._make_one(
+ client,
+ table,
+ entries,
+ operation_timeout,
+ attempt_timeout,
+ retryable_exceptions,
+ )
+ # running gapic_fn should trigger a client call
+ assert client.mutate_rows.call_count == 0
+ instance._gapic_fn()
+ assert client.mutate_rows.call_count == 1
+ # gapic_fn should call with table details
+ inner_kwargs = client.mutate_rows.call_args[1]
+ assert len(inner_kwargs) == 4
+ assert inner_kwargs["table_name"] == table.table_name
+ assert inner_kwargs["app_profile_id"] == table.app_profile_id
+ assert inner_kwargs["retry"] is None
+ metadata = inner_kwargs["metadata"]
+ assert len(metadata) == 1
+ assert metadata[0][0] == "x-goog-request-params"
+ assert str(table.table_name) in metadata[0][1]
+ assert str(table.app_profile_id) in metadata[0][1]
+ # entries should be passed down
+ entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries]
+ assert instance.mutations == entries_w_pb
+ # timeout_gen should generate per-attempt timeout
+ assert next(instance.timeout_generator) == attempt_timeout
+ # ensure predicate is set
+ assert instance.is_retryable is not None
+ assert instance.is_retryable(DeadlineExceeded("")) is False
+ assert instance.is_retryable(Aborted("")) is False
+ assert instance.is_retryable(_MutateRowsIncomplete("")) is True
+ assert instance.is_retryable(RuntimeError("")) is False
+ assert instance.remaining_indices == list(range(len(entries)))
+ assert instance.errors == {}
+
+ def test_ctor_too_many_entries(self):
+ """
+ should raise an error if an operation is created with more than 100,000 entries
+ """
+ from google.cloud.bigtable.data._async._mutate_rows import (
+ _MUTATE_ROWS_REQUEST_MUTATION_LIMIT,
+ )
+
+ assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100_000
+
+ client = mock.Mock()
+ table = mock.Mock()
+ entries = [_make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT
+ operation_timeout = 0.05
+ attempt_timeout = 0.01
+ # no errors if at limit
+ self._make_one(client, table, entries, operation_timeout, attempt_timeout)
+ # raise error after crossing
+ with pytest.raises(ValueError) as e:
+ self._make_one(
+ client,
+ table,
+ entries + [_make_mutation()],
+ operation_timeout,
+ attempt_timeout,
+ )
+ assert "mutate_rows requests can contain at most 100000 mutations" in str(
+ e.value
+ )
+ assert "Found 100001" in str(e.value)
+
+ @pytest.mark.asyncio
+ async def test_mutate_rows_operation(self):
+ """
+ Test successful case of mutate_rows_operation
+ """
+ client = mock.Mock()
+ table = mock.Mock()
+ entries = [_make_mutation(), _make_mutation()]
+ operation_timeout = 0.05
+ cls = self._target_class()
+ with mock.patch(
+ f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock()
+ ) as attempt_mock:
+ instance = self._make_one(
+ client, table, entries, operation_timeout, operation_timeout
+ )
+ await instance.start()
+ assert attempt_mock.call_count == 1
+
+ @pytest.mark.parametrize(
+ "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden]
+ )
+ @pytest.mark.asyncio
+ async def test_mutate_rows_attempt_exception(self, exc_type):
+ """
+ exceptions raised from attempt should be raised in MutationsExceptionGroup
+ """
+ client = AsyncMock()
+ table = mock.Mock()
+ entries = [_make_mutation(), _make_mutation()]
+ operation_timeout = 0.05
+ expected_exception = exc_type("test")
+ client.mutate_rows.side_effect = expected_exception
+ found_exc = None
+ try:
+ instance = self._make_one(
+ client, table, entries, operation_timeout, operation_timeout
+ )
+ await instance._run_attempt()
+ except Exception as e:
+ found_exc = e
+ assert client.mutate_rows.call_count == 1
+ assert type(found_exc) is exc_type
+ assert found_exc == expected_exception
+ assert len(instance.errors) == 2
+ assert len(instance.remaining_indices) == 0
+
+ @pytest.mark.parametrize(
+ "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden]
+ )
+ @pytest.mark.asyncio
+ async def test_mutate_rows_exception(self, exc_type):
+ """
+ exceptions raised from retryable should be raised in MutationsExceptionGroup
+ """
+ from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
+ from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
+
+ client = mock.Mock()
+ table = mock.Mock()
+ entries = [_make_mutation(), _make_mutation()]
+ operation_timeout = 0.05
+ expected_cause = exc_type("abort")
+ with mock.patch.object(
+ self._target_class(),
+ "_run_attempt",
+ AsyncMock(),
+ ) as attempt_mock:
+ attempt_mock.side_effect = expected_cause
+ found_exc = None
+ try:
+ instance = self._make_one(
+ client, table, entries, operation_timeout, operation_timeout
+ )
+ await instance.start()
+ except MutationsExceptionGroup as e:
+ found_exc = e
+ assert attempt_mock.call_count == 1
+ assert len(found_exc.exceptions) == 2
+ assert isinstance(found_exc.exceptions[0], FailedMutationEntryError)
+ assert isinstance(found_exc.exceptions[1], FailedMutationEntryError)
+ assert found_exc.exceptions[0].__cause__ == expected_cause
+ assert found_exc.exceptions[1].__cause__ == expected_cause
+
+ @pytest.mark.parametrize(
+ "exc_type",
+ [core_exceptions.DeadlineExceeded, RuntimeError],
+ )
+ @pytest.mark.asyncio
+ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type):
+ """
+ If an exception fails but eventually passes, it should not raise an exception
+ """
+ from google.cloud.bigtable.data._async._mutate_rows import (
+ _MutateRowsOperationAsync,
+ )
+
+ client = mock.Mock()
+ table = mock.Mock()
+ entries = [_make_mutation()]
+ operation_timeout = 1
+ expected_cause = exc_type("retry")
+ num_retries = 2
+ with mock.patch.object(
+ _MutateRowsOperationAsync,
+ "_run_attempt",
+ AsyncMock(),
+ ) as attempt_mock:
+ attempt_mock.side_effect = [expected_cause] * num_retries + [None]
+ instance = self._make_one(
+ client,
+ table,
+ entries,
+ operation_timeout,
+ operation_timeout,
+ retryable_exceptions=(exc_type,),
+ )
+ await instance.start()
+ assert attempt_mock.call_count == num_retries + 1
+
+ @pytest.mark.asyncio
+ async def test_mutate_rows_incomplete_ignored(self):
+ """
+ MutateRowsIncomplete exceptions should not be added to error list
+ """
+ from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete
+ from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
+ from google.api_core.exceptions import DeadlineExceeded
+
+ client = mock.Mock()
+ table = mock.Mock()
+ entries = [_make_mutation()]
+ operation_timeout = 0.05
+ with mock.patch.object(
+ self._target_class(),
+ "_run_attempt",
+ AsyncMock(),
+ ) as attempt_mock:
+ attempt_mock.side_effect = _MutateRowsIncomplete("ignored")
+ found_exc = None
+ try:
+ instance = self._make_one(
+ client, table, entries, operation_timeout, operation_timeout
+ )
+ await instance.start()
+ except MutationsExceptionGroup as e:
+ found_exc = e
+ assert attempt_mock.call_count > 0
+ assert len(found_exc.exceptions) == 1
+ assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded)
+
+ @pytest.mark.asyncio
+ async def test_run_attempt_single_entry_success(self):
+ """Test mutating a single entry"""
+ mutation = _make_mutation()
+ expected_timeout = 1.3
+ mock_gapic_fn = self._make_mock_gapic({0: mutation})
+ instance = self._make_one(
+ mutation_entries=[mutation],
+ attempt_timeout=expected_timeout,
+ )
+ with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn):
+ await instance._run_attempt()
+ assert len(instance.remaining_indices) == 0
+ assert mock_gapic_fn.call_count == 1
+ _, kwargs = mock_gapic_fn.call_args
+ assert kwargs["timeout"] == expected_timeout
+ assert kwargs["entries"] == [mutation._to_pb()]
+
+ @pytest.mark.asyncio
+ async def test_run_attempt_empty_request(self):
+ """Calling with no mutations should result in no API calls"""
+ mock_gapic_fn = self._make_mock_gapic([])
+ instance = self._make_one(
+ mutation_entries=[],
+ )
+ await instance._run_attempt()
+ assert mock_gapic_fn.call_count == 0
+
+ @pytest.mark.asyncio
+ async def test_run_attempt_partial_success_retryable(self):
+ """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception"""
+ from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete
+
+ success_mutation = _make_mutation()
+ success_mutation_2 = _make_mutation()
+ failure_mutation = _make_mutation()
+ mutations = [success_mutation, failure_mutation, success_mutation_2]
+ mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300})
+ instance = self._make_one(
+ mutation_entries=mutations,
+ )
+ instance.is_retryable = lambda x: True
+ with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn):
+ with pytest.raises(_MutateRowsIncomplete):
+ await instance._run_attempt()
+ assert instance.remaining_indices == [1]
+ assert 0 not in instance.errors
+ assert len(instance.errors[1]) == 1
+ assert instance.errors[1][0].grpc_status_code == 300
+ assert 2 not in instance.errors
+
+ @pytest.mark.asyncio
+ async def test_run_attempt_partial_success_non_retryable(self):
+ """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error"""
+ success_mutation = _make_mutation()
+ success_mutation_2 = _make_mutation()
+ failure_mutation = _make_mutation()
+ mutations = [success_mutation, failure_mutation, success_mutation_2]
+ mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300})
+ instance = self._make_one(
+ mutation_entries=mutations,
+ )
+ instance.is_retryable = lambda x: False
+ with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn):
+ await instance._run_attempt()
+ assert instance.remaining_indices == []
+ assert 0 not in instance.errors
+ assert len(instance.errors[1]) == 1
+ assert instance.errors[1][0].grpc_status_code == 300
+ assert 2 not in instance.errors
diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py
new file mode 100644
index 000000000..4e7797c6d
--- /dev/null
+++ b/tests/unit/data/_async/test__read_rows.py
@@ -0,0 +1,391 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync
+
+# try/except added for compatibility with python < 3.8
+try:
+ from unittest import mock
+ from unittest.mock import AsyncMock # type: ignore
+except ImportError: # pragma: NO COVER
+ import mock # type: ignore
+ from mock import AsyncMock # type: ignore # noqa F401
+
+TEST_FAMILY = "family_name"
+TEST_QUALIFIER = b"qualifier"
+TEST_TIMESTAMP = 123456789
+TEST_LABELS = ["label1", "label2"]
+
+
+class TestReadRowsOperation:
+ """
+ Tests helper functions in the ReadRowsOperation class
+ in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt
+ is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests
+ """
+
+ @staticmethod
+ def _get_target_class():
+ from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync
+
+ return _ReadRowsOperationAsync
+
+ def _make_one(self, *args, **kwargs):
+ return self._get_target_class()(*args, **kwargs)
+
+ def test_ctor(self):
+ from google.cloud.bigtable.data import ReadRowsQuery
+
+ row_limit = 91
+ query = ReadRowsQuery(limit=row_limit)
+ client = mock.Mock()
+ client.read_rows = mock.Mock()
+ client.read_rows.return_value = None
+ table = mock.Mock()
+ table._client = client
+ table.table_name = "test_table"
+ table.app_profile_id = "test_profile"
+ expected_operation_timeout = 42
+ expected_request_timeout = 44
+ time_gen_mock = mock.Mock()
+ with mock.patch(
+ "google.cloud.bigtable.data._async._read_rows._attempt_timeout_generator",
+ time_gen_mock,
+ ):
+ instance = self._make_one(
+ query,
+ table,
+ operation_timeout=expected_operation_timeout,
+ attempt_timeout=expected_request_timeout,
+ )
+ assert time_gen_mock.call_count == 1
+ time_gen_mock.assert_called_once_with(
+ expected_request_timeout, expected_operation_timeout
+ )
+ assert instance._last_yielded_row_key is None
+ assert instance._remaining_count == row_limit
+ assert instance.operation_timeout == expected_operation_timeout
+ assert client.read_rows.call_count == 0
+ assert instance._metadata == [
+ (
+ "x-goog-request-params",
+ "table_name=test_table&app_profile_id=test_profile",
+ )
+ ]
+ assert instance.request.table_name == table.table_name
+ assert instance.request.app_profile_id == table.app_profile_id
+ assert instance.request.rows_limit == row_limit
+
+ @pytest.mark.parametrize(
+ "in_keys,last_key,expected",
+ [
+ (["b", "c", "d"], "a", ["b", "c", "d"]),
+ (["a", "b", "c"], "b", ["c"]),
+ (["a", "b", "c"], "c", []),
+ (["a", "b", "c"], "d", []),
+ (["d", "c", "b", "a"], "b", ["d", "c"]),
+ ],
+ )
+ def test_revise_request_rowset_keys(self, in_keys, last_key, expected):
+ from google.cloud.bigtable_v2.types import RowSet as RowSetPB
+ from google.cloud.bigtable_v2.types import RowRange as RowRangePB
+
+ in_keys = [key.encode("utf-8") for key in in_keys]
+ expected = [key.encode("utf-8") for key in expected]
+ last_key = last_key.encode("utf-8")
+
+ sample_range = RowRangePB(start_key_open=last_key)
+ row_set = RowSetPB(row_keys=in_keys, row_ranges=[sample_range])
+ revised = self._get_target_class()._revise_request_rowset(row_set, last_key)
+ assert revised.row_keys == expected
+ assert revised.row_ranges == [sample_range]
+
+ @pytest.mark.parametrize(
+ "in_ranges,last_key,expected",
+ [
+ (
+ [{"start_key_open": "b", "end_key_closed": "d"}],
+ "a",
+ [{"start_key_open": "b", "end_key_closed": "d"}],
+ ),
+ (
+ [{"start_key_closed": "b", "end_key_closed": "d"}],
+ "a",
+ [{"start_key_closed": "b", "end_key_closed": "d"}],
+ ),
+ (
+ [{"start_key_open": "a", "end_key_closed": "d"}],
+ "b",
+ [{"start_key_open": "b", "end_key_closed": "d"}],
+ ),
+ (
+ [{"start_key_closed": "a", "end_key_open": "d"}],
+ "b",
+ [{"start_key_open": "b", "end_key_open": "d"}],
+ ),
+ (
+ [{"start_key_closed": "b", "end_key_closed": "d"}],
+ "b",
+ [{"start_key_open": "b", "end_key_closed": "d"}],
+ ),
+ ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []),
+ ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []),
+ ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []),
+ ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]),
+ ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]),
+ (
+ [{"end_key_closed": "z"}],
+ "a",
+ [{"start_key_open": "a", "end_key_closed": "z"}],
+ ),
+ (
+ [{"end_key_open": "z"}],
+ "a",
+ [{"start_key_open": "a", "end_key_open": "z"}],
+ ),
+ ],
+ )
+ def test_revise_request_rowset_ranges(self, in_ranges, last_key, expected):
+ from google.cloud.bigtable_v2.types import RowSet as RowSetPB
+ from google.cloud.bigtable_v2.types import RowRange as RowRangePB
+
+ # convert to protobuf
+ next_key = (last_key + "a").encode("utf-8")
+ last_key = last_key.encode("utf-8")
+ in_ranges = [
+ RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()})
+ for r in in_ranges
+ ]
+ expected = [
+ RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) for r in expected
+ ]
+
+ row_set = RowSetPB(row_ranges=in_ranges, row_keys=[next_key])
+ revised = self._get_target_class()._revise_request_rowset(row_set, last_key)
+ assert revised.row_keys == [next_key]
+ assert revised.row_ranges == expected
+
+ @pytest.mark.parametrize("last_key", ["a", "b", "c"])
+ def test_revise_request_full_table(self, last_key):
+ from google.cloud.bigtable_v2.types import RowSet as RowSetPB
+ from google.cloud.bigtable_v2.types import RowRange as RowRangePB
+
+ # convert to protobuf
+ last_key = last_key.encode("utf-8")
+ row_set = RowSetPB()
+ for selected_set in [row_set, None]:
+ revised = self._get_target_class()._revise_request_rowset(
+ selected_set, last_key
+ )
+ assert revised.row_keys == []
+ assert len(revised.row_ranges) == 1
+ assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key)
+
+ def test_revise_to_empty_rowset(self):
+ """revising to an empty rowset should raise error"""
+ from google.cloud.bigtable.data.exceptions import _RowSetComplete
+ from google.cloud.bigtable_v2.types import RowSet as RowSetPB
+ from google.cloud.bigtable_v2.types import RowRange as RowRangePB
+
+ row_keys = [b"a", b"b", b"c"]
+ row_range = RowRangePB(end_key_open=b"c")
+ row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range])
+ with pytest.raises(_RowSetComplete):
+ self._get_target_class()._revise_request_rowset(row_set, b"d")
+
+ @pytest.mark.parametrize(
+ "start_limit,emit_num,expected_limit",
+ [
+ (10, 0, 10),
+ (10, 1, 9),
+ (10, 10, 0),
+ (None, 10, None),
+ (None, 0, None),
+ (4, 2, 2),
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_revise_limit(self, start_limit, emit_num, expected_limit):
+ """
+ revise_limit should revise the request's limit field
+ - if limit is 0 (unlimited), it should never be revised
+ - if start_limit-emit_num == 0, the request should end early
+ - if the number emitted exceeds the new limit, an exception should
+ should be raised (tested in test_revise_limit_over_limit)
+ """
+ from google.cloud.bigtable.data import ReadRowsQuery
+ from google.cloud.bigtable_v2.types import ReadRowsResponse
+
+ async def awaitable_stream():
+ async def mock_stream():
+ for i in range(emit_num):
+ yield ReadRowsResponse(
+ chunks=[
+ ReadRowsResponse.CellChunk(
+ row_key=str(i).encode(),
+ family_name="b",
+ qualifier=b"c",
+ value=b"d",
+ commit_row=True,
+ )
+ ]
+ )
+
+ return mock_stream()
+
+ query = ReadRowsQuery(limit=start_limit)
+ table = mock.Mock()
+ table.table_name = "table_name"
+ table.app_profile_id = "app_profile_id"
+ instance = self._make_one(query, table, 10, 10)
+ assert instance._remaining_count == start_limit
+ # read emit_num rows
+ async for val in instance.chunk_stream(awaitable_stream()):
+ pass
+ assert instance._remaining_count == expected_limit
+
+ @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)])
+ @pytest.mark.asyncio
+ async def test_revise_limit_over_limit(self, start_limit, emit_num):
+ """
+ Should raise runtime error if we get in state where emit_num > start_num
+ (unless start_num == 0, which represents unlimited)
+ """
+ from google.cloud.bigtable.data import ReadRowsQuery
+ from google.cloud.bigtable_v2.types import ReadRowsResponse
+ from google.cloud.bigtable.data.exceptions import InvalidChunk
+
+ async def awaitable_stream():
+ async def mock_stream():
+ for i in range(emit_num):
+ yield ReadRowsResponse(
+ chunks=[
+ ReadRowsResponse.CellChunk(
+ row_key=str(i).encode(),
+ family_name="b",
+ qualifier=b"c",
+ value=b"d",
+ commit_row=True,
+ )
+ ]
+ )
+
+ return mock_stream()
+
+ query = ReadRowsQuery(limit=start_limit)
+ table = mock.Mock()
+ table.table_name = "table_name"
+ table.app_profile_id = "app_profile_id"
+ instance = self._make_one(query, table, 10, 10)
+ assert instance._remaining_count == start_limit
+ with pytest.raises(InvalidChunk) as e:
+ # read emit_num rows
+ async for val in instance.chunk_stream(awaitable_stream()):
+ pass
+ assert "emit count exceeds row limit" in str(e.value)
+
+ @pytest.mark.asyncio
+ async def test_aclose(self):
+ """
+ should be able to close a stream safely with aclose.
+ Closed generators should raise StopAsyncIteration on next yield
+ """
+
+ async def mock_stream():
+ while True:
+ yield 1
+
+ with mock.patch.object(
+ _ReadRowsOperationAsync, "_read_rows_attempt"
+ ) as mock_attempt:
+ instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1)
+ wrapped_gen = mock_stream()
+ mock_attempt.return_value = wrapped_gen
+ gen = instance.start_operation()
+ # read one row
+ await gen.__anext__()
+ await gen.aclose()
+ with pytest.raises(StopAsyncIteration):
+ await gen.__anext__()
+ # try calling a second time
+ await gen.aclose()
+ # ensure close was propagated to wrapped generator
+ with pytest.raises(StopAsyncIteration):
+ await wrapped_gen.__anext__()
+
+ @pytest.mark.asyncio
+ async def test_retryable_ignore_repeated_rows(self):
+ """
+ Duplicate rows should cause an invalid chunk error
+ """
+ from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync
+ from google.cloud.bigtable.data.exceptions import InvalidChunk
+ from google.cloud.bigtable_v2.types import ReadRowsResponse
+
+ row_key = b"duplicate"
+
+ async def mock_awaitable_stream():
+ async def mock_stream():
+ while True:
+ yield ReadRowsResponse(
+ chunks=[
+ ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True)
+ ]
+ )
+ yield ReadRowsResponse(
+ chunks=[
+ ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True)
+ ]
+ )
+
+ return mock_stream()
+
+ instance = mock.Mock()
+ instance._last_yielded_row_key = None
+ instance._remaining_count = None
+ stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream())
+ await stream.__anext__()
+ with pytest.raises(InvalidChunk) as exc:
+ await stream.__anext__()
+ assert "row keys should be strictly increasing" in str(exc.value)
+
+
+class MockStream(_ReadRowsOperationAsync):
+ """
+ Mock a _ReadRowsOperationAsync stream for testing
+ """
+
+ def __init__(self, items=None, errors=None, operation_timeout=None):
+ self.transient_errors = errors
+ self.operation_timeout = operation_timeout
+ self.next_idx = 0
+ if items is None:
+ items = list(range(10))
+ self.items = items
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ if self.next_idx >= len(self.items):
+ raise StopAsyncIteration
+ item = self.items[self.next_idx]
+ self.next_idx += 1
+ if isinstance(item, Exception):
+ raise item
+ return item
+
+ async def aclose(self):
+ pass
diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py
new file mode 100644
index 000000000..a0019947d
--- /dev/null
+++ b/tests/unit/data/_async/test_client.py
@@ -0,0 +1,2957 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import grpc
+import asyncio
+import re
+import sys
+
+import pytest
+
+from google.cloud.bigtable.data import mutations
+from google.auth.credentials import AnonymousCredentials
+from google.cloud.bigtable_v2.types import ReadRowsResponse
+from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+from google.api_core import exceptions as core_exceptions
+from google.cloud.bigtable.data.exceptions import InvalidChunk
+from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete
+from google.cloud.bigtable.data import TABLE_DEFAULT
+
+from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule
+from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule
+
+# try/except added for compatibility with python < 3.8
+try:
+ from unittest import mock
+ from unittest.mock import AsyncMock # type: ignore
+except ImportError: # pragma: NO COVER
+ import mock # type: ignore
+ from mock import AsyncMock # type: ignore
+
+VENEER_HEADER_REGEX = re.compile(
+ r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-data-async gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+"
+)
+
+
+def _make_client(*args, use_emulator=True, **kwargs):
+ import os
+ from google.cloud.bigtable.data._async.client import BigtableDataClientAsync
+
+ env_mask = {}
+ # by default, use emulator mode to avoid auth issues in CI
+ # emulator mode must be disabled by tests that check channel pooling/refresh background tasks
+ if use_emulator:
+ env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost"
+ else:
+ # set some default values
+ kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials())
+ kwargs["project"] = kwargs.get("project", "project-id")
+ with mock.patch.dict(os.environ, env_mask):
+ return BigtableDataClientAsync(*args, **kwargs)
+
+
+class TestBigtableDataClientAsync:
+ def _get_target_class(self):
+ from google.cloud.bigtable.data._async.client import BigtableDataClientAsync
+
+ return BigtableDataClientAsync
+
+ def _make_one(self, *args, **kwargs):
+ return _make_client(*args, **kwargs)
+
+ @pytest.mark.asyncio
+ async def test_ctor(self):
+ expected_project = "project-id"
+ expected_pool_size = 11
+ expected_credentials = AnonymousCredentials()
+ client = self._make_one(
+ project="project-id",
+ pool_size=expected_pool_size,
+ credentials=expected_credentials,
+ use_emulator=False,
+ )
+ await asyncio.sleep(0)
+ assert client.project == expected_project
+ assert len(client.transport._grpc_channel._pool) == expected_pool_size
+ assert not client._active_instances
+ assert len(client._channel_refresh_tasks) == expected_pool_size
+ assert client.transport._credentials == expected_credentials
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test_ctor_super_inits(self):
+ from google.cloud.bigtable_v2.services.bigtable.async_client import (
+ BigtableAsyncClient,
+ )
+ from google.cloud.client import ClientWithProject
+ from google.api_core import client_options as client_options_lib
+
+ project = "project-id"
+ pool_size = 11
+ credentials = AnonymousCredentials()
+ client_options = {"api_endpoint": "foo.bar:1234"}
+ options_parsed = client_options_lib.from_dict(client_options)
+ transport_str = f"pooled_grpc_asyncio_{pool_size}"
+ with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init:
+ bigtable_client_init.return_value = None
+ with mock.patch.object(
+ ClientWithProject, "__init__"
+ ) as client_project_init:
+ client_project_init.return_value = None
+ try:
+ self._make_one(
+ project=project,
+ pool_size=pool_size,
+ credentials=credentials,
+ client_options=options_parsed,
+ use_emulator=False,
+ )
+ except AttributeError:
+ pass
+ # test gapic superclass init was called
+ assert bigtable_client_init.call_count == 1
+ kwargs = bigtable_client_init.call_args[1]
+ assert kwargs["transport"] == transport_str
+ assert kwargs["credentials"] == credentials
+ assert kwargs["client_options"] == options_parsed
+ # test mixin superclass init was called
+ assert client_project_init.call_count == 1
+ kwargs = client_project_init.call_args[1]
+ assert kwargs["project"] == project
+ assert kwargs["credentials"] == credentials
+ assert kwargs["client_options"] == options_parsed
+
+ @pytest.mark.asyncio
+ async def test_ctor_dict_options(self):
+ from google.cloud.bigtable_v2.services.bigtable.async_client import (
+ BigtableAsyncClient,
+ )
+ from google.api_core.client_options import ClientOptions
+
+ client_options = {"api_endpoint": "foo.bar:1234"}
+ with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init:
+ try:
+ self._make_one(client_options=client_options)
+ except TypeError:
+ pass
+ bigtable_client_init.assert_called_once()
+ kwargs = bigtable_client_init.call_args[1]
+ called_options = kwargs["client_options"]
+ assert called_options.api_endpoint == "foo.bar:1234"
+ assert isinstance(called_options, ClientOptions)
+ with mock.patch.object(
+ self._get_target_class(), "_start_background_channel_refresh"
+ ) as start_background_refresh:
+ client = self._make_one(client_options=client_options, use_emulator=False)
+ start_background_refresh.assert_called_once()
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test_veneer_grpc_headers(self):
+ # client_info should be populated with headers to
+ # detect as a veneer client
+ patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method")
+ with patch as gapic_mock:
+ client = self._make_one(project="project-id")
+ wrapped_call_list = gapic_mock.call_args_list
+ assert len(wrapped_call_list) > 0
+ # each wrapped call should have veneer headers
+ for call in wrapped_call_list:
+ client_info = call.kwargs["client_info"]
+ assert client_info is not None, f"{call} has no client_info"
+ wrapped_user_agent_sorted = " ".join(
+ sorted(client_info.to_user_agent().split(" "))
+ )
+ assert VENEER_HEADER_REGEX.match(
+ wrapped_user_agent_sorted
+ ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}"
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test_channel_pool_creation(self):
+ pool_size = 14
+ with mock.patch(
+ "google.api_core.grpc_helpers_async.create_channel"
+ ) as create_channel:
+ create_channel.return_value = AsyncMock()
+ client = self._make_one(project="project-id", pool_size=pool_size)
+ assert create_channel.call_count == pool_size
+ await client.close()
+ # channels should be unique
+ client = self._make_one(project="project-id", pool_size=pool_size)
+ pool_list = list(client.transport._grpc_channel._pool)
+ pool_set = set(client.transport._grpc_channel._pool)
+ assert len(pool_list) == len(pool_set)
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test_channel_pool_rotation(self):
+ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import (
+ PooledChannel,
+ )
+
+ pool_size = 7
+
+ with mock.patch.object(PooledChannel, "next_channel") as next_channel:
+ client = self._make_one(project="project-id", pool_size=pool_size)
+ assert len(client.transport._grpc_channel._pool) == pool_size
+ next_channel.reset_mock()
+ with mock.patch.object(
+ type(client.transport._grpc_channel._pool[0]), "unary_unary"
+ ) as unary_unary:
+ # calling an rpc `pool_size` times should use a different channel each time
+ channel_next = None
+ for i in range(pool_size):
+ channel_last = channel_next
+ channel_next = client.transport.grpc_channel._pool[i]
+ assert channel_last != channel_next
+ next_channel.return_value = channel_next
+ client.transport.ping_and_warm()
+ assert next_channel.call_count == i + 1
+ unary_unary.assert_called_once()
+ unary_unary.reset_mock()
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test_channel_pool_replace(self):
+ with mock.patch.object(asyncio, "sleep"):
+ pool_size = 7
+ client = self._make_one(project="project-id", pool_size=pool_size)
+ for replace_idx in range(pool_size):
+ start_pool = [
+ channel for channel in client.transport._grpc_channel._pool
+ ]
+ grace_period = 9
+ with mock.patch.object(
+ type(client.transport._grpc_channel._pool[0]), "close"
+ ) as close:
+ new_channel = grpc.aio.insecure_channel("localhost:8080")
+ await client.transport.replace_channel(
+ replace_idx, grace=grace_period, new_channel=new_channel
+ )
+ close.assert_called_once_with(grace=grace_period)
+ close.assert_awaited_once()
+ assert client.transport._grpc_channel._pool[replace_idx] == new_channel
+ for i in range(pool_size):
+ if i != replace_idx:
+ assert client.transport._grpc_channel._pool[i] == start_pool[i]
+ else:
+ assert client.transport._grpc_channel._pool[i] != start_pool[i]
+ await client.close()
+
+ @pytest.mark.filterwarnings("ignore::RuntimeWarning")
+ def test__start_background_channel_refresh_sync(self):
+ # should raise RuntimeError if called in a sync context
+ client = self._make_one(project="project-id", use_emulator=False)
+ with pytest.raises(RuntimeError):
+ client._start_background_channel_refresh()
+
+ @pytest.mark.asyncio
+ async def test__start_background_channel_refresh_tasks_exist(self):
+ # if tasks exist, should do nothing
+ client = self._make_one(project="project-id", use_emulator=False)
+ assert len(client._channel_refresh_tasks) > 0
+ with mock.patch.object(asyncio, "create_task") as create_task:
+ client._start_background_channel_refresh()
+ create_task.assert_not_called()
+ await client.close()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("pool_size", [1, 3, 7])
+ async def test__start_background_channel_refresh(self, pool_size):
+ # should create background tasks for each channel
+ client = self._make_one(
+ project="project-id", pool_size=pool_size, use_emulator=False
+ )
+ ping_and_warm = AsyncMock()
+ client._ping_and_warm_instances = ping_and_warm
+ client._start_background_channel_refresh()
+ assert len(client._channel_refresh_tasks) == pool_size
+ for task in client._channel_refresh_tasks:
+ assert isinstance(task, asyncio.Task)
+ await asyncio.sleep(0.1)
+ assert ping_and_warm.call_count == pool_size
+ for channel in client.transport._grpc_channel._pool:
+ ping_and_warm.assert_any_call(channel)
+ await client.close()
+
+ @pytest.mark.asyncio
+ @pytest.mark.skipif(
+ sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher"
+ )
+ async def test__start_background_channel_refresh_tasks_names(self):
+ # if tasks exist, should do nothing
+ pool_size = 3
+ client = self._make_one(
+ project="project-id", pool_size=pool_size, use_emulator=False
+ )
+ for i in range(pool_size):
+ name = client._channel_refresh_tasks[i].get_name()
+ assert str(i) in name
+ assert "BigtableDataClientAsync channel refresh " in name
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test__ping_and_warm_instances(self):
+ """
+ test ping and warm with mocked asyncio.gather
+ """
+ client_mock = mock.Mock()
+ with mock.patch.object(asyncio, "gather", AsyncMock()) as gather:
+ # simulate gather by returning the same number of items as passed in
+ gather.side_effect = lambda *args, **kwargs: [None for _ in args]
+ channel = mock.Mock()
+ # test with no instances
+ client_mock._active_instances = []
+ result = await self._get_target_class()._ping_and_warm_instances(
+ client_mock, channel
+ )
+ assert len(result) == 0
+ gather.assert_called_once()
+ gather.assert_awaited_once()
+ assert not gather.call_args.args
+ assert gather.call_args.kwargs == {"return_exceptions": True}
+ # test with instances
+ client_mock._active_instances = [
+ (mock.Mock(), mock.Mock(), mock.Mock())
+ ] * 4
+ gather.reset_mock()
+ channel.reset_mock()
+ result = await self._get_target_class()._ping_and_warm_instances(
+ client_mock, channel
+ )
+ assert len(result) == 4
+ gather.assert_called_once()
+ gather.assert_awaited_once()
+ assert len(gather.call_args.args) == 4
+ # check grpc call arguments
+ grpc_call_args = channel.unary_unary().call_args_list
+ for idx, (_, kwargs) in enumerate(grpc_call_args):
+ (
+ expected_instance,
+ expected_table,
+ expected_app_profile,
+ ) = client_mock._active_instances[idx]
+ request = kwargs["request"]
+ assert request["name"] == expected_instance
+ assert request["app_profile_id"] == expected_app_profile
+ metadata = kwargs["metadata"]
+ assert len(metadata) == 1
+ assert metadata[0][0] == "x-goog-request-params"
+ assert (
+ metadata[0][1]
+ == f"name={expected_instance}&app_profile_id={expected_app_profile}"
+ )
+
+ @pytest.mark.asyncio
+ async def test__ping_and_warm_single_instance(self):
+ """
+ should be able to call ping and warm with single instance
+ """
+ client_mock = mock.Mock()
+ with mock.patch.object(asyncio, "gather", AsyncMock()) as gather:
+ # simulate gather by returning the same number of items as passed in
+ gather.side_effect = lambda *args, **kwargs: [None for _ in args]
+ channel = mock.Mock()
+ # test with large set of instances
+ client_mock._active_instances = [mock.Mock()] * 100
+ test_key = ("test-instance", "test-table", "test-app-profile")
+ result = await self._get_target_class()._ping_and_warm_instances(
+ client_mock, channel, test_key
+ )
+ # should only have been called with test instance
+ assert len(result) == 1
+ # check grpc call arguments
+ grpc_call_args = channel.unary_unary().call_args_list
+ assert len(grpc_call_args) == 1
+ kwargs = grpc_call_args[0][1]
+ request = kwargs["request"]
+ assert request["name"] == "test-instance"
+ assert request["app_profile_id"] == "test-app-profile"
+ metadata = kwargs["metadata"]
+ assert len(metadata) == 1
+ assert metadata[0][0] == "x-goog-request-params"
+ assert (
+ metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile"
+ )
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "refresh_interval, wait_time, expected_sleep",
+ [
+ (0, 0, 0),
+ (0, 1, 0),
+ (10, 0, 10),
+ (10, 5, 5),
+ (10, 10, 0),
+ (10, 15, 0),
+ ],
+ )
+ async def test__manage_channel_first_sleep(
+ self, refresh_interval, wait_time, expected_sleep
+ ):
+ # first sleep time should be `refresh_interval` seconds after client init
+ import time
+
+ with mock.patch.object(time, "monotonic") as time:
+ time.return_value = 0
+ with mock.patch.object(asyncio, "sleep") as sleep:
+ sleep.side_effect = asyncio.CancelledError
+ try:
+ client = self._make_one(project="project-id")
+ client._channel_init_time = -wait_time
+ await client._manage_channel(0, refresh_interval, refresh_interval)
+ except asyncio.CancelledError:
+ pass
+ sleep.assert_called_once()
+ call_time = sleep.call_args[0][0]
+ assert (
+ abs(call_time - expected_sleep) < 0.1
+ ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}"
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test__manage_channel_ping_and_warm(self):
+ """
+ _manage channel should call ping and warm internally
+ """
+ import time
+
+ client_mock = mock.Mock()
+ client_mock._channel_init_time = time.monotonic()
+ channel_list = [mock.Mock(), mock.Mock()]
+ client_mock.transport.channels = channel_list
+ new_channel = mock.Mock()
+ client_mock.transport.grpc_channel._create_channel.return_value = new_channel
+ # should ping an warm all new channels, and old channels if sleeping
+ with mock.patch.object(asyncio, "sleep"):
+ # stop process after replace_channel is called
+ client_mock.transport.replace_channel.side_effect = asyncio.CancelledError
+ ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock()
+ # should ping and warm old channel then new if sleep > 0
+ try:
+ channel_idx = 1
+ await self._get_target_class()._manage_channel(
+ client_mock, channel_idx, 10
+ )
+ except asyncio.CancelledError:
+ pass
+ # should have called at loop start, and after replacement
+ assert ping_and_warm.call_count == 2
+ # should have replaced channel once
+ assert client_mock.transport.replace_channel.call_count == 1
+ # make sure new and old channels were warmed
+ old_channel = channel_list[channel_idx]
+ assert old_channel != new_channel
+ called_with = [call[0][0] for call in ping_and_warm.call_args_list]
+ assert old_channel in called_with
+ assert new_channel in called_with
+ # should ping and warm instantly new channel only if not sleeping
+ ping_and_warm.reset_mock()
+ try:
+ await self._get_target_class()._manage_channel(client_mock, 0, 0, 0)
+ except asyncio.CancelledError:
+ pass
+ ping_and_warm.assert_called_once_with(new_channel)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "refresh_interval, num_cycles, expected_sleep",
+ [
+ (None, 1, 60 * 35),
+ (10, 10, 100),
+ (10, 1, 10),
+ ],
+ )
+ async def test__manage_channel_sleeps(
+ self, refresh_interval, num_cycles, expected_sleep
+ ):
+ # make sure that sleeps work as expected
+ import time
+ import random
+
+ channel_idx = 1
+ with mock.patch.object(random, "uniform") as uniform:
+ uniform.side_effect = lambda min_, max_: min_
+ with mock.patch.object(time, "time") as time:
+ time.return_value = 0
+ with mock.patch.object(asyncio, "sleep") as sleep:
+ sleep.side_effect = [None for i in range(num_cycles - 1)] + [
+ asyncio.CancelledError
+ ]
+ try:
+ client = self._make_one(project="project-id")
+ if refresh_interval is not None:
+ await client._manage_channel(
+ channel_idx, refresh_interval, refresh_interval
+ )
+ else:
+ await client._manage_channel(channel_idx)
+ except asyncio.CancelledError:
+ pass
+ assert sleep.call_count == num_cycles
+ total_sleep = sum([call[0][0] for call in sleep.call_args_list])
+ assert (
+ abs(total_sleep - expected_sleep) < 0.1
+ ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}"
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test__manage_channel_random(self):
+ import random
+
+ with mock.patch.object(asyncio, "sleep") as sleep:
+ with mock.patch.object(random, "uniform") as uniform:
+ uniform.return_value = 0
+ try:
+ uniform.side_effect = asyncio.CancelledError
+ client = self._make_one(project="project-id", pool_size=1)
+ except asyncio.CancelledError:
+ uniform.side_effect = None
+ uniform.reset_mock()
+ sleep.reset_mock()
+ min_val = 200
+ max_val = 205
+ uniform.side_effect = lambda min_, max_: min_
+ sleep.side_effect = [None, None, asyncio.CancelledError]
+ try:
+ await client._manage_channel(0, min_val, max_val)
+ except asyncio.CancelledError:
+ pass
+ assert uniform.call_count == 2
+ uniform_args = [call[0] for call in uniform.call_args_list]
+ for found_min, found_max in uniform_args:
+ assert found_min == min_val
+ assert found_max == max_val
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100])
+ async def test__manage_channel_refresh(self, num_cycles):
+ # make sure that channels are properly refreshed
+ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import (
+ PooledBigtableGrpcAsyncIOTransport,
+ )
+ from google.api_core import grpc_helpers_async
+
+ expected_grace = 9
+ expected_refresh = 0.5
+ channel_idx = 1
+ new_channel = grpc.aio.insecure_channel("localhost:8080")
+
+ with mock.patch.object(
+ PooledBigtableGrpcAsyncIOTransport, "replace_channel"
+ ) as replace_channel:
+ with mock.patch.object(asyncio, "sleep") as sleep:
+ sleep.side_effect = [None for i in range(num_cycles)] + [
+ asyncio.CancelledError
+ ]
+ with mock.patch.object(
+ grpc_helpers_async, "create_channel"
+ ) as create_channel:
+ create_channel.return_value = new_channel
+ client = self._make_one(project="project-id", use_emulator=False)
+ create_channel.reset_mock()
+ try:
+ await client._manage_channel(
+ channel_idx,
+ refresh_interval_min=expected_refresh,
+ refresh_interval_max=expected_refresh,
+ grace_period=expected_grace,
+ )
+ except asyncio.CancelledError:
+ pass
+ assert sleep.call_count == num_cycles + 1
+ assert create_channel.call_count == num_cycles
+ assert replace_channel.call_count == num_cycles
+ for call in replace_channel.call_args_list:
+ args, kwargs = call
+ assert args[0] == channel_idx
+ assert kwargs["grace"] == expected_grace
+ assert kwargs["new_channel"] == new_channel
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test__register_instance(self):
+ """
+ test instance registration
+ """
+ # set up mock client
+ client_mock = mock.Mock()
+ client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}"
+ active_instances = set()
+ instance_owners = {}
+ client_mock._active_instances = active_instances
+ client_mock._instance_owners = instance_owners
+ client_mock._channel_refresh_tasks = []
+ client_mock._start_background_channel_refresh.side_effect = (
+ lambda: client_mock._channel_refresh_tasks.append(mock.Mock)
+ )
+ mock_channels = [mock.Mock() for i in range(5)]
+ client_mock.transport.channels = mock_channels
+ client_mock._ping_and_warm_instances = AsyncMock()
+ table_mock = mock.Mock()
+ await self._get_target_class()._register_instance(
+ client_mock, "instance-1", table_mock
+ )
+ # first call should start background refresh
+ assert client_mock._start_background_channel_refresh.call_count == 1
+ # ensure active_instances and instance_owners were updated properly
+ expected_key = (
+ "prefix/instance-1",
+ table_mock.table_name,
+ table_mock.app_profile_id,
+ )
+ assert len(active_instances) == 1
+ assert expected_key == tuple(list(active_instances)[0])
+ assert len(instance_owners) == 1
+ assert expected_key == tuple(list(instance_owners)[0])
+ # should be a new task set
+ assert client_mock._channel_refresh_tasks
+ # next call should not call _start_background_channel_refresh again
+ table_mock2 = mock.Mock()
+ await self._get_target_class()._register_instance(
+ client_mock, "instance-2", table_mock2
+ )
+ assert client_mock._start_background_channel_refresh.call_count == 1
+ # but it should call ping and warm with new instance key
+ assert client_mock._ping_and_warm_instances.call_count == len(mock_channels)
+ for channel in mock_channels:
+ assert channel in [
+ call[0][0]
+ for call in client_mock._ping_and_warm_instances.call_args_list
+ ]
+ # check for updated lists
+ assert len(active_instances) == 2
+ assert len(instance_owners) == 2
+ expected_key2 = (
+ "prefix/instance-2",
+ table_mock2.table_name,
+ table_mock2.app_profile_id,
+ )
+ assert any(
+ [
+ expected_key2 == tuple(list(active_instances)[i])
+ for i in range(len(active_instances))
+ ]
+ )
+ assert any(
+ [
+ expected_key2 == tuple(list(instance_owners)[i])
+ for i in range(len(instance_owners))
+ ]
+ )
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "insert_instances,expected_active,expected_owner_keys",
+ [
+ ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]),
+ ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]),
+ ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]),
+ (
+ [("1", "t", "p"), ("2", "t", "p")],
+ [("1", "t", "p"), ("2", "t", "p")],
+ [("1", "t", "p"), ("2", "t", "p")],
+ ),
+ ],
+ )
+ async def test__register_instance_state(
+ self, insert_instances, expected_active, expected_owner_keys
+ ):
+ """
+ test that active_instances and instance_owners are updated as expected
+ """
+ # set up mock client
+ client_mock = mock.Mock()
+ client_mock._gapic_client.instance_path.side_effect = lambda a, b: b
+ active_instances = set()
+ instance_owners = {}
+ client_mock._active_instances = active_instances
+ client_mock._instance_owners = instance_owners
+ client_mock._channel_refresh_tasks = []
+ client_mock._start_background_channel_refresh.side_effect = (
+ lambda: client_mock._channel_refresh_tasks.append(mock.Mock)
+ )
+ mock_channels = [mock.Mock() for i in range(5)]
+ client_mock.transport.channels = mock_channels
+ client_mock._ping_and_warm_instances = AsyncMock()
+ table_mock = mock.Mock()
+ # register instances
+ for instance, table, profile in insert_instances:
+ table_mock.table_name = table
+ table_mock.app_profile_id = profile
+ await self._get_target_class()._register_instance(
+ client_mock, instance, table_mock
+ )
+ assert len(active_instances) == len(expected_active)
+ assert len(instance_owners) == len(expected_owner_keys)
+ for expected in expected_active:
+ assert any(
+ [
+ expected == tuple(list(active_instances)[i])
+ for i in range(len(active_instances))
+ ]
+ )
+ for expected in expected_owner_keys:
+ assert any(
+ [
+ expected == tuple(list(instance_owners)[i])
+ for i in range(len(instance_owners))
+ ]
+ )
+
+ @pytest.mark.asyncio
+ async def test__remove_instance_registration(self):
+ client = self._make_one(project="project-id")
+ table = mock.Mock()
+ await client._register_instance("instance-1", table)
+ await client._register_instance("instance-2", table)
+ assert len(client._active_instances) == 2
+ assert len(client._instance_owners.keys()) == 2
+ instance_1_path = client._gapic_client.instance_path(
+ client.project, "instance-1"
+ )
+ instance_1_key = (instance_1_path, table.table_name, table.app_profile_id)
+ instance_2_path = client._gapic_client.instance_path(
+ client.project, "instance-2"
+ )
+ instance_2_key = (instance_2_path, table.table_name, table.app_profile_id)
+ assert len(client._instance_owners[instance_1_key]) == 1
+ assert list(client._instance_owners[instance_1_key])[0] == id(table)
+ assert len(client._instance_owners[instance_2_key]) == 1
+ assert list(client._instance_owners[instance_2_key])[0] == id(table)
+ success = await client._remove_instance_registration("instance-1", table)
+ assert success
+ assert len(client._active_instances) == 1
+ assert len(client._instance_owners[instance_1_key]) == 0
+ assert len(client._instance_owners[instance_2_key]) == 1
+ assert client._active_instances == {instance_2_key}
+ success = await client._remove_instance_registration("fake-key", table)
+ assert not success
+ assert len(client._active_instances) == 1
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test__multiple_table_registration(self):
+ """
+ registering with multiple tables with the same key should
+ add multiple owners to instance_owners, but only keep one copy
+ of shared key in active_instances
+ """
+ from google.cloud.bigtable.data._async.client import _WarmedInstanceKey
+
+ async with self._make_one(project="project-id") as client:
+ async with client.get_table("instance_1", "table_1") as table_1:
+ instance_1_path = client._gapic_client.instance_path(
+ client.project, "instance_1"
+ )
+ instance_1_key = _WarmedInstanceKey(
+ instance_1_path, table_1.table_name, table_1.app_profile_id
+ )
+ assert len(client._instance_owners[instance_1_key]) == 1
+ assert len(client._active_instances) == 1
+ assert id(table_1) in client._instance_owners[instance_1_key]
+ # duplicate table should register in instance_owners under same key
+ async with client.get_table("instance_1", "table_1") as table_2:
+ assert len(client._instance_owners[instance_1_key]) == 2
+ assert len(client._active_instances) == 1
+ assert id(table_1) in client._instance_owners[instance_1_key]
+ assert id(table_2) in client._instance_owners[instance_1_key]
+ # unique table should register in instance_owners and active_instances
+ async with client.get_table("instance_1", "table_3") as table_3:
+ instance_3_path = client._gapic_client.instance_path(
+ client.project, "instance_1"
+ )
+ instance_3_key = _WarmedInstanceKey(
+ instance_3_path, table_3.table_name, table_3.app_profile_id
+ )
+ assert len(client._instance_owners[instance_1_key]) == 2
+ assert len(client._instance_owners[instance_3_key]) == 1
+ assert len(client._active_instances) == 2
+ assert id(table_1) in client._instance_owners[instance_1_key]
+ assert id(table_2) in client._instance_owners[instance_1_key]
+ assert id(table_3) in client._instance_owners[instance_3_key]
+ # sub-tables should be unregistered, but instance should still be active
+ assert len(client._active_instances) == 1
+ assert instance_1_key in client._active_instances
+ assert id(table_2) not in client._instance_owners[instance_1_key]
+ # both tables are gone. instance should be unregistered
+ assert len(client._active_instances) == 0
+ assert instance_1_key not in client._active_instances
+ assert len(client._instance_owners[instance_1_key]) == 0
+
+ @pytest.mark.asyncio
+ async def test__multiple_instance_registration(self):
+ """
+ registering with multiple instance keys should update the key
+ in instance_owners and active_instances
+ """
+ from google.cloud.bigtable.data._async.client import _WarmedInstanceKey
+
+ async with self._make_one(project="project-id") as client:
+ async with client.get_table("instance_1", "table_1") as table_1:
+ async with client.get_table("instance_2", "table_2") as table_2:
+ instance_1_path = client._gapic_client.instance_path(
+ client.project, "instance_1"
+ )
+ instance_1_key = _WarmedInstanceKey(
+ instance_1_path, table_1.table_name, table_1.app_profile_id
+ )
+ instance_2_path = client._gapic_client.instance_path(
+ client.project, "instance_2"
+ )
+ instance_2_key = _WarmedInstanceKey(
+ instance_2_path, table_2.table_name, table_2.app_profile_id
+ )
+ assert len(client._instance_owners[instance_1_key]) == 1
+ assert len(client._instance_owners[instance_2_key]) == 1
+ assert len(client._active_instances) == 2
+ assert id(table_1) in client._instance_owners[instance_1_key]
+ assert id(table_2) in client._instance_owners[instance_2_key]
+ # instance2 should be unregistered, but instance1 should still be active
+ assert len(client._active_instances) == 1
+ assert instance_1_key in client._active_instances
+ assert len(client._instance_owners[instance_2_key]) == 0
+ assert len(client._instance_owners[instance_1_key]) == 1
+ assert id(table_1) in client._instance_owners[instance_1_key]
+ # both tables are gone. instances should both be unregistered
+ assert len(client._active_instances) == 0
+ assert len(client._instance_owners[instance_1_key]) == 0
+ assert len(client._instance_owners[instance_2_key]) == 0
+
+ @pytest.mark.asyncio
+ async def test_get_table(self):
+ from google.cloud.bigtable.data._async.client import TableAsync
+ from google.cloud.bigtable.data._async.client import _WarmedInstanceKey
+
+ client = self._make_one(project="project-id")
+ assert not client._active_instances
+ expected_table_id = "table-id"
+ expected_instance_id = "instance-id"
+ expected_app_profile_id = "app-profile-id"
+ table = client.get_table(
+ expected_instance_id,
+ expected_table_id,
+ expected_app_profile_id,
+ )
+ await asyncio.sleep(0)
+ assert isinstance(table, TableAsync)
+ assert table.table_id == expected_table_id
+ assert (
+ table.table_name
+ == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}"
+ )
+ assert table.instance_id == expected_instance_id
+ assert (
+ table.instance_name
+ == f"projects/{client.project}/instances/{expected_instance_id}"
+ )
+ assert table.app_profile_id == expected_app_profile_id
+ assert table.client is client
+ instance_key = _WarmedInstanceKey(
+ table.instance_name, table.table_name, table.app_profile_id
+ )
+ assert instance_key in client._active_instances
+ assert client._instance_owners[instance_key] == {id(table)}
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test_get_table_arg_passthrough(self):
+ """
+ All arguments passed in get_table should be sent to constructor
+ """
+ async with self._make_one(project="project-id") as client:
+ with mock.patch(
+ "google.cloud.bigtable.data._async.client.TableAsync.__init__",
+ ) as mock_constructor:
+ mock_constructor.return_value = None
+ assert not client._active_instances
+ expected_table_id = "table-id"
+ expected_instance_id = "instance-id"
+ expected_app_profile_id = "app-profile-id"
+ expected_args = (1, "test", {"test": 2})
+ expected_kwargs = {"hello": "world", "test": 2}
+
+ client.get_table(
+ expected_instance_id,
+ expected_table_id,
+ expected_app_profile_id,
+ *expected_args,
+ **expected_kwargs,
+ )
+ mock_constructor.assert_called_once_with(
+ client,
+ expected_instance_id,
+ expected_table_id,
+ expected_app_profile_id,
+ *expected_args,
+ **expected_kwargs,
+ )
+
+ @pytest.mark.asyncio
+ async def test_get_table_context_manager(self):
+ from google.cloud.bigtable.data._async.client import TableAsync
+ from google.cloud.bigtable.data._async.client import _WarmedInstanceKey
+
+ expected_table_id = "table-id"
+ expected_instance_id = "instance-id"
+ expected_app_profile_id = "app-profile-id"
+ expected_project_id = "project-id"
+
+ with mock.patch.object(TableAsync, "close") as close_mock:
+ async with self._make_one(project=expected_project_id) as client:
+ async with client.get_table(
+ expected_instance_id,
+ expected_table_id,
+ expected_app_profile_id,
+ ) as table:
+ await asyncio.sleep(0)
+ assert isinstance(table, TableAsync)
+ assert table.table_id == expected_table_id
+ assert (
+ table.table_name
+ == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}"
+ )
+ assert table.instance_id == expected_instance_id
+ assert (
+ table.instance_name
+ == f"projects/{expected_project_id}/instances/{expected_instance_id}"
+ )
+ assert table.app_profile_id == expected_app_profile_id
+ assert table.client is client
+ instance_key = _WarmedInstanceKey(
+ table.instance_name, table.table_name, table.app_profile_id
+ )
+ assert instance_key in client._active_instances
+ assert client._instance_owners[instance_key] == {id(table)}
+ assert close_mock.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_multiple_pool_sizes(self):
+ # should be able to create multiple clients with different pool sizes without issue
+ pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
+ for pool_size in pool_sizes:
+ client = self._make_one(
+ project="project-id", pool_size=pool_size, use_emulator=False
+ )
+ assert len(client._channel_refresh_tasks) == pool_size
+ client_duplicate = self._make_one(
+ project="project-id", pool_size=pool_size, use_emulator=False
+ )
+ assert len(client_duplicate._channel_refresh_tasks) == pool_size
+ assert str(pool_size) in str(client.transport)
+ await client.close()
+ await client_duplicate.close()
+
+ @pytest.mark.asyncio
+ async def test_close(self):
+ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import (
+ PooledBigtableGrpcAsyncIOTransport,
+ )
+
+ pool_size = 7
+ client = self._make_one(
+ project="project-id", pool_size=pool_size, use_emulator=False
+ )
+ assert len(client._channel_refresh_tasks) == pool_size
+ tasks_list = list(client._channel_refresh_tasks)
+ for task in client._channel_refresh_tasks:
+ assert not task.done()
+ with mock.patch.object(
+ PooledBigtableGrpcAsyncIOTransport, "close", AsyncMock()
+ ) as close_mock:
+ await client.close()
+ close_mock.assert_called_once()
+ close_mock.assert_awaited()
+ for task in tasks_list:
+ assert task.done()
+ assert task.cancelled()
+ assert client._channel_refresh_tasks == []
+
+ @pytest.mark.asyncio
+ async def test_close_with_timeout(self):
+ pool_size = 7
+ expected_timeout = 19
+ client = self._make_one(project="project-id", pool_size=pool_size)
+ tasks = list(client._channel_refresh_tasks)
+ with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for_mock:
+ await client.close(timeout=expected_timeout)
+ wait_for_mock.assert_called_once()
+ wait_for_mock.assert_awaited()
+ assert wait_for_mock.call_args[1]["timeout"] == expected_timeout
+ client._channel_refresh_tasks = tasks
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test_context_manager(self):
+ # context manager should close the client cleanly
+ close_mock = AsyncMock()
+ true_close = None
+ async with self._make_one(project="project-id") as client:
+ true_close = client.close()
+ client.close = close_mock
+ for task in client._channel_refresh_tasks:
+ assert not task.done()
+ assert client.project == "project-id"
+ assert client._active_instances == set()
+ close_mock.assert_not_called()
+ close_mock.assert_called_once()
+ close_mock.assert_awaited()
+ # actually close the client
+ await true_close
+
+ def test_client_ctor_sync(self):
+ # initializing client in a sync context should raise RuntimeError
+
+ with pytest.warns(RuntimeWarning) as warnings:
+ client = _make_client(project="project-id", use_emulator=False)
+ expected_warning = [w for w in warnings if "client.py" in w.filename]
+ assert len(expected_warning) == 1
+ assert (
+ "BigtableDataClientAsync should be started in an asyncio event loop."
+ in str(expected_warning[0].message)
+ )
+ assert client.project == "project-id"
+ assert client._channel_refresh_tasks == []
+
+
+class TestTableAsync:
+ @pytest.mark.asyncio
+ async def test_table_ctor(self):
+ from google.cloud.bigtable.data._async.client import TableAsync
+ from google.cloud.bigtable.data._async.client import _WarmedInstanceKey
+
+ expected_table_id = "table-id"
+ expected_instance_id = "instance-id"
+ expected_app_profile_id = "app-profile-id"
+ expected_operation_timeout = 123
+ expected_attempt_timeout = 12
+ expected_read_rows_operation_timeout = 1.5
+ expected_read_rows_attempt_timeout = 0.5
+ expected_mutate_rows_operation_timeout = 2.5
+ expected_mutate_rows_attempt_timeout = 0.75
+ client = _make_client()
+ assert not client._active_instances
+
+ table = TableAsync(
+ client,
+ expected_instance_id,
+ expected_table_id,
+ expected_app_profile_id,
+ default_operation_timeout=expected_operation_timeout,
+ default_attempt_timeout=expected_attempt_timeout,
+ default_read_rows_operation_timeout=expected_read_rows_operation_timeout,
+ default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout,
+ default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout,
+ default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout,
+ )
+ await asyncio.sleep(0)
+ assert table.table_id == expected_table_id
+ assert table.instance_id == expected_instance_id
+ assert table.app_profile_id == expected_app_profile_id
+ assert table.client is client
+ instance_key = _WarmedInstanceKey(
+ table.instance_name, table.table_name, table.app_profile_id
+ )
+ assert instance_key in client._active_instances
+ assert client._instance_owners[instance_key] == {id(table)}
+ assert table.default_operation_timeout == expected_operation_timeout
+ assert table.default_attempt_timeout == expected_attempt_timeout
+ assert (
+ table.default_read_rows_operation_timeout
+ == expected_read_rows_operation_timeout
+ )
+ assert (
+ table.default_read_rows_attempt_timeout
+ == expected_read_rows_attempt_timeout
+ )
+ assert (
+ table.default_mutate_rows_operation_timeout
+ == expected_mutate_rows_operation_timeout
+ )
+ assert (
+ table.default_mutate_rows_attempt_timeout
+ == expected_mutate_rows_attempt_timeout
+ )
+ # ensure task reaches completion
+ await table._register_instance_task
+ assert table._register_instance_task.done()
+ assert not table._register_instance_task.cancelled()
+ assert table._register_instance_task.exception() is None
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test_table_ctor_defaults(self):
+ """
+ should provide default timeout values and app_profile_id
+ """
+ from google.cloud.bigtable.data._async.client import TableAsync
+
+ expected_table_id = "table-id"
+ expected_instance_id = "instance-id"
+ client = _make_client()
+ assert not client._active_instances
+
+ table = TableAsync(
+ client,
+ expected_instance_id,
+ expected_table_id,
+ )
+ await asyncio.sleep(0)
+ assert table.table_id == expected_table_id
+ assert table.instance_id == expected_instance_id
+ assert table.app_profile_id is None
+ assert table.client is client
+ assert table.default_operation_timeout == 60
+ assert table.default_read_rows_operation_timeout == 600
+ assert table.default_mutate_rows_operation_timeout == 600
+ assert table.default_attempt_timeout == 20
+ assert table.default_read_rows_attempt_timeout == 20
+ assert table.default_mutate_rows_attempt_timeout == 60
+ await client.close()
+
+ @pytest.mark.asyncio
+ async def test_table_ctor_invalid_timeout_values(self):
+ """
+ bad timeout values should raise ValueError
+ """
+ from google.cloud.bigtable.data._async.client import TableAsync
+
+ client = _make_client()
+
+ timeout_pairs = [
+ ("default_operation_timeout", "default_attempt_timeout"),
+ (
+ "default_read_rows_operation_timeout",
+ "default_read_rows_attempt_timeout",
+ ),
+ (
+ "default_mutate_rows_operation_timeout",
+ "default_mutate_rows_attempt_timeout",
+ ),
+ ]
+ for operation_timeout, attempt_timeout in timeout_pairs:
+ with pytest.raises(ValueError) as e:
+ TableAsync(client, "", "", **{attempt_timeout: -1})
+ assert "attempt_timeout must be greater than 0" in str(e.value)
+ with pytest.raises(ValueError) as e:
+ TableAsync(client, "", "", **{operation_timeout: -1})
+ assert "operation_timeout must be greater than 0" in str(e.value)
+ await client.close()
+
+ def test_table_ctor_sync(self):
+ # initializing client in a sync context should raise RuntimeError
+ from google.cloud.bigtable.data._async.client import TableAsync
+
+ client = mock.Mock()
+ with pytest.raises(RuntimeError) as e:
+ TableAsync(client, "instance-id", "table-id")
+ assert e.match("TableAsync must be created within an async event loop context.")
+
+ @pytest.mark.asyncio
+ # iterate over all retryable rpcs
+ @pytest.mark.parametrize(
+ "fn_name,fn_args,retry_fn_path,extra_retryables",
+ [
+ (
+ "read_rows_stream",
+ (ReadRowsQuery(),),
+ "google.api_core.retry.retry_target_stream_async",
+ (),
+ ),
+ (
+ "read_rows",
+ (ReadRowsQuery(),),
+ "google.api_core.retry.retry_target_stream_async",
+ (),
+ ),
+ (
+ "read_row",
+ (b"row_key",),
+ "google.api_core.retry.retry_target_stream_async",
+ (),
+ ),
+ (
+ "read_rows_sharded",
+ ([ReadRowsQuery()],),
+ "google.api_core.retry.retry_target_stream_async",
+ (),
+ ),
+ (
+ "row_exists",
+ (b"row_key",),
+ "google.api_core.retry.retry_target_stream_async",
+ (),
+ ),
+ ("sample_row_keys", (), "google.api_core.retry.retry_target_async", ()),
+ (
+ "mutate_row",
+ (b"row_key", [mock.Mock()]),
+ "google.api_core.retry.retry_target_async",
+ (),
+ ),
+ (
+ "bulk_mutate_rows",
+ ([mutations.RowMutationEntry(b"key", [mock.Mock()])],),
+ "google.api_core.retry.retry_target_async",
+ (_MutateRowsIncomplete,),
+ ),
+ ],
+ )
+ # test different inputs for retryable exceptions
+ @pytest.mark.parametrize(
+ "input_retryables,expected_retryables",
+ [
+ (
+ TABLE_DEFAULT.READ_ROWS,
+ [
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ core_exceptions.Aborted,
+ ],
+ ),
+ (
+ TABLE_DEFAULT.DEFAULT,
+ [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable],
+ ),
+ (
+ TABLE_DEFAULT.MUTATE_ROWS,
+ [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable],
+ ),
+ ([], []),
+ ([4], [core_exceptions.DeadlineExceeded]),
+ ],
+ )
+ async def test_customizable_retryable_errors(
+ self,
+ input_retryables,
+ expected_retryables,
+ fn_name,
+ fn_args,
+ retry_fn_path,
+ extra_retryables,
+ ):
+ """
+ Test that retryable functions support user-configurable arguments, and that the configured retryables are passed
+ down to the gapic layer.
+ """
+ with mock.patch(retry_fn_path) as retry_fn_mock:
+ async with _make_client() as client:
+ table = client.get_table("instance-id", "table-id")
+ expected_predicate = lambda a: a in expected_retryables # noqa
+ retry_fn_mock.side_effect = RuntimeError("stop early")
+ with mock.patch(
+ "google.api_core.retry.if_exception_type"
+ ) as predicate_builder_mock:
+ predicate_builder_mock.return_value = expected_predicate
+ with pytest.raises(Exception):
+ # we expect an exception from attempting to call the mock
+ test_fn = table.__getattribute__(fn_name)
+ await test_fn(*fn_args, retryable_errors=input_retryables)
+ # passed in errors should be used to build the predicate
+ predicate_builder_mock.assert_called_once_with(
+ *expected_retryables, *extra_retryables
+ )
+ retry_call_args = retry_fn_mock.call_args_list[0].args
+ # output of if_exception_type should be sent in to retry constructor
+ assert retry_call_args[1] is expected_predicate
+
+ @pytest.mark.parametrize(
+ "fn_name,fn_args,gapic_fn",
+ [
+ ("read_rows_stream", (ReadRowsQuery(),), "read_rows"),
+ ("read_rows", (ReadRowsQuery(),), "read_rows"),
+ ("read_row", (b"row_key",), "read_rows"),
+ ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"),
+ ("row_exists", (b"row_key",), "read_rows"),
+ ("sample_row_keys", (), "sample_row_keys"),
+ ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"),
+ (
+ "bulk_mutate_rows",
+ ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],),
+ "mutate_rows",
+ ),
+ ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"),
+ (
+ "read_modify_write_row",
+ (b"row_key", mock.Mock()),
+ "read_modify_write_row",
+ ),
+ ],
+ )
+ @pytest.mark.parametrize("include_app_profile", [True, False])
+ @pytest.mark.asyncio
+ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn):
+ """check that all requests attach proper metadata headers"""
+ from google.cloud.bigtable.data import TableAsync
+
+ profile = "profile" if include_app_profile else None
+ with mock.patch(
+ f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}", mock.AsyncMock()
+ ) as gapic_mock:
+ gapic_mock.side_effect = RuntimeError("stop early")
+ async with _make_client() as client:
+ table = TableAsync(client, "instance-id", "table-id", profile)
+ try:
+ test_fn = table.__getattribute__(fn_name)
+ maybe_stream = await test_fn(*fn_args)
+ [i async for i in maybe_stream]
+ except Exception:
+ # we expect an exception from attempting to call the mock
+ pass
+ kwargs = gapic_mock.call_args_list[0].kwargs
+ metadata = kwargs["metadata"]
+ goog_metadata = None
+ for key, value in metadata:
+ if key == "x-goog-request-params":
+ goog_metadata = value
+ assert goog_metadata is not None, "x-goog-request-params not found"
+ assert "table_name=" + table.table_name in goog_metadata
+ if include_app_profile:
+ assert "app_profile_id=profile" in goog_metadata
+ else:
+ assert "app_profile_id=" not in goog_metadata
+
+
+class TestReadRows:
+ """
+ Tests for table.read_rows and related methods.
+ """
+
+ def _make_table(self, *args, **kwargs):
+ from google.cloud.bigtable.data._async.client import TableAsync
+
+ client_mock = mock.Mock()
+ client_mock._register_instance.side_effect = (
+ lambda *args, **kwargs: asyncio.sleep(0)
+ )
+ client_mock._remove_instance_registration.side_effect = (
+ lambda *args, **kwargs: asyncio.sleep(0)
+ )
+ kwargs["instance_id"] = kwargs.get(
+ "instance_id", args[0] if args else "instance"
+ )
+ kwargs["table_id"] = kwargs.get(
+ "table_id", args[1] if len(args) > 1 else "table"
+ )
+ client_mock._gapic_client.table_path.return_value = kwargs["table_id"]
+ client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"]
+ return TableAsync(client_mock, *args, **kwargs)
+
+ def _make_stats(self):
+ from google.cloud.bigtable_v2.types import RequestStats
+ from google.cloud.bigtable_v2.types import FullReadStatsView
+ from google.cloud.bigtable_v2.types import ReadIterationStats
+
+ return RequestStats(
+ full_read_stats_view=FullReadStatsView(
+ read_iteration_stats=ReadIterationStats(
+ rows_seen_count=1,
+ rows_returned_count=2,
+ cells_seen_count=3,
+ cells_returned_count=4,
+ )
+ )
+ )
+
+ @staticmethod
+ def _make_chunk(*args, **kwargs):
+ from google.cloud.bigtable_v2 import ReadRowsResponse
+
+ kwargs["row_key"] = kwargs.get("row_key", b"row_key")
+ kwargs["family_name"] = kwargs.get("family_name", "family_name")
+ kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier")
+ kwargs["value"] = kwargs.get("value", b"value")
+ kwargs["commit_row"] = kwargs.get("commit_row", True)
+
+ return ReadRowsResponse.CellChunk(*args, **kwargs)
+
+ @staticmethod
+ async def _make_gapic_stream(
+ chunk_list: list[ReadRowsResponse.CellChunk | Exception],
+ sleep_time=0,
+ ):
+ from google.cloud.bigtable_v2 import ReadRowsResponse
+
+ class mock_stream:
+ def __init__(self, chunk_list, sleep_time):
+ self.chunk_list = chunk_list
+ self.idx = -1
+ self.sleep_time = sleep_time
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ self.idx += 1
+ if len(self.chunk_list) > self.idx:
+ if sleep_time:
+ await asyncio.sleep(self.sleep_time)
+ chunk = self.chunk_list[self.idx]
+ if isinstance(chunk, Exception):
+ raise chunk
+ else:
+ return ReadRowsResponse(chunks=[chunk])
+ raise StopAsyncIteration
+
+ def cancel(self):
+ pass
+
+ return mock_stream(chunk_list, sleep_time)
+
+ async def execute_fn(self, table, *args, **kwargs):
+ return await table.read_rows(*args, **kwargs)
+
+ @pytest.mark.asyncio
+ async def test_read_rows(self):
+ query = ReadRowsQuery()
+ chunks = [
+ self._make_chunk(row_key=b"test_1"),
+ self._make_chunk(row_key=b"test_2"),
+ ]
+ async with self._make_table() as table:
+ read_rows = table.client._gapic_client.read_rows
+ read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream(
+ chunks
+ )
+ results = await self.execute_fn(table, query, operation_timeout=3)
+ assert len(results) == 2
+ assert results[0].row_key == b"test_1"
+ assert results[1].row_key == b"test_2"
+
+ @pytest.mark.asyncio
+ async def test_read_rows_stream(self):
+ query = ReadRowsQuery()
+ chunks = [
+ self._make_chunk(row_key=b"test_1"),
+ self._make_chunk(row_key=b"test_2"),
+ ]
+ async with self._make_table() as table:
+ read_rows = table.client._gapic_client.read_rows
+ read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream(
+ chunks
+ )
+ gen = await table.read_rows_stream(query, operation_timeout=3)
+ results = [row async for row in gen]
+ assert len(results) == 2
+ assert results[0].row_key == b"test_1"
+ assert results[1].row_key == b"test_2"
+
+ @pytest.mark.parametrize("include_app_profile", [True, False])
+ @pytest.mark.asyncio
+ async def test_read_rows_query_matches_request(self, include_app_profile):
+ from google.cloud.bigtable.data import RowRange
+ from google.cloud.bigtable.data.row_filters import PassAllFilter
+
+ app_profile_id = "app_profile_id" if include_app_profile else None
+ async with self._make_table(app_profile_id=app_profile_id) as table:
+ read_rows = table.client._gapic_client.read_rows
+ read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([])
+ row_keys = [b"test_1", "test_2"]
+ row_ranges = RowRange("1start", "2end")
+ filter_ = PassAllFilter(True)
+ limit = 99
+ query = ReadRowsQuery(
+ row_keys=row_keys,
+ row_ranges=row_ranges,
+ row_filter=filter_,
+ limit=limit,
+ )
+
+ results = await table.read_rows(query, operation_timeout=3)
+ assert len(results) == 0
+ call_request = read_rows.call_args_list[0][0][0]
+ query_pb = query._to_pb(table)
+ assert call_request == query_pb
+
+ @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1])
+ @pytest.mark.asyncio
+ async def test_read_rows_timeout(self, operation_timeout):
+ async with self._make_table() as table:
+ read_rows = table.client._gapic_client.read_rows
+ query = ReadRowsQuery()
+ chunks = [self._make_chunk(row_key=b"test_1")]
+ read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream(
+ chunks, sleep_time=1
+ )
+ try:
+ await table.read_rows(query, operation_timeout=operation_timeout)
+ except core_exceptions.DeadlineExceeded as e:
+ assert (
+ e.message
+ == f"operation_timeout of {operation_timeout:0.1f}s exceeded"
+ )
+
+ @pytest.mark.parametrize(
+ "per_request_t, operation_t, expected_num",
+ [
+ (0.05, 0.08, 2),
+ (0.05, 0.54, 11),
+ (0.05, 0.14, 3),
+ (0.05, 0.24, 5),
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_read_rows_attempt_timeout(
+ self, per_request_t, operation_t, expected_num
+ ):
+ """
+ Ensures that the attempt_timeout is respected and that the number of
+ requests is as expected.
+
+ operation_timeout does not cancel the request, so we expect the number of
+ requests to be the ceiling of operation_timeout / attempt_timeout.
+ """
+ from google.cloud.bigtable.data.exceptions import RetryExceptionGroup
+
+ expected_last_timeout = operation_t - (expected_num - 1) * per_request_t
+
+ # mocking uniform ensures there are no sleeps between retries
+ with mock.patch("random.uniform", side_effect=lambda a, b: 0):
+ async with self._make_table() as table:
+ read_rows = table.client._gapic_client.read_rows
+ read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream(
+ chunks, sleep_time=per_request_t
+ )
+ query = ReadRowsQuery()
+ chunks = [core_exceptions.DeadlineExceeded("mock deadline")]
+
+ try:
+ await table.read_rows(
+ query,
+ operation_timeout=operation_t,
+ attempt_timeout=per_request_t,
+ )
+ except core_exceptions.DeadlineExceeded as e:
+ retry_exc = e.__cause__
+ if expected_num == 0:
+ assert retry_exc is None
+ else:
+ assert type(retry_exc) is RetryExceptionGroup
+ assert f"{expected_num} failed attempts" in str(retry_exc)
+ assert len(retry_exc.exceptions) == expected_num
+ for sub_exc in retry_exc.exceptions:
+ assert sub_exc.message == "mock deadline"
+ assert read_rows.call_count == expected_num
+ # check timeouts
+ for _, call_kwargs in read_rows.call_args_list[:-1]:
+ assert call_kwargs["timeout"] == per_request_t
+ assert call_kwargs["retry"] is None
+ # last timeout should be adjusted to account for the time spent
+ assert (
+ abs(
+ read_rows.call_args_list[-1][1]["timeout"]
+ - expected_last_timeout
+ )
+ < 0.05
+ )
+
+ @pytest.mark.parametrize(
+ "exc_type",
+ [
+ core_exceptions.Aborted,
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_read_rows_retryable_error(self, exc_type):
+ async with self._make_table() as table:
+ read_rows = table.client._gapic_client.read_rows
+ read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream(
+ [expected_error]
+ )
+ query = ReadRowsQuery()
+ expected_error = exc_type("mock error")
+ try:
+ await table.read_rows(query, operation_timeout=0.1)
+ except core_exceptions.DeadlineExceeded as e:
+ retry_exc = e.__cause__
+ root_cause = retry_exc.exceptions[0]
+ assert type(root_cause) is exc_type
+ assert root_cause == expected_error
+
+ @pytest.mark.parametrize(
+ "exc_type",
+ [
+ core_exceptions.Cancelled,
+ core_exceptions.PreconditionFailed,
+ core_exceptions.NotFound,
+ core_exceptions.PermissionDenied,
+ core_exceptions.Conflict,
+ core_exceptions.InternalServerError,
+ core_exceptions.TooManyRequests,
+ core_exceptions.ResourceExhausted,
+ InvalidChunk,
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_read_rows_non_retryable_error(self, exc_type):
+ async with self._make_table() as table:
+ read_rows = table.client._gapic_client.read_rows
+ read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream(
+ [expected_error]
+ )
+ query = ReadRowsQuery()
+ expected_error = exc_type("mock error")
+ try:
+ await table.read_rows(query, operation_timeout=0.1)
+ except exc_type as e:
+ assert e == expected_error
+
+ @pytest.mark.asyncio
+ async def test_read_rows_revise_request(self):
+ """
+ Ensure that _revise_request is called between retries
+ """
+ from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync
+ from google.cloud.bigtable.data.exceptions import InvalidChunk
+ from google.cloud.bigtable_v2.types import RowSet
+
+ return_val = RowSet()
+ with mock.patch.object(
+ _ReadRowsOperationAsync, "_revise_request_rowset"
+ ) as revise_rowset:
+ revise_rowset.return_value = return_val
+ async with self._make_table() as table:
+ read_rows = table.client._gapic_client.read_rows
+ read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream(
+ chunks
+ )
+ row_keys = [b"test_1", b"test_2", b"test_3"]
+ query = ReadRowsQuery(row_keys=row_keys)
+ chunks = [
+ self._make_chunk(row_key=b"test_1"),
+ core_exceptions.Aborted("mock retryable error"),
+ ]
+ try:
+ await table.read_rows(query)
+ except InvalidChunk:
+ revise_rowset.assert_called()
+ first_call_kwargs = revise_rowset.call_args_list[0].kwargs
+ assert first_call_kwargs["row_set"] == query._to_pb(table).rows
+ assert first_call_kwargs["last_seen_row_key"] == b"test_1"
+ revised_call = read_rows.call_args_list[1].args[0]
+ assert revised_call.rows == return_val
+
+ @pytest.mark.asyncio
+ async def test_read_rows_default_timeouts(self):
+ """
+ Ensure that the default timeouts are set on the read rows operation when not overridden
+ """
+ from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync
+
+ operation_timeout = 8
+ attempt_timeout = 4
+ with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op:
+ mock_op.side_effect = RuntimeError("mock error")
+ async with self._make_table(
+ default_read_rows_operation_timeout=operation_timeout,
+ default_read_rows_attempt_timeout=attempt_timeout,
+ ) as table:
+ try:
+ await table.read_rows(ReadRowsQuery())
+ except RuntimeError:
+ pass
+ kwargs = mock_op.call_args_list[0].kwargs
+ assert kwargs["operation_timeout"] == operation_timeout
+ assert kwargs["attempt_timeout"] == attempt_timeout
+
+ @pytest.mark.asyncio
+ async def test_read_rows_default_timeout_override(self):
+ """
+ When timeouts are passed, they overwrite default values
+ """
+ from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync
+
+ operation_timeout = 8
+ attempt_timeout = 4
+ with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op:
+ mock_op.side_effect = RuntimeError("mock error")
+ async with self._make_table(
+ default_operation_timeout=99, default_attempt_timeout=97
+ ) as table:
+ try:
+ await table.read_rows(
+ ReadRowsQuery(),
+ operation_timeout=operation_timeout,
+ attempt_timeout=attempt_timeout,
+ )
+ except RuntimeError:
+ pass
+ kwargs = mock_op.call_args_list[0].kwargs
+ assert kwargs["operation_timeout"] == operation_timeout
+ assert kwargs["attempt_timeout"] == attempt_timeout
+
+ @pytest.mark.asyncio
+ async def test_read_row(self):
+ """Test reading a single row"""
+ async with _make_client() as client:
+ table = client.get_table("instance", "table")
+ row_key = b"test_1"
+ with mock.patch.object(table, "read_rows") as read_rows:
+ expected_result = object()
+ read_rows.side_effect = lambda *args, **kwargs: [expected_result]
+ expected_op_timeout = 8
+ expected_req_timeout = 4
+ row = await table.read_row(
+ row_key,
+ operation_timeout=expected_op_timeout,
+ attempt_timeout=expected_req_timeout,
+ )
+ assert row == expected_result
+ assert read_rows.call_count == 1
+ args, kwargs = read_rows.call_args_list[0]
+ assert kwargs["operation_timeout"] == expected_op_timeout
+ assert kwargs["attempt_timeout"] == expected_req_timeout
+ assert len(args) == 1
+ assert isinstance(args[0], ReadRowsQuery)
+ query = args[0]
+ assert query.row_keys == [row_key]
+ assert query.row_ranges == []
+ assert query.limit == 1
+
+ @pytest.mark.asyncio
+ async def test_read_row_w_filter(self):
+ """Test reading a single row with an added filter"""
+ async with _make_client() as client:
+ table = client.get_table("instance", "table")
+ row_key = b"test_1"
+ with mock.patch.object(table, "read_rows") as read_rows:
+ expected_result = object()
+ read_rows.side_effect = lambda *args, **kwargs: [expected_result]
+ expected_op_timeout = 8
+ expected_req_timeout = 4
+ mock_filter = mock.Mock()
+ expected_filter = {"filter": "mock filter"}
+ mock_filter._to_dict.return_value = expected_filter
+ row = await table.read_row(
+ row_key,
+ operation_timeout=expected_op_timeout,
+ attempt_timeout=expected_req_timeout,
+ row_filter=expected_filter,
+ )
+ assert row == expected_result
+ assert read_rows.call_count == 1
+ args, kwargs = read_rows.call_args_list[0]
+ assert kwargs["operation_timeout"] == expected_op_timeout
+ assert kwargs["attempt_timeout"] == expected_req_timeout
+ assert len(args) == 1
+ assert isinstance(args[0], ReadRowsQuery)
+ query = args[0]
+ assert query.row_keys == [row_key]
+ assert query.row_ranges == []
+ assert query.limit == 1
+ assert query.filter == expected_filter
+
+ @pytest.mark.asyncio
+ async def test_read_row_no_response(self):
+ """should return None if row does not exist"""
+ async with _make_client() as client:
+ table = client.get_table("instance", "table")
+ row_key = b"test_1"
+ with mock.patch.object(table, "read_rows") as read_rows:
+ # return no rows
+ read_rows.side_effect = lambda *args, **kwargs: []
+ expected_op_timeout = 8
+ expected_req_timeout = 4
+ result = await table.read_row(
+ row_key,
+ operation_timeout=expected_op_timeout,
+ attempt_timeout=expected_req_timeout,
+ )
+ assert result is None
+ assert read_rows.call_count == 1
+ args, kwargs = read_rows.call_args_list[0]
+ assert kwargs["operation_timeout"] == expected_op_timeout
+ assert kwargs["attempt_timeout"] == expected_req_timeout
+ assert isinstance(args[0], ReadRowsQuery)
+ query = args[0]
+ assert query.row_keys == [row_key]
+ assert query.row_ranges == []
+ assert query.limit == 1
+
+ @pytest.mark.parametrize(
+ "return_value,expected_result",
+ [
+ ([], False),
+ ([object()], True),
+ ([object(), object()], True),
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_row_exists(self, return_value, expected_result):
+ """Test checking for row existence"""
+ async with _make_client() as client:
+ table = client.get_table("instance", "table")
+ row_key = b"test_1"
+ with mock.patch.object(table, "read_rows") as read_rows:
+ # return no rows
+ read_rows.side_effect = lambda *args, **kwargs: return_value
+ expected_op_timeout = 1
+ expected_req_timeout = 2
+ result = await table.row_exists(
+ row_key,
+ operation_timeout=expected_op_timeout,
+ attempt_timeout=expected_req_timeout,
+ )
+ assert expected_result == result
+ assert read_rows.call_count == 1
+ args, kwargs = read_rows.call_args_list[0]
+ assert kwargs["operation_timeout"] == expected_op_timeout
+ assert kwargs["attempt_timeout"] == expected_req_timeout
+ assert isinstance(args[0], ReadRowsQuery)
+ expected_filter = {
+ "chain": {
+ "filters": [
+ {"cells_per_row_limit_filter": 1},
+ {"strip_value_transformer": True},
+ ]
+ }
+ }
+ query = args[0]
+ assert query.row_keys == [row_key]
+ assert query.row_ranges == []
+ assert query.limit == 1
+ assert query.filter._to_dict() == expected_filter
+
+
+class TestReadRowsSharded:
+ @pytest.mark.asyncio
+ async def test_read_rows_sharded_empty_query(self):
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with pytest.raises(ValueError) as exc:
+ await table.read_rows_sharded([])
+ assert "empty sharded_query" in str(exc.value)
+
+ @pytest.mark.asyncio
+ async def test_read_rows_sharded_multiple_queries(self):
+ """
+ Test with multiple queries. Should return results from both
+ """
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ table.client._gapic_client, "read_rows"
+ ) as read_rows:
+ read_rows.side_effect = (
+ lambda *args, **kwargs: TestReadRows._make_gapic_stream(
+ [
+ TestReadRows._make_chunk(row_key=k)
+ for k in args[0].rows.row_keys
+ ]
+ )
+ )
+ query_1 = ReadRowsQuery(b"test_1")
+ query_2 = ReadRowsQuery(b"test_2")
+ result = await table.read_rows_sharded([query_1, query_2])
+ assert len(result) == 2
+ assert result[0].row_key == b"test_1"
+ assert result[1].row_key == b"test_2"
+
+ @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24])
+ @pytest.mark.asyncio
+ async def test_read_rows_sharded_multiple_queries_calls(self, n_queries):
+ """
+ Each query should trigger a separate read_rows call
+ """
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(table, "read_rows") as read_rows:
+ query_list = [ReadRowsQuery() for _ in range(n_queries)]
+ await table.read_rows_sharded(query_list)
+ assert read_rows.call_count == n_queries
+
+ @pytest.mark.asyncio
+ async def test_read_rows_sharded_errors(self):
+ """
+ Errors should be exposed as ShardedReadRowsExceptionGroups
+ """
+ from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup
+ from google.cloud.bigtable.data.exceptions import FailedQueryShardError
+
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(table, "read_rows") as read_rows:
+ read_rows.side_effect = RuntimeError("mock error")
+ query_1 = ReadRowsQuery(b"test_1")
+ query_2 = ReadRowsQuery(b"test_2")
+ with pytest.raises(ShardedReadRowsExceptionGroup) as exc:
+ await table.read_rows_sharded([query_1, query_2])
+ exc_group = exc.value
+ assert isinstance(exc_group, ShardedReadRowsExceptionGroup)
+ assert len(exc.value.exceptions) == 2
+ assert isinstance(exc.value.exceptions[0], FailedQueryShardError)
+ assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError)
+ assert exc.value.exceptions[0].index == 0
+ assert exc.value.exceptions[0].query == query_1
+ assert isinstance(exc.value.exceptions[1], FailedQueryShardError)
+ assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError)
+ assert exc.value.exceptions[1].index == 1
+ assert exc.value.exceptions[1].query == query_2
+
+ @pytest.mark.asyncio
+ async def test_read_rows_sharded_concurrent(self):
+ """
+ Ensure sharded requests are concurrent
+ """
+ import time
+
+ async def mock_call(*args, **kwargs):
+ await asyncio.sleep(0.1)
+ return [mock.Mock()]
+
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(table, "read_rows") as read_rows:
+ read_rows.side_effect = mock_call
+ queries = [ReadRowsQuery() for _ in range(10)]
+ start_time = time.monotonic()
+ result = await table.read_rows_sharded(queries)
+ call_time = time.monotonic() - start_time
+ assert read_rows.call_count == 10
+ assert len(result) == 10
+ # if run in sequence, we would expect this to take 1 second
+ assert call_time < 0.2
+
+ @pytest.mark.asyncio
+ async def test_read_rows_sharded_batching(self):
+ """
+ Large queries should be processed in batches to limit concurrency
+ operation timeout should change between batches
+ """
+ from google.cloud.bigtable.data._async.client import TableAsync
+ from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT
+
+ assert _CONCURRENCY_LIMIT == 10 # change this test if this changes
+
+ n_queries = 90
+ expected_num_batches = n_queries // _CONCURRENCY_LIMIT
+ query_list = [ReadRowsQuery() for _ in range(n_queries)]
+
+ table_mock = AsyncMock()
+ start_operation_timeout = 10
+ start_attempt_timeout = 3
+ table_mock.default_read_rows_operation_timeout = start_operation_timeout
+ table_mock.default_read_rows_attempt_timeout = start_attempt_timeout
+ # clock ticks one second on each check
+ with mock.patch("time.monotonic", side_effect=range(0, 100000)):
+ with mock.patch("asyncio.gather", AsyncMock()) as gather_mock:
+ await TableAsync.read_rows_sharded(table_mock, query_list)
+ # should have individual calls for each query
+ assert table_mock.read_rows.call_count == n_queries
+ # should have single gather call for each batch
+ assert gather_mock.call_count == expected_num_batches
+ # ensure that timeouts decrease over time
+ kwargs = [
+ table_mock.read_rows.call_args_list[idx][1]
+ for idx in range(n_queries)
+ ]
+ for batch_idx in range(expected_num_batches):
+ batch_kwargs = kwargs[
+ batch_idx
+ * _CONCURRENCY_LIMIT : (batch_idx + 1)
+ * _CONCURRENCY_LIMIT
+ ]
+ for req_kwargs in batch_kwargs:
+ # each batch should have the same operation_timeout, and it should decrease in each batch
+ expected_operation_timeout = start_operation_timeout - (
+ batch_idx + 1
+ )
+ assert (
+ req_kwargs["operation_timeout"]
+ == expected_operation_timeout
+ )
+ # each attempt_timeout should start with default value, but decrease when operation_timeout reaches it
+ expected_attempt_timeout = min(
+ start_attempt_timeout, expected_operation_timeout
+ )
+ assert req_kwargs["attempt_timeout"] == expected_attempt_timeout
+ # await all created coroutines to avoid warnings
+ for i in range(len(gather_mock.call_args_list)):
+ for j in range(len(gather_mock.call_args_list[i][0])):
+ await gather_mock.call_args_list[i][0][j]
+
+
+class TestSampleRowKeys:
+ async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]):
+ from google.cloud.bigtable_v2.types import SampleRowKeysResponse
+
+ for value in sample_list:
+ yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1])
+
+ @pytest.mark.asyncio
+ async def test_sample_row_keys(self):
+ """
+ Test that method returns the expected key samples
+ """
+ samples = [
+ (b"test_1", 0),
+ (b"test_2", 100),
+ (b"test_3", 200),
+ ]
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ table.client._gapic_client, "sample_row_keys", AsyncMock()
+ ) as sample_row_keys:
+ sample_row_keys.return_value = self._make_gapic_stream(samples)
+ result = await table.sample_row_keys()
+ assert len(result) == 3
+ assert all(isinstance(r, tuple) for r in result)
+ assert all(isinstance(r[0], bytes) for r in result)
+ assert all(isinstance(r[1], int) for r in result)
+ assert result[0] == samples[0]
+ assert result[1] == samples[1]
+ assert result[2] == samples[2]
+
+ @pytest.mark.asyncio
+ async def test_sample_row_keys_bad_timeout(self):
+ """
+ should raise error if timeout is negative
+ """
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with pytest.raises(ValueError) as e:
+ await table.sample_row_keys(operation_timeout=-1)
+ assert "operation_timeout must be greater than 0" in str(e.value)
+ with pytest.raises(ValueError) as e:
+ await table.sample_row_keys(attempt_timeout=-1)
+ assert "attempt_timeout must be greater than 0" in str(e.value)
+
+ @pytest.mark.asyncio
+ async def test_sample_row_keys_default_timeout(self):
+ """Should fallback to using table default operation_timeout"""
+ expected_timeout = 99
+ async with _make_client() as client:
+ async with client.get_table(
+ "i",
+ "t",
+ default_operation_timeout=expected_timeout,
+ default_attempt_timeout=expected_timeout,
+ ) as table:
+ with mock.patch.object(
+ table.client._gapic_client, "sample_row_keys", AsyncMock()
+ ) as sample_row_keys:
+ sample_row_keys.return_value = self._make_gapic_stream([])
+ result = await table.sample_row_keys()
+ _, kwargs = sample_row_keys.call_args
+ assert abs(kwargs["timeout"] - expected_timeout) < 0.1
+ assert result == []
+ assert kwargs["retry"] is None
+
+ @pytest.mark.asyncio
+ async def test_sample_row_keys_gapic_params(self):
+ """
+ make sure arguments are propagated to gapic call as expected
+ """
+ expected_timeout = 10
+ expected_profile = "test1"
+ instance = "instance_name"
+ table_id = "my_table"
+ async with _make_client() as client:
+ async with client.get_table(
+ instance, table_id, app_profile_id=expected_profile
+ ) as table:
+ with mock.patch.object(
+ table.client._gapic_client, "sample_row_keys", AsyncMock()
+ ) as sample_row_keys:
+ sample_row_keys.return_value = self._make_gapic_stream([])
+ await table.sample_row_keys(attempt_timeout=expected_timeout)
+ args, kwargs = sample_row_keys.call_args
+ assert len(args) == 0
+ assert len(kwargs) == 5
+ assert kwargs["timeout"] == expected_timeout
+ assert kwargs["app_profile_id"] == expected_profile
+ assert kwargs["table_name"] == table.table_name
+ assert kwargs["metadata"] is not None
+ assert kwargs["retry"] is None
+
+ @pytest.mark.parametrize(
+ "retryable_exception",
+ [
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_sample_row_keys_retryable_errors(self, retryable_exception):
+ """
+ retryable errors should be retried until timeout
+ """
+ from google.api_core.exceptions import DeadlineExceeded
+ from google.cloud.bigtable.data.exceptions import RetryExceptionGroup
+
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ table.client._gapic_client, "sample_row_keys", AsyncMock()
+ ) as sample_row_keys:
+ sample_row_keys.side_effect = retryable_exception("mock")
+ with pytest.raises(DeadlineExceeded) as e:
+ await table.sample_row_keys(operation_timeout=0.05)
+ cause = e.value.__cause__
+ assert isinstance(cause, RetryExceptionGroup)
+ assert len(cause.exceptions) > 0
+ assert isinstance(cause.exceptions[0], retryable_exception)
+
+ @pytest.mark.parametrize(
+ "non_retryable_exception",
+ [
+ core_exceptions.OutOfRange,
+ core_exceptions.NotFound,
+ core_exceptions.FailedPrecondition,
+ RuntimeError,
+ ValueError,
+ core_exceptions.Aborted,
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception):
+ """
+ non-retryable errors should cause a raise
+ """
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ table.client._gapic_client, "sample_row_keys", AsyncMock()
+ ) as sample_row_keys:
+ sample_row_keys.side_effect = non_retryable_exception("mock")
+ with pytest.raises(non_retryable_exception):
+ await table.sample_row_keys()
+
+
+class TestMutateRow:
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mutation_arg",
+ [
+ mutations.SetCell("family", b"qualifier", b"value"),
+ mutations.SetCell(
+ "family", b"qualifier", b"value", timestamp_micros=1234567890
+ ),
+ mutations.DeleteRangeFromColumn("family", b"qualifier"),
+ mutations.DeleteAllFromFamily("family"),
+ mutations.DeleteAllFromRow(),
+ [mutations.SetCell("family", b"qualifier", b"value")],
+ [
+ mutations.DeleteRangeFromColumn("family", b"qualifier"),
+ mutations.DeleteAllFromRow(),
+ ],
+ ],
+ )
+ async def test_mutate_row(self, mutation_arg):
+ """Test mutations with no errors"""
+ expected_attempt_timeout = 19
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_row"
+ ) as mock_gapic:
+ mock_gapic.return_value = None
+ await table.mutate_row(
+ "row_key",
+ mutation_arg,
+ attempt_timeout=expected_attempt_timeout,
+ )
+ assert mock_gapic.call_count == 1
+ kwargs = mock_gapic.call_args_list[0].kwargs
+ assert (
+ kwargs["table_name"]
+ == "projects/project/instances/instance/tables/table"
+ )
+ assert kwargs["row_key"] == b"row_key"
+ formatted_mutations = (
+ [mutation._to_pb() for mutation in mutation_arg]
+ if isinstance(mutation_arg, list)
+ else [mutation_arg._to_pb()]
+ )
+ assert kwargs["mutations"] == formatted_mutations
+ assert kwargs["timeout"] == expected_attempt_timeout
+ # make sure gapic layer is not retrying
+ assert kwargs["retry"] is None
+
+ @pytest.mark.parametrize(
+ "retryable_exception",
+ [
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_mutate_row_retryable_errors(self, retryable_exception):
+ from google.api_core.exceptions import DeadlineExceeded
+ from google.cloud.bigtable.data.exceptions import RetryExceptionGroup
+
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_row"
+ ) as mock_gapic:
+ mock_gapic.side_effect = retryable_exception("mock")
+ with pytest.raises(DeadlineExceeded) as e:
+ mutation = mutations.DeleteAllFromRow()
+ assert mutation.is_idempotent() is True
+ await table.mutate_row(
+ "row_key", mutation, operation_timeout=0.01
+ )
+ cause = e.value.__cause__
+ assert isinstance(cause, RetryExceptionGroup)
+ assert isinstance(cause.exceptions[0], retryable_exception)
+
+ @pytest.mark.parametrize(
+ "retryable_exception",
+ [
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_mutate_row_non_idempotent_retryable_errors(
+ self, retryable_exception
+ ):
+ """
+ Non-idempotent mutations should not be retried
+ """
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_row"
+ ) as mock_gapic:
+ mock_gapic.side_effect = retryable_exception("mock")
+ with pytest.raises(retryable_exception):
+ mutation = mutations.SetCell(
+ "family", b"qualifier", b"value", -1
+ )
+ assert mutation.is_idempotent() is False
+ await table.mutate_row(
+ "row_key", mutation, operation_timeout=0.2
+ )
+
+ @pytest.mark.parametrize(
+ "non_retryable_exception",
+ [
+ core_exceptions.OutOfRange,
+ core_exceptions.NotFound,
+ core_exceptions.FailedPrecondition,
+ RuntimeError,
+ ValueError,
+ core_exceptions.Aborted,
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_mutate_row_non_retryable_errors(self, non_retryable_exception):
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_row"
+ ) as mock_gapic:
+ mock_gapic.side_effect = non_retryable_exception("mock")
+ with pytest.raises(non_retryable_exception):
+ mutation = mutations.SetCell(
+ "family",
+ b"qualifier",
+ b"value",
+ timestamp_micros=1234567890,
+ )
+ assert mutation.is_idempotent() is True
+ await table.mutate_row(
+ "row_key", mutation, operation_timeout=0.2
+ )
+
+ @pytest.mark.parametrize("include_app_profile", [True, False])
+ @pytest.mark.asyncio
+ async def test_mutate_row_metadata(self, include_app_profile):
+ """request should attach metadata headers"""
+ profile = "profile" if include_app_profile else None
+ async with _make_client() as client:
+ async with client.get_table("i", "t", app_profile_id=profile) as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_row", AsyncMock()
+ ) as read_rows:
+ await table.mutate_row("rk", mock.Mock())
+ kwargs = read_rows.call_args_list[0].kwargs
+ metadata = kwargs["metadata"]
+ goog_metadata = None
+ for key, value in metadata:
+ if key == "x-goog-request-params":
+ goog_metadata = value
+ assert goog_metadata is not None, "x-goog-request-params not found"
+ assert "table_name=" + table.table_name in goog_metadata
+ if include_app_profile:
+ assert "app_profile_id=profile" in goog_metadata
+ else:
+ assert "app_profile_id=" not in goog_metadata
+
+ @pytest.mark.parametrize("mutations", [[], None])
+ @pytest.mark.asyncio
+ async def test_mutate_row_no_mutations(self, mutations):
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with pytest.raises(ValueError) as e:
+ await table.mutate_row("key", mutations=mutations)
+ assert e.value.args[0] == "No mutations provided"
+
+
+class TestBulkMutateRows:
+ async def _mock_response(self, response_list):
+ from google.cloud.bigtable_v2.types import MutateRowsResponse
+ from google.rpc import status_pb2
+
+ statuses = []
+ for response in response_list:
+ if isinstance(response, core_exceptions.GoogleAPICallError):
+ statuses.append(
+ status_pb2.Status(
+ message=str(response), code=response.grpc_status_code.value[0]
+ )
+ )
+ else:
+ statuses.append(status_pb2.Status(code=0))
+ entries = [
+ MutateRowsResponse.Entry(index=i, status=statuses[i])
+ for i in range(len(response_list))
+ ]
+
+ async def generator():
+ yield MutateRowsResponse(entries=entries)
+
+ return generator()
+
+ @pytest.mark.asyncio
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mutation_arg",
+ [
+ [mutations.SetCell("family", b"qualifier", b"value")],
+ [
+ mutations.SetCell(
+ "family", b"qualifier", b"value", timestamp_micros=1234567890
+ )
+ ],
+ [mutations.DeleteRangeFromColumn("family", b"qualifier")],
+ [mutations.DeleteAllFromFamily("family")],
+ [mutations.DeleteAllFromRow()],
+ [mutations.SetCell("family", b"qualifier", b"value")],
+ [
+ mutations.DeleteRangeFromColumn("family", b"qualifier"),
+ mutations.DeleteAllFromRow(),
+ ],
+ ],
+ )
+ async def test_bulk_mutate_rows(self, mutation_arg):
+ """Test mutations with no errors"""
+ expected_attempt_timeout = 19
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_rows"
+ ) as mock_gapic:
+ mock_gapic.return_value = self._mock_response([None])
+ bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg)
+ await table.bulk_mutate_rows(
+ [bulk_mutation],
+ attempt_timeout=expected_attempt_timeout,
+ )
+ assert mock_gapic.call_count == 1
+ kwargs = mock_gapic.call_args[1]
+ assert (
+ kwargs["table_name"]
+ == "projects/project/instances/instance/tables/table"
+ )
+ assert kwargs["entries"] == [bulk_mutation._to_pb()]
+ assert kwargs["timeout"] == expected_attempt_timeout
+ assert kwargs["retry"] is None
+
+ @pytest.mark.asyncio
+ async def test_bulk_mutate_rows_multiple_entries(self):
+ """Test mutations with no errors"""
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_rows"
+ ) as mock_gapic:
+ mock_gapic.return_value = self._mock_response([None, None])
+ mutation_list = [mutations.DeleteAllFromRow()]
+ entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list)
+ entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list)
+ await table.bulk_mutate_rows(
+ [entry_1, entry_2],
+ )
+ assert mock_gapic.call_count == 1
+ kwargs = mock_gapic.call_args[1]
+ assert (
+ kwargs["table_name"]
+ == "projects/project/instances/instance/tables/table"
+ )
+ assert kwargs["entries"][0] == entry_1._to_pb()
+ assert kwargs["entries"][1] == entry_2._to_pb()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "exception",
+ [
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ ],
+ )
+ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable(
+ self, exception
+ ):
+ """
+ Individual idempotent mutations should be retried if they fail with a retryable error
+ """
+ from google.cloud.bigtable.data.exceptions import (
+ RetryExceptionGroup,
+ FailedMutationEntryError,
+ MutationsExceptionGroup,
+ )
+
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_rows"
+ ) as mock_gapic:
+ mock_gapic.side_effect = lambda *a, **k: self._mock_response(
+ [exception("mock")]
+ )
+ with pytest.raises(MutationsExceptionGroup) as e:
+ mutation = mutations.DeleteAllFromRow()
+ entry = mutations.RowMutationEntry(b"row_key", [mutation])
+ assert mutation.is_idempotent() is True
+ await table.bulk_mutate_rows([entry], operation_timeout=0.05)
+ assert len(e.value.exceptions) == 1
+ failed_exception = e.value.exceptions[0]
+ assert "non-idempotent" not in str(failed_exception)
+ assert isinstance(failed_exception, FailedMutationEntryError)
+ cause = failed_exception.__cause__
+ assert isinstance(cause, RetryExceptionGroup)
+ assert isinstance(cause.exceptions[0], exception)
+ # last exception should be due to retry timeout
+ assert isinstance(
+ cause.exceptions[-1], core_exceptions.DeadlineExceeded
+ )
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "exception",
+ [
+ core_exceptions.OutOfRange,
+ core_exceptions.NotFound,
+ core_exceptions.FailedPrecondition,
+ core_exceptions.Aborted,
+ ],
+ )
+ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable(
+ self, exception
+ ):
+ """
+ Individual idempotent mutations should not be retried if they fail with a non-retryable error
+ """
+ from google.cloud.bigtable.data.exceptions import (
+ FailedMutationEntryError,
+ MutationsExceptionGroup,
+ )
+
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_rows"
+ ) as mock_gapic:
+ mock_gapic.side_effect = lambda *a, **k: self._mock_response(
+ [exception("mock")]
+ )
+ with pytest.raises(MutationsExceptionGroup) as e:
+ mutation = mutations.DeleteAllFromRow()
+ entry = mutations.RowMutationEntry(b"row_key", [mutation])
+ assert mutation.is_idempotent() is True
+ await table.bulk_mutate_rows([entry], operation_timeout=0.05)
+ assert len(e.value.exceptions) == 1
+ failed_exception = e.value.exceptions[0]
+ assert "non-idempotent" not in str(failed_exception)
+ assert isinstance(failed_exception, FailedMutationEntryError)
+ cause = failed_exception.__cause__
+ assert isinstance(cause, exception)
+
+ @pytest.mark.parametrize(
+ "retryable_exception",
+ [
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_bulk_mutate_idempotent_retryable_request_errors(
+ self, retryable_exception
+ ):
+ """
+ Individual idempotent mutations should be retried if the request fails with a retryable error
+ """
+ from google.cloud.bigtable.data.exceptions import (
+ RetryExceptionGroup,
+ FailedMutationEntryError,
+ MutationsExceptionGroup,
+ )
+
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_rows"
+ ) as mock_gapic:
+ mock_gapic.side_effect = retryable_exception("mock")
+ with pytest.raises(MutationsExceptionGroup) as e:
+ mutation = mutations.SetCell(
+ "family", b"qualifier", b"value", timestamp_micros=123
+ )
+ entry = mutations.RowMutationEntry(b"row_key", [mutation])
+ assert mutation.is_idempotent() is True
+ await table.bulk_mutate_rows([entry], operation_timeout=0.05)
+ assert len(e.value.exceptions) == 1
+ failed_exception = e.value.exceptions[0]
+ assert isinstance(failed_exception, FailedMutationEntryError)
+ assert "non-idempotent" not in str(failed_exception)
+ cause = failed_exception.__cause__
+ assert isinstance(cause, RetryExceptionGroup)
+ assert isinstance(cause.exceptions[0], retryable_exception)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "retryable_exception",
+ [
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ ],
+ )
+ async def test_bulk_mutate_rows_non_idempotent_retryable_errors(
+ self, retryable_exception
+ ):
+ """Non-Idempotent mutations should never be retried"""
+ from google.cloud.bigtable.data.exceptions import (
+ FailedMutationEntryError,
+ MutationsExceptionGroup,
+ )
+
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_rows"
+ ) as mock_gapic:
+ mock_gapic.side_effect = lambda *a, **k: self._mock_response(
+ [retryable_exception("mock")]
+ )
+ with pytest.raises(MutationsExceptionGroup) as e:
+ mutation = mutations.SetCell(
+ "family", b"qualifier", b"value", -1
+ )
+ entry = mutations.RowMutationEntry(b"row_key", [mutation])
+ assert mutation.is_idempotent() is False
+ await table.bulk_mutate_rows([entry], operation_timeout=0.2)
+ assert len(e.value.exceptions) == 1
+ failed_exception = e.value.exceptions[0]
+ assert isinstance(failed_exception, FailedMutationEntryError)
+ assert "non-idempotent" in str(failed_exception)
+ cause = failed_exception.__cause__
+ assert isinstance(cause, retryable_exception)
+
+ @pytest.mark.parametrize(
+ "non_retryable_exception",
+ [
+ core_exceptions.OutOfRange,
+ core_exceptions.NotFound,
+ core_exceptions.FailedPrecondition,
+ RuntimeError,
+ ValueError,
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception):
+ """
+ If the request fails with a non-retryable error, mutations should not be retried
+ """
+ from google.cloud.bigtable.data.exceptions import (
+ FailedMutationEntryError,
+ MutationsExceptionGroup,
+ )
+
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_rows"
+ ) as mock_gapic:
+ mock_gapic.side_effect = non_retryable_exception("mock")
+ with pytest.raises(MutationsExceptionGroup) as e:
+ mutation = mutations.SetCell(
+ "family", b"qualifier", b"value", timestamp_micros=123
+ )
+ entry = mutations.RowMutationEntry(b"row_key", [mutation])
+ assert mutation.is_idempotent() is True
+ await table.bulk_mutate_rows([entry], operation_timeout=0.2)
+ assert len(e.value.exceptions) == 1
+ failed_exception = e.value.exceptions[0]
+ assert isinstance(failed_exception, FailedMutationEntryError)
+ assert "non-idempotent" not in str(failed_exception)
+ cause = failed_exception.__cause__
+ assert isinstance(cause, non_retryable_exception)
+
+ @pytest.mark.asyncio
+ async def test_bulk_mutate_error_index(self):
+ """
+ Test partial failure, partial success. Errors should be associated with the correct index
+ """
+ from google.api_core.exceptions import (
+ DeadlineExceeded,
+ ServiceUnavailable,
+ FailedPrecondition,
+ )
+ from google.cloud.bigtable.data.exceptions import (
+ RetryExceptionGroup,
+ FailedMutationEntryError,
+ MutationsExceptionGroup,
+ )
+
+ async with _make_client(project="project") as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "mutate_rows"
+ ) as mock_gapic:
+ # fail with retryable errors, then a non-retryable one
+ mock_gapic.side_effect = [
+ self._mock_response([None, ServiceUnavailable("mock"), None]),
+ self._mock_response([DeadlineExceeded("mock")]),
+ self._mock_response([FailedPrecondition("final")]),
+ ]
+ with pytest.raises(MutationsExceptionGroup) as e:
+ mutation = mutations.SetCell(
+ "family", b"qualifier", b"value", timestamp_micros=123
+ )
+ entries = [
+ mutations.RowMutationEntry(
+ (f"row_key_{i}").encode(), [mutation]
+ )
+ for i in range(3)
+ ]
+ assert mutation.is_idempotent() is True
+ await table.bulk_mutate_rows(entries, operation_timeout=1000)
+ assert len(e.value.exceptions) == 1
+ failed = e.value.exceptions[0]
+ assert isinstance(failed, FailedMutationEntryError)
+ assert failed.index == 1
+ assert failed.entry == entries[1]
+ cause = failed.__cause__
+ assert isinstance(cause, RetryExceptionGroup)
+ assert len(cause.exceptions) == 3
+ assert isinstance(cause.exceptions[0], ServiceUnavailable)
+ assert isinstance(cause.exceptions[1], DeadlineExceeded)
+ assert isinstance(cause.exceptions[2], FailedPrecondition)
+
+ @pytest.mark.asyncio
+ async def test_bulk_mutate_error_recovery(self):
+ """
+ If an error occurs, then resolves, no exception should be raised
+ """
+ from google.api_core.exceptions import DeadlineExceeded
+
+ async with _make_client(project="project") as client:
+ table = client.get_table("instance", "table")
+ with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic:
+ # fail with a retryable error, then a non-retryable one
+ mock_gapic.side_effect = [
+ self._mock_response([DeadlineExceeded("mock")]),
+ self._mock_response([None]),
+ ]
+ mutation = mutations.SetCell(
+ "family", b"qualifier", b"value", timestamp_micros=123
+ )
+ entries = [
+ mutations.RowMutationEntry((f"row_key_{i}").encode(), [mutation])
+ for i in range(3)
+ ]
+ await table.bulk_mutate_rows(entries, operation_timeout=1000)
+
+
+class TestCheckAndMutateRow:
+ @pytest.mark.parametrize("gapic_result", [True, False])
+ @pytest.mark.asyncio
+ async def test_check_and_mutate(self, gapic_result):
+ from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse
+
+ app_profile = "app_profile_id"
+ async with _make_client() as client:
+ async with client.get_table(
+ "instance", "table", app_profile_id=app_profile
+ ) as table:
+ with mock.patch.object(
+ client._gapic_client, "check_and_mutate_row"
+ ) as mock_gapic:
+ mock_gapic.return_value = CheckAndMutateRowResponse(
+ predicate_matched=gapic_result
+ )
+ row_key = b"row_key"
+ predicate = None
+ true_mutations = [mock.Mock()]
+ false_mutations = [mock.Mock(), mock.Mock()]
+ operation_timeout = 0.2
+ found = await table.check_and_mutate_row(
+ row_key,
+ predicate,
+ true_case_mutations=true_mutations,
+ false_case_mutations=false_mutations,
+ operation_timeout=operation_timeout,
+ )
+ assert found == gapic_result
+ kwargs = mock_gapic.call_args[1]
+ assert kwargs["table_name"] == table.table_name
+ assert kwargs["row_key"] == row_key
+ assert kwargs["predicate_filter"] == predicate
+ assert kwargs["true_mutations"] == [
+ m._to_pb() for m in true_mutations
+ ]
+ assert kwargs["false_mutations"] == [
+ m._to_pb() for m in false_mutations
+ ]
+ assert kwargs["app_profile_id"] == app_profile
+ assert kwargs["timeout"] == operation_timeout
+ assert kwargs["retry"] is None
+
+ @pytest.mark.asyncio
+ async def test_check_and_mutate_bad_timeout(self):
+ """Should raise error if operation_timeout < 0"""
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with pytest.raises(ValueError) as e:
+ await table.check_and_mutate_row(
+ b"row_key",
+ None,
+ true_case_mutations=[mock.Mock()],
+ false_case_mutations=[],
+ operation_timeout=-1,
+ )
+ assert str(e.value) == "operation_timeout must be greater than 0"
+
+ @pytest.mark.asyncio
+ async def test_check_and_mutate_single_mutations(self):
+ """if single mutations are passed, they should be internally wrapped in a list"""
+ from google.cloud.bigtable.data.mutations import SetCell
+ from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse
+
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "check_and_mutate_row"
+ ) as mock_gapic:
+ mock_gapic.return_value = CheckAndMutateRowResponse(
+ predicate_matched=True
+ )
+ true_mutation = SetCell("family", b"qualifier", b"value")
+ false_mutation = SetCell("family", b"qualifier", b"value")
+ await table.check_and_mutate_row(
+ b"row_key",
+ None,
+ true_case_mutations=true_mutation,
+ false_case_mutations=false_mutation,
+ )
+ kwargs = mock_gapic.call_args[1]
+ assert kwargs["true_mutations"] == [true_mutation._to_pb()]
+ assert kwargs["false_mutations"] == [false_mutation._to_pb()]
+
+ @pytest.mark.asyncio
+ async def test_check_and_mutate_predicate_object(self):
+ """predicate filter should be passed to gapic request"""
+ from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse
+
+ mock_predicate = mock.Mock()
+ predicate_pb = {"predicate": "dict"}
+ mock_predicate._to_pb.return_value = predicate_pb
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "check_and_mutate_row"
+ ) as mock_gapic:
+ mock_gapic.return_value = CheckAndMutateRowResponse(
+ predicate_matched=True
+ )
+ await table.check_and_mutate_row(
+ b"row_key",
+ mock_predicate,
+ false_case_mutations=[mock.Mock()],
+ )
+ kwargs = mock_gapic.call_args[1]
+ assert kwargs["predicate_filter"] == predicate_pb
+ assert mock_predicate._to_pb.call_count == 1
+ assert kwargs["retry"] is None
+
+ @pytest.mark.asyncio
+ async def test_check_and_mutate_mutations_parsing(self):
+ """mutations objects should be converted to protos"""
+ from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse
+ from google.cloud.bigtable.data.mutations import DeleteAllFromRow
+
+ mutations = [mock.Mock() for _ in range(5)]
+ for idx, mutation in enumerate(mutations):
+ mutation._to_pb.return_value = f"fake {idx}"
+ mutations.append(DeleteAllFromRow())
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "check_and_mutate_row"
+ ) as mock_gapic:
+ mock_gapic.return_value = CheckAndMutateRowResponse(
+ predicate_matched=True
+ )
+ await table.check_and_mutate_row(
+ b"row_key",
+ None,
+ true_case_mutations=mutations[0:2],
+ false_case_mutations=mutations[2:],
+ )
+ kwargs = mock_gapic.call_args[1]
+ assert kwargs["true_mutations"] == ["fake 0", "fake 1"]
+ assert kwargs["false_mutations"] == [
+ "fake 2",
+ "fake 3",
+ "fake 4",
+ DeleteAllFromRow()._to_pb(),
+ ]
+ assert all(
+ mutation._to_pb.call_count == 1 for mutation in mutations[:5]
+ )
+
+
+class TestReadModifyWriteRow:
+ @pytest.mark.parametrize(
+ "call_rules,expected_rules",
+ [
+ (
+ AppendValueRule("f", "c", b"1"),
+ [AppendValueRule("f", "c", b"1")._to_pb()],
+ ),
+ (
+ [AppendValueRule("f", "c", b"1")],
+ [AppendValueRule("f", "c", b"1")._to_pb()],
+ ),
+ (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]),
+ (
+ [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)],
+ [
+ AppendValueRule("f", "c", b"1")._to_pb(),
+ IncrementRule("f", "c", 1)._to_pb(),
+ ],
+ ),
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules):
+ """
+ Test that the gapic call is called with given rules
+ """
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with mock.patch.object(
+ client._gapic_client, "read_modify_write_row"
+ ) as mock_gapic:
+ await table.read_modify_write_row("key", call_rules)
+ assert mock_gapic.call_count == 1
+ found_kwargs = mock_gapic.call_args_list[0][1]
+ assert found_kwargs["rules"] == expected_rules
+ assert found_kwargs["retry"] is None
+
+ @pytest.mark.parametrize("rules", [[], None])
+ @pytest.mark.asyncio
+ async def test_read_modify_write_no_rules(self, rules):
+ async with _make_client() as client:
+ async with client.get_table("instance", "table") as table:
+ with pytest.raises(ValueError) as e:
+ await table.read_modify_write_row("key", rules=rules)
+ assert e.value.args[0] == "rules must contain at least one item"
+
+ @pytest.mark.asyncio
+ async def test_read_modify_write_call_defaults(self):
+ instance = "instance1"
+ table_id = "table1"
+ project = "project1"
+ row_key = "row_key1"
+ async with _make_client(project=project) as client:
+ async with client.get_table(instance, table_id) as table:
+ with mock.patch.object(
+ client._gapic_client, "read_modify_write_row"
+ ) as mock_gapic:
+ await table.read_modify_write_row(row_key, mock.Mock())
+ assert mock_gapic.call_count == 1
+ kwargs = mock_gapic.call_args_list[0][1]
+ assert (
+ kwargs["table_name"]
+ == f"projects/{project}/instances/{instance}/tables/{table_id}"
+ )
+ assert kwargs["app_profile_id"] is None
+ assert kwargs["row_key"] == row_key.encode()
+ assert kwargs["timeout"] > 1
+
+ @pytest.mark.asyncio
+ async def test_read_modify_write_call_overrides(self):
+ row_key = b"row_key1"
+ expected_timeout = 12345
+ profile_id = "profile1"
+ async with _make_client() as client:
+ async with client.get_table(
+ "instance", "table_id", app_profile_id=profile_id
+ ) as table:
+ with mock.patch.object(
+ client._gapic_client, "read_modify_write_row"
+ ) as mock_gapic:
+ await table.read_modify_write_row(
+ row_key,
+ mock.Mock(),
+ operation_timeout=expected_timeout,
+ )
+ assert mock_gapic.call_count == 1
+ kwargs = mock_gapic.call_args_list[0][1]
+ assert kwargs["app_profile_id"] is profile_id
+ assert kwargs["row_key"] == row_key
+ assert kwargs["timeout"] == expected_timeout
+
+ @pytest.mark.asyncio
+ async def test_read_modify_write_string_key(self):
+ row_key = "string_row_key1"
+ async with _make_client() as client:
+ async with client.get_table("instance", "table_id") as table:
+ with mock.patch.object(
+ client._gapic_client, "read_modify_write_row"
+ ) as mock_gapic:
+ await table.read_modify_write_row(row_key, mock.Mock())
+ assert mock_gapic.call_count == 1
+ kwargs = mock_gapic.call_args_list[0][1]
+ assert kwargs["row_key"] == row_key.encode()
+
+ @pytest.mark.asyncio
+ async def test_read_modify_write_row_building(self):
+ """
+ results from gapic call should be used to construct row
+ """
+ from google.cloud.bigtable.data.row import Row
+ from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse
+ from google.cloud.bigtable_v2.types import Row as RowPB
+
+ mock_response = ReadModifyWriteRowResponse(row=RowPB())
+ async with _make_client() as client:
+ async with client.get_table("instance", "table_id") as table:
+ with mock.patch.object(
+ client._gapic_client, "read_modify_write_row"
+ ) as mock_gapic:
+ with mock.patch.object(Row, "_from_pb") as constructor_mock:
+ mock_gapic.return_value = mock_response
+ await table.read_modify_write_row("key", mock.Mock())
+ assert constructor_mock.call_count == 1
+ constructor_mock.assert_called_once_with(mock_response.row)
diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py
new file mode 100644
index 000000000..cca7c9824
--- /dev/null
+++ b/tests/unit/data/_async/test_mutations_batcher.py
@@ -0,0 +1,1184 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+import asyncio
+import google.api_core.exceptions as core_exceptions
+from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete
+from google.cloud.bigtable.data import TABLE_DEFAULT
+
+# try/except added for compatibility with python < 3.8
+try:
+ from unittest import mock
+ from unittest.mock import AsyncMock
+except ImportError: # pragma: NO COVER
+ import mock # type: ignore
+ from mock import AsyncMock # type: ignore
+
+
+def _make_mutation(count=1, size=1):
+ mutation = mock.Mock()
+ mutation.size.return_value = size
+ mutation.mutations = [mock.Mock()] * count
+ return mutation
+
+
+class Test_FlowControl:
+ def _make_one(self, max_mutation_count=10, max_mutation_bytes=100):
+ from google.cloud.bigtable.data._async.mutations_batcher import (
+ _FlowControlAsync,
+ )
+
+ return _FlowControlAsync(max_mutation_count, max_mutation_bytes)
+
+ def test_ctor(self):
+ max_mutation_count = 9
+ max_mutation_bytes = 19
+ instance = self._make_one(max_mutation_count, max_mutation_bytes)
+ assert instance._max_mutation_count == max_mutation_count
+ assert instance._max_mutation_bytes == max_mutation_bytes
+ assert instance._in_flight_mutation_count == 0
+ assert instance._in_flight_mutation_bytes == 0
+ assert isinstance(instance._capacity_condition, asyncio.Condition)
+
+ def test_ctor_invalid_values(self):
+ """Test that values are positive, and fit within expected limits"""
+ with pytest.raises(ValueError) as e:
+ self._make_one(0, 1)
+ assert "max_mutation_count must be greater than 0" in str(e.value)
+ with pytest.raises(ValueError) as e:
+ self._make_one(1, 0)
+ assert "max_mutation_bytes must be greater than 0" in str(e.value)
+
+ @pytest.mark.parametrize(
+ "max_count,max_size,existing_count,existing_size,new_count,new_size,expected",
+ [
+ (1, 1, 0, 0, 0, 0, True),
+ (1, 1, 1, 1, 1, 1, False),
+ (10, 10, 0, 0, 0, 0, True),
+ (10, 10, 0, 0, 9, 9, True),
+ (10, 10, 0, 0, 11, 9, True),
+ (10, 10, 0, 1, 11, 9, True),
+ (10, 10, 1, 0, 11, 9, False),
+ (10, 10, 0, 0, 9, 11, True),
+ (10, 10, 1, 0, 9, 11, True),
+ (10, 10, 0, 1, 9, 11, False),
+ (10, 1, 0, 0, 1, 0, True),
+ (1, 10, 0, 0, 0, 8, True),
+ (float("inf"), float("inf"), 0, 0, 1e10, 1e10, True),
+ (8, 8, 0, 0, 1e10, 1e10, True),
+ (12, 12, 6, 6, 5, 5, True),
+ (12, 12, 5, 5, 6, 6, True),
+ (12, 12, 6, 6, 6, 6, True),
+ (12, 12, 6, 6, 7, 7, False),
+ # allow capacity check if new_count or new_size exceeds limits
+ (12, 12, 0, 0, 13, 13, True),
+ (12, 12, 12, 0, 0, 13, True),
+ (12, 12, 0, 12, 13, 0, True),
+ # but not if there's already values in flight
+ (12, 12, 1, 1, 13, 13, False),
+ (12, 12, 1, 1, 0, 13, False),
+ (12, 12, 1, 1, 13, 0, False),
+ ],
+ )
+ def test__has_capacity(
+ self,
+ max_count,
+ max_size,
+ existing_count,
+ existing_size,
+ new_count,
+ new_size,
+ expected,
+ ):
+ """
+ _has_capacity should return True if the new mutation will will not exceed the max count or size
+ """
+ instance = self._make_one(max_count, max_size)
+ instance._in_flight_mutation_count = existing_count
+ instance._in_flight_mutation_bytes = existing_size
+ assert instance._has_capacity(new_count, new_size) == expected
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "existing_count,existing_size,added_count,added_size,new_count,new_size",
+ [
+ (0, 0, 0, 0, 0, 0),
+ (2, 2, 1, 1, 1, 1),
+ (2, 0, 1, 0, 1, 0),
+ (0, 2, 0, 1, 0, 1),
+ (10, 10, 0, 0, 10, 10),
+ (10, 10, 5, 5, 5, 5),
+ (0, 0, 1, 1, -1, -1),
+ ],
+ )
+ async def test_remove_from_flow_value_update(
+ self,
+ existing_count,
+ existing_size,
+ added_count,
+ added_size,
+ new_count,
+ new_size,
+ ):
+ """
+ completed mutations should lower the inflight values
+ """
+ instance = self._make_one()
+ instance._in_flight_mutation_count = existing_count
+ instance._in_flight_mutation_bytes = existing_size
+ mutation = _make_mutation(added_count, added_size)
+ await instance.remove_from_flow(mutation)
+ assert instance._in_flight_mutation_count == new_count
+ assert instance._in_flight_mutation_bytes == new_size
+
+ @pytest.mark.asyncio
+ async def test__remove_from_flow_unlock(self):
+ """capacity condition should notify after mutation is complete"""
+ instance = self._make_one(10, 10)
+ instance._in_flight_mutation_count = 10
+ instance._in_flight_mutation_bytes = 10
+
+ async def task_routine():
+ async with instance._capacity_condition:
+ await instance._capacity_condition.wait_for(
+ lambda: instance._has_capacity(1, 1)
+ )
+
+ task = asyncio.create_task(task_routine())
+ await asyncio.sleep(0.05)
+ # should be blocked due to capacity
+ assert task.done() is False
+ # try changing size
+ mutation = _make_mutation(count=0, size=5)
+ await instance.remove_from_flow([mutation])
+ await asyncio.sleep(0.05)
+ assert instance._in_flight_mutation_count == 10
+ assert instance._in_flight_mutation_bytes == 5
+ assert task.done() is False
+ # try changing count
+ instance._in_flight_mutation_bytes = 10
+ mutation = _make_mutation(count=5, size=0)
+ await instance.remove_from_flow([mutation])
+ await asyncio.sleep(0.05)
+ assert instance._in_flight_mutation_count == 5
+ assert instance._in_flight_mutation_bytes == 10
+ assert task.done() is False
+ # try changing both
+ instance._in_flight_mutation_count = 10
+ mutation = _make_mutation(count=5, size=5)
+ await instance.remove_from_flow([mutation])
+ await asyncio.sleep(0.05)
+ assert instance._in_flight_mutation_count == 5
+ assert instance._in_flight_mutation_bytes == 5
+ # task should be complete
+ assert task.done() is True
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mutations,count_cap,size_cap,expected_results",
+ [
+ # high capacity results in no batching
+ ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]),
+ # low capacity splits up into batches
+ ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]),
+ # test count as limiting factor
+ ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]),
+ # test size as limiting factor
+ ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]),
+ # test with some bloackages and some flows
+ (
+ [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)],
+ 5,
+ 5,
+ [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]],
+ ),
+ ],
+ )
+ async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results):
+ """
+ Test batching with various flow control settings
+ """
+ mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations]
+ instance = self._make_one(count_cap, size_cap)
+ i = 0
+ async for batch in instance.add_to_flow(mutation_objs):
+ expected_batch = expected_results[i]
+ assert len(batch) == len(expected_batch)
+ for j in range(len(expected_batch)):
+ # check counts
+ assert len(batch[j].mutations) == expected_batch[j][0]
+ # check sizes
+ assert batch[j].size() == expected_batch[j][1]
+ # update lock
+ await instance.remove_from_flow(batch)
+ i += 1
+ assert i == len(expected_results)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mutations,max_limit,expected_results",
+ [
+ ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]),
+ ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]),
+ ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]),
+ ],
+ )
+ async def test_add_to_flow_max_mutation_limits(
+ self, mutations, max_limit, expected_results
+ ):
+ """
+ Test flow control running up against the max API limit
+ Should submit request early, even if the flow control has room for more
+ """
+ with mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT",
+ max_limit,
+ ):
+ mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations]
+ # flow control has no limits except API restrictions
+ instance = self._make_one(float("inf"), float("inf"))
+ i = 0
+ async for batch in instance.add_to_flow(mutation_objs):
+ expected_batch = expected_results[i]
+ assert len(batch) == len(expected_batch)
+ for j in range(len(expected_batch)):
+ # check counts
+ assert len(batch[j].mutations) == expected_batch[j][0]
+ # check sizes
+ assert batch[j].size() == expected_batch[j][1]
+ # update lock
+ await instance.remove_from_flow(batch)
+ i += 1
+ assert i == len(expected_results)
+
+ @pytest.mark.asyncio
+ async def test_add_to_flow_oversize(self):
+ """
+ mutations over the flow control limits should still be accepted
+ """
+ instance = self._make_one(2, 3)
+ large_size_mutation = _make_mutation(count=1, size=10)
+ large_count_mutation = _make_mutation(count=10, size=1)
+ results = [out async for out in instance.add_to_flow([large_size_mutation])]
+ assert len(results) == 1
+ await instance.remove_from_flow(results[0])
+ count_results = [
+ out async for out in instance.add_to_flow(large_count_mutation)
+ ]
+ assert len(count_results) == 1
+
+
+class TestMutationsBatcherAsync:
+ def _get_target_class(self):
+ from google.cloud.bigtable.data._async.mutations_batcher import (
+ MutationsBatcherAsync,
+ )
+
+ return MutationsBatcherAsync
+
+ def _make_one(self, table=None, **kwargs):
+ from google.api_core.exceptions import DeadlineExceeded
+ from google.api_core.exceptions import ServiceUnavailable
+
+ if table is None:
+ table = mock.Mock()
+ table.default_mutate_rows_operation_timeout = 10
+ table.default_mutate_rows_attempt_timeout = 10
+ table.default_mutate_rows_retryable_errors = (
+ DeadlineExceeded,
+ ServiceUnavailable,
+ )
+
+ return self._get_target_class()(table, **kwargs)
+
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer"
+ )
+ @pytest.mark.asyncio
+ async def test_ctor_defaults(self, flush_timer_mock):
+ flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0))
+ table = mock.Mock()
+ table.default_mutate_rows_operation_timeout = 10
+ table.default_mutate_rows_attempt_timeout = 8
+ table.default_mutate_rows_retryable_errors = [Exception]
+ async with self._make_one(table) as instance:
+ assert instance._table == table
+ assert instance.closed is False
+ assert instance._flush_jobs == set()
+ assert len(instance._staged_entries) == 0
+ assert len(instance._oldest_exceptions) == 0
+ assert len(instance._newest_exceptions) == 0
+ assert instance._exception_list_limit == 10
+ assert instance._exceptions_since_last_raise == 0
+ assert instance._flow_control._max_mutation_count == 100000
+ assert instance._flow_control._max_mutation_bytes == 104857600
+ assert instance._flow_control._in_flight_mutation_count == 0
+ assert instance._flow_control._in_flight_mutation_bytes == 0
+ assert instance._entries_processed_since_last_raise == 0
+ assert (
+ instance._operation_timeout
+ == table.default_mutate_rows_operation_timeout
+ )
+ assert (
+ instance._attempt_timeout == table.default_mutate_rows_attempt_timeout
+ )
+ assert (
+ instance._retryable_errors == table.default_mutate_rows_retryable_errors
+ )
+ await asyncio.sleep(0)
+ assert flush_timer_mock.call_count == 1
+ assert flush_timer_mock.call_args[0][0] == 5
+ assert isinstance(instance._flush_timer, asyncio.Future)
+
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer",
+ )
+ @pytest.mark.asyncio
+ async def test_ctor_explicit(self, flush_timer_mock):
+ """Test with explicit parameters"""
+ flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0))
+ table = mock.Mock()
+ flush_interval = 20
+ flush_limit_count = 17
+ flush_limit_bytes = 19
+ flow_control_max_mutation_count = 1001
+ flow_control_max_bytes = 12
+ operation_timeout = 11
+ attempt_timeout = 2
+ retryable_errors = [Exception]
+ async with self._make_one(
+ table,
+ flush_interval=flush_interval,
+ flush_limit_mutation_count=flush_limit_count,
+ flush_limit_bytes=flush_limit_bytes,
+ flow_control_max_mutation_count=flow_control_max_mutation_count,
+ flow_control_max_bytes=flow_control_max_bytes,
+ batch_operation_timeout=operation_timeout,
+ batch_attempt_timeout=attempt_timeout,
+ batch_retryable_errors=retryable_errors,
+ ) as instance:
+ assert instance._table == table
+ assert instance.closed is False
+ assert instance._flush_jobs == set()
+ assert len(instance._staged_entries) == 0
+ assert len(instance._oldest_exceptions) == 0
+ assert len(instance._newest_exceptions) == 0
+ assert instance._exception_list_limit == 10
+ assert instance._exceptions_since_last_raise == 0
+ assert (
+ instance._flow_control._max_mutation_count
+ == flow_control_max_mutation_count
+ )
+ assert instance._flow_control._max_mutation_bytes == flow_control_max_bytes
+ assert instance._flow_control._in_flight_mutation_count == 0
+ assert instance._flow_control._in_flight_mutation_bytes == 0
+ assert instance._entries_processed_since_last_raise == 0
+ assert instance._operation_timeout == operation_timeout
+ assert instance._attempt_timeout == attempt_timeout
+ assert instance._retryable_errors == retryable_errors
+ await asyncio.sleep(0)
+ assert flush_timer_mock.call_count == 1
+ assert flush_timer_mock.call_args[0][0] == flush_interval
+ assert isinstance(instance._flush_timer, asyncio.Future)
+
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer"
+ )
+ @pytest.mark.asyncio
+ async def test_ctor_no_flush_limits(self, flush_timer_mock):
+ """Test with None for flush limits"""
+ flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0))
+ table = mock.Mock()
+ table.default_mutate_rows_operation_timeout = 10
+ table.default_mutate_rows_attempt_timeout = 8
+ table.default_mutate_rows_retryable_errors = ()
+ flush_interval = None
+ flush_limit_count = None
+ flush_limit_bytes = None
+ async with self._make_one(
+ table,
+ flush_interval=flush_interval,
+ flush_limit_mutation_count=flush_limit_count,
+ flush_limit_bytes=flush_limit_bytes,
+ ) as instance:
+ assert instance._table == table
+ assert instance.closed is False
+ assert instance._staged_entries == []
+ assert len(instance._oldest_exceptions) == 0
+ assert len(instance._newest_exceptions) == 0
+ assert instance._exception_list_limit == 10
+ assert instance._exceptions_since_last_raise == 0
+ assert instance._flow_control._in_flight_mutation_count == 0
+ assert instance._flow_control._in_flight_mutation_bytes == 0
+ assert instance._entries_processed_since_last_raise == 0
+ await asyncio.sleep(0)
+ assert flush_timer_mock.call_count == 1
+ assert flush_timer_mock.call_args[0][0] is None
+ assert isinstance(instance._flush_timer, asyncio.Future)
+
+ @pytest.mark.asyncio
+ async def test_ctor_invalid_values(self):
+ """Test that timeout values are positive, and fit within expected limits"""
+ with pytest.raises(ValueError) as e:
+ self._make_one(batch_operation_timeout=-1)
+ assert "operation_timeout must be greater than 0" in str(e.value)
+ with pytest.raises(ValueError) as e:
+ self._make_one(batch_attempt_timeout=-1)
+ assert "attempt_timeout must be greater than 0" in str(e.value)
+
+ def test_default_argument_consistency(self):
+ """
+ We supply default arguments in MutationsBatcherAsync.__init__, and in
+ table.mutations_batcher. Make sure any changes to defaults are applied to
+ both places
+ """
+ from google.cloud.bigtable.data._async.client import TableAsync
+ from google.cloud.bigtable.data._async.mutations_batcher import (
+ MutationsBatcherAsync,
+ )
+ import inspect
+
+ get_batcher_signature = dict(
+ inspect.signature(TableAsync.mutations_batcher).parameters
+ )
+ get_batcher_signature.pop("self")
+ batcher_init_signature = dict(
+ inspect.signature(MutationsBatcherAsync).parameters
+ )
+ batcher_init_signature.pop("table")
+ # both should have same number of arguments
+ assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys())
+ assert len(get_batcher_signature) == 8 # update if expected params change
+ # both should have same argument names
+ assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys())
+ # both should have same default values
+ for arg_name in get_batcher_signature.keys():
+ assert (
+ get_batcher_signature[arg_name].default
+ == batcher_init_signature[arg_name].default
+ )
+
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush"
+ )
+ @pytest.mark.asyncio
+ async def test__start_flush_timer_w_None(self, flush_mock):
+ """Empty timer should return immediately"""
+ async with self._make_one() as instance:
+ with mock.patch("asyncio.sleep") as sleep_mock:
+ await instance._start_flush_timer(None)
+ assert sleep_mock.call_count == 0
+ assert flush_mock.call_count == 0
+
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush"
+ )
+ @pytest.mark.asyncio
+ async def test__start_flush_timer_call_when_closed(self, flush_mock):
+ """closed batcher's timer should return immediately"""
+ async with self._make_one() as instance:
+ await instance.close()
+ flush_mock.reset_mock()
+ with mock.patch("asyncio.sleep") as sleep_mock:
+ await instance._start_flush_timer(1)
+ assert sleep_mock.call_count == 0
+ assert flush_mock.call_count == 0
+
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush"
+ )
+ @pytest.mark.asyncio
+ async def test__flush_timer(self, flush_mock):
+ """Timer should continue to call _schedule_flush in a loop"""
+ expected_sleep = 12
+ async with self._make_one(flush_interval=expected_sleep) as instance:
+ instance._staged_entries = [mock.Mock()]
+ loop_num = 3
+ with mock.patch("asyncio.sleep") as sleep_mock:
+ sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()]
+ try:
+ await instance._flush_timer
+ except asyncio.CancelledError:
+ pass
+ assert sleep_mock.call_count == loop_num + 1
+ sleep_mock.assert_called_with(expected_sleep)
+ assert flush_mock.call_count == loop_num
+
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush"
+ )
+ @pytest.mark.asyncio
+ async def test__flush_timer_no_mutations(self, flush_mock):
+ """Timer should not flush if no new mutations have been staged"""
+ expected_sleep = 12
+ async with self._make_one(flush_interval=expected_sleep) as instance:
+ loop_num = 3
+ with mock.patch("asyncio.sleep") as sleep_mock:
+ sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()]
+ try:
+ await instance._flush_timer
+ except asyncio.CancelledError:
+ pass
+ assert sleep_mock.call_count == loop_num + 1
+ sleep_mock.assert_called_with(expected_sleep)
+ assert flush_mock.call_count == 0
+
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush"
+ )
+ @pytest.mark.asyncio
+ async def test__flush_timer_close(self, flush_mock):
+ """Timer should continue terminate after close"""
+ async with self._make_one() as instance:
+ with mock.patch("asyncio.sleep"):
+ # let task run in background
+ await asyncio.sleep(0.5)
+ assert instance._flush_timer.done() is False
+ # close the batcher
+ await instance.close()
+ await asyncio.sleep(0.1)
+ # task should be complete
+ assert instance._flush_timer.done() is True
+
+ @pytest.mark.asyncio
+ async def test_append_closed(self):
+ """Should raise exception"""
+ with pytest.raises(RuntimeError):
+ instance = self._make_one()
+ await instance.close()
+ await instance.append(mock.Mock())
+
+ @pytest.mark.asyncio
+ async def test_append_wrong_mutation(self):
+ """
+ Mutation objects should raise an exception.
+ Only support RowMutationEntry
+ """
+ from google.cloud.bigtable.data.mutations import DeleteAllFromRow
+
+ async with self._make_one() as instance:
+ expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher"
+ with pytest.raises(ValueError) as e:
+ await instance.append(DeleteAllFromRow())
+ assert str(e.value) == expected_error
+
+ @pytest.mark.asyncio
+ async def test_append_outside_flow_limits(self):
+ """entries larger than mutation limits are still processed"""
+ async with self._make_one(
+ flow_control_max_mutation_count=1, flow_control_max_bytes=1
+ ) as instance:
+ oversized_entry = _make_mutation(count=0, size=2)
+ await instance.append(oversized_entry)
+ assert instance._staged_entries == [oversized_entry]
+ assert instance._staged_count == 0
+ assert instance._staged_bytes == 2
+ instance._staged_entries = []
+ async with self._make_one(
+ flow_control_max_mutation_count=1, flow_control_max_bytes=1
+ ) as instance:
+ overcount_entry = _make_mutation(count=2, size=0)
+ await instance.append(overcount_entry)
+ assert instance._staged_entries == [overcount_entry]
+ assert instance._staged_count == 2
+ assert instance._staged_bytes == 0
+ instance._staged_entries = []
+
+ @pytest.mark.asyncio
+ async def test_append_flush_runs_after_limit_hit(self):
+ """
+ If the user appends a bunch of entries above the flush limits back-to-back,
+ it should still flush in a single task
+ """
+ from google.cloud.bigtable.data._async.mutations_batcher import (
+ MutationsBatcherAsync,
+ )
+
+ with mock.patch.object(
+ MutationsBatcherAsync, "_execute_mutate_rows"
+ ) as op_mock:
+ async with self._make_one(flush_limit_bytes=100) as instance:
+ # mock network calls
+ async def mock_call(*args, **kwargs):
+ return []
+
+ op_mock.side_effect = mock_call
+ # append a mutation just under the size limit
+ await instance.append(_make_mutation(size=99))
+ # append a bunch of entries back-to-back in a loop
+ num_entries = 10
+ for _ in range(num_entries):
+ await instance.append(_make_mutation(size=1))
+ # let any flush jobs finish
+ await asyncio.gather(*instance._flush_jobs)
+ # should have only flushed once, with large mutation and first mutation in loop
+ assert op_mock.call_count == 1
+ sent_batch = op_mock.call_args[0][0]
+ assert len(sent_batch) == 2
+ # others should still be pending
+ assert len(instance._staged_entries) == num_entries - 1
+
+ @pytest.mark.parametrize(
+ "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush",
+ [
+ (10, 10, 1, 1, False),
+ (10, 10, 9, 9, False),
+ (10, 10, 10, 1, True),
+ (10, 10, 1, 10, True),
+ (10, 10, 10, 10, True),
+ (1, 1, 10, 10, True),
+ (1, 1, 0, 0, False),
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_append(
+ self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush
+ ):
+ """test appending different mutations, and checking if it causes a flush"""
+ async with self._make_one(
+ flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes
+ ) as instance:
+ assert instance._staged_count == 0
+ assert instance._staged_bytes == 0
+ assert instance._staged_entries == []
+ mutation = _make_mutation(count=mutation_count, size=mutation_bytes)
+ with mock.patch.object(instance, "_schedule_flush") as flush_mock:
+ await instance.append(mutation)
+ assert flush_mock.call_count == bool(expect_flush)
+ assert instance._staged_count == mutation_count
+ assert instance._staged_bytes == mutation_bytes
+ assert instance._staged_entries == [mutation]
+ instance._staged_entries = []
+
+ @pytest.mark.asyncio
+ async def test_append_multiple_sequentially(self):
+ """Append multiple mutations"""
+ async with self._make_one(
+ flush_limit_mutation_count=8, flush_limit_bytes=8
+ ) as instance:
+ assert instance._staged_count == 0
+ assert instance._staged_bytes == 0
+ assert instance._staged_entries == []
+ mutation = _make_mutation(count=2, size=3)
+ with mock.patch.object(instance, "_schedule_flush") as flush_mock:
+ await instance.append(mutation)
+ assert flush_mock.call_count == 0
+ assert instance._staged_count == 2
+ assert instance._staged_bytes == 3
+ assert len(instance._staged_entries) == 1
+ await instance.append(mutation)
+ assert flush_mock.call_count == 0
+ assert instance._staged_count == 4
+ assert instance._staged_bytes == 6
+ assert len(instance._staged_entries) == 2
+ await instance.append(mutation)
+ assert flush_mock.call_count == 1
+ assert instance._staged_count == 6
+ assert instance._staged_bytes == 9
+ assert len(instance._staged_entries) == 3
+ instance._staged_entries = []
+
+ @pytest.mark.asyncio
+ async def test_flush_flow_control_concurrent_requests(self):
+ """
+ requests should happen in parallel if flow control breaks up single flush into batches
+ """
+ import time
+
+ num_calls = 10
+ fake_mutations = [_make_mutation(count=1) for _ in range(num_calls)]
+ async with self._make_one(flow_control_max_mutation_count=1) as instance:
+ with mock.patch.object(
+ instance, "_execute_mutate_rows", AsyncMock()
+ ) as op_mock:
+ # mock network calls
+ async def mock_call(*args, **kwargs):
+ await asyncio.sleep(0.1)
+ return []
+
+ op_mock.side_effect = mock_call
+ start_time = time.monotonic()
+ # flush one large batch, that will be broken up into smaller batches
+ instance._staged_entries = fake_mutations
+ instance._schedule_flush()
+ await asyncio.sleep(0.01)
+ # make room for new mutations
+ for i in range(num_calls):
+ await instance._flow_control.remove_from_flow(
+ [_make_mutation(count=1)]
+ )
+ await asyncio.sleep(0.01)
+ # allow flushes to complete
+ await asyncio.gather(*instance._flush_jobs)
+ duration = time.monotonic() - start_time
+ assert len(instance._oldest_exceptions) == 0
+ assert len(instance._newest_exceptions) == 0
+ # if flushes were sequential, total duration would be 1s
+ assert duration < 0.5
+ assert op_mock.call_count == num_calls
+
+ @pytest.mark.asyncio
+ async def test_schedule_flush_no_mutations(self):
+ """schedule flush should return None if no staged mutations"""
+ async with self._make_one() as instance:
+ with mock.patch.object(instance, "_flush_internal") as flush_mock:
+ for i in range(3):
+ assert instance._schedule_flush() is None
+ assert flush_mock.call_count == 0
+
+ @pytest.mark.asyncio
+ async def test_schedule_flush_with_mutations(self):
+ """if new mutations exist, should add a new flush task to _flush_jobs"""
+ async with self._make_one() as instance:
+ with mock.patch.object(instance, "_flush_internal") as flush_mock:
+ for i in range(1, 4):
+ mutation = mock.Mock()
+ instance._staged_entries = [mutation]
+ instance._schedule_flush()
+ assert instance._staged_entries == []
+ # let flush task run
+ await asyncio.sleep(0)
+ assert instance._staged_entries == []
+ assert instance._staged_count == 0
+ assert instance._staged_bytes == 0
+ assert flush_mock.call_count == i
+
+ @pytest.mark.asyncio
+ async def test__flush_internal(self):
+ """
+ _flush_internal should:
+ - await previous flush call
+ - delegate batching to _flow_control
+ - call _execute_mutate_rows on each batch
+ - update self.exceptions and self._entries_processed_since_last_raise
+ """
+ num_entries = 10
+ async with self._make_one() as instance:
+ with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock:
+ with mock.patch.object(
+ instance._flow_control, "add_to_flow"
+ ) as flow_mock:
+ # mock flow control to always return a single batch
+ async def gen(x):
+ yield x
+
+ flow_mock.side_effect = lambda x: gen(x)
+ mutations = [_make_mutation(count=1, size=1)] * num_entries
+ await instance._flush_internal(mutations)
+ assert instance._entries_processed_since_last_raise == num_entries
+ assert execute_mock.call_count == 1
+ assert flow_mock.call_count == 1
+ instance._oldest_exceptions.clear()
+ instance._newest_exceptions.clear()
+
+ @pytest.mark.asyncio
+ async def test_flush_clears_job_list(self):
+ """
+ a job should be added to _flush_jobs when _schedule_flush is called,
+ and removed when it completes
+ """
+ async with self._make_one() as instance:
+ with mock.patch.object(instance, "_flush_internal", AsyncMock()):
+ mutations = [_make_mutation(count=1, size=1)]
+ instance._staged_entries = mutations
+ assert instance._flush_jobs == set()
+ new_job = instance._schedule_flush()
+ assert instance._flush_jobs == {new_job}
+ await new_job
+ assert instance._flush_jobs == set()
+
+ @pytest.mark.parametrize(
+ "num_starting,num_new_errors,expected_total_errors",
+ [
+ (0, 0, 0),
+ (0, 1, 1),
+ (0, 2, 2),
+ (1, 0, 1),
+ (1, 1, 2),
+ (10, 2, 12),
+ (10, 20, 20), # should cap at 20
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test__flush_internal_with_errors(
+ self, num_starting, num_new_errors, expected_total_errors
+ ):
+ """
+ errors returned from _execute_mutate_rows should be added to internal exceptions
+ """
+ from google.cloud.bigtable.data import exceptions
+
+ num_entries = 10
+ expected_errors = [
+ exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError())
+ ] * num_new_errors
+ async with self._make_one() as instance:
+ instance._oldest_exceptions = [mock.Mock()] * num_starting
+ with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock:
+ execute_mock.return_value = expected_errors
+ with mock.patch.object(
+ instance._flow_control, "add_to_flow"
+ ) as flow_mock:
+ # mock flow control to always return a single batch
+ async def gen(x):
+ yield x
+
+ flow_mock.side_effect = lambda x: gen(x)
+ mutations = [_make_mutation(count=1, size=1)] * num_entries
+ await instance._flush_internal(mutations)
+ assert instance._entries_processed_since_last_raise == num_entries
+ assert execute_mock.call_count == 1
+ assert flow_mock.call_count == 1
+ found_exceptions = instance._oldest_exceptions + list(
+ instance._newest_exceptions
+ )
+ assert len(found_exceptions) == expected_total_errors
+ for i in range(num_starting, expected_total_errors):
+ assert found_exceptions[i] == expected_errors[i - num_starting]
+ # errors should have index stripped
+ assert found_exceptions[i].index is None
+ # clear out exceptions
+ instance._oldest_exceptions.clear()
+ instance._newest_exceptions.clear()
+
+ async def _mock_gapic_return(self, num=5):
+ from google.cloud.bigtable_v2.types import MutateRowsResponse
+ from google.rpc import status_pb2
+
+ async def gen(num):
+ for i in range(num):
+ entry = MutateRowsResponse.Entry(
+ index=i, status=status_pb2.Status(code=0)
+ )
+ yield MutateRowsResponse(entries=[entry])
+
+ return gen(num)
+
+ @pytest.mark.asyncio
+ async def test_timer_flush_end_to_end(self):
+ """Flush should automatically trigger after flush_interval"""
+ num_nutations = 10
+ mutations = [_make_mutation(count=2, size=2)] * num_nutations
+
+ async with self._make_one(flush_interval=0.05) as instance:
+ instance._table.default_operation_timeout = 10
+ instance._table.default_attempt_timeout = 9
+ with mock.patch.object(
+ instance._table.client._gapic_client, "mutate_rows"
+ ) as gapic_mock:
+ gapic_mock.side_effect = (
+ lambda *args, **kwargs: self._mock_gapic_return(num_nutations)
+ )
+ for m in mutations:
+ await instance.append(m)
+ assert instance._entries_processed_since_last_raise == 0
+ # let flush trigger due to timer
+ await asyncio.sleep(0.1)
+ assert instance._entries_processed_since_last_raise == num_nutations
+
+ @pytest.mark.asyncio
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync",
+ )
+ async def test__execute_mutate_rows(self, mutate_rows):
+ mutate_rows.return_value = AsyncMock()
+ start_operation = mutate_rows().start
+ table = mock.Mock()
+ table.table_name = "test-table"
+ table.app_profile_id = "test-app-profile"
+ table.default_mutate_rows_operation_timeout = 17
+ table.default_mutate_rows_attempt_timeout = 13
+ table.default_mutate_rows_retryable_errors = ()
+ async with self._make_one(table) as instance:
+ batch = [_make_mutation()]
+ result = await instance._execute_mutate_rows(batch)
+ assert start_operation.call_count == 1
+ args, kwargs = mutate_rows.call_args
+ assert args[0] == table.client._gapic_client
+ assert args[1] == table
+ assert args[2] == batch
+ kwargs["operation_timeout"] == 17
+ kwargs["attempt_timeout"] == 13
+ assert result == []
+
+ @pytest.mark.asyncio
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync.start"
+ )
+ async def test__execute_mutate_rows_returns_errors(self, mutate_rows):
+ """Errors from operation should be retruned as list"""
+ from google.cloud.bigtable.data.exceptions import (
+ MutationsExceptionGroup,
+ FailedMutationEntryError,
+ )
+
+ err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error"))
+ err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error"))
+ mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10)
+ table = mock.Mock()
+ table.default_mutate_rows_operation_timeout = 17
+ table.default_mutate_rows_attempt_timeout = 13
+ table.default_mutate_rows_retryable_errors = ()
+ async with self._make_one(table) as instance:
+ batch = [_make_mutation()]
+ result = await instance._execute_mutate_rows(batch)
+ assert len(result) == 2
+ assert result[0] == err1
+ assert result[1] == err2
+ # indices should be set to None
+ assert result[0].index is None
+ assert result[1].index is None
+
+ @pytest.mark.asyncio
+ async def test__raise_exceptions(self):
+ """Raise exceptions and reset error state"""
+ from google.cloud.bigtable.data import exceptions
+
+ expected_total = 1201
+ expected_exceptions = [RuntimeError("mock")] * 3
+ async with self._make_one() as instance:
+ instance._oldest_exceptions = expected_exceptions
+ instance._entries_processed_since_last_raise = expected_total
+ try:
+ instance._raise_exceptions()
+ except exceptions.MutationsExceptionGroup as exc:
+ assert list(exc.exceptions) == expected_exceptions
+ assert str(expected_total) in str(exc)
+ assert instance._entries_processed_since_last_raise == 0
+ instance._oldest_exceptions, instance._newest_exceptions = ([], [])
+ # try calling again
+ instance._raise_exceptions()
+
+ @pytest.mark.asyncio
+ async def test___aenter__(self):
+ """Should return self"""
+ async with self._make_one() as instance:
+ assert await instance.__aenter__() == instance
+
+ @pytest.mark.asyncio
+ async def test___aexit__(self):
+ """aexit should call close"""
+ async with self._make_one() as instance:
+ with mock.patch.object(instance, "close") as close_mock:
+ await instance.__aexit__(None, None, None)
+ assert close_mock.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_close(self):
+ """Should clean up all resources"""
+ async with self._make_one() as instance:
+ with mock.patch.object(instance, "_schedule_flush") as flush_mock:
+ with mock.patch.object(instance, "_raise_exceptions") as raise_mock:
+ await instance.close()
+ assert instance.closed is True
+ assert instance._flush_timer.done() is True
+ assert instance._flush_jobs == set()
+ assert flush_mock.call_count == 1
+ assert raise_mock.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_close_w_exceptions(self):
+ """Raise exceptions on close"""
+ from google.cloud.bigtable.data import exceptions
+
+ expected_total = 10
+ expected_exceptions = [RuntimeError("mock")]
+ async with self._make_one() as instance:
+ instance._oldest_exceptions = expected_exceptions
+ instance._entries_processed_since_last_raise = expected_total
+ try:
+ await instance.close()
+ except exceptions.MutationsExceptionGroup as exc:
+ assert list(exc.exceptions) == expected_exceptions
+ assert str(expected_total) in str(exc)
+ assert instance._entries_processed_since_last_raise == 0
+ # clear out exceptions
+ instance._oldest_exceptions, instance._newest_exceptions = ([], [])
+
+ @pytest.mark.asyncio
+ async def test__on_exit(self, recwarn):
+ """Should raise warnings if unflushed mutations exist"""
+ async with self._make_one() as instance:
+ # calling without mutations is noop
+ instance._on_exit()
+ assert len(recwarn) == 0
+ # calling with existing mutations should raise warning
+ num_left = 4
+ instance._staged_entries = [mock.Mock()] * num_left
+ with pytest.warns(UserWarning) as w:
+ instance._on_exit()
+ assert len(w) == 1
+ assert "unflushed mutations" in str(w[0].message).lower()
+ assert str(num_left) in str(w[0].message)
+ # calling while closed is noop
+ instance.closed = True
+ instance._on_exit()
+ assert len(recwarn) == 0
+ # reset staged mutations for cleanup
+ instance._staged_entries = []
+
+ @pytest.mark.asyncio
+ async def test_atexit_registration(self):
+ """Should run _on_exit on program termination"""
+ import atexit
+
+ with mock.patch.object(atexit, "register") as register_mock:
+ assert register_mock.call_count == 0
+ async with self._make_one():
+ assert register_mock.call_count == 1
+
+ @pytest.mark.asyncio
+ @mock.patch(
+ "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync",
+ )
+ async def test_timeout_args_passed(self, mutate_rows):
+ """
+ batch_operation_timeout and batch_attempt_timeout should be used
+ in api calls
+ """
+ mutate_rows.return_value = AsyncMock()
+ expected_operation_timeout = 17
+ expected_attempt_timeout = 13
+ async with self._make_one(
+ batch_operation_timeout=expected_operation_timeout,
+ batch_attempt_timeout=expected_attempt_timeout,
+ ) as instance:
+ assert instance._operation_timeout == expected_operation_timeout
+ assert instance._attempt_timeout == expected_attempt_timeout
+ # make simulated gapic call
+ await instance._execute_mutate_rows([_make_mutation()])
+ assert mutate_rows.call_count == 1
+ kwargs = mutate_rows.call_args[1]
+ assert kwargs["operation_timeout"] == expected_operation_timeout
+ assert kwargs["attempt_timeout"] == expected_attempt_timeout
+
+ @pytest.mark.parametrize(
+ "limit,in_e,start_e,end_e",
+ [
+ (10, 0, (10, 0), (10, 0)),
+ (1, 10, (0, 0), (1, 1)),
+ (10, 1, (0, 0), (1, 0)),
+ (10, 10, (0, 0), (10, 0)),
+ (10, 11, (0, 0), (10, 1)),
+ (3, 20, (0, 0), (3, 3)),
+ (10, 20, (0, 0), (10, 10)),
+ (10, 21, (0, 0), (10, 10)),
+ (2, 1, (2, 0), (2, 1)),
+ (2, 1, (1, 0), (2, 0)),
+ (2, 2, (1, 0), (2, 1)),
+ (3, 1, (3, 1), (3, 2)),
+ (3, 3, (3, 1), (3, 3)),
+ (1000, 5, (999, 0), (1000, 4)),
+ (1000, 5, (0, 0), (5, 0)),
+ (1000, 5, (1000, 0), (1000, 5)),
+ ],
+ )
+ def test__add_exceptions(self, limit, in_e, start_e, end_e):
+ """
+ Test that the _add_exceptions function properly updates the
+ _oldest_exceptions and _newest_exceptions lists
+ Args:
+ - limit: the _exception_list_limit representing the max size of either list
+ - in_e: size of list of exceptions to send to _add_exceptions
+ - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions
+ - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions
+ """
+ from collections import deque
+
+ input_list = [RuntimeError(f"mock {i}") for i in range(in_e)]
+ mock_batcher = mock.Mock()
+ mock_batcher._oldest_exceptions = [
+ RuntimeError(f"starting mock {i}") for i in range(start_e[0])
+ ]
+ mock_batcher._newest_exceptions = deque(
+ [RuntimeError(f"starting mock {i}") for i in range(start_e[1])],
+ maxlen=limit,
+ )
+ mock_batcher._exception_list_limit = limit
+ mock_batcher._exceptions_since_last_raise = 0
+ self._get_target_class()._add_exceptions(mock_batcher, input_list)
+ assert len(mock_batcher._oldest_exceptions) == end_e[0]
+ assert len(mock_batcher._newest_exceptions) == end_e[1]
+ assert mock_batcher._exceptions_since_last_raise == in_e
+ # make sure that the right items ended up in the right spots
+ # should fill the oldest slots first
+ oldest_list_diff = end_e[0] - start_e[0]
+ # new items should by added on top of the starting list
+ newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit)
+ for i in range(oldest_list_diff):
+ assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i]
+ # then, the newest slots should be filled with the last items of the input list
+ for i in range(1, newest_list_diff + 1):
+ assert mock_batcher._newest_exceptions[-i] == input_list[-i]
+
+ @pytest.mark.asyncio
+ # test different inputs for retryable exceptions
+ @pytest.mark.parametrize(
+ "input_retryables,expected_retryables",
+ [
+ (
+ TABLE_DEFAULT.READ_ROWS,
+ [
+ core_exceptions.DeadlineExceeded,
+ core_exceptions.ServiceUnavailable,
+ core_exceptions.Aborted,
+ ],
+ ),
+ (
+ TABLE_DEFAULT.DEFAULT,
+ [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable],
+ ),
+ (
+ TABLE_DEFAULT.MUTATE_ROWS,
+ [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable],
+ ),
+ ([], []),
+ ([4], [core_exceptions.DeadlineExceeded]),
+ ],
+ )
+ async def test_customizable_retryable_errors(
+ self, input_retryables, expected_retryables
+ ):
+ """
+ Test that retryable functions support user-configurable arguments, and that the configured retryables are passed
+ down to the gapic layer.
+ """
+ from google.cloud.bigtable.data._async.client import TableAsync
+
+ with mock.patch(
+ "google.api_core.retry.if_exception_type"
+ ) as predicate_builder_mock:
+ with mock.patch(
+ "google.api_core.retry.retry_target_async"
+ ) as retry_fn_mock:
+ table = None
+ with mock.patch("asyncio.create_task"):
+ table = TableAsync(mock.Mock(), "instance", "table")
+ async with self._make_one(
+ table, batch_retryable_errors=input_retryables
+ ) as instance:
+ assert instance._retryable_errors == expected_retryables
+ expected_predicate = lambda a: a in expected_retryables # noqa
+ predicate_builder_mock.return_value = expected_predicate
+ retry_fn_mock.side_effect = RuntimeError("stop early")
+ mutation = _make_mutation(count=1, size=1)
+ await instance._execute_mutate_rows([mutation])
+ # passed in errors should be used to build the predicate
+ predicate_builder_mock.assert_called_once_with(
+ *expected_retryables, _MutateRowsIncomplete
+ )
+ retry_call_args = retry_fn_mock.call_args_list[0].args
+ # output of if_exception_type should be sent in to retry constructor
+ assert retry_call_args[1] is expected_predicate
diff --git a/tests/unit/read-rows-acceptance-test.json b/tests/unit/data/read-rows-acceptance-test.json
similarity index 100%
rename from tests/unit/read-rows-acceptance-test.json
rename to tests/unit/data/read-rows-acceptance-test.json
diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py
new file mode 100644
index 000000000..5a9c500ed
--- /dev/null
+++ b/tests/unit/data/test__helpers.py
@@ -0,0 +1,248 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+import grpc
+from google.api_core import exceptions as core_exceptions
+import google.cloud.bigtable.data._helpers as _helpers
+from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
+
+import mock
+
+
+class TestMakeMetadata:
+ @pytest.mark.parametrize(
+ "table,profile,expected",
+ [
+ ("table", "profile", "table_name=table&app_profile_id=profile"),
+ ("table", None, "table_name=table"),
+ ],
+ )
+ def test__make_metadata(self, table, profile, expected):
+ metadata = _helpers._make_metadata(table, profile)
+ assert metadata == [("x-goog-request-params", expected)]
+
+
+class TestAttemptTimeoutGenerator:
+ @pytest.mark.parametrize(
+ "request_t,operation_t,expected_list",
+ [
+ (1, 3.5, [1, 1, 1, 0.5, 0, 0]),
+ (None, 3.5, [3.5, 2.5, 1.5, 0.5, 0, 0]),
+ (10, 5, [5, 4, 3, 2, 1, 0, 0]),
+ (3, 3, [3, 2, 1, 0, 0, 0, 0]),
+ (0, 3, [0, 0, 0]),
+ (3, 0, [0, 0, 0]),
+ (-1, 3, [0, 0, 0]),
+ (3, -1, [0, 0, 0]),
+ ],
+ )
+ def test_attempt_timeout_generator(self, request_t, operation_t, expected_list):
+ """
+ test different values for timeouts. Clock is incremented by 1 second for each item in expected_list
+ """
+ timestamp_start = 123
+ with mock.patch("time.monotonic") as mock_monotonic:
+ mock_monotonic.return_value = timestamp_start
+ generator = _helpers._attempt_timeout_generator(request_t, operation_t)
+ for val in expected_list:
+ mock_monotonic.return_value += 1
+ assert next(generator) == val
+
+ @pytest.mark.parametrize(
+ "request_t,operation_t,expected",
+ [
+ (1, 3.5, 1),
+ (None, 3.5, 3.5),
+ (10, 5, 5),
+ (5, 10, 5),
+ (3, 3, 3),
+ (0, 3, 0),
+ (3, 0, 0),
+ (-1, 3, 0),
+ (3, -1, 0),
+ ],
+ )
+ def test_attempt_timeout_frozen_time(self, request_t, operation_t, expected):
+ """test with time.monotonic frozen"""
+ timestamp_start = 123
+ with mock.patch("time.monotonic") as mock_monotonic:
+ mock_monotonic.return_value = timestamp_start
+ generator = _helpers._attempt_timeout_generator(request_t, operation_t)
+ assert next(generator) == expected
+ # value should not change without time.monotonic changing
+ assert next(generator) == expected
+
+ def test_attempt_timeout_w_sleeps(self):
+ """use real sleep values to make sure it matches expectations"""
+ from time import sleep
+
+ operation_timeout = 1
+ generator = _helpers._attempt_timeout_generator(None, operation_timeout)
+ expected_value = operation_timeout
+ sleep_time = 0.1
+ for i in range(3):
+ found_value = next(generator)
+ assert abs(found_value - expected_value) < 0.001
+ sleep(sleep_time)
+ expected_value -= sleep_time
+
+
+class TestValidateTimeouts:
+ def test_validate_timeouts_error_messages(self):
+ with pytest.raises(ValueError) as e:
+ _helpers._validate_timeouts(operation_timeout=1, attempt_timeout=-1)
+ assert "attempt_timeout must be greater than 0" in str(e.value)
+ with pytest.raises(ValueError) as e:
+ _helpers._validate_timeouts(operation_timeout=-1, attempt_timeout=1)
+ assert "operation_timeout must be greater than 0" in str(e.value)
+
+ @pytest.mark.parametrize(
+ "args,expected",
+ [
+ ([1, None, False], False),
+ ([1, None, True], True),
+ ([1, 1, False], True),
+ ([1, 1, True], True),
+ ([1, 1], True),
+ ([1, None], False),
+ ([2, 1], True),
+ ([0, 1], False),
+ ([1, 0], False),
+ ([60, None], False),
+ ([600, None], False),
+ ([600, 600], True),
+ ],
+ )
+ def test_validate_with_inputs(self, args, expected):
+ """
+ test whether an exception is thrown with different inputs
+ """
+ success = False
+ try:
+ _helpers._validate_timeouts(*args)
+ success = True
+ except ValueError:
+ pass
+ assert success == expected
+
+
+class TestGetTimeouts:
+ @pytest.mark.parametrize(
+ "input_times,input_table,expected",
+ [
+ ((2, 1), {}, (2, 1)),
+ ((2, 4), {}, (2, 2)),
+ ((2, None), {}, (2, 2)),
+ (
+ (TABLE_DEFAULT.DEFAULT, TABLE_DEFAULT.DEFAULT),
+ {"operation": 3, "attempt": 2},
+ (3, 2),
+ ),
+ (
+ (TABLE_DEFAULT.READ_ROWS, TABLE_DEFAULT.READ_ROWS),
+ {"read_rows_operation": 3, "read_rows_attempt": 2},
+ (3, 2),
+ ),
+ (
+ (TABLE_DEFAULT.MUTATE_ROWS, TABLE_DEFAULT.MUTATE_ROWS),
+ {"mutate_rows_operation": 3, "mutate_rows_attempt": 2},
+ (3, 2),
+ ),
+ ((10, TABLE_DEFAULT.DEFAULT), {"attempt": None}, (10, 10)),
+ ((10, TABLE_DEFAULT.DEFAULT), {"attempt": 5}, (10, 5)),
+ ((10, TABLE_DEFAULT.DEFAULT), {"attempt": 100}, (10, 10)),
+ ((TABLE_DEFAULT.DEFAULT, 10), {"operation": 12}, (12, 10)),
+ ((TABLE_DEFAULT.DEFAULT, 10), {"operation": 3}, (3, 3)),
+ ],
+ )
+ def test_get_timeouts(self, input_times, input_table, expected):
+ """
+ test input/output mappings for a variety of valid inputs
+ """
+ fake_table = mock.Mock()
+ for key in input_table.keys():
+ # set the default fields in our fake table mock
+ setattr(fake_table, f"default_{key}_timeout", input_table[key])
+ t1, t2 = _helpers._get_timeouts(input_times[0], input_times[1], fake_table)
+ assert t1 == expected[0]
+ assert t2 == expected[1]
+
+ @pytest.mark.parametrize(
+ "input_times,input_table",
+ [
+ ([0, 1], {}),
+ ([1, 0], {}),
+ ([None, 1], {}),
+ ([TABLE_DEFAULT.DEFAULT, 1], {"operation": None}),
+ ([TABLE_DEFAULT.DEFAULT, 1], {"operation": 0}),
+ ([1, TABLE_DEFAULT.DEFAULT], {"attempt": 0}),
+ ],
+ )
+ def test_get_timeouts_invalid(self, input_times, input_table):
+ """
+ test with inputs that should raise error during validation step
+ """
+ fake_table = mock.Mock()
+ for key in input_table.keys():
+ # set the default fields in our fake table mock
+ setattr(fake_table, f"default_{key}_timeout", input_table[key])
+ with pytest.raises(ValueError):
+ _helpers._get_timeouts(input_times[0], input_times[1], fake_table)
+
+
+class TestGetRetryableErrors:
+ @pytest.mark.parametrize(
+ "input_codes,input_table,expected",
+ [
+ ((), {}, []),
+ ((Exception,), {}, [Exception]),
+ (TABLE_DEFAULT.DEFAULT, {"default": [Exception]}, [Exception]),
+ (
+ TABLE_DEFAULT.READ_ROWS,
+ {"default_read_rows": (RuntimeError, ValueError)},
+ [RuntimeError, ValueError],
+ ),
+ (
+ TABLE_DEFAULT.MUTATE_ROWS,
+ {"default_mutate_rows": (ValueError,)},
+ [ValueError],
+ ),
+ ((4,), {}, [core_exceptions.DeadlineExceeded]),
+ (
+ [grpc.StatusCode.DEADLINE_EXCEEDED],
+ {},
+ [core_exceptions.DeadlineExceeded],
+ ),
+ (
+ (14, grpc.StatusCode.ABORTED, RuntimeError),
+ {},
+ [
+ core_exceptions.ServiceUnavailable,
+ core_exceptions.Aborted,
+ RuntimeError,
+ ],
+ ),
+ ],
+ )
+ def test_get_retryable_errors(self, input_codes, input_table, expected):
+ """
+ test input/output mappings for a variety of valid inputs
+ """
+ fake_table = mock.Mock()
+ for key in input_table.keys():
+ # set the default fields in our fake table mock
+ setattr(fake_table, f"{key}_retryable_errors", input_table[key])
+ result = _helpers._get_retryable_errors(input_codes, fake_table)
+ assert result == expected
diff --git a/tests/unit/data/test_exceptions.py b/tests/unit/data/test_exceptions.py
new file mode 100644
index 000000000..bc921717e
--- /dev/null
+++ b/tests/unit/data/test_exceptions.py
@@ -0,0 +1,533 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import pytest
+import sys
+
+import google.cloud.bigtable.data.exceptions as bigtable_exceptions
+
+# try/except added for compatibility with python < 3.8
+try:
+ from unittest import mock
+except ImportError: # pragma: NO COVER
+ import mock # type: ignore
+
+
+class TracebackTests311:
+ """
+ Provides a set of tests that should be run on python 3.11 and above,
+ to verify that the exception traceback looks as expected
+ """
+
+ @pytest.mark.skipif(
+ sys.version_info < (3, 11), reason="requires python3.11 or higher"
+ )
+ def test_311_traceback(self):
+ """
+ Exception customizations should not break rich exception group traceback in python 3.11
+ """
+ import traceback
+
+ sub_exc1 = RuntimeError("first sub exception")
+ sub_exc2 = ZeroDivisionError("second sub exception")
+ sub_group = self._make_one(excs=[sub_exc2])
+ exc_group = self._make_one(excs=[sub_exc1, sub_group])
+
+ expected_traceback = (
+ f" | google.cloud.bigtable.data.exceptions.{type(exc_group).__name__}: {str(exc_group)}",
+ " +-+---------------- 1 ----------------",
+ " | RuntimeError: first sub exception",
+ " +---------------- 2 ----------------",
+ f" | google.cloud.bigtable.data.exceptions.{type(sub_group).__name__}: {str(sub_group)}",
+ " +-+---------------- 1 ----------------",
+ " | ZeroDivisionError: second sub exception",
+ " +------------------------------------",
+ )
+ exception_caught = False
+ try:
+ raise exc_group
+ except self._get_class():
+ exception_caught = True
+ tb = traceback.format_exc()
+ tb_relevant_lines = tuple(tb.splitlines()[3:])
+ assert expected_traceback == tb_relevant_lines
+ assert exception_caught
+
+ @pytest.mark.skipif(
+ sys.version_info < (3, 11), reason="requires python3.11 or higher"
+ )
+ def test_311_traceback_with_cause(self):
+ """
+ traceback should display nicely with sub-exceptions with __cause__ set
+ """
+ import traceback
+
+ sub_exc1 = RuntimeError("first sub exception")
+ cause_exc = ImportError("cause exception")
+ sub_exc1.__cause__ = cause_exc
+ sub_exc2 = ZeroDivisionError("second sub exception")
+ exc_group = self._make_one(excs=[sub_exc1, sub_exc2])
+
+ expected_traceback = (
+ f" | google.cloud.bigtable.data.exceptions.{type(exc_group).__name__}: {str(exc_group)}",
+ " +-+---------------- 1 ----------------",
+ " | ImportError: cause exception",
+ " | ",
+ " | The above exception was the direct cause of the following exception:",
+ " | ",
+ " | RuntimeError: first sub exception",
+ " +---------------- 2 ----------------",
+ " | ZeroDivisionError: second sub exception",
+ " +------------------------------------",
+ )
+ exception_caught = False
+ try:
+ raise exc_group
+ except self._get_class():
+ exception_caught = True
+ tb = traceback.format_exc()
+ tb_relevant_lines = tuple(tb.splitlines()[3:])
+ assert expected_traceback == tb_relevant_lines
+ assert exception_caught
+
+ @pytest.mark.skipif(
+ sys.version_info < (3, 11), reason="requires python3.11 or higher"
+ )
+ def test_311_exception_group(self):
+ """
+ Python 3.11+ should handle exepctions as native exception groups
+ """
+ exceptions = [RuntimeError("mock"), ValueError("mock")]
+ instance = self._make_one(excs=exceptions)
+ # ensure split works as expected
+ runtime_error, others = instance.split(lambda e: isinstance(e, RuntimeError))
+ assert runtime_error.exceptions[0] == exceptions[0]
+ assert others.exceptions[0] == exceptions[1]
+
+
+class TracebackTests310:
+ """
+ Provides a set of tests that should be run on python 3.10 and under,
+ to verify that the exception traceback looks as expected
+ """
+
+ @pytest.mark.skipif(
+ sys.version_info >= (3, 11), reason="requires python3.10 or lower"
+ )
+ def test_310_traceback(self):
+ """
+ Exception customizations should not break rich exception group traceback in python 3.10
+ """
+ import traceback
+
+ sub_exc1 = RuntimeError("first sub exception")
+ sub_exc2 = ZeroDivisionError("second sub exception")
+ sub_group = self._make_one(excs=[sub_exc2])
+ exc_group = self._make_one(excs=[sub_exc1, sub_group])
+ found_message = str(exc_group).splitlines()[0]
+ found_sub_message = str(sub_group).splitlines()[0]
+
+ expected_traceback = (
+ f"google.cloud.bigtable.data.exceptions.{type(exc_group).__name__}: {found_message}",
+ "--+---------------- 1 ----------------",
+ " | RuntimeError: first sub exception",
+ " +---------------- 2 ----------------",
+ f" | {type(sub_group).__name__}: {found_sub_message}",
+ " --+---------------- 1 ----------------",
+ " | ZeroDivisionError: second sub exception",
+ " +------------------------------------",
+ )
+ exception_caught = False
+ try:
+ raise exc_group
+ except self._get_class():
+ exception_caught = True
+ tb = traceback.format_exc()
+ tb_relevant_lines = tuple(tb.splitlines()[3:])
+ assert expected_traceback == tb_relevant_lines
+ assert exception_caught
+
+ @pytest.mark.skipif(
+ sys.version_info >= (3, 11), reason="requires python3.10 or lower"
+ )
+ def test_310_traceback_with_cause(self):
+ """
+ traceback should display nicely with sub-exceptions with __cause__ set
+ """
+ import traceback
+
+ sub_exc1 = RuntimeError("first sub exception")
+ cause_exc = ImportError("cause exception")
+ sub_exc1.__cause__ = cause_exc
+ sub_exc2 = ZeroDivisionError("second sub exception")
+ exc_group = self._make_one(excs=[sub_exc1, sub_exc2])
+ found_message = str(exc_group).splitlines()[0]
+
+ expected_traceback = (
+ f"google.cloud.bigtable.data.exceptions.{type(exc_group).__name__}: {found_message}",
+ "--+---------------- 1 ----------------",
+ " | ImportError: cause exception",
+ " | ",
+ " | The above exception was the direct cause of the following exception:",
+ " | ",
+ " | RuntimeError: first sub exception",
+ " +---------------- 2 ----------------",
+ " | ZeroDivisionError: second sub exception",
+ " +------------------------------------",
+ )
+ exception_caught = False
+ try:
+ raise exc_group
+ except self._get_class():
+ exception_caught = True
+ tb = traceback.format_exc()
+ tb_relevant_lines = tuple(tb.splitlines()[3:])
+ assert expected_traceback == tb_relevant_lines
+ assert exception_caught
+
+
+class TestBigtableExceptionGroup(TracebackTests311, TracebackTests310):
+ """
+ Subclass for MutationsExceptionGroup, RetryExceptionGroup, and ShardedReadRowsExceptionGroup
+ """
+
+ def _get_class(self):
+ from google.cloud.bigtable.data.exceptions import _BigtableExceptionGroup
+
+ return _BigtableExceptionGroup
+
+ def _make_one(self, message="test_message", excs=None):
+ if excs is None:
+ excs = [RuntimeError("mock")]
+
+ return self._get_class()(message, excs=excs)
+
+ def test_raise(self):
+ """
+ Create exception in raise statement, which calls __new__ and __init__
+ """
+ test_msg = "test message"
+ test_excs = [Exception(test_msg)]
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class()(test_msg, test_excs)
+ found_message = str(e.value).splitlines()[
+ 0
+ ] # added to prase out subexceptions in <3.11
+ assert found_message == test_msg
+ assert list(e.value.exceptions) == test_excs
+
+ def test_raise_empty_list(self):
+ """
+ Empty exception lists are not supported
+ """
+ with pytest.raises(ValueError) as e:
+ raise self._make_one(excs=[])
+ assert "non-empty sequence" in str(e.value)
+
+ def test_exception_handling(self):
+ """
+ All versions should inherit from exception
+ and support tranditional exception handling
+ """
+ instance = self._make_one()
+ assert isinstance(instance, Exception)
+ try:
+ raise instance
+ except Exception as e:
+ assert isinstance(e, Exception)
+ assert e == instance
+ was_raised = True
+ assert was_raised
+
+
+class TestMutationsExceptionGroup(TestBigtableExceptionGroup):
+ def _get_class(self):
+ from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
+
+ return MutationsExceptionGroup
+
+ def _make_one(self, excs=None, num_entries=3):
+ if excs is None:
+ excs = [RuntimeError("mock")]
+
+ return self._get_class()(excs, num_entries)
+
+ @pytest.mark.parametrize(
+ "exception_list,total_entries,expected_message",
+ [
+ ([Exception()], 1, "1 failed entry from 1 attempted."),
+ ([Exception()], 2, "1 failed entry from 2 attempted."),
+ (
+ [Exception(), RuntimeError()],
+ 2,
+ "2 failed entries from 2 attempted.",
+ ),
+ ],
+ )
+ def test_raise(self, exception_list, total_entries, expected_message):
+ """
+ Create exception in raise statement, which calls __new__ and __init__
+ """
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class()(exception_list, total_entries)
+ found_message = str(e.value).splitlines()[
+ 0
+ ] # added to prase out subexceptions in <3.11
+ assert found_message == expected_message
+ assert list(e.value.exceptions) == exception_list
+
+ def test_raise_custom_message(self):
+ """
+ should be able to set a custom error message
+ """
+ custom_message = "custom message"
+ exception_list = [Exception()]
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class()(exception_list, 5, message=custom_message)
+ found_message = str(e.value).splitlines()[
+ 0
+ ] # added to prase out subexceptions in <3.11
+ assert found_message == custom_message
+ assert list(e.value.exceptions) == exception_list
+
+ @pytest.mark.parametrize(
+ "first_list_len,second_list_len,total_excs,entry_count,expected_message",
+ [
+ (3, 0, 3, 4, "3 failed entries from 4 attempted."),
+ (1, 0, 1, 2, "1 failed entry from 2 attempted."),
+ (0, 1, 1, 2, "1 failed entry from 2 attempted."),
+ (2, 2, 4, 4, "4 failed entries from 4 attempted."),
+ (
+ 1,
+ 1,
+ 3,
+ 2,
+ "3 failed entries from 2 attempted. (first 1 and last 1 attached as sub-exceptions; 1 truncated)",
+ ),
+ (
+ 1,
+ 2,
+ 100,
+ 2,
+ "100 failed entries from 2 attempted. (first 1 and last 2 attached as sub-exceptions; 97 truncated)",
+ ),
+ (
+ 2,
+ 1,
+ 4,
+ 9,
+ "4 failed entries from 9 attempted. (first 2 and last 1 attached as sub-exceptions; 1 truncated)",
+ ),
+ (
+ 3,
+ 0,
+ 10,
+ 10,
+ "10 failed entries from 10 attempted. (first 3 attached as sub-exceptions; 7 truncated)",
+ ),
+ (
+ 0,
+ 3,
+ 10,
+ 10,
+ "10 failed entries from 10 attempted. (last 3 attached as sub-exceptions; 7 truncated)",
+ ),
+ ],
+ )
+ def test_from_truncated_lists(
+ self, first_list_len, second_list_len, total_excs, entry_count, expected_message
+ ):
+ """
+ Should be able to make MutationsExceptionGroup using a pair of
+ lists representing a larger truncated list of exceptions
+ """
+ first_list = [Exception()] * first_list_len
+ second_list = [Exception()] * second_list_len
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class().from_truncated_lists(
+ first_list, second_list, total_excs, entry_count
+ )
+ found_message = str(e.value).splitlines()[
+ 0
+ ] # added to prase out subexceptions in <3.11
+ assert found_message == expected_message
+ assert list(e.value.exceptions) == first_list + second_list
+
+
+class TestRetryExceptionGroup(TestBigtableExceptionGroup):
+ def _get_class(self):
+ from google.cloud.bigtable.data.exceptions import RetryExceptionGroup
+
+ return RetryExceptionGroup
+
+ def _make_one(self, excs=None):
+ if excs is None:
+ excs = [RuntimeError("mock")]
+
+ return self._get_class()(excs=excs)
+
+ @pytest.mark.parametrize(
+ "exception_list,expected_message",
+ [
+ ([Exception()], "1 failed attempt"),
+ ([Exception(), RuntimeError()], "2 failed attempts"),
+ (
+ [Exception(), ValueError("test")],
+ "2 failed attempts",
+ ),
+ (
+ [
+ bigtable_exceptions.RetryExceptionGroup(
+ [Exception(), ValueError("test")]
+ )
+ ],
+ "1 failed attempt",
+ ),
+ ],
+ )
+ def test_raise(self, exception_list, expected_message):
+ """
+ Create exception in raise statement, which calls __new__ and __init__
+ """
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class()(exception_list)
+ found_message = str(e.value).splitlines()[
+ 0
+ ] # added to prase out subexceptions in <3.11
+ assert found_message == expected_message
+ assert list(e.value.exceptions) == exception_list
+
+
+class TestShardedReadRowsExceptionGroup(TestBigtableExceptionGroup):
+ def _get_class(self):
+ from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup
+
+ return ShardedReadRowsExceptionGroup
+
+ def _make_one(self, excs=None, succeeded=None, num_entries=3):
+ if excs is None:
+ excs = [RuntimeError("mock")]
+ succeeded = succeeded or []
+
+ return self._get_class()(excs, succeeded, num_entries)
+
+ @pytest.mark.parametrize(
+ "exception_list,succeeded,total_entries,expected_message",
+ [
+ ([Exception()], [], 1, "1 sub-exception (from 1 query attempted)"),
+ ([Exception()], [1], 2, "1 sub-exception (from 2 queries attempted)"),
+ (
+ [Exception(), RuntimeError()],
+ [0, 1],
+ 2,
+ "2 sub-exceptions (from 2 queries attempted)",
+ ),
+ ],
+ )
+ def test_raise(self, exception_list, succeeded, total_entries, expected_message):
+ """
+ Create exception in raise statement, which calls __new__ and __init__
+ """
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class()(exception_list, succeeded, total_entries)
+ found_message = str(e.value).splitlines()[
+ 0
+ ] # added to prase out subexceptions in <3.11
+ assert found_message == expected_message
+ assert list(e.value.exceptions) == exception_list
+ assert e.value.successful_rows == succeeded
+
+
+class TestFailedMutationEntryError:
+ def _get_class(self):
+ from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
+
+ return FailedMutationEntryError
+
+ def _make_one(self, idx=9, entry=mock.Mock(), cause=RuntimeError("mock")):
+ return self._get_class()(idx, entry, cause)
+
+ def test_raise(self):
+ """
+ Create exception in raise statement, which calls __new__ and __init__
+ """
+ test_idx = 2
+ test_entry = mock.Mock()
+ test_exc = ValueError("test")
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class()(test_idx, test_entry, test_exc)
+ assert str(e.value) == "Failed idempotent mutation entry at index 2"
+ assert e.value.index == test_idx
+ assert e.value.entry == test_entry
+ assert e.value.__cause__ == test_exc
+ assert isinstance(e.value, Exception)
+ assert test_entry.is_idempotent.call_count == 1
+
+ def test_raise_idempotent(self):
+ """
+ Test raise with non idempotent entry
+ """
+ test_idx = 2
+ test_entry = unittest.mock.Mock()
+ test_entry.is_idempotent.return_value = False
+ test_exc = ValueError("test")
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class()(test_idx, test_entry, test_exc)
+ assert str(e.value) == "Failed non-idempotent mutation entry at index 2"
+ assert e.value.index == test_idx
+ assert e.value.entry == test_entry
+ assert e.value.__cause__ == test_exc
+ assert test_entry.is_idempotent.call_count == 1
+
+ def test_no_index(self):
+ """
+ Instances without an index should display different error string
+ """
+ test_idx = None
+ test_entry = unittest.mock.Mock()
+ test_exc = ValueError("test")
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class()(test_idx, test_entry, test_exc)
+ assert str(e.value) == "Failed idempotent mutation entry"
+ assert e.value.index == test_idx
+ assert e.value.entry == test_entry
+ assert e.value.__cause__ == test_exc
+ assert isinstance(e.value, Exception)
+ assert test_entry.is_idempotent.call_count == 1
+
+
+class TestFailedQueryShardError:
+ def _get_class(self):
+ from google.cloud.bigtable.data.exceptions import FailedQueryShardError
+
+ return FailedQueryShardError
+
+ def _make_one(self, idx=9, query=mock.Mock(), cause=RuntimeError("mock")):
+ return self._get_class()(idx, query, cause)
+
+ def test_raise(self):
+ """
+ Create exception in raise statement, which calls __new__ and __init__
+ """
+ test_idx = 2
+ test_query = mock.Mock()
+ test_exc = ValueError("test")
+ with pytest.raises(self._get_class()) as e:
+ raise self._get_class()(test_idx, test_query, test_exc)
+ assert str(e.value) == "Failed query at index 2"
+ assert e.value.index == test_idx
+ assert e.value.query == test_query
+ assert e.value.__cause__ == test_exc
+ assert isinstance(e.value, Exception)
diff --git a/tests/unit/data/test_mutations.py b/tests/unit/data/test_mutations.py
new file mode 100644
index 000000000..485c86e42
--- /dev/null
+++ b/tests/unit/data/test_mutations.py
@@ -0,0 +1,708 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+import google.cloud.bigtable.data.mutations as mutations
+
+# try/except added for compatibility with python < 3.8
+try:
+ from unittest import mock
+except ImportError: # pragma: NO COVER
+ import mock # type: ignore
+
+
+class TestBaseMutation:
+ def _target_class(self):
+ from google.cloud.bigtable.data.mutations import Mutation
+
+ return Mutation
+
+ def test__to_dict(self):
+ """Should be unimplemented in the base class"""
+ with pytest.raises(NotImplementedError):
+ self._target_class()._to_dict(mock.Mock())
+
+ def test_is_idempotent(self):
+ """is_idempotent should assume True"""
+ assert self._target_class().is_idempotent(mock.Mock())
+
+ def test___str__(self):
+ """Str representation of mutations should be to_dict"""
+ self_mock = mock.Mock()
+ str_value = self._target_class().__str__(self_mock)
+ assert self_mock._to_dict.called
+ assert str_value == str(self_mock._to_dict.return_value)
+
+ @pytest.mark.parametrize("test_dict", [{}, {"key": "value"}])
+ def test_size(self, test_dict):
+ from sys import getsizeof
+
+ """Size should return size of dict representation"""
+ self_mock = mock.Mock()
+ self_mock._to_dict.return_value = test_dict
+ size_value = self._target_class().size(self_mock)
+ assert size_value == getsizeof(test_dict)
+
+ @pytest.mark.parametrize(
+ "expected_class,input_dict",
+ [
+ (
+ mutations.SetCell,
+ {
+ "set_cell": {
+ "family_name": "foo",
+ "column_qualifier": b"bar",
+ "value": b"test",
+ "timestamp_micros": 12345,
+ }
+ },
+ ),
+ (
+ mutations.DeleteRangeFromColumn,
+ {
+ "delete_from_column": {
+ "family_name": "foo",
+ "column_qualifier": b"bar",
+ "time_range": {},
+ }
+ },
+ ),
+ (
+ mutations.DeleteRangeFromColumn,
+ {
+ "delete_from_column": {
+ "family_name": "foo",
+ "column_qualifier": b"bar",
+ "time_range": {"start_timestamp_micros": 123456789},
+ }
+ },
+ ),
+ (
+ mutations.DeleteRangeFromColumn,
+ {
+ "delete_from_column": {
+ "family_name": "foo",
+ "column_qualifier": b"bar",
+ "time_range": {"end_timestamp_micros": 123456789},
+ }
+ },
+ ),
+ (
+ mutations.DeleteRangeFromColumn,
+ {
+ "delete_from_column": {
+ "family_name": "foo",
+ "column_qualifier": b"bar",
+ "time_range": {
+ "start_timestamp_micros": 123,
+ "end_timestamp_micros": 123456789,
+ },
+ }
+ },
+ ),
+ (
+ mutations.DeleteAllFromFamily,
+ {"delete_from_family": {"family_name": "foo"}},
+ ),
+ (mutations.DeleteAllFromRow, {"delete_from_row": {}}),
+ ],
+ )
+ def test__from_dict(self, expected_class, input_dict):
+ """Should be able to create instance from dict"""
+ instance = self._target_class()._from_dict(input_dict)
+ assert isinstance(instance, expected_class)
+ found_dict = instance._to_dict()
+ assert found_dict == input_dict
+
+ @pytest.mark.parametrize(
+ "input_dict",
+ [
+ {"set_cell": {}},
+ {
+ "set_cell": {
+ "column_qualifier": b"bar",
+ "value": b"test",
+ "timestamp_micros": 12345,
+ }
+ },
+ {
+ "set_cell": {
+ "family_name": "f",
+ "column_qualifier": b"bar",
+ "value": b"test",
+ }
+ },
+ {"delete_from_family": {}},
+ {"delete_from_column": {}},
+ {"fake-type"},
+ {},
+ ],
+ )
+ def test__from_dict_missing_fields(self, input_dict):
+ """If dict is malformed or fields are missing, should raise ValueError"""
+ with pytest.raises(ValueError):
+ self._target_class()._from_dict(input_dict)
+
+ def test__from_dict_wrong_subclass(self):
+ """You shouldn't be able to instantiate one mutation type using the dict of another"""
+ subclasses = [
+ mutations.SetCell("foo", b"bar", b"test"),
+ mutations.DeleteRangeFromColumn("foo", b"bar"),
+ mutations.DeleteAllFromFamily("foo"),
+ mutations.DeleteAllFromRow(),
+ ]
+ for instance in subclasses:
+ others = [other for other in subclasses if other != instance]
+ for other in others:
+ with pytest.raises(ValueError) as e:
+ type(other)._from_dict(instance._to_dict())
+ assert "Mutation type mismatch" in str(e.value)
+
+
+class TestSetCell:
+ def _target_class(self):
+ from google.cloud.bigtable.data.mutations import SetCell
+
+ return SetCell
+
+ def _make_one(self, *args, **kwargs):
+ return self._target_class()(*args, **kwargs)
+
+ @pytest.mark.parametrize("input_val", [2**64, -(2**64)])
+ def test_ctor_large_int(self, input_val):
+ with pytest.raises(ValueError) as e:
+ self._make_one(family="f", qualifier=b"b", new_value=input_val)
+ assert "int values must be between" in str(e.value)
+
+ @pytest.mark.parametrize("input_val", ["", "a", "abc", "hello world!"])
+ def test_ctor_str_value(self, input_val):
+ found = self._make_one(family="f", qualifier=b"b", new_value=input_val)
+ assert found.new_value == input_val.encode("utf-8")
+
+ def test_ctor(self):
+ """Ensure constructor sets expected values"""
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ expected_value = b"test-value"
+ expected_timestamp = 1234567890
+ instance = self._make_one(
+ expected_family, expected_qualifier, expected_value, expected_timestamp
+ )
+ assert instance.family == expected_family
+ assert instance.qualifier == expected_qualifier
+ assert instance.new_value == expected_value
+ assert instance.timestamp_micros == expected_timestamp
+
+ def test_ctor_str_inputs(self):
+ """Test with string qualifier and value"""
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ expected_value = b"test-value"
+ instance = self._make_one(expected_family, "test-qualifier", "test-value")
+ assert instance.family == expected_family
+ assert instance.qualifier == expected_qualifier
+ assert instance.new_value == expected_value
+
+ @pytest.mark.parametrize("input_val", [-20, -1, 0, 1, 100, int(2**60)])
+ def test_ctor_int_value(self, input_val):
+ found = self._make_one(family="f", qualifier=b"b", new_value=input_val)
+ assert found.new_value == input_val.to_bytes(8, "big", signed=True)
+
+ @pytest.mark.parametrize(
+ "int_value,expected_bytes",
+ [
+ (-42, b"\xff\xff\xff\xff\xff\xff\xff\xd6"),
+ (-2, b"\xff\xff\xff\xff\xff\xff\xff\xfe"),
+ (-1, b"\xff\xff\xff\xff\xff\xff\xff\xff"),
+ (0, b"\x00\x00\x00\x00\x00\x00\x00\x00"),
+ (1, b"\x00\x00\x00\x00\x00\x00\x00\x01"),
+ (2, b"\x00\x00\x00\x00\x00\x00\x00\x02"),
+ (100, b"\x00\x00\x00\x00\x00\x00\x00d"),
+ ],
+ )
+ def test_ctor_int_value_bytes(self, int_value, expected_bytes):
+ """Test with int value"""
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ instance = self._make_one(expected_family, expected_qualifier, int_value)
+ assert instance.family == expected_family
+ assert instance.qualifier == expected_qualifier
+ assert instance.new_value == expected_bytes
+
+ def test_ctor_negative_timestamp(self):
+ """Only positive or -1 timestamps are valid"""
+ with pytest.raises(ValueError) as e:
+ self._make_one("test-family", b"test-qualifier", b"test-value", -2)
+ assert (
+ "timestamp_micros must be positive (or -1 for server-side timestamp)"
+ in str(e.value)
+ )
+
+ @pytest.mark.parametrize(
+ "timestamp_ns,expected_timestamp_micros",
+ [
+ (0, 0),
+ (1, 0),
+ (123, 0),
+ (999, 0),
+ (999_999, 0),
+ (1_000_000, 1000),
+ (1_234_567, 1000),
+ (1_999_999, 1000),
+ (2_000_000, 2000),
+ (1_234_567_890_123, 1_234_567_000),
+ ],
+ )
+ def test_ctor_no_timestamp(self, timestamp_ns, expected_timestamp_micros):
+ """If no timestamp is given, should use current time with millisecond precision"""
+ with mock.patch("time.time_ns", return_value=timestamp_ns):
+ instance = self._make_one("test-family", b"test-qualifier", b"test-value")
+ assert instance.timestamp_micros == expected_timestamp_micros
+
+ def test__to_dict(self):
+ """ensure dict representation is as expected"""
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ expected_value = b"test-value"
+ expected_timestamp = 123456789
+ instance = self._make_one(
+ expected_family, expected_qualifier, expected_value, expected_timestamp
+ )
+ got_dict = instance._to_dict()
+ assert list(got_dict.keys()) == ["set_cell"]
+ got_inner_dict = got_dict["set_cell"]
+ assert got_inner_dict["family_name"] == expected_family
+ assert got_inner_dict["column_qualifier"] == expected_qualifier
+ assert got_inner_dict["timestamp_micros"] == expected_timestamp
+ assert got_inner_dict["value"] == expected_value
+ assert len(got_inner_dict.keys()) == 4
+
+ def test__to_dict_server_timestamp(self):
+ """test with server side timestamp -1 value"""
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ expected_value = b"test-value"
+ expected_timestamp = -1
+ instance = self._make_one(
+ expected_family, expected_qualifier, expected_value, expected_timestamp
+ )
+ got_dict = instance._to_dict()
+ assert list(got_dict.keys()) == ["set_cell"]
+ got_inner_dict = got_dict["set_cell"]
+ assert got_inner_dict["family_name"] == expected_family
+ assert got_inner_dict["column_qualifier"] == expected_qualifier
+ assert got_inner_dict["timestamp_micros"] == expected_timestamp
+ assert got_inner_dict["value"] == expected_value
+ assert len(got_inner_dict.keys()) == 4
+
+ def test__to_pb(self):
+ """ensure proto representation is as expected"""
+ import google.cloud.bigtable_v2.types.data as data_pb
+
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ expected_value = b"test-value"
+ expected_timestamp = 123456789
+ instance = self._make_one(
+ expected_family, expected_qualifier, expected_value, expected_timestamp
+ )
+ got_pb = instance._to_pb()
+ assert isinstance(got_pb, data_pb.Mutation)
+ assert got_pb.set_cell.family_name == expected_family
+ assert got_pb.set_cell.column_qualifier == expected_qualifier
+ assert got_pb.set_cell.timestamp_micros == expected_timestamp
+ assert got_pb.set_cell.value == expected_value
+
+ def test__to_pb_server_timestamp(self):
+ """test with server side timestamp -1 value"""
+ import google.cloud.bigtable_v2.types.data as data_pb
+
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ expected_value = b"test-value"
+ expected_timestamp = -1
+ instance = self._make_one(
+ expected_family, expected_qualifier, expected_value, expected_timestamp
+ )
+ got_pb = instance._to_pb()
+ assert isinstance(got_pb, data_pb.Mutation)
+ assert got_pb.set_cell.family_name == expected_family
+ assert got_pb.set_cell.column_qualifier == expected_qualifier
+ assert got_pb.set_cell.timestamp_micros == expected_timestamp
+ assert got_pb.set_cell.value == expected_value
+
+ @pytest.mark.parametrize(
+ "timestamp,expected_value",
+ [
+ (1234567890, True),
+ (1, True),
+ (0, True),
+ (-1, False),
+ (None, True),
+ ],
+ )
+ def test_is_idempotent(self, timestamp, expected_value):
+ """is_idempotent is based on whether an explicit timestamp is set"""
+ instance = self._make_one(
+ "test-family", b"test-qualifier", b"test-value", timestamp
+ )
+ assert instance.is_idempotent() is expected_value
+
+ def test___str__(self):
+ """Str representation of mutations should be to_dict"""
+ instance = self._make_one(
+ "test-family", b"test-qualifier", b"test-value", 1234567890
+ )
+ str_value = instance.__str__()
+ dict_value = instance._to_dict()
+ assert str_value == str(dict_value)
+
+
+class TestDeleteRangeFromColumn:
+ def _target_class(self):
+ from google.cloud.bigtable.data.mutations import DeleteRangeFromColumn
+
+ return DeleteRangeFromColumn
+
+ def _make_one(self, *args, **kwargs):
+ return self._target_class()(*args, **kwargs)
+
+ def test_ctor(self):
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ expected_start = 1234567890
+ expected_end = 1234567891
+ instance = self._make_one(
+ expected_family, expected_qualifier, expected_start, expected_end
+ )
+ assert instance.family == expected_family
+ assert instance.qualifier == expected_qualifier
+ assert instance.start_timestamp_micros == expected_start
+ assert instance.end_timestamp_micros == expected_end
+
+ def test_ctor_no_timestamps(self):
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ instance = self._make_one(expected_family, expected_qualifier)
+ assert instance.family == expected_family
+ assert instance.qualifier == expected_qualifier
+ assert instance.start_timestamp_micros is None
+ assert instance.end_timestamp_micros is None
+
+ def test_ctor_timestamps_out_of_order(self):
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ expected_start = 10
+ expected_end = 1
+ with pytest.raises(ValueError) as excinfo:
+ self._make_one(
+ expected_family, expected_qualifier, expected_start, expected_end
+ )
+ assert "start_timestamp_micros must be <= end_timestamp_micros" in str(
+ excinfo.value
+ )
+
+ @pytest.mark.parametrize(
+ "start,end",
+ [
+ (0, 1),
+ (None, 1),
+ (0, None),
+ ],
+ )
+ def test__to_dict(self, start, end):
+ """Should be unimplemented in the base class"""
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+
+ instance = self._make_one(expected_family, expected_qualifier, start, end)
+ got_dict = instance._to_dict()
+ assert list(got_dict.keys()) == ["delete_from_column"]
+ got_inner_dict = got_dict["delete_from_column"]
+ assert len(got_inner_dict.keys()) == 3
+ assert got_inner_dict["family_name"] == expected_family
+ assert got_inner_dict["column_qualifier"] == expected_qualifier
+ time_range_dict = got_inner_dict["time_range"]
+ expected_len = int(isinstance(start, int)) + int(isinstance(end, int))
+ assert len(time_range_dict.keys()) == expected_len
+ if start is not None:
+ assert time_range_dict["start_timestamp_micros"] == start
+ if end is not None:
+ assert time_range_dict["end_timestamp_micros"] == end
+
+ def test__to_pb(self):
+ """ensure proto representation is as expected"""
+ import google.cloud.bigtable_v2.types.data as data_pb
+
+ expected_family = "test-family"
+ expected_qualifier = b"test-qualifier"
+ instance = self._make_one(expected_family, expected_qualifier)
+ got_pb = instance._to_pb()
+ assert isinstance(got_pb, data_pb.Mutation)
+ assert got_pb.delete_from_column.family_name == expected_family
+ assert got_pb.delete_from_column.column_qualifier == expected_qualifier
+
+ def test_is_idempotent(self):
+ """is_idempotent is always true"""
+ instance = self._make_one(
+ "test-family", b"test-qualifier", 1234567890, 1234567891
+ )
+ assert instance.is_idempotent() is True
+
+ def test___str__(self):
+ """Str representation of mutations should be to_dict"""
+ instance = self._make_one("test-family", b"test-qualifier")
+ str_value = instance.__str__()
+ dict_value = instance._to_dict()
+ assert str_value == str(dict_value)
+
+
+class TestDeleteAllFromFamily:
+ def _target_class(self):
+ from google.cloud.bigtable.data.mutations import DeleteAllFromFamily
+
+ return DeleteAllFromFamily
+
+ def _make_one(self, *args, **kwargs):
+ return self._target_class()(*args, **kwargs)
+
+ def test_ctor(self):
+ expected_family = "test-family"
+ instance = self._make_one(expected_family)
+ assert instance.family_to_delete == expected_family
+
+ def test__to_dict(self):
+ """Should be unimplemented in the base class"""
+ expected_family = "test-family"
+ instance = self._make_one(expected_family)
+ got_dict = instance._to_dict()
+ assert list(got_dict.keys()) == ["delete_from_family"]
+ got_inner_dict = got_dict["delete_from_family"]
+ assert len(got_inner_dict.keys()) == 1
+ assert got_inner_dict["family_name"] == expected_family
+
+ def test__to_pb(self):
+ """ensure proto representation is as expected"""
+ import google.cloud.bigtable_v2.types.data as data_pb
+
+ expected_family = "test-family"
+ instance = self._make_one(expected_family)
+ got_pb = instance._to_pb()
+ assert isinstance(got_pb, data_pb.Mutation)
+ assert got_pb.delete_from_family.family_name == expected_family
+
+ def test_is_idempotent(self):
+ """is_idempotent is always true"""
+ instance = self._make_one("test-family")
+ assert instance.is_idempotent() is True
+
+ def test___str__(self):
+ """Str representation of mutations should be to_dict"""
+ instance = self._make_one("test-family")
+ str_value = instance.__str__()
+ dict_value = instance._to_dict()
+ assert str_value == str(dict_value)
+
+
+class TestDeleteFromRow:
+ def _target_class(self):
+ from google.cloud.bigtable.data.mutations import DeleteAllFromRow
+
+ return DeleteAllFromRow
+
+ def _make_one(self, *args, **kwargs):
+ return self._target_class()(*args, **kwargs)
+
+ def test_ctor(self):
+ self._make_one()
+
+ def test__to_dict(self):
+ """Should be unimplemented in the base class"""
+ instance = self._make_one()
+ got_dict = instance._to_dict()
+ assert list(got_dict.keys()) == ["delete_from_row"]
+ assert len(got_dict["delete_from_row"].keys()) == 0
+
+ def test__to_pb(self):
+ """ensure proto representation is as expected"""
+ import google.cloud.bigtable_v2.types.data as data_pb
+
+ instance = self._make_one()
+ got_pb = instance._to_pb()
+ assert isinstance(got_pb, data_pb.Mutation)
+ assert "delete_from_row" in str(got_pb)
+
+ def test_is_idempotent(self):
+ """is_idempotent is always true"""
+ instance = self._make_one()
+ assert instance.is_idempotent() is True
+
+ def test___str__(self):
+ """Str representation of mutations should be to_dict"""
+ instance = self._make_one()
+ assert instance.__str__() == "{'delete_from_row': {}}"
+
+
+class TestRowMutationEntry:
+ def _target_class(self):
+ from google.cloud.bigtable.data.mutations import RowMutationEntry
+
+ return RowMutationEntry
+
+ def _make_one(self, row_key, mutations):
+ return self._target_class()(row_key, mutations)
+
+ def test_ctor(self):
+ expected_key = b"row_key"
+ expected_mutations = [mock.Mock()]
+ instance = self._make_one(expected_key, expected_mutations)
+ assert instance.row_key == expected_key
+ assert list(instance.mutations) == expected_mutations
+
+ def test_ctor_over_limit(self):
+ """Should raise error if mutations exceed MAX_MUTATIONS_PER_ENTRY"""
+ from google.cloud.bigtable.data.mutations import (
+ _MUTATE_ROWS_REQUEST_MUTATION_LIMIT,
+ )
+
+ assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100_000
+ # no errors at limit
+ expected_mutations = [None for _ in range(_MUTATE_ROWS_REQUEST_MUTATION_LIMIT)]
+ self._make_one(b"row_key", expected_mutations)
+ # error if over limit
+ with pytest.raises(ValueError) as e:
+ self._make_one("key", expected_mutations + [mock.Mock()])
+ assert "entries must have <= 100000 mutations" in str(e.value)
+
+ def test_ctor_str_key(self):
+ expected_key = "row_key"
+ expected_mutations = [mock.Mock(), mock.Mock()]
+ instance = self._make_one(expected_key, expected_mutations)
+ assert instance.row_key == b"row_key"
+ assert list(instance.mutations) == expected_mutations
+
+ def test_ctor_single_mutation(self):
+ from google.cloud.bigtable.data.mutations import DeleteAllFromRow
+
+ expected_key = b"row_key"
+ expected_mutations = DeleteAllFromRow()
+ instance = self._make_one(expected_key, expected_mutations)
+ assert instance.row_key == expected_key
+ assert instance.mutations == (expected_mutations,)
+
+ def test__to_dict(self):
+ expected_key = "row_key"
+ mutation_mock = mock.Mock()
+ n_mutations = 3
+ expected_mutations = [mutation_mock for i in range(n_mutations)]
+ for mock_mutations in expected_mutations:
+ mock_mutations._to_dict.return_value = {"test": "data"}
+ instance = self._make_one(expected_key, expected_mutations)
+ expected_result = {
+ "row_key": b"row_key",
+ "mutations": [{"test": "data"}] * n_mutations,
+ }
+ assert instance._to_dict() == expected_result
+ assert mutation_mock._to_dict.call_count == n_mutations
+
+ def test__to_pb(self):
+ from google.cloud.bigtable_v2.types.bigtable import MutateRowsRequest
+ from google.cloud.bigtable_v2.types.data import Mutation
+
+ expected_key = "row_key"
+ mutation_mock = mock.Mock()
+ n_mutations = 3
+ expected_mutations = [mutation_mock for i in range(n_mutations)]
+ for mock_mutations in expected_mutations:
+ mock_mutations._to_pb.return_value = Mutation()
+ instance = self._make_one(expected_key, expected_mutations)
+ pb_result = instance._to_pb()
+ assert isinstance(pb_result, MutateRowsRequest.Entry)
+ assert pb_result.row_key == b"row_key"
+ assert pb_result.mutations == [Mutation()] * n_mutations
+ assert mutation_mock._to_pb.call_count == n_mutations
+
+ @pytest.mark.parametrize(
+ "mutations,result",
+ [
+ ([mock.Mock(is_idempotent=lambda: True)], True),
+ ([mock.Mock(is_idempotent=lambda: False)], False),
+ (
+ [
+ mock.Mock(is_idempotent=lambda: True),
+ mock.Mock(is_idempotent=lambda: False),
+ ],
+ False,
+ ),
+ (
+ [
+ mock.Mock(is_idempotent=lambda: True),
+ mock.Mock(is_idempotent=lambda: True),
+ ],
+ True,
+ ),
+ ],
+ )
+ def test_is_idempotent(self, mutations, result):
+ instance = self._make_one("row_key", mutations)
+ assert instance.is_idempotent() == result
+
+ def test_empty_mutations(self):
+ with pytest.raises(ValueError) as e:
+ self._make_one("row_key", [])
+ assert "must not be empty" in str(e.value)
+
+ @pytest.mark.parametrize("test_dict", [{}, {"key": "value"}])
+ def test_size(self, test_dict):
+ from sys import getsizeof
+
+ """Size should return size of dict representation"""
+ self_mock = mock.Mock()
+ self_mock._to_dict.return_value = test_dict
+ size_value = self._target_class().size(self_mock)
+ assert size_value == getsizeof(test_dict)
+
+ def test__from_dict_mock(self):
+ """
+ test creating instance from entry dict, with mocked mutation._from_dict
+ """
+ expected_key = b"row_key"
+ expected_mutations = [mock.Mock(), mock.Mock()]
+ input_dict = {
+ "row_key": expected_key,
+ "mutations": [{"test": "data"}, {"another": "data"}],
+ }
+ with mock.patch.object(mutations.Mutation, "_from_dict") as inner_from_dict:
+ inner_from_dict.side_effect = expected_mutations
+ instance = self._target_class()._from_dict(input_dict)
+ assert instance.row_key == b"row_key"
+ assert inner_from_dict.call_count == 2
+ assert len(instance.mutations) == 2
+ assert instance.mutations[0] == expected_mutations[0]
+ assert instance.mutations[1] == expected_mutations[1]
+
+ def test__from_dict(self):
+ """
+ test creating end-to-end with a real mutation instance
+ """
+ input_dict = {
+ "row_key": b"row_key",
+ "mutations": [{"delete_from_family": {"family_name": "test_family"}}],
+ }
+ instance = self._target_class()._from_dict(input_dict)
+ assert instance.row_key == b"row_key"
+ assert len(instance.mutations) == 1
+ assert isinstance(instance.mutations[0], mutations.DeleteAllFromFamily)
+ assert instance.mutations[0].family_to_delete == "test_family"
diff --git a/tests/unit/data/test_read_modify_write_rules.py b/tests/unit/data/test_read_modify_write_rules.py
new file mode 100644
index 000000000..1f67da13b
--- /dev/null
+++ b/tests/unit/data/test_read_modify_write_rules.py
@@ -0,0 +1,186 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+# try/except added for compatibility with python < 3.8
+try:
+ from unittest import mock
+except ImportError: # pragma: NO COVER
+ import mock # type: ignore
+
+
+class TestBaseReadModifyWriteRule:
+ def _target_class(self):
+ from google.cloud.bigtable.data.read_modify_write_rules import (
+ ReadModifyWriteRule,
+ )
+
+ return ReadModifyWriteRule
+
+ def test_abstract(self):
+ """should not be able to instantiate"""
+ with pytest.raises(TypeError):
+ self._target_class()(family="foo", qualifier=b"bar")
+
+ def test__to_dict(self):
+ """
+ to_dict not implemented in base class
+ """
+ with pytest.raises(NotImplementedError):
+ self._target_class()._to_dict(mock.Mock())
+
+
+class TestIncrementRule:
+ def _target_class(self):
+ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule
+
+ return IncrementRule
+
+ @pytest.mark.parametrize(
+ "args,expected",
+ [
+ (("fam", b"qual", 1), ("fam", b"qual", 1)),
+ (("fam", b"qual", -12), ("fam", b"qual", -12)),
+ (("fam", "qual", 1), ("fam", b"qual", 1)),
+ (("fam", "qual", 0), ("fam", b"qual", 0)),
+ (("", "", 0), ("", b"", 0)),
+ (("f", b"q"), ("f", b"q", 1)),
+ ],
+ )
+ def test_ctor(self, args, expected):
+ instance = self._target_class()(*args)
+ assert instance.family == expected[0]
+ assert instance.qualifier == expected[1]
+ assert instance.increment_amount == expected[2]
+
+ @pytest.mark.parametrize("input_amount", [1.1, None, "1", object(), "", b"", b"1"])
+ def test_ctor_bad_input(self, input_amount):
+ with pytest.raises(TypeError) as e:
+ self._target_class()("fam", b"qual", input_amount)
+ assert "increment_amount must be an integer" in str(e.value)
+
+ @pytest.mark.parametrize(
+ "large_value", [2**64, 2**64 + 1, -(2**64), -(2**64) - 1]
+ )
+ def test_ctor_large_values(self, large_value):
+ with pytest.raises(ValueError) as e:
+ self._target_class()("fam", b"qual", large_value)
+ assert "too large" in str(e.value)
+
+ @pytest.mark.parametrize(
+ "args,expected",
+ [
+ (("fam", b"qual", 1), ("fam", b"qual", 1)),
+ (("fam", b"qual", -12), ("fam", b"qual", -12)),
+ (("fam", "qual", 1), ("fam", b"qual", 1)),
+ (("fam", "qual", 0), ("fam", b"qual", 0)),
+ (("", "", 0), ("", b"", 0)),
+ (("f", b"q"), ("f", b"q", 1)),
+ ],
+ )
+ def test__to_dict(self, args, expected):
+ instance = self._target_class()(*args)
+ expected = {
+ "family_name": expected[0],
+ "column_qualifier": expected[1],
+ "increment_amount": expected[2],
+ }
+ assert instance._to_dict() == expected
+
+ @pytest.mark.parametrize(
+ "args,expected",
+ [
+ (("fam", b"qual", 1), ("fam", b"qual", 1)),
+ (("fam", b"qual", -12), ("fam", b"qual", -12)),
+ (("fam", "qual", 1), ("fam", b"qual", 1)),
+ (("fam", "qual", 0), ("fam", b"qual", 0)),
+ (("", "", 0), ("", b"", 0)),
+ (("f", b"q"), ("f", b"q", 1)),
+ ],
+ )
+ def test__to_pb(self, args, expected):
+ import google.cloud.bigtable_v2.types.data as data_pb
+
+ instance = self._target_class()(*args)
+ pb_result = instance._to_pb()
+ assert isinstance(pb_result, data_pb.ReadModifyWriteRule)
+ assert pb_result.family_name == expected[0]
+ assert pb_result.column_qualifier == expected[1]
+ assert pb_result.increment_amount == expected[2]
+
+
+class TestAppendValueRule:
+ def _target_class(self):
+ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule
+
+ return AppendValueRule
+
+ @pytest.mark.parametrize(
+ "args,expected",
+ [
+ (("fam", b"qual", b"val"), ("fam", b"qual", b"val")),
+ (("fam", "qual", b"val"), ("fam", b"qual", b"val")),
+ (("", "", b""), ("", b"", b"")),
+ (("f", "q", "str_val"), ("f", b"q", b"str_val")),
+ (("f", "q", ""), ("f", b"q", b"")),
+ ],
+ )
+ def test_ctor(self, args, expected):
+ instance = self._target_class()(*args)
+ assert instance.family == expected[0]
+ assert instance.qualifier == expected[1]
+ assert instance.append_value == expected[2]
+
+ @pytest.mark.parametrize("input_val", [5, 1.1, None, object()])
+ def test_ctor_bad_input(self, input_val):
+ with pytest.raises(TypeError) as e:
+ self._target_class()("fam", b"qual", input_val)
+ assert "append_value must be bytes or str" in str(e.value)
+
+ @pytest.mark.parametrize(
+ "args,expected",
+ [
+ (("fam", b"qual", b"val"), ("fam", b"qual", b"val")),
+ (("fam", "qual", b"val"), ("fam", b"qual", b"val")),
+ (("", "", b""), ("", b"", b"")),
+ ],
+ )
+ def test__to_dict(self, args, expected):
+ instance = self._target_class()(*args)
+ expected = {
+ "family_name": expected[0],
+ "column_qualifier": expected[1],
+ "append_value": expected[2],
+ }
+ assert instance._to_dict() == expected
+
+ @pytest.mark.parametrize(
+ "args,expected",
+ [
+ (("fam", b"qual", b"val"), ("fam", b"qual", b"val")),
+ (("fam", "qual", b"val"), ("fam", b"qual", b"val")),
+ (("", "", b""), ("", b"", b"")),
+ ],
+ )
+ def test__to_pb(self, args, expected):
+ import google.cloud.bigtable_v2.types.data as data_pb
+
+ instance = self._target_class()(*args)
+ pb_result = instance._to_pb()
+ assert isinstance(pb_result, data_pb.ReadModifyWriteRule)
+ assert pb_result.family_name == expected[0]
+ assert pb_result.column_qualifier == expected[1]
+ assert pb_result.append_value == expected[2]
diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/test_read_rows_acceptance.py
new file mode 100644
index 000000000..7cb3c08dc
--- /dev/null
+++ b/tests/unit/data/test_read_rows_acceptance.py
@@ -0,0 +1,331 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import os
+from itertools import zip_longest
+
+import pytest
+import mock
+
+from google.cloud.bigtable_v2 import ReadRowsResponse
+
+from google.cloud.bigtable.data._async.client import BigtableDataClientAsync
+from google.cloud.bigtable.data.exceptions import InvalidChunk
+from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync
+from google.cloud.bigtable.data.row import Row
+
+from ..v2_client.test_row_merger import ReadRowsTest, TestFile
+
+
+def parse_readrows_acceptance_tests():
+ dirname = os.path.dirname(__file__)
+ filename = os.path.join(dirname, "./read-rows-acceptance-test.json")
+
+ with open(filename) as json_file:
+ test_json = TestFile.from_json(json_file.read())
+ return test_json.read_rows_tests
+
+
+def extract_results_from_row(row: Row):
+ results = []
+ for family, col, cells in row.items():
+ for cell in cells:
+ results.append(
+ ReadRowsTest.Result(
+ row_key=row.row_key,
+ family_name=family,
+ qualifier=col,
+ timestamp_micros=cell.timestamp_ns // 1000,
+ value=cell.value,
+ label=(cell.labels[0] if cell.labels else ""),
+ )
+ )
+ return results
+
+
+@pytest.mark.parametrize(
+ "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description
+)
+@pytest.mark.asyncio
+async def test_row_merger_scenario(test_case: ReadRowsTest):
+ async def _scenerio_stream():
+ for chunk in test_case.chunks:
+ yield ReadRowsResponse(chunks=[chunk])
+
+ try:
+ results = []
+ instance = mock.Mock()
+ instance._last_yielded_row_key = None
+ instance._remaining_count = None
+ chunker = _ReadRowsOperationAsync.chunk_stream(
+ instance, _coro_wrapper(_scenerio_stream())
+ )
+ merger = _ReadRowsOperationAsync.merge_rows(chunker)
+ async for row in merger:
+ for cell in row:
+ cell_result = ReadRowsTest.Result(
+ row_key=cell.row_key,
+ family_name=cell.family,
+ qualifier=cell.qualifier,
+ timestamp_micros=cell.timestamp_micros,
+ value=cell.value,
+ label=cell.labels[0] if cell.labels else "",
+ )
+ results.append(cell_result)
+ except InvalidChunk:
+ results.append(ReadRowsTest.Result(error=True))
+ for expected, actual in zip_longest(test_case.results, results):
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description
+)
+@pytest.mark.asyncio
+async def test_read_rows_scenario(test_case: ReadRowsTest):
+ async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]):
+ from google.cloud.bigtable_v2 import ReadRowsResponse
+
+ class mock_stream:
+ def __init__(self, chunk_list):
+ self.chunk_list = chunk_list
+ self.idx = -1
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ self.idx += 1
+ if len(self.chunk_list) > self.idx:
+ chunk = self.chunk_list[self.idx]
+ return ReadRowsResponse(chunks=[chunk])
+ raise StopAsyncIteration
+
+ def cancel(self):
+ pass
+
+ return mock_stream(chunk_list)
+
+ try:
+ with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}):
+ # use emulator mode to avoid auth issues in CI
+ client = BigtableDataClientAsync()
+ table = client.get_table("instance", "table")
+ results = []
+ with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows:
+ # run once, then return error on retry
+ read_rows.return_value = _make_gapic_stream(test_case.chunks)
+ async for row in await table.read_rows_stream(query={}):
+ for cell in row:
+ cell_result = ReadRowsTest.Result(
+ row_key=cell.row_key,
+ family_name=cell.family,
+ qualifier=cell.qualifier,
+ timestamp_micros=cell.timestamp_micros,
+ value=cell.value,
+ label=cell.labels[0] if cell.labels else "",
+ )
+ results.append(cell_result)
+ except InvalidChunk:
+ results.append(ReadRowsTest.Result(error=True))
+ finally:
+ await client.close()
+ for expected, actual in zip_longest(test_case.results, results):
+ assert actual == expected
+
+
+@pytest.mark.asyncio
+async def test_out_of_order_rows():
+ async def _row_stream():
+ yield ReadRowsResponse(last_scanned_row_key=b"a")
+
+ instance = mock.Mock()
+ instance._remaining_count = None
+ instance._last_yielded_row_key = b"b"
+ chunker = _ReadRowsOperationAsync.chunk_stream(
+ instance, _coro_wrapper(_row_stream())
+ )
+ merger = _ReadRowsOperationAsync.merge_rows(chunker)
+ with pytest.raises(InvalidChunk):
+ async for _ in merger:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_bare_reset():
+ first_chunk = ReadRowsResponse.CellChunk(
+ ReadRowsResponse.CellChunk(
+ row_key=b"a", family_name="f", qualifier=b"q", value=b"v"
+ )
+ )
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ first_chunk,
+ ReadRowsResponse.CellChunk(
+ ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a")
+ ),
+ )
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ first_chunk,
+ ReadRowsResponse.CellChunk(
+ ReadRowsResponse.CellChunk(reset_row=True, family_name="f")
+ ),
+ )
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ first_chunk,
+ ReadRowsResponse.CellChunk(
+ ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q")
+ ),
+ )
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ first_chunk,
+ ReadRowsResponse.CellChunk(
+ ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000)
+ ),
+ )
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ first_chunk,
+ ReadRowsResponse.CellChunk(
+ ReadRowsResponse.CellChunk(reset_row=True, labels=["a"])
+ ),
+ )
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ first_chunk,
+ ReadRowsResponse.CellChunk(
+ ReadRowsResponse.CellChunk(reset_row=True, value=b"v")
+ ),
+ )
+
+
+@pytest.mark.asyncio
+async def test_missing_family():
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ ReadRowsResponse.CellChunk(
+ row_key=b"a",
+ qualifier=b"q",
+ timestamp_micros=1000,
+ value=b"v",
+ commit_row=True,
+ )
+ )
+
+
+@pytest.mark.asyncio
+async def test_mid_cell_row_key_change():
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ ReadRowsResponse.CellChunk(
+ row_key=b"a",
+ family_name="f",
+ qualifier=b"q",
+ timestamp_micros=1000,
+ value_size=2,
+ value=b"v",
+ ),
+ ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True),
+ )
+
+
+@pytest.mark.asyncio
+async def test_mid_cell_family_change():
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ ReadRowsResponse.CellChunk(
+ row_key=b"a",
+ family_name="f",
+ qualifier=b"q",
+ timestamp_micros=1000,
+ value_size=2,
+ value=b"v",
+ ),
+ ReadRowsResponse.CellChunk(family_name="f2", value=b"v", commit_row=True),
+ )
+
+
+@pytest.mark.asyncio
+async def test_mid_cell_qualifier_change():
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ ReadRowsResponse.CellChunk(
+ row_key=b"a",
+ family_name="f",
+ qualifier=b"q",
+ timestamp_micros=1000,
+ value_size=2,
+ value=b"v",
+ ),
+ ReadRowsResponse.CellChunk(qualifier=b"q2", value=b"v", commit_row=True),
+ )
+
+
+@pytest.mark.asyncio
+async def test_mid_cell_timestamp_change():
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ ReadRowsResponse.CellChunk(
+ row_key=b"a",
+ family_name="f",
+ qualifier=b"q",
+ timestamp_micros=1000,
+ value_size=2,
+ value=b"v",
+ ),
+ ReadRowsResponse.CellChunk(
+ timestamp_micros=2000, value=b"v", commit_row=True
+ ),
+ )
+
+
+@pytest.mark.asyncio
+async def test_mid_cell_labels_change():
+ with pytest.raises(InvalidChunk):
+ await _process_chunks(
+ ReadRowsResponse.CellChunk(
+ row_key=b"a",
+ family_name="f",
+ qualifier=b"q",
+ timestamp_micros=1000,
+ value_size=2,
+ value=b"v",
+ ),
+ ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True),
+ )
+
+
+async def _coro_wrapper(stream):
+ return stream
+
+
+async def _process_chunks(*chunks):
+ async def _row_stream():
+ yield ReadRowsResponse(chunks=chunks)
+
+ instance = mock.Mock()
+ instance._remaining_count = None
+ instance._last_yielded_row_key = None
+ chunker = _ReadRowsOperationAsync.chunk_stream(
+ instance, _coro_wrapper(_row_stream())
+ )
+ merger = _ReadRowsOperationAsync.merge_rows(chunker)
+ results = []
+ async for row in merger:
+ results.append(row)
+ return results
diff --git a/tests/unit/data/test_read_rows_query.py b/tests/unit/data/test_read_rows_query.py
new file mode 100644
index 000000000..ba3b0468b
--- /dev/null
+++ b/tests/unit/data/test_read_rows_query.py
@@ -0,0 +1,589 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+TEST_ROWS = [
+ "row_key_1",
+ b"row_key_2",
+]
+
+
+class TestRowRange:
+ @staticmethod
+ def _get_target_class():
+ from google.cloud.bigtable.data.read_rows_query import RowRange
+
+ return RowRange
+
+ def _make_one(self, *args, **kwargs):
+ return self._get_target_class()(*args, **kwargs)
+
+ def test_ctor_start_end(self):
+ row_range = self._make_one("test_row", "test_row2")
+ assert row_range.start_key == "test_row".encode()
+ assert row_range.end_key == "test_row2".encode()
+ assert row_range.start_is_inclusive is True
+ assert row_range.end_is_inclusive is False
+
+ def test_ctor_start_only(self):
+ row_range = self._make_one("test_row3")
+ assert row_range.start_key == "test_row3".encode()
+ assert row_range.start_is_inclusive is True
+ assert row_range.end_key is None
+ assert row_range.end_is_inclusive is True
+
+ def test_ctor_end_only(self):
+ row_range = self._make_one(end_key="test_row4")
+ assert row_range.end_key == "test_row4".encode()
+ assert row_range.end_is_inclusive is False
+ assert row_range.start_key is None
+ assert row_range.start_is_inclusive is True
+
+ def test_ctor_empty_strings(self):
+ """
+ empty strings should be treated as None
+ """
+ row_range = self._make_one("", "")
+ assert row_range.start_key is None
+ assert row_range.end_key is None
+ assert row_range.start_is_inclusive is True
+ assert row_range.end_is_inclusive is True
+
+ def test_ctor_inclusive_flags(self):
+ row_range = self._make_one("test_row5", "test_row6", False, True)
+ assert row_range.start_key == "test_row5".encode()
+ assert row_range.end_key == "test_row6".encode()
+ assert row_range.start_is_inclusive is False
+ assert row_range.end_is_inclusive is True
+
+ def test_ctor_defaults(self):
+ row_range = self._make_one()
+ assert row_range.start_key is None
+ assert row_range.end_key is None
+
+ def test_ctor_invalid_keys(self):
+ # test with invalid keys
+ with pytest.raises(ValueError) as exc:
+ self._make_one(1, "2")
+ assert str(exc.value) == "start_key must be a string or bytes"
+ with pytest.raises(ValueError) as exc:
+ self._make_one("1", 2)
+ assert str(exc.value) == "end_key must be a string or bytes"
+ with pytest.raises(ValueError) as exc:
+ self._make_one("2", "1")
+ assert str(exc.value) == "start_key must be less than or equal to end_key"
+
+ @pytest.mark.parametrize(
+ "dict_repr,expected",
+ [
+ ({"start_key_closed": "test_row", "end_key_open": "test_row2"}, True),
+ ({"start_key_closed": b"test_row", "end_key_open": b"test_row2"}, True),
+ ({"start_key_open": "test_row", "end_key_closed": "test_row2"}, True),
+ ({"start_key_open": b"a"}, True),
+ ({"end_key_closed": b"b"}, True),
+ ({"start_key_closed": "a"}, True),
+ ({"end_key_open": b"b"}, True),
+ ({}, False),
+ ],
+ )
+ def test___bool__(self, dict_repr, expected):
+ """
+ Only row range with both points empty should be falsy
+ """
+ from google.cloud.bigtable.data.read_rows_query import RowRange
+
+ row_range = RowRange._from_dict(dict_repr)
+ assert bool(row_range) is expected
+
+ def test__eq__(self):
+ """
+ test that row ranges can be compared for equality
+ """
+ from google.cloud.bigtable.data.read_rows_query import RowRange
+
+ range1 = RowRange("1", "2")
+ range1_dup = RowRange("1", "2")
+ range2 = RowRange("1", "3")
+ range_w_empty = RowRange(None, "2")
+ assert range1 == range1_dup
+ assert range1 != range2
+ assert range1 != range_w_empty
+ range_1_w_inclusive_start = RowRange("1", "2", start_is_inclusive=True)
+ range_1_w_exclusive_start = RowRange("1", "2", start_is_inclusive=False)
+ range_1_w_inclusive_end = RowRange("1", "2", end_is_inclusive=True)
+ range_1_w_exclusive_end = RowRange("1", "2", end_is_inclusive=False)
+ assert range1 == range_1_w_inclusive_start
+ assert range1 == range_1_w_exclusive_end
+ assert range1 != range_1_w_exclusive_start
+ assert range1 != range_1_w_inclusive_end
+
+ @pytest.mark.parametrize(
+ "dict_repr,expected",
+ [
+ (
+ {"start_key_closed": "test_row", "end_key_open": "test_row2"},
+ "[b'test_row', b'test_row2')",
+ ),
+ (
+ {"start_key_open": "test_row", "end_key_closed": "test_row2"},
+ "(b'test_row', b'test_row2']",
+ ),
+ ({"start_key_open": b"a"}, "(b'a', +inf]"),
+ ({"end_key_closed": b"b"}, "[-inf, b'b']"),
+ ({"end_key_open": b"b"}, "[-inf, b'b')"),
+ ({}, "[-inf, +inf]"),
+ ],
+ )
+ def test___str__(self, dict_repr, expected):
+ """
+ test string representations of row ranges
+ """
+ from google.cloud.bigtable.data.read_rows_query import RowRange
+
+ row_range = RowRange._from_dict(dict_repr)
+ assert str(row_range) == expected
+
+ @pytest.mark.parametrize(
+ "dict_repr,expected",
+ [
+ (
+ {"start_key_closed": "test_row", "end_key_open": "test_row2"},
+ "RowRange(start_key=b'test_row', end_key=b'test_row2')",
+ ),
+ (
+ {"start_key_open": "test_row", "end_key_closed": "test_row2"},
+ "RowRange(start_key=b'test_row', end_key=b'test_row2', start_is_inclusive=False, end_is_inclusive=True)",
+ ),
+ (
+ {"start_key_open": b"a"},
+ "RowRange(start_key=b'a', end_key=None, start_is_inclusive=False)",
+ ),
+ (
+ {"end_key_closed": b"b"},
+ "RowRange(start_key=None, end_key=b'b', end_is_inclusive=True)",
+ ),
+ ({"end_key_open": b"b"}, "RowRange(start_key=None, end_key=b'b')"),
+ ({}, "RowRange(start_key=None, end_key=None)"),
+ ],
+ )
+ def test___repr__(self, dict_repr, expected):
+ """
+ test repr representations of row ranges
+ """
+ from google.cloud.bigtable.data.read_rows_query import RowRange
+
+ row_range = RowRange._from_dict(dict_repr)
+ assert repr(row_range) == expected
+
+
+class TestReadRowsQuery:
+ @staticmethod
+ def _get_target_class():
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+
+ return ReadRowsQuery
+
+ def _make_one(self, *args, **kwargs):
+ return self._get_target_class()(*args, **kwargs)
+
+ def test_ctor_defaults(self):
+ query = self._make_one()
+ assert query.row_keys == list()
+ assert query.row_ranges == list()
+ assert query.filter is None
+ assert query.limit is None
+
+ def test_ctor_explicit(self):
+ from google.cloud.bigtable.data.row_filters import RowFilterChain
+ from google.cloud.bigtable.data.read_rows_query import RowRange
+
+ filter_ = RowFilterChain()
+ query = self._make_one(
+ ["row_key_1", "row_key_2"],
+ row_ranges=[RowRange("row_key_3", "row_key_4")],
+ limit=10,
+ row_filter=filter_,
+ )
+ assert len(query.row_keys) == 2
+ assert "row_key_1".encode() in query.row_keys
+ assert "row_key_2".encode() in query.row_keys
+ assert len(query.row_ranges) == 1
+ assert RowRange("row_key_3", "row_key_4") in query.row_ranges
+ assert query.filter == filter_
+ assert query.limit == 10
+
+ def test_ctor_invalid_limit(self):
+ with pytest.raises(ValueError) as exc:
+ self._make_one(limit=-1)
+ assert str(exc.value) == "limit must be >= 0"
+
+ def test_set_filter(self):
+ from google.cloud.bigtable.data.row_filters import RowFilterChain
+
+ filter1 = RowFilterChain()
+ query = self._make_one()
+ assert query.filter is None
+ query.filter = filter1
+ assert query.filter == filter1
+ filter2 = RowFilterChain()
+ query.filter = filter2
+ assert query.filter == filter2
+ query.filter = None
+ assert query.filter is None
+ query.filter = RowFilterChain()
+ assert query.filter == RowFilterChain()
+
+ def test_set_limit(self):
+ query = self._make_one()
+ assert query.limit is None
+ query.limit = 10
+ assert query.limit == 10
+ query.limit = 9
+ assert query.limit == 9
+ query.limit = 0
+ assert query.limit is None
+ with pytest.raises(ValueError) as exc:
+ query.limit = -1
+ assert str(exc.value) == "limit must be >= 0"
+ with pytest.raises(ValueError) as exc:
+ query.limit = -100
+ assert str(exc.value) == "limit must be >= 0"
+
+ def test_add_key_str(self):
+ query = self._make_one()
+ assert query.row_keys == list()
+ input_str = "test_row"
+ query.add_key(input_str)
+ assert len(query.row_keys) == 1
+ assert input_str.encode() in query.row_keys
+ input_str2 = "test_row2"
+ query.add_key(input_str2)
+ assert len(query.row_keys) == 2
+ assert input_str.encode() in query.row_keys
+ assert input_str2.encode() in query.row_keys
+
+ def test_add_key_bytes(self):
+ query = self._make_one()
+ assert query.row_keys == list()
+ input_bytes = b"test_row"
+ query.add_key(input_bytes)
+ assert len(query.row_keys) == 1
+ assert input_bytes in query.row_keys
+ input_bytes2 = b"test_row2"
+ query.add_key(input_bytes2)
+ assert len(query.row_keys) == 2
+ assert input_bytes in query.row_keys
+ assert input_bytes2 in query.row_keys
+
+ def test_add_rows_batch(self):
+ query = self._make_one()
+ assert query.row_keys == list()
+ input_batch = ["test_row", b"test_row2", "test_row3"]
+ for k in input_batch:
+ query.add_key(k)
+ assert len(query.row_keys) == 3
+ assert b"test_row" in query.row_keys
+ assert b"test_row2" in query.row_keys
+ assert b"test_row3" in query.row_keys
+ # test adding another batch
+ for k in ["test_row4", b"test_row5"]:
+ query.add_key(k)
+ assert len(query.row_keys) == 5
+ assert input_batch[0].encode() in query.row_keys
+ assert input_batch[1] in query.row_keys
+ assert input_batch[2].encode() in query.row_keys
+ assert b"test_row4" in query.row_keys
+ assert b"test_row5" in query.row_keys
+
+ def test_add_key_invalid(self):
+ query = self._make_one()
+ with pytest.raises(ValueError) as exc:
+ query.add_key(1)
+ assert str(exc.value) == "row_key must be string or bytes"
+ with pytest.raises(ValueError) as exc:
+ query.add_key(["s"])
+ assert str(exc.value) == "row_key must be string or bytes"
+
+ def test_add_range(self):
+ from google.cloud.bigtable.data.read_rows_query import RowRange
+
+ query = self._make_one()
+ assert query.row_ranges == list()
+ input_range = RowRange(start_key=b"test_row")
+ query.add_range(input_range)
+ assert len(query.row_ranges) == 1
+ assert input_range in query.row_ranges
+ input_range2 = RowRange(start_key=b"test_row2")
+ query.add_range(input_range2)
+ assert len(query.row_ranges) == 2
+ assert input_range in query.row_ranges
+ assert input_range2 in query.row_ranges
+
+ def _parse_query_string(self, query_string):
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery, RowRange
+
+ query = ReadRowsQuery()
+ segments = query_string.split(",")
+ for segment in segments:
+ if "-" in segment:
+ start, end = segment.split("-")
+ s_open, e_open = True, True
+ if start == "":
+ start = None
+ s_open = None
+ else:
+ if start[0] == "(":
+ s_open = False
+ start = start[1:]
+ if end == "":
+ end = None
+ e_open = None
+ else:
+ if end[-1] == ")":
+ e_open = False
+ end = end[:-1]
+ query.add_range(RowRange(start, end, s_open, e_open))
+ else:
+ query.add_key(segment)
+ return query
+
+ @pytest.mark.parametrize(
+ "query_string,shard_points",
+ [
+ ("a,[p-q)", []),
+ ("0_key,[1_range_start-2_range_end)", ["3_split"]),
+ ("0_key,[1_range_start-2_range_end)", ["2_range_end"]),
+ ("0_key,[1_range_start-2_range_end]", ["2_range_end"]),
+ ("-1_range_end)", ["5_split"]),
+ ("8_key,(1_range_start-2_range_end]", ["1_range_start"]),
+ ("9_row_key,(5_range_start-7_range_end)", ["3_split"]),
+ ("3_row_key,(5_range_start-7_range_end)", ["2_row_key"]),
+ ("4_split,4_split,(3_split-5_split]", ["3_split", "5_split"]),
+ ("(3_split-", ["3_split"]),
+ ],
+ )
+ def test_shard_no_split(self, query_string, shard_points):
+ """
+ Test sharding with a set of queries that should not result in any splits.
+ """
+ initial_query = self._parse_query_string(query_string)
+ row_samples = [(point.encode(), None) for point in shard_points]
+ sharded_queries = initial_query.shard(row_samples)
+ assert len(sharded_queries) == 1
+ assert initial_query == sharded_queries[0]
+
+ def test_shard_full_table_scan_empty_split(self):
+ """
+ Sharding a full table scan with no split should return another full table scan.
+ """
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+
+ full_scan_query = ReadRowsQuery()
+ split_points = []
+ sharded_queries = full_scan_query.shard(split_points)
+ assert len(sharded_queries) == 1
+ result_query = sharded_queries[0]
+ assert result_query == full_scan_query
+
+ def test_shard_full_table_scan_with_split(self):
+ """
+ Test splitting a full table scan into two queries
+ """
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+
+ full_scan_query = ReadRowsQuery()
+ split_points = [(b"a", None)]
+ sharded_queries = full_scan_query.shard(split_points)
+ assert len(sharded_queries) == 2
+ assert sharded_queries[0] == self._parse_query_string("-a]")
+ assert sharded_queries[1] == self._parse_query_string("(a-")
+
+ def test_shard_full_table_scan_with_multiple_split(self):
+ """
+ Test splitting a full table scan into three queries
+ """
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+
+ full_scan_query = ReadRowsQuery()
+ split_points = [(b"a", None), (b"z", None)]
+ sharded_queries = full_scan_query.shard(split_points)
+ assert len(sharded_queries) == 3
+ assert sharded_queries[0] == self._parse_query_string("-a]")
+ assert sharded_queries[1] == self._parse_query_string("(a-z]")
+ assert sharded_queries[2] == self._parse_query_string("(z-")
+
+ def test_shard_multiple_keys(self):
+ """
+ Test splitting multiple individual keys into separate queries
+ """
+ initial_query = self._parse_query_string("1_beforeSplit,2_onSplit,3_afterSplit")
+ split_points = [(b"2_onSplit", None)]
+ sharded_queries = initial_query.shard(split_points)
+ assert len(sharded_queries) == 2
+ assert sharded_queries[0] == self._parse_query_string("1_beforeSplit,2_onSplit")
+ assert sharded_queries[1] == self._parse_query_string("3_afterSplit")
+
+ def test_shard_keys_empty_left(self):
+ """
+ Test with the left-most split point empty
+ """
+ initial_query = self._parse_query_string("5_test,8_test")
+ split_points = [(b"0_split", None), (b"6_split", None)]
+ sharded_queries = initial_query.shard(split_points)
+ assert len(sharded_queries) == 2
+ assert sharded_queries[0] == self._parse_query_string("5_test")
+ assert sharded_queries[1] == self._parse_query_string("8_test")
+
+ def test_shard_keys_empty_right(self):
+ """
+ Test with the right-most split point empty
+ """
+ initial_query = self._parse_query_string("0_test,2_test")
+ split_points = [(b"1_split", None), (b"5_split", None)]
+ sharded_queries = initial_query.shard(split_points)
+ assert len(sharded_queries) == 2
+ assert sharded_queries[0] == self._parse_query_string("0_test")
+ assert sharded_queries[1] == self._parse_query_string("2_test")
+
+ def test_shard_mixed_split(self):
+ """
+ Test splitting a complex query with multiple split points
+ """
+ initial_query = self._parse_query_string("0,a,c,-a],-b],(c-e],(d-f],(m-")
+ split_points = [(s.encode(), None) for s in ["a", "d", "j", "o"]]
+ sharded_queries = initial_query.shard(split_points)
+ assert len(sharded_queries) == 5
+ assert sharded_queries[0] == self._parse_query_string("0,a,-a]")
+ assert sharded_queries[1] == self._parse_query_string("c,(a-b],(c-d]")
+ assert sharded_queries[2] == self._parse_query_string("(d-e],(d-f]")
+ assert sharded_queries[3] == self._parse_query_string("(m-o]")
+ assert sharded_queries[4] == self._parse_query_string("(o-")
+
+ def test_shard_unsorted_request(self):
+ """
+ Test with a query that contains rows and queries in a random order
+ """
+ initial_query = self._parse_query_string(
+ "7_row_key_1,2_row_key_2,[8_range_1_start-9_range_1_end),[3_range_2_start-4_range_2_end)"
+ )
+ split_points = [(b"5-split", None)]
+ sharded_queries = initial_query.shard(split_points)
+ assert len(sharded_queries) == 2
+ assert sharded_queries[0] == self._parse_query_string(
+ "2_row_key_2,[3_range_2_start-4_range_2_end)"
+ )
+ assert sharded_queries[1] == self._parse_query_string(
+ "7_row_key_1,[8_range_1_start-9_range_1_end)"
+ )
+
+ @pytest.mark.parametrize(
+ "query_string,shard_points",
+ [
+ ("a,[p-q)", []),
+ ("0_key,[1_range_start-2_range_end)", ["3_split"]),
+ ("-1_range_end)", ["5_split"]),
+ ("0_key,[1_range_start-2_range_end)", ["2_range_end"]),
+ ("9_row_key,(5_range_start-7_range_end)", ["3_split"]),
+ ("(5_range_start-", ["3_split"]),
+ ("3_split,[3_split-5_split)", ["3_split", "5_split"]),
+ ("[3_split-", ["3_split"]),
+ ("", []),
+ ("", ["3_split"]),
+ ("", ["3_split", "5_split"]),
+ ("1,2,3,4,5,6,7,8,9", ["3_split"]),
+ ],
+ )
+ def test_shard_keeps_filter(self, query_string, shard_points):
+ """
+ sharded queries should keep the filter from the original query
+ """
+ initial_query = self._parse_query_string(query_string)
+ expected_filter = {"test": "filter"}
+ initial_query.filter = expected_filter
+ row_samples = [(point.encode(), None) for point in shard_points]
+ sharded_queries = initial_query.shard(row_samples)
+ assert len(sharded_queries) > 0
+ for query in sharded_queries:
+ assert query.filter == expected_filter
+
+ def test_shard_limit_exception(self):
+ """
+ queries with a limit should raise an exception when a shard is attempted
+ """
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+
+ query = ReadRowsQuery(limit=10)
+ with pytest.raises(AttributeError) as e:
+ query.shard([])
+ assert "Cannot shard query with a limit" in str(e.value)
+
+ @pytest.mark.parametrize(
+ "first_args,second_args,expected",
+ [
+ ((), (), True),
+ ((), ("a",), False),
+ (("a",), (), False),
+ (("a",), ("a",), True),
+ ((["a"],), (["a", "b"],), False),
+ ((["a", "b"],), (["a", "b"],), True),
+ ((["a", b"b"],), ([b"a", "b"],), True),
+ (("a",), (b"a",), True),
+ (("a",), ("b",), False),
+ (("a",), ("a", ["b"]), False),
+ (("a", "b"), ("a", ["b"]), True),
+ (("a", ["b"]), ("a", ["b", "c"]), False),
+ (("a", ["b", "c"]), ("a", [b"b", "c"]), True),
+ (("a", ["b", "c"], 1), ("a", ["b", b"c"], 1), True),
+ (("a", ["b"], 1), ("a", ["b"], 2), False),
+ (("a", ["b"], 1, {"a": "b"}), ("a", ["b"], 1, {"a": "b"}), True),
+ (("a", ["b"], 1, {"a": "b"}), ("a", ["b"], 1), False),
+ (
+ (),
+ (None, [None], None, None),
+ True,
+ ), # empty query is equal to empty row range
+ ((), (None, [None], 1, None), False),
+ ((), (None, [None], None, {"a": "b"}), False),
+ ],
+ )
+ def test___eq__(self, first_args, second_args, expected):
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+ from google.cloud.bigtable.data.read_rows_query import RowRange
+
+ # replace row_range placeholders with a RowRange object
+ if len(first_args) > 1:
+ first_args = list(first_args)
+ first_args[1] = [RowRange(c) for c in first_args[1]]
+ if len(second_args) > 1:
+ second_args = list(second_args)
+ second_args[1] = [RowRange(c) for c in second_args[1]]
+ first = ReadRowsQuery(*first_args)
+ second = ReadRowsQuery(*second_args)
+ assert (first == second) == expected
+
+ def test___repr__(self):
+ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
+
+ instance = self._make_one(row_keys=["a", "b"], row_filter={}, limit=10)
+ # should be able to recreate the instance from the repr
+ repr_str = repr(instance)
+ recreated = eval(repr_str)
+ assert isinstance(recreated, ReadRowsQuery)
+ assert recreated == instance
+
+ def test_empty_row_set(self):
+ """Empty strings should be treated as keys inputs"""
+ query = self._make_one(row_keys="")
+ assert query.row_keys == [b""]
diff --git a/tests/unit/data/test_row.py b/tests/unit/data/test_row.py
new file mode 100644
index 000000000..10b5bdb23
--- /dev/null
+++ b/tests/unit/data/test_row.py
@@ -0,0 +1,718 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import time
+
+TEST_VALUE = b"1234"
+TEST_ROW_KEY = b"row"
+TEST_FAMILY_ID = "cf1"
+TEST_QUALIFIER = b"col"
+TEST_TIMESTAMP = time.time_ns() // 1000
+TEST_LABELS = ["label1", "label2"]
+
+
+class TestRow(unittest.TestCase):
+ @staticmethod
+ def _get_target_class():
+ from google.cloud.bigtable.data.row import Row
+
+ return Row
+
+ def _make_one(self, *args, **kwargs):
+ if len(args) == 0:
+ args = (TEST_ROW_KEY, [self._make_cell()])
+ return self._get_target_class()(*args, **kwargs)
+
+ def _make_cell(
+ self,
+ value=TEST_VALUE,
+ row_key=TEST_ROW_KEY,
+ family_id=TEST_FAMILY_ID,
+ qualifier=TEST_QUALIFIER,
+ timestamp=TEST_TIMESTAMP,
+ labels=TEST_LABELS,
+ ):
+ from google.cloud.bigtable.data.row import Cell
+
+ return Cell(value, row_key, family_id, qualifier, timestamp, labels)
+
+ def test_ctor(self):
+ cells = [self._make_cell(), self._make_cell()]
+ row_response = self._make_one(TEST_ROW_KEY, cells)
+ self.assertEqual(list(row_response), cells)
+ self.assertEqual(row_response.row_key, TEST_ROW_KEY)
+
+ def test__from_pb(self):
+ """
+ Construct from protobuf.
+ """
+ from google.cloud.bigtable_v2.types import Row as RowPB
+ from google.cloud.bigtable_v2.types import Family as FamilyPB
+ from google.cloud.bigtable_v2.types import Column as ColumnPB
+ from google.cloud.bigtable_v2.types import Cell as CellPB
+
+ row_key = b"row_key"
+ cells = [
+ CellPB(
+ value=str(i).encode(),
+ timestamp_micros=TEST_TIMESTAMP,
+ labels=TEST_LABELS,
+ )
+ for i in range(2)
+ ]
+ column = ColumnPB(qualifier=TEST_QUALIFIER, cells=cells)
+ families_pb = [FamilyPB(name=TEST_FAMILY_ID, columns=[column])]
+ row_pb = RowPB(key=row_key, families=families_pb)
+ output = self._get_target_class()._from_pb(row_pb)
+ self.assertEqual(output.row_key, row_key)
+ self.assertEqual(len(output), 2)
+ self.assertEqual(output[0].value, b"0")
+ self.assertEqual(output[1].value, b"1")
+ self.assertEqual(output[0].timestamp_micros, TEST_TIMESTAMP)
+ self.assertEqual(output[0].labels, TEST_LABELS)
+ assert output[0].row_key == row_key
+ assert output[0].family == TEST_FAMILY_ID
+ assert output[0].qualifier == TEST_QUALIFIER
+
+ def test__from_pb_sparse(self):
+ """
+ Construct from minimal protobuf.
+ """
+ from google.cloud.bigtable_v2.types import Row as RowPB
+
+ row_key = b"row_key"
+ row_pb = RowPB(key=row_key)
+ output = self._get_target_class()._from_pb(row_pb)
+ self.assertEqual(output.row_key, row_key)
+ self.assertEqual(len(output), 0)
+
+ def test_get_cells(self):
+ cell_list = []
+ for family_id in ["1", "2"]:
+ for qualifier in [b"a", b"b"]:
+ cell = self._make_cell(family_id=family_id, qualifier=qualifier)
+ cell_list.append(cell)
+ # test getting all cells
+ row_response = self._make_one(TEST_ROW_KEY, cell_list)
+ self.assertEqual(row_response.get_cells(), cell_list)
+ # test getting cells in a family
+ output = row_response.get_cells(family="1")
+ self.assertEqual(len(output), 2)
+ self.assertEqual(output[0].family, "1")
+ self.assertEqual(output[1].family, "1")
+ self.assertEqual(output[0], cell_list[0])
+ # test getting cells in a family/qualifier
+ # should accept bytes or str for qualifier
+ for q in [b"a", "a"]:
+ output = row_response.get_cells(family="1", qualifier=q)
+ self.assertEqual(len(output), 1)
+ self.assertEqual(output[0].family, "1")
+ self.assertEqual(output[0].qualifier, b"a")
+ self.assertEqual(output[0], cell_list[0])
+ # calling with just qualifier should raise an error
+ with self.assertRaises(ValueError):
+ row_response.get_cells(qualifier=b"a")
+ # test calling with bad family or qualifier
+ with self.assertRaises(ValueError):
+ row_response.get_cells(family="3", qualifier=b"a")
+ with self.assertRaises(ValueError):
+ row_response.get_cells(family="3")
+ with self.assertRaises(ValueError):
+ row_response.get_cells(family="1", qualifier=b"c")
+
+ def test___repr__(self):
+ cell_str = (
+ "{'value': b'1234', 'timestamp_micros': %d, 'labels': ['label1', 'label2']}"
+ % (TEST_TIMESTAMP)
+ )
+ expected_prefix = "Row(key=b'row', cells="
+ row = self._make_one(TEST_ROW_KEY, [self._make_cell()])
+ self.assertIn(expected_prefix, repr(row))
+ self.assertIn(cell_str, repr(row))
+ expected_full = (
+ "Row(key=b'row', cells={\n ('cf1', b'col'): [{'value': b'1234', 'timestamp_micros': %d, 'labels': ['label1', 'label2']}],\n})"
+ % (TEST_TIMESTAMP)
+ )
+ self.assertEqual(expected_full, repr(row))
+ # try with multiple cells
+ row = self._make_one(TEST_ROW_KEY, [self._make_cell(), self._make_cell()])
+ self.assertIn(expected_prefix, repr(row))
+ self.assertIn(cell_str, repr(row))
+
+ def test___str__(self):
+ cells = [
+ self._make_cell(value=b"1234", family_id="1", qualifier=b"col"),
+ self._make_cell(value=b"5678", family_id="3", qualifier=b"col"),
+ self._make_cell(value=b"1", family_id="3", qualifier=b"col"),
+ self._make_cell(value=b"2", family_id="3", qualifier=b"col"),
+ ]
+
+ row_response = self._make_one(TEST_ROW_KEY, cells)
+ expected = (
+ "{\n"
+ + " (family='1', qualifier=b'col'): [b'1234'],\n"
+ + " (family='3', qualifier=b'col'): [b'5678', (+2 more)],\n"
+ + "}"
+ )
+ self.assertEqual(expected, str(row_response))
+
+ def test_to_dict(self):
+ from google.cloud.bigtable_v2.types import Row
+
+ cell1 = self._make_cell()
+ cell2 = self._make_cell()
+ cell2.value = b"other"
+ row = self._make_one(TEST_ROW_KEY, [cell1, cell2])
+ row_dict = row._to_dict()
+ expected_dict = {
+ "key": TEST_ROW_KEY,
+ "families": [
+ {
+ "name": TEST_FAMILY_ID,
+ "columns": [
+ {
+ "qualifier": TEST_QUALIFIER,
+ "cells": [
+ {
+ "value": TEST_VALUE,
+ "timestamp_micros": TEST_TIMESTAMP,
+ "labels": TEST_LABELS,
+ },
+ {
+ "value": b"other",
+ "timestamp_micros": TEST_TIMESTAMP,
+ "labels": TEST_LABELS,
+ },
+ ],
+ }
+ ],
+ },
+ ],
+ }
+ self.assertEqual(len(row_dict), len(expected_dict))
+ for key, value in expected_dict.items():
+ self.assertEqual(row_dict[key], value)
+ # should be able to construct a Cell proto from the dict
+ row_proto = Row(**row_dict)
+ self.assertEqual(row_proto.key, TEST_ROW_KEY)
+ self.assertEqual(len(row_proto.families), 1)
+ family = row_proto.families[0]
+ self.assertEqual(family.name, TEST_FAMILY_ID)
+ self.assertEqual(len(family.columns), 1)
+ column = family.columns[0]
+ self.assertEqual(column.qualifier, TEST_QUALIFIER)
+ self.assertEqual(len(column.cells), 2)
+ self.assertEqual(column.cells[0].value, TEST_VALUE)
+ self.assertEqual(column.cells[0].timestamp_micros, TEST_TIMESTAMP)
+ self.assertEqual(column.cells[0].labels, TEST_LABELS)
+ self.assertEqual(column.cells[1].value, cell2.value)
+ self.assertEqual(column.cells[1].timestamp_micros, TEST_TIMESTAMP)
+ self.assertEqual(column.cells[1].labels, TEST_LABELS)
+
+ def test_iteration(self):
+ from google.cloud.bigtable.data.row import Cell
+
+ # should be able to iterate over the Row as a list
+ cell1 = self._make_cell(value=b"1")
+ cell2 = self._make_cell(value=b"2")
+ cell3 = self._make_cell(value=b"3")
+ row_response = self._make_one(TEST_ROW_KEY, [cell1, cell2, cell3])
+ self.assertEqual(len(row_response), 3)
+ result_list = list(row_response)
+ self.assertEqual(len(result_list), 3)
+ # should be able to iterate over all cells
+ idx = 0
+ for cell in row_response:
+ self.assertIsInstance(cell, Cell)
+ self.assertEqual(cell.value, result_list[idx].value)
+ self.assertEqual(cell.value, str(idx + 1).encode())
+ idx += 1
+
+ def test_contains_cell(self):
+ cell3 = self._make_cell(value=b"3")
+ cell1 = self._make_cell(value=b"1")
+ cell2 = self._make_cell(value=b"2")
+ cell4 = self._make_cell(value=b"4")
+ row_response = self._make_one(TEST_ROW_KEY, [cell3, cell1, cell2])
+ self.assertIn(cell1, row_response)
+ self.assertIn(cell2, row_response)
+ self.assertNotIn(cell4, row_response)
+ cell3_copy = self._make_cell(value=b"3")
+ self.assertIn(cell3_copy, row_response)
+
+ def test_contains_family_id(self):
+ new_family_id = "new_family_id"
+ cell = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ cell2 = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ new_family_id,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ row_response = self._make_one(TEST_ROW_KEY, [cell, cell2])
+ self.assertIn(TEST_FAMILY_ID, row_response)
+ self.assertIn("new_family_id", row_response)
+ self.assertIn(new_family_id, row_response)
+ self.assertNotIn("not_a_family_id", row_response)
+ self.assertNotIn(None, row_response)
+
+ def test_contains_family_qualifier_tuple(self):
+ new_family_id = "new_family_id"
+ new_qualifier = b"new_qualifier"
+ cell = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ cell2 = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ new_family_id,
+ new_qualifier,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ row_response = self._make_one(TEST_ROW_KEY, [cell, cell2])
+ self.assertIn((TEST_FAMILY_ID, TEST_QUALIFIER), row_response)
+ self.assertIn(("new_family_id", "new_qualifier"), row_response)
+ self.assertIn(("new_family_id", b"new_qualifier"), row_response)
+ self.assertIn((new_family_id, new_qualifier), row_response)
+
+ self.assertNotIn(("not_a_family_id", TEST_QUALIFIER), row_response)
+ self.assertNotIn((TEST_FAMILY_ID, "not_a_qualifier"), row_response)
+ self.assertNotIn((TEST_FAMILY_ID, new_qualifier), row_response)
+ self.assertNotIn(("not_a_family_id", "not_a_qualifier"), row_response)
+ self.assertNotIn((None, None), row_response)
+ self.assertNotIn(None, row_response)
+
+ def test_int_indexing(self):
+ # should be able to index into underlying list with an index number directly
+ cell_list = [self._make_cell(value=str(i).encode()) for i in range(10)]
+ sorted(cell_list)
+ row_response = self._make_one(TEST_ROW_KEY, cell_list)
+ self.assertEqual(len(row_response), 10)
+ for i in range(10):
+ self.assertEqual(row_response[i].value, str(i).encode())
+ # backwards indexing should work
+ self.assertEqual(row_response[-i - 1].value, str(9 - i).encode())
+ with self.assertRaises(IndexError):
+ row_response[10]
+ with self.assertRaises(IndexError):
+ row_response[-11]
+
+ def test_slice_indexing(self):
+ # should be able to index with a range of indices
+ cell_list = [self._make_cell(value=str(i).encode()) for i in range(10)]
+ sorted(cell_list)
+ row_response = self._make_one(TEST_ROW_KEY, cell_list)
+ self.assertEqual(len(row_response), 10)
+ self.assertEqual(len(row_response[0:10]), 10)
+ self.assertEqual(row_response[0:10], cell_list)
+ self.assertEqual(len(row_response[0:]), 10)
+ self.assertEqual(row_response[0:], cell_list)
+ self.assertEqual(len(row_response[:10]), 10)
+ self.assertEqual(row_response[:10], cell_list)
+ self.assertEqual(len(row_response[0:10:1]), 10)
+ self.assertEqual(row_response[0:10:1], cell_list)
+ self.assertEqual(len(row_response[0:10:2]), 5)
+ self.assertEqual(row_response[0:10:2], [cell_list[i] for i in range(0, 10, 2)])
+ self.assertEqual(len(row_response[0:10:3]), 4)
+ self.assertEqual(row_response[0:10:3], [cell_list[i] for i in range(0, 10, 3)])
+ self.assertEqual(len(row_response[10:0:-1]), 9)
+ self.assertEqual(len(row_response[10:0:-2]), 5)
+ self.assertEqual(row_response[10:0:-3], cell_list[10:0:-3])
+ self.assertEqual(len(row_response[0:100]), 10)
+
+ def test_family_indexing(self):
+ # should be able to retrieve cells in a family
+ new_family_id = "new_family_id"
+ cell = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ cell2 = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ cell3 = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ new_family_id,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ row_response = self._make_one(TEST_ROW_KEY, [cell, cell2, cell3])
+
+ self.assertEqual(len(row_response[TEST_FAMILY_ID]), 2)
+ self.assertEqual(row_response[TEST_FAMILY_ID][0], cell)
+ self.assertEqual(row_response[TEST_FAMILY_ID][1], cell2)
+ self.assertEqual(len(row_response[new_family_id]), 1)
+ self.assertEqual(row_response[new_family_id][0], cell3)
+ with self.assertRaises(ValueError):
+ row_response["not_a_family_id"]
+ with self.assertRaises(TypeError):
+ row_response[None]
+ with self.assertRaises(TypeError):
+ row_response[b"new_family_id"]
+
+ def test_family_qualifier_indexing(self):
+ # should be able to retrieve cells in a family/qualifier tuplw
+ new_family_id = "new_family_id"
+ new_qualifier = b"new_qualifier"
+ cell = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ cell2 = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ cell3 = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ new_family_id,
+ new_qualifier,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ row_response = self._make_one(TEST_ROW_KEY, [cell, cell2, cell3])
+
+ self.assertEqual(len(row_response[TEST_FAMILY_ID, TEST_QUALIFIER]), 2)
+ self.assertEqual(row_response[TEST_FAMILY_ID, TEST_QUALIFIER][0], cell)
+ self.assertEqual(row_response[TEST_FAMILY_ID, TEST_QUALIFIER][1], cell2)
+ self.assertEqual(len(row_response[new_family_id, new_qualifier]), 1)
+ self.assertEqual(row_response[new_family_id, new_qualifier][0], cell3)
+ self.assertEqual(len(row_response["new_family_id", "new_qualifier"]), 1)
+ self.assertEqual(len(row_response["new_family_id", b"new_qualifier"]), 1)
+ with self.assertRaises(ValueError):
+ row_response[new_family_id, "not_a_qualifier"]
+ with self.assertRaises(ValueError):
+ row_response["not_a_family_id", new_qualifier]
+ with self.assertRaises(TypeError):
+ row_response[None, None]
+ with self.assertRaises(TypeError):
+ row_response[b"new_family_id", b"new_qualifier"]
+
+ def test_get_column_components(self):
+ # should be able to retrieve (family,qualifier) tuples as keys
+ new_family_id = "new_family_id"
+ new_qualifier = b"new_qualifier"
+ cell = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ cell2 = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ cell3 = self._make_cell(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ new_family_id,
+ new_qualifier,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ row_response = self._make_one(TEST_ROW_KEY, [cell, cell2, cell3])
+
+ self.assertEqual(len(row_response._get_column_components()), 2)
+ self.assertEqual(
+ row_response._get_column_components(),
+ [(TEST_FAMILY_ID, TEST_QUALIFIER), (new_family_id, new_qualifier)],
+ )
+
+ row_response = self._make_one(TEST_ROW_KEY, [])
+ self.assertEqual(len(row_response._get_column_components()), 0)
+ self.assertEqual(row_response._get_column_components(), [])
+
+ row_response = self._make_one(TEST_ROW_KEY, [cell])
+ self.assertEqual(len(row_response._get_column_components()), 1)
+ self.assertEqual(
+ row_response._get_column_components(), [(TEST_FAMILY_ID, TEST_QUALIFIER)]
+ )
+
+
+class TestCell(unittest.TestCase):
+ @staticmethod
+ def _get_target_class():
+ from google.cloud.bigtable.data.row import Cell
+
+ return Cell
+
+ def _make_one(self, *args, **kwargs):
+ if len(args) == 0:
+ args = (
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ return self._get_target_class()(*args, **kwargs)
+
+ def test_ctor(self):
+ cell = self._make_one(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ self.assertEqual(cell.value, TEST_VALUE)
+ self.assertEqual(cell.row_key, TEST_ROW_KEY)
+ self.assertEqual(cell.family, TEST_FAMILY_ID)
+ self.assertEqual(cell.qualifier, TEST_QUALIFIER)
+ self.assertEqual(cell.timestamp_micros, TEST_TIMESTAMP)
+ self.assertEqual(cell.labels, TEST_LABELS)
+
+ def test_to_dict(self):
+ from google.cloud.bigtable_v2.types import Cell
+
+ cell = self._make_one()
+ cell_dict = cell._to_dict()
+ expected_dict = {
+ "value": TEST_VALUE,
+ "timestamp_micros": TEST_TIMESTAMP,
+ "labels": TEST_LABELS,
+ }
+ self.assertEqual(len(cell_dict), len(expected_dict))
+ for key, value in expected_dict.items():
+ self.assertEqual(cell_dict[key], value)
+ # should be able to construct a Cell proto from the dict
+ cell_proto = Cell(**cell_dict)
+ self.assertEqual(cell_proto.value, TEST_VALUE)
+ self.assertEqual(cell_proto.timestamp_micros, TEST_TIMESTAMP)
+ self.assertEqual(cell_proto.labels, TEST_LABELS)
+
+ def test_to_dict_no_labels(self):
+ from google.cloud.bigtable_v2.types import Cell
+
+ cell_no_labels = self._make_one(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ None,
+ )
+ cell_dict = cell_no_labels._to_dict()
+ expected_dict = {
+ "value": TEST_VALUE,
+ "timestamp_micros": TEST_TIMESTAMP,
+ }
+ self.assertEqual(len(cell_dict), len(expected_dict))
+ for key, value in expected_dict.items():
+ self.assertEqual(cell_dict[key], value)
+ # should be able to construct a Cell proto from the dict
+ cell_proto = Cell(**cell_dict)
+ self.assertEqual(cell_proto.value, TEST_VALUE)
+ self.assertEqual(cell_proto.timestamp_micros, TEST_TIMESTAMP)
+ self.assertEqual(cell_proto.labels, [])
+
+ def test_int_value(self):
+ test_int = 1234
+ bytes_value = test_int.to_bytes(4, "big", signed=True)
+ cell = self._make_one(
+ bytes_value,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ self.assertEqual(int(cell), test_int)
+ # ensure string formatting works
+ formatted = "%d" % cell
+ self.assertEqual(formatted, str(test_int))
+ self.assertEqual(int(formatted), test_int)
+
+ def test_int_value_negative(self):
+ test_int = -99999
+ bytes_value = test_int.to_bytes(4, "big", signed=True)
+ cell = self._make_one(
+ bytes_value,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ self.assertEqual(int(cell), test_int)
+ # ensure string formatting works
+ formatted = "%d" % cell
+ self.assertEqual(formatted, str(test_int))
+ self.assertEqual(int(formatted), test_int)
+
+ def test___str__(self):
+ test_value = b"helloworld"
+ cell = self._make_one(
+ test_value,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ self.assertEqual(str(cell), "b'helloworld'")
+ self.assertEqual(str(cell), str(test_value))
+
+ def test___repr__(self):
+ from google.cloud.bigtable.data.row import Cell # type: ignore # noqa: F401
+
+ cell = self._make_one()
+ expected = (
+ "Cell(value=b'1234', row_key=b'row', "
+ + "family='cf1', qualifier=b'col', "
+ + f"timestamp_micros={TEST_TIMESTAMP}, labels=['label1', 'label2'])"
+ )
+ self.assertEqual(repr(cell), expected)
+ # should be able to construct instance from __repr__
+ result = eval(repr(cell))
+ self.assertEqual(result, cell)
+
+ def test___repr___no_labels(self):
+ from google.cloud.bigtable.data.row import Cell # type: ignore # noqa: F401
+
+ cell_no_labels = self._make_one(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ None,
+ )
+ expected = (
+ "Cell(value=b'1234', row_key=b'row', "
+ + "family='cf1', qualifier=b'col', "
+ + f"timestamp_micros={TEST_TIMESTAMP}, labels=[])"
+ )
+ self.assertEqual(repr(cell_no_labels), expected)
+ # should be able to construct instance from __repr__
+ result = eval(repr(cell_no_labels))
+ self.assertEqual(result, cell_no_labels)
+
+ def test_equality(self):
+ cell1 = self._make_one()
+ cell2 = self._make_one()
+ self.assertEqual(cell1, cell2)
+ self.assertTrue(cell1 == cell2)
+ args = (
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ for i in range(0, len(args)):
+ # try changing each argument
+ modified_cell = self._make_one(*args[:i], args[i] + args[i], *args[i + 1 :])
+ self.assertNotEqual(cell1, modified_cell)
+ self.assertFalse(cell1 == modified_cell)
+ self.assertTrue(cell1 != modified_cell)
+
+ def test_hash(self):
+ # class should be hashable
+ cell1 = self._make_one()
+ d = {cell1: 1}
+ cell2 = self._make_one()
+ self.assertEqual(d[cell2], 1)
+
+ args = (
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ TEST_FAMILY_ID,
+ TEST_QUALIFIER,
+ TEST_TIMESTAMP,
+ TEST_LABELS,
+ )
+ for i in range(0, len(args)):
+ # try changing each argument
+ modified_cell = self._make_one(*args[:i], args[i] + args[i], *args[i + 1 :])
+ with self.assertRaises(KeyError):
+ d[modified_cell]
+
+ def test_ordering(self):
+ # create cell list in order from lowest to highest
+ higher_cells = []
+ i = 0
+ # families; alphebetical order
+ for family in ["z", "y", "x"]:
+ # qualifiers; lowest byte value first
+ for qualifier in [b"z", b"y", b"x"]:
+ # timestamps; newest first
+ for timestamp in [
+ TEST_TIMESTAMP,
+ TEST_TIMESTAMP + 1,
+ TEST_TIMESTAMP + 2,
+ ]:
+ cell = self._make_one(
+ TEST_VALUE,
+ TEST_ROW_KEY,
+ family,
+ qualifier,
+ timestamp,
+ TEST_LABELS,
+ )
+ # cell should be the highest priority encountered so far
+ self.assertEqual(i, len(higher_cells))
+ i += 1
+ for other in higher_cells:
+ self.assertLess(cell, other)
+ higher_cells.append(cell)
+ # final order should be reverse of sorted order
+ expected_order = higher_cells
+ expected_order.reverse()
+ self.assertEqual(expected_order, sorted(higher_cells))
diff --git a/tests/unit/data/test_row_filters.py b/tests/unit/data/test_row_filters.py
new file mode 100644
index 000000000..e90b6f270
--- /dev/null
+++ b/tests/unit/data/test_row_filters.py
@@ -0,0 +1,2039 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+
+def test_abstract_class_constructors():
+ from google.cloud.bigtable.data.row_filters import RowFilter
+ from google.cloud.bigtable.data.row_filters import _BoolFilter
+ from google.cloud.bigtable.data.row_filters import _FilterCombination
+ from google.cloud.bigtable.data.row_filters import _CellCountFilter
+
+ with pytest.raises(TypeError):
+ RowFilter()
+ with pytest.raises(TypeError):
+ _BoolFilter(False)
+ with pytest.raises(TypeError):
+ _FilterCombination([])
+ with pytest.raises(TypeError):
+ _CellCountFilter(0)
+
+
+def test_bool_filter_constructor():
+ for FilterType in _get_bool_filters():
+ flag = True
+ row_filter = FilterType(flag)
+ assert row_filter.flag is flag
+
+
+def test_bool_filter___eq__type_differ():
+ for FilterType in _get_bool_filters():
+ flag = object()
+ row_filter1 = FilterType(flag)
+ row_filter2 = object()
+ assert not (row_filter1 == row_filter2)
+
+
+def test_bool_filter___eq__same_value():
+ for FilterType in _get_bool_filters():
+ flag = object()
+ row_filter1 = FilterType(flag)
+ row_filter2 = FilterType(flag)
+ assert row_filter1 == row_filter2
+
+
+def test_bool_filter___ne__same_value():
+ for FilterType in _get_bool_filters():
+ flag = object()
+ row_filter1 = FilterType(flag)
+ row_filter2 = FilterType(flag)
+ assert not (row_filter1 != row_filter2)
+
+
+def test_sink_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import SinkFilter
+
+ flag = True
+ row_filter = SinkFilter(flag)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(sink=flag)
+ assert pb_val == expected_pb
+
+
+def test_sink_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import SinkFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ flag = True
+ row_filter = SinkFilter(flag)
+ expected_dict = {"sink": flag}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_sink_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import SinkFilter
+
+ flag = True
+ row_filter = SinkFilter(flag)
+ assert repr(row_filter) == "SinkFilter(flag={})".format(flag)
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_pass_all_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import PassAllFilter
+
+ flag = True
+ row_filter = PassAllFilter(flag)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(pass_all_filter=flag)
+ assert pb_val == expected_pb
+
+
+def test_pass_all_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import PassAllFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ flag = True
+ row_filter = PassAllFilter(flag)
+ expected_dict = {"pass_all_filter": flag}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_pass_all_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import PassAllFilter
+
+ flag = True
+ row_filter = PassAllFilter(flag)
+ assert repr(row_filter) == "PassAllFilter(flag={})".format(flag)
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_block_all_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import BlockAllFilter
+
+ flag = True
+ row_filter = BlockAllFilter(flag)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(block_all_filter=flag)
+ assert pb_val == expected_pb
+
+
+def test_block_all_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import BlockAllFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ flag = True
+ row_filter = BlockAllFilter(flag)
+ expected_dict = {"block_all_filter": flag}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_block_all_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import BlockAllFilter
+
+ flag = True
+ row_filter = BlockAllFilter(flag)
+ assert repr(row_filter) == "BlockAllFilter(flag={})".format(flag)
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_regex_filterconstructor():
+ for FilterType in _get_regex_filters():
+ regex = b"abc"
+ row_filter = FilterType(regex)
+ assert row_filter.regex == regex
+
+
+def test_regex_filterconstructor_non_bytes():
+ for FilterType in _get_regex_filters():
+ regex = "abc"
+ row_filter = FilterType(regex)
+ assert row_filter.regex == b"abc"
+
+
+def test_regex_filter__eq__type_differ():
+ for FilterType in _get_regex_filters():
+ regex = b"def-rgx"
+ row_filter1 = FilterType(regex)
+ row_filter2 = object()
+ assert not (row_filter1 == row_filter2)
+
+
+def test_regex_filter__eq__same_value():
+ for FilterType in _get_regex_filters():
+ regex = b"trex-regex"
+ row_filter1 = FilterType(regex)
+ row_filter2 = FilterType(regex)
+ assert row_filter1 == row_filter2
+
+
+def test_regex_filter__ne__same_value():
+ for FilterType in _get_regex_filters():
+ regex = b"abc"
+ row_filter1 = FilterType(regex)
+ row_filter2 = FilterType(regex)
+ assert not (row_filter1 != row_filter2)
+
+
+def test_row_key_regex_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import RowKeyRegexFilter
+
+ regex = b"row-key-regex"
+ row_filter = RowKeyRegexFilter(regex)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(row_key_regex_filter=regex)
+ assert pb_val == expected_pb
+
+
+def test_row_key_regex_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import RowKeyRegexFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ regex = b"row-key-regex"
+ row_filter = RowKeyRegexFilter(regex)
+ expected_dict = {"row_key_regex_filter": regex}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_row_key_regex_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import RowKeyRegexFilter
+
+ regex = b"row-key-regex"
+ row_filter = RowKeyRegexFilter(regex)
+ assert repr(row_filter) == "RowKeyRegexFilter(regex={})".format(regex)
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_row_sample_filter_constructor():
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+
+ sample = object()
+ row_filter = RowSampleFilter(sample)
+ assert row_filter.sample is sample
+
+
+def test_row_sample_filter___eq__type_differ():
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+
+ sample = object()
+ row_filter1 = RowSampleFilter(sample)
+ row_filter2 = object()
+ assert not (row_filter1 == row_filter2)
+
+
+def test_row_sample_filter___eq__same_value():
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+
+ sample = object()
+ row_filter1 = RowSampleFilter(sample)
+ row_filter2 = RowSampleFilter(sample)
+ assert row_filter1 == row_filter2
+
+
+def test_row_sample_filter___ne__():
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+
+ sample = object()
+ other_sample = object()
+ row_filter1 = RowSampleFilter(sample)
+ row_filter2 = RowSampleFilter(other_sample)
+ assert row_filter1 != row_filter2
+
+
+def test_row_sample_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+
+ sample = 0.25
+ row_filter = RowSampleFilter(sample)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(row_sample_filter=sample)
+ assert pb_val == expected_pb
+
+
+def test_row_sample_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+
+ sample = 0.25
+ row_filter = RowSampleFilter(sample)
+ assert repr(row_filter) == "RowSampleFilter(sample={})".format(sample)
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_family_name_regex_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import FamilyNameRegexFilter
+
+ regex = "family-regex"
+ row_filter = FamilyNameRegexFilter(regex)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(family_name_regex_filter=regex)
+ assert pb_val == expected_pb
+
+
+def test_family_name_regex_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import FamilyNameRegexFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ regex = "family-regex"
+ row_filter = FamilyNameRegexFilter(regex)
+ expected_dict = {"family_name_regex_filter": regex.encode()}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_family_name_regex_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import FamilyNameRegexFilter
+
+ regex = "family-regex"
+ row_filter = FamilyNameRegexFilter(regex)
+ expected = "FamilyNameRegexFilter(regex=b'family-regex')"
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_column_qualifier_regex_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import ColumnQualifierRegexFilter
+
+ regex = b"column-regex"
+ row_filter = ColumnQualifierRegexFilter(regex)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(column_qualifier_regex_filter=regex)
+ assert pb_val == expected_pb
+
+
+def test_column_qualifier_regex_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import ColumnQualifierRegexFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ regex = b"column-regex"
+ row_filter = ColumnQualifierRegexFilter(regex)
+ expected_dict = {"column_qualifier_regex_filter": regex}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_column_qualifier_regex_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import ColumnQualifierRegexFilter
+
+ regex = b"column-regex"
+ row_filter = ColumnQualifierRegexFilter(regex)
+ assert repr(row_filter) == "ColumnQualifierRegexFilter(regex={})".format(regex)
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_timestamp_range_constructor():
+ from google.cloud.bigtable.data.row_filters import TimestampRange
+
+ start = object()
+ end = object()
+ time_range = TimestampRange(start=start, end=end)
+ assert time_range.start is start
+ assert time_range.end is end
+
+
+def test_timestamp_range___eq__():
+ from google.cloud.bigtable.data.row_filters import TimestampRange
+
+ start = object()
+ end = object()
+ time_range1 = TimestampRange(start=start, end=end)
+ time_range2 = TimestampRange(start=start, end=end)
+ assert time_range1 == time_range2
+
+
+def test_timestamp_range___eq__type_differ():
+ from google.cloud.bigtable.data.row_filters import TimestampRange
+
+ start = object()
+ end = object()
+ time_range1 = TimestampRange(start=start, end=end)
+ time_range2 = object()
+ assert not (time_range1 == time_range2)
+
+
+def test_timestamp_range___ne__same_value():
+ from google.cloud.bigtable.data.row_filters import TimestampRange
+
+ start = object()
+ end = object()
+ time_range1 = TimestampRange(start=start, end=end)
+ time_range2 = TimestampRange(start=start, end=end)
+ assert not (time_range1 != time_range2)
+
+
+def _timestamp_range_to_pb_helper(pb_kwargs, start=None, end=None):
+ import datetime
+ from google.cloud._helpers import _EPOCH
+ from google.cloud.bigtable.data.row_filters import TimestampRange
+
+ if start is not None:
+ start = _EPOCH + datetime.timedelta(microseconds=start)
+ if end is not None:
+ end = _EPOCH + datetime.timedelta(microseconds=end)
+ time_range = TimestampRange(start=start, end=end)
+ expected_pb = _TimestampRangePB(**pb_kwargs)
+ time_pb = time_range._to_pb()
+ assert time_pb.start_timestamp_micros == expected_pb.start_timestamp_micros
+ assert time_pb.end_timestamp_micros == expected_pb.end_timestamp_micros
+ assert time_pb == expected_pb
+
+
+def test_timestamp_range_to_pb():
+ start_micros = 30871234
+ end_micros = 12939371234
+ start_millis = start_micros // 1000 * 1000
+ assert start_millis == 30871000
+ end_millis = end_micros // 1000 * 1000 + 1000
+ assert end_millis == 12939372000
+ pb_kwargs = {}
+ pb_kwargs["start_timestamp_micros"] = start_millis
+ pb_kwargs["end_timestamp_micros"] = end_millis
+ _timestamp_range_to_pb_helper(pb_kwargs, start=start_micros, end=end_micros)
+
+
+def test_timestamp_range_to_dict():
+ from google.cloud.bigtable.data.row_filters import TimestampRange
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+ import datetime
+
+ row_filter = TimestampRange(
+ start=datetime.datetime(2019, 1, 1), end=datetime.datetime(2019, 1, 2)
+ )
+ expected_dict = {
+ "start_timestamp_micros": 1546300800000000,
+ "end_timestamp_micros": 1546387200000000,
+ }
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.TimestampRange(**expected_dict) == expected_pb_value
+
+
+def test_timestamp_range_to_pb_start_only():
+ # Makes sure already milliseconds granularity
+ start_micros = 30871000
+ start_millis = start_micros // 1000 * 1000
+ assert start_millis == 30871000
+ pb_kwargs = {}
+ pb_kwargs["start_timestamp_micros"] = start_millis
+ _timestamp_range_to_pb_helper(pb_kwargs, start=start_micros, end=None)
+
+
+def test_timestamp_range_to_dict_start_only():
+ from google.cloud.bigtable.data.row_filters import TimestampRange
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+ import datetime
+
+ row_filter = TimestampRange(start=datetime.datetime(2019, 1, 1))
+ expected_dict = {"start_timestamp_micros": 1546300800000000}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.TimestampRange(**expected_dict) == expected_pb_value
+
+
+def test_timestamp_range_to_pb_end_only():
+ # Makes sure already milliseconds granularity
+ end_micros = 12939371000
+ end_millis = end_micros // 1000 * 1000
+ assert end_millis == 12939371000
+ pb_kwargs = {}
+ pb_kwargs["end_timestamp_micros"] = end_millis
+ _timestamp_range_to_pb_helper(pb_kwargs, start=None, end=end_micros)
+
+
+def test_timestamp_range_to_dict_end_only():
+ from google.cloud.bigtable.data.row_filters import TimestampRange
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+ import datetime
+
+ row_filter = TimestampRange(end=datetime.datetime(2019, 1, 2))
+ expected_dict = {"end_timestamp_micros": 1546387200000000}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.TimestampRange(**expected_dict) == expected_pb_value
+
+
+def timestamp_range___repr__():
+ from google.cloud.bigtable.data.row_filters import TimestampRange
+
+ start = object()
+ end = object()
+ time_range = TimestampRange(start=start, end=end)
+ assert repr(time_range) == "TimestampRange(start={}, end={})".format(start, end)
+ assert repr(time_range) == str(time_range)
+ assert eval(repr(time_range)) == time_range
+
+
+def test_timestamp_range_filter___eq__type_differ():
+ from google.cloud.bigtable.data.row_filters import TimestampRangeFilter
+
+ range_ = object()
+ row_filter1 = TimestampRangeFilter(range_)
+ row_filter2 = object()
+ assert not (row_filter1 == row_filter2)
+
+
+def test_timestamp_range_filter___eq__same_value():
+ from google.cloud.bigtable.data.row_filters import TimestampRangeFilter
+
+ range_ = object()
+ row_filter1 = TimestampRangeFilter(range_)
+ row_filter2 = TimestampRangeFilter(range_)
+ assert row_filter1 == row_filter2
+
+
+def test_timestamp_range_filter___ne__():
+ from google.cloud.bigtable.data.row_filters import TimestampRangeFilter
+
+ range_ = object()
+ other_range_ = object()
+ row_filter1 = TimestampRangeFilter(range_)
+ row_filter2 = TimestampRangeFilter(other_range_)
+ assert row_filter1 != row_filter2
+
+
+def test_timestamp_range_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import TimestampRangeFilter
+
+ row_filter = TimestampRangeFilter()
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(timestamp_range_filter=_TimestampRangePB())
+ assert pb_val == expected_pb
+
+
+def test_timestamp_range_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import TimestampRangeFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+ import datetime
+
+ row_filter = TimestampRangeFilter(
+ start=datetime.datetime(2019, 1, 1), end=datetime.datetime(2019, 1, 2)
+ )
+ expected_dict = {
+ "timestamp_range_filter": {
+ "start_timestamp_micros": 1546300800000000,
+ "end_timestamp_micros": 1546387200000000,
+ }
+ }
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_timestamp_range_filter_empty_to_dict():
+ from google.cloud.bigtable.data.row_filters import TimestampRangeFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter = TimestampRangeFilter()
+ expected_dict = {"timestamp_range_filter": {}}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_timestamp_range_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import TimestampRangeFilter
+ import datetime
+
+ start = datetime.datetime(2019, 1, 1)
+ end = datetime.datetime(2019, 1, 2)
+ row_filter = TimestampRangeFilter(start, end)
+ assert (
+ repr(row_filter)
+ == f"TimestampRangeFilter(start={repr(start)}, end={repr(end)})"
+ )
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_column_range_filter_constructor_defaults():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = object()
+ row_filter = ColumnRangeFilter(family_id)
+ assert row_filter.family_id is family_id
+ assert row_filter.start_qualifier is None
+ assert row_filter.end_qualifier is None
+ assert row_filter.inclusive_start
+ assert row_filter.inclusive_end
+
+
+def test_column_range_filter_constructor_explicit():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = object()
+ start_qualifier = object()
+ end_qualifier = object()
+ inclusive_start = object()
+ inclusive_end = object()
+ row_filter = ColumnRangeFilter(
+ family_id,
+ start_qualifier=start_qualifier,
+ end_qualifier=end_qualifier,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+ assert row_filter.family_id is family_id
+ assert row_filter.start_qualifier is start_qualifier
+ assert row_filter.end_qualifier is end_qualifier
+ assert row_filter.inclusive_start is inclusive_start
+ assert row_filter.inclusive_end is inclusive_end
+
+
+def test_column_range_filter_constructor_():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = object()
+ with pytest.raises(ValueError):
+ ColumnRangeFilter(family_id, inclusive_start=True)
+
+
+def test_column_range_filter_constructor_bad_end():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = object()
+ with pytest.raises(ValueError):
+ ColumnRangeFilter(family_id, inclusive_end=True)
+
+
+def test_column_range_filter___eq__():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = object()
+ start_qualifier = object()
+ end_qualifier = object()
+ inclusive_start = object()
+ inclusive_end = object()
+ row_filter1 = ColumnRangeFilter(
+ family_id,
+ start_qualifier=start_qualifier,
+ end_qualifier=end_qualifier,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+ row_filter2 = ColumnRangeFilter(
+ family_id,
+ start_qualifier=start_qualifier,
+ end_qualifier=end_qualifier,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+ assert row_filter1 == row_filter2
+
+
+def test_column_range_filter___eq__type_differ():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = object()
+ row_filter1 = ColumnRangeFilter(family_id)
+ row_filter2 = object()
+ assert not (row_filter1 == row_filter2)
+
+
+def test_column_range_filter___ne__():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = object()
+ other_family_id = object()
+ start_qualifier = object()
+ end_qualifier = object()
+ inclusive_start = object()
+ inclusive_end = object()
+ row_filter1 = ColumnRangeFilter(
+ family_id,
+ start_qualifier=start_qualifier,
+ end_qualifier=end_qualifier,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+ row_filter2 = ColumnRangeFilter(
+ other_family_id,
+ start_qualifier=start_qualifier,
+ end_qualifier=end_qualifier,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+ assert row_filter1 != row_filter2
+
+
+def test_column_range_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = "column-family-id"
+ row_filter = ColumnRangeFilter(family_id)
+ col_range_pb = _ColumnRangePB(family_name=family_id)
+ expected_pb = _RowFilterPB(column_range_filter=col_range_pb)
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_column_range_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ family_id = "column-family-id"
+ row_filter = ColumnRangeFilter(family_id)
+ expected_dict = {"column_range_filter": {"family_name": family_id}}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_column_range_filter_to_pb_inclusive_start():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = "column-family-id"
+ column = b"column"
+ row_filter = ColumnRangeFilter(family_id, start_qualifier=column)
+ col_range_pb = _ColumnRangePB(family_name=family_id, start_qualifier_closed=column)
+ expected_pb = _RowFilterPB(column_range_filter=col_range_pb)
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_column_range_filter_to_pb_exclusive_start():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = "column-family-id"
+ column = b"column"
+ row_filter = ColumnRangeFilter(
+ family_id, start_qualifier=column, inclusive_start=False
+ )
+ col_range_pb = _ColumnRangePB(family_name=family_id, start_qualifier_open=column)
+ expected_pb = _RowFilterPB(column_range_filter=col_range_pb)
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_column_range_filter_to_pb_inclusive_end():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = "column-family-id"
+ column = b"column"
+ row_filter = ColumnRangeFilter(family_id, end_qualifier=column)
+ col_range_pb = _ColumnRangePB(family_name=family_id, end_qualifier_closed=column)
+ expected_pb = _RowFilterPB(column_range_filter=col_range_pb)
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_column_range_filter_to_pb_exclusive_end():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = "column-family-id"
+ column = b"column"
+ row_filter = ColumnRangeFilter(family_id, end_qualifier=column, inclusive_end=False)
+ col_range_pb = _ColumnRangePB(family_name=family_id, end_qualifier_open=column)
+ expected_pb = _RowFilterPB(column_range_filter=col_range_pb)
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_column_range_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import ColumnRangeFilter
+
+ family_id = "column-family-id"
+ start_qualifier = b"column"
+ end_qualifier = b"column2"
+ row_filter = ColumnRangeFilter(family_id, start_qualifier, end_qualifier)
+ expected = "ColumnRangeFilter(family_id='column-family-id', start_qualifier=b'column', end_qualifier=b'column2', inclusive_start=True, inclusive_end=True)"
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_value_regex_filter_to_pb_w_bytes():
+ from google.cloud.bigtable.data.row_filters import ValueRegexFilter
+
+ value = regex = b"value-regex"
+ row_filter = ValueRegexFilter(value)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(value_regex_filter=regex)
+ assert pb_val == expected_pb
+
+
+def test_value_regex_filter_to_dict_w_bytes():
+ from google.cloud.bigtable.data.row_filters import ValueRegexFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ value = regex = b"value-regex"
+ row_filter = ValueRegexFilter(value)
+ expected_dict = {"value_regex_filter": regex}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_value_regex_filter_to_pb_w_str():
+ from google.cloud.bigtable.data.row_filters import ValueRegexFilter
+
+ value = "value-regex"
+ regex = value.encode("ascii")
+ row_filter = ValueRegexFilter(value)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(value_regex_filter=regex)
+ assert pb_val == expected_pb
+
+
+def test_value_regex_filter_to_dict_w_str():
+ from google.cloud.bigtable.data.row_filters import ValueRegexFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ value = "value-regex"
+ regex = value.encode("ascii")
+ row_filter = ValueRegexFilter(value)
+ expected_dict = {"value_regex_filter": regex}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_value_regex_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import ValueRegexFilter
+
+ value = "value-regex"
+ row_filter = ValueRegexFilter(value)
+ expected = "ValueRegexFilter(regex=b'value-regex')"
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_literal_value_filter_to_pb_w_bytes():
+ from google.cloud.bigtable.data.row_filters import LiteralValueFilter
+
+ value = regex = b"value_regex"
+ row_filter = LiteralValueFilter(value)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(value_regex_filter=regex)
+ assert pb_val == expected_pb
+
+
+def test_literal_value_filter_to_dict_w_bytes():
+ from google.cloud.bigtable.data.row_filters import LiteralValueFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ value = regex = b"value_regex"
+ row_filter = LiteralValueFilter(value)
+ expected_dict = {"value_regex_filter": regex}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_literal_value_filter_to_pb_w_str():
+ from google.cloud.bigtable.data.row_filters import LiteralValueFilter
+
+ value = "value_regex"
+ regex = value.encode("ascii")
+ row_filter = LiteralValueFilter(value)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(value_regex_filter=regex)
+ assert pb_val == expected_pb
+
+
+def test_literal_value_filter_to_dict_w_str():
+ from google.cloud.bigtable.data.row_filters import LiteralValueFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ value = "value_regex"
+ regex = value.encode("ascii")
+ row_filter = LiteralValueFilter(value)
+ expected_dict = {"value_regex_filter": regex}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+@pytest.mark.parametrize(
+ "value,expected_byte_string",
+ [
+ # null bytes are encoded as "\x00" in ascii characters
+ # others are just prefixed with "\"
+ (0, b"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00"),
+ (1, b"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\\x01"),
+ (
+ 68,
+ b"\\x00\\x00\\x00\\x00\\x00\\x00\\x00D",
+ ), # bytes that encode to alphanum are not escaped
+ (570, b"\\x00\\x00\\x00\\x00\\x00\\x00\\\x02\\\x3a"),
+ (2852126720, b"\\x00\\x00\\x00\\x00\xaa\\x00\\x00\\x00"),
+ (-1, b"\xff\xff\xff\xff\xff\xff\xff\xff"),
+ (-1096642724096, b"\xff\xff\xff\\x00\xaa\xff\xff\\x00"),
+ ],
+)
+def test_literal_value_filter_w_int(value, expected_byte_string):
+ from google.cloud.bigtable.data.row_filters import LiteralValueFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter = LiteralValueFilter(value)
+ # test pb
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(value_regex_filter=expected_byte_string)
+ assert pb_val == expected_pb
+ # test dict
+ expected_dict = {"value_regex_filter": expected_byte_string}
+ assert row_filter._to_dict() == expected_dict
+ assert data_v2_pb2.RowFilter(**expected_dict) == pb_val
+
+
+def test_literal_value_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import LiteralValueFilter
+
+ value = "value_regex"
+ row_filter = LiteralValueFilter(value)
+ expected = "LiteralValueFilter(value=b'value_regex')"
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_value_range_filter_constructor_defaults():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ row_filter = ValueRangeFilter()
+
+ assert row_filter.start_value is None
+ assert row_filter.end_value is None
+ assert row_filter.inclusive_start
+ assert row_filter.inclusive_end
+
+
+def test_value_range_filter_constructor_explicit():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ start_value = object()
+ end_value = object()
+ inclusive_start = object()
+ inclusive_end = object()
+
+ row_filter = ValueRangeFilter(
+ start_value=start_value,
+ end_value=end_value,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+
+ assert row_filter.start_value is start_value
+ assert row_filter.end_value is end_value
+ assert row_filter.inclusive_start is inclusive_start
+ assert row_filter.inclusive_end is inclusive_end
+
+
+def test_value_range_filter_constructor_w_int_values():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+ import struct
+
+ start_value = 1
+ end_value = 10
+
+ row_filter = ValueRangeFilter(start_value=start_value, end_value=end_value)
+
+ expected_start_value = struct.Struct(">q").pack(start_value)
+ expected_end_value = struct.Struct(">q").pack(end_value)
+
+ assert row_filter.start_value == expected_start_value
+ assert row_filter.end_value == expected_end_value
+ assert row_filter.inclusive_start
+ assert row_filter.inclusive_end
+
+
+def test_value_range_filter_constructor_bad_start():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ with pytest.raises(ValueError):
+ ValueRangeFilter(inclusive_start=True)
+
+
+def test_value_range_filter_constructor_bad_end():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ with pytest.raises(ValueError):
+ ValueRangeFilter(inclusive_end=True)
+
+
+def test_value_range_filter___eq__():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ start_value = object()
+ end_value = object()
+ inclusive_start = object()
+ inclusive_end = object()
+ row_filter1 = ValueRangeFilter(
+ start_value=start_value,
+ end_value=end_value,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+ row_filter2 = ValueRangeFilter(
+ start_value=start_value,
+ end_value=end_value,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+ assert row_filter1 == row_filter2
+
+
+def test_value_range_filter___eq__type_differ():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ row_filter1 = ValueRangeFilter()
+ row_filter2 = object()
+ assert not (row_filter1 == row_filter2)
+
+
+def test_value_range_filter___ne__():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ start_value = object()
+ other_start_value = object()
+ end_value = object()
+ inclusive_start = object()
+ inclusive_end = object()
+ row_filter1 = ValueRangeFilter(
+ start_value=start_value,
+ end_value=end_value,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+ row_filter2 = ValueRangeFilter(
+ start_value=other_start_value,
+ end_value=end_value,
+ inclusive_start=inclusive_start,
+ inclusive_end=inclusive_end,
+ )
+ assert row_filter1 != row_filter2
+
+
+def test_value_range_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ row_filter = ValueRangeFilter()
+ expected_pb = _RowFilterPB(value_range_filter=_ValueRangePB())
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_value_range_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter = ValueRangeFilter()
+ expected_dict = {"value_range_filter": {}}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_value_range_filter_to_pb_inclusive_start():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ value = b"some-value"
+ row_filter = ValueRangeFilter(start_value=value)
+ val_range_pb = _ValueRangePB(start_value_closed=value)
+ expected_pb = _RowFilterPB(value_range_filter=val_range_pb)
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_value_range_filter_to_pb_exclusive_start():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ value = b"some-value"
+ row_filter = ValueRangeFilter(start_value=value, inclusive_start=False)
+ val_range_pb = _ValueRangePB(start_value_open=value)
+ expected_pb = _RowFilterPB(value_range_filter=val_range_pb)
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_value_range_filter_to_pb_inclusive_end():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ value = b"some-value"
+ row_filter = ValueRangeFilter(end_value=value)
+ val_range_pb = _ValueRangePB(end_value_closed=value)
+ expected_pb = _RowFilterPB(value_range_filter=val_range_pb)
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_value_range_filter_to_pb_exclusive_end():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ value = b"some-value"
+ row_filter = ValueRangeFilter(end_value=value, inclusive_end=False)
+ val_range_pb = _ValueRangePB(end_value_open=value)
+ expected_pb = _RowFilterPB(value_range_filter=val_range_pb)
+ assert row_filter._to_pb() == expected_pb
+
+
+def test_value_range_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import ValueRangeFilter
+
+ start_value = b"some-value"
+ end_value = b"some-other-value"
+ row_filter = ValueRangeFilter(
+ start_value=start_value, end_value=end_value, inclusive_end=False
+ )
+ expected = "ValueRangeFilter(start_value=b'some-value', end_value=b'some-other-value', inclusive_start=True, inclusive_end=False)"
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_cell_count_constructor():
+ for FilerType in _get_cell_count_filters():
+ num_cells = object()
+ row_filter = FilerType(num_cells)
+ assert row_filter.num_cells is num_cells
+
+
+def test_cell_count___eq__type_differ():
+ for FilerType in _get_cell_count_filters():
+ num_cells = object()
+ row_filter1 = FilerType(num_cells)
+ row_filter2 = object()
+ assert not (row_filter1 == row_filter2)
+
+
+def test_cell_count___eq__same_value():
+ for FilerType in _get_cell_count_filters():
+ num_cells = object()
+ row_filter1 = FilerType(num_cells)
+ row_filter2 = FilerType(num_cells)
+ assert row_filter1 == row_filter2
+
+
+def test_cell_count___ne__same_value():
+ for FilerType in _get_cell_count_filters():
+ num_cells = object()
+ row_filter1 = FilerType(num_cells)
+ row_filter2 = FilerType(num_cells)
+ assert not (row_filter1 != row_filter2)
+
+
+def test_cells_row_offset_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter
+
+ num_cells = 76
+ row_filter = CellsRowOffsetFilter(num_cells)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(cells_per_row_offset_filter=num_cells)
+ assert pb_val == expected_pb
+
+
+def test_cells_row_offset_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ num_cells = 76
+ row_filter = CellsRowOffsetFilter(num_cells)
+ expected_dict = {"cells_per_row_offset_filter": num_cells}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_cells_row_offset_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter
+
+ num_cells = 76
+ row_filter = CellsRowOffsetFilter(num_cells)
+ expected = "CellsRowOffsetFilter(num_cells={})".format(num_cells)
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_cells_row_limit_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter
+
+ num_cells = 189
+ row_filter = CellsRowLimitFilter(num_cells)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(cells_per_row_limit_filter=num_cells)
+ assert pb_val == expected_pb
+
+
+def test_cells_row_limit_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ num_cells = 189
+ row_filter = CellsRowLimitFilter(num_cells)
+ expected_dict = {"cells_per_row_limit_filter": num_cells}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_cells_row_limit_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter
+
+ num_cells = 189
+ row_filter = CellsRowLimitFilter(num_cells)
+ expected = "CellsRowLimitFilter(num_cells={})".format(num_cells)
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_cells_column_limit_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import CellsColumnLimitFilter
+
+ num_cells = 10
+ row_filter = CellsColumnLimitFilter(num_cells)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(cells_per_column_limit_filter=num_cells)
+ assert pb_val == expected_pb
+
+
+def test_cells_column_limit_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import CellsColumnLimitFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ num_cells = 10
+ row_filter = CellsColumnLimitFilter(num_cells)
+ expected_dict = {"cells_per_column_limit_filter": num_cells}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_cells_column_limit_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import CellsColumnLimitFilter
+
+ num_cells = 10
+ row_filter = CellsColumnLimitFilter(num_cells)
+ expected = "CellsColumnLimitFilter(num_cells={})".format(num_cells)
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_strip_value_transformer_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ flag = True
+ row_filter = StripValueTransformerFilter(flag)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(strip_value_transformer=flag)
+ assert pb_val == expected_pb
+
+
+def test_strip_value_transformer_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ flag = True
+ row_filter = StripValueTransformerFilter(flag)
+ expected_dict = {"strip_value_transformer": flag}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_strip_value_transformer_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ flag = True
+ row_filter = StripValueTransformerFilter(flag)
+ expected = "StripValueTransformerFilter(flag={})".format(flag)
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_apply_label_filter_constructor():
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+
+ label = object()
+ row_filter = ApplyLabelFilter(label)
+ assert row_filter.label is label
+
+
+def test_apply_label_filter___eq__type_differ():
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+
+ label = object()
+ row_filter1 = ApplyLabelFilter(label)
+ row_filter2 = object()
+ assert not (row_filter1 == row_filter2)
+
+
+def test_apply_label_filter___eq__same_value():
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+
+ label = object()
+ row_filter1 = ApplyLabelFilter(label)
+ row_filter2 = ApplyLabelFilter(label)
+ assert row_filter1 == row_filter2
+
+
+def test_apply_label_filter___ne__():
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+
+ label = object()
+ other_label = object()
+ row_filter1 = ApplyLabelFilter(label)
+ row_filter2 = ApplyLabelFilter(other_label)
+ assert row_filter1 != row_filter2
+
+
+def test_apply_label_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+
+ label = "label"
+ row_filter = ApplyLabelFilter(label)
+ pb_val = row_filter._to_pb()
+ expected_pb = _RowFilterPB(apply_label_transformer=label)
+ assert pb_val == expected_pb
+
+
+def test_apply_label_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ label = "label"
+ row_filter = ApplyLabelFilter(label)
+ expected_dict = {"apply_label_transformer": label}
+ assert row_filter._to_dict() == expected_dict
+ expected_pb_value = row_filter._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_apply_label_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import ApplyLabelFilter
+
+ label = "label"
+ row_filter = ApplyLabelFilter(label)
+ expected = "ApplyLabelFilter(label={})".format(label)
+ assert repr(row_filter) == expected
+ assert repr(row_filter) == str(row_filter)
+ assert eval(repr(row_filter)) == row_filter
+
+
+def test_filter_combination_constructor_defaults():
+ for FilterType in _get_filter_combination_filters():
+ row_filter = FilterType()
+ assert row_filter.filters == []
+
+
+def test_filter_combination_constructor_explicit():
+ for FilterType in _get_filter_combination_filters():
+ filters = object()
+ row_filter = FilterType(filters=filters)
+ assert row_filter.filters is filters
+
+
+def test_filter_combination___eq__():
+ for FilterType in _get_filter_combination_filters():
+ filters = object()
+ row_filter1 = FilterType(filters=filters)
+ row_filter2 = FilterType(filters=filters)
+ assert row_filter1 == row_filter2
+
+
+def test_filter_combination___eq__type_differ():
+ for FilterType in _get_filter_combination_filters():
+ filters = object()
+ row_filter1 = FilterType(filters=filters)
+ row_filter2 = object()
+ assert not (row_filter1 == row_filter2)
+
+
+def test_filter_combination___ne__():
+ for FilterType in _get_filter_combination_filters():
+ filters = object()
+ other_filters = object()
+ row_filter1 = FilterType(filters=filters)
+ row_filter2 = FilterType(filters=other_filters)
+ assert row_filter1 != row_filter2
+
+
+def test_filter_combination_len():
+ for FilterType in _get_filter_combination_filters():
+ filters = [object(), object()]
+ row_filter = FilterType(filters=filters)
+ assert len(row_filter) == len(filters)
+
+
+def test_filter_combination_iter():
+ for FilterType in _get_filter_combination_filters():
+ filters = [object(), object()]
+ row_filter = FilterType(filters=filters)
+ assert list(iter(row_filter)) == filters
+ for filter_, expected in zip(row_filter, filters):
+ assert filter_ is expected
+
+
+def test_filter_combination___getitem__():
+ for FilterType in _get_filter_combination_filters():
+ filters = [object(), object()]
+ row_filter = FilterType(filters=filters)
+ row_filter[0] is filters[0]
+ row_filter[1] is filters[1]
+ with pytest.raises(IndexError):
+ row_filter[2]
+ row_filter[:] is filters[:]
+
+
+def test_filter_combination___str__():
+ from google.cloud.bigtable.data.row_filters import PassAllFilter
+
+ for FilterType in _get_filter_combination_filters():
+ filters = [PassAllFilter(True), PassAllFilter(False)]
+ row_filter = FilterType(filters=filters)
+ expected = (
+ "([\n PassAllFilter(flag=True),\n PassAllFilter(flag=False),\n])"
+ )
+ assert expected in str(row_filter)
+
+
+def test_row_filter_chain_to_pb():
+ from google.cloud.bigtable.data.row_filters import RowFilterChain
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_pb = row_filter1._to_pb()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_pb = row_filter2._to_pb()
+
+ row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2])
+ filter_pb = row_filter3._to_pb()
+
+ expected_pb = _RowFilterPB(
+ chain=_RowFilterChainPB(filters=[row_filter1_pb, row_filter2_pb])
+ )
+ assert filter_pb == expected_pb
+
+
+def test_row_filter_chain_to_dict():
+ from google.cloud.bigtable.data.row_filters import RowFilterChain
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_dict = row_filter1._to_dict()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_dict = row_filter2._to_dict()
+
+ row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2])
+ filter_dict = row_filter3._to_dict()
+
+ expected_dict = {"chain": {"filters": [row_filter1_dict, row_filter2_dict]}}
+ assert filter_dict == expected_dict
+ expected_pb_value = row_filter3._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_row_filter_chain_to_pb_nested():
+ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter
+ from google.cloud.bigtable.data.row_filters import RowFilterChain
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter2 = RowSampleFilter(0.25)
+
+ row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2])
+ row_filter3_pb = row_filter3._to_pb()
+
+ row_filter4 = CellsRowLimitFilter(11)
+ row_filter4_pb = row_filter4._to_pb()
+
+ row_filter5 = RowFilterChain(filters=[row_filter3, row_filter4])
+ filter_pb = row_filter5._to_pb()
+
+ expected_pb = _RowFilterPB(
+ chain=_RowFilterChainPB(filters=[row_filter3_pb, row_filter4_pb])
+ )
+ assert filter_pb == expected_pb
+
+
+def test_row_filter_chain_to_dict_nested():
+ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter
+ from google.cloud.bigtable.data.row_filters import RowFilterChain
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter1 = StripValueTransformerFilter(True)
+
+ row_filter2 = RowSampleFilter(0.25)
+
+ row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2])
+ row_filter3_dict = row_filter3._to_dict()
+
+ row_filter4 = CellsRowLimitFilter(11)
+ row_filter4_dict = row_filter4._to_dict()
+
+ row_filter5 = RowFilterChain(filters=[row_filter3, row_filter4])
+ filter_dict = row_filter5._to_dict()
+
+ expected_dict = {"chain": {"filters": [row_filter3_dict, row_filter4_dict]}}
+ assert filter_dict == expected_dict
+ expected_pb_value = row_filter5._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_row_filter_chain___repr__():
+ from google.cloud.bigtable.data.row_filters import RowFilterChain
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter2 = RowSampleFilter(0.25)
+
+ row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2])
+ expected = f"RowFilterChain(filters={[row_filter1, row_filter2]})"
+ assert repr(row_filter3) == expected
+ assert eval(repr(row_filter3)) == row_filter3
+
+
+def test_row_filter_chain___str__():
+ from google.cloud.bigtable.data.row_filters import RowFilterChain
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter2 = RowSampleFilter(0.25)
+
+ row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2])
+ expected = "RowFilterChain([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n])"
+ assert str(row_filter3) == expected
+ # test nested
+ row_filter4 = RowFilterChain(filters=[row_filter3])
+ expected = "RowFilterChain([\n RowFilterChain([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n ]),\n])"
+ assert str(row_filter4) == expected
+
+
+def test_row_filter_union_to_pb():
+ from google.cloud.bigtable.data.row_filters import RowFilterUnion
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_pb = row_filter1._to_pb()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_pb = row_filter2._to_pb()
+
+ row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2])
+ filter_pb = row_filter3._to_pb()
+
+ expected_pb = _RowFilterPB(
+ interleave=_RowFilterInterleavePB(filters=[row_filter1_pb, row_filter2_pb])
+ )
+ assert filter_pb == expected_pb
+
+
+def test_row_filter_union_to_dict():
+ from google.cloud.bigtable.data.row_filters import RowFilterUnion
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_dict = row_filter1._to_dict()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_dict = row_filter2._to_dict()
+
+ row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2])
+ filter_dict = row_filter3._to_dict()
+
+ expected_dict = {"interleave": {"filters": [row_filter1_dict, row_filter2_dict]}}
+ assert filter_dict == expected_dict
+ expected_pb_value = row_filter3._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_row_filter_union_to_pb_nested():
+ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter
+ from google.cloud.bigtable.data.row_filters import RowFilterUnion
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter2 = RowSampleFilter(0.25)
+
+ row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2])
+ row_filter3_pb = row_filter3._to_pb()
+
+ row_filter4 = CellsRowLimitFilter(11)
+ row_filter4_pb = row_filter4._to_pb()
+
+ row_filter5 = RowFilterUnion(filters=[row_filter3, row_filter4])
+ filter_pb = row_filter5._to_pb()
+
+ expected_pb = _RowFilterPB(
+ interleave=_RowFilterInterleavePB(filters=[row_filter3_pb, row_filter4_pb])
+ )
+ assert filter_pb == expected_pb
+
+
+def test_row_filter_union_to_dict_nested():
+ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter
+ from google.cloud.bigtable.data.row_filters import RowFilterUnion
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter1 = StripValueTransformerFilter(True)
+
+ row_filter2 = RowSampleFilter(0.25)
+
+ row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2])
+ row_filter3_dict = row_filter3._to_dict()
+
+ row_filter4 = CellsRowLimitFilter(11)
+ row_filter4_dict = row_filter4._to_dict()
+
+ row_filter5 = RowFilterUnion(filters=[row_filter3, row_filter4])
+ filter_dict = row_filter5._to_dict()
+
+ expected_dict = {"interleave": {"filters": [row_filter3_dict, row_filter4_dict]}}
+ assert filter_dict == expected_dict
+ expected_pb_value = row_filter5._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_row_filter_union___repr__():
+ from google.cloud.bigtable.data.row_filters import RowFilterUnion
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter2 = RowSampleFilter(0.25)
+
+ row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2])
+ expected = "RowFilterUnion(filters=[StripValueTransformerFilter(flag=True), RowSampleFilter(sample=0.25)])"
+ assert repr(row_filter3) == expected
+ assert eval(repr(row_filter3)) == row_filter3
+
+
+def test_row_filter_union___str__():
+ from google.cloud.bigtable.data.row_filters import RowFilterUnion
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter2 = RowSampleFilter(0.25)
+
+ row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2])
+ expected = "RowFilterUnion([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n])"
+ assert str(row_filter3) == expected
+ # test nested
+ row_filter4 = RowFilterUnion(filters=[row_filter3])
+ expected = "RowFilterUnion([\n RowFilterUnion([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n ]),\n])"
+ assert str(row_filter4) == expected
+
+
+def test_conditional_row_filter_constructor():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+
+ predicate_filter = object()
+ true_filter = object()
+ false_filter = object()
+ cond_filter = ConditionalRowFilter(
+ predicate_filter, true_filter=true_filter, false_filter=false_filter
+ )
+ assert cond_filter.predicate_filter is predicate_filter
+ assert cond_filter.true_filter is true_filter
+ assert cond_filter.false_filter is false_filter
+
+
+def test_conditional_row_filter___eq__():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+
+ predicate_filter = object()
+ true_filter = object()
+ false_filter = object()
+ cond_filter1 = ConditionalRowFilter(
+ predicate_filter, true_filter=true_filter, false_filter=false_filter
+ )
+ cond_filter2 = ConditionalRowFilter(
+ predicate_filter, true_filter=true_filter, false_filter=false_filter
+ )
+ assert cond_filter1 == cond_filter2
+
+
+def test_conditional_row_filter___eq__type_differ():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+
+ predicate_filter = object()
+ true_filter = object()
+ false_filter = object()
+ cond_filter1 = ConditionalRowFilter(
+ predicate_filter, true_filter=true_filter, false_filter=false_filter
+ )
+ cond_filter2 = object()
+ assert not (cond_filter1 == cond_filter2)
+
+
+def test_conditional_row_filter___ne__():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+
+ predicate_filter = object()
+ other_predicate_filter = object()
+ true_filter = object()
+ false_filter = object()
+ cond_filter1 = ConditionalRowFilter(
+ predicate_filter, true_filter=true_filter, false_filter=false_filter
+ )
+ cond_filter2 = ConditionalRowFilter(
+ other_predicate_filter, true_filter=true_filter, false_filter=false_filter
+ )
+ assert cond_filter1 != cond_filter2
+
+
+def test_conditional_row_filter_to_pb():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+ from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_pb = row_filter1._to_pb()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_pb = row_filter2._to_pb()
+
+ row_filter3 = CellsRowOffsetFilter(11)
+ row_filter3_pb = row_filter3._to_pb()
+
+ row_filter4 = ConditionalRowFilter(
+ row_filter1, true_filter=row_filter2, false_filter=row_filter3
+ )
+ filter_pb = row_filter4._to_pb()
+
+ expected_pb = _RowFilterPB(
+ condition=_RowFilterConditionPB(
+ predicate_filter=row_filter1_pb,
+ true_filter=row_filter2_pb,
+ false_filter=row_filter3_pb,
+ )
+ )
+ assert filter_pb == expected_pb
+
+
+def test_conditional_row_filter_to_dict():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+ from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_dict = row_filter1._to_dict()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_dict = row_filter2._to_dict()
+
+ row_filter3 = CellsRowOffsetFilter(11)
+ row_filter3_dict = row_filter3._to_dict()
+
+ row_filter4 = ConditionalRowFilter(
+ row_filter1, true_filter=row_filter2, false_filter=row_filter3
+ )
+ filter_dict = row_filter4._to_dict()
+
+ expected_dict = {
+ "condition": {
+ "predicate_filter": row_filter1_dict,
+ "true_filter": row_filter2_dict,
+ "false_filter": row_filter3_dict,
+ }
+ }
+ assert filter_dict == expected_dict
+ expected_pb_value = row_filter4._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_conditional_row_filter_to_pb_true_only():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_pb = row_filter1._to_pb()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_pb = row_filter2._to_pb()
+
+ row_filter3 = ConditionalRowFilter(row_filter1, true_filter=row_filter2)
+ filter_pb = row_filter3._to_pb()
+
+ expected_pb = _RowFilterPB(
+ condition=_RowFilterConditionPB(
+ predicate_filter=row_filter1_pb, true_filter=row_filter2_pb
+ )
+ )
+ assert filter_pb == expected_pb
+
+
+def test_conditional_row_filter_to_dict_true_only():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_dict = row_filter1._to_dict()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_dict = row_filter2._to_dict()
+
+ row_filter3 = ConditionalRowFilter(row_filter1, true_filter=row_filter2)
+ filter_dict = row_filter3._to_dict()
+
+ expected_dict = {
+ "condition": {
+ "predicate_filter": row_filter1_dict,
+ "true_filter": row_filter2_dict,
+ }
+ }
+ assert filter_dict == expected_dict
+ expected_pb_value = row_filter3._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_conditional_row_filter_to_pb_false_only():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_pb = row_filter1._to_pb()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_pb = row_filter2._to_pb()
+
+ row_filter3 = ConditionalRowFilter(row_filter1, false_filter=row_filter2)
+ filter_pb = row_filter3._to_pb()
+
+ expected_pb = _RowFilterPB(
+ condition=_RowFilterConditionPB(
+ predicate_filter=row_filter1_pb, false_filter=row_filter2_pb
+ )
+ )
+ assert filter_pb == expected_pb
+
+
+def test_conditional_row_filter_to_dict_false_only():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter1_dict = row_filter1._to_dict()
+
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter2_dict = row_filter2._to_dict()
+
+ row_filter3 = ConditionalRowFilter(row_filter1, false_filter=row_filter2)
+ filter_dict = row_filter3._to_dict()
+
+ expected_dict = {
+ "condition": {
+ "predicate_filter": row_filter1_dict,
+ "false_filter": row_filter2_dict,
+ }
+ }
+ assert filter_dict == expected_dict
+ expected_pb_value = row_filter3._to_pb()
+ assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value
+
+
+def test_conditional_row_filter___repr__():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter3 = ConditionalRowFilter(row_filter1, true_filter=row_filter2)
+ expected = (
+ "ConditionalRowFilter(predicate_filter=StripValueTransformerFilter("
+ "flag=True), true_filter=RowSampleFilter(sample=0.25), false_filter=None)"
+ )
+ assert repr(row_filter3) == expected
+ assert eval(repr(row_filter3)) == row_filter3
+ # test nested
+ row_filter4 = ConditionalRowFilter(row_filter3, true_filter=row_filter2)
+ expected = "ConditionalRowFilter(predicate_filter=ConditionalRowFilter(predicate_filter=StripValueTransformerFilter(flag=True), true_filter=RowSampleFilter(sample=0.25), false_filter=None), true_filter=RowSampleFilter(sample=0.25), false_filter=None)"
+ assert repr(row_filter4) == expected
+ assert eval(repr(row_filter4)) == row_filter4
+
+
+def test_conditional_row_filter___str__():
+ from google.cloud.bigtable.data.row_filters import ConditionalRowFilter
+ from google.cloud.bigtable.data.row_filters import RowSampleFilter
+ from google.cloud.bigtable.data.row_filters import RowFilterUnion
+ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter
+
+ row_filter1 = StripValueTransformerFilter(True)
+ row_filter2 = RowSampleFilter(0.25)
+ row_filter3 = ConditionalRowFilter(row_filter1, true_filter=row_filter2)
+ expected = "ConditionalRowFilter(\n predicate_filter=StripValueTransformerFilter(flag=True),\n true_filter=RowSampleFilter(sample=0.25),\n)"
+ assert str(row_filter3) == expected
+ # test nested
+ row_filter4 = ConditionalRowFilter(
+ row_filter3,
+ true_filter=row_filter2,
+ false_filter=RowFilterUnion([row_filter1, row_filter2]),
+ )
+ expected = "ConditionalRowFilter(\n predicate_filter=ConditionalRowFilter(\n predicate_filter=StripValueTransformerFilter(flag=True),\n true_filter=RowSampleFilter(sample=0.25),\n ),\n true_filter=RowSampleFilter(sample=0.25),\n false_filter=RowFilterUnion([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n ]),\n)"
+ assert str(row_filter4) == expected
+
+
+@pytest.mark.parametrize(
+ "input_arg, expected_bytes",
+ [
+ (b"abc", b"abc"),
+ ("abc", b"abc"),
+ (1, b"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\\x01"), # null bytes are ascii
+ (b"*", b"\\*"),
+ (".", b"\\."),
+ (b"\\", b"\\\\"),
+ (b"h.*i", b"h\\.\\*i"),
+ (b'""', b'\\"\\"'),
+ (b"[xyz]", b"\\[xyz\\]"),
+ (b"\xe2\x98\xba\xef\xb8\x8f", b"\xe2\x98\xba\xef\xb8\x8f"),
+ ("☃", b"\xe2\x98\x83"),
+ (r"\C☃", b"\\\\C\xe2\x98\x83"),
+ ],
+)
+def test_literal_value__write_literal_regex(input_arg, expected_bytes):
+ from google.cloud.bigtable.data.row_filters import LiteralValueFilter
+
+ filter_ = LiteralValueFilter(input_arg)
+ assert filter_.regex == expected_bytes
+
+
+def _ColumnRangePB(*args, **kw):
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ return data_v2_pb2.ColumnRange(*args, **kw)
+
+
+def _RowFilterPB(*args, **kw):
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ return data_v2_pb2.RowFilter(*args, **kw)
+
+
+def _RowFilterChainPB(*args, **kw):
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ return data_v2_pb2.RowFilter.Chain(*args, **kw)
+
+
+def _RowFilterConditionPB(*args, **kw):
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ return data_v2_pb2.RowFilter.Condition(*args, **kw)
+
+
+def _RowFilterInterleavePB(*args, **kw):
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ return data_v2_pb2.RowFilter.Interleave(*args, **kw)
+
+
+def _TimestampRangePB(*args, **kw):
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ return data_v2_pb2.TimestampRange(*args, **kw)
+
+
+def _ValueRangePB(*args, **kw):
+ from google.cloud.bigtable_v2.types import data as data_v2_pb2
+
+ return data_v2_pb2.ValueRange(*args, **kw)
+
+
+def _get_regex_filters():
+ from google.cloud.bigtable.data.row_filters import (
+ RowKeyRegexFilter,
+ FamilyNameRegexFilter,
+ ColumnQualifierRegexFilter,
+ ValueRegexFilter,
+ LiteralValueFilter,
+ )
+
+ return [
+ RowKeyRegexFilter,
+ FamilyNameRegexFilter,
+ ColumnQualifierRegexFilter,
+ ValueRegexFilter,
+ LiteralValueFilter,
+ ]
+
+
+def _get_bool_filters():
+ from google.cloud.bigtable.data.row_filters import (
+ SinkFilter,
+ PassAllFilter,
+ BlockAllFilter,
+ StripValueTransformerFilter,
+ )
+
+ return [
+ SinkFilter,
+ PassAllFilter,
+ BlockAllFilter,
+ StripValueTransformerFilter,
+ ]
+
+
+def _get_cell_count_filters():
+ from google.cloud.bigtable.data.row_filters import (
+ CellsRowLimitFilter,
+ CellsRowOffsetFilter,
+ CellsColumnLimitFilter,
+ )
+
+ return [
+ CellsRowLimitFilter,
+ CellsRowOffsetFilter,
+ CellsColumnLimitFilter,
+ ]
+
+
+def _get_filter_combination_filters():
+ from google.cloud.bigtable.data.row_filters import (
+ RowFilterChain,
+ RowFilterUnion,
+ )
+
+ return [
+ RowFilterChain,
+ RowFilterUnion,
+ ]
diff --git a/tests/unit/gapic/bigtable_admin_v2/test_bigtable_instance_admin.py b/tests/unit/gapic/bigtable_admin_v2/test_bigtable_instance_admin.py
index ddbf0032f..7a24cab54 100644
--- a/tests/unit/gapic/bigtable_admin_v2/test_bigtable_instance_admin.py
+++ b/tests/unit/gapic/bigtable_admin_v2/test_bigtable_instance_admin.py
@@ -29,6 +29,7 @@
import json
import math
import pytest
+from google.api_core import api_core_version
from proto.marshal.rules.dates import DurationRule, TimestampRule
from proto.marshal.rules import wrappers
from requests import Response
@@ -86,6 +87,17 @@ def modify_default_endpoint(client):
)
+# If default endpoint template is localhost, then default mtls endpoint will be the same.
+# This method modifies the default endpoint template so the client can produce a different
+# mtls endpoint for endpoint testing purposes.
+def modify_default_endpoint_template(client):
+ return (
+ "test.{UNIVERSE_DOMAIN}"
+ if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE)
+ else client._DEFAULT_ENDPOINT_TEMPLATE
+ )
+
+
def test__get_default_mtls_endpoint():
api_endpoint = "example.googleapis.com"
api_mtls_endpoint = "example.mtls.googleapis.com"
@@ -116,6 +128,298 @@ def test__get_default_mtls_endpoint():
)
+def test__read_environment_variables():
+ assert BigtableInstanceAdminClient._read_environment_variables() == (
+ False,
+ "auto",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
+ assert BigtableInstanceAdminClient._read_environment_variables() == (
+ True,
+ "auto",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}):
+ assert BigtableInstanceAdminClient._read_environment_variables() == (
+ False,
+ "auto",
+ None,
+ )
+
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
+ ):
+ with pytest.raises(ValueError) as excinfo:
+ BigtableInstanceAdminClient._read_environment_variables()
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ assert BigtableInstanceAdminClient._read_environment_variables() == (
+ False,
+ "never",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
+ assert BigtableInstanceAdminClient._read_environment_variables() == (
+ False,
+ "always",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
+ assert BigtableInstanceAdminClient._read_environment_variables() == (
+ False,
+ "auto",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
+ with pytest.raises(MutualTLSChannelError) as excinfo:
+ BigtableInstanceAdminClient._read_environment_variables()
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}):
+ assert BigtableInstanceAdminClient._read_environment_variables() == (
+ False,
+ "auto",
+ "foo.com",
+ )
+
+
+def test__get_client_cert_source():
+ mock_provided_cert_source = mock.Mock()
+ mock_default_cert_source = mock.Mock()
+
+ assert BigtableInstanceAdminClient._get_client_cert_source(None, False) is None
+ assert (
+ BigtableInstanceAdminClient._get_client_cert_source(
+ mock_provided_cert_source, False
+ )
+ is None
+ )
+ assert (
+ BigtableInstanceAdminClient._get_client_cert_source(
+ mock_provided_cert_source, True
+ )
+ == mock_provided_cert_source
+ )
+
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source", return_value=True
+ ):
+ with mock.patch(
+ "google.auth.transport.mtls.default_client_cert_source",
+ return_value=mock_default_cert_source,
+ ):
+ assert (
+ BigtableInstanceAdminClient._get_client_cert_source(None, True)
+ is mock_default_cert_source
+ )
+ assert (
+ BigtableInstanceAdminClient._get_client_cert_source(
+ mock_provided_cert_source, "true"
+ )
+ is mock_provided_cert_source
+ )
+
+
+@mock.patch.object(
+ BigtableInstanceAdminClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableInstanceAdminClient),
+)
+@mock.patch.object(
+ BigtableInstanceAdminAsyncClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableInstanceAdminAsyncClient),
+)
+def test__get_api_endpoint():
+ api_override = "foo.com"
+ mock_client_cert_source = mock.Mock()
+ default_universe = BigtableInstanceAdminClient._DEFAULT_UNIVERSE
+ default_endpoint = BigtableInstanceAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=default_universe
+ )
+ mock_universe = "bar.com"
+ mock_endpoint = BigtableInstanceAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=mock_universe
+ )
+
+ assert (
+ BigtableInstanceAdminClient._get_api_endpoint(
+ api_override, mock_client_cert_source, default_universe, "always"
+ )
+ == api_override
+ )
+ assert (
+ BigtableInstanceAdminClient._get_api_endpoint(
+ None, mock_client_cert_source, default_universe, "auto"
+ )
+ == BigtableInstanceAdminClient.DEFAULT_MTLS_ENDPOINT
+ )
+ assert (
+ BigtableInstanceAdminClient._get_api_endpoint(
+ None, None, default_universe, "auto"
+ )
+ == default_endpoint
+ )
+ assert (
+ BigtableInstanceAdminClient._get_api_endpoint(
+ None, None, default_universe, "always"
+ )
+ == BigtableInstanceAdminClient.DEFAULT_MTLS_ENDPOINT
+ )
+ assert (
+ BigtableInstanceAdminClient._get_api_endpoint(
+ None, mock_client_cert_source, default_universe, "always"
+ )
+ == BigtableInstanceAdminClient.DEFAULT_MTLS_ENDPOINT
+ )
+ assert (
+ BigtableInstanceAdminClient._get_api_endpoint(
+ None, None, mock_universe, "never"
+ )
+ == mock_endpoint
+ )
+ assert (
+ BigtableInstanceAdminClient._get_api_endpoint(
+ None, None, default_universe, "never"
+ )
+ == default_endpoint
+ )
+
+ with pytest.raises(MutualTLSChannelError) as excinfo:
+ BigtableInstanceAdminClient._get_api_endpoint(
+ None, mock_client_cert_source, mock_universe, "auto"
+ )
+ assert (
+ str(excinfo.value)
+ == "mTLS is not supported in any universe other than googleapis.com."
+ )
+
+
+def test__get_universe_domain():
+ client_universe_domain = "foo.com"
+ universe_domain_env = "bar.com"
+
+ assert (
+ BigtableInstanceAdminClient._get_universe_domain(
+ client_universe_domain, universe_domain_env
+ )
+ == client_universe_domain
+ )
+ assert (
+ BigtableInstanceAdminClient._get_universe_domain(None, universe_domain_env)
+ == universe_domain_env
+ )
+ assert (
+ BigtableInstanceAdminClient._get_universe_domain(None, None)
+ == BigtableInstanceAdminClient._DEFAULT_UNIVERSE
+ )
+
+ with pytest.raises(ValueError) as excinfo:
+ BigtableInstanceAdminClient._get_universe_domain("", None)
+ assert str(excinfo.value) == "Universe Domain cannot be an empty string."
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name",
+ [
+ (
+ BigtableInstanceAdminClient,
+ transports.BigtableInstanceAdminGrpcTransport,
+ "grpc",
+ ),
+ (
+ BigtableInstanceAdminClient,
+ transports.BigtableInstanceAdminRestTransport,
+ "rest",
+ ),
+ ],
+)
+def test__validate_universe_domain(client_class, transport_class, transport_name):
+ client = client_class(
+ transport=transport_class(credentials=ga_credentials.AnonymousCredentials())
+ )
+ assert client._validate_universe_domain() == True
+
+ # Test the case when universe is already validated.
+ assert client._validate_universe_domain() == True
+
+ if transport_name == "grpc":
+ # Test the case where credentials are provided by the
+ # `local_channel_credentials`. The default universes in both match.
+ channel = grpc.secure_channel(
+ "http://localhost/", grpc.local_channel_credentials()
+ )
+ client = client_class(transport=transport_class(channel=channel))
+ assert client._validate_universe_domain() == True
+
+ # Test the case where credentials do not exist: e.g. a transport is provided
+ # with no credentials. Validation should still succeed because there is no
+ # mismatch with non-existent credentials.
+ channel = grpc.secure_channel(
+ "http://localhost/", grpc.local_channel_credentials()
+ )
+ transport = transport_class(channel=channel)
+ transport._credentials = None
+ client = client_class(transport=transport)
+ assert client._validate_universe_domain() == True
+
+ # TODO: This is needed to cater for older versions of google-auth
+ # Make this test unconditional once the minimum supported version of
+ # google-auth becomes 2.23.0 or higher.
+ google_auth_major, google_auth_minor = [
+ int(part) for part in google.auth.__version__.split(".")[0:2]
+ ]
+ if google_auth_major > 2 or (google_auth_major == 2 and google_auth_minor >= 23):
+ credentials = ga_credentials.AnonymousCredentials()
+ credentials._universe_domain = "foo.com"
+ # Test the case when there is a universe mismatch from the credentials.
+ client = client_class(transport=transport_class(credentials=credentials))
+ with pytest.raises(ValueError) as excinfo:
+ client._validate_universe_domain()
+ assert (
+ str(excinfo.value)
+ == "The configured universe domain (googleapis.com) does not match the universe domain found in the credentials (foo.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default."
+ )
+
+ # Test the case when there is a universe mismatch from the client.
+ #
+ # TODO: Make this test unconditional once the minimum supported version of
+ # google-api-core becomes 2.15.0 or higher.
+ api_core_major, api_core_minor = [
+ int(part) for part in api_core_version.__version__.split(".")[0:2]
+ ]
+ if api_core_major > 2 or (api_core_major == 2 and api_core_minor >= 15):
+ client = client_class(
+ client_options={"universe_domain": "bar.com"},
+ transport=transport_class(
+ credentials=ga_credentials.AnonymousCredentials(),
+ ),
+ )
+ with pytest.raises(ValueError) as excinfo:
+ client._validate_universe_domain()
+ assert (
+ str(excinfo.value)
+ == "The configured universe domain (bar.com) does not match the universe domain found in the credentials (googleapis.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default."
+ )
+
+ # Test that ValueError is raised if universe_domain is provided via client options and credentials is None
+ with pytest.raises(ValueError):
+ client._compare_universes("foo.bar", None)
+
+
@pytest.mark.parametrize(
"client_class,transport_name",
[
@@ -239,13 +543,13 @@ def test_bigtable_instance_admin_client_get_transport_class():
)
@mock.patch.object(
BigtableInstanceAdminClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableInstanceAdminClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableInstanceAdminClient),
)
@mock.patch.object(
BigtableInstanceAdminAsyncClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableInstanceAdminAsyncClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableInstanceAdminAsyncClient),
)
def test_bigtable_instance_admin_client_client_options(
client_class, transport_class, transport_name
@@ -287,7 +591,9 @@ def test_bigtable_instance_admin_client_client_options(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -317,15 +623,23 @@ def test_bigtable_instance_admin_client_client_options(
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
# unsupported value.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
- with pytest.raises(MutualTLSChannelError):
+ with pytest.raises(MutualTLSChannelError) as excinfo:
client = client_class(transport=transport_name)
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
# Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value.
with mock.patch.dict(
os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
):
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError) as excinfo:
client = client_class(transport=transport_name)
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
# Check the case quota_project_id is provided
options = client_options.ClientOptions(quota_project_id="octopus")
@@ -335,7 +649,9 @@ def test_bigtable_instance_admin_client_client_options(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id="octopus",
@@ -353,7 +669,9 @@ def test_bigtable_instance_admin_client_client_options(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -406,13 +724,13 @@ def test_bigtable_instance_admin_client_client_options(
)
@mock.patch.object(
BigtableInstanceAdminClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableInstanceAdminClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableInstanceAdminClient),
)
@mock.patch.object(
BigtableInstanceAdminAsyncClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableInstanceAdminAsyncClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableInstanceAdminAsyncClient),
)
@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"})
def test_bigtable_instance_admin_client_mtls_env_auto(
@@ -435,7 +753,9 @@ def test_bigtable_instance_admin_client_mtls_env_auto(
if use_client_cert_env == "false":
expected_client_cert_source = None
- expected_host = client.DEFAULT_ENDPOINT
+ expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ )
else:
expected_client_cert_source = client_cert_source_callback
expected_host = client.DEFAULT_MTLS_ENDPOINT
@@ -467,7 +787,9 @@ def test_bigtable_instance_admin_client_mtls_env_auto(
return_value=client_cert_source_callback,
):
if use_client_cert_env == "false":
- expected_host = client.DEFAULT_ENDPOINT
+ expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ )
expected_client_cert_source = None
else:
expected_host = client.DEFAULT_MTLS_ENDPOINT
@@ -501,7 +823,9 @@ def test_bigtable_instance_admin_client_mtls_env_auto(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -591,6 +915,115 @@ def test_bigtable_instance_admin_client_get_mtls_endpoint_and_cert_source(client
assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
assert cert_source == mock_client_cert_source
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
+ # unsupported value.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
+ with pytest.raises(MutualTLSChannelError) as excinfo:
+ client_class.get_mtls_endpoint_and_cert_source()
+
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+
+ # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
+ ):
+ with pytest.raises(ValueError) as excinfo:
+ client_class.get_mtls_endpoint_and_cert_source()
+
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+
+
+@pytest.mark.parametrize(
+ "client_class", [BigtableInstanceAdminClient, BigtableInstanceAdminAsyncClient]
+)
+@mock.patch.object(
+ BigtableInstanceAdminClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableInstanceAdminClient),
+)
+@mock.patch.object(
+ BigtableInstanceAdminAsyncClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableInstanceAdminAsyncClient),
+)
+def test_bigtable_instance_admin_client_client_api_endpoint(client_class):
+ mock_client_cert_source = client_cert_source_callback
+ api_override = "foo.com"
+ default_universe = BigtableInstanceAdminClient._DEFAULT_UNIVERSE
+ default_endpoint = BigtableInstanceAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=default_universe
+ )
+ mock_universe = "bar.com"
+ mock_endpoint = BigtableInstanceAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=mock_universe
+ )
+
+ # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true",
+ # use ClientOptions.api_endpoint as the api endpoint regardless.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
+ with mock.patch(
+ "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel"
+ ):
+ options = client_options.ClientOptions(
+ client_cert_source=mock_client_cert_source, api_endpoint=api_override
+ )
+ client = client_class(
+ client_options=options,
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ assert client.api_endpoint == api_override
+
+ # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never",
+ # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ client = client_class(credentials=ga_credentials.AnonymousCredentials())
+ assert client.api_endpoint == default_endpoint
+
+ # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always",
+ # use the DEFAULT_MTLS_ENDPOINT as the api endpoint.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
+ client = client_class(credentials=ga_credentials.AnonymousCredentials())
+ assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
+
+ # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default),
+ # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist,
+ # and ClientOptions.universe_domain="bar.com",
+ # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint.
+ options = client_options.ClientOptions()
+ universe_exists = hasattr(options, "universe_domain")
+ if universe_exists:
+ options = client_options.ClientOptions(universe_domain=mock_universe)
+ client = client_class(
+ client_options=options, credentials=ga_credentials.AnonymousCredentials()
+ )
+ else:
+ client = client_class(
+ client_options=options, credentials=ga_credentials.AnonymousCredentials()
+ )
+ assert client.api_endpoint == (
+ mock_endpoint if universe_exists else default_endpoint
+ )
+ assert client.universe_domain == (
+ mock_universe if universe_exists else default_universe
+ )
+
+ # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never",
+ # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint.
+ options = client_options.ClientOptions()
+ if hasattr(options, "universe_domain"):
+ delattr(options, "universe_domain")
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ client = client_class(
+ client_options=options, credentials=ga_credentials.AnonymousCredentials()
+ )
+ assert client.api_endpoint == default_endpoint
+
@pytest.mark.parametrize(
"client_class,transport_class,transport_name",
@@ -625,7 +1058,9 @@ def test_bigtable_instance_admin_client_client_options_scopes(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=["1", "2"],
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -670,7 +1105,9 @@ def test_bigtable_instance_admin_client_client_options_credentials_file(
patched.assert_called_once_with(
credentials=None,
credentials_file="credentials.json",
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -730,7 +1167,9 @@ def test_bigtable_instance_admin_client_create_channel_credentials_file(
patched.assert_called_once_with(
credentials=None,
credentials_file="credentials.json",
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -4261,7 +4700,7 @@ async def test_list_app_profiles_flattened_error_async():
def test_list_app_profiles_pager(transport_name: str = "grpc"):
client = BigtableInstanceAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -4313,7 +4752,7 @@ def test_list_app_profiles_pager(transport_name: str = "grpc"):
def test_list_app_profiles_pages(transport_name: str = "grpc"):
client = BigtableInstanceAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -4357,7 +4796,7 @@ def test_list_app_profiles_pages(transport_name: str = "grpc"):
@pytest.mark.asyncio
async def test_list_app_profiles_async_pager():
client = BigtableInstanceAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -4409,7 +4848,7 @@ async def test_list_app_profiles_async_pager():
@pytest.mark.asyncio
async def test_list_app_profiles_async_pages():
client = BigtableInstanceAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -5953,7 +6392,7 @@ async def test_list_hot_tablets_flattened_error_async():
def test_list_hot_tablets_pager(transport_name: str = "grpc"):
client = BigtableInstanceAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -6003,7 +6442,7 @@ def test_list_hot_tablets_pager(transport_name: str = "grpc"):
def test_list_hot_tablets_pages(transport_name: str = "grpc"):
client = BigtableInstanceAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -6045,7 +6484,7 @@ def test_list_hot_tablets_pages(transport_name: str = "grpc"):
@pytest.mark.asyncio
async def test_list_hot_tablets_async_pager():
client = BigtableInstanceAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -6095,7 +6534,7 @@ async def test_list_hot_tablets_async_pager():
@pytest.mark.asyncio
async def test_list_hot_tablets_async_pages():
client = BigtableInstanceAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -12348,7 +12787,7 @@ def test_credentials_transport_error():
)
# It is an error to provide an api_key and a credential.
- options = mock.Mock()
+ options = client_options.ClientOptions()
options.api_key = "api_key"
with pytest.raises(ValueError):
client = BigtableInstanceAdminClient(
@@ -13381,7 +13820,9 @@ def test_api_key_credentials(client_class, transport_class):
patched.assert_called_once_with(
credentials=mock_cred,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
diff --git a/tests/unit/gapic/bigtable_admin_v2/test_bigtable_table_admin.py b/tests/unit/gapic/bigtable_admin_v2/test_bigtable_table_admin.py
index b29dc5106..b52ad0606 100644
--- a/tests/unit/gapic/bigtable_admin_v2/test_bigtable_table_admin.py
+++ b/tests/unit/gapic/bigtable_admin_v2/test_bigtable_table_admin.py
@@ -29,6 +29,7 @@
import json
import math
import pytest
+from google.api_core import api_core_version
from proto.marshal.rules.dates import DurationRule, TimestampRule
from proto.marshal.rules import wrappers
from requests import Response
@@ -88,6 +89,17 @@ def modify_default_endpoint(client):
)
+# If default endpoint template is localhost, then default mtls endpoint will be the same.
+# This method modifies the default endpoint template so the client can produce a different
+# mtls endpoint for endpoint testing purposes.
+def modify_default_endpoint_template(client):
+ return (
+ "test.{UNIVERSE_DOMAIN}"
+ if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE)
+ else client._DEFAULT_ENDPOINT_TEMPLATE
+ )
+
+
def test__get_default_mtls_endpoint():
api_endpoint = "example.googleapis.com"
api_mtls_endpoint = "example.mtls.googleapis.com"
@@ -118,6 +130,286 @@ def test__get_default_mtls_endpoint():
)
+def test__read_environment_variables():
+ assert BigtableTableAdminClient._read_environment_variables() == (
+ False,
+ "auto",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
+ assert BigtableTableAdminClient._read_environment_variables() == (
+ True,
+ "auto",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}):
+ assert BigtableTableAdminClient._read_environment_variables() == (
+ False,
+ "auto",
+ None,
+ )
+
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
+ ):
+ with pytest.raises(ValueError) as excinfo:
+ BigtableTableAdminClient._read_environment_variables()
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ assert BigtableTableAdminClient._read_environment_variables() == (
+ False,
+ "never",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
+ assert BigtableTableAdminClient._read_environment_variables() == (
+ False,
+ "always",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
+ assert BigtableTableAdminClient._read_environment_variables() == (
+ False,
+ "auto",
+ None,
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
+ with pytest.raises(MutualTLSChannelError) as excinfo:
+ BigtableTableAdminClient._read_environment_variables()
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}):
+ assert BigtableTableAdminClient._read_environment_variables() == (
+ False,
+ "auto",
+ "foo.com",
+ )
+
+
+def test__get_client_cert_source():
+ mock_provided_cert_source = mock.Mock()
+ mock_default_cert_source = mock.Mock()
+
+ assert BigtableTableAdminClient._get_client_cert_source(None, False) is None
+ assert (
+ BigtableTableAdminClient._get_client_cert_source(
+ mock_provided_cert_source, False
+ )
+ is None
+ )
+ assert (
+ BigtableTableAdminClient._get_client_cert_source(
+ mock_provided_cert_source, True
+ )
+ == mock_provided_cert_source
+ )
+
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source", return_value=True
+ ):
+ with mock.patch(
+ "google.auth.transport.mtls.default_client_cert_source",
+ return_value=mock_default_cert_source,
+ ):
+ assert (
+ BigtableTableAdminClient._get_client_cert_source(None, True)
+ is mock_default_cert_source
+ )
+ assert (
+ BigtableTableAdminClient._get_client_cert_source(
+ mock_provided_cert_source, "true"
+ )
+ is mock_provided_cert_source
+ )
+
+
+@mock.patch.object(
+ BigtableTableAdminClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableTableAdminClient),
+)
+@mock.patch.object(
+ BigtableTableAdminAsyncClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableTableAdminAsyncClient),
+)
+def test__get_api_endpoint():
+ api_override = "foo.com"
+ mock_client_cert_source = mock.Mock()
+ default_universe = BigtableTableAdminClient._DEFAULT_UNIVERSE
+ default_endpoint = BigtableTableAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=default_universe
+ )
+ mock_universe = "bar.com"
+ mock_endpoint = BigtableTableAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=mock_universe
+ )
+
+ assert (
+ BigtableTableAdminClient._get_api_endpoint(
+ api_override, mock_client_cert_source, default_universe, "always"
+ )
+ == api_override
+ )
+ assert (
+ BigtableTableAdminClient._get_api_endpoint(
+ None, mock_client_cert_source, default_universe, "auto"
+ )
+ == BigtableTableAdminClient.DEFAULT_MTLS_ENDPOINT
+ )
+ assert (
+ BigtableTableAdminClient._get_api_endpoint(None, None, default_universe, "auto")
+ == default_endpoint
+ )
+ assert (
+ BigtableTableAdminClient._get_api_endpoint(
+ None, None, default_universe, "always"
+ )
+ == BigtableTableAdminClient.DEFAULT_MTLS_ENDPOINT
+ )
+ assert (
+ BigtableTableAdminClient._get_api_endpoint(
+ None, mock_client_cert_source, default_universe, "always"
+ )
+ == BigtableTableAdminClient.DEFAULT_MTLS_ENDPOINT
+ )
+ assert (
+ BigtableTableAdminClient._get_api_endpoint(None, None, mock_universe, "never")
+ == mock_endpoint
+ )
+ assert (
+ BigtableTableAdminClient._get_api_endpoint(
+ None, None, default_universe, "never"
+ )
+ == default_endpoint
+ )
+
+ with pytest.raises(MutualTLSChannelError) as excinfo:
+ BigtableTableAdminClient._get_api_endpoint(
+ None, mock_client_cert_source, mock_universe, "auto"
+ )
+ assert (
+ str(excinfo.value)
+ == "mTLS is not supported in any universe other than googleapis.com."
+ )
+
+
+def test__get_universe_domain():
+ client_universe_domain = "foo.com"
+ universe_domain_env = "bar.com"
+
+ assert (
+ BigtableTableAdminClient._get_universe_domain(
+ client_universe_domain, universe_domain_env
+ )
+ == client_universe_domain
+ )
+ assert (
+ BigtableTableAdminClient._get_universe_domain(None, universe_domain_env)
+ == universe_domain_env
+ )
+ assert (
+ BigtableTableAdminClient._get_universe_domain(None, None)
+ == BigtableTableAdminClient._DEFAULT_UNIVERSE
+ )
+
+ with pytest.raises(ValueError) as excinfo:
+ BigtableTableAdminClient._get_universe_domain("", None)
+ assert str(excinfo.value) == "Universe Domain cannot be an empty string."
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name",
+ [
+ (BigtableTableAdminClient, transports.BigtableTableAdminGrpcTransport, "grpc"),
+ (BigtableTableAdminClient, transports.BigtableTableAdminRestTransport, "rest"),
+ ],
+)
+def test__validate_universe_domain(client_class, transport_class, transport_name):
+ client = client_class(
+ transport=transport_class(credentials=ga_credentials.AnonymousCredentials())
+ )
+ assert client._validate_universe_domain() == True
+
+ # Test the case when universe is already validated.
+ assert client._validate_universe_domain() == True
+
+ if transport_name == "grpc":
+ # Test the case where credentials are provided by the
+ # `local_channel_credentials`. The default universes in both match.
+ channel = grpc.secure_channel(
+ "http://localhost/", grpc.local_channel_credentials()
+ )
+ client = client_class(transport=transport_class(channel=channel))
+ assert client._validate_universe_domain() == True
+
+ # Test the case where credentials do not exist: e.g. a transport is provided
+ # with no credentials. Validation should still succeed because there is no
+ # mismatch with non-existent credentials.
+ channel = grpc.secure_channel(
+ "http://localhost/", grpc.local_channel_credentials()
+ )
+ transport = transport_class(channel=channel)
+ transport._credentials = None
+ client = client_class(transport=transport)
+ assert client._validate_universe_domain() == True
+
+ # TODO: This is needed to cater for older versions of google-auth
+ # Make this test unconditional once the minimum supported version of
+ # google-auth becomes 2.23.0 or higher.
+ google_auth_major, google_auth_minor = [
+ int(part) for part in google.auth.__version__.split(".")[0:2]
+ ]
+ if google_auth_major > 2 or (google_auth_major == 2 and google_auth_minor >= 23):
+ credentials = ga_credentials.AnonymousCredentials()
+ credentials._universe_domain = "foo.com"
+ # Test the case when there is a universe mismatch from the credentials.
+ client = client_class(transport=transport_class(credentials=credentials))
+ with pytest.raises(ValueError) as excinfo:
+ client._validate_universe_domain()
+ assert (
+ str(excinfo.value)
+ == "The configured universe domain (googleapis.com) does not match the universe domain found in the credentials (foo.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default."
+ )
+
+ # Test the case when there is a universe mismatch from the client.
+ #
+ # TODO: Make this test unconditional once the minimum supported version of
+ # google-api-core becomes 2.15.0 or higher.
+ api_core_major, api_core_minor = [
+ int(part) for part in api_core_version.__version__.split(".")[0:2]
+ ]
+ if api_core_major > 2 or (api_core_major == 2 and api_core_minor >= 15):
+ client = client_class(
+ client_options={"universe_domain": "bar.com"},
+ transport=transport_class(
+ credentials=ga_credentials.AnonymousCredentials(),
+ ),
+ )
+ with pytest.raises(ValueError) as excinfo:
+ client._validate_universe_domain()
+ assert (
+ str(excinfo.value)
+ == "The configured universe domain (bar.com) does not match the universe domain found in the credentials (googleapis.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default."
+ )
+
+ # Test that ValueError is raised if universe_domain is provided via client options and credentials is None
+ with pytest.raises(ValueError):
+ client._compare_universes("foo.bar", None)
+
+
@pytest.mark.parametrize(
"client_class,transport_name",
[
@@ -233,13 +525,13 @@ def test_bigtable_table_admin_client_get_transport_class():
)
@mock.patch.object(
BigtableTableAdminClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableTableAdminClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableTableAdminClient),
)
@mock.patch.object(
BigtableTableAdminAsyncClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableTableAdminAsyncClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableTableAdminAsyncClient),
)
def test_bigtable_table_admin_client_client_options(
client_class, transport_class, transport_name
@@ -281,7 +573,9 @@ def test_bigtable_table_admin_client_client_options(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -311,15 +605,23 @@ def test_bigtable_table_admin_client_client_options(
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
# unsupported value.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
- with pytest.raises(MutualTLSChannelError):
+ with pytest.raises(MutualTLSChannelError) as excinfo:
client = client_class(transport=transport_name)
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
# Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value.
with mock.patch.dict(
os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
):
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError) as excinfo:
client = client_class(transport=transport_name)
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
# Check the case quota_project_id is provided
options = client_options.ClientOptions(quota_project_id="octopus")
@@ -329,7 +631,9 @@ def test_bigtable_table_admin_client_client_options(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id="octopus",
@@ -347,7 +651,9 @@ def test_bigtable_table_admin_client_client_options(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -400,13 +706,13 @@ def test_bigtable_table_admin_client_client_options(
)
@mock.patch.object(
BigtableTableAdminClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableTableAdminClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableTableAdminClient),
)
@mock.patch.object(
BigtableTableAdminAsyncClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableTableAdminAsyncClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableTableAdminAsyncClient),
)
@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"})
def test_bigtable_table_admin_client_mtls_env_auto(
@@ -429,7 +735,9 @@ def test_bigtable_table_admin_client_mtls_env_auto(
if use_client_cert_env == "false":
expected_client_cert_source = None
- expected_host = client.DEFAULT_ENDPOINT
+ expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ )
else:
expected_client_cert_source = client_cert_source_callback
expected_host = client.DEFAULT_MTLS_ENDPOINT
@@ -461,7 +769,9 @@ def test_bigtable_table_admin_client_mtls_env_auto(
return_value=client_cert_source_callback,
):
if use_client_cert_env == "false":
- expected_host = client.DEFAULT_ENDPOINT
+ expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ )
expected_client_cert_source = None
else:
expected_host = client.DEFAULT_MTLS_ENDPOINT
@@ -495,7 +805,9 @@ def test_bigtable_table_admin_client_mtls_env_auto(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -585,6 +897,115 @@ def test_bigtable_table_admin_client_get_mtls_endpoint_and_cert_source(client_cl
assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
assert cert_source == mock_client_cert_source
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
+ # unsupported value.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
+ with pytest.raises(MutualTLSChannelError) as excinfo:
+ client_class.get_mtls_endpoint_and_cert_source()
+
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+
+ # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
+ ):
+ with pytest.raises(ValueError) as excinfo:
+ client_class.get_mtls_endpoint_and_cert_source()
+
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+
+
+@pytest.mark.parametrize(
+ "client_class", [BigtableTableAdminClient, BigtableTableAdminAsyncClient]
+)
+@mock.patch.object(
+ BigtableTableAdminClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableTableAdminClient),
+)
+@mock.patch.object(
+ BigtableTableAdminAsyncClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableTableAdminAsyncClient),
+)
+def test_bigtable_table_admin_client_client_api_endpoint(client_class):
+ mock_client_cert_source = client_cert_source_callback
+ api_override = "foo.com"
+ default_universe = BigtableTableAdminClient._DEFAULT_UNIVERSE
+ default_endpoint = BigtableTableAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=default_universe
+ )
+ mock_universe = "bar.com"
+ mock_endpoint = BigtableTableAdminClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=mock_universe
+ )
+
+ # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true",
+ # use ClientOptions.api_endpoint as the api endpoint regardless.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
+ with mock.patch(
+ "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel"
+ ):
+ options = client_options.ClientOptions(
+ client_cert_source=mock_client_cert_source, api_endpoint=api_override
+ )
+ client = client_class(
+ client_options=options,
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ assert client.api_endpoint == api_override
+
+ # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never",
+ # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ client = client_class(credentials=ga_credentials.AnonymousCredentials())
+ assert client.api_endpoint == default_endpoint
+
+ # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always",
+ # use the DEFAULT_MTLS_ENDPOINT as the api endpoint.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
+ client = client_class(credentials=ga_credentials.AnonymousCredentials())
+ assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
+
+ # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default),
+ # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist,
+ # and ClientOptions.universe_domain="bar.com",
+ # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint.
+ options = client_options.ClientOptions()
+ universe_exists = hasattr(options, "universe_domain")
+ if universe_exists:
+ options = client_options.ClientOptions(universe_domain=mock_universe)
+ client = client_class(
+ client_options=options, credentials=ga_credentials.AnonymousCredentials()
+ )
+ else:
+ client = client_class(
+ client_options=options, credentials=ga_credentials.AnonymousCredentials()
+ )
+ assert client.api_endpoint == (
+ mock_endpoint if universe_exists else default_endpoint
+ )
+ assert client.universe_domain == (
+ mock_universe if universe_exists else default_universe
+ )
+
+ # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never",
+ # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint.
+ options = client_options.ClientOptions()
+ if hasattr(options, "universe_domain"):
+ delattr(options, "universe_domain")
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ client = client_class(
+ client_options=options, credentials=ga_credentials.AnonymousCredentials()
+ )
+ assert client.api_endpoint == default_endpoint
+
@pytest.mark.parametrize(
"client_class,transport_class,transport_name",
@@ -611,7 +1032,9 @@ def test_bigtable_table_admin_client_client_options_scopes(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=["1", "2"],
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -656,7 +1079,9 @@ def test_bigtable_table_admin_client_client_options_credentials_file(
patched.assert_called_once_with(
credentials=None,
credentials_file="credentials.json",
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -716,7 +1141,9 @@ def test_bigtable_table_admin_client_create_channel_credentials_file(
patched.assert_called_once_with(
credentials=None,
credentials_file="credentials.json",
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -1513,7 +1940,7 @@ async def test_list_tables_flattened_error_async():
def test_list_tables_pager(transport_name: str = "grpc"):
client = BigtableTableAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -1563,7 +1990,7 @@ def test_list_tables_pager(transport_name: str = "grpc"):
def test_list_tables_pages(transport_name: str = "grpc"):
client = BigtableTableAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -1605,7 +2032,7 @@ def test_list_tables_pages(transport_name: str = "grpc"):
@pytest.mark.asyncio
async def test_list_tables_async_pager():
client = BigtableTableAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -1655,7 +2082,7 @@ async def test_list_tables_async_pager():
@pytest.mark.asyncio
async def test_list_tables_async_pages():
client = BigtableTableAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -4281,7 +4708,7 @@ async def test_list_snapshots_flattened_error_async():
def test_list_snapshots_pager(transport_name: str = "grpc"):
client = BigtableTableAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -4331,7 +4758,7 @@ def test_list_snapshots_pager(transport_name: str = "grpc"):
def test_list_snapshots_pages(transport_name: str = "grpc"):
client = BigtableTableAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -4373,7 +4800,7 @@ def test_list_snapshots_pages(transport_name: str = "grpc"):
@pytest.mark.asyncio
async def test_list_snapshots_async_pager():
client = BigtableTableAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -4423,7 +4850,7 @@ async def test_list_snapshots_async_pager():
@pytest.mark.asyncio
async def test_list_snapshots_async_pages():
client = BigtableTableAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -5892,7 +6319,7 @@ async def test_list_backups_flattened_error_async():
def test_list_backups_pager(transport_name: str = "grpc"):
client = BigtableTableAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -5942,7 +6369,7 @@ def test_list_backups_pager(transport_name: str = "grpc"):
def test_list_backups_pages(transport_name: str = "grpc"):
client = BigtableTableAdminClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
transport=transport_name,
)
@@ -5984,7 +6411,7 @@ def test_list_backups_pages(transport_name: str = "grpc"):
@pytest.mark.asyncio
async def test_list_backups_async_pager():
client = BigtableTableAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -6034,7 +6461,7 @@ async def test_list_backups_async_pager():
@pytest.mark.asyncio
async def test_list_backups_async_pages():
client = BigtableTableAdminAsyncClient(
- credentials=ga_credentials.AnonymousCredentials,
+ credentials=ga_credentials.AnonymousCredentials(),
)
# Mock the actual call within the gRPC stub, and fake the request.
@@ -14596,7 +15023,7 @@ def test_credentials_transport_error():
)
# It is an error to provide an api_key and a credential.
- options = mock.Mock()
+ options = client_options.ClientOptions()
options.api_key = "api_key"
with pytest.raises(ValueError):
client = BigtableTableAdminClient(
@@ -15641,7 +16068,9 @@ def test_api_key_credentials(client_class, transport_class):
patched.assert_called_once_with(
credentials=mock_cred,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
diff --git a/tests/unit/gapic/bigtable_v2/test_bigtable.py b/tests/unit/gapic/bigtable_v2/test_bigtable.py
index 2319306d7..ab05af426 100644
--- a/tests/unit/gapic/bigtable_v2/test_bigtable.py
+++ b/tests/unit/gapic/bigtable_v2/test_bigtable.py
@@ -29,6 +29,7 @@
import json
import math
import pytest
+from google.api_core import api_core_version
from proto.marshal.rules.dates import DurationRule, TimestampRule
from proto.marshal.rules import wrappers
from requests import Response
@@ -71,6 +72,17 @@ def modify_default_endpoint(client):
)
+# If default endpoint template is localhost, then default mtls endpoint will be the same.
+# This method modifies the default endpoint template so the client can produce a different
+# mtls endpoint for endpoint testing purposes.
+def modify_default_endpoint_template(client):
+ return (
+ "test.{UNIVERSE_DOMAIN}"
+ if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE)
+ else client._DEFAULT_ENDPOINT_TEMPLATE
+ )
+
+
def test__get_default_mtls_endpoint():
api_endpoint = "example.googleapis.com"
api_mtls_endpoint = "example.mtls.googleapis.com"
@@ -95,6 +107,251 @@ def test__get_default_mtls_endpoint():
assert BigtableClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi
+def test__read_environment_variables():
+ assert BigtableClient._read_environment_variables() == (False, "auto", None)
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
+ assert BigtableClient._read_environment_variables() == (True, "auto", None)
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}):
+ assert BigtableClient._read_environment_variables() == (False, "auto", None)
+
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
+ ):
+ with pytest.raises(ValueError) as excinfo:
+ BigtableClient._read_environment_variables()
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ assert BigtableClient._read_environment_variables() == (False, "never", None)
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
+ assert BigtableClient._read_environment_variables() == (False, "always", None)
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
+ assert BigtableClient._read_environment_variables() == (False, "auto", None)
+
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
+ with pytest.raises(MutualTLSChannelError) as excinfo:
+ BigtableClient._read_environment_variables()
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+
+ with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}):
+ assert BigtableClient._read_environment_variables() == (
+ False,
+ "auto",
+ "foo.com",
+ )
+
+
+def test__get_client_cert_source():
+ mock_provided_cert_source = mock.Mock()
+ mock_default_cert_source = mock.Mock()
+
+ assert BigtableClient._get_client_cert_source(None, False) is None
+ assert (
+ BigtableClient._get_client_cert_source(mock_provided_cert_source, False) is None
+ )
+ assert (
+ BigtableClient._get_client_cert_source(mock_provided_cert_source, True)
+ == mock_provided_cert_source
+ )
+
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source", return_value=True
+ ):
+ with mock.patch(
+ "google.auth.transport.mtls.default_client_cert_source",
+ return_value=mock_default_cert_source,
+ ):
+ assert (
+ BigtableClient._get_client_cert_source(None, True)
+ is mock_default_cert_source
+ )
+ assert (
+ BigtableClient._get_client_cert_source(
+ mock_provided_cert_source, "true"
+ )
+ is mock_provided_cert_source
+ )
+
+
+@mock.patch.object(
+ BigtableClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableClient),
+)
+@mock.patch.object(
+ BigtableAsyncClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableAsyncClient),
+)
+def test__get_api_endpoint():
+ api_override = "foo.com"
+ mock_client_cert_source = mock.Mock()
+ default_universe = BigtableClient._DEFAULT_UNIVERSE
+ default_endpoint = BigtableClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=default_universe
+ )
+ mock_universe = "bar.com"
+ mock_endpoint = BigtableClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=mock_universe
+ )
+
+ assert (
+ BigtableClient._get_api_endpoint(
+ api_override, mock_client_cert_source, default_universe, "always"
+ )
+ == api_override
+ )
+ assert (
+ BigtableClient._get_api_endpoint(
+ None, mock_client_cert_source, default_universe, "auto"
+ )
+ == BigtableClient.DEFAULT_MTLS_ENDPOINT
+ )
+ assert (
+ BigtableClient._get_api_endpoint(None, None, default_universe, "auto")
+ == default_endpoint
+ )
+ assert (
+ BigtableClient._get_api_endpoint(None, None, default_universe, "always")
+ == BigtableClient.DEFAULT_MTLS_ENDPOINT
+ )
+ assert (
+ BigtableClient._get_api_endpoint(
+ None, mock_client_cert_source, default_universe, "always"
+ )
+ == BigtableClient.DEFAULT_MTLS_ENDPOINT
+ )
+ assert (
+ BigtableClient._get_api_endpoint(None, None, mock_universe, "never")
+ == mock_endpoint
+ )
+ assert (
+ BigtableClient._get_api_endpoint(None, None, default_universe, "never")
+ == default_endpoint
+ )
+
+ with pytest.raises(MutualTLSChannelError) as excinfo:
+ BigtableClient._get_api_endpoint(
+ None, mock_client_cert_source, mock_universe, "auto"
+ )
+ assert (
+ str(excinfo.value)
+ == "mTLS is not supported in any universe other than googleapis.com."
+ )
+
+
+def test__get_universe_domain():
+ client_universe_domain = "foo.com"
+ universe_domain_env = "bar.com"
+
+ assert (
+ BigtableClient._get_universe_domain(client_universe_domain, universe_domain_env)
+ == client_universe_domain
+ )
+ assert (
+ BigtableClient._get_universe_domain(None, universe_domain_env)
+ == universe_domain_env
+ )
+ assert (
+ BigtableClient._get_universe_domain(None, None)
+ == BigtableClient._DEFAULT_UNIVERSE
+ )
+
+ with pytest.raises(ValueError) as excinfo:
+ BigtableClient._get_universe_domain("", None)
+ assert str(excinfo.value) == "Universe Domain cannot be an empty string."
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name",
+ [
+ (BigtableClient, transports.BigtableGrpcTransport, "grpc"),
+ (BigtableClient, transports.BigtableRestTransport, "rest"),
+ ],
+)
+def test__validate_universe_domain(client_class, transport_class, transport_name):
+ client = client_class(
+ transport=transport_class(credentials=ga_credentials.AnonymousCredentials())
+ )
+ assert client._validate_universe_domain() == True
+
+ # Test the case when universe is already validated.
+ assert client._validate_universe_domain() == True
+
+ if transport_name == "grpc":
+ # Test the case where credentials are provided by the
+ # `local_channel_credentials`. The default universes in both match.
+ channel = grpc.secure_channel(
+ "http://localhost/", grpc.local_channel_credentials()
+ )
+ client = client_class(transport=transport_class(channel=channel))
+ assert client._validate_universe_domain() == True
+
+ # Test the case where credentials do not exist: e.g. a transport is provided
+ # with no credentials. Validation should still succeed because there is no
+ # mismatch with non-existent credentials.
+ channel = grpc.secure_channel(
+ "http://localhost/", grpc.local_channel_credentials()
+ )
+ transport = transport_class(channel=channel)
+ transport._credentials = None
+ client = client_class(transport=transport)
+ assert client._validate_universe_domain() == True
+
+ # TODO: This is needed to cater for older versions of google-auth
+ # Make this test unconditional once the minimum supported version of
+ # google-auth becomes 2.23.0 or higher.
+ google_auth_major, google_auth_minor = [
+ int(part) for part in google.auth.__version__.split(".")[0:2]
+ ]
+ if google_auth_major > 2 or (google_auth_major == 2 and google_auth_minor >= 23):
+ credentials = ga_credentials.AnonymousCredentials()
+ credentials._universe_domain = "foo.com"
+ # Test the case when there is a universe mismatch from the credentials.
+ client = client_class(transport=transport_class(credentials=credentials))
+ with pytest.raises(ValueError) as excinfo:
+ client._validate_universe_domain()
+ assert (
+ str(excinfo.value)
+ == "The configured universe domain (googleapis.com) does not match the universe domain found in the credentials (foo.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default."
+ )
+
+ # Test the case when there is a universe mismatch from the client.
+ #
+ # TODO: Make this test unconditional once the minimum supported version of
+ # google-api-core becomes 2.15.0 or higher.
+ api_core_major, api_core_minor = [
+ int(part) for part in api_core_version.__version__.split(".")[0:2]
+ ]
+ if api_core_major > 2 or (api_core_major == 2 and api_core_minor >= 15):
+ client = client_class(
+ client_options={"universe_domain": "bar.com"},
+ transport=transport_class(
+ credentials=ga_credentials.AnonymousCredentials(),
+ ),
+ )
+ with pytest.raises(ValueError) as excinfo:
+ client._validate_universe_domain()
+ assert (
+ str(excinfo.value)
+ == "The configured universe domain (bar.com) does not match the universe domain found in the credentials (googleapis.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default."
+ )
+
+ # Test that ValueError is raised if universe_domain is provided via client options and credentials is None
+ with pytest.raises(ValueError):
+ client._compare_universes("foo.bar", None)
+
+
@pytest.mark.parametrize(
"client_class,transport_name",
[
@@ -201,12 +458,14 @@ def test_bigtable_client_get_transport_class():
],
)
@mock.patch.object(
- BigtableClient, "DEFAULT_ENDPOINT", modify_default_endpoint(BigtableClient)
+ BigtableClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableClient),
)
@mock.patch.object(
BigtableAsyncClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableAsyncClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableAsyncClient),
)
def test_bigtable_client_client_options(client_class, transport_class, transport_name):
# Check that if channel is provided we won't create a new one.
@@ -246,7 +505,9 @@ def test_bigtable_client_client_options(client_class, transport_class, transport
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -276,15 +537,23 @@ def test_bigtable_client_client_options(client_class, transport_class, transport
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
# unsupported value.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
- with pytest.raises(MutualTLSChannelError):
+ with pytest.raises(MutualTLSChannelError) as excinfo:
client = client_class(transport=transport_name)
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
# Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value.
with mock.patch.dict(
os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
):
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError) as excinfo:
client = client_class(transport=transport_name)
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
# Check the case quota_project_id is provided
options = client_options.ClientOptions(quota_project_id="octopus")
@@ -294,7 +563,9 @@ def test_bigtable_client_client_options(client_class, transport_class, transport
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id="octopus",
@@ -312,7 +583,9 @@ def test_bigtable_client_client_options(client_class, transport_class, transport
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -344,12 +617,14 @@ def test_bigtable_client_client_options(client_class, transport_class, transport
],
)
@mock.patch.object(
- BigtableClient, "DEFAULT_ENDPOINT", modify_default_endpoint(BigtableClient)
+ BigtableClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableClient),
)
@mock.patch.object(
BigtableAsyncClient,
- "DEFAULT_ENDPOINT",
- modify_default_endpoint(BigtableAsyncClient),
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableAsyncClient),
)
@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"})
def test_bigtable_client_mtls_env_auto(
@@ -372,7 +647,9 @@ def test_bigtable_client_mtls_env_auto(
if use_client_cert_env == "false":
expected_client_cert_source = None
- expected_host = client.DEFAULT_ENDPOINT
+ expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ )
else:
expected_client_cert_source = client_cert_source_callback
expected_host = client.DEFAULT_MTLS_ENDPOINT
@@ -404,7 +681,9 @@ def test_bigtable_client_mtls_env_auto(
return_value=client_cert_source_callback,
):
if use_client_cert_env == "false":
- expected_host = client.DEFAULT_ENDPOINT
+ expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ )
expected_client_cert_source = None
else:
expected_host = client.DEFAULT_MTLS_ENDPOINT
@@ -438,7 +717,9 @@ def test_bigtable_client_mtls_env_auto(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -524,6 +805,113 @@ def test_bigtable_client_get_mtls_endpoint_and_cert_source(client_class):
assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
assert cert_source == mock_client_cert_source
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
+ # unsupported value.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
+ with pytest.raises(MutualTLSChannelError) as excinfo:
+ client_class.get_mtls_endpoint_and_cert_source()
+
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+
+ # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
+ ):
+ with pytest.raises(ValueError) as excinfo:
+ client_class.get_mtls_endpoint_and_cert_source()
+
+ assert (
+ str(excinfo.value)
+ == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+
+
+@pytest.mark.parametrize("client_class", [BigtableClient, BigtableAsyncClient])
+@mock.patch.object(
+ BigtableClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableClient),
+)
+@mock.patch.object(
+ BigtableAsyncClient,
+ "_DEFAULT_ENDPOINT_TEMPLATE",
+ modify_default_endpoint_template(BigtableAsyncClient),
+)
+def test_bigtable_client_client_api_endpoint(client_class):
+ mock_client_cert_source = client_cert_source_callback
+ api_override = "foo.com"
+ default_universe = BigtableClient._DEFAULT_UNIVERSE
+ default_endpoint = BigtableClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=default_universe
+ )
+ mock_universe = "bar.com"
+ mock_endpoint = BigtableClient._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=mock_universe
+ )
+
+ # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true",
+ # use ClientOptions.api_endpoint as the api endpoint regardless.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
+ with mock.patch(
+ "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel"
+ ):
+ options = client_options.ClientOptions(
+ client_cert_source=mock_client_cert_source, api_endpoint=api_override
+ )
+ client = client_class(
+ client_options=options,
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ assert client.api_endpoint == api_override
+
+ # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never",
+ # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ client = client_class(credentials=ga_credentials.AnonymousCredentials())
+ assert client.api_endpoint == default_endpoint
+
+ # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always",
+ # use the DEFAULT_MTLS_ENDPOINT as the api endpoint.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
+ client = client_class(credentials=ga_credentials.AnonymousCredentials())
+ assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
+
+ # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default),
+ # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist,
+ # and ClientOptions.universe_domain="bar.com",
+ # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint.
+ options = client_options.ClientOptions()
+ universe_exists = hasattr(options, "universe_domain")
+ if universe_exists:
+ options = client_options.ClientOptions(universe_domain=mock_universe)
+ client = client_class(
+ client_options=options, credentials=ga_credentials.AnonymousCredentials()
+ )
+ else:
+ client = client_class(
+ client_options=options, credentials=ga_credentials.AnonymousCredentials()
+ )
+ assert client.api_endpoint == (
+ mock_endpoint if universe_exists else default_endpoint
+ )
+ assert client.universe_domain == (
+ mock_universe if universe_exists else default_universe
+ )
+
+ # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never",
+ # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint.
+ options = client_options.ClientOptions()
+ if hasattr(options, "universe_domain"):
+ delattr(options, "universe_domain")
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ client = client_class(
+ client_options=options, credentials=ga_credentials.AnonymousCredentials()
+ )
+ assert client.api_endpoint == default_endpoint
+
@pytest.mark.parametrize(
"client_class,transport_class,transport_name",
@@ -546,7 +934,9 @@ def test_bigtable_client_client_options_scopes(
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=["1", "2"],
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -581,7 +971,9 @@ def test_bigtable_client_client_options_credentials_file(
patched.assert_called_once_with(
credentials=None,
credentials_file="credentials.json",
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -634,7 +1026,9 @@ def test_bigtable_client_create_channel_credentials_file(
patched.assert_called_once_with(
credentials=None,
credentials_file="credentials.json",
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
@@ -5639,7 +6033,7 @@ def test_credentials_transport_error():
)
# It is an error to provide an api_key and a credential.
- options = mock.Mock()
+ options = client_options.ClientOptions()
options.api_key = "api_key"
with pytest.raises(ValueError):
client = BigtableClient(
@@ -6428,7 +6822,9 @@ def test_api_key_credentials(client_class, transport_class):
patched.assert_called_once_with(
credentials=mock_cred,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=client._DEFAULT_ENDPOINT_TEMPLATE.format(
+ UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE
+ ),
scopes=None,
client_cert_source_for_mtls=None,
quota_project_id=None,
diff --git a/tests/unit/v2_client/__init__.py b/tests/unit/v2_client/__init__.py
new file mode 100644
index 000000000..e8e1c3845
--- /dev/null
+++ b/tests/unit/v2_client/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/tests/unit/_testing.py b/tests/unit/v2_client/_testing.py
similarity index 100%
rename from tests/unit/_testing.py
rename to tests/unit/v2_client/_testing.py
diff --git a/tests/unit/v2_client/read-rows-acceptance-test.json b/tests/unit/v2_client/read-rows-acceptance-test.json
new file mode 100644
index 000000000..011ace2b9
--- /dev/null
+++ b/tests/unit/v2_client/read-rows-acceptance-test.json
@@ -0,0 +1,1665 @@
+{
+ "readRowsTests": [
+ {
+ "description": "invalid - no commit",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - no cell key before commit",
+ "chunks": [
+ {
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - no cell key before value",
+ "chunks": [
+ {
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - new col family must specify qualifier",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "familyName": "B",
+ "timestampMicros": "98",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "bare commit implies ts=0",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ },
+ {
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C"
+ }
+ ]
+ },
+ {
+ "description": "simple row with timestamp",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ }
+ ]
+ },
+ {
+ "description": "missing timestamp, implied ts=0",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "value": "value-VAL"
+ }
+ ]
+ },
+ {
+ "description": "empty cell value",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C"
+ }
+ ]
+ },
+ {
+ "description": "two unsplit cells",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "timestampMicros": "98",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "two qualifiers",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "qualifier": "RA==",
+ "timestampMicros": "98",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "D",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "two families",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "familyName": "B",
+ "qualifier": "RQ==",
+ "timestampMicros": "98",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "B",
+ "qualifier": "E",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "with labels",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "labels": [
+ "L_1"
+ ],
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "timestampMicros": "98",
+ "labels": [
+ "L_2"
+ ],
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1",
+ "label": "L_1"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "98",
+ "value": "value-VAL_2",
+ "label": "L_2"
+ }
+ ]
+ },
+ {
+ "description": "split cell, bare commit",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dg==",
+ "valueSize": 9,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUw=",
+ "commitRow": false
+ },
+ {
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C"
+ }
+ ]
+ },
+ {
+ "description": "split cell",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dg==",
+ "valueSize": 9,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUw=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ }
+ ]
+ },
+ {
+ "description": "split four ways",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "labels": [
+ "L"
+ ],
+ "value": "dg==",
+ "valueSize": 9,
+ "commitRow": false
+ },
+ {
+ "value": "YQ==",
+ "valueSize": 9,
+ "commitRow": false
+ },
+ {
+ "value": "bA==",
+ "valueSize": 9,
+ "commitRow": false
+ },
+ {
+ "value": "dWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL",
+ "label": "L"
+ }
+ ]
+ },
+ {
+ "description": "two split cells",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUxfMQ==",
+ "commitRow": false
+ },
+ {
+ "timestampMicros": "98",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUxfMg==",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "multi-qualifier splits",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUxfMQ==",
+ "commitRow": false
+ },
+ {
+ "qualifier": "RA==",
+ "timestampMicros": "98",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUxfMg==",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "D",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "multi-qualifier multi-split",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YQ==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "bHVlLVZBTF8x",
+ "commitRow": false
+ },
+ {
+ "qualifier": "RA==",
+ "timestampMicros": "98",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YQ==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "bHVlLVZBTF8y",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "D",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "multi-family split",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUxfMQ==",
+ "commitRow": false
+ },
+ {
+ "familyName": "B",
+ "qualifier": "RQ==",
+ "timestampMicros": "98",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUxfMg==",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "B",
+ "qualifier": "E",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "invalid - no commit between rows",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - no commit after first row",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - last row missing commit",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ },
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - duplicate row key",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "B",
+ "qualifier": "RA==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ },
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - new row missing row key",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ },
+ {
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ },
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "two rows",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ }
+ ]
+ },
+ {
+ "description": "two rows implicit timestamp",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "value": "value-VAL"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ }
+ ]
+ },
+ {
+ "description": "two rows empty value",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ }
+ ]
+ },
+ {
+ "description": "two rows, one with multiple cells",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "timestampMicros": "98",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "B",
+ "qualifier": "RA==",
+ "timestampMicros": "97",
+ "value": "dmFsdWUtVkFMXzM=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "B",
+ "qualifier": "D",
+ "timestampMicros": "97",
+ "value": "value-VAL_3"
+ }
+ ]
+ },
+ {
+ "description": "two rows, multiple cells",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "qualifier": "RA==",
+ "timestampMicros": "98",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "B",
+ "qualifier": "RQ==",
+ "timestampMicros": "97",
+ "value": "dmFsdWUtVkFMXzM=",
+ "commitRow": false
+ },
+ {
+ "qualifier": "Rg==",
+ "timestampMicros": "96",
+ "value": "dmFsdWUtVkFMXzQ=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "D",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "B",
+ "qualifier": "E",
+ "timestampMicros": "97",
+ "value": "value-VAL_3"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "B",
+ "qualifier": "F",
+ "timestampMicros": "96",
+ "value": "value-VAL_4"
+ }
+ ]
+ },
+ {
+ "description": "two rows, multiple cells, multiple families",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "familyName": "B",
+ "qualifier": "RQ==",
+ "timestampMicros": "98",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "M",
+ "qualifier": "Tw==",
+ "timestampMicros": "97",
+ "value": "dmFsdWUtVkFMXzM=",
+ "commitRow": false
+ },
+ {
+ "familyName": "N",
+ "qualifier": "UA==",
+ "timestampMicros": "96",
+ "value": "dmFsdWUtVkFMXzQ=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK_1",
+ "familyName": "B",
+ "qualifier": "E",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "M",
+ "qualifier": "O",
+ "timestampMicros": "97",
+ "value": "value-VAL_3"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "N",
+ "qualifier": "P",
+ "timestampMicros": "96",
+ "value": "value-VAL_4"
+ }
+ ]
+ },
+ {
+ "description": "two rows, four cells, 2 labels",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "99",
+ "labels": [
+ "L_1"
+ ],
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "timestampMicros": "98",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "B",
+ "qualifier": "RA==",
+ "timestampMicros": "97",
+ "labels": [
+ "L_3"
+ ],
+ "value": "dmFsdWUtVkFMXzM=",
+ "commitRow": false
+ },
+ {
+ "timestampMicros": "96",
+ "value": "dmFsdWUtVkFMXzQ=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "99",
+ "value": "value-VAL_1",
+ "label": "L_1"
+ },
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "98",
+ "value": "value-VAL_2"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "B",
+ "qualifier": "D",
+ "timestampMicros": "97",
+ "value": "value-VAL_3",
+ "label": "L_3"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "B",
+ "qualifier": "D",
+ "timestampMicros": "96",
+ "value": "value-VAL_4"
+ }
+ ]
+ },
+ {
+ "description": "two rows with splits, same timestamp",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUxfMQ==",
+ "commitRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dg==",
+ "valueSize": 11,
+ "commitRow": false
+ },
+ {
+ "value": "YWx1ZS1WQUxfMg==",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL_1"
+ },
+ {
+ "rowKey": "RK_2",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "invalid - bare reset",
+ "chunks": [
+ {
+ "resetRow": true
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - bad reset, no commit",
+ "chunks": [
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - missing key after reset",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "no data after reset",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ }
+ ]
+ },
+ {
+ "description": "simple reset",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ }
+ ]
+ },
+ {
+ "description": "reset to new val",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "reset to new qual",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "RA==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "D",
+ "timestampMicros": "100",
+ "value": "value-VAL_1"
+ }
+ ]
+ },
+ {
+ "description": "reset with splits",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "timestampMicros": "98",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "reset two cells",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": false
+ },
+ {
+ "timestampMicros": "97",
+ "value": "dmFsdWUtVkFMXzM=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL_2"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "97",
+ "value": "value-VAL_3"
+ }
+ ]
+ },
+ {
+ "description": "two resets",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzM=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL_3"
+ }
+ ]
+ },
+ {
+ "description": "reset then two cells",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "Uks=",
+ "familyName": "B",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": false
+ },
+ {
+ "qualifier": "RA==",
+ "timestampMicros": "97",
+ "value": "dmFsdWUtVkFMXzM=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "B",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL_2"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "B",
+ "qualifier": "D",
+ "timestampMicros": "97",
+ "value": "value-VAL_3"
+ }
+ ]
+ },
+ {
+ "description": "reset to new row",
+ "chunks": [
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "UktfMg==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzI=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_2",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL_2"
+ }
+ ]
+ },
+ {
+ "description": "reset in between chunks",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "labels": [
+ "L"
+ ],
+ "value": "dg==",
+ "valueSize": 10,
+ "commitRow": false
+ },
+ {
+ "value": "YQ==",
+ "valueSize": 10,
+ "commitRow": false
+ },
+ {
+ "resetRow": true
+ },
+ {
+ "rowKey": "UktfMQ==",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFMXzE=",
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK_1",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL_1"
+ }
+ ]
+ },
+ {
+ "description": "invalid - reset with chunk",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "labels": [
+ "L"
+ ],
+ "value": "dg==",
+ "valueSize": 10,
+ "commitRow": false
+ },
+ {
+ "value": "YQ==",
+ "valueSize": 10,
+ "resetRow": true
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "invalid - commit with chunk",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "labels": [
+ "L"
+ ],
+ "value": "dg==",
+ "valueSize": 10,
+ "commitRow": false
+ },
+ {
+ "value": "YQ==",
+ "valueSize": 10,
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "error": true
+ }
+ ]
+ },
+ {
+ "description": "empty cell chunk",
+ "chunks": [
+ {
+ "rowKey": "Uks=",
+ "familyName": "A",
+ "qualifier": "Qw==",
+ "timestampMicros": "100",
+ "value": "dmFsdWUtVkFM",
+ "commitRow": false
+ },
+ {
+ "commitRow": false
+ },
+ {
+ "commitRow": true
+ }
+ ],
+ "results": [
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C",
+ "timestampMicros": "100",
+ "value": "value-VAL"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C"
+ },
+ {
+ "rowKey": "RK",
+ "familyName": "A",
+ "qualifier": "C"
+ }
+ ]
+ }
+ ]
+}
diff --git a/tests/unit/test_app_profile.py b/tests/unit/v2_client/test_app_profile.py
similarity index 100%
rename from tests/unit/test_app_profile.py
rename to tests/unit/v2_client/test_app_profile.py
diff --git a/tests/unit/test_backup.py b/tests/unit/v2_client/test_backup.py
similarity index 100%
rename from tests/unit/test_backup.py
rename to tests/unit/v2_client/test_backup.py
diff --git a/tests/unit/test_batcher.py b/tests/unit/v2_client/test_batcher.py
similarity index 98%
rename from tests/unit/test_batcher.py
rename to tests/unit/v2_client/test_batcher.py
index 741d9f282..fcf606972 100644
--- a/tests/unit/test_batcher.py
+++ b/tests/unit/v2_client/test_batcher.py
@@ -198,7 +198,7 @@ def test_mutations_batcher_response_with_error_codes():
mocked_response = [Status(code=1), Status(code=5)]
- with mock.patch("tests.unit.test_batcher._Table") as mocked_table:
+ with mock.patch("tests.unit.v2_client.test_batcher._Table") as mocked_table:
table = mocked_table.return_value
mutation_batcher = MutationsBatcher(table=table)
diff --git a/tests/unit/test_client.py b/tests/unit/v2_client/test_client.py
similarity index 100%
rename from tests/unit/test_client.py
rename to tests/unit/v2_client/test_client.py
diff --git a/tests/unit/test_cluster.py b/tests/unit/v2_client/test_cluster.py
similarity index 100%
rename from tests/unit/test_cluster.py
rename to tests/unit/v2_client/test_cluster.py
diff --git a/tests/unit/test_column_family.py b/tests/unit/v2_client/test_column_family.py
similarity index 99%
rename from tests/unit/test_column_family.py
rename to tests/unit/v2_client/test_column_family.py
index 80b05d744..e4f74e264 100644
--- a/tests/unit/test_column_family.py
+++ b/tests/unit/v2_client/test_column_family.py
@@ -336,7 +336,7 @@ def _create_test_helper(gc_rule=None):
from google.cloud.bigtable_admin_v2.types import (
bigtable_table_admin as table_admin_v2_pb2,
)
- from tests.unit._testing import _FakeStub
+ from ._testing import _FakeStub
from google.cloud.bigtable_admin_v2.services.bigtable_table_admin import (
BigtableTableAdminClient,
)
@@ -404,7 +404,7 @@ def test_column_family_create_with_gc_rule():
def _update_test_helper(gc_rule=None):
- from tests.unit._testing import _FakeStub
+ from ._testing import _FakeStub
from google.cloud.bigtable_admin_v2.types import (
bigtable_table_admin as table_admin_v2_pb2,
)
@@ -478,7 +478,7 @@ def test_column_family_delete():
from google.cloud.bigtable_admin_v2.types import (
bigtable_table_admin as table_admin_v2_pb2,
)
- from tests.unit._testing import _FakeStub
+ from ._testing import _FakeStub
from google.cloud.bigtable_admin_v2.services.bigtable_table_admin import (
BigtableTableAdminClient,
)
diff --git a/tests/unit/test_encryption_info.py b/tests/unit/v2_client/test_encryption_info.py
similarity index 100%
rename from tests/unit/test_encryption_info.py
rename to tests/unit/v2_client/test_encryption_info.py
diff --git a/tests/unit/test_error.py b/tests/unit/v2_client/test_error.py
similarity index 100%
rename from tests/unit/test_error.py
rename to tests/unit/v2_client/test_error.py
diff --git a/tests/unit/test_instance.py b/tests/unit/v2_client/test_instance.py
similarity index 100%
rename from tests/unit/test_instance.py
rename to tests/unit/v2_client/test_instance.py
diff --git a/tests/unit/test_policy.py b/tests/unit/v2_client/test_policy.py
similarity index 100%
rename from tests/unit/test_policy.py
rename to tests/unit/v2_client/test_policy.py
diff --git a/tests/unit/test_row.py b/tests/unit/v2_client/test_row.py
similarity index 99%
rename from tests/unit/test_row.py
rename to tests/unit/v2_client/test_row.py
index 49bbfc45c..f04802f5c 100644
--- a/tests/unit/test_row.py
+++ b/tests/unit/v2_client/test_row.py
@@ -480,7 +480,7 @@ def test_conditional_row_commit_too_many_mutations():
def test_conditional_row_commit_no_mutations():
- from tests.unit._testing import _FakeStub
+ from ._testing import _FakeStub
project_id = "project-id"
row_key = b"row_key"
@@ -607,7 +607,7 @@ def mock_parse_rmw_row_response(row_response):
def test_append_row_commit_no_rules():
- from tests.unit._testing import _FakeStub
+ from ._testing import _FakeStub
project_id = "project-id"
row_key = b"row_key"
diff --git a/tests/unit/test_row_data.py b/tests/unit/v2_client/test_row_data.py
similarity index 97%
rename from tests/unit/test_row_data.py
rename to tests/unit/v2_client/test_row_data.py
index 9f2c40a54..7c2987b56 100644
--- a/tests/unit/test_row_data.py
+++ b/tests/unit/v2_client/test_row_data.py
@@ -362,6 +362,30 @@ def test__retry_read_rows_exception_deadline_exceeded_wrapped_in_grpc():
assert _retry_read_rows_exception(exception)
+def test_partial_cell_data():
+ from google.cloud.bigtable.row_data import PartialCellData
+
+ expected_key = b"row-key"
+ expected_family_name = b"family-name"
+ expected_qualifier = b"qualifier"
+ expected_timestamp = 1234
+ instance = PartialCellData(
+ expected_key, expected_family_name, expected_qualifier, expected_timestamp
+ )
+ assert instance.row_key == expected_key
+ assert instance.family_name == expected_family_name
+ assert instance.qualifier == expected_qualifier
+ assert instance.timestamp_micros == expected_timestamp
+ assert instance.value == b""
+ assert instance.labels == ()
+ # test updating value
+ added_value = b"added-value"
+ instance.append_value(added_value)
+ assert instance.value == added_value
+ instance.append_value(added_value)
+ assert instance.value == added_value + added_value
+
+
def _make_partial_rows_data(*args, **kwargs):
from google.cloud.bigtable.row_data import PartialRowsData
diff --git a/tests/unit/test_row_filters.py b/tests/unit/v2_client/test_row_filters.py
similarity index 100%
rename from tests/unit/test_row_filters.py
rename to tests/unit/v2_client/test_row_filters.py
diff --git a/tests/unit/test_row_merger.py b/tests/unit/v2_client/test_row_merger.py
similarity index 100%
rename from tests/unit/test_row_merger.py
rename to tests/unit/v2_client/test_row_merger.py
diff --git a/tests/unit/test_row_set.py b/tests/unit/v2_client/test_row_set.py
similarity index 100%
rename from tests/unit/test_row_set.py
rename to tests/unit/v2_client/test_row_set.py
diff --git a/tests/unit/test_table.py b/tests/unit/v2_client/test_table.py
similarity index 100%
rename from tests/unit/test_table.py
rename to tests/unit/v2_client/test_table.py