diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/BatchCursorFlux.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/BatchCursorFlux.java index 90bbe9ed0a4..119598a265b 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/BatchCursorFlux.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/BatchCursorFlux.java @@ -18,11 +18,9 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; -import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; -import reactor.util.context.Context; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; @@ -48,9 +46,9 @@ public void subscribe(final Subscriber subscriber) { if (calculateDemand(demand) > 0 && inProgress.compareAndSet(false, true)) { if (batchCursor == null) { int batchSize = calculateBatchSize(sink.requestedFromDownstream()); - Context initialContext = subscriber instanceof CoreSubscriber - ? ((CoreSubscriber) subscriber).currentContext() : null; - batchCursorPublisher.batchCursor(batchSize).subscribe(bc -> { + batchCursorPublisher.batchCursor(batchSize) + .contextWrite(sink.contextView()) + .subscribe(bc -> { batchCursor = bc; inProgress.set(false); @@ -60,7 +58,7 @@ public void subscribe(final Subscriber subscriber) { } else { recurseCursor(); } - }, sink::error, null, initialContext); + }, sink::error); } else { inProgress.set(false); recurseCursor(); @@ -86,6 +84,7 @@ private void recurseCursor(){ } else { batchCursor.setBatchSize(calculateBatchSize(sink.requestedFromDownstream())); Mono.from(batchCursor.next(() -> sink.isCancelled())) + .contextWrite(sink.contextView()) .doOnCancel(this::closeCursor) .subscribe(results -> { if (!results.isEmpty()) { diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/BatchCursorPublisher.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/BatchCursorPublisher.java index cf5a9d9f25b..13ee27f002f 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/BatchCursorPublisher.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/BatchCursorPublisher.java @@ -123,21 +123,17 @@ public TimeoutMode getTimeoutMode() { public Publisher first() { return batchCursor(this::asAsyncFirstReadOperation) - .flatMap(batchCursor -> Mono.create(sink -> { + .flatMap(batchCursor -> { batchCursor.setBatchSize(1); - Mono.from(batchCursor.next()) + return Mono.from(batchCursor.next()) .doOnTerminate(batchCursor::close) - .doOnError(sink::error) - .doOnSuccess(results -> { + .flatMap(results -> { if (results == null || results.isEmpty()) { - sink.success(); - } else { - sink.success(results.get(0)); + return Mono.empty(); } - }) - .contextWrite(sink.contextView()) - .subscribe(); - })); + return Mono.fromCallable(() -> results.get(0)); + }); + }); } @Override diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/Crypt.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/Crypt.java index 6d5aca27457..13d9373a3ff 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/Crypt.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/Crypt.java @@ -306,6 +306,7 @@ private void collInfo(final MongoCryptContext cryptContext, sink.error(new IllegalStateException("Missing database name")); } else { collectionInfoRetriever.filter(databaseName, cryptContext.getMongoOperation(), operationTimeout) + .contextWrite(sink.contextView()) .doOnSuccess(result -> { if (result != null) { cryptContext.addMongoOperationResult(result); @@ -328,6 +329,7 @@ private void mark(final MongoCryptContext cryptContext, sink.error(wrapInClientException(new IllegalStateException("Missing database name"))); } else { commandMarker.mark(databaseName, cryptContext.getMongoOperation(), operationTimeout) + .contextWrite(sink.contextView()) .doOnSuccess(result -> { cryptContext.addMongoOperationResult(result); cryptContext.completeMongoOperation(); @@ -343,6 +345,7 @@ private void fetchKeys(final MongoCryptContext cryptContext, final MonoSink sink, @Nullable final Timeout operationTimeout) { keyRetriever.find(cryptContext.getMongoOperation(), operationTimeout) + .contextWrite(sink.contextView()) .doOnSuccess(results -> { for (BsonDocument result : results) { cryptContext.addMongoOperationResult(result); @@ -361,11 +364,13 @@ private void decryptKeys(final MongoCryptContext cryptContext, MongoKeyDecryptor keyDecryptor = cryptContext.nextKeyDecryptor(); if (keyDecryptor != null) { keyManagementService.decryptKey(keyDecryptor, operationTimeout) + .contextWrite(sink.contextView()) .doOnSuccess(r -> decryptKeys(cryptContext, databaseName, sink, operationTimeout)) .doOnError(e -> sink.error(wrapInClientException(e))) .subscribe(); } else { Mono.fromRunnable(cryptContext::completeKeyDecryptors) + .contextWrite(sink.contextView()) .doOnSuccess(r -> executeStateMachineWithSink(cryptContext, databaseName, sink, operationTimeout)) .doOnError(e -> sink.error(wrapInClientException(e))) .subscribe(); diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java index a45d369c676..7d9a46cdf3f 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java @@ -40,9 +40,6 @@ import java.util.Date; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; import static com.mongodb.ReadPreference.primary; import static com.mongodb.assertions.Assertions.notNull; @@ -106,7 +103,7 @@ public BsonValue getId() { @Override public void subscribe(final Subscriber s) { - Mono.defer(() -> { + Mono.deferContextual(ctx -> { AtomicBoolean terminated = new AtomicBoolean(false); Timeout timeout = TimeoutContext.startTimeout(timeoutMs); return createCheckAndCreateIndexesMono(timeout) @@ -120,7 +117,7 @@ public void subscribe(final Subscriber s) { return originalError; }) .then(Mono.error(originalError))) - .doOnCancel(() -> createCancellationMono(terminated, timeout).subscribe()) + .doOnCancel(() -> createCancellationMono(terminated, timeout).contextWrite(ctx).subscribe()) .then(); }).subscribe(s); } @@ -149,38 +146,15 @@ public void subscribe(final Subscriber subscriber) { } private Mono createCheckAndCreateIndexesMono(@Nullable final Timeout timeout) { - AtomicBoolean collectionExists = new AtomicBoolean(false); - return Mono.create(sink -> findAllInCollection(filesCollection, timeout).subscribe( - d -> collectionExists.set(true), - sink::error, - () -> { - if (collectionExists.get()) { - sink.success(); - } else { - checkAndCreateIndex(filesCollection.withReadPreference(primary()), FILES_INDEX, timeout) - .doOnSuccess(i -> checkAndCreateIndex(chunksCollection.withReadPreference(primary()), CHUNKS_INDEX, timeout) - .subscribe(unused -> {}, sink::error, sink::success)) - .subscribe(unused -> {}, sink::error); - } - }) - ); - } - - private Mono findAllInCollection(final MongoCollection collection, @Nullable final Timeout timeout) { - return collectionWithTimeoutDeferred(collection - .withDocumentClass(Document.class) - .withReadPreference(primary()), timeout) - .flatMap(wrappedCollection -> { - if (clientSession != null) { - return Mono.from(wrappedCollection.find(clientSession) - .projection(PROJECTION) - .first()); - } else { - return Mono.from(wrappedCollection.find() - .projection(PROJECTION) - .first()); - } - }); + return collectionWithTimeoutDeferred(filesCollection.withDocumentClass(Document.class).withReadPreference(primary()), timeout) + .map(collection -> clientSession != null ? collection.find(clientSession) : collection.find()) + .flatMap(findPublisher -> Mono.from(findPublisher.projection(PROJECTION).first())) + .switchIfEmpty(Mono.defer(() -> + checkAndCreateIndex(filesCollection.withReadPreference(primary()), FILES_INDEX, timeout) + .then(checkAndCreateIndex(chunksCollection.withReadPreference(primary()), CHUNKS_INDEX, timeout)) + .then(Mono.empty()) + )) + .then(); } private Mono hasIndex(final MongoCollection collection, final Document index, @Nullable final Timeout timeout) { @@ -228,40 +202,37 @@ private Mono createIndexMono(final MongoCollection collection, fi } private Mono createSaveChunksMono(final AtomicBoolean terminated, @Nullable final Timeout timeout) { - return Mono.create(sink -> { - AtomicLong lengthInBytes = new AtomicLong(0); - AtomicInteger chunkIndex = new AtomicInteger(0); - new ResizingByteBufferFlux(source, chunkSizeBytes) - .takeUntilOther(createMonoTimer(timeout)) - .flatMap((Function>) byteBuffer -> { - if (terminated.get()) { - return Mono.empty(); - } - byte[] byteArray = new byte[byteBuffer.remaining()]; - if (byteBuffer.hasArray()) { - System.arraycopy(byteBuffer.array(), byteBuffer.position(), byteArray, 0, byteBuffer.remaining()); - } else { - byteBuffer.mark(); - byteBuffer.get(byteArray); - byteBuffer.reset(); - } - Binary data = new Binary(byteArray); - lengthInBytes.addAndGet(data.length()); + return new ResizingByteBufferFlux(source, chunkSizeBytes) + .takeUntilOther(createMonoTimer(timeout)) + .index() + .flatMap(indexAndBuffer -> { + if (terminated.get()) { + return Mono.empty(); + } + Long index = indexAndBuffer.getT1(); + ByteBuffer byteBuffer = indexAndBuffer.getT2(); + byte[] byteArray = new byte[byteBuffer.remaining()]; + if (byteBuffer.hasArray()) { + System.arraycopy(byteBuffer.array(), byteBuffer.position(), byteArray, 0, byteBuffer.remaining()); + } else { + byteBuffer.mark(); + byteBuffer.get(byteArray); + byteBuffer.reset(); + } + Binary data = new Binary(byteArray); - Document chunkDocument = new Document("files_id", fileId) - .append("n", chunkIndex.getAndIncrement()) - .append("data", data); + Document chunkDocument = new Document("files_id", fileId) + .append("n", index.intValue()) + .append("data", data); - if (clientSession == null) { - return collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE).insertOne(chunkDocument); - } else { - return collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE).insertOne(clientSession, - chunkDocument); - } + Publisher insertOnePublisher = clientSession == null + ? collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE).insertOne(chunkDocument) + : collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE) + .insertOne(clientSession, chunkDocument); - }) - .subscribe(null, sink::error, () -> sink.success(lengthInBytes.get())); - }); + return Mono.from(insertOnePublisher).thenReturn(data.length()); + }) + .reduce(0L, Long::sum); } /**