Skip to content

Commit 810b5d4

Browse files
DataflowTemplatedJobStartOperator fix overwriting of location with default value, when a region is provided. (#31082)
* Fix overwriting of location with default value, when a region is provided. * Update tests/providers/google/cloud/operators/test_dataflow.py Co-authored-by: Pankaj Singh <[email protected]> * Fix incompatible type for passing location to the TemplateJobStartTrigger. --------- Co-authored-by: Pankaj Singh <[email protected]>
1 parent 00a527f commit 810b5d4

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

airflow/providers/google/cloud/operators/dataflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def __init__(
600600
options: dict[str, Any] | None = None,
601601
dataflow_default_options: dict[str, Any] | None = None,
602602
parameters: dict[str, str] | None = None,
603-
location: str = DEFAULT_DATAFLOW_LOCATION,
603+
location: str | None = None,
604604
gcp_conn_id: str = "google_cloud_default",
605605
poll_sleep: int = 10,
606606
impersonation_chain: str | Sequence[str] | None = None,
@@ -690,7 +690,7 @@ def set_current_job(current_job):
690690
trigger=TemplateJobStartTrigger(
691691
project_id=self.project_id,
692692
job_id=job_id,
693-
location=self.location,
693+
location=self.location if self.location else DEFAULT_DATAFLOW_LOCATION,
694694
gcp_conn_id=self.gcp_conn_id,
695695
poll_sleep=self.poll_sleep,
696696
impersonation_chain=self.impersonation_chain,

tests/providers/google/cloud/hooks/test_dataflow.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,50 @@ def test_start_python_dataflow_with_custom_region_as_parameter(
330330
job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION
331331
)
332332

333+
@mock.patch(DATAFLOW_STRING.format("uuid.uuid4"))
334+
@mock.patch(DATAFLOW_STRING.format("DataflowHook.wait_for_done"))
335+
@mock.patch(DATAFLOW_STRING.format("process_line_and_extract_dataflow_job_id_callback"))
336+
def test_start_python_dataflow_with_no_custom_region_or_region(
337+
self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
338+
):
339+
mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
340+
mock_uuid.return_value = MOCK_UUID
341+
on_new_job_id_callback = MagicMock()
342+
py_requirements = ["pandas", "numpy"]
343+
job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
344+
345+
passed_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
346+
347+
with pytest.warns(AirflowProviderDeprecationWarning, match="This method is deprecated"):
348+
self.dataflow_hook.start_python_dataflow(
349+
job_name=JOB_NAME,
350+
variables=passed_variables,
351+
dataflow=PY_FILE,
352+
py_options=PY_OPTIONS,
353+
py_interpreter=DEFAULT_PY_INTERPRETER,
354+
py_requirements=py_requirements,
355+
on_new_job_id_callback=on_new_job_id_callback,
356+
)
357+
358+
expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
359+
expected_variables["job_name"] = job_name
360+
expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
361+
362+
mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
363+
mock_beam_start_python_pipeline.assert_called_once_with(
364+
variables=expected_variables,
365+
py_file=PY_FILE,
366+
py_interpreter=DEFAULT_PY_INTERPRETER,
367+
py_options=PY_OPTIONS,
368+
py_requirements=py_requirements,
369+
py_system_site_packages=False,
370+
process_line_callback=mock_callback_on_job_id.return_value,
371+
)
372+
373+
mock_dataflow_wait_for_done.assert_called_once_with(
374+
job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
375+
)
376+
333377
@mock.patch(DATAFLOW_STRING.format("uuid.uuid4"))
334378
@mock.patch(DATAFLOW_STRING.format("DataflowHook.wait_for_done"))
335379
@mock.patch(DATAFLOW_STRING.format("process_line_and_extract_dataflow_job_id_callback"))

tests/providers/google/cloud/operators/test_dataflow.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
},
7777
}
7878
TEST_LOCATION = "custom-location"
79+
TEST_REGION = "custom-region"
7980
TEST_PROJECT = "test-project"
8081
TEST_SQL_JOB_NAME = "test-sql-job-name"
8182
TEST_DATASET = "test-dataset"
@@ -534,6 +535,42 @@ def test_validation_deferrable_params_raises_error(self):
534535
with pytest.raises(ValueError):
535536
DataflowTemplatedJobStartOperator(**init_kwargs)
536537

538+
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow")
539+
def test_start_with_custom_region(self, dataflow_mock):
540+
init_kwargs = {
541+
"task_id": TASK_ID,
542+
"template": TEMPLATE,
543+
"dataflow_default_options": {
544+
"region": TEST_REGION,
545+
},
546+
"poll_sleep": POLL_SLEEP,
547+
"wait_until_finished": True,
548+
"cancel_timeout": CANCEL_TIMEOUT,
549+
}
550+
operator = DataflowTemplatedJobStartOperator(**init_kwargs)
551+
operator.execute(None)
552+
assert dataflow_mock.called
553+
_, kwargs = dataflow_mock.call_args_list[0]
554+
assert kwargs["variables"]["region"] == TEST_REGION
555+
assert kwargs["location"] is None
556+
557+
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow")
558+
def test_start_with_location(self, dataflow_mock):
559+
init_kwargs = {
560+
"task_id": TASK_ID,
561+
"template": TEMPLATE,
562+
"location": TEST_LOCATION,
563+
"poll_sleep": POLL_SLEEP,
564+
"wait_until_finished": True,
565+
"cancel_timeout": CANCEL_TIMEOUT,
566+
}
567+
operator = DataflowTemplatedJobStartOperator(**init_kwargs)
568+
operator.execute(None)
569+
assert dataflow_mock.called
570+
_, kwargs = dataflow_mock.call_args_list[0]
571+
assert not kwargs["variables"]
572+
assert kwargs["location"] == TEST_LOCATION
573+
537574

538575
class TestDataflowStartFlexTemplateOperator:
539576
@pytest.fixture

0 commit comments

Comments
 (0)