blob: 5b02b18cb61d936319ec387fef82bca69e666996 [file] [log] [blame]
[email protected]adb225d2013-08-30 13:14:431// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_basic_stream.h"
6
7#include <algorithm>
8#include <limits>
9#include <string>
10#include <vector>
11
12#include "base/basictypes.h"
13#include "base/bind.h"
14#include "base/logging.h"
15#include "net/base/io_buffer.h"
16#include "net/base/net_errors.h"
17#include "net/socket/client_socket_handle.h"
18#include "net/websockets/websocket_errors.h"
19#include "net/websockets/websocket_frame.h"
20#include "net/websockets/websocket_frame_parser.h"
21
22namespace net {
23
24namespace {
25
26// The number of bytes to attempt to read at a time.
27// TODO(ricea): See if there is a better number or algorithm to fulfill our
28// requirements:
29// 1. We would like to use minimal memory on low-bandwidth or idle connections
30// 2. We would like to read as close to line speed as possible on
31// high-bandwidth connections
32// 3. We can't afford to cause jank on the IO thread by copying large buffers
33// around
34// 4. We would like to hit any sweet-spots that might exist in terms of network
35// packet sizes / encryption block sizes / IPC alignment issues, etc.
36const int kReadBufferSize = 32 * 1024;
37
38} // namespace
39
40WebSocketBasicStream::WebSocketBasicStream(
41 scoped_ptr<ClientSocketHandle> connection)
42 : read_buffer_(new IOBufferWithSize(kReadBufferSize)),
43 connection_(connection.Pass()),
44 generate_websocket_masking_key_(&GenerateWebSocketMaskingKey) {
45 DCHECK(connection_->is_initialized());
46}
47
48WebSocketBasicStream::~WebSocketBasicStream() { Close(); }
49
50int WebSocketBasicStream::ReadFrames(
51 ScopedVector<WebSocketFrameChunk>* frame_chunks,
52 const CompletionCallback& callback) {
53 DCHECK(frame_chunks->empty());
54 // If there is data left over after parsing the HTTP headers, attempt to parse
55 // it as WebSocket frames.
56 if (http_read_buffer_) {
57 DCHECK_GE(http_read_buffer_->offset(), 0);
58 // We cannot simply copy the data into read_buffer_, as it might be too
59 // large.
60 scoped_refptr<GrowableIOBuffer> buffered_data;
61 buffered_data.swap(http_read_buffer_);
62 DCHECK(http_read_buffer_.get() == NULL);
63 if (!parser_.Decode(buffered_data->StartOfBuffer(),
64 buffered_data->offset(),
65 frame_chunks))
66 return WebSocketErrorToNetError(parser_.websocket_error());
67 if (!frame_chunks->empty())
68 return OK;
69 }
70
71 // Run until socket stops giving us data or we get some chunks.
72 while (true) {
73 // base::Unretained(this) here is safe because net::Socket guarantees not to
74 // call any callbacks after Disconnect(), which we call from the
75 // destructor. The caller of ReadFrames() is required to keep |frame_chunks|
76 // valid.
77 int result = connection_->socket()
78 ->Read(read_buffer_.get(),
79 read_buffer_->size(),
80 base::Bind(&WebSocketBasicStream::OnReadComplete,
81 base::Unretained(this),
82 base::Unretained(frame_chunks),
83 callback));
84 if (result == ERR_IO_PENDING)
85 return result;
86 result = HandleReadResult(result, frame_chunks);
87 if (result != ERR_IO_PENDING)
88 return result;
89 }
90}
91
92int WebSocketBasicStream::WriteFrames(
93 ScopedVector<WebSocketFrameChunk>* frame_chunks,
94 const CompletionCallback& callback) {
95 // This function always concatenates all frames into a single buffer.
96 // TODO(ricea): Investigate whether it would be better in some cases to
97 // perform multiple writes with smaller buffers.
98 //
99 // First calculate the size of the buffer we need to allocate.
100 typedef ScopedVector<WebSocketFrameChunk>::const_iterator Iterator;
101 const int kMaximumTotalSize = std::numeric_limits<int>::max();
102 int total_size = 0;
103 for (Iterator it = frame_chunks->begin(); it != frame_chunks->end(); ++it) {
104 WebSocketFrameChunk* chunk = *it;
105 DCHECK(chunk->header)
106 << "Only complete frames are supported by WebSocketBasicStream";
107 DCHECK(chunk->final_chunk)
108 << "Only complete frames are supported by WebSocketBasicStream";
109 // Force the masked bit on.
110 chunk->header->masked = true;
111 // We enforce flow control so the renderer should never be able to force us
112 // to cache anywhere near 2GB of frames.
113 int chunk_size =
114 chunk->data->size() + GetWebSocketFrameHeaderSize(*(chunk->header));
115 CHECK_GE(kMaximumTotalSize - total_size, chunk_size)
116 << "Aborting to prevent overflow";
117 total_size += chunk_size;
118 }
119 scoped_refptr<IOBufferWithSize> combined_buffer(
120 new IOBufferWithSize(total_size));
121 char* dest = combined_buffer->data();
122 int remaining_size = total_size;
123 for (Iterator it = frame_chunks->begin(); it != frame_chunks->end(); ++it) {
124 WebSocketFrameChunk* chunk = *it;
125 WebSocketMaskingKey mask = generate_websocket_masking_key_();
126 int result = WriteWebSocketFrameHeader(
127 *(chunk->header), &mask, dest, remaining_size);
128 DCHECK(result != ERR_INVALID_ARGUMENT)
129 << "WriteWebSocketFrameHeader() says that " << remaining_size
130 << " is not enough to write the header in. This should not happen.";
131 CHECK_GE(result, 0) << "Potentially security-critical check failed";
132 dest += result;
133 remaining_size -= result;
134
135 const char* const frame_data = chunk->data->data();
136 const int frame_size = chunk->data->size();
137 CHECK_GE(remaining_size, frame_size);
138 std::copy(frame_data, frame_data + frame_size, dest);
139 MaskWebSocketFramePayload(mask, 0, dest, frame_size);
140 dest += frame_size;
141 remaining_size -= frame_size;
142 }
143 DCHECK_EQ(0, remaining_size) << "Buffer size calculation was wrong; "
144 << remaining_size << " bytes left over.";
145 scoped_refptr<DrainableIOBuffer> drainable_buffer(
146 new DrainableIOBuffer(combined_buffer, total_size));
147 return WriteEverything(drainable_buffer, callback);
148}
149
150void WebSocketBasicStream::Close() { connection_->socket()->Disconnect(); }
151
152std::string WebSocketBasicStream::GetSubProtocol() const {
153 return sub_protocol_;
154}
155
156std::string WebSocketBasicStream::GetExtensions() const { return extensions_; }
157
158int WebSocketBasicStream::SendHandshakeRequest(
159 const GURL& url,
160 const HttpRequestHeaders& headers,
161 HttpResponseInfo* response_info,
162 const CompletionCallback& callback) {
163 // TODO(ricea): Implement handshake-related functionality.
164 NOTREACHED();
165 return ERR_NOT_IMPLEMENTED;
166}
167
168int WebSocketBasicStream::ReadHandshakeResponse(
169 const CompletionCallback& callback) {
170 NOTREACHED();
171 return ERR_NOT_IMPLEMENTED;
172}
173
174/*static*/
175scoped_ptr<WebSocketBasicStream>
176WebSocketBasicStream::CreateWebSocketBasicStreamForTesting(
177 scoped_ptr<ClientSocketHandle> connection,
178 const scoped_refptr<GrowableIOBuffer>& http_read_buffer,
179 const std::string& sub_protocol,
180 const std::string& extensions,
181 WebSocketMaskingKeyGeneratorFunction key_generator_function) {
182 scoped_ptr<WebSocketBasicStream> stream(
183 new WebSocketBasicStream(connection.Pass()));
184 if (http_read_buffer) {
185 stream->http_read_buffer_ = http_read_buffer;
186 }
187 stream->sub_protocol_ = sub_protocol;
188 stream->extensions_ = extensions;
189 stream->generate_websocket_masking_key_ = key_generator_function;
190 return stream.Pass();
191}
192
193int WebSocketBasicStream::WriteEverything(
194 const scoped_refptr<DrainableIOBuffer>& buffer,
195 const CompletionCallback& callback) {
196 while (buffer->BytesRemaining() > 0) {
197 // The use of base::Unretained() here is safe because on destruction we
198 // disconnect the socket, preventing any further callbacks.
199 int result = connection_->socket()
200 ->Write(buffer.get(),
201 buffer->BytesRemaining(),
202 base::Bind(&WebSocketBasicStream::OnWriteComplete,
203 base::Unretained(this),
204 buffer,
205 callback));
206 if (result > 0) {
207 buffer->DidConsume(result);
208 } else {
209 return result;
210 }
211 }
212 return OK;
213}
214
215void WebSocketBasicStream::OnWriteComplete(
216 const scoped_refptr<DrainableIOBuffer>& buffer,
217 const CompletionCallback& callback,
218 int result) {
219 if (result < 0) {
220 DCHECK(result != ERR_IO_PENDING);
221 callback.Run(result);
222 return;
223 }
224
225 DCHECK(result != 0);
226 buffer->DidConsume(result);
227 result = WriteEverything(buffer, callback);
228 if (result != ERR_IO_PENDING)
229 callback.Run(result);
230}
231
232int WebSocketBasicStream::HandleReadResult(
233 int result,
234 ScopedVector<WebSocketFrameChunk>* frame_chunks) {
235 DCHECK_NE(ERR_IO_PENDING, result);
236 DCHECK(frame_chunks->empty());
237 if (result < 0)
238 return result;
239 if (result == 0)
240 return ERR_CONNECTION_CLOSED;
241 if (!parser_.Decode(read_buffer_->data(), result, frame_chunks))
242 return WebSocketErrorToNetError(parser_.websocket_error());
243 if (!frame_chunks->empty())
244 return OK;
245 return ERR_IO_PENDING;
246}
247
248void WebSocketBasicStream::OnReadComplete(
249 ScopedVector<WebSocketFrameChunk>* frame_chunks,
250 const CompletionCallback& callback,
251 int result) {
252 result = HandleReadResult(result, frame_chunks);
253 if (result == ERR_IO_PENDING)
254 result = ReadFrames(frame_chunks, callback);
255 if (result != ERR_IO_PENDING)
256 callback.Run(result);
257}
258
259} // namespace net