Provide child/frame IDs for WebSocket handshake request

AndroidCookiePolicy needs the child ID and the frame ID of a WebSocket
connection to determine if it allows the connection to attach
third-party cookies. This CL provide the additional information to the
WebSocket handshake net::URLRequest.

BUG=634311

Review-Url: https://codereview.chromium.org/2397393002
Cr-Commit-Position: refs/heads/master@{#427109}
diff --git a/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java b/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java
index 9023991..5ca9b52 100644
--- a/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java
+++ b/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java
@@ -9,6 +9,7 @@
 import android.util.Log;
 import android.util.Pair;
 
+import org.apache.http.Header;
 import org.apache.http.HttpException;
 import org.apache.http.HttpRequest;
 import org.apache.http.HttpResponse;
@@ -33,8 +34,10 @@
 import java.net.URI;
 import java.net.URL;
 import java.net.URLConnection;
+import java.nio.charset.Charset;
 import java.security.KeyManagementException;
 import java.security.KeyStore;
+import java.security.MessageDigest;
 import java.security.NoSuchAlgorithmException;
 import java.security.cert.X509Certificate;
 import java.util.ArrayList;
@@ -80,13 +83,15 @@
         final Runnable mResponseAction;
         final boolean mIsNotFound;
         final boolean mIsNoContent;
+        final boolean mForWebSocket;
 
         Response(byte[] responseData, List<Pair<String, String>> responseHeaders,
-                boolean isRedirect, boolean isNotFound, boolean isNoContent,
+                boolean isRedirect, boolean isNotFound, boolean isNoContent, boolean forWebSocket,
                 Runnable responseAction) {
             mIsRedirect = isRedirect;
             mIsNotFound = isNotFound;
             mIsNoContent = isNoContent;
+            mForWebSocket = forWebSocket;
             mResponseData = responseData;
             mResponseHeaders = responseHeaders == null
                     ? new ArrayList<Pair<String, String>>() : responseHeaders;
@@ -195,6 +200,7 @@
     private static final int RESPONSE_STATUS_MOVED_TEMPORARILY = 1;
     private static final int RESPONSE_STATUS_NOT_FOUND = 2;
     private static final int RESPONSE_STATUS_NO_CONTENT = 3;
+    private static final int RESPONSE_STATUS_FOR_WEBSOCKET = 4;
 
     private String setResponseInternal(
             String requestPath, byte[] responseData,
@@ -203,11 +209,12 @@
         final boolean isRedirect = (status == RESPONSE_STATUS_MOVED_TEMPORARILY);
         final boolean isNotFound = (status == RESPONSE_STATUS_NOT_FOUND);
         final boolean isNoContent = (status == RESPONSE_STATUS_NO_CONTENT);
+        final boolean forWebSocket = (status == RESPONSE_STATUS_FOR_WEBSOCKET);
 
         synchronized (mLock) {
-            mResponseMap.put(requestPath, new Response(
-                    responseData, responseHeaders, isRedirect, isNotFound, isNoContent,
-                    responseAction));
+            mResponseMap.put(
+                    requestPath, new Response(responseData, responseHeaders, isRedirect, isNotFound,
+                                         isNoContent, forWebSocket, responseAction));
             mResponseCountMap.put(requestPath, Integer.valueOf(0));
             mLastRequestMap.put(requestPath, null);
         }
@@ -344,6 +351,28 @@
     }
 
     /**
+     * Sets a response to a WebSocket handshake request.
+     *
+     * @param requestPath The path to respond to.
+     * @param responseHeaders Any additional headers that should be returned along with the
+     *                        response (null is acceptable).
+     * @return The full URL including the path that should be requested to get the expected
+     *         response.
+     */
+    public String setResponseForWebSocket(
+            String requestPath, List<Pair<String, String>> responseHeaders) {
+        if (responseHeaders == null) {
+            responseHeaders = new ArrayList<Pair<String, String>>();
+        } else {
+            responseHeaders = new ArrayList<Pair<String, String>>(responseHeaders);
+        }
+        responseHeaders.add(Pair.create("Connection", "Upgrade"));
+        responseHeaders.add(Pair.create("Upgrade", "websocket"));
+        return setResponseInternal(
+                requestPath, "".getBytes(), responseHeaders, null, RESPONSE_STATUS_FOR_WEBSOCKET);
+    }
+
+    /**
      * Get the number of requests was made at this path since it was last set.
      */
     public int getRequestCount(String requestPath) {
@@ -481,6 +510,23 @@
                 httpResponse.addHeader(header.first, header.second);
             }
             servedResponseFor(path, request);
+        } else if (response.mForWebSocket) {
+            Header[] keys = request.getHeaders("Sec-WebSocket-Key");
+            try {
+                if (keys.length == 1) {
+                    final String key = keys[0].getValue();
+                    httpResponse = createResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
+                    for (Pair<String, String> header : response.mResponseHeaders) {
+                        httpResponse.addHeader(header.first, header.second);
+                    }
+                    httpResponse.addHeader("Sec-WebSocket-Accept", computeWebSocketAccept(key));
+                } else {
+                    httpResponse = createResponse(HttpStatus.SC_NOT_FOUND);
+                }
+            } catch (NoSuchAlgorithmException e) {
+                httpResponse = createResponse(HttpStatus.SC_NOT_FOUND);
+            }
+            servedResponseFor(path, request);
         } else {
             if (response.mResponseAction != null) response.mResponseAction.run();
 
@@ -552,6 +598,20 @@
         return entity;
     }
 
+    /**
+     * Return a response for WebSocket handshake challenge.
+     */
+    private static String computeWebSocketAccept(String keyString) throws NoSuchAlgorithmException {
+        byte[] key = keyString.getBytes(Charset.forName("US-ASCII"));
+        byte[] guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(Charset.forName("US-ASCII"));
+
+        MessageDigest md = MessageDigest.getInstance("SHA");
+        md.update(key);
+        md.update(guid);
+        byte[] output = md.digest();
+        return Base64.encodeToString(output, Base64.DEFAULT);
+    }
+
     private static class ServerThread extends Thread {
         private TestWebServer mServer;
         private ServerSocket mSocket;
diff --git a/net/websockets/websocket_channel.cc b/net/websockets/websocket_channel.cc
index 6cbd8b59..f5d639f 100644
--- a/net/websockets/websocket_channel.cc
+++ b/net/websockets/websocket_channel.cc
@@ -175,6 +175,10 @@
  public:
   explicit ConnectDelegate(WebSocketChannel* creator) : creator_(creator) {}
 
+  void OnCreateRequest(net::URLRequest* request) override {
+    creator_->OnCreateURLRequest(request);
+  }
+
   void OnSuccess(std::unique_ptr<WebSocketStream> stream) override {
     creator_->OnConnectSuccess(std::move(stream));
     // |this| may have been deleted.
@@ -603,6 +607,10 @@
   SetState(CONNECTING);
 }
 
+void WebSocketChannel::OnCreateURLRequest(URLRequest* request) {
+  event_interface_->OnCreateURLRequest(request);
+}
+
 void WebSocketChannel::OnConnectSuccess(
     std::unique_ptr<WebSocketStream> stream) {
   DCHECK(stream);
diff --git a/net/websockets/websocket_channel.h b/net/websockets/websocket_channel.h
index 43089791..2f68b14 100644
--- a/net/websockets/websocket_channel.h
+++ b/net/websockets/websocket_channel.h
@@ -33,6 +33,7 @@
 
 class NetLogWithSource;
 class IOBuffer;
+class URLRequest;
 class URLRequestContext;
 struct WebSocketHandshakeRequestInfo;
 struct WebSocketHandshakeResponseInfo;
@@ -226,6 +227,9 @@
       const std::string& additional_headers,
       const WebSocketStreamRequestCreationCallback& callback);
 
+  // Called when a URLRequest is created for handshaking.
+  void OnCreateURLRequest(URLRequest* request);
+
   // Success callback from WebSocketStream::CreateAndConnectStream(). Reports
   // success to the event interface. May delete |this|.
   void OnConnectSuccess(std::unique_ptr<WebSocketStream> stream);
diff --git a/net/websockets/websocket_channel_test.cc b/net/websockets/websocket_channel_test.cc
index 10ea232e..d2c477a 100644
--- a/net/websockets/websocket_channel_test.cc
+++ b/net/websockets/websocket_channel_test.cc
@@ -167,6 +167,7 @@
                              std::vector<char>(data, data + buffer_size));
   }
 
+  MOCK_METHOD1(OnCreateURLRequest, void(URLRequest*));
   MOCK_METHOD2(OnAddChannelResponse,
                ChannelState(const std::string&,
                             const std::string&));  // NOLINT
@@ -211,6 +212,7 @@
 // This fake EventInterface is for tests which need a WebSocketEventInterface
 // implementation but are not verifying how it is used.
 class FakeWebSocketEventInterface : public WebSocketEventInterface {
+  void OnCreateURLRequest(URLRequest* request) override {}
   ChannelState OnAddChannelResponse(const std::string& selected_protocol,
                                     const std::string& extensions) override {
     return CHANNEL_ALIVE;
diff --git a/net/websockets/websocket_end_to_end_test.cc b/net/websockets/websocket_end_to_end_test.cc
index 31b557a..ba2a657 100644
--- a/net/websockets/websocket_end_to_end_test.cc
+++ b/net/websockets/websocket_end_to_end_test.cc
@@ -37,6 +37,8 @@
 
 namespace net {
 
+class URLRequest;
+
 namespace {
 
 static const char kEchoServer[] = "echo-with-no-extension";
@@ -66,6 +68,8 @@
   std::string extensions() const;
 
   // Implementation of WebSocketEventInterface.
+  void OnCreateURLRequest(URLRequest* request) override {}
+
   ChannelState OnAddChannelResponse(const std::string& selected_subprotocol,
                                     const std::string& extensions) override;
 
diff --git a/net/websockets/websocket_event_interface.h b/net/websockets/websocket_event_interface.h
index ac29659..db24207f 100644
--- a/net/websockets/websocket_event_interface.h
+++ b/net/websockets/websocket_event_interface.h
@@ -22,6 +22,7 @@
 
 class IOBuffer;
 class SSLInfo;
+class URLRequest;
 struct WebSocketHandshakeRequestInfo;
 struct WebSocketHandshakeResponseInfo;
 
@@ -41,6 +42,9 @@
 
   virtual ~WebSocketEventInterface() {}
 
+  // Called when a URLRequest is created for handshaking.
+  virtual void OnCreateURLRequest(URLRequest* request) = 0;
+
   // Called in response to an AddChannelRequest. This means that a response has
   // been received from the remote server.
   virtual ChannelState OnAddChannelResponse(
diff --git a/net/websockets/websocket_handshake_stream_create_helper_test.cc b/net/websockets/websocket_handshake_stream_create_helper_test.cc
index 6b43527ff..f4d7c26 100644
--- a/net/websockets/websocket_handshake_stream_create_helper_test.cc
+++ b/net/websockets/websocket_handshake_stream_create_helper_test.cc
@@ -65,6 +65,7 @@
  public:
   ~TestConnectDelegate() override {}
 
+  void OnCreateRequest(URLRequest* request) override {}
   void OnSuccess(std::unique_ptr<WebSocketStream> stream) override {}
   void OnFailure(const std::string& failure_message) override {}
   void OnStartOpeningHandshake(
diff --git a/net/websockets/websocket_stream.cc b/net/websockets/websocket_stream.cc
index ef62c24..8cca365 100644
--- a/net/websockets/websocket_stream.cc
+++ b/net/websockets/websocket_stream.cc
@@ -116,6 +116,7 @@
         WebSocketHandshakeStreamBase::CreateHelper::DataKey(),
         handshake_stream_create_helper_);
     url_request_->SetLoadFlags(LOAD_DISABLE_CACHE | LOAD_BYPASS_CACHE);
+    connect_delegate_->OnCreateRequest(url_request_.get());
   }
 
   // Destroying this object destroys the URLRequest, which cancels the request
diff --git a/net/websockets/websocket_stream.h b/net/websockets/websocket_stream.h
index 389e092..e960f3d 100644
--- a/net/websockets/websocket_stream.h
+++ b/net/websockets/websocket_stream.h
@@ -32,6 +32,7 @@
 namespace net {
 
 class NetLogWithSource;
+class URLRequest;
 class URLRequestContext;
 struct WebSocketFrame;
 class WebSocketHandshakeStreamBase;
@@ -71,6 +72,9 @@
   class NET_EXPORT_PRIVATE ConnectDelegate {
    public:
     virtual ~ConnectDelegate();
+    // Called when the URLRequest is created.
+    virtual void OnCreateRequest(URLRequest* url_request) = 0;
+
     // Called on successful connection. The parameter is an object derived from
     // WebSocketStream.
     virtual void OnSuccess(std::unique_ptr<WebSocketStream> stream) = 0;
diff --git a/net/websockets/websocket_stream_create_test_base.cc b/net/websockets/websocket_stream_create_test_base.cc
index 6ca5d70..d3b1668 100644
--- a/net/websockets/websocket_stream_create_test_base.cc
+++ b/net/websockets/websocket_stream_create_test_base.cc
@@ -50,6 +50,10 @@
                       const base::Closure& done_callback)
       : owner_(owner), done_callback_(done_callback) {}
 
+  void OnCreateRequest(URLRequest* request) override {
+    owner_->url_request_ = request;
+  }
+
   void OnSuccess(std::unique_ptr<WebSocketStream> stream) override {
     stream.swap(owner_->stream_);
     done_callback_.Run();
@@ -92,8 +96,7 @@
 };
 
 WebSocketStreamCreateTestBase::WebSocketStreamCreateTestBase()
-    : has_failed_(false), ssl_fatal_(false) {
-}
+    : has_failed_(false), ssl_fatal_(false), url_request_(nullptr) {}
 
 WebSocketStreamCreateTestBase::~WebSocketStreamCreateTestBase() {
 }
diff --git a/net/websockets/websocket_stream_create_test_base.h b/net/websockets/websocket_stream_create_test_base.h
index be7912f..5805b50 100644
--- a/net/websockets/websocket_stream_create_test_base.h
+++ b/net/websockets/websocket_stream_create_test_base.h
@@ -24,6 +24,7 @@
 
 class HttpRequestHeaders;
 class HttpResponseHeaders;
+class URLRequest;
 class WebSocketStream;
 class WebSocketStreamRequest;
 struct WebSocketHandshakeRequestInfo;
@@ -75,6 +76,7 @@
   SSLInfo ssl_info_;
   bool ssl_fatal_;
   std::vector<std::unique_ptr<SSLSocketDataProvider>> ssl_data_;
+  URLRequest* url_request_;
 
   // This temporarily sets WebSocketEndpointLockManager unlock delay to zero
   // during tests.
diff --git a/net/websockets/websocket_stream_test.cc b/net/websockets/websocket_stream_test.cc
index fcafedc..091c28d 100644
--- a/net/websockets/websocket_stream_test.cc
+++ b/net/websockets/websocket_stream_test.cc
@@ -325,11 +325,13 @@
 
 // Confirm that the basic case works as expected.
 TEST_F(WebSocketStreamCreateTest, SimpleSuccess) {
+  EXPECT_FALSE(url_request_);
   CreateAndConnectStandard("ws://localhost/", "localhost", "/",
                            NoSubProtocols(), LocalhostOrigin(), LocalhostUrl(),
                            "", "", "");
   EXPECT_FALSE(request_info_);
   EXPECT_FALSE(response_info_);
+  EXPECT_TRUE(url_request_);
   WaitUntilConnectDone();
   EXPECT_FALSE(has_failed());
   EXPECT_TRUE(stream_);