Skip to content

Commit e117a1a

Browse files
feat: Add support for increasing partitions in python (#74)
* Add support for increasing partitions in python * updates to address comments * addressing comments
1 parent b5ffc42 commit e117a1a

File tree

8 files changed

+453
-12
lines changed

8 files changed

+453
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ docs.metadata
5050

5151
# Virtual environment
5252
env/
53+
venv/
5354
coverage.xml
5455
sponge_log.xml
5556

google/cloud/pubsublite/internal/wire/make_publisher.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import AsyncIterator, Mapping, Optional, MutableMapping
15+
from typing import AsyncIterator, Mapping, Optional
1616

1717
from google.cloud.pubsub_v1.types import BatchSettings
1818

@@ -25,8 +25,13 @@
2525
GapicConnectionFactory,
2626
)
2727
from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata
28+
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
29+
PartitionCountWatcherImpl,
30+
)
31+
from google.cloud.pubsublite.internal.wire.partition_count_watching_publisher import (
32+
PartitionCountWatchingPublisher,
33+
)
2834
from google.cloud.pubsublite.internal.wire.publisher import Publisher
29-
from google.cloud.pubsublite.internal.wire.routing_publisher import RoutingPublisher
3035
from google.cloud.pubsublite.internal.wire.single_partition_publisher import (
3136
SinglePartitionPublisher,
3237
)
@@ -37,14 +42,14 @@
3742
from google.api_core.client_options import ClientOptions
3843
from google.auth.credentials import Credentials
3944

40-
4145
DEFAULT_BATCHING_SETTINGS = BatchSettings(
4246
max_bytes=(
4347
3 * 1024 * 1024
4448
), # 3 MiB to stay 1 MiB below GRPC's 4 MiB per-message limit.
4549
max_messages=1000,
4650
max_latency=0.05, # 50 ms
4751
)
52+
DEFAULT_PARTITION_POLL_PERIOD = 600 # ten minutes
4853

4954

5055
def make_publisher(
@@ -87,21 +92,24 @@ def make_publisher(
8792
credentials=credentials, transport=transport, client_options=client_options
8893
) # type: ignore
8994

90-
clients: MutableMapping[Partition, Publisher] = {}
91-
92-
partition_count = admin_client.get_topic_partition_count(topic)
93-
for partition in range(partition_count):
94-
partition = Partition(partition)
95-
95+
def publisher_factory(partition: Partition):
9696
def connection_factory(requests: AsyncIterator[PublishRequest]):
9797
final_metadata = merge_metadata(
9898
metadata, topic_routing_metadata(topic, partition)
9999
)
100100
return client.publish(requests, metadata=list(final_metadata.items()))
101101

102-
clients[partition] = SinglePartitionPublisher(
102+
return SinglePartitionPublisher(
103103
InitialPublishRequest(topic=str(topic), partition=partition.value),
104104
per_partition_batching_settings,
105105
GapicConnectionFactory(connection_factory),
106106
)
107-
return RoutingPublisher(DefaultRoutingPolicy(partition_count), clients)
107+
108+
def policy_factory(partition_count: int):
109+
return DefaultRoutingPolicy(partition_count)
110+
111+
return PartitionCountWatchingPublisher(
112+
PartitionCountWatcherImpl(admin_client, topic, DEFAULT_PARTITION_POLL_PERIOD),
113+
publisher_factory,
114+
policy_factory,
115+
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import abstractmethod
16+
from typing import AsyncContextManager
17+
18+
19+
class PartitionCountWatcher(AsyncContextManager):
20+
@abstractmethod
21+
async def get_partition_count(self) -> int:
22+
raise NotImplementedError()
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
from concurrent.futures.thread import ThreadPoolExecutor
16+
import asyncio
17+
18+
from google.cloud.pubsublite import AdminClientInterface
19+
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
20+
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
21+
PartitionCountWatcher,
22+
)
23+
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
24+
from google.cloud.pubsublite.types import TopicPath
25+
from google.api_core.exceptions import GoogleAPICallError
26+
27+
28+
class PartitionCountWatcherImpl(PartitionCountWatcher, PermanentFailable):
29+
_admin: AdminClientInterface
30+
_topic_path: TopicPath
31+
_duration: float
32+
_any_success: bool
33+
_thread: ThreadPoolExecutor
34+
_queue: asyncio.Queue
35+
_poll_partition_loop: asyncio.Future
36+
37+
def __init__(
38+
self, admin: AdminClientInterface, topic_path: TopicPath, duration: float
39+
):
40+
super().__init__()
41+
self._admin = admin
42+
self._topic_path = topic_path
43+
self._duration = duration
44+
self._any_success = False
45+
46+
async def __aenter__(self):
47+
self._thread = ThreadPoolExecutor(max_workers=1)
48+
self._queue = asyncio.Queue(maxsize=1)
49+
self._poll_partition_loop = asyncio.ensure_future(
50+
self.run_poller(self._poll_partition_loop)
51+
)
52+
53+
async def __aexit__(self, exc_type, exc_val, exc_tb):
54+
self._poll_partition_loop.cancel()
55+
await wait_ignore_cancelled(self._poll_partition_loop)
56+
self._thread.shutdown(wait=False)
57+
58+
def _get_partition_count_sync(self) -> int:
59+
return self._admin.get_topic_partition_count(self._topic_path)
60+
61+
async def _poll_partition_loop(self):
62+
try:
63+
partition_count = await asyncio.get_event_loop().run_in_executor(
64+
self._thread, self._get_partition_count_sync
65+
)
66+
self._any_success = True
67+
await self._queue.put(partition_count)
68+
except GoogleAPICallError as e:
69+
if not self._any_success:
70+
raise e
71+
logging.exception("Failed to retrieve partition count")
72+
await asyncio.sleep(self._duration)
73+
74+
async def get_partition_count(self) -> int:
75+
return await self.await_unless_failed(self._queue.get())
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import asyncio
15+
import sys
16+
from typing import Callable, Dict
17+
18+
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
19+
from google.cloud.pubsublite.internal.wire.partition_count_watcher import (
20+
PartitionCountWatcher,
21+
)
22+
from google.cloud.pubsublite.internal.wire.publisher import Publisher
23+
from google.cloud.pubsublite.internal.wire.routing_policy import RoutingPolicy
24+
from google.cloud.pubsublite.types import PublishMetadata, Partition
25+
from google.cloud.pubsublite_v1 import PubSubMessage
26+
27+
28+
class PartitionCountWatchingPublisher(Publisher):
29+
_publishers: Dict[Partition, Publisher]
30+
_publisher_factory: Callable[[Partition], Publisher]
31+
_policy_factory: Callable[[int], RoutingPolicy]
32+
_watcher: PartitionCountWatcher
33+
_partition_count_poller: asyncio.Future
34+
35+
def __init__(
36+
self,
37+
watcher: PartitionCountWatcher,
38+
publisher_factory: Callable[[Partition], Publisher],
39+
policy_factory: Callable[[int], RoutingPolicy],
40+
):
41+
self._publishers = {}
42+
self._publisher_factory = publisher_factory
43+
self._policy_factory = policy_factory
44+
self._watcher = watcher
45+
46+
async def __aenter__(self):
47+
try:
48+
await self._watcher.__aenter__()
49+
await self._poll_partition_count_action()
50+
except Exception:
51+
await self._watcher.__aexit__(*sys.exc_info())
52+
raise
53+
self._partition_count_poller = asyncio.ensure_future(
54+
self._watch_partition_count()
55+
)
56+
return self
57+
58+
async def __aexit__(self, exc_type, exc_val, exc_tb):
59+
self._partition_count_poller.cancel()
60+
await wait_ignore_cancelled(self._partition_count_poller)
61+
await self._watcher.__aexit__(exc_type, exc_val, exc_tb)
62+
for publisher in self._publishers.values():
63+
await publisher.__aexit__(exc_type, exc_val, exc_tb)
64+
65+
async def _poll_partition_count_action(self):
66+
partition_count = await self._watcher.get_partition_count()
67+
await self._handle_partition_count_update(partition_count)
68+
69+
async def _watch_partition_count(self):
70+
while True:
71+
await self._poll_partition_count_action()
72+
73+
async def _handle_partition_count_update(self, partition_count: int):
74+
current_count = len(self._publishers)
75+
if current_count == partition_count:
76+
return
77+
if current_count > partition_count:
78+
return
79+
80+
new_publishers = {
81+
Partition(index): self._publisher_factory(Partition(index))
82+
for index in range(current_count, partition_count)
83+
}
84+
await asyncio.gather(*[p.__aenter__() for p in new_publishers.values()])
85+
routing_policy = self._policy_factory(partition_count)
86+
87+
self._publishers.update(new_publishers)
88+
self._routing_policy = routing_policy
89+
90+
async def publish(self, message: PubSubMessage) -> PublishMetadata:
91+
partition = self._routing_policy.route(message)
92+
assert partition in self._publishers
93+
publisher = self._publishers[partition]
94+
return await publisher.publish(message)

google/cloud/pubsublite/testing/test_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
import asyncio
16-
from typing import List, Union, Any, TypeVar, Generic, Optional
16+
import threading
17+
from typing import List, Union, Any, TypeVar, Generic, Optional, Callable
1718

1819
from asynctest import CoroutineMock
1920

@@ -62,3 +63,16 @@ def wire_queues(mock: CoroutineMock) -> QueuePair:
6263

6364
class Box(Generic[T]):
6465
val: Optional[T]
66+
67+
68+
def run_on_thread(func: Callable[[], T]) -> T:
69+
box = Box()
70+
71+
def set_box():
72+
box.val = func()
73+
74+
# Initialize watcher on another thread with a different event loop.
75+
thread = threading.Thread(target=set_box)
76+
thread.start()
77+
thread.join()
78+
return box.val
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import asyncio
15+
import queue
16+
from asynctest.mock import MagicMock
17+
import pytest
18+
19+
from google.cloud.pubsublite import AdminClientInterface
20+
from google.cloud.pubsublite.internal.wire.partition_count_watcher_impl import (
21+
PartitionCountWatcherImpl,
22+
)
23+
from google.cloud.pubsublite.internal.wire.publisher import Publisher
24+
from google.cloud.pubsublite.testing.test_utils import run_on_thread
25+
from google.cloud.pubsublite.types import Partition, TopicPath
26+
from google.api_core.exceptions import GoogleAPICallError
27+
28+
pytestmark = pytest.mark.asyncio
29+
30+
31+
@pytest.fixture()
32+
def mock_publishers():
33+
return {Partition(i): MagicMock(spec=Publisher) for i in range(10)}
34+
35+
36+
@pytest.fixture()
37+
def topic():
38+
return TopicPath.parse("projects/1/locations/us-central1-a/topics/topic")
39+
40+
41+
@pytest.fixture()
42+
def mock_admin():
43+
admin = MagicMock(spec=AdminClientInterface)
44+
return admin
45+
46+
47+
@pytest.fixture()
48+
def watcher(mock_admin, topic):
49+
return run_on_thread(lambda: PartitionCountWatcherImpl(mock_admin, topic, 0.001))
50+
51+
52+
async def test_init(watcher, mock_admin, topic):
53+
mock_admin.get_topic_partition_count.return_value = 2
54+
async with watcher:
55+
pass
56+
57+
58+
async def test_get_count_first_failure(watcher, mock_admin, topic):
59+
mock_admin.get_topic_partition_count.side_effect = GoogleAPICallError("error")
60+
with pytest.raises(GoogleAPICallError):
61+
async with watcher:
62+
await watcher.get_partition_count()
63+
64+
65+
async def test_get_multiple_counts(watcher, mock_admin, topic):
66+
q = queue.Queue()
67+
mock_admin.get_topic_partition_count.side_effect = q.get
68+
async with watcher:
69+
task1 = asyncio.ensure_future(watcher.get_partition_count())
70+
task2 = asyncio.ensure_future(watcher.get_partition_count())
71+
assert not task1.done()
72+
assert not task2.done()
73+
q.put(3)
74+
assert await task1 == 3
75+
assert not task2.done()
76+
q.put(4)
77+
assert await task2 == 4
78+
79+
80+
async def test_subsequent_failures_ignored(watcher, mock_admin, topic):
81+
q = queue.Queue()
82+
83+
def side_effect():
84+
value = q.get()
85+
if isinstance(value, Exception):
86+
raise value
87+
return value
88+
89+
mock_admin.get_topic_partition_count.side_effect = lambda x: side_effect()
90+
async with watcher:
91+
q.put(3)
92+
assert await watcher.get_partition_count() == 3
93+
q.put(GoogleAPICallError("error"))
94+
q.put(4)
95+
assert await watcher.get_partition_count() == 4

0 commit comments

Comments
 (0)