Provide child/frame IDs for WebSocket handshake request
AndroidCookiePolicy needs the child ID and the frame ID of a WebSocket
connection to determine if it allows the connection to attach
third-party cookies. This CL provide the additional information to the
WebSocket handshake net::URLRequest.
BUG=634311
Review-Url: https://codereview.chromium.org/2397393002
Cr-Commit-Position: refs/heads/master@{#427109}
diff --git a/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java b/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java
index 9023991..5ca9b52 100644
--- a/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java
+++ b/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java
@@ -9,6 +9,7 @@
import android.util.Log;
import android.util.Pair;
+import org.apache.http.Header;
import org.apache.http.HttpException;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
@@ -33,8 +34,10 @@
import java.net.URI;
import java.net.URL;
import java.net.URLConnection;
+import java.nio.charset.Charset;
import java.security.KeyManagementException;
import java.security.KeyStore;
+import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
@@ -80,13 +83,15 @@
final Runnable mResponseAction;
final boolean mIsNotFound;
final boolean mIsNoContent;
+ final boolean mForWebSocket;
Response(byte[] responseData, List<Pair<String, String>> responseHeaders,
- boolean isRedirect, boolean isNotFound, boolean isNoContent,
+ boolean isRedirect, boolean isNotFound, boolean isNoContent, boolean forWebSocket,
Runnable responseAction) {
mIsRedirect = isRedirect;
mIsNotFound = isNotFound;
mIsNoContent = isNoContent;
+ mForWebSocket = forWebSocket;
mResponseData = responseData;
mResponseHeaders = responseHeaders == null
? new ArrayList<Pair<String, String>>() : responseHeaders;
@@ -195,6 +200,7 @@
private static final int RESPONSE_STATUS_MOVED_TEMPORARILY = 1;
private static final int RESPONSE_STATUS_NOT_FOUND = 2;
private static final int RESPONSE_STATUS_NO_CONTENT = 3;
+ private static final int RESPONSE_STATUS_FOR_WEBSOCKET = 4;
private String setResponseInternal(
String requestPath, byte[] responseData,
@@ -203,11 +209,12 @@
final boolean isRedirect = (status == RESPONSE_STATUS_MOVED_TEMPORARILY);
final boolean isNotFound = (status == RESPONSE_STATUS_NOT_FOUND);
final boolean isNoContent = (status == RESPONSE_STATUS_NO_CONTENT);
+ final boolean forWebSocket = (status == RESPONSE_STATUS_FOR_WEBSOCKET);
synchronized (mLock) {
- mResponseMap.put(requestPath, new Response(
- responseData, responseHeaders, isRedirect, isNotFound, isNoContent,
- responseAction));
+ mResponseMap.put(
+ requestPath, new Response(responseData, responseHeaders, isRedirect, isNotFound,
+ isNoContent, forWebSocket, responseAction));
mResponseCountMap.put(requestPath, Integer.valueOf(0));
mLastRequestMap.put(requestPath, null);
}
@@ -344,6 +351,28 @@
}
/**
+ * Sets a response to a WebSocket handshake request.
+ *
+ * @param requestPath The path to respond to.
+ * @param responseHeaders Any additional headers that should be returned along with the
+ * response (null is acceptable).
+ * @return The full URL including the path that should be requested to get the expected
+ * response.
+ */
+ public String setResponseForWebSocket(
+ String requestPath, List<Pair<String, String>> responseHeaders) {
+ if (responseHeaders == null) {
+ responseHeaders = new ArrayList<Pair<String, String>>();
+ } else {
+ responseHeaders = new ArrayList<Pair<String, String>>(responseHeaders);
+ }
+ responseHeaders.add(Pair.create("Connection", "Upgrade"));
+ responseHeaders.add(Pair.create("Upgrade", "websocket"));
+ return setResponseInternal(
+ requestPath, "".getBytes(), responseHeaders, null, RESPONSE_STATUS_FOR_WEBSOCKET);
+ }
+
+ /**
* Get the number of requests was made at this path since it was last set.
*/
public int getRequestCount(String requestPath) {
@@ -481,6 +510,23 @@
httpResponse.addHeader(header.first, header.second);
}
servedResponseFor(path, request);
+ } else if (response.mForWebSocket) {
+ Header[] keys = request.getHeaders("Sec-WebSocket-Key");
+ try {
+ if (keys.length == 1) {
+ final String key = keys[0].getValue();
+ httpResponse = createResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
+ for (Pair<String, String> header : response.mResponseHeaders) {
+ httpResponse.addHeader(header.first, header.second);
+ }
+ httpResponse.addHeader("Sec-WebSocket-Accept", computeWebSocketAccept(key));
+ } else {
+ httpResponse = createResponse(HttpStatus.SC_NOT_FOUND);
+ }
+ } catch (NoSuchAlgorithmException e) {
+ httpResponse = createResponse(HttpStatus.SC_NOT_FOUND);
+ }
+ servedResponseFor(path, request);
} else {
if (response.mResponseAction != null) response.mResponseAction.run();
@@ -552,6 +598,20 @@
return entity;
}
+ /**
+ * Return a response for WebSocket handshake challenge.
+ */
+ private static String computeWebSocketAccept(String keyString) throws NoSuchAlgorithmException {
+ byte[] key = keyString.getBytes(Charset.forName("US-ASCII"));
+ byte[] guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(Charset.forName("US-ASCII"));
+
+ MessageDigest md = MessageDigest.getInstance("SHA");
+ md.update(key);
+ md.update(guid);
+ byte[] output = md.digest();
+ return Base64.encodeToString(output, Base64.DEFAULT);
+ }
+
private static class ServerThread extends Thread {
private TestWebServer mServer;
private ServerSocket mSocket;
diff --git a/net/websockets/websocket_channel.cc b/net/websockets/websocket_channel.cc
index 6cbd8b59..f5d639f 100644
--- a/net/websockets/websocket_channel.cc
+++ b/net/websockets/websocket_channel.cc
@@ -175,6 +175,10 @@
public:
explicit ConnectDelegate(WebSocketChannel* creator) : creator_(creator) {}
+ void OnCreateRequest(net::URLRequest* request) override {
+ creator_->OnCreateURLRequest(request);
+ }
+
void OnSuccess(std::unique_ptr<WebSocketStream> stream) override {
creator_->OnConnectSuccess(std::move(stream));
// |this| may have been deleted.
@@ -603,6 +607,10 @@
SetState(CONNECTING);
}
+void WebSocketChannel::OnCreateURLRequest(URLRequest* request) {
+ event_interface_->OnCreateURLRequest(request);
+}
+
void WebSocketChannel::OnConnectSuccess(
std::unique_ptr<WebSocketStream> stream) {
DCHECK(stream);
diff --git a/net/websockets/websocket_channel.h b/net/websockets/websocket_channel.h
index 43089791..2f68b14 100644
--- a/net/websockets/websocket_channel.h
+++ b/net/websockets/websocket_channel.h
@@ -33,6 +33,7 @@
class NetLogWithSource;
class IOBuffer;
+class URLRequest;
class URLRequestContext;
struct WebSocketHandshakeRequestInfo;
struct WebSocketHandshakeResponseInfo;
@@ -226,6 +227,9 @@
const std::string& additional_headers,
const WebSocketStreamRequestCreationCallback& callback);
+ // Called when a URLRequest is created for handshaking.
+ void OnCreateURLRequest(URLRequest* request);
+
// Success callback from WebSocketStream::CreateAndConnectStream(). Reports
// success to the event interface. May delete |this|.
void OnConnectSuccess(std::unique_ptr<WebSocketStream> stream);
diff --git a/net/websockets/websocket_channel_test.cc b/net/websockets/websocket_channel_test.cc
index 10ea232e..d2c477a 100644
--- a/net/websockets/websocket_channel_test.cc
+++ b/net/websockets/websocket_channel_test.cc
@@ -167,6 +167,7 @@
std::vector<char>(data, data + buffer_size));
}
+ MOCK_METHOD1(OnCreateURLRequest, void(URLRequest*));
MOCK_METHOD2(OnAddChannelResponse,
ChannelState(const std::string&,
const std::string&)); // NOLINT
@@ -211,6 +212,7 @@
// This fake EventInterface is for tests which need a WebSocketEventInterface
// implementation but are not verifying how it is used.
class FakeWebSocketEventInterface : public WebSocketEventInterface {
+ void OnCreateURLRequest(URLRequest* request) override {}
ChannelState OnAddChannelResponse(const std::string& selected_protocol,
const std::string& extensions) override {
return CHANNEL_ALIVE;
diff --git a/net/websockets/websocket_end_to_end_test.cc b/net/websockets/websocket_end_to_end_test.cc
index 31b557a..ba2a657 100644
--- a/net/websockets/websocket_end_to_end_test.cc
+++ b/net/websockets/websocket_end_to_end_test.cc
@@ -37,6 +37,8 @@
namespace net {
+class URLRequest;
+
namespace {
static const char kEchoServer[] = "echo-with-no-extension";
@@ -66,6 +68,8 @@
std::string extensions() const;
// Implementation of WebSocketEventInterface.
+ void OnCreateURLRequest(URLRequest* request) override {}
+
ChannelState OnAddChannelResponse(const std::string& selected_subprotocol,
const std::string& extensions) override;
diff --git a/net/websockets/websocket_event_interface.h b/net/websockets/websocket_event_interface.h
index ac29659..db24207f 100644
--- a/net/websockets/websocket_event_interface.h
+++ b/net/websockets/websocket_event_interface.h
@@ -22,6 +22,7 @@
class IOBuffer;
class SSLInfo;
+class URLRequest;
struct WebSocketHandshakeRequestInfo;
struct WebSocketHandshakeResponseInfo;
@@ -41,6 +42,9 @@
virtual ~WebSocketEventInterface() {}
+ // Called when a URLRequest is created for handshaking.
+ virtual void OnCreateURLRequest(URLRequest* request) = 0;
+
// Called in response to an AddChannelRequest. This means that a response has
// been received from the remote server.
virtual ChannelState OnAddChannelResponse(
diff --git a/net/websockets/websocket_handshake_stream_create_helper_test.cc b/net/websockets/websocket_handshake_stream_create_helper_test.cc
index 6b43527ff..f4d7c26 100644
--- a/net/websockets/websocket_handshake_stream_create_helper_test.cc
+++ b/net/websockets/websocket_handshake_stream_create_helper_test.cc
@@ -65,6 +65,7 @@
public:
~TestConnectDelegate() override {}
+ void OnCreateRequest(URLRequest* request) override {}
void OnSuccess(std::unique_ptr<WebSocketStream> stream) override {}
void OnFailure(const std::string& failure_message) override {}
void OnStartOpeningHandshake(
diff --git a/net/websockets/websocket_stream.cc b/net/websockets/websocket_stream.cc
index ef62c24..8cca365 100644
--- a/net/websockets/websocket_stream.cc
+++ b/net/websockets/websocket_stream.cc
@@ -116,6 +116,7 @@
WebSocketHandshakeStreamBase::CreateHelper::DataKey(),
handshake_stream_create_helper_);
url_request_->SetLoadFlags(LOAD_DISABLE_CACHE | LOAD_BYPASS_CACHE);
+ connect_delegate_->OnCreateRequest(url_request_.get());
}
// Destroying this object destroys the URLRequest, which cancels the request
diff --git a/net/websockets/websocket_stream.h b/net/websockets/websocket_stream.h
index 389e092..e960f3d 100644
--- a/net/websockets/websocket_stream.h
+++ b/net/websockets/websocket_stream.h
@@ -32,6 +32,7 @@
namespace net {
class NetLogWithSource;
+class URLRequest;
class URLRequestContext;
struct WebSocketFrame;
class WebSocketHandshakeStreamBase;
@@ -71,6 +72,9 @@
class NET_EXPORT_PRIVATE ConnectDelegate {
public:
virtual ~ConnectDelegate();
+ // Called when the URLRequest is created.
+ virtual void OnCreateRequest(URLRequest* url_request) = 0;
+
// Called on successful connection. The parameter is an object derived from
// WebSocketStream.
virtual void OnSuccess(std::unique_ptr<WebSocketStream> stream) = 0;
diff --git a/net/websockets/websocket_stream_create_test_base.cc b/net/websockets/websocket_stream_create_test_base.cc
index 6ca5d70..d3b1668 100644
--- a/net/websockets/websocket_stream_create_test_base.cc
+++ b/net/websockets/websocket_stream_create_test_base.cc
@@ -50,6 +50,10 @@
const base::Closure& done_callback)
: owner_(owner), done_callback_(done_callback) {}
+ void OnCreateRequest(URLRequest* request) override {
+ owner_->url_request_ = request;
+ }
+
void OnSuccess(std::unique_ptr<WebSocketStream> stream) override {
stream.swap(owner_->stream_);
done_callback_.Run();
@@ -92,8 +96,7 @@
};
WebSocketStreamCreateTestBase::WebSocketStreamCreateTestBase()
- : has_failed_(false), ssl_fatal_(false) {
-}
+ : has_failed_(false), ssl_fatal_(false), url_request_(nullptr) {}
WebSocketStreamCreateTestBase::~WebSocketStreamCreateTestBase() {
}
diff --git a/net/websockets/websocket_stream_create_test_base.h b/net/websockets/websocket_stream_create_test_base.h
index be7912f..5805b50 100644
--- a/net/websockets/websocket_stream_create_test_base.h
+++ b/net/websockets/websocket_stream_create_test_base.h
@@ -24,6 +24,7 @@
class HttpRequestHeaders;
class HttpResponseHeaders;
+class URLRequest;
class WebSocketStream;
class WebSocketStreamRequest;
struct WebSocketHandshakeRequestInfo;
@@ -75,6 +76,7 @@
SSLInfo ssl_info_;
bool ssl_fatal_;
std::vector<std::unique_ptr<SSLSocketDataProvider>> ssl_data_;
+ URLRequest* url_request_;
// This temporarily sets WebSocketEndpointLockManager unlock delay to zero
// during tests.
diff --git a/net/websockets/websocket_stream_test.cc b/net/websockets/websocket_stream_test.cc
index fcafedc..091c28d 100644
--- a/net/websockets/websocket_stream_test.cc
+++ b/net/websockets/websocket_stream_test.cc
@@ -325,11 +325,13 @@
// Confirm that the basic case works as expected.
TEST_F(WebSocketStreamCreateTest, SimpleSuccess) {
+ EXPECT_FALSE(url_request_);
CreateAndConnectStandard("ws://localhost/", "localhost", "/",
NoSubProtocols(), LocalhostOrigin(), LocalhostUrl(),
"", "", "");
EXPECT_FALSE(request_info_);
EXPECT_FALSE(response_info_);
+ EXPECT_TRUE(url_request_);
WaitUntilConnectDone();
EXPECT_FALSE(has_failed());
EXPECT_TRUE(stream_);