Skip to content

Commit 6671058

Browse files
Gagarofelixxm
authored andcommitted
Fixed #30581 -- Added support for Meta.constraints validation.
Thanks Simon Charette, Keryn Knight, and Mariusz Felisiak for reviews.
1 parent 441103a commit 6671058

File tree

17 files changed

+852
-88
lines changed

17 files changed

+852
-88
lines changed

django/contrib/postgres/constraints.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import warnings
22

33
from django.contrib.postgres.indexes import OpClass
4-
from django.db import NotSupportedError
4+
from django.core.exceptions import ValidationError
5+
from django.db import DEFAULT_DB_ALIAS, NotSupportedError
56
from django.db.backends.ddl_references import Expressions, Statement, Table
67
from django.db.models import BaseConstraint, Deferrable, F, Q
7-
from django.db.models.expressions import ExpressionList
8+
from django.db.models.expressions import Exists, ExpressionList
89
from django.db.models.indexes import IndexExpression
10+
from django.db.models.lookups import PostgresOperatorLookup
911
from django.db.models.sql import Query
1012
from django.utils.deprecation import RemovedInDjango50Warning
1113

@@ -32,6 +34,7 @@ def __init__(
3234
deferrable=None,
3335
include=None,
3436
opclasses=(),
37+
violation_error_message=None,
3538
):
3639
if index_type and index_type.lower() not in {"gist", "spgist"}:
3740
raise ValueError(
@@ -78,7 +81,7 @@ def __init__(
7881
category=RemovedInDjango50Warning,
7982
stacklevel=2,
8083
)
81-
super().__init__(name=name)
84+
super().__init__(name=name, violation_error_message=violation_error_message)
8285

8386
def _get_expressions(self, schema_editor, query):
8487
expressions = []
@@ -197,3 +200,44 @@ def __repr__(self):
197200
"" if not self.include else " include=%s" % repr(self.include),
198201
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
199202
)
203+
204+
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
205+
queryset = model._default_manager.using(using)
206+
replacement_map = instance._get_field_value_map(
207+
meta=model._meta, exclude=exclude
208+
)
209+
lookups = []
210+
for idx, (expression, operator) in enumerate(self.expressions):
211+
if isinstance(expression, str):
212+
expression = F(expression)
213+
if isinstance(expression, F):
214+
if exclude and expression.name in exclude:
215+
return
216+
rhs_expression = replacement_map.get(expression.name, expression)
217+
else:
218+
rhs_expression = expression.replace_references(replacement_map)
219+
if exclude:
220+
for expr in rhs_expression.flatten():
221+
if isinstance(expr, F) and expr.name in exclude:
222+
return
223+
# Remove OpClass because it only has sense during the constraint
224+
# creation.
225+
if isinstance(expression, OpClass):
226+
expression = expression.get_source_expressions()[0]
227+
if isinstance(rhs_expression, OpClass):
228+
rhs_expression = rhs_expression.get_source_expressions()[0]
229+
lookup = PostgresOperatorLookup(lhs=expression, rhs=rhs_expression)
230+
lookup.postgres_operator = operator
231+
lookups.append(lookup)
232+
queryset = queryset.filter(*lookups)
233+
model_class_pk = instance._get_pk_val(model._meta)
234+
if not instance._state.adding and model_class_pk is not None:
235+
queryset = queryset.exclude(pk=model_class_pk)
236+
if not self.condition:
237+
if queryset.exists():
238+
raise ValidationError(self.get_violation_error_message())
239+
else:
240+
if (self.condition & Exists(queryset.filter(self.condition))).check(
241+
replacement_map, using=using
242+
):
243+
raise ValidationError(self.get_violation_error_message())

django/db/models/base.py

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from django.db.models.constants import LOOKUP_SEP
2929
from django.db.models.constraints import CheckConstraint, UniqueConstraint
3030
from django.db.models.deletion import CASCADE, Collector
31+
from django.db.models.expressions import RawSQL
3132
from django.db.models.fields.related import (
3233
ForeignObjectRel,
3334
OneToOneField,
@@ -1189,6 +1190,16 @@ def _get_next_or_previous_in_order(self, is_next):
11891190
setattr(self, cachename, obj)
11901191
return getattr(self, cachename)
11911192

1193+
def _get_field_value_map(self, meta, exclude=None):
1194+
if exclude is None:
1195+
exclude = set()
1196+
meta = meta or self._meta
1197+
return {
1198+
field.name: Value(getattr(self, field.attname), field)
1199+
for field in meta.local_concrete_fields
1200+
if field.name not in exclude
1201+
}
1202+
11921203
def prepare_database_save(self, field):
11931204
if self.pk is None:
11941205
raise ValueError(
@@ -1221,7 +1232,7 @@ def validate_unique(self, exclude=None):
12211232
if errors:
12221233
raise ValidationError(errors)
12231234

1224-
def _get_unique_checks(self, exclude=None):
1235+
def _get_unique_checks(self, exclude=None, include_meta_constraints=False):
12251236
"""
12261237
Return a list of checks to perform. Since validate_unique() could be
12271238
called from a ModelForm, some fields may have been excluded; we can't
@@ -1234,13 +1245,15 @@ def _get_unique_checks(self, exclude=None):
12341245
unique_checks = []
12351246

12361247
unique_togethers = [(self.__class__, self._meta.unique_together)]
1237-
constraints = [(self.__class__, self._meta.total_unique_constraints)]
1248+
constraints = []
1249+
if include_meta_constraints:
1250+
constraints = [(self.__class__, self._meta.total_unique_constraints)]
12381251
for parent_class in self._meta.get_parent_list():
12391252
if parent_class._meta.unique_together:
12401253
unique_togethers.append(
12411254
(parent_class, parent_class._meta.unique_together)
12421255
)
1243-
if parent_class._meta.total_unique_constraints:
1256+
if include_meta_constraints and parent_class._meta.total_unique_constraints:
12441257
constraints.append(
12451258
(parent_class, parent_class._meta.total_unique_constraints)
12461259
)
@@ -1251,10 +1264,11 @@ def _get_unique_checks(self, exclude=None):
12511264
# Add the check if the field isn't excluded.
12521265
unique_checks.append((model_class, tuple(check)))
12531266

1254-
for model_class, model_constraints in constraints:
1255-
for constraint in model_constraints:
1256-
if not any(name in exclude for name in constraint.fields):
1257-
unique_checks.append((model_class, constraint.fields))
1267+
if include_meta_constraints:
1268+
for model_class, model_constraints in constraints:
1269+
for constraint in model_constraints:
1270+
if not any(name in exclude for name in constraint.fields):
1271+
unique_checks.append((model_class, constraint.fields))
12581272

12591273
# These are checks for the unique_for_<date/year/month>.
12601274
date_checks = []
@@ -1410,10 +1424,35 @@ def unique_error_message(self, model_class, unique_check):
14101424
params=params,
14111425
)
14121426

1413-
def full_clean(self, exclude=None, validate_unique=True):
1427+
def get_constraints(self):
1428+
constraints = [(self.__class__, self._meta.constraints)]
1429+
for parent_class in self._meta.get_parent_list():
1430+
if parent_class._meta.constraints:
1431+
constraints.append((parent_class, parent_class._meta.constraints))
1432+
return constraints
1433+
1434+
def validate_constraints(self, exclude=None):
1435+
constraints = self.get_constraints()
1436+
using = router.db_for_write(self.__class__, instance=self)
1437+
1438+
errors = {}
1439+
for model_class, model_constraints in constraints:
1440+
for constraint in model_constraints:
1441+
try:
1442+
constraint.validate(model_class, self, exclude=exclude, using=using)
1443+
except ValidationError as e:
1444+
if e.code == "unique" and len(constraint.fields) == 1:
1445+
errors.setdefault(constraint.fields[0], []).append(e)
1446+
else:
1447+
errors = e.update_error_dict(errors)
1448+
if errors:
1449+
raise ValidationError(errors)
1450+
1451+
def full_clean(self, exclude=None, validate_unique=True, validate_constraints=True):
14141452
"""
1415-
Call clean_fields(), clean(), and validate_unique() on the model.
1416-
Raise a ValidationError for any errors that occur.
1453+
Call clean_fields(), clean(), validate_unique(), and
1454+
validate_constraints() on the model. Raise a ValidationError for any
1455+
errors that occur.
14171456
"""
14181457
errors = {}
14191458
if exclude is None:
@@ -1443,6 +1482,16 @@ def full_clean(self, exclude=None, validate_unique=True):
14431482
except ValidationError as e:
14441483
errors = e.update_error_dict(errors)
14451484

1485+
# Run constraints checks, but only for fields that passed validation.
1486+
if validate_constraints:
1487+
for name in errors:
1488+
if name != NON_FIELD_ERRORS and name not in exclude:
1489+
exclude.add(name)
1490+
try:
1491+
self.validate_constraints(exclude=exclude)
1492+
except ValidationError as e:
1493+
errors = e.update_error_dict(errors)
1494+
14461495
if errors:
14471496
raise ValidationError(errors)
14481497

@@ -2339,8 +2388,28 @@ def _check_constraints(cls, databases):
23392388
connection.features.supports_table_check_constraints
23402389
or "supports_table_check_constraints"
23412390
not in cls._meta.required_db_features
2342-
) and isinstance(constraint.check, Q):
2343-
references.update(cls._get_expr_references(constraint.check))
2391+
):
2392+
if isinstance(constraint.check, Q):
2393+
references.update(
2394+
cls._get_expr_references(constraint.check)
2395+
)
2396+
if any(
2397+
isinstance(expr, RawSQL)
2398+
for expr in constraint.check.flatten()
2399+
):
2400+
errors.append(
2401+
checks.Warning(
2402+
f"Check constraint {constraint.name!r} contains "
2403+
f"RawSQL() expression and won't be validated "
2404+
f"during the model full_clean().",
2405+
hint=(
2406+
"Silence this warning if you don't care about "
2407+
"it."
2408+
),
2409+
obj=cls,
2410+
id="models.W045",
2411+
),
2412+
)
23442413
for field_name, *lookups in references:
23452414
# pk is an alias that won't be found by opts.get_field.
23462415
if field_name != "pk":

django/db/models/constraints.py

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
from enum import Enum
22

3-
from django.db.models.expressions import ExpressionList, F
3+
from django.core.exceptions import FieldError, ValidationError
4+
from django.db import connections
5+
from django.db.models.expressions import Exists, ExpressionList, F
46
from django.db.models.indexes import IndexExpression
7+
from django.db.models.lookups import Exact
58
from django.db.models.query_utils import Q
69
from django.db.models.sql.query import Query
10+
from django.db.utils import DEFAULT_DB_ALIAS
11+
from django.utils.translation import gettext_lazy as _
712

813
__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
914

1015

1116
class BaseConstraint:
12-
def __init__(self, name):
17+
violation_error_message = _("Constraint “%(name)s” is violated.")
18+
19+
def __init__(self, name, violation_error_message=None):
1320
self.name = name
21+
if violation_error_message is not None:
22+
self.violation_error_message = violation_error_message
1423

1524
@property
1625
def contains_expressions(self):
@@ -25,6 +34,12 @@ def create_sql(self, model, schema_editor):
2534
def remove_sql(self, model, schema_editor):
2635
raise NotImplementedError("This method must be implemented by a subclass.")
2736

37+
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
38+
raise NotImplementedError("This method must be implemented by a subclass.")
39+
40+
def get_violation_error_message(self):
41+
return self.violation_error_message % {"name": self.name}
42+
2843
def deconstruct(self):
2944
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
3045
path = path.replace("django.db.models.constraints", "django.db.models")
@@ -36,13 +51,13 @@ def clone(self):
3651

3752

3853
class CheckConstraint(BaseConstraint):
39-
def __init__(self, *, check, name):
54+
def __init__(self, *, check, name, violation_error_message=None):
4055
self.check = check
4156
if not getattr(check, "conditional", False):
4257
raise TypeError(
4358
"CheckConstraint.check must be a Q instance or boolean expression."
4459
)
45-
super().__init__(name)
60+
super().__init__(name, violation_error_message=violation_error_message)
4661

4762
def _get_check_sql(self, model, schema_editor):
4863
query = Query(model=model, alias_cols=False)
@@ -62,6 +77,14 @@ def create_sql(self, model, schema_editor):
6277
def remove_sql(self, model, schema_editor):
6378
return schema_editor._delete_check_sql(model, self.name)
6479

80+
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
81+
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
82+
try:
83+
if not Q(self.check).check(against, using=using):
84+
raise ValidationError(self.get_violation_error_message())
85+
except FieldError:
86+
pass
87+
6588
def __repr__(self):
6689
return "<%s: check=%s name=%s>" % (
6790
self.__class__.__qualname__,
@@ -99,6 +122,7 @@ def __init__(
99122
deferrable=None,
100123
include=None,
101124
opclasses=(),
125+
violation_error_message=None,
102126
):
103127
if not name:
104128
raise ValueError("A unique constraint must be named.")
@@ -148,7 +172,7 @@ def __init__(
148172
F(expression) if isinstance(expression, str) else expression
149173
for expression in expressions
150174
)
151-
super().__init__(name)
175+
super().__init__(name, violation_error_message=violation_error_message)
152176

153177
@property
154178
def contains_expressions(self):
@@ -265,3 +289,61 @@ def deconstruct(self):
265289
if self.opclasses:
266290
kwargs["opclasses"] = self.opclasses
267291
return path, self.expressions, kwargs
292+
293+
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
294+
queryset = model._default_manager.using(using)
295+
if self.fields:
296+
lookup_kwargs = {}
297+
for field_name in self.fields:
298+
if exclude and field_name in exclude:
299+
return
300+
field = model._meta.get_field(field_name)
301+
lookup_value = getattr(instance, field.attname)
302+
if lookup_value is None or (
303+
lookup_value == ""
304+
and connections[using].features.interprets_empty_strings_as_nulls
305+
):
306+
# A composite constraint containing NULL value cannot cause
307+
# a violation since NULL != NULL in SQL.
308+
return
309+
lookup_kwargs[field.name] = lookup_value
310+
queryset = queryset.filter(**lookup_kwargs)
311+
else:
312+
# Ignore constraints with excluded fields.
313+
if exclude:
314+
for expression in self.expressions:
315+
for expr in expression.flatten():
316+
if isinstance(expr, F) and expr.name in exclude:
317+
return
318+
replacement_map = instance._get_field_value_map(
319+
meta=model._meta, exclude=exclude
320+
)
321+
expressions = [
322+
Exact(expr, expr.replace_references(replacement_map))
323+
for expr in self.expressions
324+
]
325+
queryset = queryset.filter(*expressions)
326+
model_class_pk = instance._get_pk_val(model._meta)
327+
if not instance._state.adding and model_class_pk is not None:
328+
queryset = queryset.exclude(pk=model_class_pk)
329+
if not self.condition:
330+
if queryset.exists():
331+
if self.expressions:
332+
raise ValidationError(self.get_violation_error_message())
333+
# When fields are defined, use the unique_error_message() for
334+
# backward compatibility.
335+
for model, constraints in instance.get_constraints():
336+
for constraint in constraints:
337+
if constraint is self:
338+
raise ValidationError(
339+
instance.unique_error_message(model, self.fields)
340+
)
341+
else:
342+
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
343+
try:
344+
if (self.condition & Exists(queryset.filter(self.condition))).check(
345+
against, using=using
346+
):
347+
raise ValidationError(self.get_violation_error_message())
348+
except FieldError:
349+
pass

0 commit comments

Comments
 (0)