Fix temporary keys for media cluster.

This commit is contained in:
John Preston 2019-11-21 13:37:39 +03:00
parent 54b5342af4
commit 276adec0ff
8 changed files with 318 additions and 162 deletions

View File

@ -13,6 +13,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "mtproto/mtproto_rsa_public_key.h" #include "mtproto/mtproto_rsa_public_key.h"
#include "mtproto/mtproto_rpc_sender.h" #include "mtproto/mtproto_rpc_sender.h"
#include "mtproto/dc_options.h" #include "mtproto/dc_options.h"
#include "mtproto/dcenter.h"
#include "mtproto/connection_abstract.h" #include "mtproto/connection_abstract.h"
#include "zlib.h" #include "zlib.h"
#include "core/application.h" #include "core/application.h"
@ -78,6 +79,12 @@ using namespace details;
return idsStr + "]"; return idsStr + "]";
} }
[[nodiscard]] TemporaryKeyType TemporaryKeyTypeByDcType(DcType type) {
return (type == DcType::MediaCluster)
? TemporaryKeyType::MediaCluster
: TemporaryKeyType::Regular;
}
void wrapInvokeAfter(SecureRequest &to, const SecureRequest &from, const RequestMap &haveSent, int32 skipBeforeRequest = 0) { void wrapInvokeAfter(SecureRequest &to, const SecureRequest &from, const RequestMap &haveSent, int32 skipBeforeRequest = 0) {
const auto afterId = *(mtpMsgId*)(from->after->data() + 4); const auto afterId = *(mtpMsgId*)(from->after->data() + 4);
const auto i = afterId ? haveSent.constFind(afterId) : haveSent.cend(); const auto i = afterId ? haveSent.constFind(afterId) : haveSent.cend();
@ -164,7 +171,7 @@ void ConnectionPrivate::appendTestConnection(
const QString &ip, const QString &ip,
int port, int port,
const bytes::vector &protocolSecret) { const bytes::vector &protocolSecret) {
QWriteLocker lock(&stateConnMutex); QWriteLocker lock(&_stateMutex);
const auto priority = (qthelp::is_ipv6(ip) ? 0 : 1) const auto priority = (qthelp::is_ipv6(ip) ? 0 : 1)
+ (protocol == DcOptions::Variants::Tcp ? 1 : 0) + (protocol == DcOptions::Variants::Tcp ? 1 : 0)
@ -216,7 +223,7 @@ int16 ConnectionPrivate::getProtocolDcId() const {
const auto testedDcId = cTestMode() const auto testedDcId = cTestMode()
? (kTestModeDcIdShift + simpleDcId) ? (kTestModeDcIdShift + simpleDcId)
: simpleDcId; : simpleDcId;
return (_dcType == DcType::MediaDownload) return (_currentDcType == DcType::MediaCluster)
? -testedDcId ? -testedDcId
: testedDcId; : testedDcId;
} }
@ -279,9 +286,11 @@ ConnectionPrivate::ConnectionPrivate(
ShiftedDcId shiftedDcId) ShiftedDcId shiftedDcId)
: QObject(nullptr) : QObject(nullptr)
, _instance(instance) , _instance(instance)
, _state(DisconnectedState)
, _shiftedDcId(shiftedDcId)
, _owner(owner) , _owner(owner)
, _shiftedDcId(shiftedDcId)
, _realDcType(_instance->dcOptions()->dcType(_shiftedDcId))
, _currentDcType(_realDcType)
, _state(DisconnectedState)
, _retryTimer(thread, [=] { retryByTimer(); }) , _retryTimer(thread, [=] { retryByTimer(); })
, _oldConnectionTimer(thread, [=] { markConnectionOld(); }) , _oldConnectionTimer(thread, [=] { markConnectionOld(); })
, _waitForConnectedTimer(thread, [=] { waitConnectedFailed(); }) , _waitForConnectedTimer(thread, [=] { waitConnectedFailed(); })
@ -329,7 +338,7 @@ int32 ConnectionPrivate::getShiftedDcId() const {
} }
int32 ConnectionPrivate::getState() const { int32 ConnectionPrivate::getState() const {
QReadLocker lock(&stateConnMutex); QReadLocker lock(&_stateMutex);
int32 result = _state; int32 result = _state;
if (_state < 0) { if (_state < 0) {
if (_retryTimer.isActive()) { if (_retryTimer.isActive()) {
@ -343,7 +352,7 @@ int32 ConnectionPrivate::getState() const {
} }
QString ConnectionPrivate::transport() const { QString ConnectionPrivate::transport() const {
QReadLocker lock(&stateConnMutex); QReadLocker lock(&_stateMutex);
if (!_connection || (_state < 0)) { if (!_connection || (_state < 0)) {
return QString(); return QString();
} }
@ -354,13 +363,13 @@ QString ConnectionPrivate::transport() const {
bool ConnectionPrivate::setState(int32 state, int32 ifState) { bool ConnectionPrivate::setState(int32 state, int32 ifState) {
if (ifState != Connection::UpdateAlways) { if (ifState != Connection::UpdateAlways) {
QReadLocker lock(&stateConnMutex); QReadLocker lock(&_stateMutex);
if (_state != ifState) { if (_state != ifState) {
return false; return false;
} }
} }
QWriteLocker lock(&stateConnMutex); QWriteLocker lock(&_stateMutex);
if (_state == state) { if (_state == state) {
return false; return false;
} }
@ -409,6 +418,15 @@ uint32 ConnectionPrivate::nextRequestSeqNumber(bool needAck) {
return result * 2 + (needAck ? 1 : 0); return result * 2 + (needAck ? 1 : 0);
} }
bool ConnectionPrivate::realDcTypeChanged() {
const auto now = _instance->dcOptions()->dcType(_shiftedDcId);
if (_realDcType == now) {
return false;
}
_realDcType = now;
return true;
}
bool ConnectionPrivate::markSessionAsStarted() { bool ConnectionPrivate::markSessionAsStarted() {
if (_sessionMarkedAsStarted) { if (_sessionMarkedAsStarted) {
return false; return false;
@ -643,10 +661,10 @@ void ConnectionPrivate::tryToSend() {
const auto systemLangCode = _connectionOptions->systemLangCode; const auto systemLangCode = _connectionOptions->systemLangCode;
const auto cloudLangCode = _connectionOptions->cloudLangCode; const auto cloudLangCode = _connectionOptions->cloudLangCode;
const auto langPackName = _connectionOptions->langPackName; const auto langPackName = _connectionOptions->langPackName;
const auto deviceModel = (_dcType == DcType::Cdn) const auto deviceModel = (_currentDcType == DcType::Cdn)
? "n/a" ? "n/a"
: _instance->deviceModel(); : _instance->deviceModel();
const auto systemVersion = (_dcType == DcType::Cdn) const auto systemVersion = (_currentDcType == DcType::Cdn)
? "n/a" ? "n/a"
: _instance->systemVersion(); : _instance->systemVersion();
#if defined OS_MAC_STORE #if defined OS_MAC_STORE
@ -928,40 +946,38 @@ void ConnectionPrivate::connectToServer(bool afterConfig) {
DEBUG_LOG(("MTP Error: " DEBUG_LOG(("MTP Error: "
"connectToServer() called for finished connection!")); "connectToServer() called for finished connection!"));
return; return;
} else if (afterConfig && (!_testConnections.empty() || _connection)) {
return;
}
destroyAllConnections();
if (realDcTypeChanged() && _keyCreator) {
destroyTemporaryKey();
return;
} }
_connectionOptions = std::make_unique<ConnectionOptions>( _connectionOptions = std::make_unique<ConnectionOptions>(
_sessionData->connectionOptions()); _sessionData->connectionOptions());
tryAcquireKeyCreation();
const auto bareDc = BareDcId(_shiftedDcId); const auto bareDc = BareDcId(_shiftedDcId);
_dcType = _instance->dcOptions()->dcType(_shiftedDcId);
// Use media_only addresses only if key for this dc is already created. _currentDcType = tryAcquireKeyCreation();
if (_dcType == DcType::MediaDownload && _keyCreator) { if (_currentDcType == DcType::Cdn && !_instance->isKeysDestroyer()) {
_dcType = DcType::Regular;
} else if (_dcType == DcType::Cdn && !_instance->isKeysDestroyer()) {
if (!_instance->dcOptions()->hasCDNKeysForDc(bareDc)) { if (!_instance->dcOptions()->hasCDNKeysForDc(bareDc)) {
requestCDNConfig(); requestCDNConfig();
return; return;
} }
} }
if (afterConfig && (!_testConnections.empty() || _connection)) {
return;
}
destroyAllConnections();
if (_connectionOptions->proxy.type == ProxyData::Type::Mtproto) { if (_connectionOptions->proxy.type == ProxyData::Type::Mtproto) {
// host, port, secret for mtproto proxy are taken from proxy. // host, port, secret for mtproto proxy are taken from proxy.
appendTestConnection(DcOptions::Variants::Tcp, {}, 0, {}); appendTestConnection(DcOptions::Variants::Tcp, {}, 0, {});
} else { } else {
using Variants = DcOptions::Variants; using Variants = DcOptions::Variants;
const auto special = (_dcType == DcType::Temporary); const auto special = (_currentDcType == DcType::Temporary);
const auto variants = _instance->dcOptions()->lookup( const auto variants = _instance->dcOptions()->lookup(
bareDc, bareDc,
_dcType, _currentDcType,
_connectionOptions->proxy.type != ProxyData::Type::None); _connectionOptions->proxy.type != ProxyData::Type::None);
const auto useIPv4 = special ? true : _connectionOptions->useIPv4; const auto useIPv4 = special ? true : _connectionOptions->useIPv4;
const auto useIPv6 = special ? false : _connectionOptions->useIPv6; const auto useIPv6 = special ? false : _connectionOptions->useIPv6;
@ -1351,7 +1367,11 @@ void ConnectionPrivate::handleReceived() {
auto from = decryptedInts + kEncryptedHeaderIntsCount; auto from = decryptedInts + kEncryptedHeaderIntsCount;
auto end = from + (messageLength / kIntSize); auto end = from + (messageLength / kIntSize);
auto sfrom = decryptedInts + 4U; // msg_id + seq_no + length + message auto sfrom = decryptedInts + 4U; // msg_id + seq_no + length + message
MTP_LOG(_shiftedDcId, ("Recv: ") + details::DumpToText(sfrom, end) + QString(" (keyId:%1)").arg(_encryptionKey->keyId())); MTP_LOG(_shiftedDcId, ("Recv: ")
+ details::DumpToText(sfrom, end)
+ QString(" (protocolDcId:%1,key:%2)"
).arg(getProtocolDcId()
).arg(_encryptionKey->keyId()));
if (_receivedMessageIds.registerMsgId(msgId, needAck)) { if (_receivedMessageIds.registerMsgId(msgId, needAck)) {
res = handleOneReceived(from, end, msgId, serverTime, serverSalt, badTime); res = handleOneReceived(from, end, msgId, serverTime, serverSalt, badTime);
@ -1860,9 +1880,11 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr
response); response);
switch (result) { switch (result) {
case DcKeyBindState::Success: case DcKeyBindState::Success:
_sessionData->releaseKeyCreationOnDone( if (!_sessionData->releaseKeyCreationOnDone(
_encryptionKey, _encryptionKey,
base::take(_keyCreator)->bindPersistentKey()); base::take(_keyCreator)->bindPersistentKey())) {
return HandleResult::DestroyTemporaryKey;
}
_sessionData->queueNeedToResumeAndSend(); _sessionData->queueNeedToResumeAndSend();
return HandleResult::Success; return HandleResult::Success;
case DcKeyBindState::DefinitelyDestroyed: case DcKeyBindState::DefinitelyDestroyed:
@ -1967,7 +1989,7 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr
return HandleResult::ResetSession; return HandleResult::ResetSession;
} }
if (_dcType == DcType::Regular) { if (_currentDcType == DcType::Regular) {
mtpBuffer update(end - from); mtpBuffer update(end - from);
if (end > from) memcpy(update.data(), from, (end - from) * sizeof(mtpPrime)); if (end > from) memcpy(update.data(), from, (end - from) * sizeof(mtpPrime));
@ -1986,7 +2008,8 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr
LOG(("Message Error: unknown constructor 0x%1").arg(cons, 0, 16)); LOG(("Message Error: unknown constructor 0x%1").arg(cons, 0, 16));
} }
} else { } else {
LOG(("Message Error: unexpected updates in dcType: %1").arg(static_cast<int>(_dcType))); LOG(("Message Error: unexpected updates in dcType: %1"
).arg(static_cast<int>(_currentDcType)));
} }
return HandleResult::Success; return HandleResult::Success;
@ -2311,17 +2334,27 @@ void ConnectionPrivate::checkAuthKey() {
} else if (_instance->isKeysDestroyer()) { } else if (_instance->isKeysDestroyer()) {
applyAuthKey(_sessionData->getPersistentKey()); applyAuthKey(_sessionData->getPersistentKey());
} else { } else {
applyAuthKey(_sessionData->getTemporaryKey()); applyAuthKey(_sessionData->getTemporaryKey(
TemporaryKeyTypeByDcType(_currentDcType)));
} }
} }
void ConnectionPrivate::updateAuthKey() { void ConnectionPrivate::updateAuthKey() {
if (_instance->isKeysDestroyer() || _keyCreator) { if (_instance->isKeysDestroyer() || _keyCreator || !_connection) {
return; return;
} }
DEBUG_LOG(("AuthKey Info: Connection updating key from Session, dc %1").arg(_shiftedDcId)); DEBUG_LOG(("AuthKey Info: Connection updating key from Session, dc %1"
applyAuthKey(_sessionData->getTemporaryKey()); ).arg(_shiftedDcId));
const auto myKeyType = TemporaryKeyTypeByDcType(_currentDcType);
applyAuthKey(_sessionData->getTemporaryKey(myKeyType));
if (_connection
&& !_encryptionKey
&& myKeyType == TemporaryKeyType::MediaCluster
&& _sessionData->getTemporaryKey(TemporaryKeyType::Regular)) {
restart();
}
} }
void ConnectionPrivate::setCurrentKeyId(uint64 newKeyId) { void ConnectionPrivate::setCurrentKeyId(uint64 newKeyId) {
@ -2395,11 +2428,17 @@ bool ConnectionPrivate::destroyOldEnoughPersistentKey() {
return true; return true;
} }
void ConnectionPrivate::tryAcquireKeyCreation() { DcType ConnectionPrivate::tryAcquireKeyCreation() {
if (_instance->isKeysDestroyer() if (_keyCreator) {
|| _keyCreator return _currentDcType;
|| !_sessionData->acquireKeyCreation()) { } else if (_instance->isKeysDestroyer()) {
return; return _realDcType;
}
const auto keyType = TemporaryKeyTypeByDcType(_realDcType);
const auto acquired = _sessionData->acquireKeyCreation(keyType);
if (acquired == CreatingKeyType::None) {
return _realDcType;
} }
using Result = DcKeyResult; using Result = DcKeyResult;
@ -2409,7 +2448,7 @@ void ConnectionPrivate::tryAcquireKeyCreation() {
if (!result) { if (!result) {
releaseKeyCreationOnFail(); releaseKeyCreationOnFail();
if (result.error() == Error::UnknownPublicKey) { if (result.error() == Error::UnknownPublicKey) {
if (_dcType == DcType::Cdn) { if (_realDcType == DcType::Cdn) {
LOG(("Warning: CDN public RSA key not found")); LOG(("Warning: CDN public RSA key not found"));
requestCDNConfig(); requestCDNConfig();
return; return;
@ -2457,11 +2496,14 @@ void ConnectionPrivate::tryAcquireKeyCreation() {
}; };
auto request = DcKeyRequest(); auto request = DcKeyRequest();
request.persistentNeeded = !_sessionData->getPersistentKey(); request.persistentNeeded = (acquired == CreatingKeyType::Persistent);
request.temporaryExpiresIn = kTemporaryExpiresIn; request.temporaryExpiresIn = kTemporaryExpiresIn;
_keyCreator = std::make_unique<BoundKeyCreator>( _keyCreator = std::make_unique<BoundKeyCreator>(
request, request,
std::move(delegate)); std::move(delegate));
const auto forceUseRegular = (_realDcType == DcType::MediaCluster)
&& (acquired != CreatingKeyType::TemporaryMediaCluster);
return forceUseRegular ? DcType::Regular : _realDcType;
} }
void ConnectionPrivate::authKeyChecked() { void ConnectionPrivate::authKeyChecked() {
@ -2548,7 +2590,11 @@ bool ConnectionPrivate::sendSecureRequest(
memcpy(request->data() + 2, &_sessionId, 2 * sizeof(mtpPrime)); memcpy(request->data() + 2, &_sessionId, 2 * sizeof(mtpPrime));
auto from = request->constData() + 4; auto from = request->constData() + 4;
MTP_LOG(_shiftedDcId, ("Send: ") + details::DumpToText(from, from + messageSize) + QString(" (keyId:%1)").arg(_encryptionKey->keyId())); MTP_LOG(_shiftedDcId, ("Send: ")
+ details::DumpToText(from, from + messageSize)
+ QString(" (protocolDcId:%1,key:%2)"
).arg(getProtocolDcId()
).arg(_encryptionKey->keyId()));
#ifdef TDESKTOP_MTPROTO_OLD #ifdef TDESKTOP_MTPROTO_OLD
uint32 padding = fullSize - 4 - messageSize; uint32 padding = fullSize - 4 - messageSize;

View File

@ -181,7 +181,7 @@ private:
crl::time msCanWait = 0, crl::time msCanWait = 0,
bool forceContainer = false); bool forceContainer = false);
void tryAcquireKeyCreation(); [[nodiscard]] DcType tryAcquireKeyCreation();
void resetSession(); void resetSession();
void checkAuthKey(); void checkAuthKey();
void authKeyChecked(); void authKeyChecked();
@ -193,20 +193,22 @@ private:
void setCurrentKeyId(uint64 newKeyId); void setCurrentKeyId(uint64 newKeyId);
void changeSessionId(); void changeSessionId();
void setSessionSalt(uint64 salt);
[[nodiscard]] bool markSessionAsStarted(); [[nodiscard]] bool markSessionAsStarted();
[[nodiscard]] uint32 nextRequestSeqNumber(bool needAck); [[nodiscard]] uint32 nextRequestSeqNumber(bool needAck);
const not_null<Instance*> _instance; [[nodiscard]] bool realDcTypeChanged();
DcType _dcType = DcType::Regular;
mutable QReadWriteLock stateConnMutex; const not_null<Instance*> _instance;
const not_null<Connection*> _owner;
const ShiftedDcId _shiftedDcId = 0;
DcType _realDcType = DcType();
DcType _currentDcType = DcType();
mutable QReadWriteLock _stateMutex;
int32 _state = DisconnectedState; int32 _state = DisconnectedState;
bool _needSessionReset = false; bool _needSessionReset = false;
ShiftedDcId _shiftedDcId = 0;
not_null<Connection*> _owner;
ConnectionPointer _connection; ConnectionPointer _connection;
std::vector<TestConnection> _testConnections; std::vector<TestConnection> _testConnections;
crl::time _startedConnectingAt = 0; crl::time _startedConnectingAt = 0;

View File

@ -540,8 +540,9 @@ DcType DcOptions::dcType(ShiftedDcId shiftedDcId) const {
if (_cdnDcIds.find(BareDcId(shiftedDcId)) != _cdnDcIds.cend()) { if (_cdnDcIds.find(BareDcId(shiftedDcId)) != _cdnDcIds.cend()) {
return DcType::Cdn; return DcType::Cdn;
} }
if (isDownloadDcId(shiftedDcId)) { const auto dcId = BareDcId(shiftedDcId);
return DcType::MediaDownload; if (isDownloadDcId(shiftedDcId) && hasMediaOnlyOptionsFor(dcId)) {
return DcType::MediaCluster;
} }
return DcType::Regular; return DcType::Regular;
} }
@ -599,40 +600,54 @@ auto DcOptions::lookup(
bool throughProxy) const -> Variants { bool throughProxy) const -> Variants {
using Flag = Flag; using Flag = Flag;
auto result = Variants(); auto result = Variants();
{
ReadLocker lock(this); ReadLocker lock(this);
const auto i = _data.find(dcId); const auto i = _data.find(dcId);
if (i == end(_data)) { if (i == end(_data)) {
return result; return result;
}
for (const auto &endpoint : i->second) {
const auto flags = endpoint.flags;
if (type == DcType::Cdn && !(flags & Flag::f_cdn)) {
continue;
} else if (type != DcType::MediaCluster
&& (flags & Flag::f_media_only)) {
continue;
} else if (!ValidateSecret(endpoint.secret)) {
continue;
} }
for (const auto &endpoint : i->second) { const auto address = (flags & Flag::f_ipv6)
const auto flags = endpoint.flags; ? Variants::IPv6
if (type == DcType::Cdn && !(flags & Flag::f_cdn)) { : Variants::IPv4;
continue; result.data[address][Variants::Tcp].push_back(endpoint);
} else if (type != DcType::MediaDownload if (!(flags & (Flag::f_tcpo_only | Flag::f_secret))) {
&& (flags & Flag::f_media_only)) { result.data[address][Variants::Http].push_back(endpoint);
continue;
} else if (!ValidateSecret(endpoint.secret)) {
continue;
}
const auto address = (flags & Flag::f_ipv6)
? Variants::IPv6
: Variants::IPv4;
result.data[address][Variants::Tcp].push_back(endpoint);
if (!(flags & (Flag::f_tcpo_only | Flag::f_secret))) {
result.data[address][Variants::Http].push_back(endpoint);
}
}
if (type == DcType::MediaDownload) {
FilterIfHasWithFlag(result, Flag::f_media_only);
}
if (throughProxy) {
FilterIfHasWithFlag(result, Flag::f_static);
} }
} }
if (type == DcType::MediaCluster) {
FilterIfHasWithFlag(result, Flag::f_media_only);
}
if (throughProxy) {
FilterIfHasWithFlag(result, Flag::f_static);
}
return result; return result;
} }
bool DcOptions::hasMediaOnlyOptionsFor(DcId dcId) const {
ReadLocker lock(this);
const auto i = _data.find(dcId);
if (i == end(_data)) {
return false;
}
for (const auto &endpoint : i->second) {
const auto flags = endpoint.flags;
if (flags & Flag::f_media_only) {
return true;
}
}
return false;
}
void DcOptions::FilterIfHasWithFlag(Variants &variants, Flag flag) { void DcOptions::FilterIfHasWithFlag(Variants &variants, Flag flag) {
const auto is = [&](const Endpoint &endpoint) { const auto is = [&](const Endpoint &endpoint) {
return (endpoint.flags & flag) != 0; return (endpoint.flags & flag) != 0;

View File

@ -22,7 +22,7 @@ namespace MTP {
enum class DcType { enum class DcType {
Regular, Regular,
Temporary, Temporary,
MediaDownload, MediaCluster,
Cdn, Cdn,
}; };
class DcOptions { class DcOptions {
@ -89,8 +89,11 @@ public:
}; };
std::vector<Endpoint> data[AddressTypeCount][ProtocolCount]; std::vector<Endpoint> data[AddressTypeCount][ProtocolCount];
}; };
Variants lookup(DcId dcId, DcType type, bool throughProxy) const; [[nodiscard]] Variants lookup(
DcType dcType(ShiftedDcId shiftedDcId) const; DcId dcId,
DcType type,
bool throughProxy) const;
[[nodiscard]] DcType dcType(ShiftedDcId shiftedDcId) const;
void setCDNConfig(const MTPDcdnConfig &config); void setCDNConfig(const MTPDcdnConfig &config);
[[nodiscard]] bool hasCDNKeysForDc(DcId dcId) const; [[nodiscard]] bool hasCDNKeysForDc(DcId dcId) const;
@ -121,6 +124,8 @@ private:
const std::map<DcId, std::vector<Endpoint>> &b); const std::map<DcId, std::vector<Endpoint>> &b);
static void FilterIfHasWithFlag(Variants &variants, Flag flag); static void FilterIfHasWithFlag(Variants &variants, Flag flag);
[[nodiscard]] bool hasMediaOnlyOptionsFor(DcId dcId) const;
void processFromList(const QVector<MTPDcOption> &options, bool overwrite); void processFromList(const QVector<MTPDcOption> &options, bool overwrite);
void computeCdnDcIds(); void computeCdnDcIds();

View File

@ -21,6 +21,32 @@ namespace {
constexpr auto kEnumerateDcTimeout = 8000; // 8 seconds timeout for help_getConfig to work (then move to other dc) constexpr auto kEnumerateDcTimeout = 8000; // 8 seconds timeout for help_getConfig to work (then move to other dc)
constexpr auto kSpecialRequestTimeoutMs = 6000; // 4 seconds timeout for it to work in a specially requested dc. constexpr auto kSpecialRequestTimeoutMs = 6000; // 4 seconds timeout for it to work in a specially requested dc.
int IndexByType(TemporaryKeyType type) {
switch (type) {
case TemporaryKeyType::Regular: return 0;
case TemporaryKeyType::MediaCluster: return 1;
}
Unexpected("Type value in IndexByType.");
}
int IndexByType(CreatingKeyType type) {
switch (type) {
case CreatingKeyType::Persistent:
case CreatingKeyType::TemporaryRegular: return 0;
case CreatingKeyType::TemporaryMediaCluster: return 1;
}
Unexpected("Creating type value in IndexByType.");
}
const char *NameOfType(CreatingKeyType type) {
switch (type) {
case CreatingKeyType::Persistent: return "persistent";
case CreatingKeyType::TemporaryRegular: return "regular";
case CreatingKeyType::TemporaryMediaCluster: return "media";
}
Unexpected("Type value in NameOfType.");
}
} // namespace } // namespace
Dcenter::Dcenter(DcId dcId, AuthKeyPtr &&key) Dcenter::Dcenter(DcId dcId, AuthKeyPtr &&key)
@ -32,9 +58,9 @@ DcId Dcenter::id() const {
return _id; return _id;
} }
AuthKeyPtr Dcenter::getTemporaryKey() const { AuthKeyPtr Dcenter::getTemporaryKey(TemporaryKeyType type) const {
QReadLocker lock(&_mutex); QReadLocker lock(&_mutex);
return _temporaryKey; return _temporaryKeys[IndexByType(type)];
} }
AuthKeyPtr Dcenter::getPersistentKey() const { AuthKeyPtr Dcenter::getPersistentKey() const {
@ -44,12 +70,14 @@ AuthKeyPtr Dcenter::getPersistentKey() const {
bool Dcenter::destroyTemporaryKey(uint64 keyId) { bool Dcenter::destroyTemporaryKey(uint64 keyId) {
QWriteLocker lock(&_mutex); QWriteLocker lock(&_mutex);
if (!_temporaryKey || _temporaryKey->keyId() != keyId) { for (auto &key : _temporaryKeys) {
return false; if (key && key->keyId() == keyId) {
key = nullptr;
_connectionInited = false;
return true;
}
} }
_temporaryKey = nullptr; return false;
_connectionInited = false;
return true;
} }
bool Dcenter::destroyConfirmedForgottenKey(uint64 keyId) { bool Dcenter::destroyConfirmedForgottenKey(uint64 keyId) {
@ -57,7 +85,9 @@ bool Dcenter::destroyConfirmedForgottenKey(uint64 keyId) {
if (!_persistentKey || _persistentKey->keyId() != keyId) { if (!_persistentKey || _persistentKey->keyId() != keyId) {
return false; return false;
} }
_temporaryKey = nullptr; for (auto &key : _temporaryKeys) {
key = nullptr;
}
_persistentKey = nullptr; _persistentKey = nullptr;
_connectionInited = false; _connectionInited = false;
return true; return true;
@ -73,40 +103,66 @@ void Dcenter::setConnectionInited(bool connectionInited) {
_connectionInited = connectionInited; _connectionInited = connectionInited;
} }
bool Dcenter::acquireKeyCreation() { CreatingKeyType Dcenter::acquireKeyCreation(TemporaryKeyType type) {
QReadLocker lock(&_mutex); QReadLocker lock(&_mutex);
if (_temporaryKey != nullptr) { if (type == TemporaryKeyType::MediaCluster) {
return false; int a = 0;
}
const auto index = IndexByType(type);
auto &key = _temporaryKeys[index];
if (key != nullptr) {
return CreatingKeyType::None;
} }
auto expected = false; auto expected = false;
return _creatingKey.compare_exchange_strong(expected, true); const auto regular = IndexByType(TemporaryKeyType::Regular);
if (type == TemporaryKeyType::MediaCluster && _temporaryKeys[regular]) {
return !_creatingKeys[index].compare_exchange_strong(expected, true)
? CreatingKeyType::None
: CreatingKeyType::TemporaryMediaCluster;
}
return !_creatingKeys[regular].compare_exchange_strong(expected, true)
? CreatingKeyType::None
: !_persistentKey
? CreatingKeyType::Persistent
: CreatingKeyType::TemporaryRegular;
} }
void Dcenter::releaseKeyCreationOnFail() { bool Dcenter::releaseKeyCreationOnDone(
Expects(_creatingKey); CreatingKeyType type,
Expects(_temporaryKey == nullptr);
_creatingKey = false;
}
void Dcenter::releaseKeyCreationOnDone(
const AuthKeyPtr &temporaryKey, const AuthKeyPtr &temporaryKey,
const AuthKeyPtr &persistentKey) { const AuthKeyPtr &persistentKeyUsedForBind) {
Expects(_creatingKey); Expects(_creatingKeys[IndexByType(type)]);
Expects(_temporaryKey == nullptr); Expects(_temporaryKeys[IndexByType(type)] == nullptr);
Expects(temporaryKey != nullptr);
QWriteLocker lock(&_mutex); QWriteLocker lock(&_mutex);
DEBUG_LOG(("AuthKey Info: Dcenter::releaseKeyCreationOnDone(%1, %2), " if (type != CreatingKeyType::Persistent
"emitting authKeyChanged, dc %3" && _persistentKey != persistentKeyUsedForBind) {
).arg(temporaryKey ? temporaryKey->keyId() : 0 return false;
).arg(persistentKey ? persistentKey->keyId() : 0
).arg(_id));
_temporaryKey = temporaryKey;
if (persistentKey) {
_persistentKey = persistentKey;
} }
if (type == CreatingKeyType::Persistent) {
_persistentKey = persistentKeyUsedForBind;
} else if (_persistentKey != persistentKeyUsedForBind) {
return false;
}
_temporaryKeys[IndexByType(type)] = temporaryKey;
_creatingKeys[IndexByType(type)] = false;
_connectionInited = false; _connectionInited = false;
_creatingKey = false;
DEBUG_LOG(("AuthKey Info: Dcenter::releaseKeyCreationOnDone(%1, %2, %3)."
).arg(NameOfType(type)
).arg(temporaryKey ? temporaryKey->keyId() : 0
).arg(persistentKeyUsedForBind
? persistentKeyUsedForBind->keyId()
: 0));
return true;
}
void Dcenter::releaseKeyCreationOnFail(CreatingKeyType type) {
Expects(_creatingKeys[IndexByType(type)]);
Expects(_temporaryKeys[IndexByType(type)] == nullptr);
_creatingKeys[IndexByType(type)] = false;
} }
} // namespace internal } // namespace internal

View File

@ -15,6 +15,18 @@ using AuthKeyPtr = std::shared_ptr<AuthKey>;
namespace internal { namespace internal {
enum class TemporaryKeyType {
Regular,
MediaCluster
};
enum class CreatingKeyType {
None,
Persistent,
TemporaryRegular,
TemporaryMediaCluster
};
class Dcenter : public QObject { class Dcenter : public QObject {
public: public:
// Main thread. // Main thread.
@ -22,29 +34,30 @@ public:
// Thread-safe. // Thread-safe.
[[nodiscard]] DcId id() const; [[nodiscard]] DcId id() const;
[[nodiscard]] AuthKeyPtr getTemporaryKey() const;
[[nodiscard]] AuthKeyPtr getPersistentKey() const; [[nodiscard]] AuthKeyPtr getPersistentKey() const;
[[nodiscard]] AuthKeyPtr getTemporaryKey(TemporaryKeyType type) const;
[[nodiscard]] CreatingKeyType acquireKeyCreation(TemporaryKeyType type);
bool releaseKeyCreationOnDone(
CreatingKeyType type,
const AuthKeyPtr &temporaryKey,
const AuthKeyPtr &persistentKeyUsedForBind);
void releaseKeyCreationOnFail(CreatingKeyType type);
bool destroyTemporaryKey(uint64 keyId); bool destroyTemporaryKey(uint64 keyId);
bool destroyConfirmedForgottenKey(uint64 keyId); bool destroyConfirmedForgottenKey(uint64 keyId);
void releaseKeyCreationOnDone(
const AuthKeyPtr &temporaryKey,
const AuthKeyPtr &persistentKey);
[[nodiscard]] bool connectionInited() const; [[nodiscard]] bool connectionInited() const;
void setConnectionInited(bool connectionInited = true); void setConnectionInited(bool connectionInited = true);
[[nodiscard]] bool acquireKeyCreation();
void releaseKeyCreationOnFail();
private: private:
static constexpr auto kTemporaryKeysCount = 2;
const DcId _id = 0; const DcId _id = 0;
mutable QReadWriteLock _mutex; mutable QReadWriteLock _mutex;
AuthKeyPtr _temporaryKey; AuthKeyPtr _temporaryKeys[kTemporaryKeysCount];
AuthKeyPtr _persistentKey; AuthKeyPtr _persistentKey;
bool _connectionInited = false; bool _connectionInited = false;
std::atomic<bool> _creatingKey = false; std::atomic<bool> _creatingKeys[kTemporaryKeysCount] = { false };
}; };

View File

@ -164,9 +164,9 @@ bool SessionData::connectionInited() const {
return _owner ? _owner->connectionInited() : false; return _owner ? _owner->connectionInited() : false;
} }
AuthKeyPtr SessionData::getTemporaryKey() const { AuthKeyPtr SessionData::getTemporaryKey(TemporaryKeyType type) const {
QMutexLocker lock(&_ownerMutex); QMutexLocker lock(&_ownerMutex);
return _owner ? _owner->getTemporaryKey() : nullptr; return _owner ? _owner->getTemporaryKey(type) : nullptr;
} }
AuthKeyPtr SessionData::getPersistentKey() const { AuthKeyPtr SessionData::getPersistentKey() const {
@ -174,18 +174,20 @@ AuthKeyPtr SessionData::getPersistentKey() const {
return _owner ? _owner->getPersistentKey() : nullptr; return _owner ? _owner->getPersistentKey() : nullptr;
} }
bool SessionData::acquireKeyCreation() { CreatingKeyType SessionData::acquireKeyCreation(TemporaryKeyType type) {
QMutexLocker lock(&_ownerMutex); QMutexLocker lock(&_ownerMutex);
return _owner ? _owner->acquireKeyCreation() : false; return _owner ? _owner->acquireKeyCreation(type) : CreatingKeyType::None;
} }
void SessionData::releaseKeyCreationOnDone( bool SessionData::releaseKeyCreationOnDone(
const AuthKeyPtr &temporaryKey, const AuthKeyPtr &temporaryKey,
const AuthKeyPtr &persistentKey) { const AuthKeyPtr &persistentKeyUsedForBind) {
QMutexLocker lock(&_ownerMutex); QMutexLocker lock(&_ownerMutex);
if (_owner) { return _owner
_owner->releaseKeyCreationOnDone(temporaryKey, persistentKey); ? _owner->releaseKeyCreationOnDone(
} temporaryKey,
persistentKeyUsedForBind)
: false;
} }
void SessionData::releaseKeyCreationOnFail() { void SessionData::releaseKeyCreationOnFail() {
@ -529,42 +531,56 @@ void Session::sendPrepared(
} }
} }
bool Session::acquireKeyCreation() { CreatingKeyType Session::acquireKeyCreation(TemporaryKeyType type) {
Expects(!_myKeyCreation); Expects(_myKeyCreation == CreatingKeyType::None);
if (!_dc->acquireKeyCreation()) { _myKeyCreation = _dc->acquireKeyCreation(type);
return false; return _myKeyCreation;
}
_myKeyCreation = true;
return true;
} }
void Session::releaseKeyCreationOnDone( bool Session::releaseKeyCreationOnDone(
const AuthKeyPtr &temporaryKey, const AuthKeyPtr &temporaryKey,
const AuthKeyPtr &persistentKey) { const AuthKeyPtr &persistentKeyUsedForBind) {
Expects(_myKeyCreation); Expects(_myKeyCreation != CreatingKeyType::None);
Expects(persistentKeyUsedForBind != nullptr);
const auto wasKeyCreation = std::exchange(
_myKeyCreation,
CreatingKeyType::None);
const auto result = _dc->releaseKeyCreationOnDone(
wasKeyCreation,
temporaryKey,
persistentKeyUsedForBind);
if (!result) {
DEBUG_LOG(("AuthKey Info: Persistent key changed "
"while binding temporary, dcWithShift %1"
).arg(_shiftedDcId));
return false;
}
DEBUG_LOG(("AuthKey Info: Session key bound, setting, dcWithShift %1" DEBUG_LOG(("AuthKey Info: Session key bound, setting, dcWithShift %1"
).arg(_shiftedDcId)); ).arg(_shiftedDcId));
_dc->releaseKeyCreationOnDone(temporaryKey, persistentKey);
_myKeyCreation = false;
const auto dcId = _dc->id(); const auto dcId = _dc->id();
const auto instance = _instance; const auto instance = _instance;
InvokeQueued(instance, [=] { InvokeQueued(instance, [=] {
if (persistentKey) { if (wasKeyCreation == CreatingKeyType::Persistent) {
instance->dcPersistentKeyChanged(dcId, persistentKey); instance->dcPersistentKeyChanged(dcId, persistentKeyUsedForBind);
} else { } else {
instance->dcTemporaryKeyChanged(dcId); instance->dcTemporaryKeyChanged(dcId);
} }
}); });
return true;
} }
void Session::releaseKeyCreationOnFail() { void Session::releaseKeyCreationOnFail() {
Expects(_myKeyCreation); Expects(_myKeyCreation != CreatingKeyType::None);
_dc->releaseKeyCreationOnFail(); const auto wasKeyCreation = std::exchange(
_myKeyCreation = false; _myKeyCreation,
CreatingKeyType::None);
_dc->releaseKeyCreationOnFail(wasKeyCreation);
} }
void Session::notifyDcConnectionInited() { void Session::notifyDcConnectionInited() {
@ -587,8 +603,8 @@ int32 Session::getDcWithShift() const {
return _shiftedDcId; return _shiftedDcId;
} }
AuthKeyPtr Session::getTemporaryKey() const { AuthKeyPtr Session::getTemporaryKey(TemporaryKeyType type) const {
return _dc->getTemporaryKey(); return _dc->getTemporaryKey(type);
} }
AuthKeyPtr Session::getPersistentKey() const { AuthKeyPtr Session::getPersistentKey() const {
@ -643,7 +659,7 @@ void Session::tryToReceive() {
} }
Session::~Session() { Session::~Session() {
if (_myKeyCreation) { if (_myKeyCreation != CreatingKeyType::None) {
releaseKeyCreationOnFail(); releaseKeyCreationOnFail();
} }
Assert(_connection == nullptr); Assert(_connection == nullptr);

View File

@ -24,6 +24,9 @@ namespace internal {
class Dcenter; class Dcenter;
class Connection; class Connection;
enum class TemporaryKeyType;
enum class CreatingKeyType;
using PreRequestMap = QMap<mtpRequestId, SecureRequest>; using PreRequestMap = QMap<mtpRequestId, SecureRequest>;
using RequestMap = QMap<mtpMsgId, SecureRequest>; using RequestMap = QMap<mtpMsgId, SecureRequest>;
@ -163,11 +166,11 @@ public:
[[nodiscard]] bool connectionInited() const; [[nodiscard]] bool connectionInited() const;
[[nodiscard]] AuthKeyPtr getPersistentKey() const; [[nodiscard]] AuthKeyPtr getPersistentKey() const;
[[nodiscard]] AuthKeyPtr getTemporaryKey() const; [[nodiscard]] AuthKeyPtr getTemporaryKey(TemporaryKeyType type) const;
[[nodiscard]] bool acquireKeyCreation(); [[nodiscard]] CreatingKeyType acquireKeyCreation(TemporaryKeyType type);
void releaseKeyCreationOnDone( [[nodiscard]] bool releaseKeyCreationOnDone(
const AuthKeyPtr &temporaryKey, const AuthKeyPtr &temporaryKey,
const AuthKeyPtr &persistentKey); const AuthKeyPtr &persistentKeyUsedForBind);
void releaseKeyCreationOnFail(); void releaseKeyCreationOnFail();
void destroyTemporaryKey(uint64 keyId); void destroyTemporaryKey(uint64 keyId);
void resend( void resend(
@ -229,7 +232,7 @@ public:
// Thread-safe. // Thread-safe.
[[nodiscard]] ShiftedDcId getDcWithShift() const; [[nodiscard]] ShiftedDcId getDcWithShift() const;
[[nodiscard]] AuthKeyPtr getPersistentKey() const; [[nodiscard]] AuthKeyPtr getPersistentKey() const;
[[nodiscard]] AuthKeyPtr getTemporaryKey() const; [[nodiscard]] AuthKeyPtr getTemporaryKey(TemporaryKeyType type) const;
[[nodiscard]] bool connectionInited() const; [[nodiscard]] bool connectionInited() const;
void resend( void resend(
mtpMsgId msgId, mtpMsgId msgId,
@ -245,10 +248,10 @@ public:
bool newRequest = true); bool newRequest = true);
// Connection thread. // Connection thread.
[[nodiscard]] bool acquireKeyCreation(); [[nodiscard]] CreatingKeyType acquireKeyCreation(TemporaryKeyType type);
void releaseKeyCreationOnDone( [[nodiscard]] bool releaseKeyCreationOnDone(
const AuthKeyPtr &temporaryKey, const AuthKeyPtr &temporaryKey,
const AuthKeyPtr &persistentKey); const AuthKeyPtr &persistentKeyUsedForBind);
void releaseKeyCreationOnFail(); void releaseKeyCreationOnFail();
void destroyTemporaryKey(uint64 keyId); void destroyTemporaryKey(uint64 keyId);
@ -289,7 +292,7 @@ private:
bool _needToReceive = false; bool _needToReceive = false;
AuthKeyPtr _dcKeyForCheck; AuthKeyPtr _dcKeyForCheck;
bool _myKeyCreation = false; CreatingKeyType _myKeyCreation = CreatingKeyType();
crl::time _msSendCall = 0; crl::time _msSendCall = 0;
crl::time _msWait = 0; crl::time _msWait = 0;