Implement send/receive for the new type.

This commit is contained in:
John Preston 2019-07-08 17:41:34 +02:00
parent 69b6b48738
commit 2f0331b2e0
7 changed files with 201 additions and 73 deletions

View File

@ -24,6 +24,7 @@ constexpr auto kPacketSizeMax = int(0x01000000 * sizeof(mtpPrime));
constexpr auto kFullConnectionTimeout = 8 * crl::time(1000); constexpr auto kFullConnectionTimeout = 8 * crl::time(1000);
constexpr auto kSmallBufferSize = 256 * 1024; constexpr auto kSmallBufferSize = 256 * 1024;
constexpr auto kMinPacketBuffer = 256; constexpr auto kMinPacketBuffer = 256;
constexpr auto kConnectionStartPrefixSize = 64;
} // namespace } // namespace
@ -277,7 +278,7 @@ void TcpConnection::ensureAvailableInBuffer(int amount) {
void TcpConnection::socketRead() { void TcpConnection::socketRead() {
Expects(_leftBytes > 0 || !_usingLargeBuffer); Expects(_leftBytes > 0 || !_usingLargeBuffer);
if (_socket->isConnected()) { if (!_socket || !_socket->isConnected()) {
LOG(("MTP Error: Socket not connected in socketRead()")); LOG(("MTP Error: Socket not connected in socketRead()"));
emit error(kErrorCodeOther); emit error(kErrorCodeOther);
return; return;
@ -295,11 +296,7 @@ void TcpConnection::socketRead() {
auto &buffer = _usingLargeBuffer ? _largeBuffer : _smallBuffer; auto &buffer = _usingLargeBuffer ? _largeBuffer : _smallBuffer;
const auto full = bytes::make_span(buffer).subspan(_offsetBytes); const auto full = bytes::make_span(buffer).subspan(_offsetBytes);
const auto free = full.subspan(_readBytes); const auto free = full.subspan(_readBytes);
Assert(free.size() >= readLimit); const auto readCount = _socket->read(free.subspan(0, readLimit));
const auto readCount = _socket->read(
reinterpret_cast<char*>(free.data()),
readLimit);
if (readCount > 0) { if (readCount > 0) {
const auto read = free.subspan(0, readCount); const auto read = free.subspan(0, readCount);
aesCtrEncrypt(read, _receiveKey, &_receiveState); aesCtrEncrypt(read, _receiveKey, &_receiveState);
@ -365,7 +362,9 @@ void TcpConnection::socketRead() {
TCP_LOG(("TCP Info: no bytes read, but bytes available was true...")); TCP_LOG(("TCP Info: no bytes read, but bytes available was true..."));
break; break;
} }
} while (_socket->isConnected() && _socket->hasBytesAvailable()); } while (_socket
&& _socket->isConnected()
&& _socket->hasBytesAvailable());
} }
mtpBuffer TcpConnection::parsePacket(bytes::const_span bytes) { mtpBuffer TcpConnection::parsePacket(bytes::const_span bytes) {
@ -423,16 +422,31 @@ bool TcpConnection::requiresExtendedPadding() const {
void TcpConnection::sendData(mtpBuffer &&buffer) { void TcpConnection::sendData(mtpBuffer &&buffer) {
Expects(buffer.size() > 2); Expects(buffer.size() > 2);
if (_status != Status::Finished) { if (!_socket) {
sendBuffer(std::move(buffer)); return;
} }
char connectionStartPrefixBytes[kConnectionStartPrefixSize];
const auto connectionStartPrefix = prepareConnectionStartPrefix(
bytes::make_span(connectionStartPrefixBytes));
// buffer: 2 available int-s + data + available int.
const auto bytes = _protocol->finalizePacket(buffer);
TCP_LOG(("TCP Info: write packet %1 bytes").arg(bytes.size()));
aesCtrEncrypt(bytes, _sendKey, &_sendState);
_socket->write(connectionStartPrefix, bytes);
} }
void TcpConnection::writeConnectionStart() { bytes::const_span TcpConnection::prepareConnectionStartPrefix(
bytes::span buffer) {
Expects(_protocol != nullptr); Expects(_protocol != nullptr);
if (_connectionStarted) {
return {};
}
_connectionStarted = true;
// prepare random part // prepare random part
auto nonceBytes = bytes::vector(64); char nonceBytes[64];
const auto nonce = bytes::make_span(nonceBytes); const auto nonce = bytes::make_span(nonceBytes);
const auto zero = reinterpret_cast<uchar*>(nonce.data()); const auto zero = reinterpret_cast<uchar*>(nonce.data());
@ -481,31 +495,17 @@ void TcpConnection::writeConnectionStart() {
const auto dcId = reinterpret_cast<int16*>(nonce.data() + 60); const auto dcId = reinterpret_cast<int16*>(nonce.data() + 60);
*dcId = _protocolDcId; *dcId = _protocolDcId;
_socket->write(reinterpret_cast<const char*>(nonce.data()), 56); bytes::copy(buffer, nonce.subspan(0, 56));
aesCtrEncrypt(nonce, _sendKey, &_sendState); aesCtrEncrypt(nonce, _sendKey, &_sendState);
_socket->write( bytes::copy(buffer.subspan(56), nonce.subspan(56));
reinterpret_cast<const char*>(nonce.subspan(56).data()),
8); return buffer;
} }
void TcpConnection::sendBuffer(mtpBuffer &&buffer) {
if (!_connectionStarted) {
writeConnectionStart();
_connectionStarted = true;
}
// buffer: 2 available int-s + data + available int.
const auto bytes = _protocol->finalizePacket(buffer);
TCP_LOG(("TCP Info: write packet %1 bytes").arg(bytes.size()));
aesCtrEncrypt(bytes, _sendKey, &_sendState);
_socket->write(
reinterpret_cast<const char*>(bytes.data()),
bytes.size());
}
void TcpConnection::disconnectFromServer() { void TcpConnection::disconnectFromServer() {
if (_status == Status::Finished) return; if (_status == Status::Finished) {
return;
}
_status = Status::Finished; _status = Status::Finished;
_connectedLifetime.destroy(); _connectedLifetime.destroy();
_lifetime.destroy(); _lifetime.destroy();
@ -576,7 +576,7 @@ crl::time TcpConnection::fullConnectTimeout() const {
} }
void TcpConnection::socketPacket(bytes::const_span bytes) { void TcpConnection::socketPacket(bytes::const_span bytes) {
if (_status == Status::Finished) return; Expects(_socket != nullptr);
// old quickack?.. // old quickack?..
const auto data = parsePacket(bytes); const auto data = parsePacket(bytes);
@ -620,7 +620,7 @@ bool TcpConnection::isConnected() const {
} }
int32 TcpConnection::debugState() const { int32 TcpConnection::debugState() const {
return _socket->debugState(); return _socket ? _socket->debugState() : -1;
} }
QString TcpConnection::transport() const { QString TcpConnection::transport() const {
@ -645,7 +645,9 @@ QString TcpConnection::tag() const {
} }
void TcpConnection::socketError() { void TcpConnection::socketError() {
if (_status == Status::Finished) return; if (!_socket) {
return;
}
emit error(kErrorCodeOther); emit error(kErrorCodeOther);
} }

View File

@ -48,7 +48,7 @@ private:
}; };
void socketRead(); void socketRead();
void writeConnectionStart(); bytes::const_span prepareConnectionStartPrefix(bytes::span buffer);
void socketPacket(bytes::const_span bytes); void socketPacket(bytes::const_span bytes);
@ -58,14 +58,11 @@ private:
mtpBuffer parsePacket(bytes::const_span bytes); mtpBuffer parsePacket(bytes::const_span bytes);
void ensureAvailableInBuffer(int amount); void ensureAvailableInBuffer(int amount);
static void handleError(QAbstractSocket::SocketError e, QTcpSocket &sock);
static uint32 fourCharsToUInt(char ch1, char ch2, char ch3, char ch4) { static uint32 fourCharsToUInt(char ch1, char ch2, char ch3, char ch4) {
char ch[4] = { ch1, ch2, ch3, ch4 }; char ch[4] = { ch1, ch2, ch3, ch4 };
return *reinterpret_cast<uint32*>(ch); return *reinterpret_cast<uint32*>(ch);
} }
void sendBuffer(mtpBuffer &&buffer);
std::unique_ptr<AbstractSocket> _socket; std::unique_ptr<AbstractSocket> _socket;
bool _connectionStarted = false; bool _connectionStarted = false;

View File

@ -7,6 +7,8 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
*/ */
#pragma once #pragma once
#include "base/bytes.h"
namespace MTP { namespace MTP {
namespace internal { namespace internal {
@ -38,8 +40,10 @@ public:
virtual void connectToHost(const QString &address, int port) = 0; virtual void connectToHost(const QString &address, int port) = 0;
[[nodiscard]] virtual bool isConnected() = 0; [[nodiscard]] virtual bool isConnected() = 0;
[[nodiscard]] virtual bool hasBytesAvailable() = 0; [[nodiscard]] virtual bool hasBytesAvailable() = 0;
[[nodiscard]] virtual int64 read(char *buffer, int64 maxLength) = 0; [[nodiscard]] virtual int64 read(bytes::span buffer) = 0;
virtual int64 write(const char *buffer, int64 length) = 0; virtual void write(
bytes::const_span prefix,
bytes::const_span buffer) = 0;
virtual int32 debugState() = 0; virtual int32 debugState() = 0;

View File

@ -53,12 +53,23 @@ bool TcpSocket::hasBytesAvailable() {
return _socket.bytesAvailable() > 0; return _socket.bytesAvailable() > 0;
} }
int64 TcpSocket::read(char *buffer, int64 maxLength) { int64 TcpSocket::read(bytes::span buffer) {
return _socket.read(buffer, maxLength); return _socket.read(
reinterpret_cast<char*>(buffer.data()),
buffer.size());
} }
int64 TcpSocket::write(const char *buffer, int64 length) { void TcpSocket::write(bytes::const_span prefix, bytes::const_span buffer) {
return _socket.write(buffer, length); Expects(!buffer.empty());
if (!prefix.empty()) {
_socket.write(
reinterpret_cast<const char*>(prefix.data()),
prefix.size());
}
_socket.write(
reinterpret_cast<const char*>(buffer.data()),
buffer.size());
} }
int32 TcpSocket::debugState() { int32 TcpSocket::debugState() {

View File

@ -12,15 +12,15 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
namespace MTP { namespace MTP {
namespace internal { namespace internal {
class TcpSocket : public AbstractSocket { class TcpSocket final : public AbstractSocket {
public: public:
TcpSocket(not_null<QThread*> thread, const ProxyData &proxy); TcpSocket(not_null<QThread*> thread, const ProxyData &proxy);
void connectToHost(const QString &address, int port) override; void connectToHost(const QString &address, int port) override;
bool isConnected() override; bool isConnected() override;
bool hasBytesAvailable() override; bool hasBytesAvailable() override;
int64 read(char *buffer, int64 maxLength) override; int64 read(bytes::span buffer) override;
int64 write(const char *buffer, int64 length) override; void write(bytes::const_span prefix, bytes::const_span buffer) override;
int32 debugState() override; int32 debugState() override;

View File

@ -22,7 +22,9 @@ const auto kServerHelloPart1 = qstr("\x16\x03\x03");
const auto kServerHelloPart3 = qstr("\x14\x03\x03\x00\x01\x01\x17\x03\x03"); const auto kServerHelloPart3 = qstr("\x14\x03\x03\x00\x01\x01\x17\x03\x03");
constexpr auto kServerHelloDigestPosition = 11; constexpr auto kServerHelloDigestPosition = 11;
const auto kServerHeader = qstr("\x17\x03\x03"); const auto kServerHeader = qstr("\x17\x03\x03");
constexpr auto kServerDataSkip = 5; constexpr auto kClientPartSize = 2878;
const auto kClientPrefix = qstr("\x14\x03\x03\x00\x01\x01");
const auto kClientHeader = qstr("\x17\x03\x03");
[[nodiscard]] MTPTlsClientHello PrepareClientHelloRules() { [[nodiscard]] MTPTlsClientHello PrepareClientHelloRules() {
auto stack = std::vector<QVector<MTPTlsBlock>>(); auto stack = std::vector<QVector<MTPTlsBlock>>();
@ -258,11 +260,11 @@ void ClientHelloGenerator::writeBlock(const MTPDtlsBlockDomain &data) {
} }
void ClientHelloGenerator::writeBlock(const MTPDtlsBlockScope &data) { void ClientHelloGenerator::writeBlock(const MTPDtlsBlockScope &data) {
const auto already = _result.size();
const auto storage = grow(kLengthSize); const auto storage = grow(kLengthSize);
if (storage.empty()) { if (storage.empty()) {
return; return;
} }
const auto already = _result.size();
writeBlocks(data.ventries().v); writeBlocks(data.ventries().v);
const auto length = qToBigEndian(uint16(_result.size() - already)); const auto length = qToBigEndian(uint16(_result.size() - already));
bytes::copy(storage, bytes::object_as_span(&length)); bytes::copy(storage, bytes::object_as_span(&length));
@ -383,7 +385,7 @@ void TlsSocket::plainConnected() {
static const auto kClientHelloRules = PrepareClientHelloRules(); static const auto kClientHelloRules = PrepareClientHelloRules();
const auto hello = PrepareClientHello( const auto hello = PrepareClientHello(
kClientHelloRules, kClientHelloRules,
"google.com", "www.google.com",
_key); _key);
if (hello.data.isEmpty()) { if (hello.data.isEmpty()) {
LOG(("TLS Error: Could not generate Client Hello!")); LOG(("TLS Error: Could not generate Client Hello!"));
@ -400,14 +402,15 @@ void TlsSocket::plainDisconnected() {
_state = State::NotConnected; _state = State::NotConnected;
_incoming = QByteArray(); _incoming = QByteArray();
_serverHelloLength = 0; _serverHelloLength = 0;
_incomingGoodDataOffset = 0;
_incomingGoodDataLimit = 0;
_disconnected.fire({}); _disconnected.fire({});
} }
void TlsSocket::plainReadyRead() { void TlsSocket::plainReadyRead() {
switch (_state) { switch (_state) {
case State::WaitingHello: return readHello(); case State::WaitingHello: return readHello();
case State::Ready: case State::Connected: return readData();
case State::Working: return readData();
} }
} }
@ -481,8 +484,7 @@ void TlsSocket::checkHelloParts34(int parts123Size) {
} }
void TlsSocket::checkHelloDigest() { void TlsSocket::checkHelloDigest() {
const auto incoming = bytes::make_detached_span(_incoming); const auto fulldata = bytes::make_detached_span(_incoming).subspan(
const auto fulldata = incoming.subspan(
0, 0,
kHelloDigestLength + _serverHelloLength); kHelloDigestLength + _serverHelloLength);
const auto digest = fulldata.subspan( const auto digest = fulldata.subspan(
@ -496,18 +498,70 @@ void TlsSocket::checkHelloDigest() {
handleError(); handleError();
return; return;
} }
if (incoming.size() > fulldata.size()) { shiftIncomingBy(fulldata.size());
bytes::move(incoming, incoming.subspan(fulldata.size())); if (!_incoming.isEmpty()) {
_incoming.chop(fulldata.size()); InvokeQueued(this, [=] {
InvokeQueued(this, [=] { readData(); }); if (!checkNextPacket()) {
} else { handleError();
_incoming.clear(); }
});
} }
_state = State::Ready; _incomingGoodDataOffset = _incomingGoodDataLimit = 0;
_state = State::Connected;
_connected.fire({}); _connected.fire({});
} }
void TlsSocket::readData() { void TlsSocket::readData() {
if (!isConnected()) {
return;
}
_incoming.append(_socket.readAll());
if (!checkNextPacket()) {
handleError();
} else if (hasBytesAvailable()) {
_readyRead.fire({});
}
}
bool TlsSocket::checkNextPacket() {
auto offset = 0;
const auto incoming = bytes::make_span(_incoming);
while (!_incomingGoodDataLimit) {
const auto fullHeader = kServerHeader.size() + kLengthSize;
if (incoming.size() <= offset + fullHeader) {
return true;
}
if (!CheckPart(incoming.subspan(offset), kServerHeader)) {
LOG(("TLS Error: Bad packet header."));
return false;
}
const auto length = ReadPartLength(
incoming,
offset + kServerHeader.size());
if (length > 0) {
if (offset > 0) {
shiftIncomingBy(offset);
}
_incomingGoodDataOffset = fullHeader;
_incomingGoodDataLimit = length;
} else {
offset += kServerHeader.size() + kLengthSize + length;
}
}
return true;
}
void TlsSocket::shiftIncomingBy(int amount) {
Expects(_incomingGoodDataOffset == 0);
Expects(_incomingGoodDataLimit == 0);
const auto incoming = bytes::make_detached_span(_incoming);
if (incoming.size() > amount) {
bytes::move(incoming, incoming.subspan(amount));
_incoming.chop(amount);
} else {
_incoming.clear();
}
} }
void TlsSocket::connectToHost(const QString &address, int port) { void TlsSocket::connectToHost(const QString &address, int port) {
@ -518,19 +572,76 @@ void TlsSocket::connectToHost(const QString &address, int port) {
} }
bool TlsSocket::isConnected() { bool TlsSocket::isConnected() {
return (_socket.state() == QAbstractSocket::ConnectedState); return (_state == State::Connected);
} }
bool TlsSocket::hasBytesAvailable() { bool TlsSocket::hasBytesAvailable() {
return _socket.bytesAvailable(); return (_incomingGoodDataLimit > 0)
&& (_incomingGoodDataOffset < _incoming.size());
} }
int64 TlsSocket::read(char *buffer, int64 maxLength) { int64 TlsSocket::read(bytes::span buffer) {
return _socket.read(buffer, maxLength); auto written = int64(0);
while (_incomingGoodDataLimit) {
const auto available = std::min(
_incomingGoodDataLimit,
_incoming.size() - _incomingGoodDataOffset);
if (available <= 0) {
return written;
}
const auto write = std::min(index_type(available), buffer.size());
if (write <= 0) {
return written;
}
bytes::copy(
buffer,
bytes::make_span(_incoming).subspan(
_incomingGoodDataOffset,
write));
written += write;
buffer = buffer.subspan(write);
_incomingGoodDataLimit -= write;
_incomingGoodDataOffset += write;
if (_incomingGoodDataLimit) {
return written;
}
shiftIncomingBy(base::take(_incomingGoodDataOffset));
if (!checkNextPacket()) {
_state = State::Error;
InvokeQueued(this, [=] { handleError(); });
return written;
}
}
return written;
} }
int64 TlsSocket::write(const char *buffer, int64 length) { void TlsSocket::write(bytes::const_span prefix, bytes::const_span buffer) {
return _socket.write(buffer, length); Expects(!buffer.empty());
if (!isConnected()) {
return;
}
if (!prefix.empty()) {
_socket.write(kClientPrefix.data(), kClientPrefix.size());
}
while (!buffer.empty()) {
const auto write = std::min(
kClientPartSize - prefix.size(),
buffer.size());
_socket.write(kClientHeader.data(), kClientHeader.size());
const auto size = qToBigEndian(uint16(prefix.size() + write));
_socket.write(reinterpret_cast<const char*>(&size), sizeof(size));
if (!prefix.empty()) {
_socket.write(
reinterpret_cast<const char*>(prefix.data()),
prefix.size());
prefix = bytes::const_span();
}
_socket.write(
reinterpret_cast<const char*>(buffer.data()),
write);
buffer = buffer.subspan(write);
}
} }
int32 TlsSocket::debugState() { int32 TlsSocket::debugState() {

View File

@ -12,7 +12,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
namespace MTP { namespace MTP {
namespace internal { namespace internal {
class TlsSocket : public AbstractSocket { class TlsSocket final : public AbstractSocket {
public: public:
TlsSocket( TlsSocket(
not_null<QThread*> thread, not_null<QThread*> thread,
@ -22,8 +22,8 @@ public:
void connectToHost(const QString &address, int port) override; void connectToHost(const QString &address, int port) override;
bool isConnected() override; bool isConnected() override;
bool hasBytesAvailable() override; bool hasBytesAvailable() override;
int64 read(char *buffer, int64 maxLength) override; int64 read(bytes::span buffer) override;
int64 write(const char *buffer, int64 length) override; void write(bytes::const_span prefix, bytes::const_span buffer) override;
int32 debugState() override; int32 debugState() override;
@ -32,8 +32,7 @@ private:
NotConnected, NotConnected,
Connecting, Connecting,
WaitingHello, WaitingHello,
Ready, Connected,
Working,
Error, Error,
}; };
@ -47,11 +46,15 @@ private:
void checkHelloParts34(int parts123Size); void checkHelloParts34(int parts123Size);
void checkHelloDigest(); void checkHelloDigest();
void readData(); void readData();
[[nodiscard]] bool checkNextPacket();
void shiftIncomingBy(int amount);
QTcpSocket _socket; QTcpSocket _socket;
bytes::vector _key; bytes::vector _key;
State _state = State::NotConnected; State _state = State::NotConnected;
QByteArray _incoming; QByteArray _incoming;
int _incomingGoodDataOffset = 0;
int _incomingGoodDataLimit = 0;
int16 _serverHelloLength = 0; int16 _serverHelloLength = 0;
}; };