Skip to content

Commit df132b2

Browse files
Add GKEStartKueueInsideClusterOperator (#37072)
1 parent 2a0f3d2 commit df132b2

File tree

8 files changed

+850
-18
lines changed

8 files changed

+850
-18
lines changed

airflow/providers/cncf/kubernetes/hooks/kubernetes.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from airflow.utils import yaml
3838

3939
if TYPE_CHECKING:
40-
from kubernetes.client.models import V1Pod
40+
from kubernetes.client.models import V1Deployment, V1Pod
4141

4242
LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..."
4343

@@ -282,6 +282,10 @@ def api_client(self) -> client.ApiClient:
282282
def core_v1_client(self) -> client.CoreV1Api:
283283
return client.CoreV1Api(api_client=self.api_client)
284284

285+
@cached_property
286+
def apps_v1_client(self) -> client.AppsV1Api:
287+
return client.AppsV1Api(api_client=self.api_client)
288+
285289
@cached_property
286290
def custom_object_client(self) -> client.CustomObjectsApi:
287291
return client.CustomObjectsApi(api_client=self.api_client)
@@ -450,6 +454,24 @@ def get_namespaced_pod_list(
450454
**kwargs,
451455
)
452456

457+
def get_deployment_status(
458+
self,
459+
name: str,
460+
namespace: str = "default",
461+
**kwargs,
462+
) -> V1Deployment:
463+
"""Get status of existing Deployment.
464+
465+
:param name: Name of Deployment to retrieve
466+
:param namespace: Deployment namespace
467+
"""
468+
try:
469+
return self.apps_v1_client.read_namespaced_deployment_status(
470+
name=name, namespace=namespace, pretty=True, **kwargs
471+
)
472+
except Exception as exc:
473+
raise exc
474+
453475

454476
def _get_bool(val) -> bool | None:
455477
"""Convert val to bool if can be done with certainty; if we cannot infer intention we return None."""

airflow/providers/google/cloud/hooks/kubernetes_engine.py

Lines changed: 154 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,7 @@
1515
# KIND, either express or implied. See the License for the
1616
# specific language governing permissions and limitations
1717
# under the License.
18-
"""
19-
This module contains a Google Kubernetes Engine Hook.
20-
21-
.. spelling:word-list::
22-
23-
gapic
24-
enums
25-
"""
18+
"""This module contains a Google Kubernetes Engine Hook."""
2619
from __future__ import annotations
2720

2821
import contextlib
@@ -41,13 +34,15 @@
4134
from google.cloud import container_v1, exceptions # type: ignore[attr-defined]
4235
from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient
4336
from google.cloud.container_v1.types import Cluster, Operation
44-
from kubernetes import client
37+
from kubernetes import client, utils
38+
from kubernetes.client.models import V1Deployment
4539
from kubernetes_asyncio import client as async_client
4640
from kubernetes_asyncio.config.kube_config import FileOrData
4741
from urllib3.exceptions import HTTPError
4842

4943
from airflow import version
5044
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
45+
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
5146
from airflow.providers.cncf.kubernetes.kube_client import _enable_tcp_keepalive
5247
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol
5348
from airflow.providers.google.common.consts import CLIENT_INFO
@@ -299,6 +294,130 @@ def get_cluster(
299294
timeout=timeout,
300295
)
301296

297+
def check_cluster_autoscaling_ability(self, cluster: Cluster | dict):
298+
"""
299+
Check if the specified Cluster has ability to autoscale.
300+
301+
Cluster should be Autopilot, with Node Auto-provisioning or regular auto-scaled node pools.
302+
Returns True if the Cluster supports autoscaling, otherwise returns False.
303+
304+
:param cluster: The Cluster object.
305+
"""
306+
if isinstance(cluster, Cluster):
307+
cluster_dict_representation = Cluster.to_dict(cluster)
308+
elif not isinstance(cluster, dict):
309+
raise AirflowException("cluster is not instance of Cluster proto or python dict")
310+
else:
311+
cluster_dict_representation = cluster
312+
313+
node_pools_autoscaled = False
314+
for node_pool in cluster_dict_representation["node_pools"]:
315+
try:
316+
if node_pool["autoscaling"]["enabled"] is True:
317+
node_pools_autoscaled = True
318+
break
319+
except KeyError:
320+
self.log.info("No autoscaling enabled in Node pools level.")
321+
break
322+
if (
323+
cluster_dict_representation["autopilot"]["enabled"]
324+
or cluster_dict_representation["autoscaling"]["enable_node_autoprovisioning"]
325+
or node_pools_autoscaled
326+
):
327+
return True
328+
else:
329+
return False
330+
331+
332+
class GKEDeploymentHook(GoogleBaseHook, KubernetesHook):
333+
"""Google Kubernetes Engine Deployment APIs."""
334+
335+
def __init__(
336+
self,
337+
cluster_url: str,
338+
ssl_ca_cert: str,
339+
*args,
340+
**kwargs,
341+
):
342+
super().__init__(*args, **kwargs)
343+
self._cluster_url = cluster_url
344+
self._ssl_ca_cert = ssl_ca_cert
345+
346+
@cached_property
347+
def api_client(self) -> client.ApiClient:
348+
return self.get_conn()
349+
350+
@cached_property
351+
def core_v1_client(self) -> client.CoreV1Api:
352+
return client.CoreV1Api(self.api_client)
353+
354+
@cached_property
355+
def batch_v1_client(self) -> client.BatchV1Api:
356+
return client.BatchV1Api(self.api_client)
357+
358+
@cached_property
359+
def apps_v1_client(self) -> client.AppsV1Api:
360+
return client.AppsV1Api(api_client=self.api_client)
361+
362+
def get_conn(self) -> client.ApiClient:
363+
configuration = self._get_config()
364+
configuration.refresh_api_key_hook = self._refresh_api_key_hook
365+
return client.ApiClient(configuration)
366+
367+
def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
368+
configuration.api_key = {"authorization": self._get_token(self.get_credentials())}
369+
370+
def _get_config(self) -> client.configuration.Configuration:
371+
configuration = client.Configuration(
372+
host=self._cluster_url,
373+
api_key_prefix={"authorization": "Bearer"},
374+
api_key={"authorization": self._get_token(self.get_credentials())},
375+
)
376+
configuration.ssl_ca_cert = FileOrData(
377+
{
378+
"certificate-authority-data": self._ssl_ca_cert,
379+
},
380+
file_key_name="certificate-authority",
381+
).as_file()
382+
return configuration
383+
384+
@staticmethod
385+
def _get_token(creds: google.auth.credentials.Credentials) -> str:
386+
if creds.token is None or creds.expired:
387+
auth_req = google_requests.Request()
388+
creds.refresh(auth_req)
389+
return creds.token
390+
391+
def check_kueue_deployment_running(self, name, namespace):
392+
timeout = 300
393+
polling_period_seconds = 2
394+
395+
while timeout is None or timeout > 0:
396+
try:
397+
deployment = self.get_deployment_status(name=name, namespace=namespace)
398+
deployment_status = V1Deployment.to_dict(deployment)["status"]
399+
replicas = deployment_status["replicas"]
400+
ready_replicas = deployment_status["ready_replicas"]
401+
unavailable_replicas = deployment_status["unavailable_replicas"]
402+
if (
403+
replicas is not None
404+
and ready_replicas is not None
405+
and unavailable_replicas is None
406+
and replicas == ready_replicas
407+
):
408+
return
409+
else:
410+
self.log.info("Waiting until Deployment will be ready...")
411+
time.sleep(polling_period_seconds)
412+
except Exception as e:
413+
self.log.exception("Exception occurred while checking for Deployment status.")
414+
raise e
415+
416+
if timeout is not None:
417+
timeout -= polling_period_seconds
418+
419+
raise AirflowException("Deployment timed out")
420+
302421

303422
class GKEAsyncHook(GoogleBaseAsyncHook):
304423
"""Asynchronous client of GKE."""
@@ -431,6 +550,32 @@ def _get_token(creds: google.auth.credentials.Credentials) -> str:
431550
creds.refresh(auth_req)
432551
return creds.token
433552

553+
def apply_from_yaml_file(
554+
self,
555+
yaml_file: str | None = None,
556+
yaml_objects: list[dict] | None = None,
557+
verbose: bool = False,
558+
namespace: str = "default",
559+
):
560+
"""
561+
Perform an action from a yaml file on a Pod.
562+
563+
:param yaml_file: Contains the path to yaml file.
564+
:param yaml_objects: List of YAML objects; used instead of reading the yaml_file.
565+
:param verbose: If True, print confirmation from create action. Default is False.
566+
:param namespace: Contains the namespace to create all resources inside. The namespace must
567+
preexist otherwise the resource creation will fail.
568+
"""
569+
k8s_client = self.get_conn()
570+
571+
utils.create_from_yaml(
572+
k8s_client=k8s_client,
573+
yaml_objects=yaml_objects,
574+
yaml_file=yaml_file,
575+
verbose=verbose,
576+
namespace=namespace,
577+
)
578+
434579
def get_pod(self, name: str, namespace: str) -> V1Pod:
435580
"""Get a pod object.
436581

0 commit comments

Comments
 (0)