|
15 | 15 | # KIND, either express or implied. See the License for the
|
16 | 16 | # specific language governing permissions and limitations
|
17 | 17 | # under the License.
|
18 |
| -""" |
19 |
| -This module contains a Google Kubernetes Engine Hook. |
20 |
| -
|
21 |
| -.. spelling:word-list:: |
22 |
| -
|
23 |
| - gapic |
24 |
| - enums |
25 |
| -""" |
| 18 | +"""This module contains a Google Kubernetes Engine Hook.""" |
26 | 19 | from __future__ import annotations
|
27 | 20 |
|
28 | 21 | import contextlib
|
|
41 | 34 | from google.cloud import container_v1, exceptions # type: ignore[attr-defined]
|
42 | 35 | from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient
|
43 | 36 | from google.cloud.container_v1.types import Cluster, Operation
|
44 |
| -from kubernetes import client |
| 37 | +from kubernetes import client, utils |
| 38 | +from kubernetes.client.models import V1Deployment |
45 | 39 | from kubernetes_asyncio import client as async_client
|
46 | 40 | from kubernetes_asyncio.config.kube_config import FileOrData
|
47 | 41 | from urllib3.exceptions import HTTPError
|
48 | 42 |
|
49 | 43 | from airflow import version
|
50 | 44 | from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
| 45 | +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook |
51 | 46 | from airflow.providers.cncf.kubernetes.kube_client import _enable_tcp_keepalive
|
52 | 47 | from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol
|
53 | 48 | from airflow.providers.google.common.consts import CLIENT_INFO
|
@@ -299,6 +294,130 @@ def get_cluster(
|
299 | 294 | timeout=timeout,
|
300 | 295 | )
|
301 | 296 |
|
| 297 | + def check_cluster_autoscaling_ability(self, cluster: Cluster | dict): |
| 298 | + """ |
| 299 | + Check if the specified Cluster has ability to autoscale. |
| 300 | +
|
| 301 | + Cluster should be Autopilot, with Node Auto-provisioning or regular auto-scaled node pools. |
| 302 | + Returns True if the Cluster supports autoscaling, otherwise returns False. |
| 303 | +
|
| 304 | + :param cluster: The Cluster object. |
| 305 | + """ |
| 306 | + if isinstance(cluster, Cluster): |
| 307 | + cluster_dict_representation = Cluster.to_dict(cluster) |
| 308 | + elif not isinstance(cluster, dict): |
| 309 | + raise AirflowException("cluster is not instance of Cluster proto or python dict") |
| 310 | + else: |
| 311 | + cluster_dict_representation = cluster |
| 312 | + |
| 313 | + node_pools_autoscaled = False |
| 314 | + for node_pool in cluster_dict_representation["node_pools"]: |
| 315 | + try: |
| 316 | + if node_pool["autoscaling"]["enabled"] is True: |
| 317 | + node_pools_autoscaled = True |
| 318 | + break |
| 319 | + except KeyError: |
| 320 | + self.log.info("No autoscaling enabled in Node pools level.") |
| 321 | + break |
| 322 | + if ( |
| 323 | + cluster_dict_representation["autopilot"]["enabled"] |
| 324 | + or cluster_dict_representation["autoscaling"]["enable_node_autoprovisioning"] |
| 325 | + or node_pools_autoscaled |
| 326 | + ): |
| 327 | + return True |
| 328 | + else: |
| 329 | + return False |
| 330 | + |
| 331 | + |
| 332 | +class GKEDeploymentHook(GoogleBaseHook, KubernetesHook): |
| 333 | + """Google Kubernetes Engine Deployment APIs.""" |
| 334 | + |
| 335 | + def __init__( |
| 336 | + self, |
| 337 | + cluster_url: str, |
| 338 | + ssl_ca_cert: str, |
| 339 | + *args, |
| 340 | + **kwargs, |
| 341 | + ): |
| 342 | + super().__init__(*args, **kwargs) |
| 343 | + self._cluster_url = cluster_url |
| 344 | + self._ssl_ca_cert = ssl_ca_cert |
| 345 | + |
| 346 | + @cached_property |
| 347 | + def api_client(self) -> client.ApiClient: |
| 348 | + return self.get_conn() |
| 349 | + |
| 350 | + @cached_property |
| 351 | + def core_v1_client(self) -> client.CoreV1Api: |
| 352 | + return client.CoreV1Api(self.api_client) |
| 353 | + |
| 354 | + @cached_property |
| 355 | + def batch_v1_client(self) -> client.BatchV1Api: |
| 356 | + return client.BatchV1Api(self.api_client) |
| 357 | + |
| 358 | + @cached_property |
| 359 | + def apps_v1_client(self) -> client.AppsV1Api: |
| 360 | + return client.AppsV1Api(api_client=self.api_client) |
| 361 | + |
| 362 | + def get_conn(self) -> client.ApiClient: |
| 363 | + configuration = self._get_config() |
| 364 | + configuration.refresh_api_key_hook = self._refresh_api_key_hook |
| 365 | + return client.ApiClient(configuration) |
| 366 | + |
| 367 | + def _refresh_api_key_hook(self, configuration: client.configuration.Configuration): |
| 368 | + configuration.api_key = {"authorization": self._get_token(self.get_credentials())} |
| 369 | + |
| 370 | + def _get_config(self) -> client.configuration.Configuration: |
| 371 | + configuration = client.Configuration( |
| 372 | + host=self._cluster_url, |
| 373 | + api_key_prefix={"authorization": "Bearer"}, |
| 374 | + api_key={"authorization": self._get_token(self.get_credentials())}, |
| 375 | + ) |
| 376 | + configuration.ssl_ca_cert = FileOrData( |
| 377 | + { |
| 378 | + "certificate-authority-data": self._ssl_ca_cert, |
| 379 | + }, |
| 380 | + file_key_name="certificate-authority", |
| 381 | + ).as_file() |
| 382 | + return configuration |
| 383 | + |
| 384 | + @staticmethod |
| 385 | + def _get_token(creds: google.auth.credentials.Credentials) -> str: |
| 386 | + if creds.token is None or creds.expired: |
| 387 | + auth_req = google_requests.Request() |
| 388 | + creds.refresh(auth_req) |
| 389 | + return creds.token |
| 390 | + |
| 391 | + def check_kueue_deployment_running(self, name, namespace): |
| 392 | + timeout = 300 |
| 393 | + polling_period_seconds = 2 |
| 394 | + |
| 395 | + while timeout is None or timeout > 0: |
| 396 | + try: |
| 397 | + deployment = self.get_deployment_status(name=name, namespace=namespace) |
| 398 | + deployment_status = V1Deployment.to_dict(deployment)["status"] |
| 399 | + replicas = deployment_status["replicas"] |
| 400 | + ready_replicas = deployment_status["ready_replicas"] |
| 401 | + unavailable_replicas = deployment_status["unavailable_replicas"] |
| 402 | + if ( |
| 403 | + replicas is not None |
| 404 | + and ready_replicas is not None |
| 405 | + and unavailable_replicas is None |
| 406 | + and replicas == ready_replicas |
| 407 | + ): |
| 408 | + return |
| 409 | + else: |
| 410 | + self.log.info("Waiting until Deployment will be ready...") |
| 411 | + time.sleep(polling_period_seconds) |
| 412 | + except Exception as e: |
| 413 | + self.log.exception("Exception occurred while checking for Deployment status.") |
| 414 | + raise e |
| 415 | + |
| 416 | + if timeout is not None: |
| 417 | + timeout -= polling_period_seconds |
| 418 | + |
| 419 | + raise AirflowException("Deployment timed out") |
| 420 | + |
302 | 421 |
|
303 | 422 | class GKEAsyncHook(GoogleBaseAsyncHook):
|
304 | 423 | """Asynchronous client of GKE."""
|
@@ -431,6 +550,32 @@ def _get_token(creds: google.auth.credentials.Credentials) -> str:
|
431 | 550 | creds.refresh(auth_req)
|
432 | 551 | return creds.token
|
433 | 552 |
|
| 553 | + def apply_from_yaml_file( |
| 554 | + self, |
| 555 | + yaml_file: str | None = None, |
| 556 | + yaml_objects: list[dict] | None = None, |
| 557 | + verbose: bool = False, |
| 558 | + namespace: str = "default", |
| 559 | + ): |
| 560 | + """ |
| 561 | + Perform an action from a yaml file on a Pod. |
| 562 | +
|
| 563 | + :param yaml_file: Contains the path to yaml file. |
| 564 | + :param yaml_objects: List of YAML objects; used instead of reading the yaml_file. |
| 565 | + :param verbose: If True, print confirmation from create action. Default is False. |
| 566 | + :param namespace: Contains the namespace to create all resources inside. The namespace must |
| 567 | + preexist otherwise the resource creation will fail. |
| 568 | + """ |
| 569 | + k8s_client = self.get_conn() |
| 570 | + |
| 571 | + utils.create_from_yaml( |
| 572 | + k8s_client=k8s_client, |
| 573 | + yaml_objects=yaml_objects, |
| 574 | + yaml_file=yaml_file, |
| 575 | + verbose=verbose, |
| 576 | + namespace=namespace, |
| 577 | + ) |
| 578 | + |
434 | 579 | def get_pod(self, name: str, namespace: str) -> V1Pod:
|
435 | 580 | """Get a pod object.
|
436 | 581 |
|
|
0 commit comments