Collect update notifications only when datastore is observed

Previously, DataStore started collecting updates from IPC right after
it is initialized and kept this collection alive as long as the scope
is alive.

This change updates it to only observe the file if there is an active
observer on the DataStore. We achieve this by updating `data` to a
channel flow that will collect on the shared flow when it is used.

Some tests needed to be updated because they made an assumption on the
order of execution which slightly changes with this channel flow but
does not effect correctness.

Fixes: 267792241
Test: SingleProcessDataStoreTest#observeFileOnlyWhenDatastoreIsObserved
Change-Id: Id8221718f75b9755119d5e9f38c4827fa907b867
diff --git a/datastore/datastore-core/src/androidInstrumentedTest/kotlin/androidx/datastore/core/MultiProcessDataStoreSingleProcessTest.kt b/datastore/datastore-core/src/androidInstrumentedTest/kotlin/androidx/datastore/core/MultiProcessDataStoreSingleProcessTest.kt
index 24bd9db..9ecbb53 100644
--- a/datastore/datastore-core/src/androidInstrumentedTest/kotlin/androidx/datastore/core/MultiProcessDataStoreSingleProcessTest.kt
+++ b/datastore/datastore-core/src/androidInstrumentedTest/kotlin/androidx/datastore/core/MultiProcessDataStoreSingleProcessTest.kt
@@ -54,6 +54,7 @@
 import kotlinx.coroutines.runBlocking
 import kotlinx.coroutines.test.TestScope
 import kotlinx.coroutines.test.UnconfinedTestDispatcher
+import kotlinx.coroutines.test.runCurrent
 import kotlinx.coroutines.test.runTest
 import kotlinx.coroutines.withContext
 import kotlinx.coroutines.withTimeout
@@ -602,6 +603,7 @@
             store.data.take(8).toList(collectedBytes)
         }
 
+        runCurrent()
         repeat(7) {
             store.updateData { it.inc() }
         }
@@ -626,6 +628,7 @@
             flowOf8.toList(bytesFromSecondCollect)
         }
 
+        runCurrent()
         repeat(7) {
             store.updateData { it.inc() }
         }
@@ -659,6 +662,7 @@
             flowOf8.take(8).toList(collectedBytes)
         }
 
+        runCurrent()
         repeat(7) {
             store.updateData { it.inc() }
         }
@@ -685,6 +689,7 @@
             }
         }
 
+        runCurrent()
         repeat(15) {
             store.updateData { it.inc() }
         }
diff --git a/datastore/datastore-core/src/commonMain/kotlin/androidx/datastore/core/DataStoreImpl.kt b/datastore/datastore-core/src/commonMain/kotlin/androidx/datastore/core/DataStoreImpl.kt
index da2b8fd..9c038e5 100644
--- a/datastore/datastore-core/src/commonMain/kotlin/androidx/datastore/core/DataStoreImpl.kt
+++ b/datastore/datastore-core/src/commonMain/kotlin/androidx/datastore/core/DataStoreImpl.kt
@@ -22,18 +22,25 @@
 import kotlin.contracts.contract
 import kotlin.coroutines.CoroutineContext
 import kotlin.coroutines.coroutineContext
+import kotlin.time.Duration
 import kotlinx.coroutines.CancellationException
 import kotlinx.coroutines.CompletableDeferred
 import kotlinx.coroutines.CoroutineScope
-import kotlinx.coroutines.Job
+import kotlinx.coroutines.CoroutineStart
 import kotlinx.coroutines.SupervisorJob
 import kotlinx.coroutines.completeWith
 import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.SharingStarted
+import kotlinx.coroutines.flow.WhileSubscribed
+import kotlinx.coroutines.flow.channelFlow
 import kotlinx.coroutines.flow.conflate
 import kotlinx.coroutines.flow.dropWhile
 import kotlinx.coroutines.flow.emitAll
 import kotlinx.coroutines.flow.flow
 import kotlinx.coroutines.flow.map
+import kotlinx.coroutines.flow.onCompletion
+import kotlinx.coroutines.flow.onStart
+import kotlinx.coroutines.flow.shareIn
 import kotlinx.coroutines.flow.takeWhile
 import kotlinx.coroutines.launch
 import kotlinx.coroutines.sync.Mutex
@@ -57,7 +64,36 @@
     private val scope: CoroutineScope = CoroutineScope(ioDispatcher() + SupervisorJob())
 ) : DataStore<T> {
 
-    override val data: Flow<T> = flow {
+    /**
+     * Shared flow responsible for observing [InterProcessCoordinator] for file changes.
+     * Each downstream [data] flow collects on this [kotlinx.coroutines.flow.SharedFlow] to ensure
+     * we observe the [InterProcessCoordinator] when there is an active collection on the [data].
+     */
+    private val updateCollection = flow<Unit> {
+        // deferring 1 flow so we can create coordinator lazily just to match existing behavior.
+        // also wait for initialization to complete before watching update events.
+        readAndInit.awaitComplete()
+        coordinator.updateNotifications.conflate().collect {
+            val currentState = inMemoryCache.currentState
+            if (currentState !is Final) {
+                // update triggered reads should always wait for lock
+                readDataAndUpdateCache(requireLock = true)
+            }
+        }
+    }.shareIn(
+        scope = scope,
+        started = SharingStarted.WhileSubscribed(
+            stopTimeout = Duration.ZERO,
+            replayExpiration = Duration.ZERO
+        ),
+        replay = 0
+    )
+
+    /**
+     * The actual values of DataStore. This is exposed in the API via [data] to be able to combine
+     * its lifetime with IPC update collection ([updateCollection]).
+     */
+    private val internalDataFlow: Flow<T> = flow {
         /**
          * If downstream flow is UnInitialized, no data has been read yet, we need to trigger a new
          * read then start emitting values once we have seen a new value (or exception).
@@ -106,6 +142,20 @@
         )
     }
 
+    override val data: Flow<T> = channelFlow {
+        val updateCollector = launch(start = CoroutineStart.LAZY) {
+            updateCollection.collect {
+                // collect it infinitely so it keeps running as long as the data flow is active.
+            }
+        }
+        internalDataFlow
+            .onStart { updateCollector.start() }
+            .onCompletion { updateCollector.cancel() }
+            .collect {
+                send(it)
+            }
+    }
+
     override suspend fun updateData(transform: suspend (t: T) -> T): T {
         val ack = CompletableDeferred<T>()
         val currentDownStreamFlowState = inMemoryCache.currentState
@@ -123,30 +173,26 @@
 
     private val readAndInit = InitDataStore(initTasksList)
 
-    private lateinit var updateCollector: Job
-
     // TODO(b/269772127): make this private after we allow multiple instances of DataStore on the
     //  same file
-    internal val storageConnection: StorageConnection<T> by lazy {
+    private val storageConnectionDelegate = lazy {
         storage.createConnection()
     }
+    internal val storageConnection by storageConnectionDelegate
     private val coordinator: InterProcessCoordinator by lazy { storageConnection.coordinator }
 
     private val writeActor = SimpleActor<Message.Update<T>>(
         scope = scope,
         onComplete = {
-            // TODO(b/267792241): remove it if updateCollector is better scoped
-            // no more reads so stop listening to file changes
-            if (::updateCollector.isInitialized) {
-                updateCollector.cancel()
-            }
+            // We expect it to always be non-null but we will leave the alternative as a no-op
+            // just in case.
             it?.let {
                 inMemoryCache.tryUpdate(Final(it))
             }
-            // We expect it to always be non-null but we will leave the alternative as a no-op
-            // just in case.
-
-            storageConnection.close()
+            // don't try to close storage connection if it was not created in the first place.
+            if (storageConnectionDelegate.isInitialized()) {
+                storageConnection.close()
+            }
         },
         onUndeliveredElement = { msg, ex ->
             msg.ack.completeExceptionally(
@@ -379,17 +425,6 @@
                 }
             }
             inMemoryCache.tryUpdate(initData)
-            if (!::updateCollector.isInitialized) {
-                updateCollector = scope.launch {
-                    coordinator.updateNotifications.conflate().collect {
-                        val currentState = inMemoryCache.currentState
-                        if (currentState !is Final) {
-                            // update triggered reads should always wait for lock
-                            readDataAndUpdateCache(requireLock = true)
-                        }
-                    }
-                }
-            }
         }
 
         @OptIn(ExperimentalContracts::class)
@@ -464,15 +499,17 @@
  */
 internal abstract class RunOnce {
     private val runMutex = Mutex()
-    private var didRun: Boolean = false
+    private val didRun = CompletableDeferred<Unit>()
     protected abstract suspend fun doRun()
 
+    suspend fun awaitComplete() = didRun.await()
+
     suspend fun runIfNeeded() {
-        if (didRun) return
+        if (didRun.isCompleted) return
         runMutex.withLock {
-            if (didRun) return
+            if (didRun.isCompleted) return
             doRun()
-            didRun = true
+            didRun.complete(Unit)
         }
     }
 }
diff --git a/datastore/datastore-core/src/commonTest/kotlin/androidx/datastore/core/CloseDownstreamOnCloseTest.kt b/datastore/datastore-core/src/commonTest/kotlin/androidx/datastore/core/CloseDownstreamOnCloseTest.kt
index 5ecb216..087891b 100644
--- a/datastore/datastore-core/src/commonTest/kotlin/androidx/datastore/core/CloseDownstreamOnCloseTest.kt
+++ b/datastore/datastore-core/src/commonTest/kotlin/androidx/datastore/core/CloseDownstreamOnCloseTest.kt
@@ -24,11 +24,13 @@
 import kotlin.test.BeforeTest
 import kotlin.test.Test
 import kotlinx.coroutines.CancellationException
+import kotlinx.coroutines.ExperimentalCoroutinesApi
 import kotlinx.coroutines.async
 import kotlinx.coroutines.cancel
 import kotlinx.coroutines.flow.toList
 import kotlinx.coroutines.test.StandardTestDispatcher
 import kotlinx.coroutines.test.TestScope
+import kotlinx.coroutines.test.runCurrent
 import kotlinx.coroutines.test.runTest
 
 abstract class CloseDownstreamOnCloseTest<F : TestFile<F>>(private val testIO: TestIO<F, *>) {
@@ -47,11 +49,13 @@
         ) { testFile }
     }
 
+    @OptIn(ExperimentalCoroutinesApi::class)
     @Test
     fun closeWhileCollecting() = testScope.runTest {
         val collector = async {
             store.data.toList().map { it.toInt() }
         }
+        runCurrent()
         store.updateData { 1 }
         datastoreScope.cancel()
         dispatcher.scheduler.advanceUntilIdle()
diff --git a/datastore/datastore-core/src/commonTest/kotlin/androidx/datastore/core/SingleProcessDataStoreTest.kt b/datastore/datastore-core/src/commonTest/kotlin/androidx/datastore/core/SingleProcessDataStoreTest.kt
index ed9fd38..1a08d82 100644
--- a/datastore/datastore-core/src/commonTest/kotlin/androidx/datastore/core/SingleProcessDataStoreTest.kt
+++ b/datastore/datastore-core/src/commonTest/kotlin/androidx/datastore/core/SingleProcessDataStoreTest.kt
@@ -37,12 +37,19 @@
 import kotlinx.coroutines.awaitCancellation
 import kotlinx.coroutines.cancel
 import kotlinx.coroutines.cancelAndJoin
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.MutableStateFlow
 import kotlinx.coroutines.flow.collect
 import kotlinx.coroutines.flow.first
+import kotlinx.coroutines.flow.flow
+import kotlinx.coroutines.flow.onCompletion
+import kotlinx.coroutines.flow.onStart
 import kotlinx.coroutines.flow.take
 import kotlinx.coroutines.flow.toList
+import kotlinx.coroutines.flow.update
 import kotlinx.coroutines.job
 import kotlinx.coroutines.launch
+import kotlinx.coroutines.suspendCancellableCoroutine
 import kotlinx.coroutines.test.TestScope
 import kotlinx.coroutines.test.UnconfinedTestDispatcher
 import kotlinx.coroutines.test.runCurrent
@@ -513,6 +520,7 @@
         val flowCollectionJob = async {
             store.data.take(8).toList(collectedBytes)
         }
+        runCurrent()
 
         repeat(7) {
             store.updateData { it.inc() }
@@ -538,6 +546,7 @@
             flowOf8.toList(bytesFromSecondCollect)
         }
 
+        runCurrent()
         repeat(7) {
             store.updateData { it.inc() }
         }
@@ -566,6 +575,7 @@
             flowOf8.take(8).toList(collectedBytes)
         }
 
+        runCurrent()
         repeat(7) {
             store.updateData { it.inc() }
         }
@@ -592,7 +602,7 @@
                 flowCollection2.await()
             }
         }
-
+        runCurrent()
         repeat(15) {
             store.updateData { it.inc() }
         }
@@ -968,6 +978,61 @@
         }
     }
 
+    @Test
+    fun observeFileOnlyWhenDatastoreIsObserved() = runTest {
+        class InterProcessCoordinatorWithInfiniteUpdates(
+            val delegate: InterProcessCoordinator,
+            val observerCount: MutableStateFlow<Int> = MutableStateFlow(0)
+        ) : InterProcessCoordinator by delegate {
+            override val updateNotifications: Flow<Unit>
+                get() {
+                    return flow<Unit> {
+                        // never emit but never finish either so we know when we are being collected
+                        suspendCancellableCoroutine { }
+                    }.onStart {
+                        observerCount.update { it + 1 }
+                    }.onCompletion {
+                        observerCount.update { it - 1 }
+                    }
+                }
+        }
+        val observerCount = MutableStateFlow(0)
+        store = testIO.getStore(
+            serializerConfig,
+            dataStoreScope,
+            {
+                InterProcessCoordinatorWithInfiniteUpdates(
+                    delegate = createSingleProcessCoordinator(testFile.path()),
+                    observerCount = observerCount
+                )
+            }
+        ) { testFile }
+        fun hasObservers(): Boolean {
+            runCurrent()
+            return observerCount.value > 0
+        }
+        assertThat(hasObservers()).isFalse()
+        val collector1 = async {
+            store.data.collect {}
+        }
+        runCurrent()
+        assertThat(hasObservers()).isTrue()
+        val collector2 = async {
+            store.data.collect {}
+        }
+        assertThat(hasObservers()).isTrue()
+        collector1.cancelAndJoin()
+        assertThat(hasObservers()).isTrue()
+        collector2.cancelAndJoin()
+        assertThat(hasObservers()).isFalse()
+        val collector3 = async {
+            store.data.collect {}
+        }
+        assertThat(hasObservers()).isTrue()
+        collector3.cancelAndJoin()
+        assertThat(hasObservers()).isFalse()
+    }
+
     private class TestingCorruptionHandler(
         private val replaceWith: Byte? = null
     ) : CorruptionHandler<Byte> {