Check keys that receive -404 error codes.

This commit is contained in:
John Preston 2019-11-14 16:34:58 +03:00
parent 1524b4a930
commit 7243fb52ad
15 changed files with 460 additions and 159 deletions

View File

@ -167,35 +167,6 @@ T rand_value() {
return result; return result;
} }
class ReadLockerAttempt {
public:
ReadLockerAttempt(not_null<QReadWriteLock*> lock) : _lock(lock), _locked(_lock->tryLockForRead()) {
}
ReadLockerAttempt(const ReadLockerAttempt &other) = delete;
ReadLockerAttempt &operator=(const ReadLockerAttempt &other) = delete;
ReadLockerAttempt(ReadLockerAttempt &&other) : _lock(other._lock), _locked(base::take(other._locked)) {
}
ReadLockerAttempt &operator=(ReadLockerAttempt &&other) {
_lock = other._lock;
_locked = base::take(other._locked);
return *this;
}
~ReadLockerAttempt() {
if (_locked) {
_lock->unlock();
}
}
operator bool() const {
return _locked;
}
private:
not_null<QReadWriteLock*> _lock;
bool _locked = false;
};
static const QRegularExpression::PatternOptions reMultiline(QRegularExpression::DotMatchesEverythingOption | QRegularExpression::MultilineOption); static const QRegularExpression::PatternOptions reMultiline(QRegularExpression::DotMatchesEverythingOption | QRegularExpression::MultilineOption);
template <typename T> template <typename T>

View File

@ -19,6 +19,9 @@ AuthKey::AuthKey(Type type, DcId dcId, const Data &data)
, _dcId(dcId) , _dcId(dcId)
, _key(data) { , _key(data) {
countKeyId(); countKeyId();
if (type == Type::Generated) {
_lastCheckTime = crl::now();
}
} }
AuthKey::AuthKey(const Data &data) : _type(Type::Local), _key(data) { AuthKey::AuthKey(const Data &data) : _type(Type::Local), _key(data) {
@ -111,6 +114,14 @@ bool AuthKey::equals(const std::shared_ptr<AuthKey> &other) const {
return other ? (_key == other->_key) : false; return other ? (_key == other->_key) : false;
} }
crl::time AuthKey::lastCheckTime() const {
return _lastCheckTime;
}
void AuthKey::setLastCheckTime(crl::time time) {
_lastCheckTime = time;
}
void AuthKey::FillData(Data &authKey, bytes::const_span computedAuthKey) { void AuthKey::FillData(Data &authKey, bytes::const_span computedAuthKey) {
auto computedAuthKeySize = computedAuthKey.size(); auto computedAuthKeySize = computedAuthKey.size();
Assert(computedAuthKeySize <= kSize); Assert(computedAuthKeySize <= kSize);

View File

@ -30,18 +30,21 @@ public:
AuthKey(const AuthKey &other) = delete; AuthKey(const AuthKey &other) = delete;
AuthKey &operator=(const AuthKey &other) = delete; AuthKey &operator=(const AuthKey &other) = delete;
Type type() const; [[nodiscard]] Type type() const;
int dcId() const; [[nodiscard]] int dcId() const;
KeyId keyId() const; [[nodiscard]] KeyId keyId() const;
void prepareAES_oldmtp(const MTPint128 &msgKey, MTPint256 &aesKey, MTPint256 &aesIV, bool send) const; void prepareAES_oldmtp(const MTPint128 &msgKey, MTPint256 &aesKey, MTPint256 &aesIV, bool send) const;
void prepareAES(const MTPint128 &msgKey, MTPint256 &aesKey, MTPint256 &aesIV, bool send) const; void prepareAES(const MTPint128 &msgKey, MTPint256 &aesKey, MTPint256 &aesIV, bool send) const;
const void *partForMsgKey(bool send) const; [[nodiscard]] const void *partForMsgKey(bool send) const;
void write(QDataStream &to) const; void write(QDataStream &to) const;
bytes::const_span data() const; [[nodiscard]] bytes::const_span data() const;
bool equals(const std::shared_ptr<AuthKey> &other) const; [[nodiscard]] bool equals(const std::shared_ptr<AuthKey> &other) const;
[[nodiscard]] crl::time lastCheckTime() const;
void setLastCheckTime(crl::time time);
static void FillData(Data &authKey, bytes::const_span computedAuthKey); static void FillData(Data &authKey, bytes::const_span computedAuthKey);
@ -52,6 +55,7 @@ private:
DcId _dcId = 0; DcId _dcId = 0;
Data _key = { { gsl::byte{} } }; Data _key = { { gsl::byte{} } };
KeyId _keyId = 0; KeyId _keyId = 0;
crl::time _lastCheckTime = 0;
}; };

View File

@ -8,6 +8,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "mtproto/connection.h" #include "mtproto/connection.h"
#include "mtproto/details/mtproto_dc_key_creator.h" #include "mtproto/details/mtproto_dc_key_creator.h"
#include "mtproto/details/mtproto_dc_key_checker.h"
#include "mtproto/session.h" #include "mtproto/session.h"
#include "mtproto/rsa_public_key.h" #include "mtproto/rsa_public_key.h"
#include "mtproto/rpc_sender.h" #include "mtproto/rpc_sender.h"
@ -94,7 +95,8 @@ void wrapInvokeAfter(SecureRequest &to, const SecureRequest &from, const Request
} // namespace } // namespace
Connection::Connection(not_null<Instance*> instance) : _instance(instance) { Connection::Connection(not_null<Instance*> instance)
: _instance(instance) {
} }
void Connection::start(SessionData *sessionData, ShiftedDcId shiftedDcId) { void Connection::start(SessionData *sessionData, ShiftedDcId shiftedDcId) {
@ -457,27 +459,31 @@ void ConnectionPrivate::resetSession() { // recreate all msg_id and msg_seqno
emit sessionResetDone(); emit sessionResetDone();
} }
mtpMsgId ConnectionPrivate::prepareToSend(SecureRequest &request, mtpMsgId currentLastId) { mtpMsgId ConnectionPrivate::prepareToSend(
if (request->size() < 9) return 0; SecureRequest &request,
mtpMsgId msgId = *(mtpMsgId*)(request->constData() + 4); mtpMsgId currentLastId) {
if (msgId) { // resending this request if (request->size() < 9) {
return 0;
}
if (const auto msgId = request.getMsgId()) {
// resending this request
QWriteLocker locker(_sessionData->toResendMutex()); QWriteLocker locker(_sessionData->toResendMutex());
auto &toResend = _sessionData->toResendMap(); auto &toResend = _sessionData->toResendMap();
const auto i = toResend.find(msgId); const auto i = toResend.find(msgId);
if (i != toResend.cend()) { if (i != toResend.cend()) {
toResend.erase(i); toResend.erase(i);
} }
} else { return msgId;
msgId = *(mtpMsgId*)(request->data() + 4) = currentLastId;
*(request->data() + 6) = _sessionData->nextRequestSeqNumber(request.needAck());
} }
return msgId; request.setMsgId(currentLastId);
request.setSeqNo(_sessionData->nextRequestSeqNumber(request.needAck()));
return currentLastId;
} }
mtpMsgId ConnectionPrivate::replaceMsgId(SecureRequest &request, mtpMsgId newId) { mtpMsgId ConnectionPrivate::replaceMsgId(SecureRequest &request, mtpMsgId newId) {
if (request->size() < 9) return 0; if (request->size() < 9) return 0;
mtpMsgId oldMsgId = *(mtpMsgId*)(request->constData() + 4); const auto oldMsgId = request.getMsgId();
if (oldMsgId != newId) { if (oldMsgId != newId) {
if (oldMsgId) { if (oldMsgId) {
QWriteLocker locker(_sessionData->toResendMutex()); QWriteLocker locker(_sessionData->toResendMutex());
@ -530,9 +536,9 @@ mtpMsgId ConnectionPrivate::replaceMsgId(SecureRequest &request, mtpMsgId newId)
} }
} }
} else { } else {
*(request->data() + 6) = _sessionData->nextRequestSeqNumber(request.needAck()); request.setSeqNo(_sessionData->nextRequestSeqNumber(request.needAck()));
} }
*(mtpMsgId*)(request->data() + 4) = newId; request.setMsgId(newId);
} }
return newId; return newId;
} }
@ -562,15 +568,25 @@ void ConnectionPrivate::tryToSend() {
auto needsLayer = !_connectionOptions->inited; auto needsLayer = !_connectionOptions->inited;
auto state = getState(); auto state = getState();
auto prependOnly = (state != ConnectedState); auto sendOnlyFirstPing = (state != ConnectedState);
if (sendOnlyFirstPing && !_pingIdToSend) {
DEBUG_LOG(("MTP Info: dc %1 not sending, waiting for Connected state, state: %2").arg(_shiftedDcId).arg(state));
return; // just do nothing, if is not connected yet
}
auto pingRequest = SecureRequest(); auto pingRequest = SecureRequest();
auto ackRequest = SecureRequest();
auto resendRequest = SecureRequest();
auto stateRequest = SecureRequest();
auto httpWaitRequest = SecureRequest();
auto checkDcKeyRequest = SecureRequest();
if (_shiftedDcId == BareDcId(_shiftedDcId)) { // main session if (_shiftedDcId == BareDcId(_shiftedDcId)) { // main session
if (!prependOnly && !_pingIdToSend && !_pingId && _pingSendAt <= crl::now()) { if (!sendOnlyFirstPing && !_pingIdToSend && !_pingId && _pingSendAt <= crl::now()) {
_pingIdToSend = rand_value<mtpPingId>(); _pingIdToSend = rand_value<mtpPingId>();
} }
} }
if (_pingIdToSend) { if (_pingIdToSend) {
if (prependOnly || _shiftedDcId != BareDcId(_shiftedDcId)) { if (sendOnlyFirstPing || _shiftedDcId != BareDcId(_shiftedDcId)) {
pingRequest = SecureRequest::Serialize(MTPPing( pingRequest = SecureRequest::Serialize(MTPPing(
MTP_long(_pingIdToSend) MTP_long(_pingIdToSend)
)); ));
@ -584,44 +600,28 @@ void ConnectionPrivate::tryToSend() {
"ping_id: %1").arg(_pingIdToSend)); "ping_id: %1").arg(_pingIdToSend));
} }
pingRequest->msDate = crl::now(); // > 0 - can send without container
_pingSendAt = pingRequest->msDate + kPingSendAfter; _pingSendAt = pingRequest->msDate + kPingSendAfter;
pingRequest->requestId = 0; // dont add to haveSent / wereAcked maps if (_shiftedDcId == BareDcId(_shiftedDcId) && !sendOnlyFirstPing) { // main session
if (_shiftedDcId == BareDcId(_shiftedDcId) && !prependOnly) { // main session
_pingSender.callOnce(kPingSendAfterForce); _pingSender.callOnce(kPingSendAfterForce);
} }
_pingId = base::take(_pingIdToSend);
_pingId = _pingIdToSend;
_pingIdToSend = 0;
} else { } else {
if (prependOnly) { DEBUG_LOG(("MTP Info: dc %1 trying to send after ping, state: %2").arg(_shiftedDcId).arg(state));
DEBUG_LOG(("MTP Info: dc %1 not sending, waiting for Connected state, state: %2").arg(_shiftedDcId).arg(state)); }
return; // just do nothing, if is not connected yet
} else { if (!sendOnlyFirstPing) {
DEBUG_LOG(("MTP Info: dc %1 trying to send after ping, state: %2").arg(_shiftedDcId).arg(state)); if (!_ackRequestData.isEmpty()) {
ackRequest = SecureRequest::Serialize(MTPMsgsAck(
MTP_msgs_ack(MTP_vector<MTPlong>(
base::take(_ackRequestData)))));
}
if (!_resendRequestData.isEmpty()) {
resendRequest = SecureRequest::Serialize(MTPMsgResendReq(
MTP_msg_resend_req(MTP_vector<MTPlong>(
base::take(_resendRequestData)))));
} }
}
SecureRequest ackRequest, resendRequest, stateRequest, httpWaitRequest; auto stateReq = QVector<MTPlong>();
if (!prependOnly && !_ackRequestData.isEmpty()) {
ackRequest = SecureRequest::Serialize(MTPMsgsAck(
MTP_msgs_ack(MTP_vector<MTPlong>(_ackRequestData))));
ackRequest->msDate = crl::now(); // > 0 - can send without container
ackRequest->requestId = 0; // dont add to haveSent / wereAcked maps
_ackRequestData.clear();
}
if (!prependOnly && !_resendRequestData.isEmpty()) {
resendRequest = SecureRequest::Serialize(MTPMsgResendReq(
MTP_msg_resend_req(MTP_vector<MTPlong>(_resendRequestData))));
resendRequest->msDate = crl::now(); // > 0 - can send without container
resendRequest->requestId = 0; // dont add to haveSent / wereAcked maps
_resendRequestData.clear();
}
if (!prependOnly) {
QVector<MTPlong> stateReq;
{ {
QWriteLocker locker(_sessionData->stateRequestMutex()); QWriteLocker locker(_sessionData->stateRequestMutex());
auto &ids = _sessionData->stateRequestMap(); auto &ids = _sessionData->stateRequestMap();
@ -636,14 +636,30 @@ void ConnectionPrivate::tryToSend() {
if (!stateReq.isEmpty()) { if (!stateReq.isEmpty()) {
stateRequest = SecureRequest::Serialize(MTPMsgsStateReq( stateRequest = SecureRequest::Serialize(MTPMsgsStateReq(
MTP_msgs_state_req(MTP_vector<MTPlong>(stateReq)))); MTP_msgs_state_req(MTP_vector<MTPlong>(stateReq))));
stateRequest->msDate = crl::now(); // > 0 - can send without container // Add to haveSent / wereAcked maps, but don't add to requestMap.
stateRequest->requestId = GetNextRequestId();// add to haveSent / wereAcked maps, but don't add to requestMap stateRequest->requestId = GetNextRequestId();
} }
if (_connection->usingHttpWait()) { if (_connection->usingHttpWait()) {
httpWaitRequest = SecureRequest::Serialize(MTPHttpWait( httpWaitRequest = SecureRequest::Serialize(MTPHttpWait(
MTP_http_wait(MTP_int(100), MTP_int(30), MTP_int(25000)))); MTP_http_wait(MTP_int(100), MTP_int(30), MTP_int(25000))));
httpWaitRequest->msDate = crl::now(); // > 0 - can send without container }
httpWaitRequest->requestId = 0; // dont add to haveSent / wereAcked maps if (!_keyChecker) {
if (const auto &keyForCheck = _sessionData->getKeyForCheck()) {
_keyChecker = std::make_unique<details::DcKeyChecker>(
_instance,
_shiftedDcId,
keyForCheck);
checkDcKeyRequest = _keyChecker->prepareRequest(
_sessionData->getKey(),
_sessionData->getSessionId());
// This is a special request with msgId used inside the message
// body, so it is prepared already with a msgId and we place
// seqNo for it manually here.
checkDcKeyRequest.setSeqNo(
_sessionData->nextRequestSeqNumber(
checkDcKeyRequest.needAck()));
}
} }
} }
@ -698,8 +714,12 @@ void ConnectionPrivate::tryToSend() {
QWriteLocker locker1(_sessionData->toSendMutex()); QWriteLocker locker1(_sessionData->toSendMutex());
auto toSendDummy = PreRequestMap(); auto toSendDummy = PreRequestMap();
auto &toSend = prependOnly ? toSendDummy : _sessionData->toSendMap(); auto &toSend = sendOnlyFirstPing
if (prependOnly) locker1.unlock(); ? toSendDummy
: _sessionData->toSendMap();
if (sendOnlyFirstPing) {
locker1.unlock();
}
uint32 toSendCount = toSend.size(); uint32 toSendCount = toSend.size();
if (pingRequest) ++toSendCount; if (pingRequest) ++toSendCount;
@ -707,13 +727,28 @@ void ConnectionPrivate::tryToSend() {
if (resendRequest) ++toSendCount; if (resendRequest) ++toSendCount;
if (stateRequest) ++toSendCount; if (stateRequest) ++toSendCount;
if (httpWaitRequest) ++toSendCount; if (httpWaitRequest) ++toSendCount;
if (checkDcKeyRequest) ++toSendCount;
if (!toSendCount) return; // nothing to send if (!toSendCount) {
return; // nothing to send
}
auto first = pingRequest ? pingRequest : (ackRequest ? ackRequest : (resendRequest ? resendRequest : (stateRequest ? stateRequest : (httpWaitRequest ? httpWaitRequest : toSend.cbegin().value())))); const auto first = pingRequest
? pingRequest
: ackRequest
? ackRequest
: resendRequest
? resendRequest
: stateRequest
? stateRequest
: httpWaitRequest
? httpWaitRequest
: checkDcKeyRequest
? checkDcKeyRequest
: toSend.cbegin().value();
if (toSendCount == 1 && first->msDate > 0) { // if can send without container if (toSendCount == 1 && first->msDate > 0) { // if can send without container
toSendRequest = first; toSendRequest = first;
if (!prependOnly) { if (!sendOnlyFirstPing) {
toSend.clear(); toSend.clear();
locker1.unlock(); locker1.unlock();
} }
@ -774,6 +809,7 @@ void ConnectionPrivate::tryToSend() {
if (resendRequest) containerSize += resendRequest.messageSize(); if (resendRequest) containerSize += resendRequest.messageSize();
if (stateRequest) containerSize += stateRequest.messageSize(); if (stateRequest) containerSize += stateRequest.messageSize();
if (httpWaitRequest) containerSize += httpWaitRequest.messageSize(); if (httpWaitRequest) containerSize += httpWaitRequest.messageSize();
if (checkDcKeyRequest) containerSize += checkDcKeyRequest.messageSize();
for (auto i = toSend.begin(), e = toSend.end(); i != e; ++i) { for (auto i = toSend.begin(), e = toSend.end(); i != e; ++i) {
containerSize += i.value().messageSize(); containerSize += i.value().messageSize();
if (needsLayer && i.value()->needsLayer) { if (needsLayer && i.value()->needsLayer) {
@ -815,7 +851,7 @@ void ConnectionPrivate::tryToSend() {
if (pingRequest) { if (pingRequest) {
_pingMsgId = placeToContainer(toSendRequest, bigMsgId, haveSentArr, pingRequest); _pingMsgId = placeToContainer(toSendRequest, bigMsgId, haveSentArr, pingRequest);
needAnyResponse = true; needAnyResponse = true;
} else if (resendRequest || stateRequest) { } else if (resendRequest || stateRequest || checkDcKeyRequest) {
needAnyResponse = true; needAnyResponse = true;
} }
for (auto i = toSend.begin(), e = toSend.end(); i != e; ++i) { for (auto i = toSend.begin(), e = toSend.end(); i != e; ++i) {
@ -869,6 +905,7 @@ void ConnectionPrivate::tryToSend() {
if (resendRequest) placeToContainer(toSendRequest, bigMsgId, haveSentArr, resendRequest); if (resendRequest) placeToContainer(toSendRequest, bigMsgId, haveSentArr, resendRequest);
if (ackRequest) placeToContainer(toSendRequest, bigMsgId, haveSentArr, ackRequest); if (ackRequest) placeToContainer(toSendRequest, bigMsgId, haveSentArr, ackRequest);
if (httpWaitRequest) placeToContainer(toSendRequest, bigMsgId, haveSentArr, httpWaitRequest); if (httpWaitRequest) placeToContainer(toSendRequest, bigMsgId, haveSentArr, httpWaitRequest);
if (checkDcKeyRequest) placeToContainer(toSendRequest, bigMsgId, haveSentArr, checkDcKeyRequest);
mtpMsgId contMsgId = prepareToSend(toSendRequest, bigMsgId); mtpMsgId contMsgId = prepareToSend(toSendRequest, bigMsgId);
*(mtpMsgId*)(haveSentIdsWrap->data() + 4) = contMsgId; *(mtpMsgId*)(haveSentIdsWrap->data() + 4) = contMsgId;
@ -1945,6 +1982,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr
} }
} }
if (_keyChecker && _keyChecker->handleResponse(reqMsgId, response)) {
return HandleResult::Success;
}
auto requestId = wasSent(reqMsgId.v); auto requestId = wasSent(reqMsgId.v);
if (requestId && requestId != mtpRequestId(0xFFFFFFFF)) { if (requestId && requestId != mtpRequestId(0xFFFFFFFF)) {
// Save rpc_result for processing in the main thread. // Save rpc_result for processing in the main thread.
@ -2437,7 +2477,8 @@ void ConnectionPrivate::createDcKey() {
DEBUG_LOG(("AuthKey Info: auth key gen succeed, id: %1, server salt: %2").arg(authKey->keyId()).arg(result->serverSalt)); DEBUG_LOG(("AuthKey Info: auth key gen succeed, id: %1, server salt: %2").arg(authKey->keyId()).arg(result->serverSalt));
_sessionData->owner()->notifyKeyCreated(std::move(authKey)); // slot will call authKeyCreated() // slot will call authKeyCreated().
_sessionData->owner()->notifyKeyCreated(std::move(authKey));
_sessionData->clear(_instance); _sessionData->clear(_instance);
unlockKey(); unlockKey();
} else if (result.error() == Error::UnknownPublicKey) { } else if (result.error() == Error::UnknownPublicKey) {
@ -2539,7 +2580,13 @@ bool ConnectionPrivate::sendSecureRequest(
SecureRequest &&request, SecureRequest &&request,
bool needAnyResponse, bool needAnyResponse,
QReadLocker &lockFinished) { QReadLocker &lockFinished) {
request.addPadding(_connection->requiresExtendedPadding()); #ifdef TDESKTOP_MTPROTO_OLD
const auto oldPadding = true;
#else // TDESKTOP_MTPROTO_OLD
const auto oldPadding = false;
#endif // TDESKTOP_MTPROTO_OLD
request.addPadding(_connection->requiresExtendedPadding(), oldPadding);
uint32 fullSize = request->size(); uint32 fullSize = request->size();
if (fullSize < 9) { if (fullSize < 9) {
return false; return false;
@ -2660,14 +2707,18 @@ mtpRequestId ConnectionPrivate::wasSent(mtpMsgId msgId) const {
void ConnectionPrivate::lockKey() { void ConnectionPrivate::lockKey() {
unlockKey(); unlockKey();
_sessionData->keyMutex()->lockForWrite(); if (const auto mutex = _sessionData->keyMutex()) {
mutex->lockForWrite();
}
_myKeyLock = true; _myKeyLock = true;
} }
void ConnectionPrivate::unlockKey() { void ConnectionPrivate::unlockKey() {
if (_myKeyLock) { if (_myKeyLock) {
_myKeyLock = false; _myKeyLock = false;
_sessionData->keyMutex()->unlock(); if (const auto mutex = _sessionData->keyMutex()) {
mutex->unlock();
}
} }
} }
@ -2683,8 +2734,7 @@ void ConnectionPrivate::stop() {
if (_sessionData) { if (_sessionData) {
if (_myKeyLock) { if (_myKeyLock) {
_sessionData->owner()->notifyKeyCreated(AuthKeyPtr()); // release key lock, let someone else create it _sessionData->owner()->notifyKeyCreated(AuthKeyPtr()); // release key lock, let someone else create it
_sessionData->keyMutex()->unlock(); unlockKey();
_myKeyLock = false;
} }
_sessionData = nullptr; _sessionData = nullptr;
} }

View File

@ -263,6 +263,7 @@ private:
bool _myKeyLock = false; bool _myKeyLock = false;
std::unique_ptr<details::DcKeyCreator> _keyCreator; std::unique_ptr<details::DcKeyCreator> _keyCreator;
std::unique_ptr<details::DcKeyChecker> _keyChecker;
}; };

View File

@ -12,12 +12,12 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
namespace MTP { namespace MTP {
namespace { namespace {
uint32 CountPaddingAmountInInts(uint32 requestSize, bool extended) { uint32 CountPaddingPrimesCount(uint32 requestSize, bool extended, bool old) {
#ifdef TDESKTOP_MTPROTO_OLD if (old) {
return ((8 + requestSize) & 0x03) return ((8 + requestSize) & 0x03)
? (4 - ((8 + requestSize) & 0x03)) ? (4 - ((8 + requestSize) & 0x03))
: 0; : 0;
#else // TDESKTOP_MTPROTO_OLD }
auto result = ((8 + requestSize) & 0x03) auto result = ((8 + requestSize) & 0x03)
? (4 - ((8 + requestSize) & 0x03)) ? (4 - ((8 + requestSize) & 0x03))
: 0; : 0;
@ -33,7 +33,6 @@ uint32 CountPaddingAmountInInts(uint32 requestSize, bool extended) {
} }
return result; return result;
#endif // TDESKTOP_MTPROTO_OLD
} }
} // namespace } // namespace
@ -49,6 +48,7 @@ SecureRequest SecureRequest::Prepare(uint32 size, uint32 reserveSize) {
result->reserve(kMessageBodyPosition + finalSize); result->reserve(kMessageBodyPosition + finalSize);
result->resize(kMessageBodyPosition); result->resize(kMessageBodyPosition);
result->back() = (size << 2); result->back() = (size << 2);
result->msDate = crl::now(); // > 0 - can send without container
return result; return result;
} }
@ -68,11 +68,39 @@ SecureRequest::operator bool() const {
return (_data != nullptr); return (_data != nullptr);
} }
void SecureRequest::addPadding(bool extended) { void SecureRequest::setMsgId(mtpMsgId msgId) {
if (_data->size() <= kMessageBodyPosition) return; Expects(_data != nullptr);
memcpy(_data->data() + kMessageIdPosition, &msgId, sizeof(mtpMsgId));
}
mtpMsgId SecureRequest::getMsgId() const {
Expects(_data != nullptr);
return *(mtpMsgId*)(_data->constData() + kMessageIdPosition);
}
void SecureRequest::setSeqNo(uint32 seqNo) {
Expects(_data != nullptr);
(*_data)[kSeqNoPosition] = mtpPrime(seqNo);
}
uint32 SecureRequest::getSeqNo() const {
Expects(_data != nullptr);
return uint32((*_data)[kSeqNoPosition]);
}
void SecureRequest::addPadding(bool extended, bool old) {
Expects(_data != nullptr);
if (_data->size() <= kMessageBodyPosition) {
return;
}
const auto requestSize = (tl::count_length(*this) >> 2); const auto requestSize = (tl::count_length(*this) >> 2);
const auto padding = CountPaddingAmountInInts(requestSize, extended); const auto padding = CountPaddingPrimesCount(requestSize, extended, old);
const auto fullSize = kMessageBodyPosition + requestSize + padding; const auto fullSize = kMessageBodyPosition + requestSize + padding;
if (uint32(_data->size()) != fullSize) { if (uint32(_data->size()) != fullSize) {
_data->resize(fullSize); _data->resize(fullSize);
@ -85,6 +113,8 @@ void SecureRequest::addPadding(bool extended) {
} }
uint32 SecureRequest::messageSize() const { uint32 SecureRequest::messageSize() const {
Expects(_data != nullptr);
if (_data->size() <= kMessageBodyPosition) { if (_data->size() <= kMessageBodyPosition) {
return 0; return 0;
} }
@ -93,13 +123,17 @@ uint32 SecureRequest::messageSize() const {
} }
bool SecureRequest::isSentContainer() const { bool SecureRequest::isSentContainer() const {
Expects(_data != nullptr);
if (_data->size() <= kMessageBodyPosition) { if (_data->size() <= kMessageBodyPosition) {
return false; return false;
} }
return (!_data->msDate && !(*_data)[kSeqNoPosition]); // msDate = 0, seqNo = 0 return (!_data->msDate && !getSeqNo()); // msDate = 0, seqNo = 0
} }
bool SecureRequest::isStateRequest() const { bool SecureRequest::isStateRequest() const {
Expects(_data != nullptr);
if (_data->size() <= kMessageBodyPosition) { if (_data->size() <= kMessageBodyPosition) {
return false; return false;
} }
@ -108,6 +142,8 @@ bool SecureRequest::isStateRequest() const {
} }
bool SecureRequest::needAck() const { bool SecureRequest::needAck() const {
Expects(_data != nullptr);
if (_data->size() <= kMessageBodyPosition) { if (_data->size() <= kMessageBodyPosition) {
return false; return false;
} }

View File

@ -138,9 +138,9 @@ public:
static constexpr auto kSaltInts = 2; static constexpr auto kSaltInts = 2;
static constexpr auto kSessionIdInts = 2; static constexpr auto kSessionIdInts = 2;
static constexpr auto kMessageIdPosition = kSaltInts + kSessionIdInts;
static constexpr auto kMessageIdInts = 2; static constexpr auto kMessageIdInts = 2;
static constexpr auto kSeqNoPosition = kSaltInts static constexpr auto kSeqNoPosition = kMessageIdPosition
+ kSessionIdInts
+ kMessageIdInts; + kMessageIdInts;
static constexpr auto kSeqNoInts = 1; static constexpr auto kSeqNoInts = 1;
static constexpr auto kMessageLengthPosition = kSeqNoPosition static constexpr auto kMessageLengthPosition = kSeqNoPosition
@ -168,13 +168,19 @@ public:
SecureRequestData &operator*() const; SecureRequestData &operator*() const;
explicit operator bool() const; explicit operator bool() const;
void addPadding(bool extended); void setMsgId(mtpMsgId msgId);
uint32 messageSize() const; [[nodiscard]] mtpMsgId getMsgId() const;
void setSeqNo(uint32 seqNo);
[[nodiscard]] uint32 getSeqNo() const;
void addPadding(bool extended, bool old);
[[nodiscard]] uint32 messageSize() const;
// "request-like" wrap for msgIds vector // "request-like" wrap for msgIds vector
bool isSentContainer() const; [[nodiscard]] bool isSentContainer() const;
bool isStateRequest() const; [[nodiscard]] bool isStateRequest() const;
bool needAck() const; [[nodiscard]] bool needAck() const;
using ResponseType = void; // don't know real response type =( using ResponseType = void; // don't know real response type =(

View File

@ -33,6 +33,11 @@ public:
void setConnectionInited(bool connectionInited = true) { void setConnectionInited(bool connectionInited = true) {
QMutexLocker lock(&initLock); QMutexLocker lock(&initLock);
_connectionInited = connectionInited; _connectionInited = connectionInited;
lock.unlock();
if (connectionInited) {
emit connectionWasInited();
}
} }
signals: signals:

View File

@ -8,24 +8,143 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "mtproto/details/mtproto_dc_key_checker.h" #include "mtproto/details/mtproto_dc_key_checker.h"
#include "mtproto/mtp_instance.h" #include "mtproto/mtp_instance.h"
#include "base/unixtime.h"
#include "base/openssl_help.h"
#include "scheme.h"
#include <QtCore/QPointer> #include <QtCore/QPointer>
namespace MTP::details { namespace MTP::details {
namespace {
constexpr auto kBindKeyExpireTimeout = TimeId(3600);
[[nodiscard]] QByteArray EncryptBindAuthKeyInner(
const AuthKeyPtr &persistentKey,
mtpMsgId realMsgId,
const MTPBindAuthKeyInner &data) {
auto serialized = SecureRequest::Serialize(data);
serialized.setMsgId(realMsgId);
serialized.setSeqNo(0);
serialized.addPadding(false, true);
constexpr auto kMsgIdPosition = SecureRequest::kMessageIdPosition;
constexpr auto kMinMessageSize = 5;
const auto sizeInPrimes = serialized->size();
const auto messageSize = serialized.messageSize();
Assert(messageSize >= kMinMessageSize);
Assert(sizeInPrimes >= kMsgIdPosition + messageSize);
const auto sizeInBytes = sizeInPrimes * sizeof(mtpPrime);
const auto padding = sizeInBytes
- (kMsgIdPosition + messageSize) * sizeof(mtpPrime);
// session_id, salt - just random here.
bytes::set_random(bytes::make_span(*serialized).subspan(
0,
kMsgIdPosition * sizeof(mtpPrime)));
const auto hash = openssl::Sha1(bytes::make_span(*serialized).subspan(
0,
sizeInBytes - padding));
auto msgKey = MTPint128();
bytes::copy(
bytes::object_as_span(&msgKey),
bytes::make_span(hash).subspan(4));
constexpr auto kAuthKeyIdBytes = 2 * sizeof(mtpPrime);
constexpr auto kMessageKeyPosition = kAuthKeyIdBytes;
constexpr auto kMessageKeyBytes = 4 * sizeof(mtpPrime);
constexpr auto kPrefix = (kAuthKeyIdBytes + kMessageKeyBytes);
auto encrypted = QByteArray(kPrefix + sizeInBytes, Qt::Uninitialized);
*reinterpret_cast<uint64*>(encrypted.data()) = persistentKey->keyId();
*reinterpret_cast<MTPint128*>(encrypted.data() + kMessageKeyPosition)
= msgKey;
aesIgeEncrypt_oldmtp(
serialized->constData(),
encrypted.data() + kPrefix,
sizeInBytes,
persistentKey,
msgKey);
return encrypted;
}
} // namespace
DcKeyChecker::DcKeyChecker( DcKeyChecker::DcKeyChecker(
not_null<Instance*> instance, not_null<Instance*> instance,
DcId dcId, ShiftedDcId shiftedDcId,
const AuthKeyPtr &key, const AuthKeyPtr &persistentKey)
FnMut<void()> destroyMe)
: _instance(instance) : _instance(instance)
, _dcId(dcId) , _shiftedDcId(shiftedDcId)
, _key(key) , _persistentKey(persistentKey) {
, _destroyMe(std::move(destroyMe)) { }
SecureRequest DcKeyChecker::prepareRequest(
const AuthKeyPtr &temporaryKey,
uint64 sessionId) {
Expects(_requestMsgId == 0);
const auto nonce = openssl::RandomValue<uint64>();
_requestMsgId = base::unixtime::mtproto_msg_id();
auto result = SecureRequest::Serialize(MTPauth_BindTempAuthKey(
MTP_long(_persistentKey->keyId()),
MTP_long(nonce),
MTP_int(kBindKeyExpireTimeout),
MTP_bytes(EncryptBindAuthKeyInner(
_persistentKey,
_requestMsgId,
MTP_bind_auth_key_inner(
MTP_long(nonce),
MTP_long(temporaryKey->keyId()),
MTP_long(_persistentKey->keyId()),
MTP_long(sessionId),
MTP_int(kBindKeyExpireTimeout))))));
result.setMsgId(_requestMsgId);
return result;
}
bool DcKeyChecker::handleResponse(
MTPlong requestMsgId,
const mtpBuffer &response) {
Expects(!response.isEmpty());
if (!_requestMsgId || requestMsgId.v != _requestMsgId) {
return false;
}
const auto destroyed = [&] {
if (response[0] != mtpc_rpc_error) {
return false;
}
auto error = MTPRpcError();
auto from = response.begin();
const auto end = from + response.size();
if (!error.read(from, end)) {
return false;
}
return error.match([&](const MTPDrpc_error &data) {
return (data.verror_code().v == 400)
&& (data.verror_message().v == "ENCRYPTED_MESSAGE_INVALID");
});
}();
const auto instance = _instance;
const auto shiftedDcId = _shiftedDcId;
const auto keyId = _persistentKey->keyId();
_persistentKey->setLastCheckTime(crl::now());
crl::on_main(instance, [=] { crl::on_main(instance, [=] {
auto destroy = std::move(_destroyMe); instance->killSession(shiftedDcId);
destroy(); if (destroyed) {
instance->keyDestroyedOnServer(BareDcId(shiftedDcId), keyId);
}
}); });
_requestMsgId = 0;
return true;
} }
} // namespace MTP::details } // namespace MTP::details

View File

@ -16,19 +16,28 @@ class Instance;
namespace MTP::details { namespace MTP::details {
enum class DcKeyState {
MaybeExisting,
DefinitelyDestroyed,
};
class DcKeyChecker final { class DcKeyChecker final {
public: public:
DcKeyChecker( DcKeyChecker(
not_null<Instance*> instance, not_null<Instance*> instance,
DcId dcId, ShiftedDcId shiftedDcId,
const AuthKeyPtr &key, const AuthKeyPtr &persistentKey);
FnMut<void()> destroyMe);
[[nodiscard]] SecureRequest prepareRequest(
const AuthKeyPtr &temporaryKey,
uint64 sessionId);
bool handleResponse(MTPlong requestMsgId, const mtpBuffer &response);
private: private:
not_null<Instance*> _instance; const not_null<Instance*> _instance;
DcId _dcId = 0; const ShiftedDcId _shiftedDcId = 0;
AuthKeyPtr _key; const AuthKeyPtr _persistentKey;
FnMut<void()> _destroyMe; mtpMsgId _requestMsgId = 0;
}; };

View File

@ -76,12 +76,17 @@ template <typename PQInnerData>
constexpr auto kSkipPrimes = 6; constexpr auto kSkipPrimes = 6;
constexpr auto kMaxPrimes = 65; // 260 bytes constexpr auto kMaxPrimes = 65; // 260 bytes
const auto p_q_inner_size = tl::count_length(data); using BoxedPQInnerData = std::conditional_t<
tl::is_boxed_v<PQInnerData>,
PQInnerData,
tl::boxed<PQInnerData>>;
const auto boxed = BoxedPQInnerData(data);
const auto p_q_inner_size = tl::count_length(boxed);
const auto sizeInPrimes = (p_q_inner_size >> 2) + kSkipPrimes; const auto sizeInPrimes = (p_q_inner_size >> 2) + kSkipPrimes;
if (sizeInPrimes >= kMaxPrimes) { if (sizeInPrimes >= kMaxPrimes) {
auto tmp = mtpBuffer(); auto tmp = mtpBuffer();
tmp.reserve(sizeInPrimes); tmp.reserve(sizeInPrimes);
data.write(tmp); boxed.write(tmp);
LOG(("AuthKey Error: too large data for RSA encrypt, size %1").arg(sizeInPrimes * sizeof(mtpPrime))); LOG(("AuthKey Error: too large data for RSA encrypt, size %1").arg(sizeInPrimes * sizeof(mtpPrime)));
DEBUG_LOG(("AuthKey Error: bad data for RSA encrypt %1").arg(Logs::mb(&tmp[0], tmp.size() * 4).str())); DEBUG_LOG(("AuthKey Error: bad data for RSA encrypt %1").arg(Logs::mb(&tmp[0], tmp.size() * 4).str()));
return {}; // can't be 255-byte string return {}; // can't be 255-byte string
@ -90,7 +95,7 @@ template <typename PQInnerData>
auto encBuffer = mtpBuffer(); auto encBuffer = mtpBuffer();
encBuffer.reserve(kMaxPrimes); encBuffer.reserve(kMaxPrimes);
encBuffer.resize(kSkipPrimes); encBuffer.resize(kSkipPrimes);
data.write(encBuffer); boxed.write(encBuffer);
encBuffer.resize(kMaxPrimes); encBuffer.resize(kMaxPrimes);
const auto bytes = bytes::make_span(encBuffer); const auto bytes = bytes::make_span(encBuffer);

View File

@ -7,7 +7,6 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
*/ */
#include "mtproto/mtp_instance.h" #include "mtproto/mtp_instance.h"
#include "mtproto/details/mtproto_dc_key_checker.h"
#include "mtproto/session.h" #include "mtproto/session.h"
#include "mtproto/dc_options.h" #include "mtproto/dc_options.h"
#include "mtproto/dcenter.h" #include "mtproto/dcenter.h"
@ -34,6 +33,7 @@ namespace {
constexpr auto kConfigBecomesOldIn = 2 * 60 * crl::time(1000); constexpr auto kConfigBecomesOldIn = 2 * 60 * crl::time(1000);
constexpr auto kConfigBecomesOldForBlockedIn = 8 * crl::time(1000); constexpr auto kConfigBecomesOldForBlockedIn = 8 * crl::time(1000);
constexpr auto kCheckKeyEach = 60 * crl::time(1000);
} // namespace } // namespace
@ -134,6 +134,7 @@ public:
void performKeyDestroy(ShiftedDcId shiftedDcId); void performKeyDestroy(ShiftedDcId shiftedDcId);
void completedKeyDestroy(ShiftedDcId shiftedDcId); void completedKeyDestroy(ShiftedDcId shiftedDcId);
void checkMainDcKey(); void checkMainDcKey();
void keyDestroyedOnServer(DcId dcId, uint64 keyId);
void clearKilledSessions(); void clearKilledSessions();
void prepareToDestroy(); void prepareToDestroy();
@ -228,8 +229,6 @@ private:
base::Timer _checkDelayedTimer; base::Timer _checkDelayedTimer;
std::unique_ptr<details::DcKeyChecker> _mainDcKeyChecker;
// Debug flag to find out how we end up crashing. // Debug flag to find out how we end up crashing.
bool MustNotCreateSessions = false; bool MustNotCreateSessions = false;
@ -1509,10 +1508,11 @@ void Instance::Private::completedKeyDestroy(ShiftedDcId shiftedDcId) {
} }
void Instance::Private::checkMainDcKey() { void Instance::Private::checkMainDcKey() {
if (_mainDcKeyChecker) { const auto id = mainDcId();
const auto shiftedDcId = ShiftDcId(id, kCheckKeyDcShift);
if (_sessions.find(shiftedDcId) != _sessions.end()) {
return; return;
} }
const auto id = mainDcId();
const auto key = [&] { const auto key = [&] {
QReadLocker lock(&_keysForWriteLock); QReadLocker lock(&_keysForWriteLock);
const auto i = _keysForWrite.find(id); const auto i = _keysForWrite.find(id);
@ -1521,11 +1521,26 @@ void Instance::Private::checkMainDcKey() {
if (!key) { if (!key) {
return; return;
} }
_mainDcKeyChecker = std::make_unique<details::DcKeyChecker>( const auto lastCheckTime = key->lastCheckTime();
_instance, if (lastCheckTime > 0 && lastCheckTime + kCheckKeyEach >= crl::now()) {
id, return;
key, }
[=] { _mainDcKeyChecker = nullptr; }); _instance->sendDcKeyCheck(shiftedDcId, key);
}
void Instance::Private::keyDestroyedOnServer(DcId dcId, uint64 keyId) {
if (dcId == _mainDcId) {
for (const auto &[id, dc] : _dcenters) {
dc->destroyKey();
}
restart();
} else {
const auto i = _dcenters.find(dcId);
if (i != end(_dcenters)) {
i->second->destroyKey();
}
restart(dcId);
}
} }
void Instance::Private::setUpdatesHandler(RPCDoneHandlerPtr onDone) { void Instance::Private::setUpdatesHandler(RPCDoneHandlerPtr onDone) {
@ -1782,6 +1797,10 @@ void Instance::checkIfKeyWasDestroyed(ShiftedDcId shiftedDcId) {
}); });
} }
void Instance::keyDestroyedOnServer(DcId dcId, uint64 keyId) {
_private->keyDestroyedOnServer(dcId, keyId);
}
void Instance::sendRequest( void Instance::sendRequest(
mtpRequestId requestId, mtpRequestId requestId,
SecureRequest &&request, SecureRequest &&request,
@ -1805,6 +1824,11 @@ void Instance::sendAnything(ShiftedDcId shiftedDcId, crl::time msCanWait) {
session->sendAnything(msCanWait); session->sendAnything(msCanWait);
} }
void Instance::sendDcKeyCheck(ShiftedDcId shiftedDcId, const AuthKeyPtr &key) {
const auto session = _private->getSession(shiftedDcId);
session->sendDcKeyCheck(key);
}
Instance::~Instance() { Instance::~Instance() {
_private->prepareToDestroy(); _private->prepareToDestroy();
} }

View File

@ -135,6 +135,7 @@ public:
} }
void sendAnything(ShiftedDcId shiftedDcId = 0, crl::time msCanWait = 0); void sendAnything(ShiftedDcId shiftedDcId = 0, crl::time msCanWait = 0);
void sendDcKeyCheck(ShiftedDcId shiftedDcId, const AuthKeyPtr &key);
void restart(); void restart();
void restart(ShiftedDcId shiftedDcId); void restart(ShiftedDcId shiftedDcId);
@ -174,6 +175,7 @@ public:
bool isKeysDestroyer() const; bool isKeysDestroyer() const;
void scheduleKeyDestroy(ShiftedDcId shiftedDcId); void scheduleKeyDestroy(ShiftedDcId shiftedDcId);
void checkIfKeyWasDestroyed(ShiftedDcId shiftedDcId); void checkIfKeyWasDestroyed(ShiftedDcId shiftedDcId);
void keyDestroyedOnServer(DcId dcId, uint64 keyId);
void requestConfig(); void requestConfig();
void requestConfigIfOld(); void requestConfigIfOld();

View File

@ -7,6 +7,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
*/ */
#include "mtproto/session.h" #include "mtproto/session.h"
#include "mtproto/details/mtproto_dc_key_checker.h"
#include "mtproto/connection.h" #include "mtproto/connection.h"
#include "mtproto/dcenter.h" #include "mtproto/dcenter.h"
#include "mtproto/auth_key.h" #include "mtproto/auth_key.h"
@ -77,6 +78,10 @@ void SessionData::setKey(const AuthKeyPtr &key) {
} }
} }
void SessionData::setKeyForCheck(const AuthKeyPtr &key) {
_dcKeyForCheck = key;
}
void SessionData::notifyConnectionInited(const ConnectionOptions &options) { void SessionData::notifyConnectionInited(const ConnectionOptions &options) {
QWriteLocker locker(&_lock); QWriteLocker locker(&_lock);
if (options.cloudLangCode == _options.cloudLangCode if (options.cloudLangCode == _options.cloudLangCode
@ -157,7 +162,7 @@ void Session::start() {
} }
void Session::createDcData() { void Session::createDcData() {
if (_dc) { if (_dc || GetDcIdShift(_shiftedDcId) == kCheckKeyDcShift) {
return; return;
} }
_dc = _instance->getDcById(_shiftedDcId); _dc = _instance->getDcById(_shiftedDcId);
@ -212,7 +217,9 @@ void Session::refreshOptions() {
} }
void Session::reInitConnection() { void Session::reInitConnection() {
_dc->setConnectionInited(false); if (_dc) {
_dc->setConnectionInited(false);
}
_data.setConnectionInited(false); _data.setConnectionInited(false);
restart(); restart();
} }
@ -242,6 +249,11 @@ void Session::unpaused() {
} }
} }
void Session::sendDcKeyCheck(const AuthKeyPtr &key) {
_data.setKeyForCheck(key);
needToResumeAndSend();
}
void Session::sendAnything(qint64 msCanWait) { void Session::sendAnything(qint64 msCanWait) {
if (_killed) { if (_killed) {
DEBUG_LOG(("Session Error: can't send anything in a killed session")); DEBUG_LOG(("Session Error: can't send anything in a killed session"));
@ -550,10 +562,12 @@ void Session::sendPrepared(
} }
QReadWriteLock *Session::keyMutex() const { QReadWriteLock *Session::keyMutex() const {
return _dc->keyMutex(); return _dc ? _dc->keyMutex() : nullptr;
} }
void Session::authKeyCreatedForDC() { void Session::authKeyCreatedForDC() {
Expects(_dc != nullptr);
DEBUG_LOG(("AuthKey Info: Session::authKeyCreatedForDC slot, emitting authKeyCreated(), dcWithShift %1").arg(_shiftedDcId)); DEBUG_LOG(("AuthKey Info: Session::authKeyCreatedForDC slot, emitting authKeyCreated(), dcWithShift %1").arg(_shiftedDcId));
_data.setKey(_dc->getKey()); _data.setKey(_dc->getKey());
emit authKeyCreated(); emit authKeyCreated();
@ -561,31 +575,37 @@ void Session::authKeyCreatedForDC() {
void Session::notifyKeyCreated(AuthKeyPtr &&key) { void Session::notifyKeyCreated(AuthKeyPtr &&key) {
DEBUG_LOG(("AuthKey Info: Session::keyCreated(), setting, dcWithShift %1").arg(_shiftedDcId)); DEBUG_LOG(("AuthKey Info: Session::keyCreated(), setting, dcWithShift %1").arg(_shiftedDcId));
_dc->setKey(std::move(key)); if (_dc) {
_dc->setKey(std::move(key));
} else {
_data.setKey(std::move(key));
emit authKeyCreated();
}
} }
void Session::connectionWasInitedForDC() { void Session::connectionWasInitedForDC() {
Expects(_dc != nullptr);
DEBUG_LOG(("MTP Info: Session::connectionWasInitedForDC slot, dcWithShift %1").arg(_shiftedDcId)); DEBUG_LOG(("MTP Info: Session::connectionWasInitedForDC slot, dcWithShift %1").arg(_shiftedDcId));
_data.setConnectionInited(); _data.setConnectionInited();
} }
void Session::notifyDcConnectionInited() { void Session::notifyDcConnectionInited() {
DEBUG_LOG(("MTP Info: emitting MTProtoDC::connectionWasInited(), dcWithShift %1").arg(_shiftedDcId)); DEBUG_LOG(("MTP Info: emitting MTProtoDC::connectionWasInited(), dcWithShift %1").arg(_shiftedDcId));
_dc->setConnectionInited(); if (_dc) {
emit _dc->connectionWasInited(); _dc->setConnectionInited();
} else {
_data.setConnectionInited();
}
} }
void Session::destroyKey() { void Session::destroyKey() {
if (!_dc) { if (const auto key = _data.getKey()) {
return;
}
if (_data.getKey()) {
DEBUG_LOG(("MTP Info: destroying auth_key for dcWithShift %1").arg(_shiftedDcId)); DEBUG_LOG(("MTP Info: destroying auth_key for dcWithShift %1").arg(_shiftedDcId));
if (_data.getKey() == _dc->getKey()) { if (_dc && _dc->getKey() == key) {
_dc->destroyKey(); _dc->destroyKey();
} }
_data.setKey(AuthKeyPtr()); _data.setKey(nullptr);
} }
} }

View File

@ -184,6 +184,11 @@ public:
} }
void setKey(const AuthKeyPtr &key); void setKey(const AuthKeyPtr &key);
const AuthKeyPtr &getKeyForCheck() const {
return _dcKeyForCheck;
}
void setKeyForCheck(const AuthKeyPtr &key);
bool isCheckedKey() const { bool isCheckedKey() const {
QReadLocker locker(&_lock); QReadLocker locker(&_lock);
return _keyChecked; return _keyChecked;
@ -193,7 +198,7 @@ public:
_keyChecked = checked; _keyChecked = checked;
} }
not_null<QReadWriteLock*> keyMutex() const; QReadWriteLock *keyMutex() const;
not_null<QReadWriteLock*> toSendMutex() const { not_null<QReadWriteLock*> toSendMutex() const {
return &_toSendLock; return &_toSendLock;
@ -291,6 +296,7 @@ private:
not_null<Session*> _owner; not_null<Session*> _owner;
AuthKeyPtr _authKey; AuthKeyPtr _authKey;
AuthKeyPtr _dcKeyForCheck;
bool _keyChecked = false; bool _keyChecked = false;
bool _layerInited = false; bool _layerInited = false;
ConnectionOptions _options; ConnectionOptions _options;
@ -345,6 +351,8 @@ public:
int32 getState() const; int32 getState() const;
QString transport() const; QString transport() const;
void sendDcKeyCheck(const AuthKeyPtr &key);
// Nulls msgId and seqNo in request, if newRequest = true. // Nulls msgId and seqNo in request, if newRequest = true.
void sendPrepared( void sendPrepared(
const SecureRequest &request, const SecureRequest &request,
@ -393,6 +401,7 @@ private:
ShiftedDcId _shiftedDcId = 0; ShiftedDcId _shiftedDcId = 0;
std::shared_ptr<Dcenter> _dc; std::shared_ptr<Dcenter> _dc;
AuthKeyPtr _dcKeyForCheck;
crl::time _msSendCall = 0; crl::time _msSendCall = 0;
crl::time _msWait = 0; crl::time _msWait = 0;
@ -404,9 +413,38 @@ private:
}; };
inline not_null<QReadWriteLock*> SessionData::keyMutex() const { inline QReadWriteLock *SessionData::keyMutex() const {
return _owner->keyMutex(); return _owner->keyMutex();
} }
class ReadLockerAttempt {
public:
ReadLockerAttempt(QReadWriteLock *lock) : _lock(lock), _locked(_lock ? _lock->tryLockForRead() : true) {
}
ReadLockerAttempt(const ReadLockerAttempt &other) = delete;
ReadLockerAttempt &operator=(const ReadLockerAttempt &other) = delete;
ReadLockerAttempt(ReadLockerAttempt &&other) : _lock(other._lock), _locked(base::take(other._locked)) {
}
ReadLockerAttempt &operator=(ReadLockerAttempt &&other) {
_lock = other._lock;
_locked = base::take(other._locked);
return *this;
}
~ReadLockerAttempt() {
if (_lock && _locked) {
_lock->unlock();
}
}
operator bool() const {
return _locked;
}
private:
QReadWriteLock *_lock = nullptr;
bool _locked = false;
};
} // namespace internal } // namespace internal
} // namespace MTP } // namespace MTP