mirror of https://github.com/procxx/kepka.git
Implement send/receive for the new type.
This commit is contained in:
parent
69b6b48738
commit
2f0331b2e0
|
@ -24,6 +24,7 @@ constexpr auto kPacketSizeMax = int(0x01000000 * sizeof(mtpPrime));
|
|||
constexpr auto kFullConnectionTimeout = 8 * crl::time(1000);
|
||||
constexpr auto kSmallBufferSize = 256 * 1024;
|
||||
constexpr auto kMinPacketBuffer = 256;
|
||||
constexpr auto kConnectionStartPrefixSize = 64;
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -277,7 +278,7 @@ void TcpConnection::ensureAvailableInBuffer(int amount) {
|
|||
void TcpConnection::socketRead() {
|
||||
Expects(_leftBytes > 0 || !_usingLargeBuffer);
|
||||
|
||||
if (_socket->isConnected()) {
|
||||
if (!_socket || !_socket->isConnected()) {
|
||||
LOG(("MTP Error: Socket not connected in socketRead()"));
|
||||
emit error(kErrorCodeOther);
|
||||
return;
|
||||
|
@ -295,11 +296,7 @@ void TcpConnection::socketRead() {
|
|||
auto &buffer = _usingLargeBuffer ? _largeBuffer : _smallBuffer;
|
||||
const auto full = bytes::make_span(buffer).subspan(_offsetBytes);
|
||||
const auto free = full.subspan(_readBytes);
|
||||
Assert(free.size() >= readLimit);
|
||||
|
||||
const auto readCount = _socket->read(
|
||||
reinterpret_cast<char*>(free.data()),
|
||||
readLimit);
|
||||
const auto readCount = _socket->read(free.subspan(0, readLimit));
|
||||
if (readCount > 0) {
|
||||
const auto read = free.subspan(0, readCount);
|
||||
aesCtrEncrypt(read, _receiveKey, &_receiveState);
|
||||
|
@ -365,7 +362,9 @@ void TcpConnection::socketRead() {
|
|||
TCP_LOG(("TCP Info: no bytes read, but bytes available was true..."));
|
||||
break;
|
||||
}
|
||||
} while (_socket->isConnected() && _socket->hasBytesAvailable());
|
||||
} while (_socket
|
||||
&& _socket->isConnected()
|
||||
&& _socket->hasBytesAvailable());
|
||||
}
|
||||
|
||||
mtpBuffer TcpConnection::parsePacket(bytes::const_span bytes) {
|
||||
|
@ -423,16 +422,31 @@ bool TcpConnection::requiresExtendedPadding() const {
|
|||
void TcpConnection::sendData(mtpBuffer &&buffer) {
|
||||
Expects(buffer.size() > 2);
|
||||
|
||||
if (_status != Status::Finished) {
|
||||
sendBuffer(std::move(buffer));
|
||||
if (!_socket) {
|
||||
return;
|
||||
}
|
||||
char connectionStartPrefixBytes[kConnectionStartPrefixSize];
|
||||
const auto connectionStartPrefix = prepareConnectionStartPrefix(
|
||||
bytes::make_span(connectionStartPrefixBytes));
|
||||
|
||||
// buffer: 2 available int-s + data + available int.
|
||||
const auto bytes = _protocol->finalizePacket(buffer);
|
||||
TCP_LOG(("TCP Info: write packet %1 bytes").arg(bytes.size()));
|
||||
aesCtrEncrypt(bytes, _sendKey, &_sendState);
|
||||
_socket->write(connectionStartPrefix, bytes);
|
||||
}
|
||||
|
||||
void TcpConnection::writeConnectionStart() {
|
||||
bytes::const_span TcpConnection::prepareConnectionStartPrefix(
|
||||
bytes::span buffer) {
|
||||
Expects(_protocol != nullptr);
|
||||
|
||||
if (_connectionStarted) {
|
||||
return {};
|
||||
}
|
||||
_connectionStarted = true;
|
||||
|
||||
// prepare random part
|
||||
auto nonceBytes = bytes::vector(64);
|
||||
char nonceBytes[64];
|
||||
const auto nonce = bytes::make_span(nonceBytes);
|
||||
|
||||
const auto zero = reinterpret_cast<uchar*>(nonce.data());
|
||||
|
@ -481,31 +495,17 @@ void TcpConnection::writeConnectionStart() {
|
|||
const auto dcId = reinterpret_cast<int16*>(nonce.data() + 60);
|
||||
*dcId = _protocolDcId;
|
||||
|
||||
_socket->write(reinterpret_cast<const char*>(nonce.data()), 56);
|
||||
bytes::copy(buffer, nonce.subspan(0, 56));
|
||||
aesCtrEncrypt(nonce, _sendKey, &_sendState);
|
||||
_socket->write(
|
||||
reinterpret_cast<const char*>(nonce.subspan(56).data()),
|
||||
8);
|
||||
bytes::copy(buffer.subspan(56), nonce.subspan(56));
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void TcpConnection::sendBuffer(mtpBuffer &&buffer) {
|
||||
if (!_connectionStarted) {
|
||||
writeConnectionStart();
|
||||
_connectionStarted = true;
|
||||
}
|
||||
|
||||
// buffer: 2 available int-s + data + available int.
|
||||
const auto bytes = _protocol->finalizePacket(buffer);
|
||||
TCP_LOG(("TCP Info: write packet %1 bytes").arg(bytes.size()));
|
||||
aesCtrEncrypt(bytes, _sendKey, &_sendState);
|
||||
_socket->write(
|
||||
reinterpret_cast<const char*>(bytes.data()),
|
||||
bytes.size());
|
||||
}
|
||||
|
||||
|
||||
void TcpConnection::disconnectFromServer() {
|
||||
if (_status == Status::Finished) return;
|
||||
if (_status == Status::Finished) {
|
||||
return;
|
||||
}
|
||||
_status = Status::Finished;
|
||||
_connectedLifetime.destroy();
|
||||
_lifetime.destroy();
|
||||
|
@ -576,7 +576,7 @@ crl::time TcpConnection::fullConnectTimeout() const {
|
|||
}
|
||||
|
||||
void TcpConnection::socketPacket(bytes::const_span bytes) {
|
||||
if (_status == Status::Finished) return;
|
||||
Expects(_socket != nullptr);
|
||||
|
||||
// old quickack?..
|
||||
const auto data = parsePacket(bytes);
|
||||
|
@ -620,7 +620,7 @@ bool TcpConnection::isConnected() const {
|
|||
}
|
||||
|
||||
int32 TcpConnection::debugState() const {
|
||||
return _socket->debugState();
|
||||
return _socket ? _socket->debugState() : -1;
|
||||
}
|
||||
|
||||
QString TcpConnection::transport() const {
|
||||
|
@ -645,7 +645,9 @@ QString TcpConnection::tag() const {
|
|||
}
|
||||
|
||||
void TcpConnection::socketError() {
|
||||
if (_status == Status::Finished) return;
|
||||
if (!_socket) {
|
||||
return;
|
||||
}
|
||||
|
||||
emit error(kErrorCodeOther);
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ private:
|
|||
};
|
||||
|
||||
void socketRead();
|
||||
void writeConnectionStart();
|
||||
bytes::const_span prepareConnectionStartPrefix(bytes::span buffer);
|
||||
|
||||
void socketPacket(bytes::const_span bytes);
|
||||
|
||||
|
@ -58,14 +58,11 @@ private:
|
|||
|
||||
mtpBuffer parsePacket(bytes::const_span bytes);
|
||||
void ensureAvailableInBuffer(int amount);
|
||||
static void handleError(QAbstractSocket::SocketError e, QTcpSocket &sock);
|
||||
static uint32 fourCharsToUInt(char ch1, char ch2, char ch3, char ch4) {
|
||||
char ch[4] = { ch1, ch2, ch3, ch4 };
|
||||
return *reinterpret_cast<uint32*>(ch);
|
||||
}
|
||||
|
||||
void sendBuffer(mtpBuffer &&buffer);
|
||||
|
||||
std::unique_ptr<AbstractSocket> _socket;
|
||||
bool _connectionStarted = false;
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
|
|||
*/
|
||||
#pragma once
|
||||
|
||||
#include "base/bytes.h"
|
||||
|
||||
namespace MTP {
|
||||
namespace internal {
|
||||
|
||||
|
@ -38,8 +40,10 @@ public:
|
|||
virtual void connectToHost(const QString &address, int port) = 0;
|
||||
[[nodiscard]] virtual bool isConnected() = 0;
|
||||
[[nodiscard]] virtual bool hasBytesAvailable() = 0;
|
||||
[[nodiscard]] virtual int64 read(char *buffer, int64 maxLength) = 0;
|
||||
virtual int64 write(const char *buffer, int64 length) = 0;
|
||||
[[nodiscard]] virtual int64 read(bytes::span buffer) = 0;
|
||||
virtual void write(
|
||||
bytes::const_span prefix,
|
||||
bytes::const_span buffer) = 0;
|
||||
|
||||
virtual int32 debugState() = 0;
|
||||
|
||||
|
|
|
@ -53,12 +53,23 @@ bool TcpSocket::hasBytesAvailable() {
|
|||
return _socket.bytesAvailable() > 0;
|
||||
}
|
||||
|
||||
int64 TcpSocket::read(char *buffer, int64 maxLength) {
|
||||
return _socket.read(buffer, maxLength);
|
||||
int64 TcpSocket::read(bytes::span buffer) {
|
||||
return _socket.read(
|
||||
reinterpret_cast<char*>(buffer.data()),
|
||||
buffer.size());
|
||||
}
|
||||
|
||||
int64 TcpSocket::write(const char *buffer, int64 length) {
|
||||
return _socket.write(buffer, length);
|
||||
void TcpSocket::write(bytes::const_span prefix, bytes::const_span buffer) {
|
||||
Expects(!buffer.empty());
|
||||
|
||||
if (!prefix.empty()) {
|
||||
_socket.write(
|
||||
reinterpret_cast<const char*>(prefix.data()),
|
||||
prefix.size());
|
||||
}
|
||||
_socket.write(
|
||||
reinterpret_cast<const char*>(buffer.data()),
|
||||
buffer.size());
|
||||
}
|
||||
|
||||
int32 TcpSocket::debugState() {
|
||||
|
|
|
@ -12,15 +12,15 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
|
|||
namespace MTP {
|
||||
namespace internal {
|
||||
|
||||
class TcpSocket : public AbstractSocket {
|
||||
class TcpSocket final : public AbstractSocket {
|
||||
public:
|
||||
TcpSocket(not_null<QThread*> thread, const ProxyData &proxy);
|
||||
|
||||
void connectToHost(const QString &address, int port) override;
|
||||
bool isConnected() override;
|
||||
bool hasBytesAvailable() override;
|
||||
int64 read(char *buffer, int64 maxLength) override;
|
||||
int64 write(const char *buffer, int64 length) override;
|
||||
int64 read(bytes::span buffer) override;
|
||||
void write(bytes::const_span prefix, bytes::const_span buffer) override;
|
||||
|
||||
int32 debugState() override;
|
||||
|
||||
|
|
|
@ -22,7 +22,9 @@ const auto kServerHelloPart1 = qstr("\x16\x03\x03");
|
|||
const auto kServerHelloPart3 = qstr("\x14\x03\x03\x00\x01\x01\x17\x03\x03");
|
||||
constexpr auto kServerHelloDigestPosition = 11;
|
||||
const auto kServerHeader = qstr("\x17\x03\x03");
|
||||
constexpr auto kServerDataSkip = 5;
|
||||
constexpr auto kClientPartSize = 2878;
|
||||
const auto kClientPrefix = qstr("\x14\x03\x03\x00\x01\x01");
|
||||
const auto kClientHeader = qstr("\x17\x03\x03");
|
||||
|
||||
[[nodiscard]] MTPTlsClientHello PrepareClientHelloRules() {
|
||||
auto stack = std::vector<QVector<MTPTlsBlock>>();
|
||||
|
@ -258,11 +260,11 @@ void ClientHelloGenerator::writeBlock(const MTPDtlsBlockDomain &data) {
|
|||
}
|
||||
|
||||
void ClientHelloGenerator::writeBlock(const MTPDtlsBlockScope &data) {
|
||||
const auto already = _result.size();
|
||||
const auto storage = grow(kLengthSize);
|
||||
if (storage.empty()) {
|
||||
return;
|
||||
}
|
||||
const auto already = _result.size();
|
||||
writeBlocks(data.ventries().v);
|
||||
const auto length = qToBigEndian(uint16(_result.size() - already));
|
||||
bytes::copy(storage, bytes::object_as_span(&length));
|
||||
|
@ -383,7 +385,7 @@ void TlsSocket::plainConnected() {
|
|||
static const auto kClientHelloRules = PrepareClientHelloRules();
|
||||
const auto hello = PrepareClientHello(
|
||||
kClientHelloRules,
|
||||
"google.com",
|
||||
"www.google.com",
|
||||
_key);
|
||||
if (hello.data.isEmpty()) {
|
||||
LOG(("TLS Error: Could not generate Client Hello!"));
|
||||
|
@ -400,14 +402,15 @@ void TlsSocket::plainDisconnected() {
|
|||
_state = State::NotConnected;
|
||||
_incoming = QByteArray();
|
||||
_serverHelloLength = 0;
|
||||
_incomingGoodDataOffset = 0;
|
||||
_incomingGoodDataLimit = 0;
|
||||
_disconnected.fire({});
|
||||
}
|
||||
|
||||
void TlsSocket::plainReadyRead() {
|
||||
switch (_state) {
|
||||
case State::WaitingHello: return readHello();
|
||||
case State::Ready:
|
||||
case State::Working: return readData();
|
||||
case State::Connected: return readData();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -481,8 +484,7 @@ void TlsSocket::checkHelloParts34(int parts123Size) {
|
|||
}
|
||||
|
||||
void TlsSocket::checkHelloDigest() {
|
||||
const auto incoming = bytes::make_detached_span(_incoming);
|
||||
const auto fulldata = incoming.subspan(
|
||||
const auto fulldata = bytes::make_detached_span(_incoming).subspan(
|
||||
0,
|
||||
kHelloDigestLength + _serverHelloLength);
|
||||
const auto digest = fulldata.subspan(
|
||||
|
@ -496,18 +498,70 @@ void TlsSocket::checkHelloDigest() {
|
|||
handleError();
|
||||
return;
|
||||
}
|
||||
if (incoming.size() > fulldata.size()) {
|
||||
bytes::move(incoming, incoming.subspan(fulldata.size()));
|
||||
_incoming.chop(fulldata.size());
|
||||
InvokeQueued(this, [=] { readData(); });
|
||||
} else {
|
||||
_incoming.clear();
|
||||
shiftIncomingBy(fulldata.size());
|
||||
if (!_incoming.isEmpty()) {
|
||||
InvokeQueued(this, [=] {
|
||||
if (!checkNextPacket()) {
|
||||
handleError();
|
||||
}
|
||||
});
|
||||
}
|
||||
_state = State::Ready;
|
||||
_incomingGoodDataOffset = _incomingGoodDataLimit = 0;
|
||||
_state = State::Connected;
|
||||
_connected.fire({});
|
||||
}
|
||||
|
||||
void TlsSocket::readData() {
|
||||
if (!isConnected()) {
|
||||
return;
|
||||
}
|
||||
_incoming.append(_socket.readAll());
|
||||
if (!checkNextPacket()) {
|
||||
handleError();
|
||||
} else if (hasBytesAvailable()) {
|
||||
_readyRead.fire({});
|
||||
}
|
||||
}
|
||||
|
||||
bool TlsSocket::checkNextPacket() {
|
||||
auto offset = 0;
|
||||
const auto incoming = bytes::make_span(_incoming);
|
||||
while (!_incomingGoodDataLimit) {
|
||||
const auto fullHeader = kServerHeader.size() + kLengthSize;
|
||||
if (incoming.size() <= offset + fullHeader) {
|
||||
return true;
|
||||
}
|
||||
if (!CheckPart(incoming.subspan(offset), kServerHeader)) {
|
||||
LOG(("TLS Error: Bad packet header."));
|
||||
return false;
|
||||
}
|
||||
const auto length = ReadPartLength(
|
||||
incoming,
|
||||
offset + kServerHeader.size());
|
||||
if (length > 0) {
|
||||
if (offset > 0) {
|
||||
shiftIncomingBy(offset);
|
||||
}
|
||||
_incomingGoodDataOffset = fullHeader;
|
||||
_incomingGoodDataLimit = length;
|
||||
} else {
|
||||
offset += kServerHeader.size() + kLengthSize + length;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void TlsSocket::shiftIncomingBy(int amount) {
|
||||
Expects(_incomingGoodDataOffset == 0);
|
||||
Expects(_incomingGoodDataLimit == 0);
|
||||
|
||||
const auto incoming = bytes::make_detached_span(_incoming);
|
||||
if (incoming.size() > amount) {
|
||||
bytes::move(incoming, incoming.subspan(amount));
|
||||
_incoming.chop(amount);
|
||||
} else {
|
||||
_incoming.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void TlsSocket::connectToHost(const QString &address, int port) {
|
||||
|
@ -518,19 +572,76 @@ void TlsSocket::connectToHost(const QString &address, int port) {
|
|||
}
|
||||
|
||||
bool TlsSocket::isConnected() {
|
||||
return (_socket.state() == QAbstractSocket::ConnectedState);
|
||||
return (_state == State::Connected);
|
||||
}
|
||||
|
||||
bool TlsSocket::hasBytesAvailable() {
|
||||
return _socket.bytesAvailable();
|
||||
return (_incomingGoodDataLimit > 0)
|
||||
&& (_incomingGoodDataOffset < _incoming.size());
|
||||
}
|
||||
|
||||
int64 TlsSocket::read(char *buffer, int64 maxLength) {
|
||||
return _socket.read(buffer, maxLength);
|
||||
int64 TlsSocket::read(bytes::span buffer) {
|
||||
auto written = int64(0);
|
||||
while (_incomingGoodDataLimit) {
|
||||
const auto available = std::min(
|
||||
_incomingGoodDataLimit,
|
||||
_incoming.size() - _incomingGoodDataOffset);
|
||||
if (available <= 0) {
|
||||
return written;
|
||||
}
|
||||
const auto write = std::min(index_type(available), buffer.size());
|
||||
if (write <= 0) {
|
||||
return written;
|
||||
}
|
||||
bytes::copy(
|
||||
buffer,
|
||||
bytes::make_span(_incoming).subspan(
|
||||
_incomingGoodDataOffset,
|
||||
write));
|
||||
written += write;
|
||||
buffer = buffer.subspan(write);
|
||||
_incomingGoodDataLimit -= write;
|
||||
_incomingGoodDataOffset += write;
|
||||
if (_incomingGoodDataLimit) {
|
||||
return written;
|
||||
}
|
||||
shiftIncomingBy(base::take(_incomingGoodDataOffset));
|
||||
if (!checkNextPacket()) {
|
||||
_state = State::Error;
|
||||
InvokeQueued(this, [=] { handleError(); });
|
||||
return written;
|
||||
}
|
||||
}
|
||||
return written;
|
||||
}
|
||||
|
||||
int64 TlsSocket::write(const char *buffer, int64 length) {
|
||||
return _socket.write(buffer, length);
|
||||
void TlsSocket::write(bytes::const_span prefix, bytes::const_span buffer) {
|
||||
Expects(!buffer.empty());
|
||||
|
||||
if (!isConnected()) {
|
||||
return;
|
||||
}
|
||||
if (!prefix.empty()) {
|
||||
_socket.write(kClientPrefix.data(), kClientPrefix.size());
|
||||
}
|
||||
while (!buffer.empty()) {
|
||||
const auto write = std::min(
|
||||
kClientPartSize - prefix.size(),
|
||||
buffer.size());
|
||||
_socket.write(kClientHeader.data(), kClientHeader.size());
|
||||
const auto size = qToBigEndian(uint16(prefix.size() + write));
|
||||
_socket.write(reinterpret_cast<const char*>(&size), sizeof(size));
|
||||
if (!prefix.empty()) {
|
||||
_socket.write(
|
||||
reinterpret_cast<const char*>(prefix.data()),
|
||||
prefix.size());
|
||||
prefix = bytes::const_span();
|
||||
}
|
||||
_socket.write(
|
||||
reinterpret_cast<const char*>(buffer.data()),
|
||||
write);
|
||||
buffer = buffer.subspan(write);
|
||||
}
|
||||
}
|
||||
|
||||
int32 TlsSocket::debugState() {
|
||||
|
|
|
@ -12,7 +12,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
|
|||
namespace MTP {
|
||||
namespace internal {
|
||||
|
||||
class TlsSocket : public AbstractSocket {
|
||||
class TlsSocket final : public AbstractSocket {
|
||||
public:
|
||||
TlsSocket(
|
||||
not_null<QThread*> thread,
|
||||
|
@ -22,8 +22,8 @@ public:
|
|||
void connectToHost(const QString &address, int port) override;
|
||||
bool isConnected() override;
|
||||
bool hasBytesAvailable() override;
|
||||
int64 read(char *buffer, int64 maxLength) override;
|
||||
int64 write(const char *buffer, int64 length) override;
|
||||
int64 read(bytes::span buffer) override;
|
||||
void write(bytes::const_span prefix, bytes::const_span buffer) override;
|
||||
|
||||
int32 debugState() override;
|
||||
|
||||
|
@ -32,8 +32,7 @@ private:
|
|||
NotConnected,
|
||||
Connecting,
|
||||
WaitingHello,
|
||||
Ready,
|
||||
Working,
|
||||
Connected,
|
||||
Error,
|
||||
};
|
||||
|
||||
|
@ -47,11 +46,15 @@ private:
|
|||
void checkHelloParts34(int parts123Size);
|
||||
void checkHelloDigest();
|
||||
void readData();
|
||||
[[nodiscard]] bool checkNextPacket();
|
||||
void shiftIncomingBy(int amount);
|
||||
|
||||
QTcpSocket _socket;
|
||||
bytes::vector _key;
|
||||
State _state = State::NotConnected;
|
||||
QByteArray _incoming;
|
||||
int _incomingGoodDataOffset = 0;
|
||||
int _incomingGoodDataLimit = 0;
|
||||
int16 _serverHelloLength = 0;
|
||||
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue