diff --git a/Telegram/SourceFiles/mtproto/connection.cpp b/Telegram/SourceFiles/mtproto/connection.cpp index 7bcec5115..2a3072383 100644 --- a/Telegram/SourceFiles/mtproto/connection.cpp +++ b/Telegram/SourceFiles/mtproto/connection.cpp @@ -13,6 +13,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL #include "mtproto/mtproto_rsa_public_key.h" #include "mtproto/mtproto_rpc_sender.h" #include "mtproto/dc_options.h" +#include "mtproto/dcenter.h" #include "mtproto/connection_abstract.h" #include "zlib.h" #include "core/application.h" @@ -78,6 +79,12 @@ using namespace details; return idsStr + "]"; } +[[nodiscard]] TemporaryKeyType TemporaryKeyTypeByDcType(DcType type) { + return (type == DcType::MediaCluster) + ? TemporaryKeyType::MediaCluster + : TemporaryKeyType::Regular; +} + void wrapInvokeAfter(SecureRequest &to, const SecureRequest &from, const RequestMap &haveSent, int32 skipBeforeRequest = 0) { const auto afterId = *(mtpMsgId*)(from->after->data() + 4); const auto i = afterId ? haveSent.constFind(afterId) : haveSent.cend(); @@ -164,7 +171,7 @@ void ConnectionPrivate::appendTestConnection( const QString &ip, int port, const bytes::vector &protocolSecret) { - QWriteLocker lock(&stateConnMutex); + QWriteLocker lock(&_stateMutex); const auto priority = (qthelp::is_ipv6(ip) ? 0 : 1) + (protocol == DcOptions::Variants::Tcp ? 1 : 0) @@ -216,7 +223,7 @@ int16 ConnectionPrivate::getProtocolDcId() const { const auto testedDcId = cTestMode() ? (kTestModeDcIdShift + simpleDcId) : simpleDcId; - return (_dcType == DcType::MediaDownload) + return (_currentDcType == DcType::MediaCluster) ? -testedDcId : testedDcId; } @@ -279,9 +286,11 @@ ConnectionPrivate::ConnectionPrivate( ShiftedDcId shiftedDcId) : QObject(nullptr) , _instance(instance) -, _state(DisconnectedState) -, _shiftedDcId(shiftedDcId) , _owner(owner) +, _shiftedDcId(shiftedDcId) +, _realDcType(_instance->dcOptions()->dcType(_shiftedDcId)) +, _currentDcType(_realDcType) +, _state(DisconnectedState) , _retryTimer(thread, [=] { retryByTimer(); }) , _oldConnectionTimer(thread, [=] { markConnectionOld(); }) , _waitForConnectedTimer(thread, [=] { waitConnectedFailed(); }) @@ -329,7 +338,7 @@ int32 ConnectionPrivate::getShiftedDcId() const { } int32 ConnectionPrivate::getState() const { - QReadLocker lock(&stateConnMutex); + QReadLocker lock(&_stateMutex); int32 result = _state; if (_state < 0) { if (_retryTimer.isActive()) { @@ -343,7 +352,7 @@ int32 ConnectionPrivate::getState() const { } QString ConnectionPrivate::transport() const { - QReadLocker lock(&stateConnMutex); + QReadLocker lock(&_stateMutex); if (!_connection || (_state < 0)) { return QString(); } @@ -354,13 +363,13 @@ QString ConnectionPrivate::transport() const { bool ConnectionPrivate::setState(int32 state, int32 ifState) { if (ifState != Connection::UpdateAlways) { - QReadLocker lock(&stateConnMutex); + QReadLocker lock(&_stateMutex); if (_state != ifState) { return false; } } - QWriteLocker lock(&stateConnMutex); + QWriteLocker lock(&_stateMutex); if (_state == state) { return false; } @@ -409,6 +418,15 @@ uint32 ConnectionPrivate::nextRequestSeqNumber(bool needAck) { return result * 2 + (needAck ? 1 : 0); } +bool ConnectionPrivate::realDcTypeChanged() { + const auto now = _instance->dcOptions()->dcType(_shiftedDcId); + if (_realDcType == now) { + return false; + } + _realDcType = now; + return true; +} + bool ConnectionPrivate::markSessionAsStarted() { if (_sessionMarkedAsStarted) { return false; @@ -643,10 +661,10 @@ void ConnectionPrivate::tryToSend() { const auto systemLangCode = _connectionOptions->systemLangCode; const auto cloudLangCode = _connectionOptions->cloudLangCode; const auto langPackName = _connectionOptions->langPackName; - const auto deviceModel = (_dcType == DcType::Cdn) + const auto deviceModel = (_currentDcType == DcType::Cdn) ? "n/a" : _instance->deviceModel(); - const auto systemVersion = (_dcType == DcType::Cdn) + const auto systemVersion = (_currentDcType == DcType::Cdn) ? "n/a" : _instance->systemVersion(); #if defined OS_MAC_STORE @@ -928,40 +946,38 @@ void ConnectionPrivate::connectToServer(bool afterConfig) { DEBUG_LOG(("MTP Error: " "connectToServer() called for finished connection!")); return; + } else if (afterConfig && (!_testConnections.empty() || _connection)) { + return; + } + + destroyAllConnections(); + + if (realDcTypeChanged() && _keyCreator) { + destroyTemporaryKey(); + return; } _connectionOptions = std::make_unique( _sessionData->connectionOptions()); - tryAcquireKeyCreation(); - const auto bareDc = BareDcId(_shiftedDcId); - _dcType = _instance->dcOptions()->dcType(_shiftedDcId); - // Use media_only addresses only if key for this dc is already created. - if (_dcType == DcType::MediaDownload && _keyCreator) { - _dcType = DcType::Regular; - } else if (_dcType == DcType::Cdn && !_instance->isKeysDestroyer()) { + _currentDcType = tryAcquireKeyCreation(); + if (_currentDcType == DcType::Cdn && !_instance->isKeysDestroyer()) { if (!_instance->dcOptions()->hasCDNKeysForDc(bareDc)) { requestCDNConfig(); return; } } - - if (afterConfig && (!_testConnections.empty() || _connection)) { - return; - } - - destroyAllConnections(); if (_connectionOptions->proxy.type == ProxyData::Type::Mtproto) { // host, port, secret for mtproto proxy are taken from proxy. appendTestConnection(DcOptions::Variants::Tcp, {}, 0, {}); } else { using Variants = DcOptions::Variants; - const auto special = (_dcType == DcType::Temporary); + const auto special = (_currentDcType == DcType::Temporary); const auto variants = _instance->dcOptions()->lookup( bareDc, - _dcType, + _currentDcType, _connectionOptions->proxy.type != ProxyData::Type::None); const auto useIPv4 = special ? true : _connectionOptions->useIPv4; const auto useIPv6 = special ? false : _connectionOptions->useIPv6; @@ -1351,7 +1367,11 @@ void ConnectionPrivate::handleReceived() { auto from = decryptedInts + kEncryptedHeaderIntsCount; auto end = from + (messageLength / kIntSize); auto sfrom = decryptedInts + 4U; // msg_id + seq_no + length + message - MTP_LOG(_shiftedDcId, ("Recv: ") + details::DumpToText(sfrom, end) + QString(" (keyId:%1)").arg(_encryptionKey->keyId())); + MTP_LOG(_shiftedDcId, ("Recv: ") + + details::DumpToText(sfrom, end) + + QString(" (protocolDcId:%1,key:%2)" + ).arg(getProtocolDcId() + ).arg(_encryptionKey->keyId())); if (_receivedMessageIds.registerMsgId(msgId, needAck)) { res = handleOneReceived(from, end, msgId, serverTime, serverSalt, badTime); @@ -1860,9 +1880,11 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr response); switch (result) { case DcKeyBindState::Success: - _sessionData->releaseKeyCreationOnDone( + if (!_sessionData->releaseKeyCreationOnDone( _encryptionKey, - base::take(_keyCreator)->bindPersistentKey()); + base::take(_keyCreator)->bindPersistentKey())) { + return HandleResult::DestroyTemporaryKey; + } _sessionData->queueNeedToResumeAndSend(); return HandleResult::Success; case DcKeyBindState::DefinitelyDestroyed: @@ -1967,7 +1989,7 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr return HandleResult::ResetSession; } - if (_dcType == DcType::Regular) { + if (_currentDcType == DcType::Regular) { mtpBuffer update(end - from); if (end > from) memcpy(update.data(), from, (end - from) * sizeof(mtpPrime)); @@ -1986,7 +2008,8 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr LOG(("Message Error: unknown constructor 0x%1").arg(cons, 0, 16)); } } else { - LOG(("Message Error: unexpected updates in dcType: %1").arg(static_cast(_dcType))); + LOG(("Message Error: unexpected updates in dcType: %1" + ).arg(static_cast(_currentDcType))); } return HandleResult::Success; @@ -2311,17 +2334,27 @@ void ConnectionPrivate::checkAuthKey() { } else if (_instance->isKeysDestroyer()) { applyAuthKey(_sessionData->getPersistentKey()); } else { - applyAuthKey(_sessionData->getTemporaryKey()); + applyAuthKey(_sessionData->getTemporaryKey( + TemporaryKeyTypeByDcType(_currentDcType))); } } void ConnectionPrivate::updateAuthKey() { - if (_instance->isKeysDestroyer() || _keyCreator) { + if (_instance->isKeysDestroyer() || _keyCreator || !_connection) { return; } - DEBUG_LOG(("AuthKey Info: Connection updating key from Session, dc %1").arg(_shiftedDcId)); - applyAuthKey(_sessionData->getTemporaryKey()); + DEBUG_LOG(("AuthKey Info: Connection updating key from Session, dc %1" + ).arg(_shiftedDcId)); + const auto myKeyType = TemporaryKeyTypeByDcType(_currentDcType); + applyAuthKey(_sessionData->getTemporaryKey(myKeyType)); + + if (_connection + && !_encryptionKey + && myKeyType == TemporaryKeyType::MediaCluster + && _sessionData->getTemporaryKey(TemporaryKeyType::Regular)) { + restart(); + } } void ConnectionPrivate::setCurrentKeyId(uint64 newKeyId) { @@ -2395,11 +2428,17 @@ bool ConnectionPrivate::destroyOldEnoughPersistentKey() { return true; } -void ConnectionPrivate::tryAcquireKeyCreation() { - if (_instance->isKeysDestroyer() - || _keyCreator - || !_sessionData->acquireKeyCreation()) { - return; +DcType ConnectionPrivate::tryAcquireKeyCreation() { + if (_keyCreator) { + return _currentDcType; + } else if (_instance->isKeysDestroyer()) { + return _realDcType; + } + + const auto keyType = TemporaryKeyTypeByDcType(_realDcType); + const auto acquired = _sessionData->acquireKeyCreation(keyType); + if (acquired == CreatingKeyType::None) { + return _realDcType; } using Result = DcKeyResult; @@ -2409,7 +2448,7 @@ void ConnectionPrivate::tryAcquireKeyCreation() { if (!result) { releaseKeyCreationOnFail(); if (result.error() == Error::UnknownPublicKey) { - if (_dcType == DcType::Cdn) { + if (_realDcType == DcType::Cdn) { LOG(("Warning: CDN public RSA key not found")); requestCDNConfig(); return; @@ -2457,11 +2496,14 @@ void ConnectionPrivate::tryAcquireKeyCreation() { }; auto request = DcKeyRequest(); - request.persistentNeeded = !_sessionData->getPersistentKey(); + request.persistentNeeded = (acquired == CreatingKeyType::Persistent); request.temporaryExpiresIn = kTemporaryExpiresIn; _keyCreator = std::make_unique( request, std::move(delegate)); + const auto forceUseRegular = (_realDcType == DcType::MediaCluster) + && (acquired != CreatingKeyType::TemporaryMediaCluster); + return forceUseRegular ? DcType::Regular : _realDcType; } void ConnectionPrivate::authKeyChecked() { @@ -2548,7 +2590,11 @@ bool ConnectionPrivate::sendSecureRequest( memcpy(request->data() + 2, &_sessionId, 2 * sizeof(mtpPrime)); auto from = request->constData() + 4; - MTP_LOG(_shiftedDcId, ("Send: ") + details::DumpToText(from, from + messageSize) + QString(" (keyId:%1)").arg(_encryptionKey->keyId())); + MTP_LOG(_shiftedDcId, ("Send: ") + + details::DumpToText(from, from + messageSize) + + QString(" (protocolDcId:%1,key:%2)" + ).arg(getProtocolDcId() + ).arg(_encryptionKey->keyId())); #ifdef TDESKTOP_MTPROTO_OLD uint32 padding = fullSize - 4 - messageSize; diff --git a/Telegram/SourceFiles/mtproto/connection.h b/Telegram/SourceFiles/mtproto/connection.h index fa8ec88cd..b53a827ad 100644 --- a/Telegram/SourceFiles/mtproto/connection.h +++ b/Telegram/SourceFiles/mtproto/connection.h @@ -181,7 +181,7 @@ private: crl::time msCanWait = 0, bool forceContainer = false); - void tryAcquireKeyCreation(); + [[nodiscard]] DcType tryAcquireKeyCreation(); void resetSession(); void checkAuthKey(); void authKeyChecked(); @@ -193,20 +193,22 @@ private: void setCurrentKeyId(uint64 newKeyId); void changeSessionId(); - void setSessionSalt(uint64 salt); [[nodiscard]] bool markSessionAsStarted(); [[nodiscard]] uint32 nextRequestSeqNumber(bool needAck); - const not_null _instance; - DcType _dcType = DcType::Regular; + [[nodiscard]] bool realDcTypeChanged(); - mutable QReadWriteLock stateConnMutex; + const not_null _instance; + const not_null _owner; + const ShiftedDcId _shiftedDcId = 0; + DcType _realDcType = DcType(); + DcType _currentDcType = DcType(); + + mutable QReadWriteLock _stateMutex; int32 _state = DisconnectedState; bool _needSessionReset = false; - ShiftedDcId _shiftedDcId = 0; - not_null _owner; ConnectionPointer _connection; std::vector _testConnections; crl::time _startedConnectingAt = 0; diff --git a/Telegram/SourceFiles/mtproto/dc_options.cpp b/Telegram/SourceFiles/mtproto/dc_options.cpp index a072048d4..7e22e4c5b 100644 --- a/Telegram/SourceFiles/mtproto/dc_options.cpp +++ b/Telegram/SourceFiles/mtproto/dc_options.cpp @@ -540,8 +540,9 @@ DcType DcOptions::dcType(ShiftedDcId shiftedDcId) const { if (_cdnDcIds.find(BareDcId(shiftedDcId)) != _cdnDcIds.cend()) { return DcType::Cdn; } - if (isDownloadDcId(shiftedDcId)) { - return DcType::MediaDownload; + const auto dcId = BareDcId(shiftedDcId); + if (isDownloadDcId(shiftedDcId) && hasMediaOnlyOptionsFor(dcId)) { + return DcType::MediaCluster; } return DcType::Regular; } @@ -599,40 +600,54 @@ auto DcOptions::lookup( bool throughProxy) const -> Variants { using Flag = Flag; auto result = Variants(); - { - ReadLocker lock(this); - const auto i = _data.find(dcId); - if (i == end(_data)) { - return result; + + ReadLocker lock(this); + const auto i = _data.find(dcId); + if (i == end(_data)) { + return result; + } + for (const auto &endpoint : i->second) { + const auto flags = endpoint.flags; + if (type == DcType::Cdn && !(flags & Flag::f_cdn)) { + continue; + } else if (type != DcType::MediaCluster + && (flags & Flag::f_media_only)) { + continue; + } else if (!ValidateSecret(endpoint.secret)) { + continue; } - for (const auto &endpoint : i->second) { - const auto flags = endpoint.flags; - if (type == DcType::Cdn && !(flags & Flag::f_cdn)) { - continue; - } else if (type != DcType::MediaDownload - && (flags & Flag::f_media_only)) { - continue; - } else if (!ValidateSecret(endpoint.secret)) { - continue; - } - const auto address = (flags & Flag::f_ipv6) - ? Variants::IPv6 - : Variants::IPv4; - result.data[address][Variants::Tcp].push_back(endpoint); - if (!(flags & (Flag::f_tcpo_only | Flag::f_secret))) { - result.data[address][Variants::Http].push_back(endpoint); - } - } - if (type == DcType::MediaDownload) { - FilterIfHasWithFlag(result, Flag::f_media_only); - } - if (throughProxy) { - FilterIfHasWithFlag(result, Flag::f_static); + const auto address = (flags & Flag::f_ipv6) + ? Variants::IPv6 + : Variants::IPv4; + result.data[address][Variants::Tcp].push_back(endpoint); + if (!(flags & (Flag::f_tcpo_only | Flag::f_secret))) { + result.data[address][Variants::Http].push_back(endpoint); } } + if (type == DcType::MediaCluster) { + FilterIfHasWithFlag(result, Flag::f_media_only); + } + if (throughProxy) { + FilterIfHasWithFlag(result, Flag::f_static); + } return result; } +bool DcOptions::hasMediaOnlyOptionsFor(DcId dcId) const { + ReadLocker lock(this); + const auto i = _data.find(dcId); + if (i == end(_data)) { + return false; + } + for (const auto &endpoint : i->second) { + const auto flags = endpoint.flags; + if (flags & Flag::f_media_only) { + return true; + } + } + return false; +} + void DcOptions::FilterIfHasWithFlag(Variants &variants, Flag flag) { const auto is = [&](const Endpoint &endpoint) { return (endpoint.flags & flag) != 0; diff --git a/Telegram/SourceFiles/mtproto/dc_options.h b/Telegram/SourceFiles/mtproto/dc_options.h index 1e910f9a6..70b7a79dd 100644 --- a/Telegram/SourceFiles/mtproto/dc_options.h +++ b/Telegram/SourceFiles/mtproto/dc_options.h @@ -22,7 +22,7 @@ namespace MTP { enum class DcType { Regular, Temporary, - MediaDownload, + MediaCluster, Cdn, }; class DcOptions { @@ -89,8 +89,11 @@ public: }; std::vector data[AddressTypeCount][ProtocolCount]; }; - Variants lookup(DcId dcId, DcType type, bool throughProxy) const; - DcType dcType(ShiftedDcId shiftedDcId) const; + [[nodiscard]] Variants lookup( + DcId dcId, + DcType type, + bool throughProxy) const; + [[nodiscard]] DcType dcType(ShiftedDcId shiftedDcId) const; void setCDNConfig(const MTPDcdnConfig &config); [[nodiscard]] bool hasCDNKeysForDc(DcId dcId) const; @@ -121,6 +124,8 @@ private: const std::map> &b); static void FilterIfHasWithFlag(Variants &variants, Flag flag); + [[nodiscard]] bool hasMediaOnlyOptionsFor(DcId dcId) const; + void processFromList(const QVector &options, bool overwrite); void computeCdnDcIds(); diff --git a/Telegram/SourceFiles/mtproto/dcenter.cpp b/Telegram/SourceFiles/mtproto/dcenter.cpp index 22e212698..663d72897 100644 --- a/Telegram/SourceFiles/mtproto/dcenter.cpp +++ b/Telegram/SourceFiles/mtproto/dcenter.cpp @@ -21,6 +21,32 @@ namespace { constexpr auto kEnumerateDcTimeout = 8000; // 8 seconds timeout for help_getConfig to work (then move to other dc) constexpr auto kSpecialRequestTimeoutMs = 6000; // 4 seconds timeout for it to work in a specially requested dc. +int IndexByType(TemporaryKeyType type) { + switch (type) { + case TemporaryKeyType::Regular: return 0; + case TemporaryKeyType::MediaCluster: return 1; + } + Unexpected("Type value in IndexByType."); +} + +int IndexByType(CreatingKeyType type) { + switch (type) { + case CreatingKeyType::Persistent: + case CreatingKeyType::TemporaryRegular: return 0; + case CreatingKeyType::TemporaryMediaCluster: return 1; + } + Unexpected("Creating type value in IndexByType."); +} + +const char *NameOfType(CreatingKeyType type) { + switch (type) { + case CreatingKeyType::Persistent: return "persistent"; + case CreatingKeyType::TemporaryRegular: return "regular"; + case CreatingKeyType::TemporaryMediaCluster: return "media"; + } + Unexpected("Type value in NameOfType."); +} + } // namespace Dcenter::Dcenter(DcId dcId, AuthKeyPtr &&key) @@ -32,9 +58,9 @@ DcId Dcenter::id() const { return _id; } -AuthKeyPtr Dcenter::getTemporaryKey() const { +AuthKeyPtr Dcenter::getTemporaryKey(TemporaryKeyType type) const { QReadLocker lock(&_mutex); - return _temporaryKey; + return _temporaryKeys[IndexByType(type)]; } AuthKeyPtr Dcenter::getPersistentKey() const { @@ -44,12 +70,14 @@ AuthKeyPtr Dcenter::getPersistentKey() const { bool Dcenter::destroyTemporaryKey(uint64 keyId) { QWriteLocker lock(&_mutex); - if (!_temporaryKey || _temporaryKey->keyId() != keyId) { - return false; + for (auto &key : _temporaryKeys) { + if (key && key->keyId() == keyId) { + key = nullptr; + _connectionInited = false; + return true; + } } - _temporaryKey = nullptr; - _connectionInited = false; - return true; + return false; } bool Dcenter::destroyConfirmedForgottenKey(uint64 keyId) { @@ -57,7 +85,9 @@ bool Dcenter::destroyConfirmedForgottenKey(uint64 keyId) { if (!_persistentKey || _persistentKey->keyId() != keyId) { return false; } - _temporaryKey = nullptr; + for (auto &key : _temporaryKeys) { + key = nullptr; + } _persistentKey = nullptr; _connectionInited = false; return true; @@ -73,40 +103,66 @@ void Dcenter::setConnectionInited(bool connectionInited) { _connectionInited = connectionInited; } -bool Dcenter::acquireKeyCreation() { +CreatingKeyType Dcenter::acquireKeyCreation(TemporaryKeyType type) { QReadLocker lock(&_mutex); - if (_temporaryKey != nullptr) { - return false; + if (type == TemporaryKeyType::MediaCluster) { + int a = 0; + } + const auto index = IndexByType(type); + auto &key = _temporaryKeys[index]; + if (key != nullptr) { + return CreatingKeyType::None; } auto expected = false; - return _creatingKey.compare_exchange_strong(expected, true); + const auto regular = IndexByType(TemporaryKeyType::Regular); + if (type == TemporaryKeyType::MediaCluster && _temporaryKeys[regular]) { + return !_creatingKeys[index].compare_exchange_strong(expected, true) + ? CreatingKeyType::None + : CreatingKeyType::TemporaryMediaCluster; + } + return !_creatingKeys[regular].compare_exchange_strong(expected, true) + ? CreatingKeyType::None + : !_persistentKey + ? CreatingKeyType::Persistent + : CreatingKeyType::TemporaryRegular; } -void Dcenter::releaseKeyCreationOnFail() { - Expects(_creatingKey); - Expects(_temporaryKey == nullptr); - - _creatingKey = false; -} - -void Dcenter::releaseKeyCreationOnDone( +bool Dcenter::releaseKeyCreationOnDone( + CreatingKeyType type, const AuthKeyPtr &temporaryKey, - const AuthKeyPtr &persistentKey) { - Expects(_creatingKey); - Expects(_temporaryKey == nullptr); + const AuthKeyPtr &persistentKeyUsedForBind) { + Expects(_creatingKeys[IndexByType(type)]); + Expects(_temporaryKeys[IndexByType(type)] == nullptr); + Expects(temporaryKey != nullptr); QWriteLocker lock(&_mutex); - DEBUG_LOG(("AuthKey Info: Dcenter::releaseKeyCreationOnDone(%1, %2), " - "emitting authKeyChanged, dc %3" - ).arg(temporaryKey ? temporaryKey->keyId() : 0 - ).arg(persistentKey ? persistentKey->keyId() : 0 - ).arg(_id)); - _temporaryKey = temporaryKey; - if (persistentKey) { - _persistentKey = persistentKey; + if (type != CreatingKeyType::Persistent + && _persistentKey != persistentKeyUsedForBind) { + return false; } + if (type == CreatingKeyType::Persistent) { + _persistentKey = persistentKeyUsedForBind; + } else if (_persistentKey != persistentKeyUsedForBind) { + return false; + } + _temporaryKeys[IndexByType(type)] = temporaryKey; + _creatingKeys[IndexByType(type)] = false; _connectionInited = false; - _creatingKey = false; + + DEBUG_LOG(("AuthKey Info: Dcenter::releaseKeyCreationOnDone(%1, %2, %3)." + ).arg(NameOfType(type) + ).arg(temporaryKey ? temporaryKey->keyId() : 0 + ).arg(persistentKeyUsedForBind + ? persistentKeyUsedForBind->keyId() + : 0)); + return true; +} + +void Dcenter::releaseKeyCreationOnFail(CreatingKeyType type) { + Expects(_creatingKeys[IndexByType(type)]); + Expects(_temporaryKeys[IndexByType(type)] == nullptr); + + _creatingKeys[IndexByType(type)] = false; } } // namespace internal diff --git a/Telegram/SourceFiles/mtproto/dcenter.h b/Telegram/SourceFiles/mtproto/dcenter.h index 96326c425..0b4636014 100644 --- a/Telegram/SourceFiles/mtproto/dcenter.h +++ b/Telegram/SourceFiles/mtproto/dcenter.h @@ -15,6 +15,18 @@ using AuthKeyPtr = std::shared_ptr; namespace internal { +enum class TemporaryKeyType { + Regular, + MediaCluster +}; + +enum class CreatingKeyType { + None, + Persistent, + TemporaryRegular, + TemporaryMediaCluster +}; + class Dcenter : public QObject { public: // Main thread. @@ -22,29 +34,30 @@ public: // Thread-safe. [[nodiscard]] DcId id() const; - - [[nodiscard]] AuthKeyPtr getTemporaryKey() const; [[nodiscard]] AuthKeyPtr getPersistentKey() const; + [[nodiscard]] AuthKeyPtr getTemporaryKey(TemporaryKeyType type) const; + [[nodiscard]] CreatingKeyType acquireKeyCreation(TemporaryKeyType type); + bool releaseKeyCreationOnDone( + CreatingKeyType type, + const AuthKeyPtr &temporaryKey, + const AuthKeyPtr &persistentKeyUsedForBind); + void releaseKeyCreationOnFail(CreatingKeyType type); bool destroyTemporaryKey(uint64 keyId); bool destroyConfirmedForgottenKey(uint64 keyId); - void releaseKeyCreationOnDone( - const AuthKeyPtr &temporaryKey, - const AuthKeyPtr &persistentKey); [[nodiscard]] bool connectionInited() const; void setConnectionInited(bool connectionInited = true); - [[nodiscard]] bool acquireKeyCreation(); - void releaseKeyCreationOnFail(); - private: + static constexpr auto kTemporaryKeysCount = 2; + const DcId _id = 0; mutable QReadWriteLock _mutex; - AuthKeyPtr _temporaryKey; + AuthKeyPtr _temporaryKeys[kTemporaryKeysCount]; AuthKeyPtr _persistentKey; bool _connectionInited = false; - std::atomic _creatingKey = false; + std::atomic _creatingKeys[kTemporaryKeysCount] = { false }; }; diff --git a/Telegram/SourceFiles/mtproto/session.cpp b/Telegram/SourceFiles/mtproto/session.cpp index 5fb0b9e66..2dc955cb6 100644 --- a/Telegram/SourceFiles/mtproto/session.cpp +++ b/Telegram/SourceFiles/mtproto/session.cpp @@ -164,9 +164,9 @@ bool SessionData::connectionInited() const { return _owner ? _owner->connectionInited() : false; } -AuthKeyPtr SessionData::getTemporaryKey() const { +AuthKeyPtr SessionData::getTemporaryKey(TemporaryKeyType type) const { QMutexLocker lock(&_ownerMutex); - return _owner ? _owner->getTemporaryKey() : nullptr; + return _owner ? _owner->getTemporaryKey(type) : nullptr; } AuthKeyPtr SessionData::getPersistentKey() const { @@ -174,18 +174,20 @@ AuthKeyPtr SessionData::getPersistentKey() const { return _owner ? _owner->getPersistentKey() : nullptr; } -bool SessionData::acquireKeyCreation() { +CreatingKeyType SessionData::acquireKeyCreation(TemporaryKeyType type) { QMutexLocker lock(&_ownerMutex); - return _owner ? _owner->acquireKeyCreation() : false; + return _owner ? _owner->acquireKeyCreation(type) : CreatingKeyType::None; } -void SessionData::releaseKeyCreationOnDone( +bool SessionData::releaseKeyCreationOnDone( const AuthKeyPtr &temporaryKey, - const AuthKeyPtr &persistentKey) { + const AuthKeyPtr &persistentKeyUsedForBind) { QMutexLocker lock(&_ownerMutex); - if (_owner) { - _owner->releaseKeyCreationOnDone(temporaryKey, persistentKey); - } + return _owner + ? _owner->releaseKeyCreationOnDone( + temporaryKey, + persistentKeyUsedForBind) + : false; } void SessionData::releaseKeyCreationOnFail() { @@ -529,42 +531,56 @@ void Session::sendPrepared( } } -bool Session::acquireKeyCreation() { - Expects(!_myKeyCreation); +CreatingKeyType Session::acquireKeyCreation(TemporaryKeyType type) { + Expects(_myKeyCreation == CreatingKeyType::None); - if (!_dc->acquireKeyCreation()) { - return false; - } - _myKeyCreation = true; - return true; + _myKeyCreation = _dc->acquireKeyCreation(type); + return _myKeyCreation; } -void Session::releaseKeyCreationOnDone( +bool Session::releaseKeyCreationOnDone( const AuthKeyPtr &temporaryKey, - const AuthKeyPtr &persistentKey) { - Expects(_myKeyCreation); + const AuthKeyPtr &persistentKeyUsedForBind) { + Expects(_myKeyCreation != CreatingKeyType::None); + Expects(persistentKeyUsedForBind != nullptr); + + const auto wasKeyCreation = std::exchange( + _myKeyCreation, + CreatingKeyType::None); + const auto result = _dc->releaseKeyCreationOnDone( + wasKeyCreation, + temporaryKey, + persistentKeyUsedForBind); + + if (!result) { + DEBUG_LOG(("AuthKey Info: Persistent key changed " + "while binding temporary, dcWithShift %1" + ).arg(_shiftedDcId)); + return false; + } DEBUG_LOG(("AuthKey Info: Session key bound, setting, dcWithShift %1" ).arg(_shiftedDcId)); - _dc->releaseKeyCreationOnDone(temporaryKey, persistentKey); - _myKeyCreation = false; const auto dcId = _dc->id(); const auto instance = _instance; InvokeQueued(instance, [=] { - if (persistentKey) { - instance->dcPersistentKeyChanged(dcId, persistentKey); + if (wasKeyCreation == CreatingKeyType::Persistent) { + instance->dcPersistentKeyChanged(dcId, persistentKeyUsedForBind); } else { instance->dcTemporaryKeyChanged(dcId); } }); + return true; } void Session::releaseKeyCreationOnFail() { - Expects(_myKeyCreation); + Expects(_myKeyCreation != CreatingKeyType::None); - _dc->releaseKeyCreationOnFail(); - _myKeyCreation = false; + const auto wasKeyCreation = std::exchange( + _myKeyCreation, + CreatingKeyType::None); + _dc->releaseKeyCreationOnFail(wasKeyCreation); } void Session::notifyDcConnectionInited() { @@ -587,8 +603,8 @@ int32 Session::getDcWithShift() const { return _shiftedDcId; } -AuthKeyPtr Session::getTemporaryKey() const { - return _dc->getTemporaryKey(); +AuthKeyPtr Session::getTemporaryKey(TemporaryKeyType type) const { + return _dc->getTemporaryKey(type); } AuthKeyPtr Session::getPersistentKey() const { @@ -643,7 +659,7 @@ void Session::tryToReceive() { } Session::~Session() { - if (_myKeyCreation) { + if (_myKeyCreation != CreatingKeyType::None) { releaseKeyCreationOnFail(); } Assert(_connection == nullptr); diff --git a/Telegram/SourceFiles/mtproto/session.h b/Telegram/SourceFiles/mtproto/session.h index dc32e130d..d7dd23b3d 100644 --- a/Telegram/SourceFiles/mtproto/session.h +++ b/Telegram/SourceFiles/mtproto/session.h @@ -24,6 +24,9 @@ namespace internal { class Dcenter; class Connection; +enum class TemporaryKeyType; +enum class CreatingKeyType; + using PreRequestMap = QMap; using RequestMap = QMap; @@ -163,11 +166,11 @@ public: [[nodiscard]] bool connectionInited() const; [[nodiscard]] AuthKeyPtr getPersistentKey() const; - [[nodiscard]] AuthKeyPtr getTemporaryKey() const; - [[nodiscard]] bool acquireKeyCreation(); - void releaseKeyCreationOnDone( + [[nodiscard]] AuthKeyPtr getTemporaryKey(TemporaryKeyType type) const; + [[nodiscard]] CreatingKeyType acquireKeyCreation(TemporaryKeyType type); + [[nodiscard]] bool releaseKeyCreationOnDone( const AuthKeyPtr &temporaryKey, - const AuthKeyPtr &persistentKey); + const AuthKeyPtr &persistentKeyUsedForBind); void releaseKeyCreationOnFail(); void destroyTemporaryKey(uint64 keyId); void resend( @@ -229,7 +232,7 @@ public: // Thread-safe. [[nodiscard]] ShiftedDcId getDcWithShift() const; [[nodiscard]] AuthKeyPtr getPersistentKey() const; - [[nodiscard]] AuthKeyPtr getTemporaryKey() const; + [[nodiscard]] AuthKeyPtr getTemporaryKey(TemporaryKeyType type) const; [[nodiscard]] bool connectionInited() const; void resend( mtpMsgId msgId, @@ -245,10 +248,10 @@ public: bool newRequest = true); // Connection thread. - [[nodiscard]] bool acquireKeyCreation(); - void releaseKeyCreationOnDone( + [[nodiscard]] CreatingKeyType acquireKeyCreation(TemporaryKeyType type); + [[nodiscard]] bool releaseKeyCreationOnDone( const AuthKeyPtr &temporaryKey, - const AuthKeyPtr &persistentKey); + const AuthKeyPtr &persistentKeyUsedForBind); void releaseKeyCreationOnFail(); void destroyTemporaryKey(uint64 keyId); @@ -289,7 +292,7 @@ private: bool _needToReceive = false; AuthKeyPtr _dcKeyForCheck; - bool _myKeyCreation = false; + CreatingKeyType _myKeyCreation = CreatingKeyType(); crl::time _msSendCall = 0; crl::time _msWait = 0;