Skip to content

Commit 431ea32

Browse files
turbaszekashb
andauthored
Resolve upstream tasks when template field is XComArg (#8805)
* Resolve upstream tasks when template field is XComArg closes: #8054 * fixup! Resolve upstream tasks when template field is XComArg * Resolve task relations in DagRun and DagBag * Add tests for serialized DAG * Set dependencies only in bag_dag, refactor tests * Traverse template_fields attribute * Use provide_test_dag_bag in all tests * fixup! Use provide_test_dag_bag in all tests * Use metaclass + setattr * Add prepare_for_execution method * Check signature of __init__ not class * Apply suggestions from code review Co-authored-by: Ash Berlin-Taylor <[email protected]> * Update airflow/models/baseoperator.py Co-authored-by: Ash Berlin-Taylor <[email protected]>
1 parent aee6ab9 commit 431ea32

File tree

7 files changed

+211
-11
lines changed

7 files changed

+211
-11
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
19+
"""Example DAG demonstrating the usage of the XComArgs."""
20+
21+
from airflow import DAG
22+
from airflow.operators.python import PythonOperator
23+
from airflow.utils.dates import days_ago
24+
25+
args = {
26+
'owner': 'airflow',
27+
'start_date': days_ago(2),
28+
}
29+
30+
31+
def dummy(*args, **kwargs):
32+
"""Dummy function"""
33+
return "pass"
34+
35+
36+
with DAG(
37+
dag_id='example_xcom_args',
38+
default_args=args,
39+
schedule_interval=None,
40+
tags=['example']
41+
) as dag:
42+
task1 = PythonOperator(
43+
task_id='task1',
44+
python_callable=dummy,
45+
)
46+
47+
task2 = PythonOperator(
48+
task_id='task2',
49+
python_callable=dummy,
50+
op_kwargs={"dummy": task1.output},
51+
)

airflow/models/baseoperator.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""
1919
Base operator for all operators.
2020
"""
21+
import abc
2122
import copy
2223
import functools
2324
import logging
@@ -60,9 +61,29 @@
6061
ScheduleInterval = Union[str, timedelta, relativedelta]
6162

6263

64+
class BaseOperatorMeta(abc.ABCMeta):
65+
"""
66+
Base metaclass of BaseOperator.
67+
"""
68+
69+
def __call__(cls, *args, **kwargs):
70+
"""
71+
Called when you call BaseOperator(). In this way we are able to perform an action
72+
after initializing an operator no matter where the ``super().__init__`` is called
73+
(before or after assign of new attributes in a custom operator).
74+
"""
75+
obj: BaseOperator = type.__call__(cls, *args, **kwargs)
76+
# Here we set upstream task defined by XComArgs passed to template fields of the operator
77+
obj.set_xcomargs_dependencies()
78+
79+
# Mark instance as instantiated https://docs.python.org/3/tutorial/classes.html#private-variables
80+
obj._BaseOperator__instantiated = True
81+
return obj
82+
83+
6384
# pylint: disable=too-many-instance-attributes,too-many-public-methods
6485
@functools.total_ordering
65-
class BaseOperator(Operator, LoggingMixin):
86+
class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta):
6687
"""
6788
Abstract base class for all operators. Since operators create objects that
6889
become nodes in the dag, BaseOperator contains many recursive methods for
@@ -292,6 +313,12 @@ class derived from this one results in the creation of a task object,
292313
# Defines if the operator supports lineage without manual definitions
293314
supports_lineage = False
294315

316+
# If True then the class constructor was called
317+
__instantiated = False
318+
319+
# Set to True before calling execute method
320+
_lock_for_execution = False
321+
295322
# noinspection PyUnusedLocal
296323
# pylint: disable=too-many-arguments,too-many-locals, too-many-statements
297324
@apply_defaults
@@ -547,6 +574,18 @@ def __lt__(self, other):
547574

548575
return self
549576

577+
def __setattr__(self, key, value):
578+
super().__setattr__(key, value)
579+
if self._lock_for_execution:
580+
# Skip any custom behaviour during execute
581+
return
582+
if self.__instantiated and key in self.template_fields:
583+
# Resolve upstreams set by assigning an XComArg after initializing
584+
# an operator, example:
585+
# op = BashOperator()
586+
# op.bash_command = "sleep 1"
587+
self.set_xcomargs_dependencies()
588+
550589
def add_inlets(self, inlets: Iterable[Any]):
551590
"""
552591
Sets inlets to this operator
@@ -633,6 +672,56 @@ def deps(self) -> Set[BaseTIDep]:
633672
NotPreviouslySkippedDep(),
634673
}
635674

675+
def prepare_for_execution(self) -> "BaseOperator":
676+
"""
677+
Lock task for execution to disable custom action in __setattr__ and
678+
returns a copy of the task
679+
"""
680+
other = copy.copy(self)
681+
other._lock_for_execution = True # pylint: disable=protected-access
682+
return other
683+
684+
def set_xcomargs_dependencies(self) -> None:
685+
"""
686+
Resolves upstream dependencies of a task. In this way passing an ``XComArg``
687+
as value for a template field will result in creating upstream relation between
688+
two tasks.
689+
690+
**Example**: ::
691+
692+
with DAG(...):
693+
generate_content = GenerateContentOperator(task_id="generate_content")
694+
send_email = EmailOperator(..., html_content=generate_content.output)
695+
696+
# This is equivalent to
697+
with DAG(...):
698+
generate_content = GenerateContentOperator(task_id="generate_content")
699+
send_email = EmailOperator(
700+
..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
701+
)
702+
generate_content >> send_email
703+
704+
"""
705+
from airflow.models.xcom_arg import XComArg
706+
707+
def apply_set_upstream(arg: Any):
708+
if isinstance(arg, XComArg):
709+
self.set_upstream(arg.operator)
710+
elif isinstance(arg, (tuple, set, list)):
711+
for elem in arg:
712+
apply_set_upstream(elem)
713+
elif isinstance(arg, dict):
714+
for elem in arg.values():
715+
apply_set_upstream(elem)
716+
elif hasattr(arg, "template_fields"):
717+
for elem in arg.template_fields:
718+
apply_set_upstream(elem)
719+
720+
for field in self.template_fields:
721+
if hasattr(self, field):
722+
arg = getattr(self, field)
723+
apply_set_upstream(arg)
724+
636725
@property
637726
def priority_weight_total(self) -> int:
638727
"""
@@ -1140,7 +1229,7 @@ def set_upstream(self, task_or_task_list: Union['BaseOperator', List['BaseOperat
11401229

11411230
@property
11421231
def output(self):
1143-
"""Returns default XComArg for the operator"""
1232+
"""Returns reference to XCom pushed by current operator"""
11441233
from airflow.models.xcom_arg import XComArg
11451234
return XComArg(operator=self)
11461235

@@ -1205,7 +1294,8 @@ def get_serialized_fields(cls):
12051294
if not cls.__serialized_fields:
12061295
cls.__serialized_fields = frozenset(
12071296
vars(BaseOperator(task_id='test')).keys() - {
1208-
'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag'
1297+
'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag',
1298+
'_BaseOperator__instantiated',
12091299
} | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'})
12101300

12111301
return cls.__serialized_fields

airflow/models/taskinstance.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# specific language governing permissions and limitations
1717
# under the License.
1818

19-
import copy
2019
import getpass
2120
import hashlib
2221
import logging
@@ -970,7 +969,7 @@ def _run_raw_task(
970969
if not mark_success:
971970
context = self.get_template_context()
972971

973-
task_copy = copy.copy(task)
972+
task_copy = task.prepare_for_execution()
974973

975974
# Sensors in `poke` mode can block execution of DAGs when running
976975
# with single process executor, thus we change the mode to`reschedule`
@@ -1154,7 +1153,7 @@ def run(
11541153

11551154
def dry_run(self):
11561155
task = self.task
1157-
task_copy = copy.copy(task)
1156+
task_copy = task.prepare_for_execution()
11581157
self.task = task_copy
11591158

11601159
self.render_templates()

airflow/providers/google/cloud/operators/sql_to_gcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from airflow.utils.decorators import apply_defaults
3232

3333

34-
class BaseSQLToGCSOperator(BaseOperator, metaclass=abc.ABCMeta):
34+
class BaseSQLToGCSOperator(BaseOperator):
3535
"""
3636
:param sql: The SQL to execute.
3737
:type sql: str

airflow/serialization/serialized_objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
299299
_decorated_fields = {'executor_config'}
300300

301301
_CONSTRUCTOR_PARAMS = {
302-
k: v.default for k, v in signature(BaseOperator).parameters.items()
302+
k: v.default for k, v in signature(BaseOperator.__init__).parameters.items()
303303
if v.default is not v.empty
304304
}
305305

@@ -537,7 +537,7 @@ def __get_constructor_defaults(): # pylint: disable=no-method-argument
537537
'access_control': '_access_control',
538538
}
539539
return {
540-
param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items()
540+
param_to_attr.get(k, k): v.default for k, v in signature(DAG.__init__).parameters.items()
541541
if v.default is not v.empty
542542
}
543543

tests/models/test_baseoperator.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,21 @@
1515
# KIND, either express or implied. See the License for the
1616
# specific language governing permissions and limitations
1717
# under the License.
18-
1918
import unittest
2019
import uuid
2120
from datetime import date, datetime
2221
from unittest import mock
2322

2423
import jinja2
24+
import pytest
2525
from parameterized import parameterized
2626

2727
from airflow.exceptions import AirflowException
2828
from airflow.lineage.entities import File
2929
from airflow.models import DAG
3030
from airflow.models.baseoperator import chain, cross_downstream
3131
from airflow.operators.dummy_operator import DummyOperator
32+
from airflow.utils.decorators import apply_defaults
3233
from tests.models import DEFAULT_DATE
3334
from tests.test_utils.mock_operators import MockNamedTuple, MockOperator
3435

@@ -347,3 +348,61 @@ def test_lineage_composition(self):
347348
task4 = DummyOperator(task_id="op4", dag=dag)
348349
task4 > [inlet, outlet, extra]
349350
self.assertEqual(task4.get_outlet_defs(), [inlet, outlet, extra])
351+
352+
353+
class CustomOp(DummyOperator):
354+
template_fields = ("field", "field2")
355+
356+
@apply_defaults
357+
def __init__(self, field=None, field2=None, *args, **kwargs):
358+
super().__init__(*args, **kwargs)
359+
self.field = field
360+
self.field2 = field2
361+
362+
def execute(self, context):
363+
self.field = None
364+
365+
366+
class TestXComArgsRelationsAreResolved:
367+
def test_setattr_performs_no_custom_action_at_execute_time(self):
368+
op = CustomOp(task_id="test_task")
369+
op_copy = op.prepare_for_execution()
370+
371+
with mock.patch(
372+
"airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies"
373+
) as method_mock:
374+
op_copy.execute({})
375+
assert method_mock.call_count == 0
376+
377+
def test_upstream_is_set_when_template_field_is_xcomarg(self):
378+
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
379+
op1 = DummyOperator(task_id="op1")
380+
op2 = CustomOp(task_id="op2", field=op1.output)
381+
382+
assert op1 in op2.upstream_list
383+
assert op2 in op1.downstream_list
384+
385+
def test_set_xcomargs_dependencies_works_recursively(self):
386+
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
387+
op1 = DummyOperator(task_id="op1")
388+
op2 = DummyOperator(task_id="op2")
389+
op3 = CustomOp(task_id="op3", field=[op1.output, op2.output])
390+
op4 = CustomOp(task_id="op4", field={"op1": op1.output, "op2": op2.output})
391+
392+
assert op1 in op3.upstream_list
393+
assert op2 in op3.upstream_list
394+
assert op1 in op4.upstream_list
395+
assert op2 in op4.upstream_list
396+
397+
def test_set_xcomargs_dependencies_works_when_set_after_init(self):
398+
with DAG(dag_id='xcomargs_test', default_args={"start_date": datetime.today()}):
399+
op1 = DummyOperator(task_id="op1")
400+
op2 = CustomOp(task_id="op2")
401+
op2.field = op1.output # value is set after init
402+
403+
assert op1 in op2.upstream_list
404+
405+
def test_set_xcomargs_dependencies_error_when_outside_dag(self):
406+
with pytest.raises(AirflowException):
407+
op1 = DummyOperator(task_id="op1")
408+
CustomOp(task_id="op2", field=op1.output)

tests/serialization/test_dag_serialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,8 @@ def test_no_new_fields_added_to_base_operator(self):
726726
"""
727727
base_operator = BaseOperator(task_id="10")
728728
fields = base_operator.__dict__
729-
self.assertEqual({'_dag': None,
729+
self.assertEqual({'_BaseOperator__instantiated': True,
730+
'_dag': None,
730731
'_downstream_task_ids': set(),
731732
'_inlets': [],
732733
'_log': base_operator.log,

0 commit comments

Comments
 (0)