Skip to content

Commit 6afd477

Browse files
feat: implement assigning subscriber (#23)
* feat: Implement SinglePartitionSubscriber. This handles mapping a single partition to a Cloud Pub/Sub Like asynchronous subscriber. * feat: Add DefaultNackHandler. * feat: Add AssigningSubscriber. This handles changing partition assignments and creates AsyncSubscribers per-partition.
1 parent bb76d90 commit 6afd477

File tree

6 files changed

+241
-2
lines changed

6 files changed

+241
-2
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from asyncio import Future, Queue, ensure_future
2+
from typing import Callable, NamedTuple, Dict, Set
3+
4+
from google.cloud.pubsub_v1.subscriber.message import Message
5+
6+
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
7+
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
8+
from google.cloud.pubsublite.internal.wire.assigner import Assigner
9+
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
10+
from google.cloud.pubsublite.partition import Partition
11+
12+
_PartitionSubscriberFactory = Callable[[Partition], AsyncSubscriber]
13+
14+
15+
class _RunningSubscriber(NamedTuple):
16+
subscriber: AsyncSubscriber
17+
poller: Future
18+
19+
20+
class AssigningSubscriber(AsyncSubscriber, PermanentFailable):
21+
_assigner: Assigner
22+
_subscriber_factory: _PartitionSubscriberFactory
23+
24+
_subscribers: Dict[Partition, _RunningSubscriber]
25+
_messages: "Queue[Message]"
26+
_assign_poller: Future
27+
28+
def __init__(self, assigner: Assigner, subscriber_factory: _PartitionSubscriberFactory):
29+
super().__init__()
30+
self._assigner = assigner
31+
self._subscriber_factory = subscriber_factory
32+
self._subscribers = {}
33+
self._messages = Queue()
34+
35+
async def read(self) -> Message:
36+
return await self.await_unless_failed(self._messages.get())
37+
38+
async def _subscribe_action(self, subscriber: AsyncSubscriber):
39+
message = await subscriber.read()
40+
await self._messages.put(message)
41+
42+
async def _start_subscriber(self, partition: Partition):
43+
new_subscriber = self._subscriber_factory(partition)
44+
await new_subscriber.__aenter__()
45+
poller = ensure_future(self.run_poller(lambda: self._subscribe_action(new_subscriber)))
46+
self._subscribers[partition] = _RunningSubscriber(new_subscriber, poller)
47+
48+
async def _stop_subscriber(self, running: _RunningSubscriber):
49+
running.poller.cancel()
50+
await wait_ignore_cancelled(running.poller)
51+
await running.subscriber.__aexit__(None, None, None)
52+
53+
async def _assign_action(self):
54+
assignment: Set[Partition] = await self._assigner.get_assignment()
55+
added_partitions = assignment - self._subscribers.keys()
56+
removed_partitions = self._subscribers.keys() - assignment
57+
for partition in added_partitions:
58+
await self._start_subscriber(partition)
59+
for partition in removed_partitions:
60+
await self._stop_subscriber(self._subscribers[partition])
61+
del self._subscribers[partition]
62+
63+
async def __aenter__(self):
64+
await self._assigner.__aenter__()
65+
self._assign_poller = ensure_future(self.run_poller(self._assign_action))
66+
return self
67+
68+
async def __aexit__(self, exc_type, exc_value, traceback):
69+
self._assign_poller.cancel()
70+
await wait_ignore_cancelled(self._assign_poller)
71+
await self._assigner.__aexit__(exc_type, exc_value, traceback)
72+
for running in self._subscribers.values():
73+
await self._stop_subscriber(running)

google/cloud/pubsublite/cloudpubsub/internal/managed_event_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self):
1515
def __enter__(self):
1616
self._thread.start()
1717

18-
def __exit__(self, __exc_type, __exc_value, __traceback):
18+
def __exit__(self, exc_type, exc_value, traceback):
1919
self._loop.call_soon_threadsafe(self._loop.stop)
2020
self._thread.join()
2121

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from asyncio import CancelledError
2+
from typing import Awaitable
3+
4+
5+
async def wait_ignore_cancelled(awaitable: Awaitable):
6+
try:
7+
await awaitable
8+
except CancelledError:
9+
pass

google/cloud/pubsublite/internal/wire/permanent_failable.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import Awaitable, TypeVar, Optional
2+
from typing import Awaitable, TypeVar, Optional, Callable
33

44
from google.api_core.exceptions import GoogleAPICallError
55

@@ -31,6 +31,19 @@ async def await_unless_failed(self, awaitable: Awaitable[T]) -> T:
3131
task.cancel()
3232
raise self._failure_task.exception()
3333

34+
async def run_poller(self, poll_action: Callable[[], Awaitable[None]]):
35+
"""
36+
Run a polling loop, which runs poll_action forever unless this is failed.
37+
Args:
38+
poll_action: A callable returning an awaitable to run in a loop. Note that async functions which return once
39+
satisfy this.
40+
"""
41+
try:
42+
while True:
43+
await self.await_unless_failed(poll_action())
44+
except GoogleAPICallError as e:
45+
self.fail(e)
46+
3447
def fail(self, err: GoogleAPICallError):
3548
if not self._failure_task.done():
3649
self._failure_task.set_exception(err)

google/cloud/pubsublite/testing/test_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22
from typing import List, Union, Any, TypeVar, Generic, Optional
33

4+
from asynctest import CoroutineMock
5+
46
T = TypeVar("T")
57

68

@@ -27,5 +29,20 @@ async def waiter(*args, **kwargs):
2729
return waiter
2830

2931

32+
class QueuePair:
33+
called: asyncio.Queue
34+
results: asyncio.Queue
35+
36+
def __init__(self):
37+
self.called = asyncio.Queue()
38+
self.results = asyncio.Queue()
39+
40+
41+
def wire_queues(mock: CoroutineMock) -> QueuePair:
42+
queues = QueuePair()
43+
mock.side_effect = make_queue_waiter(queues.called, queues.results)
44+
return queues
45+
46+
3047
class Box(Generic[T]):
3148
val: Optional[T]
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import asyncio
2+
from typing import Callable, Set
3+
4+
from asynctest.mock import MagicMock, call
5+
import pytest
6+
from google.api_core.exceptions import FailedPrecondition
7+
from google.cloud.pubsub_v1.subscriber.message import Message
8+
from google.pubsub_v1 import PubsubMessage
9+
10+
from google.cloud.pubsublite.cloudpubsub.internal.assigning_subscriber import AssigningSubscriber
11+
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
12+
from google.cloud.pubsublite.internal.wire.assigner import Assigner
13+
from google.cloud.pubsublite.partition import Partition
14+
from google.cloud.pubsublite.testing.test_utils import make_queue_waiter, wire_queues
15+
16+
# All test coroutines will be treated as marked.
17+
pytestmark = pytest.mark.asyncio
18+
19+
20+
def mock_async_context_manager(cm):
21+
cm.__aenter__.return_value = cm
22+
return cm
23+
24+
25+
@pytest.fixture()
26+
def assigner():
27+
return mock_async_context_manager(MagicMock(spec=Assigner))
28+
29+
30+
@pytest.fixture()
31+
def subscriber_factory():
32+
return MagicMock(spec=Callable[[Partition], AsyncSubscriber])
33+
34+
35+
@pytest.fixture()
36+
def subscriber(assigner, subscriber_factory):
37+
return AssigningSubscriber(assigner, subscriber_factory)
38+
39+
40+
async def test_init(subscriber, assigner):
41+
assign_queues = wire_queues(assigner.get_assignment)
42+
async with subscriber:
43+
assigner.__aenter__.assert_called_once()
44+
await assign_queues.called.get()
45+
assigner.get_assignment.assert_called_once()
46+
assigner.__aexit__.assert_called_once()
47+
48+
49+
async def test_initial_assignment(subscriber, assigner, subscriber_factory):
50+
assign_queues = wire_queues(assigner.get_assignment)
51+
async with subscriber:
52+
await assign_queues.called.get()
53+
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
54+
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
55+
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2
56+
await assign_queues.results.put({Partition(1), Partition(2)})
57+
await assign_queues.called.get()
58+
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True)
59+
sub1.__aenter__.assert_called_once()
60+
sub2.__aenter__.assert_called_once()
61+
sub1.__aexit__.assert_called_once()
62+
sub2.__aexit__.assert_called_once()
63+
64+
65+
async def test_assigner_failure(subscriber, assigner, subscriber_factory):
66+
assign_queues = wire_queues(assigner.get_assignment)
67+
async with subscriber:
68+
await assign_queues.called.get()
69+
await assign_queues.results.put(FailedPrecondition("bad assign"))
70+
with pytest.raises(FailedPrecondition):
71+
await subscriber.read()
72+
73+
74+
async def test_assignment_change(subscriber, assigner, subscriber_factory):
75+
assign_queues = wire_queues(assigner.get_assignment)
76+
async with subscriber:
77+
await assign_queues.called.get()
78+
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
79+
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
80+
sub3 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
81+
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(
82+
1) else sub2 if partition == Partition(2) else sub3
83+
await assign_queues.results.put({Partition(1), Partition(2)})
84+
await assign_queues.called.get()
85+
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True)
86+
sub1.__aenter__.assert_called_once()
87+
sub2.__aenter__.assert_called_once()
88+
await assign_queues.results.put({Partition(1), Partition(3)})
89+
await assign_queues.called.get()
90+
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2)), call(Partition(3))], any_order=True)
91+
sub3.__aenter__.assert_called_once()
92+
sub2.__aexit__.assert_called_once()
93+
sub1.__aexit__.assert_called_once()
94+
sub2.__aexit__.assert_called_once()
95+
sub3.__aexit__.assert_called_once()
96+
97+
98+
async def test_subscriber_failure(subscriber, assigner, subscriber_factory):
99+
assign_queues = wire_queues(assigner.get_assignment)
100+
async with subscriber:
101+
await assign_queues.called.get()
102+
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
103+
sub1_queues = wire_queues(sub1.read)
104+
subscriber_factory.return_value = sub1
105+
await assign_queues.results.put({Partition(1)})
106+
await sub1_queues.called.get()
107+
await sub1_queues.results.put(FailedPrecondition("sub failed"))
108+
with pytest.raises(FailedPrecondition):
109+
await subscriber.read()
110+
111+
112+
async def test_delivery_from_multiple(subscriber, assigner, subscriber_factory):
113+
assign_queues = wire_queues(assigner.get_assignment)
114+
async with subscriber:
115+
await assign_queues.called.get()
116+
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
117+
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
118+
sub1_queues = wire_queues(sub1.read)
119+
sub2_queues = wire_queues(sub2.read)
120+
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2
121+
await assign_queues.results.put({Partition(1), Partition(2)})
122+
await sub1_queues.results.put(Message(PubsubMessage(message_id="1")._pb, "", 0, None))
123+
await sub2_queues.results.put(Message(PubsubMessage(message_id="2")._pb, "", 0, None))
124+
message_ids: Set[str] = set()
125+
message_ids.add((await subscriber.read()).message_id)
126+
message_ids.add((await subscriber.read()).message_id)
127+
assert message_ids == {"1", "2"}

0 commit comments

Comments
 (0)