Skip to content

Commit 09f3446

Browse files
Fix GCSObjectExistenceSensor operator to return the same XCOM value in deferrable and non-deferrable mode (#39206)
1 parent 8c556da commit 09f3446

File tree

2 files changed

+44
-12
lines changed
  • airflow/providers/google/cloud/sensors
  • tests/providers/google/cloud/sensors

2 files changed

+44
-12
lines changed

airflow/providers/google/cloud/sensors/gcs.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
self.object = object
9090
self.use_glob = use_glob
9191
self.google_cloud_conn_id = google_cloud_conn_id
92-
self._matches: list[str] = []
92+
self._matches: bool = False
9393
self.impersonation_chain = impersonation_chain
9494
self.retry = retry
9595

@@ -101,17 +101,16 @@ def poke(self, context: Context) -> bool:
101101
gcp_conn_id=self.google_cloud_conn_id,
102102
impersonation_chain=self.impersonation_chain,
103103
)
104-
if self.use_glob:
105-
self._matches = hook.list(self.bucket, match_glob=self.object)
106-
return bool(self._matches)
107-
else:
108-
return hook.exists(self.bucket, self.object, self.retry)
104+
self._matches = (
105+
bool(hook.list(self.bucket, match_glob=self.object))
106+
if self.use_glob
107+
else hook.exists(self.bucket, self.object, self.retry)
108+
)
109+
return self._matches
109110

110-
def execute(self, context: Context) -> None:
111+
def execute(self, context: Context):
111112
"""Airflow runs this method on the worker and defers using the trigger."""
112-
if not self.deferrable:
113-
super().execute(context)
114-
else:
113+
if self.deferrable:
115114
if not self.poke(context=context):
116115
self.defer(
117116
timeout=timedelta(seconds=self.timeout),
@@ -127,8 +126,11 @@ def execute(self, context: Context) -> None:
127126
),
128127
method_name="execute_complete",
129128
)
129+
else:
130+
super().execute(context)
131+
return self._matches
130132

131-
def execute_complete(self, context: Context, event: dict[str, str]) -> str:
133+
def execute_complete(self, context: Context, event: dict[str, str]) -> bool:
132134
"""
133135
Act as a callback for when the trigger fires - returns immediately.
134136
@@ -140,7 +142,7 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> str:
140142
raise AirflowSkipException(event["message"])
141143
raise AirflowException(event["message"])
142144
self.log.info("File %s was found in bucket %s.", self.object, self.bucket)
143-
return event["message"]
145+
return True
144146

145147

146148
@deprecated(

tests/providers/google/cloud/sensors/test_gcs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@ def next_time_side_effect():
8080

8181

8282
class TestGoogleCloudStorageObjectSensor:
83+
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
84+
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor.defer")
85+
def test_gcs_object_existence_sensor_return_value(self, mock_defer, mock_hook):
86+
task = GCSObjectExistenceSensor(
87+
task_id="task-id",
88+
bucket=TEST_BUCKET,
89+
object=TEST_OBJECT,
90+
google_cloud_conn_id=TEST_GCP_CONN_ID,
91+
deferrable=True,
92+
)
93+
mock_hook.return_value.list.return_value = True
94+
return_value = task.execute(mock.MagicMock())
95+
assert return_value, True
96+
8397
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
8498
def test_should_pass_argument_to_hook(self, mock_hook):
8599
task = GCSObjectExistenceSensor(
@@ -183,6 +197,22 @@ def test_gcs_object_existence_sensor_execute_complete(self):
183197
task.execute_complete(context=None, event={"status": "success", "message": "Job completed"})
184198
mock_log_info.assert_called_with("File %s was found in bucket %s.", TEST_OBJECT, TEST_BUCKET)
185199

200+
def test_gcs_object_existence_sensor_execute_complete_return_value(self):
201+
"""Asserts that logging occurs as expected when deferrable is set to True"""
202+
task = GCSObjectExistenceSensor(
203+
task_id="task-id",
204+
bucket=TEST_BUCKET,
205+
object=TEST_OBJECT,
206+
google_cloud_conn_id=TEST_GCP_CONN_ID,
207+
deferrable=True,
208+
)
209+
with mock.patch.object(task.log, "info") as mock_log_info:
210+
return_value = task.execute_complete(
211+
context=None, event={"status": "success", "message": "Job completed"}
212+
)
213+
mock_log_info.assert_called_with("File %s was found in bucket %s.", TEST_OBJECT, TEST_BUCKET)
214+
assert return_value, True
215+
186216

187217
class TestGoogleCloudStorageObjectAsyncSensor:
188218
depcrecation_message = (

0 commit comments

Comments
 (0)