17
17
# under the License.
18
18
"""This module contains a Dataproc Job sensor."""
19
19
# pylint: disable=C0302
20
+ import time
20
21
import warnings
21
- from typing import Optional
22
+ from typing import Dict , Optional
22
23
24
+ from google .api_core .exceptions import ServerError
23
25
from google .cloud .dataproc_v1 .types import JobStatus
24
26
25
27
from airflow .exceptions import AirflowException
@@ -42,6 +44,8 @@ class DataprocJobSensor(BaseSensorOperator):
42
44
:type location: str
43
45
:param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform.
44
46
:type gcp_conn_id: str
47
+ :param wait_timeout: How many seconds wait for job to be ready.
48
+ :type wait_timeout: int
45
49
"""
46
50
47
51
template_fields = ('project_id' , 'region' , 'dataproc_job_id' )
@@ -55,6 +59,7 @@ def __init__(
55
59
region : str = None ,
56
60
location : Optional [str ] = None ,
57
61
gcp_conn_id : str = 'google_cloud_default' ,
62
+ wait_timeout : Optional [int ] = None ,
58
63
** kwargs ,
59
64
) -> None :
60
65
if region is None :
@@ -73,12 +78,36 @@ def __init__(
73
78
self .gcp_conn_id = gcp_conn_id
74
79
self .dataproc_job_id = dataproc_job_id
75
80
self .region = region
81
+ self .wait_timeout = wait_timeout
82
+ self .start_sensor_time = None
76
83
77
- def poke (self , context : dict ) -> bool :
84
+ def execute (self , context : Dict ):
85
+ self .start_sensor_time = time .monotonic ()
86
+ super ().execute (context )
87
+
88
+ def _duration (self ):
89
+ return time .monotonic () - self .start_sensor_time
90
+
91
+ def poke (self , context : Dict ) -> bool :
78
92
hook = DataprocHook (gcp_conn_id = self .gcp_conn_id )
79
- job = hook .get_job (job_id = self .dataproc_job_id , region = self .region , project_id = self .project_id )
80
- state = job .status .state
93
+ if self .wait_timeout :
94
+ try :
95
+ job = hook .get_job (
96
+ job_id = self .dataproc_job_id , region = self .region , project_id = self .project_id
97
+ )
98
+ except ServerError as err :
99
+ self .log .info (f"DURATION RUN: { self ._duration ()} " )
100
+ if self ._duration () > self .wait_timeout :
101
+ raise AirflowException (
102
+ f"Timeout: dataproc job { self .dataproc_job_id } "
103
+ f"is not ready after { self .wait_timeout } s"
104
+ )
105
+ self .log .info ("Retrying. Dataproc API returned server error when waiting for job: %s" , err )
106
+ return False
107
+ else :
108
+ job = hook .get_job (job_id = self .dataproc_job_id , region = self .region , project_id = self .project_id )
81
109
110
+ state = job .status .state
82
111
if state == JobStatus .State .ERROR :
83
112
raise AirflowException (f'Job failed:\n { job } ' )
84
113
elif state in {
0 commit comments