// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromecast/net/small_message_socket.h"

#include <stdint.h>
#include <string.h>

#include <limits>
#include <utility>

#include "base/big_endian.h"
#include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/check_op.h"
#include "base/location.h"
#include "base/sequenced_task_runner.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "chromecast/net/io_buffer_pool.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/socket.h"

namespace chromecast {

namespace {

// Maximum number of times to read/write in a loop before reposting on the
// run loop (to allow other tasks to run).
const int kMaxIOLoop = 5;

const int kDefaultBufferSize = 2048;

constexpr size_t kMax2ByteSize = std::numeric_limits<uint16_t>::max();

}  // namespace

class SmallMessageSocket::BufferWrapper : public ::net::IOBuffer {
 public:
  void SetUnderlyingBuffer(scoped_refptr<IOBuffer> base, size_t size) {
    base_ = std::move(base);
    size_ = size;
    used_ = 0;
    data_ = base_->data();
  }

  scoped_refptr<IOBuffer> TakeUnderlyingBuffer() { return std::move(base_); }

  void ClearUnderlyingBuffer() {
    data_ = nullptr;
    base_.reset();
  }

  void DidConsume(size_t bytes) {
    used_ += bytes;
    data_ = base_->data() + used_;
  }

  char* StartOfBuffer() const {
    DCHECK(base_);
    return base_->data();
  }

  size_t size() const { return size_; }
  size_t used() const { return used_; }
  size_t remaining() const {
    DCHECK_GE(size_, used_);
    return size_ - used_;
  }

 private:
  ~BufferWrapper() override { data_ = nullptr; }

  scoped_refptr<IOBuffer> base_;
  size_t size_ = 0;
  size_t used_ = 0;
};

SmallMessageSocket::SmallMessageSocket(Delegate* delegate,
                                       std::unique_ptr<net::Socket> socket)
    : delegate_(delegate),
      socket_(std::move(socket)),
      task_runner_(base::SequencedTaskRunnerHandle::Get()),
      write_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
      write_buffer_(base::MakeRefCounted<BufferWrapper>()),
      read_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
      read_buffer_(base::MakeRefCounted<BufferWrapper>()),
      weak_factory_(this) {
  DCHECK(delegate_);
  write_storage_->SetCapacity(kDefaultBufferSize);
  read_storage_->SetCapacity(kDefaultBufferSize);
}

SmallMessageSocket::~SmallMessageSocket() = default;

void SmallMessageSocket::UseBufferPool(
    scoped_refptr<IOBufferPool> buffer_pool) {
  DCHECK(buffer_pool);
  if (buffer_pool_) {
    // Replace existing buffer pool. No need to copy data out of existing buffer
    // since it will remain valid until we are done using it.
    buffer_pool_ = std::move(buffer_pool);
    return;
  }

  buffer_pool_ = std::move(buffer_pool);
  if (!in_message_) {
    ActivateBufferPool(read_storage_->StartOfBuffer(), read_storage_->offset());
  }
}

void SmallMessageSocket::ActivateBufferPool(char* current_data,
                                            size_t current_size) {
  // Copy any already-read data into a new buffer for pool-based operation.
  DCHECK(buffer_pool_);
  DCHECK(!in_message_);

  scoped_refptr<::net::IOBuffer> new_buffer;
  size_t new_buffer_size;
  if (current_size <= buffer_pool_->buffer_size()) {
    new_buffer = buffer_pool_->GetBuffer();
    CHECK(new_buffer);
    new_buffer_size = buffer_pool_->buffer_size();
  } else {
    new_buffer = base::MakeRefCounted<::net::IOBuffer>(current_size * 2);
    new_buffer_size = current_size * 2;
  }
  memcpy(new_buffer->data(), current_data, current_size);

  read_buffer_->SetUnderlyingBuffer(std::move(new_buffer), new_buffer_size);
  read_buffer_->DidConsume(current_size);
}

void SmallMessageSocket::RemoveBufferPool() {
  if (!buffer_pool_) {
    return;
  }

  if (static_cast<size_t>(read_storage_->capacity()) < read_buffer_->used()) {
    read_storage_->SetCapacity(read_buffer_->used());
  }
  memcpy(read_storage_->StartOfBuffer(), read_buffer_->StartOfBuffer(),
         read_buffer_->used());
  read_storage_->set_offset(read_buffer_->used());

  buffer_pool_.reset();
}

void* SmallMessageSocket::PrepareSend(size_t message_size) {
  if (write_buffer_->remaining()) {
    send_blocked_ = true;
    return nullptr;
  }
  size_t bytes_for_size = SizeDataBytes(message_size);

  write_storage_->set_offset(0);
  const size_t total_size = bytes_for_size + message_size;
  // TODO(lethalantidote): Remove cast once capacity converted to size_t.
  if (static_cast<size_t>(write_storage_->capacity()) < total_size) {
    write_storage_->SetCapacity(total_size);
  }

  write_buffer_->SetUnderlyingBuffer(write_storage_, total_size);
  char* data = write_buffer_->data();
  WriteSizeData(data, message_size);
  return data + bytes_for_size;
}

bool SmallMessageSocket::SendBuffer(scoped_refptr<net::IOBuffer> data,
                                    size_t size) {
  if (write_buffer_->remaining()) {
    send_blocked_ = true;
    return false;
  }

  write_buffer_->SetUnderlyingBuffer(std::move(data), size);
  Send();
  return true;
}

// static
size_t SmallMessageSocket::SizeDataBytes(size_t message_size) {
  return (message_size < kMax2ByteSize ? 2 : 6);
}

// static
size_t SmallMessageSocket::WriteSizeData(char* ptr, size_t message_size) {
  if (message_size < kMax2ByteSize) {
    base::WriteBigEndian(ptr, static_cast<uint16_t>(message_size));
    return 2;
  }

  base::WriteBigEndian(ptr, static_cast<uint16_t>(kMax2ByteSize));
  base::WriteBigEndian(ptr + sizeof(uint16_t),
                       static_cast<uint32_t>(message_size));
  return 6;
}

void SmallMessageSocket::Send() {
  for (int i = 0; i < kMaxIOLoop; ++i) {
    int result =
        socket_->Write(write_buffer_.get(), write_buffer_->remaining(),
                       base::BindOnce(&SmallMessageSocket::OnWriteComplete,
                                      base::Unretained(this)),
                       MISSING_TRAFFIC_ANNOTATION);
    if (!HandleWriteResult(result)) {
      return;
    }
  }

  task_runner_->PostTask(FROM_HERE, base::BindOnce(&SmallMessageSocket::Send,
                                                   weak_factory_.GetWeakPtr()));
}

void SmallMessageSocket::OnWriteComplete(int result) {
  if (HandleWriteResult(result)) {
    Send();
  }
}

bool SmallMessageSocket::HandleWriteResult(int result) {
  if (result == net::ERR_IO_PENDING) {
    return false;
  }
  if (result <= 0) {
    // Post a task rather than just calling OnError(), to avoid calling
    // OnError()
    // synchronously.
    task_runner_->PostTask(FROM_HERE,
                           base::BindOnce(&SmallMessageSocket::OnError,
                                          weak_factory_.GetWeakPtr(), result));
    return false;
  }

  write_buffer_->DidConsume(result);
  if (write_buffer_->remaining()) {
    return true;
  }

  write_buffer_->ClearUnderlyingBuffer();
  if (send_blocked_) {
    send_blocked_ = false;
    delegate_->OnSendUnblocked();
  }
  return false;
}

void SmallMessageSocket::OnError(int error) {
  delegate_->OnError(error);
}

void SmallMessageSocket::ReceiveMessages() {
  // Post a task rather than just calling Read(), to avoid calling delegate
  // methods from within this method.
  task_runner()->PostTask(
      FROM_HERE,
      base::BindOnce(&SmallMessageSocket::ReceiveMessagesSynchronously,
                     weak_factory_.GetWeakPtr()));
}

void SmallMessageSocket::ReceiveMessagesSynchronously() {
  if ((buffer_pool_ && HandleCompletedMessageBuffers()) ||
      (!buffer_pool_ && HandleCompletedMessages())) {
    Read();
  }
}

void SmallMessageSocket::Read() {
  // Read in a loop for a few times while data is immediately available.
  // This improves average packet receive delay as compared to always posting a
  // new task for each call to Read().
  for (int i = 0; i < kMaxIOLoop; ++i) {
    net::IOBuffer* buffer;
    int size;
    if (buffer_pool_) {
      buffer = read_buffer_.get();
      size = read_buffer_->remaining();
    } else {
      buffer = read_storage_.get();
      size = read_storage_->RemainingCapacity();
    }
    int read_result =
        socket()->Read(buffer, size,
                       base::BindOnce(&SmallMessageSocket::OnReadComplete,
                                      base::Unretained(this)));

    if (!HandleReadResult(read_result)) {
      return;
    }
  }

  task_runner()->PostTask(
      FROM_HERE,
      base::BindOnce(&SmallMessageSocket::Read, weak_factory_.GetWeakPtr()));
}

void SmallMessageSocket::OnReadComplete(int result) {
  if (HandleReadResult(result)) {
    Read();
  }
}

bool SmallMessageSocket::HandleReadResult(int result) {
  if (result == net::ERR_IO_PENDING) {
    return false;
  }

  if (result == 0 || result == net::ERR_CONNECTION_CLOSED) {
    delegate_->OnEndOfStream();
    return false;
  }

  if (result < 0) {
    delegate_->OnError(result);
    return false;
  }

  if (buffer_pool_) {
    read_buffer_->DidConsume(result);
    return HandleCompletedMessageBuffers();
  } else {
    read_storage_->set_offset(read_storage_->offset() + result);
    return HandleCompletedMessages();
  }
}

// static
bool SmallMessageSocket::ReadSize(char* ptr,
                                  size_t bytes_read,
                                  size_t& data_offset,
                                  size_t& message_size) {
  if (bytes_read < sizeof(uint16_t)) {
    return false;
  }

  uint16_t first_size;
  base::ReadBigEndian(ptr, &first_size);

  if (first_size < kMax2ByteSize) {
    data_offset = sizeof(uint16_t);
    message_size = first_size;
  } else {
    if (bytes_read < sizeof(uint16_t) + sizeof(uint32_t)) {
      return false;
    }
    uint32_t real_size;
    base::ReadBigEndian(ptr + sizeof(uint16_t), &real_size);
    data_offset = sizeof(uint16_t) + sizeof(uint32_t);
    message_size = real_size;
  }
  return true;
}

bool SmallMessageSocket::HandleCompletedMessages() {
  DCHECK(!buffer_pool_);
  bool keep_reading = true;
  size_t bytes_read = read_storage_->offset();
  char* start_ptr = read_storage_->StartOfBuffer();
  while (keep_reading) {
    size_t data_offset;
    size_t message_size;
    if (!ReadSize(start_ptr, bytes_read, data_offset, message_size)) {
      break;
    }
    size_t total_size = data_offset + message_size;

    if (static_cast<size_t>(read_storage_->capacity()) < total_size) {
      if (start_ptr != read_storage_->StartOfBuffer()) {
        memmove(read_storage_->StartOfBuffer(), start_ptr, bytes_read);
        read_storage_->set_offset(bytes_read);
      }
      read_storage_->SetCapacity(total_size);
      return true;
    }

    if (bytes_read < total_size) {
      break;  // Haven't received the full message yet.
    }

    // Take a weak pointer in case OnMessage() causes this to be deleted.
    auto self = weak_factory_.GetWeakPtr();
    in_message_ = true;
    keep_reading = delegate_->OnMessage(start_ptr + data_offset, message_size);
    if (!self) {
      return false;
    }
    in_message_ = false;

    start_ptr += total_size;
    bytes_read -= total_size;

    if (buffer_pool_) {
      // A buffer pool was added within OnMessage().
      ActivateBufferPool(start_ptr, bytes_read);
      return (keep_reading ? HandleCompletedMessageBuffers() : false);
    }
  }

  if (start_ptr != read_storage_->StartOfBuffer()) {
    memmove(read_storage_->StartOfBuffer(), start_ptr, bytes_read);
    read_storage_->set_offset(bytes_read);
  }

  return keep_reading;
}

bool SmallMessageSocket::HandleCompletedMessageBuffers() {
  DCHECK(buffer_pool_);
  size_t bytes_read;
  while ((bytes_read = read_buffer_->used())) {
    size_t data_offset;
    size_t message_size;
    if (!ReadSize(read_buffer_->StartOfBuffer(), bytes_read, data_offset,
                  message_size)) {
      break;
    }
    size_t total_size = data_offset + message_size;

    if (read_buffer_->size() < total_size) {
      // Current buffer is not big enough.
      auto new_buffer = base::MakeRefCounted<::net::IOBuffer>(total_size);
      memcpy(new_buffer->data(), read_buffer_->StartOfBuffer(), bytes_read);
      read_buffer_->SetUnderlyingBuffer(std::move(new_buffer), total_size);
      read_buffer_->DidConsume(bytes_read);
      return true;
    }

    if (bytes_read < total_size) {
      break;  // Haven't received the full message yet.
    }

    auto old_buffer = read_buffer_->TakeUnderlyingBuffer();
    auto new_buffer = buffer_pool_->GetBuffer();
    CHECK(new_buffer);
    size_t new_buffer_size = buffer_pool_->buffer_size();
    size_t extra_size = bytes_read - total_size;
    if (extra_size > 0) {
      // Copy extra data to new buffer.
      if (extra_size > buffer_pool_->buffer_size()) {
        new_buffer = base::MakeRefCounted<::net::IOBuffer>(extra_size);
        new_buffer_size = extra_size;
      }
      memcpy(new_buffer->data(), old_buffer->data() + total_size, extra_size);
    }
    read_buffer_->SetUnderlyingBuffer(std::move(new_buffer), new_buffer_size);
    read_buffer_->DidConsume(extra_size);

    // Take a weak pointer in case OnMessageBuffer() causes this to be deleted.
    auto self = weak_factory_.GetWeakPtr();
    bool keep_reading =
        delegate_->OnMessageBuffer(std::move(old_buffer), total_size);
    if (!self || !keep_reading) {
      return false;
    }
    if (!buffer_pool_) {
      // The buffer pool was removed within OnMessageBuffer().
      return HandleCompletedMessages();
    }
  }

  return true;
}

bool SmallMessageSocket::Delegate::OnMessageBuffer(
    scoped_refptr<net::IOBuffer> buffer,
    size_t size) {
  size_t offset = SizeDataBytes(size);
  return OnMessage(buffer->data() + offset, size - offset);
}

}  // namespace chromecast
