Skip to content

Commit 0b2e1a8

Browse files
author
Łukasz Wyszomirski
authored
Added wait mechanizm to the DataprocJobSensor to avoid 509 errors when Job is not available (#19740)
1 parent 56bdfe7 commit 0b2e1a8

File tree

2 files changed

+77
-4
lines changed

2 files changed

+77
-4
lines changed

airflow/providers/google/cloud/sensors/dataproc.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
# under the License.
1818
"""This module contains a Dataproc Job sensor."""
1919
# pylint: disable=C0302
20+
import time
2021
import warnings
21-
from typing import Optional
22+
from typing import Dict, Optional
2223

24+
from google.api_core.exceptions import ServerError
2325
from google.cloud.dataproc_v1.types import JobStatus
2426

2527
from airflow.exceptions import AirflowException
@@ -42,6 +44,8 @@ class DataprocJobSensor(BaseSensorOperator):
4244
:type location: str
4345
:param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform.
4446
:type gcp_conn_id: str
47+
:param wait_timeout: How many seconds wait for job to be ready.
48+
:type wait_timeout: int
4549
"""
4650

4751
template_fields = ('project_id', 'region', 'dataproc_job_id')
@@ -55,6 +59,7 @@ def __init__(
5559
region: str = None,
5660
location: Optional[str] = None,
5761
gcp_conn_id: str = 'google_cloud_default',
62+
wait_timeout: Optional[int] = None,
5863
**kwargs,
5964
) -> None:
6065
if region is None:
@@ -73,12 +78,36 @@ def __init__(
7378
self.gcp_conn_id = gcp_conn_id
7479
self.dataproc_job_id = dataproc_job_id
7580
self.region = region
81+
self.wait_timeout = wait_timeout
82+
self.start_sensor_time = None
7683

77-
def poke(self, context: dict) -> bool:
84+
def execute(self, context: Dict):
85+
self.start_sensor_time = time.monotonic()
86+
super().execute(context)
87+
88+
def _duration(self):
89+
return time.monotonic() - self.start_sensor_time
90+
91+
def poke(self, context: Dict) -> bool:
7892
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
79-
job = hook.get_job(job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id)
80-
state = job.status.state
93+
if self.wait_timeout:
94+
try:
95+
job = hook.get_job(
96+
job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id
97+
)
98+
except ServerError as err:
99+
self.log.info(f"DURATION RUN: {self._duration()}")
100+
if self._duration() > self.wait_timeout:
101+
raise AirflowException(
102+
f"Timeout: dataproc job {self.dataproc_job_id} "
103+
f"is not ready after {self.wait_timeout}s"
104+
)
105+
self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)
106+
return False
107+
else:
108+
job = hook.get_job(job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id)
81109

110+
state = job.status.state
82111
if state == JobStatus.State.ERROR:
83112
raise AirflowException(f'Job failed:\n{job}')
84113
elif state in {

tests/providers/google/cloud/sensors/test_dataproc.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
import unittest
1919
from unittest import mock
20+
from unittest.mock import Mock
2021

2122
import pytest
23+
from google.api_core.exceptions import ServerError
2224
from google.cloud.dataproc_v1.types import JobStatus
2325

2426
from airflow import AirflowException
@@ -164,3 +166,45 @@ def test_location_deprecation_warning(self, mock_hook):
164166
timeout=TIMEOUT,
165167
)
166168
sensor.poke(context={})
169+
170+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
171+
def test_wait_timeout(self, mock_hook):
172+
job_id = "job_id"
173+
mock_hook.return_value.get_job.side_effect = ServerError("Job are not ready")
174+
175+
sensor = DataprocJobSensor(
176+
task_id=TASK_ID,
177+
region=GCP_LOCATION,
178+
project_id=GCP_PROJECT,
179+
dataproc_job_id=job_id,
180+
gcp_conn_id=GCP_CONN_ID,
181+
timeout=TIMEOUT,
182+
wait_timeout=300,
183+
)
184+
185+
sensor._duration = Mock()
186+
sensor._duration.return_value = 200
187+
188+
result = sensor.poke(context={})
189+
assert not result
190+
191+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
192+
def test_wait_timeout_raise_exception(self, mock_hook):
193+
job_id = "job_id"
194+
mock_hook.return_value.get_job.side_effect = ServerError("Job are not ready")
195+
196+
sensor = DataprocJobSensor(
197+
task_id=TASK_ID,
198+
region=GCP_LOCATION,
199+
project_id=GCP_PROJECT,
200+
dataproc_job_id=job_id,
201+
gcp_conn_id=GCP_CONN_ID,
202+
timeout=TIMEOUT,
203+
wait_timeout=300,
204+
)
205+
206+
sensor._duration = Mock()
207+
sensor._duration.return_value = 301
208+
209+
with pytest.raises(AirflowException, match="Timeout: dataproc job job_id is not ready after 300s"):
210+
sensor.poke(context={})

0 commit comments

Comments
 (0)