Skip to content

Commit 18983f0

Browse files
committed
Fixed #13003 -- Ensured that ._state.db is set correctly for select_related() queries. Thanks to Alex Gaynor for the report.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@12701 bcc190cf-cafb-0310-a4f2-bffc1f526a37
1 parent 3508a86 commit 18983f0

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

django/db/models/query.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def iterator(self):
267267
for row in compiler.results_iter():
268268
if fill_cache:
269269
obj, _ = get_cached_row(self.model, row,
270-
index_start, max_depth,
270+
index_start, using=self.db, max_depth=max_depth,
271271
requested=requested, offset=len(aggregate_select),
272272
only_load=only_load)
273273
else:
@@ -279,16 +279,16 @@ def iterator(self):
279279
# Omit aggregates in object creation.
280280
obj = self.model(*row[index_start:aggregate_start])
281281

282+
# Store the source database of the object
283+
obj._state.db = self.db
284+
282285
for i, k in enumerate(extra_select):
283286
setattr(obj, k, row[i])
284287

285288
# Add the aggregates to the model
286289
for i, aggregate in enumerate(aggregate_select):
287290
setattr(obj, aggregate, row[i+aggregate_start])
288291

289-
# Store the source database of the object
290-
obj._state.db = self.db
291-
292292
yield obj
293293

294294
def aggregate(self, *args, **kwargs):
@@ -1112,7 +1112,7 @@ def update(self, **kwargs):
11121112
value_annotation = False
11131113

11141114

1115-
def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
1115+
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
11161116
requested=None, offset=0, only_load=None):
11171117
"""
11181118
Helper function that recursively returns an object with the specified
@@ -1126,6 +1126,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
11261126
* row - the row of data returned by the database cursor
11271127
* index_start - the index of the row at which data for this
11281128
object is known to start
1129+
* using - the database alias on which the query is being executed.
11291130
* max_depth - the maximum depth to which a select_related()
11301131
relationship should be explored.
11311132
* cur_depth - the current depth in the select_related() tree.
@@ -1170,6 +1171,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
11701171
obj = klass(**dict(zip(init_list, fields)))
11711172
else:
11721173
obj = klass(*fields)
1174+
11731175
else:
11741176
# Load all fields on klass
11751177
field_count = len(klass._meta.fields)
@@ -1182,6 +1184,10 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
11821184
else:
11831185
obj = klass(*fields)
11841186

1187+
# If an object was retrieved, set the database state.
1188+
if obj:
1189+
obj._state.db = using
1190+
11851191
index_end = index_start + field_count + offset
11861192
# Iterate over each related object, populating any
11871193
# select_related() fields
@@ -1193,8 +1199,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
11931199
else:
11941200
next = None
11951201
# Recursively retrieve the data for the related object
1196-
cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
1197-
cur_depth+1, next)
1202+
cached_row = get_cached_row(f.rel.to, row, index_end, using,
1203+
max_depth, cur_depth+1, next)
11981204
# If the recursive descent found an object, populate the
11991205
# descriptor caches relevant to the object
12001206
if cached_row:
@@ -1222,8 +1228,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
12221228
continue
12231229
next = requested[f.related_query_name()]
12241230
# Recursively retrieve the data for the related object
1225-
cached_row = get_cached_row(model, row, index_end, max_depth,
1226-
cur_depth+1, next)
1231+
cached_row = get_cached_row(model, row, index_end, using,
1232+
max_depth, cur_depth+1, next)
12271233
# If the recursive descent found an object, populate the
12281234
# descriptor caches relevant to the object
12291235
if cached_row:

tests/regressiontests/multiple_database/tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,20 @@ def test_raw(self):
641641
val = Book.objects.raw('SELECT id FROM "multiple_database_book"').using('other')
642642
self.assertEqual(map(lambda o: o.pk, val), [dive.pk])
643643

644+
def test_select_related(self):
645+
"Database assignment is retained if an object is retrieved with select_related()"
646+
# Create a book and author on the other database
647+
mark = Person.objects.using('other').create(name="Mark Pilgrim")
648+
dive = Book.objects.using('other').create(title="Dive into Python",
649+
published=datetime.date(2009, 5, 4),
650+
editor=mark)
651+
652+
# Retrieve the Person using select_related()
653+
book = Book.objects.using('other').select_related('editor').get(title="Dive into Python")
654+
655+
# The editor instance should have a db state
656+
self.assertEqual(book.editor._state.db, 'other')
657+
644658
class TestRouter(object):
645659
# A test router. The behaviour is vaguely master/slave, but the
646660
# databases aren't assumed to propagate changes.

0 commit comments

Comments
 (0)