Use CompletionOnceCallback in WebSocketStream.

Bug: 807724
Change-Id: I71ddcf48ef586baa4d784a76f3d281988e6c2fc4
Reviewed-on: https://chromium-review.googlesource.com/1128180
Reviewed-by: Adam Rice <[email protected]>
Commit-Queue: Bence Béky <[email protected]>
Cr-Commit-Position: refs/heads/master@{#575521}
diff --git a/net/websockets/websocket_basic_stream.cc b/net/websockets/websocket_basic_stream.cc
index 31267ba..d9d45b65 100644
--- a/net/websockets/websocket_basic_stream.cc
+++ b/net/websockets/websocket_basic_stream.cc
@@ -115,55 +115,21 @@
 
 int WebSocketBasicStream::ReadFrames(
     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
-    const CompletionCallback& callback) {
-  DCHECK(frames->empty());
-  // If there is data left over after parsing the HTTP headers, attempt to parse
-  // it as WebSocket frames.
-  if (http_read_buffer_.get()) {
-    DCHECK_GE(http_read_buffer_->offset(), 0);
-    // We cannot simply copy the data into read_buffer_, as it might be too
-    // large.
-    scoped_refptr<GrowableIOBuffer> buffered_data;
-    buffered_data.swap(http_read_buffer_);
-    DCHECK(!http_read_buffer_);
-    std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
-    if (!parser_.Decode(buffered_data->StartOfBuffer(),
-                        buffered_data->offset(),
-                        &frame_chunks))
-      return WebSocketErrorToNetError(parser_.websocket_error());
-    if (!frame_chunks.empty()) {
-      int result = ConvertChunksToFrames(&frame_chunks, frames);
-      if (result != ERR_IO_PENDING)
-        return result;
-    }
-  }
+    CompletionOnceCallback callback) {
+  read_callback_ = std::move(callback);
 
-  // Run until socket stops giving us data or we get some frames.
-  while (true) {
-    // base::Unretained(this) here is safe because net::Socket guarantees not to
-    // call any callbacks after Disconnect(), which we call from the
-    // destructor. The caller of ReadFrames() is required to keep |frames|
-    // valid.
-    int result = connection_->Read(
-        read_buffer_.get(), read_buffer_->size(),
-        base::Bind(&WebSocketBasicStream::OnReadComplete,
-                   base::Unretained(this), base::Unretained(frames), callback));
-    if (result == ERR_IO_PENDING)
-      return result;
-    result = HandleReadResult(result, frames);
-    if (result != ERR_IO_PENDING)
-      return result;
-    DCHECK(frames->empty());
-  }
+  return ReadEverything(frames);
 }
 
 int WebSocketBasicStream::WriteFrames(
     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
-    const CompletionCallback& callback) {
+    CompletionOnceCallback callback) {
   // This function always concatenates all frames into a single buffer.
   // TODO(ricea): Investigate whether it would be better in some cases to
   // perform multiple writes with smaller buffers.
-  //
+
+  write_callback_ = std::move(callback);
+
   // First calculate the size of the buffer we need to allocate.
   int total_size = CalculateSerializedSizeAndTurnOnMaskBit(frames);
   auto combined_buffer = base::MakeRefCounted<IOBufferWithSize>(total_size);
@@ -196,7 +162,7 @@
                                << remaining_size << " bytes left over.";
   auto drainable_buffer = base::MakeRefCounted<DrainableIOBuffer>(
       combined_buffer.get(), total_size);
-  return WriteEverything(drainable_buffer, callback);
+  return WriteEverything(drainable_buffer);
 }
 
 void WebSocketBasicStream::Close() {
@@ -225,17 +191,68 @@
   return stream;
 }
 
+int WebSocketBasicStream::ReadEverything(
+    std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
+  DCHECK(frames->empty());
+
+  // If there is data left over after parsing the HTTP headers, attempt to parse
+  // it as WebSocket frames.
+  if (http_read_buffer_.get()) {
+    DCHECK_GE(http_read_buffer_->offset(), 0);
+    // We cannot simply copy the data into read_buffer_, as it might be too
+    // large.
+    scoped_refptr<GrowableIOBuffer> buffered_data;
+    buffered_data.swap(http_read_buffer_);
+    DCHECK(!http_read_buffer_);
+    std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
+    if (!parser_.Decode(buffered_data->StartOfBuffer(), buffered_data->offset(),
+                        &frame_chunks))
+      return WebSocketErrorToNetError(parser_.websocket_error());
+    if (!frame_chunks.empty()) {
+      int result = ConvertChunksToFrames(&frame_chunks, frames);
+      if (result != ERR_IO_PENDING)
+        return result;
+    }
+  }
+
+  // Run until socket stops giving us data or we get some frames.
+  while (true) {
+    // base::Unretained(this) here is safe because net::Socket guarantees not to
+    // call any callbacks after Disconnect(), which we call from the destructor.
+    // The caller of ReadEverything() is required to keep |frames| valid.
+    int result = connection_->Read(
+        read_buffer_.get(), read_buffer_->size(),
+        base::BindOnce(&WebSocketBasicStream::OnReadComplete,
+                       base::Unretained(this), base::Unretained(frames)));
+    if (result == ERR_IO_PENDING)
+      return result;
+    result = HandleReadResult(result, frames);
+    if (result != ERR_IO_PENDING)
+      return result;
+    DCHECK(frames->empty());
+  }
+}
+
+void WebSocketBasicStream::OnReadComplete(
+    std::vector<std::unique_ptr<WebSocketFrame>>* frames,
+    int result) {
+  result = HandleReadResult(result, frames);
+  if (result == ERR_IO_PENDING)
+    result = ReadEverything(frames);
+  if (result != ERR_IO_PENDING)
+    std::move(read_callback_).Run(result);
+}
+
 int WebSocketBasicStream::WriteEverything(
-    const scoped_refptr<DrainableIOBuffer>& buffer,
-    const CompletionCallback& callback) {
+    const scoped_refptr<DrainableIOBuffer>& buffer) {
   while (buffer->BytesRemaining() > 0) {
     // The use of base::Unretained() here is safe because on destruction we
     // disconnect the socket, preventing any further callbacks.
-    int result =
-        connection_->Write(buffer.get(), buffer->BytesRemaining(),
-                           base::Bind(&WebSocketBasicStream::OnWriteComplete,
-                                      base::Unretained(this), buffer, callback),
-                           kTrafficAnnotation);
+    int result = connection_->Write(
+        buffer.get(), buffer->BytesRemaining(),
+        base::BindOnce(&WebSocketBasicStream::OnWriteComplete,
+                       base::Unretained(this), buffer),
+        kTrafficAnnotation);
     if (result > 0) {
       UMA_HISTOGRAM_COUNTS_100000("Net.WebSocket.DataUse.Upstream", result);
       buffer->DidConsume(result);
@@ -248,11 +265,10 @@
 
 void WebSocketBasicStream::OnWriteComplete(
     const scoped_refptr<DrainableIOBuffer>& buffer,
-    const CompletionCallback& callback,
     int result) {
   if (result < 0) {
     DCHECK_NE(ERR_IO_PENDING, result);
-    callback.Run(result);
+    std::move(write_callback_).Run(result);
     return;
   }
 
@@ -260,9 +276,9 @@
   UMA_HISTOGRAM_COUNTS_100000("Net.WebSocket.DataUse.Upstream", result);
 
   buffer->DidConsume(result);
-  result = WriteEverything(buffer, callback);
+  result = WriteEverything(buffer);
   if (result != ERR_IO_PENDING)
-    callback.Run(result);
+    std::move(write_callback_).Run(result);
 }
 
 int WebSocketBasicStream::HandleReadResult(
@@ -441,15 +457,4 @@
   incomplete_control_frame_body_->set_offset(new_offset);
 }
 
-void WebSocketBasicStream::OnReadComplete(
-    std::vector<std::unique_ptr<WebSocketFrame>>* frames,
-    const CompletionCallback& callback,
-    int result) {
-  result = HandleReadResult(result, frames);
-  if (result == ERR_IO_PENDING)
-    result = ReadFrames(frames, callback);
-  if (result != ERR_IO_PENDING)
-    callback.Run(result);
-}
-
 }  // namespace net