Fix WebSocketServer extension parser.
This CL makes the WebSocket server in net/server use the net/websockets
parser for parsing Sec-WebSocket-Extensions in the extension negotiation.
The new implementation validates the extension negotiation offer more
strictly than before. Specifically,
- Malformed Sec-WebSocket-Extensions header value causes connection failure.
- Previously it was just ignored.
- Malformed permessage-deflate parameters are declined.
- Previously part of such params were accepted partially.
BUG=523228
Review URL: https://codereview.chromium.org/1340523002
Cr-Commit-Position: refs/heads/master@{#351040}
diff --git a/net/net.gypi b/net/net.gypi
index eca3a91f..433cea8 100644
--- a/net/net.gypi
+++ b/net/net.gypi
@@ -1734,6 +1734,7 @@
'websockets/websocket_end_to_end_test.cc',
'websockets/websocket_errors_test.cc',
'websockets/websocket_extension_parser_test.cc',
+ 'websockets/websocket_extension_test.cc',
'websockets/websocket_frame_parser_test.cc',
'websockets/websocket_frame_test.cc',
'websockets/websocket_handshake_challenge_test.cc',
diff --git a/net/server/http_server.cc b/net/server/http_server.cc
index 3abd44df..f3560e8 100644
--- a/net/server/http_server.cc
+++ b/net/server/http_server.cc
@@ -236,11 +236,8 @@
connection->socket()->GetPeerAddress(&request.peer);
if (request.HasHeaderValue("connection", "upgrade")) {
- scoped_ptr<WebSocket> websocket =
- WebSocket::CreateWebSocket(this, connection, request);
- if (!websocket) // Not enough data was received.
- break;
- connection->SetWebSocket(websocket.Pass());
+ connection->SetWebSocket(
+ make_scoped_ptr(new WebSocket(this, connection)));
read_buf->DidConsume(pos);
delegate_->OnWebSocketRequest(connection->id(), request);
if (HasClosedConnection(connection))
diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc
index c9637450..79ffcec 100644
--- a/net/server/web_socket.cc
+++ b/net/server/web_socket.cc
@@ -4,6 +4,8 @@
#include "net/server/web_socket.h"
+#include <vector>
+
#include "base/base64.h"
#include "base/logging.h"
#include "base/sha1.h"
@@ -15,70 +17,87 @@
#include "net/server/http_server_request_info.h"
#include "net/server/http_server_response_info.h"
#include "net/server/web_socket_encoder.h"
+#include "net/websockets/websocket_deflate_parameters.h"
+#include "net/websockets/websocket_extension.h"
+#include "net/websockets/websocket_handshake_constants.h"
namespace net {
-WebSocket::WebSocket(HttpServer* server,
- HttpConnection* connection,
- const HttpServerRequestInfo& request)
- : server_(server), connection_(connection), closed_(false) {
- std::string request_extensions =
- request.GetHeaderValue("sec-websocket-extensions");
- encoder_.reset(WebSocketEncoder::CreateServer(request_extensions,
- &response_extensions_));
- if (!response_extensions_.empty()) {
- response_extensions_ =
- "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n";
- }
+namespace {
+
+std::string ExtensionsHeaderString(
+ const std::vector<WebSocketExtension>& extensions) {
+ if (extensions.empty())
+ return std::string();
+
+ std::string result = "Sec-WebSocket-Extensions: " + extensions[0].ToString();
+ for (size_t i = 1; i < extensions.size(); ++i)
+ result += ", " + extensions[i].ToString();
+ return result + "\r\n";
}
+std::string ValidResponseString(
+ const std::string& accept_hash,
+ const std::vector<WebSocketExtension> extensions) {
+ return base::StringPrintf(
+ "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: %s\r\n"
+ "%s"
+ "\r\n",
+ accept_hash.c_str(), ExtensionsHeaderString(extensions).c_str());
+}
+
+} // namespace
+
+WebSocket::WebSocket(HttpServer* server, HttpConnection* connection)
+ : server_(server), connection_(connection), closed_(false) {}
+
WebSocket::~WebSocket() {}
-scoped_ptr<WebSocket> WebSocket::CreateWebSocket(
- HttpServer* server,
- HttpConnection* connection,
- const HttpServerRequestInfo& request) {
+void WebSocket::Accept(const HttpServerRequestInfo& request) {
std::string version = request.GetHeaderValue("sec-websocket-version");
if (version != "8" && version != "13") {
- server->SendResponse(
- connection->id(),
- HttpServerResponseInfo::CreateFor500(
- "Invalid request format. The version is not valid."));
- return nullptr;
+ SendErrorResponse("Invalid request format. The version is not valid.");
+ return;
}
std::string key = request.GetHeaderValue("sec-websocket-key");
if (key.empty()) {
- server->SendResponse(
- connection->id(),
- HttpServerResponseInfo::CreateFor500(
- "Invalid request format. Sec-WebSocket-Key is empty or isn't "
- "specified."));
- return nullptr;
+ SendErrorResponse(
+ "Invalid request format. Sec-WebSocket-Key is empty or isn't "
+ "specified.");
+ return;
}
- return make_scoped_ptr(new WebSocket(server, connection, request));
-}
-
-void WebSocket::Accept(const HttpServerRequestInfo& request) {
- static const char* const kWebSocketGuid =
- "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
- std::string key = request.GetHeaderValue("sec-websocket-key");
- std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid);
std::string encoded_hash;
- base::Base64Encode(base::SHA1HashString(data), &encoded_hash);
+ base::Base64Encode(base::SHA1HashString(key + websockets::kWebSocketGuid),
+ &encoded_hash);
- server_->SendRaw(
- connection_->id(),
- base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "Sec-WebSocket-Accept: %s\r\n"
- "%s"
- "\r\n",
- encoded_hash.c_str(), response_extensions_.c_str()));
+ std::vector<WebSocketExtension> response_extensions;
+ auto i = request.headers.find("sec-websocket-extensions");
+ if (i == request.headers.end()) {
+ encoder_ = WebSocketEncoder::CreateServer();
+ } else {
+ WebSocketDeflateParameters params;
+ encoder_ = WebSocketEncoder::CreateServer(i->second, ¶ms);
+ if (!encoder_) {
+ Fail();
+ return;
+ }
+ if (encoder_->deflate_enabled()) {
+ DCHECK(params.IsValidAsResponse());
+ response_extensions.push_back(params.AsExtension());
+ }
+ }
+ server_->SendRaw(connection_->id(),
+ ValidResponseString(encoded_hash, response_extensions));
}
WebSocket::ParseResult WebSocket::Read(std::string* message) {
+ if (closed_)
+ return FRAME_CLOSE;
+
HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf();
base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize());
int bytes_consumed = 0;
@@ -98,4 +117,17 @@
server_->SendRaw(connection_->id(), encoded);
}
+void WebSocket::Fail() {
+ closed_ = true;
+ // TODO(yhirano): The server SHOULD log the problem.
+ server_->Close(connection_->id());
+}
+
+void WebSocket::SendErrorResponse(const std::string& message) {
+ if (closed_)
+ return;
+ closed_ = true;
+ server_->Send500(connection_->id(), message);
+}
+
} // namespace net
diff --git a/net/server/web_socket.h b/net/server/web_socket.h
index d9509d8..5309544 100644
--- a/net/server/web_socket.h
+++ b/net/server/web_socket.h
@@ -27,10 +27,7 @@
FRAME_ERROR
};
- static scoped_ptr<WebSocket> CreateWebSocket(
- HttpServer* server,
- HttpConnection* connection,
- const HttpServerRequestInfo& request);
+ WebSocket(HttpServer* server, HttpConnection* connection);
void Accept(const HttpServerRequestInfo& request);
ParseResult Read(std::string* message);
@@ -38,14 +35,12 @@
~WebSocket();
private:
- WebSocket(HttpServer* server,
- HttpConnection* connection,
- const HttpServerRequestInfo& request);
+ void Fail();
+ void SendErrorResponse(const std::string& message);
HttpServer* const server_;
HttpConnection* const connection_;
scoped_ptr<WebSocketEncoder> encoder_;
- std::string response_extensions_;
bool closed_;
DISALLOW_COPY_AND_ASSIGN(WebSocket);
diff --git a/net/server/web_socket_encoder.cc b/net/server/web_socket_encoder.cc
index 1a5431af..b1b93ee 100644
--- a/net/server/web_socket_encoder.cc
+++ b/net/server/web_socket_encoder.cc
@@ -4,10 +4,14 @@
#include "net/server/web_socket_encoder.h"
+#include <vector>
+
#include "base/logging.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "net/base/io_buffer.h"
+#include "net/websockets/websocket_deflate_parameters.h"
+#include "net/websockets/websocket_extension.h"
#include "net/websockets/websocket_extension_parser.h"
namespace net {
@@ -180,151 +184,111 @@
} // anonymous namespace
// static
-WebSocketEncoder* WebSocketEncoder::CreateServer(
- const std::string& request_extensions,
- std::string* response_extensions) {
- bool deflate;
- bool has_client_window_bits;
- int client_window_bits;
- int server_window_bits;
- bool client_no_context_takeover;
- bool server_no_context_takeover;
- ParseExtensions(request_extensions, &deflate, &has_client_window_bits,
- &client_window_bits, &server_window_bits,
- &client_no_context_takeover, &server_no_context_takeover);
+scoped_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() {
+ return make_scoped_ptr(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
+}
- if (deflate) {
- *response_extensions = base::StringPrintf(
- "permessage-deflate; server_max_window_bits=%d%s", server_window_bits,
- server_no_context_takeover ? "; server_no_context_takeover" : "");
- if (has_client_window_bits) {
- base::StringAppendF(response_extensions, "; client_max_window_bits=%d",
- client_window_bits);
- } else {
- DCHECK_EQ(client_window_bits, 15);
- }
- return new WebSocketEncoder(true /* is_server */, server_window_bits,
- client_window_bits, server_no_context_takeover);
- } else {
- *response_extensions = std::string();
- return new WebSocketEncoder(true /* is_server */);
+// static
+scoped_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer(
+ const std::string& extensions,
+ WebSocketDeflateParameters* deflate_parameters) {
+ WebSocketExtensionParser parser;
+ if (!parser.Parse(extensions)) {
+ // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the
+ // connection.
+ return nullptr;
}
+
+ for (const auto& extension : parser.extensions()) {
+ std::string failure_message;
+ WebSocketDeflateParameters offer;
+ if (!offer.Initialize(extension, &failure_message) ||
+ !offer.IsValidAsRequest(&failure_message)) {
+ // We decline unknown / malformed extensions.
+ continue;
+ }
+
+ WebSocketDeflateParameters response = offer;
+ if (offer.is_client_max_window_bits_specified() &&
+ !offer.has_client_max_window_bits_value()) {
+ // We need to choose one value for the response.
+ response.SetClientMaxWindowBits(15);
+ }
+ DCHECK(response.IsValidAsResponse());
+ DCHECK(offer.IsCompatibleWith(response));
+ auto deflater = make_scoped_ptr(
+ new WebSocketDeflater(response.server_context_take_over_mode()));
+ auto inflater = make_scoped_ptr(
+ new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize));
+ if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) ||
+ !inflater->Initialize(response.PermissiveClientMaxWindowBits())) {
+ // For some reason we cannot accept the parameters.
+ continue;
+ }
+ *deflate_parameters = response;
+ return make_scoped_ptr(
+ new WebSocketEncoder(FOR_SERVER, deflater.Pass(), inflater.Pass()));
+ }
+
+ // We cannot find an acceptable offer.
+ return make_scoped_ptr(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
}
// static
WebSocketEncoder* WebSocketEncoder::CreateClient(
const std::string& response_extensions) {
- bool deflate;
- bool has_client_window_bits;
- int client_window_bits;
- int server_window_bits;
- bool client_no_context_takeover;
- bool server_no_context_takeover;
- ParseExtensions(response_extensions, &deflate, &has_client_window_bits,
- &client_window_bits, &server_window_bits,
- &client_no_context_takeover, &server_no_context_takeover);
-
- if (deflate) {
- return new WebSocketEncoder(false /* is_server */, client_window_bits,
- server_window_bits, client_no_context_takeover);
- } else {
- return new WebSocketEncoder(false /* is_server */);
- }
-}
-
-// static
-void WebSocketEncoder::ParseExtensions(const std::string& header_value,
- bool* deflate,
- bool* has_client_window_bits,
- int* client_window_bits,
- int* server_window_bits,
- bool* client_no_context_takeover,
- bool* server_no_context_takeover) {
- *deflate = false;
- *has_client_window_bits = false;
- *client_window_bits = 15;
- *server_window_bits = 15;
- *client_no_context_takeover = false;
- *server_no_context_takeover = false;
-
- if (header_value.empty())
- return;
+ // TODO(yhirano): Add a way to return an error.
WebSocketExtensionParser parser;
- if (!parser.Parse(header_value))
- return;
- const std::vector<WebSocketExtension>& extensions = parser.extensions();
- // TODO(tyoshino): Fail if this method is used for parsing a response and
- // there are multiple permessage-deflate extensions or there are any unknown
- // extensions.
- for (const auto& extension : extensions) {
- if (extension.name() != "permessage-deflate") {
- continue;
- }
-
- const std::vector<WebSocketExtension::Parameter>& parameters =
- extension.parameters();
- for (const auto& param : parameters) {
- const std::string& name = param.name();
- // TODO(tyoshino): Fail the connection when an invalid value is given.
- if (name == "client_max_window_bits") {
- *has_client_window_bits = true;
- if (param.HasValue()) {
- int bits = 0;
- if (base::StringToInt(param.value(), &bits) && bits >= 8 &&
- bits <= 15) {
- *client_window_bits = bits;
- }
- }
- }
- if (name == "server_max_window_bits" && param.HasValue()) {
- int bits = 0;
- if (base::StringToInt(param.value(), &bits) && bits >= 8 && bits <= 15)
- *server_window_bits = bits;
- }
- if (name == "client_no_context_takeover")
- *client_no_context_takeover = true;
- if (name == "server_no_context_takeover")
- *server_no_context_takeover = true;
- }
- *deflate = true;
-
- break;
+ if (!parser.Parse(response_extensions)) {
+ // Parse error. Note that there are two cases here.
+ // 1) There is no Sec-WebSocket-Extensions header.
+ // 2) There is a malformed Sec-WebSocketExtensions header.
+ // We should return a deflate-disabled encoder for the former case and
+ // fail the connection for the latter case.
+ return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr);
}
-}
+ if (parser.extensions().size() != 1) {
+ // Only permessage-deflate extension is supported.
+ // TODO (yhirano): Fail the connection.
+ return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr);
+ }
+ const auto& extension = parser.extensions()[0];
+ WebSocketDeflateParameters params;
+ std::string failure_message;
+ if (!params.Initialize(extension, &failure_message) ||
+ !params.IsValidAsResponse(&failure_message)) {
+ // TODO (yhirano): Fail the connection.
+ return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr);
+ }
-WebSocketEncoder::WebSocketEncoder(bool is_server) : is_server_(is_server) {
-}
-
-WebSocketEncoder::WebSocketEncoder(bool is_server,
- int deflate_bits,
- int inflate_bits,
- bool no_context_takeover)
- : is_server_(is_server) {
- deflater_.reset(new WebSocketDeflater(
- no_context_takeover ? WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT
- : WebSocketDeflater::TAKE_OVER_CONTEXT));
- inflater_.reset(
+ auto deflater = make_scoped_ptr(
+ new WebSocketDeflater(params.client_context_take_over_mode()));
+ auto inflater = make_scoped_ptr(
new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize));
-
- if (!deflater_->Initialize(deflate_bits) ||
- !inflater_->Initialize(inflate_bits)) {
- // Disable deflate support.
- deflater_.reset();
- inflater_.reset();
+ if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) ||
+ !inflater->Initialize(params.PermissiveServerMaxWindowBits())) {
+ // TODO (yhirano): Fail the connection.
+ return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr);
}
+
+ return new WebSocketEncoder(FOR_CLIENT, deflater.Pass(), inflater.Pass());
}
-WebSocketEncoder::~WebSocketEncoder() {
-}
+WebSocketEncoder::WebSocketEncoder(Type type,
+ scoped_ptr<WebSocketDeflater> deflater,
+ scoped_ptr<WebSocketInflater> inflater)
+ : type_(type), deflater_(deflater.Pass()), inflater_(inflater.Pass()) {}
+
+WebSocketEncoder::~WebSocketEncoder() {}
WebSocket::ParseResult WebSocketEncoder::DecodeFrame(
const base::StringPiece& frame,
int* bytes_consumed,
std::string* output) {
bool compressed;
- WebSocket::ParseResult result =
- DecodeFrameHybi17(frame, is_server_, bytes_consumed, output, &compressed);
+ WebSocket::ParseResult result = DecodeFrameHybi17(
+ frame, type_ == FOR_SERVER, bytes_consumed, output, &compressed);
if (result == WebSocket::FRAME_OK && compressed) {
if (!Inflate(output))
result = WebSocket::FRAME_ERROR;
diff --git a/net/server/web_socket_encoder.h b/net/server/web_socket_encoder.h
index 23f0d9c..1eb749f 100644
--- a/net/server/web_socket_encoder.h
+++ b/net/server/web_socket_encoder.h
@@ -16,61 +16,50 @@
namespace net {
-class WebSocketEncoder {
+class WebSocketDeflateParameters;
+
+class WebSocketEncoder final {
public:
+ static const char kClientExtensions[];
+
~WebSocketEncoder();
- static WebSocketEncoder* CreateServer(const std::string& request_extensions,
- std::string* response_extensions);
-
- static const char kClientExtensions[];
+ // Creates and returns an encoder for a server without extensions.
+ static scoped_ptr<WebSocketEncoder> CreateServer();
+ // Creates and returns an encoder.
+ // |extensions| is the value of a Sec-WebSocket-Extensions header.
+ // Returns nullptr when there is an error.
+ static scoped_ptr<WebSocketEncoder> CreateServer(
+ const std::string& extensions,
+ WebSocketDeflateParameters* params);
+ // TODO(yhirano): Return a scoped_ptr instead of a raw pointer.
static WebSocketEncoder* CreateClient(const std::string& response_extensions);
WebSocket::ParseResult DecodeFrame(const base::StringPiece& frame,
int* bytes_consumed,
std::string* output);
-
void EncodeFrame(const std::string& frame,
int masking_key,
std::string* output);
- private:
- explicit WebSocketEncoder(bool is_server);
- WebSocketEncoder(bool is_server,
- int deflate_bits,
- int inflate_bits,
- bool no_context_takeover);
+ bool deflate_enabled() const { return deflater_; }
- // Parses a value in the Sec-WebSocket-Extensions header. If it contains a
- // single element of the permessage-deflate extension, stores the result of
- // parsing the parameters of the extension into the given variables.
- // Otherwise, returns with *deflate set to false.
- //
- // - If the client_max_window_bits parameter is missing, *client_window_bits
- // defaults to 15.
- // - If the client_max_window_bits parameter has an invalid value,
- // *client_window_bits will be set to 0.
- // - If the server_max_window_bits parameter is missing, *server_window_bits
- // defaults to 15.
- // - If the server_max_window_bits parameter has an invalid value,
- // *client_window_bits will be set to 0.
- //
- // TODO(tyoshino): Consider using a struct than taking a lot of pointers for
- // output.
- static void ParseExtensions(const std::string& header_value,
- bool* deflate,
- bool* has_client_window_bits,
- int* client_window_bits,
- int* server_window_bits,
- bool* client_no_context_takeover,
- bool* server_no_context_takeover);
+ private:
+ enum Type {
+ FOR_SERVER,
+ FOR_CLIENT,
+ };
+
+ WebSocketEncoder(Type type,
+ scoped_ptr<WebSocketDeflater> deflater,
+ scoped_ptr<WebSocketInflater> inflater);
bool Inflate(std::string* message);
bool Deflate(const std::string& message, std::string* output);
+ Type type_;
scoped_ptr<WebSocketDeflater> deflater_;
scoped_ptr<WebSocketInflater> inflater_;
- bool is_server_;
DISALLOW_COPY_AND_ASSIGN(WebSocketEncoder);
};
diff --git a/net/server/web_socket_encoder_unittest.cc b/net/server/web_socket_encoder_unittest.cc
index 7bca8764..9991bd7 100644
--- a/net/server/web_socket_encoder_unittest.cc
+++ b/net/server/web_socket_encoder_unittest.cc
@@ -3,30 +3,76 @@
// found in the LICENSE file.
#include "net/server/web_socket_encoder.h"
+
+#include "net/websockets/websocket_deflate_parameters.h"
+#include "net/websockets/websocket_extension.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
+TEST(WebSocketEncoderHandshakeTest, EmptyRequestShouldBeRejected) {
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server =
+ WebSocketEncoder::CreateServer("", ¶ms);
+
+ EXPECT_FALSE(server);
+}
+
TEST(WebSocketEncoderHandshakeTest,
CreateServerWithoutClientMaxWindowBitsParameter) {
- std::string response_extensions;
- scoped_ptr<WebSocketEncoder> server(WebSocketEncoder::CreateServer(
- "permessage-deflate", &response_extensions));
- // The response must not include client_max_window_bits if the client didn't
- // declare that it accepts the parameter.
- EXPECT_EQ("permessage-deflate; server_max_window_bits=15",
- response_extensions);
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server =
+ WebSocketEncoder::CreateServer("permessage-deflate", ¶ms);
+
+ ASSERT_TRUE(server);
+ EXPECT_TRUE(server->deflate_enabled());
+ EXPECT_EQ("permessage-deflate", params.AsExtension().ToString());
}
TEST(WebSocketEncoderHandshakeTest,
CreateServerWithServerNoContextTakeoverParameter) {
- std::string response_extensions;
- scoped_ptr<WebSocketEncoder> server(WebSocketEncoder::CreateServer(
- "permessage-deflate; server_no_context_takeover", &response_extensions));
- EXPECT_EQ(
- "permessage-deflate; server_max_window_bits=15"
- "; server_no_context_takeover",
- response_extensions);
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server = WebSocketEncoder::CreateServer(
+ "permessage-deflate; server_no_context_takeover", ¶ms);
+ ASSERT_TRUE(server);
+ EXPECT_TRUE(server->deflate_enabled());
+ EXPECT_EQ("permessage-deflate; server_no_context_takeover",
+ params.AsExtension().ToString());
+}
+
+TEST(WebSocketEncoderHandshakeTest, FirstExtensionShouldBeChosen) {
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server = WebSocketEncoder::CreateServer(
+ "permessage-deflate; server_no_context_takeover,"
+ "permessage-deflate; server_max_window_bits=15",
+ ¶ms);
+
+ ASSERT_TRUE(server);
+ EXPECT_TRUE(server->deflate_enabled());
+ EXPECT_EQ("permessage-deflate; server_no_context_takeover",
+ params.AsExtension().ToString());
+}
+
+TEST(WebSocketEncoderHandshakeTest, FirstValidExtensionShouldBeChosen) {
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server = WebSocketEncoder::CreateServer(
+ "permessage-deflate; Xserver_no_context_takeover,"
+ "permessage-deflate; server_max_window_bits=15",
+ ¶ms);
+
+ ASSERT_TRUE(server);
+ EXPECT_TRUE(server->deflate_enabled());
+ EXPECT_EQ("permessage-deflate; server_max_window_bits=15",
+ params.AsExtension().ToString());
+}
+
+TEST(WebSocketEncoderHandshakeTest, AllExtensionsAreUnknownOrMalformed) {
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server =
+ WebSocketEncoder::CreateServer("unknown, permessage-deflate; x", ¶ms);
+
+ ASSERT_TRUE(server);
+ EXPECT_FALSE(server->deflate_enabled());
}
class WebSocketEncoderTest : public testing::Test {
@@ -35,7 +81,7 @@
void SetUp() override {
std::string response_extensions;
- server_.reset(WebSocketEncoder::CreateServer("", &response_extensions));
+ server_ = WebSocketEncoder::CreateServer();
EXPECT_EQ(std::string(), response_extensions);
client_.reset(WebSocketEncoder::CreateClient(""));
}
@@ -50,17 +96,29 @@
WebSocketEncoderCompressionTest() : WebSocketEncoderTest() {}
void SetUp() override {
- std::string response_extensions;
- server_.reset(WebSocketEncoder::CreateServer(
- "permessage-deflate; client_max_window_bits", &response_extensions));
- EXPECT_EQ(
- "permessage-deflate; server_max_window_bits=15; "
- "client_max_window_bits=15",
- response_extensions);
- client_.reset(WebSocketEncoder::CreateClient(response_extensions));
+ WebSocketDeflateParameters params;
+ server_ = WebSocketEncoder::CreateServer(
+ "permessage-deflate; client_max_window_bits", ¶ms);
+ ASSERT_TRUE(server_);
+ EXPECT_TRUE(server_->deflate_enabled());
+ EXPECT_EQ("permessage-deflate; client_max_window_bits=15",
+ params.AsExtension().ToString());
+ client_.reset(
+ WebSocketEncoder::CreateClient(params.AsExtension().ToString()));
}
};
+TEST_F(WebSocketEncoderTest, DeflateDisabledEncoder) {
+ scoped_ptr<WebSocketEncoder> server(WebSocketEncoder::CreateServer());
+ scoped_ptr<WebSocketEncoder> client(WebSocketEncoder::CreateClient(""));
+
+ ASSERT_TRUE(server);
+ ASSERT_TRUE(client);
+
+ EXPECT_FALSE(server->deflate_enabled());
+ EXPECT_FALSE(client->deflate_enabled());
+}
+
TEST_F(WebSocketEncoderTest, ClientToServer) {
std::string frame("ClientToServer");
int mask = 123456;
diff --git a/net/websockets/websocket_deflate_parameters.h b/net/websockets/websocket_deflate_parameters.h
index d0c73bff..a377e52 100644
--- a/net/websockets/websocket_deflate_parameters.h
+++ b/net/websockets/websocket_deflate_parameters.h
@@ -95,6 +95,13 @@
client_max_window_bits_ = WindowBits(bits, true, true);
}
+ int PermissiveServerMaxWindowBits() const {
+ return server_max_window_bits_.PermissiveBits();
+ }
+ int PermissiveClientMaxWindowBits() const {
+ return client_max_window_bits_.PermissiveBits();
+ }
+
// Return true if |bits| is valid as a max_window_bits value.
static bool IsValidWindowBits(int bits) { return 8 <= bits && bits <= 15; }
@@ -104,6 +111,10 @@
WindowBits(int16_t bits, bool is_specified, bool has_value)
: bits(bits), is_specified(is_specified), has_value(has_value) {}
+ int PermissiveBits() const {
+ return (is_specified && has_value) ? bits : 15;
+ }
+
int16_t bits;
// True when "window bits" parameter appears in the parameters.
bool is_specified;
diff --git a/net/websockets/websocket_extension.cc b/net/websockets/websocket_extension.cc
index bf7a54b..c5e8e17 100644
--- a/net/websockets/websocket_extension.cc
+++ b/net/websockets/websocket_extension.cc
@@ -8,6 +8,7 @@
#include <string>
#include "base/logging.h"
+#include "net/http/http_util.h"
namespace net {
@@ -18,6 +19,8 @@
const std::string& value)
: name_(name), value_(value) {
DCHECK(!value.empty());
+ // |extension-param| must be a token.
+ DCHECK(HttpUtil::IsToken(value));
}
bool WebSocketExtension::Parameter::Equals(const Parameter& other) const {
@@ -45,4 +48,22 @@
return this_parameters == other_parameters;
}
+std::string WebSocketExtension::ToString() const {
+ if (name_.empty())
+ return std::string();
+
+ std::string result = name_;
+
+ for (const auto& param : parameters_) {
+ result += "; " + param.name();
+ if (!param.HasValue())
+ continue;
+
+ // |extension-param| must be a token and we don't need to quote it.
+ DCHECK(HttpUtil::IsToken(param.value()));
+ result += "=" + param.value();
+ }
+ return result;
+}
+
} // namespace net
diff --git a/net/websockets/websocket_extension.h b/net/websockets/websocket_extension.h
index 5af4023..20f2922 100644
--- a/net/websockets/websocket_extension.h
+++ b/net/websockets/websocket_extension.h
@@ -44,6 +44,7 @@
const std::string& name() const { return name_; }
const std::vector<Parameter>& parameters() const { return parameters_; }
bool Equals(const WebSocketExtension& other) const;
+ std::string ToString() const;
// The default copy constructor and the assignment operator are defined:
// we need them.
diff --git a/net/websockets/websocket_extension_test.cc b/net/websockets/websocket_extension_test.cc
new file mode 100644
index 0000000..86819b66
--- /dev/null
+++ b/net/websockets/websocket_extension_test.cc
@@ -0,0 +1,60 @@
+// Copyright 2015 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/websockets/websocket_extension.h"
+
+#include <string>
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+TEST(WebSocketExtensionTest, EqualityTest1) {
+ WebSocketExtension e1("hello");
+ WebSocketExtension e2("world");
+ EXPECT_FALSE(e1.Equals(e2));
+ EXPECT_FALSE(e2.Equals(e1));
+}
+
+TEST(WebSocketExtensionTest, EqualityTest2) {
+ WebSocketExtension e1("world");
+ WebSocketExtension e2("world");
+ e1.Add(WebSocketExtension::Parameter("foo", "bar"));
+ e2.Add(WebSocketExtension::Parameter("foo"));
+ EXPECT_FALSE(e1.Equals(e2));
+ EXPECT_FALSE(e2.Equals(e1));
+}
+
+TEST(WebSocketExtensionTest, EqualityTest3) {
+ WebSocketExtension e1("world");
+ WebSocketExtension e2("world");
+ e1.Add(WebSocketExtension::Parameter("foo", "bar"));
+ e1.Add(WebSocketExtension::Parameter("bar", "baz"));
+ e2.Add(WebSocketExtension::Parameter("bar", "baz"));
+ e2.Add(WebSocketExtension::Parameter("foo", "bar"));
+ EXPECT_TRUE(e1.Equals(e2));
+ EXPECT_TRUE(e2.Equals(e1));
+}
+
+TEST(WebSocketExtensionTest, EmptyToString) {
+ EXPECT_EQ("", WebSocketExtension().ToString());
+}
+
+TEST(WebSocketExtensionTest, SimpleToString) {
+ EXPECT_EQ("foo", WebSocketExtension("foo").ToString());
+}
+
+TEST(WebSocketExtensionTest, ToString) {
+ const std::string expected = "foo; bar; baz=hoge";
+
+ WebSocketExtension e("foo");
+ e.Add(WebSocketExtension::Parameter("bar"));
+ e.Add(WebSocketExtension::Parameter("baz", "hoge"));
+ EXPECT_EQ(expected, e.ToString());
+}
+
+} // namespace
+
+} // namespace net
diff --git a/net/websockets/websocket_handshake_constants.h b/net/websockets/websocket_handshake_constants.h
index d68a28a..f52f513 100644
--- a/net/websockets/websocket_handshake_constants.h
+++ b/net/websockets/websocket_handshake_constants.h
@@ -13,6 +13,7 @@
#define NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_CONSTANTS_H_
#include "base/basictypes.h"
+#include "net/base/net_export.h"
// This file plases constants inside the ::net::websockets namespace to avoid
// risk of collisions with other symbols in libnet.
@@ -52,7 +53,7 @@
// "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" as defined in section 4.1 of
// RFC6455.
-extern const char kWebSocketGuid[];
+extern const char NET_EXPORT kWebSocketGuid[];
// Colon-prefixed lowercase headers for SPDY3.