18
18
from __future__ import annotations
19
19
20
20
import json
21
+ import os
22
+ import platform
23
+ import tempfile
21
24
from unittest import mock
22
25
from unittest .mock import PropertyMock
23
26
27
30
28
31
from airflow .exceptions import AirflowException
29
32
from airflow .models import Connection
30
- from airflow .providers .google .cloud .hooks .cloud_sql import CloudSQLDatabaseHook , CloudSQLHook
33
+ from airflow .providers .google .cloud .hooks .cloud_sql import (
34
+ CloudSQLDatabaseHook ,
35
+ CloudSQLHook ,
36
+ CloudSqlProxyRunner ,
37
+ )
31
38
from tests .providers .google .cloud .utils .base_gcp_mock import (
32
39
mock_base_gcp_hook_default_project_id ,
33
40
mock_base_gcp_hook_no_default_project_id ,
@@ -847,8 +854,12 @@ def test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable(
847
854
err = ctx .value
848
855
assert "must be a readable file" in str (err )
849
856
857
+ @mock .patch ("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir" )
850
858
@mock .patch ("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection" )
851
- def test_cloudsql_database_hook_validate_socket_path_length_too_long (self , get_connection ):
859
+ def test_cloudsql_database_hook_validate_socket_path_length_too_long (
860
+ self , get_connection , gettempdir_mock
861
+ ):
862
+ gettempdir_mock .return_value = "/tmp"
852
863
connection = Connection ()
853
864
connection .set_extra (
854
865
json .dumps (
@@ -870,8 +881,12 @@ def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_c
870
881
err = ctx .value
871
882
assert "The UNIX socket path length cannot exceed" in str (err )
872
883
884
+ @mock .patch ("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir" )
873
885
@mock .patch ("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection" )
874
- def test_cloudsql_database_hook_validate_socket_path_length_not_too_long (self , get_connection ):
886
+ def test_cloudsql_database_hook_validate_socket_path_length_not_too_long (
887
+ self , get_connection , gettempdir_mock
888
+ ):
889
+ gettempdir_mock .return_value = "/tmp"
875
890
connection = Connection ()
876
891
connection .set_extra (
877
892
json .dumps (
@@ -1093,7 +1108,7 @@ def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection
1093
1108
hook = CloudSQLDatabaseHook ()
1094
1109
connection = hook .create_connection ()
1095
1110
assert "postgres" == connection .conn_type
1096
- assert "/tmp" in connection .host
1111
+ assert tempfile . gettempdir () in connection .host
1097
1112
assert "example-project:europe-west1:testdb" in connection .host
1098
1113
assert connection .port is None
1099
1114
assert "testdb" == connection .schema
@@ -1166,7 +1181,7 @@ def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection):
1166
1181
connection = hook .create_connection ()
1167
1182
assert "mysql" == connection .conn_type
1168
1183
assert "localhost" == connection .host
1169
- assert "/tmp" in connection .extra_dejson ["unix_socket" ]
1184
+ assert tempfile . gettempdir () in connection .extra_dejson ["unix_socket" ]
1170
1185
assert "example-project:europe-west1:testdb" in connection .extra_dejson ["unix_socket" ]
1171
1186
assert connection .port is None
1172
1187
assert "testdb" == connection .schema
@@ -1185,3 +1200,53 @@ def test_hook_with_correct_parameters_mysql_tcp(self, get_connection):
1185
1200
assert "127.0.0.1" == connection .host
1186
1201
assert 3200 != connection .port
1187
1202
assert "testdb" == connection .schema
1203
+
1204
+
1205
+ def get_processor ():
1206
+ processor = os .uname ().machine
1207
+ if processor == "x86_64" :
1208
+ processor = "amd64"
1209
+ return processor
1210
+
1211
+
1212
+ class TestCloudSqlProxyRunner :
1213
+ @pytest .mark .parametrize (
1214
+ ["version" , "download_url" ],
1215
+ [
1216
+ (
1217
+ "v1.23.0" ,
1218
+ "https://storage.googleapis.com/cloudsql-proxy/v1.23.0/cloud_sql_proxy."
1219
+ f"{ platform .system ().lower ()} .{ get_processor ()} " ,
1220
+ ),
1221
+ (
1222
+ "v1.23.0-preview.1" ,
1223
+ "https://storage.googleapis.com/cloudsql-proxy/v1.23.0-preview.1/cloud_sql_proxy."
1224
+ f"{ platform .system ().lower ()} .{ get_processor ()} " ,
1225
+ ),
1226
+ ],
1227
+ )
1228
+ def test_cloud_sql_proxy_runner_version_ok (self , version , download_url ):
1229
+ runner = CloudSqlProxyRunner (
1230
+ path_prefix = "12345678" ,
1231
+ instance_specification = "project:us-east-1:instance" ,
1232
+ sql_proxy_version = version ,
1233
+ )
1234
+ assert runner ._get_sql_proxy_download_url () == download_url
1235
+
1236
+ @pytest .mark .parametrize (
1237
+ "version" ,
1238
+ [
1239
+ "v1.23." ,
1240
+ "v1.23.0.." ,
1241
+ "v1.23.0\\ " ,
1242
+ "\\ " ,
1243
+ ],
1244
+ )
1245
+ def test_cloud_sql_proxy_runner_version_nok (self , version ):
1246
+ runner = CloudSqlProxyRunner (
1247
+ path_prefix = "12345678" ,
1248
+ instance_specification = "project:us-east-1:instance" ,
1249
+ sql_proxy_version = version ,
1250
+ )
1251
+ with pytest .raises (ValueError , match = "The sql_proxy_version should match the regular expression" ):
1252
+ runner ._get_sql_proxy_download_url ()
0 commit comments