Skip to content

Commit 5e6f8eb

Browse files
authored
Check that cloud sql provider version is valid (#29497)
Additional chek on cloud sql version should be done to avoid downloading non-existing binary.
1 parent cf81455 commit 5e6f8eb

File tree

2 files changed

+91
-15
lines changed

2 files changed

+91
-15
lines changed

airflow/providers/google/cloud/hooks/cloud_sql.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
# Time to sleep between active checks of the operation results
6060
TIME_TO_SLEEP_IN_SECONDS = 20
6161

62+
CLOUD_SQL_PROXY_VERSION_REGEX = re.compile(r"^v?(\d+\.\d+\.\d+)(-\w*.?\d?)?$")
63+
6264

6365
class CloudSqlOperationStatus:
6466
"""Helper class with operation statuses."""
@@ -449,16 +451,7 @@ def _download_sql_proxy_if_needed(self) -> None:
449451
if os.path.isfile(self.sql_proxy_path):
450452
self.log.info("cloud-sql-proxy is already present")
451453
return
452-
system = platform.system().lower()
453-
processor = os.uname().machine
454-
if processor == "x86_64":
455-
processor = "amd64"
456-
if not self.sql_proxy_version:
457-
download_url = CLOUD_SQL_PROXY_DOWNLOAD_URL.format(system, processor)
458-
else:
459-
download_url = CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL.format(
460-
self.sql_proxy_version, system, processor
461-
)
454+
download_url = self._get_sql_proxy_download_url()
462455
proxy_path_tmp = self.sql_proxy_path + ".tmp"
463456
self.log.info("Downloading cloud_sql_proxy from %s to %s", download_url, proxy_path_tmp)
464457
# httpx has a breaking API change (follow_redirects vs allow_redirects)
@@ -482,6 +475,24 @@ def _download_sql_proxy_if_needed(self) -> None:
482475
os.chmod(self.sql_proxy_path, 0o744) # Set executable bit
483476
self.sql_proxy_was_downloaded = True
484477

478+
def _get_sql_proxy_download_url(self):
479+
system = platform.system().lower()
480+
processor = os.uname().machine
481+
if processor == "x86_64":
482+
processor = "amd64"
483+
if not self.sql_proxy_version:
484+
download_url = CLOUD_SQL_PROXY_DOWNLOAD_URL.format(system, processor)
485+
else:
486+
if not CLOUD_SQL_PROXY_VERSION_REGEX.match(self.sql_proxy_version):
487+
raise ValueError(
488+
"The sql_proxy_version should match the regular expression "
489+
f"{CLOUD_SQL_PROXY_VERSION_REGEX.pattern}"
490+
)
491+
download_url = CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL.format(
492+
self.sql_proxy_version, system, processor
493+
)
494+
return download_url
495+
485496
def _get_credential_parameters(self) -> list[str]:
486497
extras = GoogleBaseHook.get_connection(conn_id=self.gcp_conn_id).extra_dejson
487498
key_path = get_field(extras, "key_path")

tests/providers/google/cloud/hooks/test_cloud_sql.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from __future__ import annotations
1919

2020
import json
21+
import os
22+
import platform
23+
import tempfile
2124
from unittest import mock
2225
from unittest.mock import PropertyMock
2326

@@ -27,7 +30,11 @@
2730

2831
from airflow.exceptions import AirflowException
2932
from airflow.models import Connection
30-
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook
33+
from airflow.providers.google.cloud.hooks.cloud_sql import (
34+
CloudSQLDatabaseHook,
35+
CloudSQLHook,
36+
CloudSqlProxyRunner,
37+
)
3138
from tests.providers.google.cloud.utils.base_gcp_mock import (
3239
mock_base_gcp_hook_default_project_id,
3340
mock_base_gcp_hook_no_default_project_id,
@@ -847,8 +854,12 @@ def test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable(
847854
err = ctx.value
848855
assert "must be a readable file" in str(err)
849856

857+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir")
850858
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
851-
def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_connection):
859+
def test_cloudsql_database_hook_validate_socket_path_length_too_long(
860+
self, get_connection, gettempdir_mock
861+
):
862+
gettempdir_mock.return_value = "/tmp"
852863
connection = Connection()
853864
connection.set_extra(
854865
json.dumps(
@@ -870,8 +881,12 @@ def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_c
870881
err = ctx.value
871882
assert "The UNIX socket path length cannot exceed" in str(err)
872883

884+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir")
873885
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
874-
def test_cloudsql_database_hook_validate_socket_path_length_not_too_long(self, get_connection):
886+
def test_cloudsql_database_hook_validate_socket_path_length_not_too_long(
887+
self, get_connection, gettempdir_mock
888+
):
889+
gettempdir_mock.return_value = "/tmp"
875890
connection = Connection()
876891
connection.set_extra(
877892
json.dumps(
@@ -1093,7 +1108,7 @@ def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection
10931108
hook = CloudSQLDatabaseHook()
10941109
connection = hook.create_connection()
10951110
assert "postgres" == connection.conn_type
1096-
assert "/tmp" in connection.host
1111+
assert tempfile.gettempdir() in connection.host
10971112
assert "example-project:europe-west1:testdb" in connection.host
10981113
assert connection.port is None
10991114
assert "testdb" == connection.schema
@@ -1166,7 +1181,7 @@ def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection):
11661181
connection = hook.create_connection()
11671182
assert "mysql" == connection.conn_type
11681183
assert "localhost" == connection.host
1169-
assert "/tmp" in connection.extra_dejson["unix_socket"]
1184+
assert tempfile.gettempdir() in connection.extra_dejson["unix_socket"]
11701185
assert "example-project:europe-west1:testdb" in connection.extra_dejson["unix_socket"]
11711186
assert connection.port is None
11721187
assert "testdb" == connection.schema
@@ -1185,3 +1200,53 @@ def test_hook_with_correct_parameters_mysql_tcp(self, get_connection):
11851200
assert "127.0.0.1" == connection.host
11861201
assert 3200 != connection.port
11871202
assert "testdb" == connection.schema
1203+
1204+
1205+
def get_processor():
1206+
processor = os.uname().machine
1207+
if processor == "x86_64":
1208+
processor = "amd64"
1209+
return processor
1210+
1211+
1212+
class TestCloudSqlProxyRunner:
1213+
@pytest.mark.parametrize(
1214+
["version", "download_url"],
1215+
[
1216+
(
1217+
"v1.23.0",
1218+
"https://storage.googleapis.com/cloudsql-proxy/v1.23.0/cloud_sql_proxy."
1219+
f"{platform.system().lower()}.{get_processor()}",
1220+
),
1221+
(
1222+
"v1.23.0-preview.1",
1223+
"https://storage.googleapis.com/cloudsql-proxy/v1.23.0-preview.1/cloud_sql_proxy."
1224+
f"{platform.system().lower()}.{get_processor()}",
1225+
),
1226+
],
1227+
)
1228+
def test_cloud_sql_proxy_runner_version_ok(self, version, download_url):
1229+
runner = CloudSqlProxyRunner(
1230+
path_prefix="12345678",
1231+
instance_specification="project:us-east-1:instance",
1232+
sql_proxy_version=version,
1233+
)
1234+
assert runner._get_sql_proxy_download_url() == download_url
1235+
1236+
@pytest.mark.parametrize(
1237+
"version",
1238+
[
1239+
"v1.23.",
1240+
"v1.23.0..",
1241+
"v1.23.0\\",
1242+
"\\",
1243+
],
1244+
)
1245+
def test_cloud_sql_proxy_runner_version_nok(self, version):
1246+
runner = CloudSqlProxyRunner(
1247+
path_prefix="12345678",
1248+
instance_specification="project:us-east-1:instance",
1249+
sql_proxy_version=version,
1250+
)
1251+
with pytest.raises(ValueError, match="The sql_proxy_version should match the regular expression"):
1252+
runner._get_sql_proxy_download_url()

0 commit comments

Comments
 (0)