Skip to content

Commit 1f33f21

Browse files
committed
Fixed #36165 -- Made PostgreSQL's SchemaEditor._delete_index_sql() respect the "sql" argument.
This is a follow up of bd366ca. Thank you Daniel Finch for the report.
1 parent e2a8f4d commit 1f33f21

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

django/db/backends/postgresql/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def remove_index(self, model, index, concurrently=False):
322322
self.execute(index.remove_sql(model, self, concurrently=concurrently))
323323

324324
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
325-
sql = (
325+
sql = sql or (
326326
self.sql_delete_index_concurrently
327327
if concurrently
328328
else self.sql_delete_index

tests/postgres_tests/test_indexes.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,30 @@ def create_sql(self, model, schema_editor, using="gin", **kwargs):
665665
str(index.create_sql(CharFieldModel, editor)),
666666
)
667667

668+
def test_custom_sql(self):
669+
class CustomSQLIndex(PostgresIndex):
670+
sql_create_index = "SELECT 1"
671+
sql_delete_index = "SELECT 2"
672+
673+
def create_sql(self, model, schema_editor, using="", **kwargs):
674+
kwargs.setdefault("sql", self.sql_create_index)
675+
return super().create_sql(model, schema_editor, using, **kwargs)
676+
677+
def remove_sql(self, model, schema_editor, **kwargs):
678+
kwargs.setdefault("sql", self.sql_delete_index)
679+
return super().remove_sql(model, schema_editor, **kwargs)
680+
681+
index = CustomSQLIndex(fields=["field"], name="custom_sql_idx")
682+
683+
operations = [
684+
(index.create_sql, CustomSQLIndex.sql_create_index),
685+
(index.remove_sql, CustomSQLIndex.sql_delete_index),
686+
]
687+
for operation, expected in operations:
688+
with self.subTest(operation=operation.__name__):
689+
with connection.schema_editor() as editor:
690+
self.assertEqual(expected, str(operation(CharFieldModel, editor)))
691+
668692
def test_op_class(self):
669693
index_name = "test_op_class"
670694
index = Index(

0 commit comments

Comments
 (0)