Don't lock whole key creation by a mutex.

This commit is contained in:
John Preston 2019-11-15 16:04:32 +03:00
parent 055b99f5b0
commit be06d68468
12 changed files with 397 additions and 445 deletions

View File

@ -40,7 +40,6 @@ namespace MTP {
namespace internal { namespace internal {
namespace { namespace {
constexpr auto kRecreateKeyId = AuthKey::KeyId(0xFFFFFFFFFFFFFFFFULL);
constexpr auto kIntSize = static_cast<int>(sizeof(mtpPrime)); constexpr auto kIntSize = static_cast<int>(sizeof(mtpPrime));
constexpr auto kWaitForBetterTimeout = crl::time(2000); constexpr auto kWaitForBetterTimeout = crl::time(2000);
constexpr auto kMinConnectedTimeout = crl::time(1000); constexpr auto kMinConnectedTimeout = crl::time(1000);
@ -100,6 +99,14 @@ Connection::Connection(not_null<Instance*> instance)
: _instance(instance) { : _instance(instance) {
} }
Connection::~Connection() {
Expects(_private == nullptr);
if (_thread) {
waitTillFinish();
}
}
void Connection::start(SessionData *sessionData, ShiftedDcId shiftedDcId) { void Connection::start(SessionData *sessionData, ShiftedDcId shiftedDcId) {
Expects(_thread == nullptr && _private == nullptr); Expects(_thread == nullptr && _private == nullptr);
@ -144,14 +151,6 @@ QString Connection::transport() const {
return _private->transport(); return _private->transport();
} }
Connection::~Connection() {
Expects(_private == nullptr);
if (_thread) {
waitTillFinish();
}
}
void ConnectionPrivate::appendTestConnection( void ConnectionPrivate::appendTestConnection(
DcOptions::Variants::Protocol protocol, DcOptions::Variants::Protocol protocol,
const QString &ip, const QString &ip,
@ -215,11 +214,14 @@ int16 ConnectionPrivate::getProtocolDcId() const {
} }
void ConnectionPrivate::destroyAllConnections() { void ConnectionPrivate::destroyAllConnections() {
{
QReadLocker lockFinished(&_sessionDataMutex);
clearKeyCreatorOnFail();
}
_waitForBetterTimer.cancel(); _waitForBetterTimer.cancel();
_waitForReceivedTimer.cancel(); _waitForReceivedTimer.cancel();
_waitForConnectedTimer.cancel(); _waitForConnectedTimer.cancel();
_testConnections.clear(); _testConnections.clear();
_keyCreator = nullptr;
_connection = nullptr; _connection = nullptr;
} }
@ -250,7 +252,7 @@ ConnectionPrivate::ConnectionPrivate(
connect(thread, &QThread::started, this, [=] { connectToServer(); }); connect(thread, &QThread::started, this, [=] { connectToServer(); });
connect(thread, &QThread::finished, this, [=] { finishAndDestroy(); }); connect(thread, &QThread::finished, this, [=] { finishAndDestroy(); });
connect(_sessionData->owner(), SIGNAL(authKeyCreated()), this, SLOT(updateAuthKey()), Qt::QueuedConnection); connect(_sessionData->owner(), SIGNAL(authKeyChanged()), this, SLOT(updateAuthKey()), Qt::QueuedConnection);
connect(_sessionData->owner(), SIGNAL(needToRestart()), this, SLOT(restartNow()), Qt::QueuedConnection); connect(_sessionData->owner(), SIGNAL(needToRestart()), this, SLOT(restartNow()), Qt::QueuedConnection);
connect(this, SIGNAL(needToReceive()), _sessionData->owner(), SLOT(tryToReceive()), Qt::QueuedConnection); connect(this, SIGNAL(needToReceive()), _sessionData->owner(), SLOT(tryToReceive()), Qt::QueuedConnection);
connect(this, SIGNAL(stateChanged(qint32)), _sessionData->owner(), SLOT(onConnectionStateChange(qint32)), Qt::QueuedConnection); connect(this, SIGNAL(stateChanged(qint32)), _sessionData->owner(), SLOT(onConnectionStateChange(qint32)), Qt::QueuedConnection);
@ -274,6 +276,13 @@ ConnectionPrivate::ConnectionPrivate(
connect(this, SIGNAL(resendAllAsync()), _sessionData->owner(), SLOT(resendAll()), Qt::QueuedConnection); connect(this, SIGNAL(resendAllAsync()), _sessionData->owner(), SLOT(resendAll()), Qt::QueuedConnection);
} }
ConnectionPrivate::~ConnectionPrivate() {
Expects(_finished);
Expects(!_connection);
Expects(_testConnections.empty());
Expects(!_keyCreator);
}
void ConnectionPrivate::onConfigLoaded() { void ConnectionPrivate::onConfigLoaded() {
connectToServer(true); connectToServer(true);
} }
@ -562,11 +571,11 @@ mtpMsgId ConnectionPrivate::placeToContainer(SecureRequest &toSendRequest, mtpMs
void ConnectionPrivate::tryToSend() { void ConnectionPrivate::tryToSend() {
QReadLocker lockFinished(&_sessionDataMutex); QReadLocker lockFinished(&_sessionDataMutex);
if (!_sessionData || !_connection) { if (!_sessionData || !_connection || !_keyId) {
return; return;
} }
auto needsLayer = !_connectionOptions->inited; auto needsLayer = !_sessionData->owner()->connectionInited();
auto state = getState(); auto state = getState();
auto sendOnlyFirstPing = (state != ConnectedState); auto sendOnlyFirstPing = (state != ConnectedState);
if (sendOnlyFirstPing && !_pingIdToSend) { if (sendOnlyFirstPing && !_pingIdToSend) {
@ -650,7 +659,7 @@ void ConnectionPrivate::tryToSend() {
_shiftedDcId, _shiftedDcId,
keyForCheck); keyForCheck);
checkDcKeyRequest = _keyChecker->prepareRequest( checkDcKeyRequest = _keyChecker->prepareRequest(
_sessionData->getKey(), _key,
_sessionData->getSessionId()); _sessionData->getSessionId());
// This is a special request with msgId used inside the message // This is a special request with msgId used inside the message
@ -771,7 +780,9 @@ void ConnectionPrivate::tryToSend() {
auto &haveSent = _sessionData->haveSentMap(); auto &haveSent = _sessionData->haveSentMap();
haveSent.insert(msgId, toSendRequest); haveSent.insert(msgId, toSendRequest);
if (needsLayer && !toSendRequest->needsLayer) needsLayer = false; if (needsLayer && !toSendRequest->needsLayer) {
needsLayer = false;
}
if (toSendRequest->after) { if (toSendRequest->after) {
const auto toSendSize = tl::count_length(toSendRequest) >> 2; const auto toSendSize = tl::count_length(toSendRequest) >> 2;
auto wrappedRequest = SecureRequest::Prepare( auto wrappedRequest = SecureRequest::Prepare(
@ -844,6 +855,7 @@ void ConnectionPrivate::tryToSend() {
// prepare "request-like" wrap for msgId vector // prepare "request-like" wrap for msgId vector
auto haveSentIdsWrap = SecureRequest::Prepare(idsWrapSize); auto haveSentIdsWrap = SecureRequest::Prepare(idsWrapSize);
haveSentIdsWrap->msDate = 0; // Container: msDate = 0, seqNo = 0.
haveSentIdsWrap->requestId = 0; haveSentIdsWrap->requestId = 0;
haveSentIdsWrap->resize(haveSentIdsWrap->size() + idsWrapSize); haveSentIdsWrap->resize(haveSentIdsWrap->size() + idsWrapSize);
auto haveSentArr = (mtpMsgId*)(haveSentIdsWrap->data() + 8); auto haveSentArr = (mtpMsgId*)(haveSentIdsWrap->data() + 8);
@ -931,15 +943,6 @@ void ConnectionPrivate::retryByTimer() {
} else if (_retryTimeout < 64000) { } else if (_retryTimeout < 64000) {
_retryTimeout *= 2; _retryTimeout *= 2;
} }
if (_keyId == kRecreateKeyId) {
if (_sessionData->getKey()) {
unlockKey();
QWriteLocker lock(_sessionData->keyMutex());
_sessionData->owner()->destroyKey();
}
_keyId = 0;
}
connectToServer(); connectToServer();
} }
@ -964,7 +967,8 @@ void ConnectionPrivate::connectToServer(bool afterConfig) {
} }
_connectionOptions = std::make_unique<ConnectionOptions>( _connectionOptions = std::make_unique<ConnectionOptions>(
_sessionData->connectionOptions()); _sessionData->connectionOptions());
const auto hasKey = (_sessionData->getKey() != nullptr); // #TODO race.
const auto hasKey = (_sessionData->owner()->getKey() != nullptr);
lockFinished.unlock(); lockFinished.unlock();
const auto bareDc = BareDcId(_shiftedDcId); const auto bareDc = BareDcId(_shiftedDcId);
@ -1212,13 +1216,6 @@ void ConnectionPrivate::connectingTimedOut() {
void ConnectionPrivate::doDisconnect() { void ConnectionPrivate::doDisconnect() {
destroyAllConnections(); destroyAllConnections();
{
QReadLocker lockFinished(&_sessionDataMutex);
if (_sessionData) {
unlockKey();
}
}
setState(DisconnectedState); setState(DisconnectedState);
_restarted = false; _restarted = false;
} }
@ -1257,21 +1254,6 @@ void ConnectionPrivate::handleReceived() {
restart(); restart();
}; };
ReadLockerAttempt lock(_sessionData->keyMutex());
if (!lock) {
DEBUG_LOG(("MTP Error: auth_key for dc %1 busy, cant lock").arg(_shiftedDcId));
clearMessages();
_keyId = 0;
return restartOnError();
}
auto key = _sessionData->getKey();
if (!key || key->keyId() != _keyId) {
DEBUG_LOG(("MTP Error: auth_key id for dc %1 changed").arg(_shiftedDcId));
return restartOnError();
}
while (!_connection->received().empty()) { while (!_connection->received().empty()) {
auto intsBuffer = std::move(_connection->received().front()); auto intsBuffer = std::move(_connection->received().front());
_connection->received().pop_front(); _connection->received().pop_front();
@ -1302,9 +1284,9 @@ void ConnectionPrivate::handleReceived() {
auto msgKey = *(MTPint128*)(ints + 2); auto msgKey = *(MTPint128*)(ints + 2);
#ifdef TDESKTOP_MTPROTO_OLD #ifdef TDESKTOP_MTPROTO_OLD
aesIgeDecrypt_oldmtp(encryptedInts, decryptedBuffer.data(), encryptedBytesCount, key, msgKey); aesIgeDecrypt_oldmtp(encryptedInts, decryptedBuffer.data(), encryptedBytesCount, _key, msgKey);
#else // TDESKTOP_MTPROTO_OLD #else // TDESKTOP_MTPROTO_OLD
aesIgeDecrypt(encryptedInts, decryptedBuffer.data(), encryptedBytesCount, key, msgKey); aesIgeDecrypt(encryptedInts, decryptedBuffer.data(), encryptedBytesCount, _key, msgKey);
#endif // TDESKTOP_MTPROTO_OLD #endif // TDESKTOP_MTPROTO_OLD
auto decryptedInts = reinterpret_cast<const mtpPrime*>(decryptedBuffer.constData()); auto decryptedInts = reinterpret_cast<const mtpPrime*>(decryptedBuffer.constData());
@ -1351,7 +1333,7 @@ void ConnectionPrivate::handleReceived() {
SHA256_CTX msgKeyLargeContext; SHA256_CTX msgKeyLargeContext;
SHA256_Init(&msgKeyLargeContext); SHA256_Init(&msgKeyLargeContext);
SHA256_Update(&msgKeyLargeContext, key->partForMsgKey(false), 32); SHA256_Update(&msgKeyLargeContext, _key->partForMsgKey(false), 32);
SHA256_Update(&msgKeyLargeContext, decryptedInts, encryptedBytesCount); SHA256_Update(&msgKeyLargeContext, decryptedInts, encryptedBytesCount);
SHA256_Final(sha256Buffer.data(), &msgKeyLargeContext); SHA256_Final(sha256Buffer.data(), &msgKeyLargeContext);
@ -1960,10 +1942,7 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr
// An error could be some RPC_CALL_FAIL or other error inside // An error could be some RPC_CALL_FAIL or other error inside
// the initConnection, so we're not sure yet that it was inited. // the initConnection, so we're not sure yet that it was inited.
// Wait till a good response is received. // Wait till a good response is received.
if (!_connectionOptions->inited) { _sessionData->notifyConnectionInited(*_connectionOptions);
_connectionOptions->inited = true;
_sessionData->notifyConnectionInited(*_connectionOptions);
}
} }
if (_keyChecker && _keyChecker->handleResponse(reqMsgId, response)) { if (_keyChecker && _keyChecker->handleResponse(reqMsgId, response)) {
@ -2346,7 +2325,7 @@ void ConnectionPrivate::onConnected(
_testConnections.clear(); _testConnections.clear();
lockFinished.unlock(); lockFinished.unlock();
updateAuthKey(); checkAuthKey();
} }
} }
@ -2383,7 +2362,7 @@ void ConnectionPrivate::confirmBestConnection() {
_connection = std::move(i->data); _connection = std::move(i->data);
_testConnections.clear(); _testConnections.clear();
updateAuthKey(); checkAuthKey();
} }
void ConnectionPrivate::removeTestConnection( void ConnectionPrivate::removeTestConnection(
@ -2396,51 +2375,58 @@ void ConnectionPrivate::removeTestConnection(
end(_testConnections)); end(_testConnections));
} }
void ConnectionPrivate::checkAuthKey() {
if (!_keyId) {
updateAuthKey();
} else {
authKeyChecked();
}
}
void ConnectionPrivate::updateAuthKey() { void ConnectionPrivate::updateAuthKey() {
QReadLocker lockFinished(&_sessionDataMutex); QReadLocker lockFinished(&_sessionDataMutex);
if (!_sessionData || !_connection) { if (!_sessionData || _keyCreator) {
return; return;
} }
DEBUG_LOG(("AuthKey Info: Connection updating key from Session, dc %1").arg(_shiftedDcId)); DEBUG_LOG(("AuthKey Info: Connection updating key from Session, dc %1").arg(_shiftedDcId));
uint64 newKeyId = 0; _key = _sessionData->owner()->getKey();
{ const auto newKeyId = _key ? _key->keyId() : 0;
ReadLockerAttempt lock(_sessionData->keyMutex()); if (newKeyId) {
if (!lock) { if (_keyId == newKeyId) {
DEBUG_LOG(("MTP Info: could not lock auth_key for read, waiting signal emit")); return;
clearMessages();
_keyId = newKeyId;
return; // some other connection is getting key
} }
auto key = _sessionData->getKey(); _sessionData->setCurrentKeyId(newKeyId);
newKeyId = key ? key->keyId() : 0;
} }
if (_keyId != newKeyId) { _keyId = newKeyId;
clearMessages(); if (!_connection) {
_keyId = newKeyId; return;
}
if (const auto already = _connection->sentEncryptedWithKeyId()) {
Assert(already != newKeyId);
DEBUG_LOG(("MTP Error: auth_key id for dc %1 changed").arg(_shiftedDcId));
lockFinished.unlock();
restart();
return;
} }
DEBUG_LOG(("AuthKey Info: Connection update key from Session, dc %1 result: %2").arg(_shiftedDcId).arg(Logs::mb(&_keyId, sizeof(_keyId)).str())); DEBUG_LOG(("AuthKey Info: Connection update key from Session, dc %1 result: %2").arg(_shiftedDcId).arg(Logs::mb(&_keyId, sizeof(_keyId)).str()));
if (_keyId) { if (_keyId) {
return authKeyCreated(); return authKeyChecked();
} }
DEBUG_LOG(("AuthKey Info: No key in updateAuthKey(), will be creating auth_key")); if (_instance->isKeysDestroyer()) {
lockKey();
const auto &key = _sessionData->getKey();
if (key) {
if (_keyId != key->keyId()) clearMessages();
_keyId = key->keyId();
unlockKey();
return authKeyCreated();
} else if (_instance->isKeysDestroyer()) {
// We are here to destroy an old key, so we're done. // We are here to destroy an old key, so we're done.
LOG(("MTP Error: No key %1 in updateAuthKey() for destroying.").arg(_shiftedDcId)); LOG(("MTP Error: No key %1 in updateAuthKey() for destroying.").arg(_shiftedDcId));
_instance->checkIfKeyWasDestroyed(_shiftedDcId); _instance->checkIfKeyWasDestroyed(_shiftedDcId);
return; return;
} else if (!_sessionData->owner()->acquireKeyCreation()) {
DEBUG_LOG(("AuthKey Info: No key in updateAuthKey(), but someone is creating already."));
return;
} }
lockFinished.unlock(); lockFinished.unlock();
DEBUG_LOG(("AuthKey Info: No key in updateAuthKey(), creating."));
createDcKey(); createDcKey();
} }
@ -2449,23 +2435,24 @@ void ConnectionPrivate::createDcKey() {
using Error = DcKeyCreator::Error; using Error = DcKeyCreator::Error;
auto delegate = DcKeyCreator::Delegate(); auto delegate = DcKeyCreator::Delegate();
delegate.done = [=](base::expected<Result, Error> result) { delegate.done = [=](base::expected<Result, Error> result) {
_keyCreator = nullptr; QReadLocker lockFinished(&_sessionDataMutex);
if (!_sessionData) return;
if (result) { if (result) {
QReadLocker lockFinished(&_sessionDataMutex); DEBUG_LOG(("AuthKey Info: auth key gen succeed, id: %1, server salt: %2").arg(result->key->keyId()).arg(result->serverSalt));
if (!_sessionData) return;
_sessionData->setSalt(result->serverSalt); _sessionData->setSalt(result->serverSalt);
_sessionData->clearForNewKey(_instance);
auto authKey = std::move(result->key); _keyCreator = nullptr;
_sessionData->owner()->releaseKeyCreationOnDone(
std::move(result->key));
DEBUG_LOG(("AuthKey Info: auth key gen succeed, id: %1, server salt: %2").arg(authKey->keyId()).arg(result->serverSalt)); updateAuthKey();
return;
// slot will call authKeyCreated(). }
_sessionData->owner()->notifyKeyCreated(std::move(authKey)); clearKeyCreatorOnFail();
_sessionData->clear(_instance); if (result.error() == Error::UnknownPublicKey) {
unlockKey();
} else if (result.error() == Error::UnknownPublicKey) {
if (_dcType == DcType::Cdn) { if (_dcType == DcType::Cdn) {
LOG(("Warning: CDN public RSA key not found")); LOG(("Warning: CDN public RSA key not found"));
requestCDNConfig(); requestCDNConfig();
@ -2489,26 +2476,18 @@ void ConnectionPrivate::createDcKey() {
expireIn); expireIn);
} }
void ConnectionPrivate::clearMessages() { void ConnectionPrivate::authKeyChecked() {
if (_keyId && _keyId != kRecreateKeyId && _connection) {
_connection->received().clear();
}
}
void ConnectionPrivate::authKeyCreated() {
_keyCreator = nullptr;
connect(_connection, &AbstractConnection::receivedData, [=] { connect(_connection, &AbstractConnection::receivedData, [=] {
handleReceived(); handleReceived();
}); });
if (_sessionData->getSalt()) { // else receive salt in bad_server_salt first, then try to send all the requests if (_sessionData->getSalt()) {
setState(ConnectedState); setState(ConnectedState);
if (_restarted) { if (_restarted) {
emit resendAllAsync(); emit resendAllAsync();
_restarted = false; _restarted = false;
} }
} } // else receive salt in bad_server_salt first, then try to send all the requests
_pingIdToSend = rand_value<uint64>(); // get server_salt _pingIdToSend = rand_value<uint64>(); // get server_salt
@ -2542,8 +2521,7 @@ void ConnectionPrivate::handleError(int errorCode) {
if (errorCode == -404) { if (errorCode == -404) {
if (_dcType == DcType::Cdn && !_instance->isKeysDestroyer()) { if (_dcType == DcType::Cdn && !_instance->isKeysDestroyer()) {
LOG(("MTP Info: -404 error received in CDN dc %1, assuming it was destroyed, recreating.").arg(_shiftedDcId)); LOG(("MTP Info: -404 error received in CDN dc %1, assuming it was destroyed, recreating.").arg(_shiftedDcId));
clearMessages(); destroyCdnKey();
_keyId = kRecreateKeyId;
return restart(); return restart();
} else { } else {
LOG(("MTP Info: -404 error received, informing instance.")); LOG(("MTP Info: -404 error received, informing instance."));
@ -2557,7 +2535,16 @@ void ConnectionPrivate::handleError(int errorCode) {
return restart(); return restart();
} }
void ConnectionPrivate::onReadyData() { void ConnectionPrivate::destroyCdnKey() {
if (_key) {
QReadLocker lockFinished(&_sessionDataMutex);
if (_sessionData) {
_sessionData->owner()->destroyCdnKey(_keyId);
}
}
_key = nullptr;
_keyId = 0;
} }
bool ConnectionPrivate::sendSecureRequest( bool ConnectionPrivate::sendSecureRequest(
@ -2581,24 +2568,6 @@ bool ConnectionPrivate::sendSecureRequest(
return false; return false;
} }
auto lock = ReadLockerAttempt(_sessionData->keyMutex());
if (!lock) {
DEBUG_LOG(("MTP Info: could not lock key for read in sendBuffer(), dc %1, restarting...").arg(_shiftedDcId));
lockFinished.unlock();
restart();
return false;
}
auto key = _sessionData->getKey();
if (!key || key->keyId() != _keyId) {
DEBUG_LOG(("MTP Error: auth_key id for dc %1 changed").arg(_shiftedDcId));
lockFinished.unlock();
restart();
return false;
}
auto session = _sessionData->getSessionId(); auto session = _sessionData->getSessionId();
auto salt = _sessionData->getSalt(); auto salt = _sessionData->getSalt();
@ -2626,7 +2595,7 @@ bool ConnectionPrivate::sendSecureRequest(
request->constData(), request->constData(),
&packet[prefix], &packet[prefix],
fullSize * sizeof(mtpPrime), fullSize * sizeof(mtpPrime),
key, _key,
msgKey); msgKey);
#else // TDESKTOP_MTPROTO_OLD #else // TDESKTOP_MTPROTO_OLD
uchar encryptedSHA256[32]; uchar encryptedSHA256[32];
@ -2634,7 +2603,7 @@ bool ConnectionPrivate::sendSecureRequest(
SHA256_CTX msgKeyLargeContext; SHA256_CTX msgKeyLargeContext;
SHA256_Init(&msgKeyLargeContext); SHA256_Init(&msgKeyLargeContext);
SHA256_Update(&msgKeyLargeContext, key->partForMsgKey(true), 32); SHA256_Update(&msgKeyLargeContext, _key->partForMsgKey(true), 32);
SHA256_Update(&msgKeyLargeContext, request->constData(), fullSize * sizeof(mtpPrime)); SHA256_Update(&msgKeyLargeContext, request->constData(), fullSize * sizeof(mtpPrime));
SHA256_Final(encryptedSHA256, &msgKeyLargeContext); SHA256_Final(encryptedSHA256, &msgKeyLargeContext);
@ -2646,13 +2615,13 @@ bool ConnectionPrivate::sendSecureRequest(
request->constData(), request->constData(),
&packet[prefix], &packet[prefix],
fullSize * sizeof(mtpPrime), fullSize * sizeof(mtpPrime),
key, _key,
msgKey); msgKey);
#endif // TDESKTOP_MTPROTO_OLD #endif // TDESKTOP_MTPROTO_OLD
DEBUG_LOG(("MTP Info: sending request, size: %1, num: %2, time: %3").arg(fullSize + 6).arg((*request)[4]).arg((*request)[5])); DEBUG_LOG(("MTP Info: sending request, size: %1, num: %2, time: %3").arg(fullSize + 6).arg((*request)[4]).arg((*request)[5]));
_connection->setSentEncrypted(); _connection->setSentEncryptedWithKeyId(_keyId);
_connection->sendData(std::move(packet)); _connection->sendData(std::move(packet));
if (needAnyResponse) { if (needAnyResponse) {
@ -2689,39 +2658,24 @@ mtpRequestId ConnectionPrivate::wasSent(mtpMsgId msgId) const {
return 0; return 0;
} }
void ConnectionPrivate::lockKey() { // _sessionDataMutex must be locked for read.
unlockKey(); void ConnectionPrivate::clearKeyCreatorOnFail() {
if (const auto mutex = _sessionData->keyMutex()) { if (_keyCreator) {
mutex->lockForWrite(); _keyCreator = nullptr;
}
_myKeyLock = true;
}
void ConnectionPrivate::unlockKey() { Assert(_sessionData != nullptr);
if (_myKeyLock) { _sessionData->owner()->releaseKeyCreationOnFail();
_myKeyLock = false;
if (const auto mutex = _sessionData->keyMutex()) {
mutex->unlock();
}
} }
} }
ConnectionPrivate::~ConnectionPrivate() {
Expects(_finished);
Expects(!_connection);
Expects(_testConnections.empty());
Expects(!_keyCreator);
}
void ConnectionPrivate::stop() { void ConnectionPrivate::stop() {
QWriteLocker lockFinished(&_sessionDataMutex); QWriteLocker lockFinished(&_sessionDataMutex);
if (_sessionData) { if (!_sessionData) {
if (_myKeyLock) { Assert(_keyCreator == nullptr);
_sessionData->owner()->notifyKeyCreated(AuthKeyPtr()); // release key lock, let someone else create it return;
unlockKey();
}
_sessionData = nullptr;
} }
clearKeyCreatorOnFail();
_sessionData = nullptr;
} }
} // namespace internal } // namespace internal

View File

@ -41,12 +41,12 @@ public:
}; };
Connection(not_null<Instance*> instance); Connection(not_null<Instance*> instance);
~Connection();
void start(SessionData *data, ShiftedDcId shiftedDcId); void start(SessionData *data, ShiftedDcId shiftedDcId);
void kill(); void kill();
void waitTillFinish(); void waitTillFinish();
~Connection();
static const int UpdateAlways = 666; static const int UpdateAlways = 666;
@ -99,14 +99,6 @@ public slots:
void onPingSendForce(); void onPingSendForce();
void onSentSome(uint64 size);
void onReceivedSome();
void onReadyData();
// General packet receive slot, connected to conn->receivedData signal
void handleReceived();
// Sessions signals, when we need to send something // Sessions signals, when we need to send something
void tryToSend(); void tryToSend();
@ -132,6 +124,10 @@ private:
qint32 errorCode); qint32 errorCode);
void onConnected(not_null<AbstractConnection*> connection); void onConnected(not_null<AbstractConnection*> connection);
void onDisconnected(not_null<AbstractConnection*> connection); void onDisconnected(not_null<AbstractConnection*> connection);
void onSentSome(uint64 size);
void onReceivedSome();
void handleReceived();
void retryByTimer(); void retryByTimer();
void waitConnectedFailed(); void waitConnectedFailed();
@ -140,7 +136,9 @@ private:
void markConnectionOld(); void markConnectionOld();
void sendPingByTimer(); void sendPingByTimer();
// Locks _sessionDataMutex.
void destroyAllConnections(); void destroyAllConnections();
void confirmBestConnection(); void confirmBestConnection();
void removeTestConnection(not_null<AbstractConnection*> connection); void removeTestConnection(not_null<AbstractConnection*> connection);
int16 getProtocolDcId() const; int16 getProtocolDcId() const;
@ -170,8 +168,6 @@ private:
mtpBuffer ungzip(const mtpPrime *from, const mtpPrime *end) const; mtpBuffer ungzip(const mtpPrime *from, const mtpPrime *end) const;
void handleMsgsStates(const QVector<MTPlong> &ids, const QByteArray &states, QVector<MTPlong> &acked); void handleMsgsStates(const QVector<MTPlong> &ids, const QByteArray &states, QVector<MTPlong> &acked);
void clearMessages();
bool setState(int32 state, int32 ifState = Connection::UpdateAlways); bool setState(int32 state, int32 ifState = Connection::UpdateAlways);
void appendTestConnection( void appendTestConnection(
@ -191,9 +187,12 @@ private:
void createDcKey(); void createDcKey();
void resetSession(); void resetSession();
void lockKey(); void checkAuthKey();
void unlockKey(); void authKeyChecked();
void authKeyCreated(); void destroyCdnKey();
// _sessionDataMutex must be locked for read.
void clearKeyCreatorOnFail();
not_null<Instance*> _instance; not_null<Instance*> _instance;
DcType _dcType = DcType::Regular; DcType _dcType = DcType::Regular;
@ -235,13 +234,12 @@ private:
bool _restarted = false; bool _restarted = false;
bool _finished = false; bool _finished = false;
AuthKeyPtr _key;
uint64 _keyId = 0; uint64 _keyId = 0;
QReadWriteLock _sessionDataMutex; QReadWriteLock _sessionDataMutex;
SessionData *_sessionData = nullptr; SessionData *_sessionData = nullptr;
std::unique_ptr<ConnectionOptions> _connectionOptions; std::unique_ptr<ConnectionOptions> _connectionOptions;
bool _myKeyLock = false;
std::unique_ptr<details::DcKeyCreator> _keyCreator; std::unique_ptr<details::DcKeyCreator> _keyCreator;
std::unique_ptr<details::DcKeyChecker> _keyChecker; std::unique_ptr<details::DcKeyChecker> _keyChecker;

View File

@ -100,8 +100,11 @@ public:
[[nodiscard]] virtual QString transport() const = 0; [[nodiscard]] virtual QString transport() const = 0;
[[nodiscard]] virtual QString tag() const = 0; [[nodiscard]] virtual QString tag() const = 0;
void setSentEncrypted() { void setSentEncryptedWithKeyId(uint64 keyId) {
_sentEncrypted = true; _sentEncryptedWithKeyId = keyId;
}
[[nodiscard]] uint64 sentEncryptedWithKeyId() const {
return _sentEncryptedWithKeyId;
} }
using BuffersQueue = std::deque<mtpBuffer>; using BuffersQueue = std::deque<mtpBuffer>;
@ -137,7 +140,6 @@ signals:
protected: protected:
BuffersQueue _receivedQueue; // list of received packets, not processed yet BuffersQueue _receivedQueue; // list of received packets, not processed yet
bool _sentEncrypted = false;
int _pingTime = 0; int _pingTime = 0;
ProxyData _proxy; ProxyData _proxy;
@ -150,6 +152,8 @@ protected:
private: private:
[[nodiscard]] uint32 extendedNotSecurePadding() const; [[nodiscard]] uint32 extendedNotSecurePadding() const;
uint64 _sentEncryptedWithKeyId = 0;
}; };
template <typename Request> template <typename Request>

View File

@ -23,53 +23,81 @@ constexpr auto kSpecialRequestTimeoutMs = 6000; // 4 seconds timeout for it to w
} // namespace } // namespace
Dcenter::Dcenter(not_null<Instance*> instance, DcId dcId, AuthKeyPtr &&key) Dcenter::Dcenter(DcId dcId, AuthKeyPtr &&key)
: _instance(instance) : _id(dcId)
, _id(dcId)
, _key(std::move(key)) { , _key(std::move(key)) {
connect(this, SIGNAL(authKeyCreated()), this, SLOT(authKeyWrite()), Qt::QueuedConnection);
} }
void Dcenter::authKeyWrite() { DcId Dcenter::id() const {
DEBUG_LOG(("AuthKey Info: MTProtoDC::authKeyWrite() slot, dc %1").arg(_id)); return _id;
if (_key) {
Local::writeMtpData();
}
} }
void Dcenter::setKey(AuthKeyPtr &&key) { AuthKeyPtr Dcenter::getKey() const {
DEBUG_LOG(("AuthKey Info: MTProtoDC::setKey(%1), emitting authKeyCreated, dc %2").arg(key ? key->keyId() : 0).arg(_id)); QReadLocker lock(&_mutex);
_key = std::move(key);
_connectionInited = false;
_instance->setKeyForWrite(_id, _key);
emit authKeyCreated();
}
QReadWriteLock *Dcenter::keyMutex() const {
return &keyLock;
}
const AuthKeyPtr &Dcenter::getKey() const {
return _key; return _key;
} }
void Dcenter::destroyKey() { void Dcenter::destroyCdnKey(uint64 keyId) {
setKey(AuthKeyPtr()); destroyKey(keyId);
}
bool Dcenter::destroyConfirmedForgottenKey(uint64 keyId) {
return destroyKey(keyId);
}
bool Dcenter::destroyKey(uint64 keyId) {
Expects(!_creatingKey || !_key);
QWriteLocker lock(&_mutex);
if (_key->keyId() != keyId) {
return false;
}
_key = nullptr;
_connectionInited = false;
lock.unlock();
emit authKeyChanged();
return true;
} }
bool Dcenter::connectionInited() const { bool Dcenter::connectionInited() const {
const auto lock = QMutexLocker(&_initLock); QReadLocker lock(&_mutex);
return _connectionInited; return _connectionInited;
} }
void Dcenter::setConnectionInited(bool connectionInited) { void Dcenter::setConnectionInited(bool connectionInited) {
auto lock = QMutexLocker(&_initLock); QWriteLocker lock(&_mutex);
_connectionInited = connectionInited; _connectionInited = connectionInited;
}
bool Dcenter::acquireKeyCreation() {
QReadLocker lock(&_mutex);
if (_key != nullptr) {
return false;
}
auto expected = false;
return _creatingKey.compare_exchange_strong(expected, true);
}
void Dcenter::releaseKeyCreationOnFail() {
Expects(_creatingKey);
Expects(_key == nullptr);
_creatingKey = false;
}
void Dcenter::releaseKeyCreationOnDone(AuthKeyPtr &&key) {
Expects(_creatingKey);
Expects(_key == nullptr);
QWriteLocker lock(&_mutex);
DEBUG_LOG(("AuthKey Info: Dcenter::releaseKeyCreationOnDone(%1), emitting authKeyChanged, dc %2").arg(key ? key->keyId() : 0).arg(_id));
_key = std::move(key);
_connectionInited = false;
_creatingKey = false;
lock.unlock(); lock.unlock();
if (connectionInited) { emit authKeyChanged();
emit connectionWasInited();
}
} }
} // namespace internal } // namespace internal

View File

@ -19,30 +19,35 @@ class Dcenter : public QObject {
Q_OBJECT Q_OBJECT
public: public:
Dcenter(not_null<Instance*> instance, DcId dcId, AuthKeyPtr &&key); // Main thread.
Dcenter(DcId dcId, AuthKeyPtr &&key);
QReadWriteLock *keyMutex() const; // Thread-safe.
const AuthKeyPtr &getKey() const; [[nodiscard]] DcId id() const;
void setKey(AuthKeyPtr &&key);
void destroyKey(); [[nodiscard]] AuthKeyPtr getKey() const;
void destroyCdnKey(uint64 keyId);
bool destroyConfirmedForgottenKey(uint64 keyId);
void releaseKeyCreationOnDone(AuthKeyPtr &&key);
[[nodiscard]] bool connectionInited() const; [[nodiscard]] bool connectionInited() const;
void setConnectionInited(bool connectionInited = true); void setConnectionInited(bool connectionInited = true);
signals: [[nodiscard]] bool acquireKeyCreation();
void authKeyCreated(); void releaseKeyCreationOnFail();
void connectionWasInited();
private slots: signals:
void authKeyWrite(); void authKeyChanged();
private: private:
mutable QReadWriteLock keyLock; bool destroyKey(uint64 keyId);
mutable QMutex _initLock;
not_null<Instance*> _instance; const DcId _id = 0;
DcId _id = 0; mutable QReadWriteLock _mutex;
AuthKeyPtr _key; AuthKeyPtr _key;
bool _connectionInited = false; bool _connectionInited = false;
std::atomic<bool> _creatingKey = false;
}; };

View File

@ -172,6 +172,9 @@ DcKeyCreator::DcKeyCreator(
} }
DcKeyCreator::~DcKeyCreator() { DcKeyCreator::~DcKeyCreator() {
if (_delegate.done) {
stopReceiving();
}
const auto clearBytes = [](bytes::span bytes) { const auto clearBytes = [](bytes::span bytes) {
OPENSSL_cleanse(bytes.data(), bytes.size()); OPENSSL_cleanse(bytes.data(), bytes.size());
}; };
@ -191,11 +194,7 @@ void DcKeyCreator::pqSend() {
} }
void DcKeyCreator::pqAnswered() { void DcKeyCreator::pqAnswered() {
QObject::disconnect( stopReceiving();
_connection,
&AbstractConnection::receivedData,
nullptr,
nullptr);
DEBUG_LOG(("AuthKey Info: receiving Req_pq answer...")); DEBUG_LOG(("AuthKey Info: receiving Req_pq answer..."));
MTPReq_pq::ResponseType res_pq; MTPReq_pq::ResponseType res_pq;
@ -272,11 +271,7 @@ void DcKeyCreator::pqAnswered() {
} }
void DcKeyCreator::dhParamsAnswered() { void DcKeyCreator::dhParamsAnswered() {
QObject::disconnect( stopReceiving();
_connection,
&AbstractConnection::receivedData,
nullptr,
nullptr);
DEBUG_LOG(("AuthKey Info: receiving Req_DH_params answer...")); DEBUG_LOG(("AuthKey Info: receiving Req_DH_params answer..."));
MTPReq_DH_params::ResponseType res_DH_params; MTPReq_DH_params::ResponseType res_DH_params;
@ -450,11 +445,7 @@ void DcKeyCreator::dhClientParamsSend() {
} }
void DcKeyCreator::dhClientParamsAnswered() { void DcKeyCreator::dhClientParamsAnswered() {
QObject::disconnect( stopReceiving();
_connection,
&AbstractConnection::receivedData,
nullptr,
nullptr);
DEBUG_LOG(("AuthKey Info: receiving Req_client_DH_params answer...")); DEBUG_LOG(("AuthKey Info: receiving Req_client_DH_params answer..."));
MTPSet_client_DH_params::ResponseType res_client_DH_params; MTPSet_client_DH_params::ResponseType res_client_DH_params;
@ -578,7 +569,8 @@ bool DcKeyCreator::readNotSecureResponse(Response &response) {
} }
void DcKeyCreator::failed(Error error) { void DcKeyCreator::failed(Error error) {
auto onstack = std::move(_delegate.done); stopReceiving();
auto onstack = base::take(_delegate.done);
onstack(tl::unexpected(error)); onstack(tl::unexpected(error));
} }
@ -589,8 +581,18 @@ void DcKeyCreator::done(uint64 serverSalt) {
_dcId, _dcId,
_authKey); _authKey);
result.serverSalt = serverSalt; result.serverSalt = serverSalt;
auto onstack = std::move(_delegate.done);
stopReceiving();
auto onstack = base::take(_delegate.done);
onstack(std::move(result)); onstack(std::move(result));
} }
void DcKeyCreator::stopReceiving() {
QObject::disconnect(
_connection,
&AbstractConnection::receivedData,
nullptr,
nullptr);
}
} // namespace MTP::details } // namespace MTP::details

View File

@ -83,6 +83,7 @@ private:
void dhClientParamsSend(); void dhClientParamsSend();
void dhClientParamsAnswered(); void dhClientParamsAnswered();
void stopReceiving();
void failed(Error error = Error::Other); void failed(Error error = Error::Other);
void done(uint64 serverSalt); void done(uint64 serverSalt);

View File

@ -645,7 +645,7 @@ not_null<Dcenter*> Instance::Private::addDc(
const auto dcId = BareDcId(shiftedDcId); const auto dcId = BareDcId(shiftedDcId);
return _dcenters.emplace( return _dcenters.emplace(
shiftedDcId, shiftedDcId,
std::make_unique<Dcenter>(_instance, dcId, std::move(key)) std::make_unique<Dcenter>(dcId, std::move(key))
).first->second.get(); ).first->second.get();
} }
@ -690,6 +690,10 @@ void Instance::Private::setKeyForWrite(DcId dcId, const AuthKeyPtr &key) {
} else { } else {
_keysForWrite.erase(dcId); _keysForWrite.erase(dcId);
} }
crl::on_main(_instance, [=] {
DEBUG_LOG(("AuthKey Info: writing auth keys, called by dc %1").arg(dcId));
Local::writeMtpData();
});
} }
AuthKeysList Instance::Private::getKeysForWrite() const { AuthKeysList Instance::Private::getKeysForWrite() const {
@ -1578,17 +1582,16 @@ void Instance::Private::checkMainDcKey() {
} }
void Instance::Private::keyDestroyedOnServer(DcId dcId, uint64 keyId) { void Instance::Private::keyDestroyedOnServer(DcId dcId, uint64 keyId) {
if (dcId == _mainDcId) { LOG(("Destroying key for dc: %1").arg(dcId));
for (const auto &[id, dc] : _dcenters) { if (const auto dc = findDc(dcId)) {
dc->destroyKey(); if (dc->destroyConfirmedForgottenKey(keyId)) {
LOG(("Key destroyed!"));
setKeyForWrite(dcId, nullptr);
} else {
LOG(("Key already is different."));
} }
restart();
} else {
if (const auto dc = findDc(dcId)) {
return dc->destroyKey();
}
restart(dcId);
} }
restart(dcId);
} }
void Instance::Private::setUpdatesHandler(RPCDoneHandlerPtr onDone) { void Instance::Private::setUpdatesHandler(RPCDoneHandlerPtr onDone) {
@ -1737,7 +1740,9 @@ void Instance::logout(RPCDoneHandlerPtr onDone, RPCFailHandlerPtr onFail) {
} }
void Instance::setKeyForWrite(DcId dcId, const AuthKeyPtr &key) { void Instance::setKeyForWrite(DcId dcId, const AuthKeyPtr &key) {
_private->setKeyForWrite(dcId, key); InvokeQueued(this, [=] {
_private->setKeyForWrite(dcId, key);
});
} }
AuthKeysList Instance::getKeysForWrite() const { AuthKeysList Instance::getKeysForWrite() const {

View File

@ -38,14 +38,16 @@ public:
QString deviceModel; QString deviceModel;
QString systemVersion; QString systemVersion;
}; };
enum class Mode { enum class Mode {
Normal, Normal,
KeysDestroyer, KeysDestroyer,
}; };
Instance(not_null<DcOptions*> options, Mode mode, Config &&config);
Instance(not_null<DcOptions*> options, Mode mode, Config &&config);
Instance(const Instance &other) = delete; Instance(const Instance &other) = delete;
Instance &operator=(const Instance &other) = delete; Instance &operator=(const Instance &other) = delete;
~Instance();
void resolveProxyDomain(const QString &host); void resolveProxyDomain(const QString &host);
void setGoodProxyDomain(const QString &host, const QString &ip); void setGoodProxyDomain(const QString &host, const QString &ip);
@ -56,16 +58,71 @@ public:
[[nodiscard]] QString cloudLangCode() const; [[nodiscard]] QString cloudLangCode() const;
[[nodiscard]] QString langPackName() const; [[nodiscard]] QString langPackName() const;
// Thread safe. // Thread-safe.
[[nodiscard]] QString deviceModel() const; [[nodiscard]] QString deviceModel() const;
[[nodiscard]] QString systemVersion() const; [[nodiscard]] QString systemVersion() const;
void setKeyForWrite(DcId dcId, const AuthKeyPtr &key); void setKeyForWrite(DcId dcId, const AuthKeyPtr &key);
// Main thread.
[[nodiscard]] AuthKeysList getKeysForWrite() const; [[nodiscard]] AuthKeysList getKeysForWrite() const;
void addKeysForDestroy(AuthKeysList &&keys); void addKeysForDestroy(AuthKeysList &&keys);
[[nodiscard]] not_null<DcOptions*> dcOptions(); [[nodiscard]] not_null<DcOptions*> dcOptions();
void restart();
void restart(ShiftedDcId shiftedDcId);
int32 dcstate(ShiftedDcId shiftedDcId = 0);
QString dctransport(ShiftedDcId shiftedDcId = 0);
void ping();
void cancel(mtpRequestId requestId);
int32 state(mtpRequestId requestId); // < 0 means waiting for such count of ms
// Main thread.
void killSession(ShiftedDcId shiftedDcId);
void stopSession(ShiftedDcId shiftedDcId);
void reInitConnection(DcId dcId);
void logout(RPCDoneHandlerPtr onDone, RPCFailHandlerPtr onFail);
void unpaused();
void queueQuittingConnection(std::unique_ptr<internal::Connection> &&connection);
void setUpdatesHandler(RPCDoneHandlerPtr onDone);
void setGlobalFailHandler(RPCFailHandlerPtr onFail);
void setStateChangedHandler(Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler);
void setSessionResetHandler(Fn<void(ShiftedDcId shiftedDcId)> handler);
void clearGlobalHandlers();
void onStateChange(ShiftedDcId shiftedDcId, int32 state);
void onSessionReset(ShiftedDcId shiftedDcId);
void clearCallbacksDelayed(std::vector<RPCCallbackClear> &&ids);
void execCallback(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end);
bool hasCallbacks(mtpRequestId requestId);
void globalCallback(const mtpPrime *from, const mtpPrime *end);
// return true if need to clean request data
bool rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err);
bool isKeysDestroyer() const;
void scheduleKeyDestroy(ShiftedDcId shiftedDcId);
void checkIfKeyWasDestroyed(ShiftedDcId shiftedDcId);
void keyDestroyedOnServer(DcId dcId, uint64 keyId);
void requestConfig();
void requestConfigIfOld();
void requestCDNConfig();
void setUserPhone(const QString &phone);
void badConfigurationError();
void syncHttpUnixtime();
void connectionFinished(not_null<internal::Connection*> connection);
void sendAnything(ShiftedDcId shiftedDcId = 0, crl::time msCanWait = 0);
void sendDcKeyCheck(ShiftedDcId shiftedDcId, const AuthKeyPtr &key);
template <typename Request> template <typename Request>
mtpRequestId send( mtpRequestId send(
const Request &request, const Request &request,
@ -134,60 +191,6 @@ public:
afterRequestId); afterRequestId);
} }
void sendAnything(ShiftedDcId shiftedDcId = 0, crl::time msCanWait = 0);
void sendDcKeyCheck(ShiftedDcId shiftedDcId, const AuthKeyPtr &key);
void restart();
void restart(ShiftedDcId shiftedDcId);
int32 dcstate(ShiftedDcId shiftedDcId = 0);
QString dctransport(ShiftedDcId shiftedDcId = 0);
void ping();
void cancel(mtpRequestId requestId);
int32 state(mtpRequestId requestId); // < 0 means waiting for such count of ms
void killSession(ShiftedDcId shiftedDcId);
void stopSession(ShiftedDcId shiftedDcId);
void reInitConnection(DcId dcId);
void logout(RPCDoneHandlerPtr onDone, RPCFailHandlerPtr onFail);
void unpaused();
void queueQuittingConnection(std::unique_ptr<internal::Connection> &&connection);
void setUpdatesHandler(RPCDoneHandlerPtr onDone);
void setGlobalFailHandler(RPCFailHandlerPtr onFail);
void setStateChangedHandler(Fn<void(ShiftedDcId shiftedDcId, int32 state)> handler);
void setSessionResetHandler(Fn<void(ShiftedDcId shiftedDcId)> handler);
void clearGlobalHandlers();
void onStateChange(ShiftedDcId shiftedDcId, int32 state);
void onSessionReset(ShiftedDcId shiftedDcId);
void clearCallbacksDelayed(std::vector<RPCCallbackClear> &&ids);
void execCallback(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end);
bool hasCallbacks(mtpRequestId requestId);
void globalCallback(const mtpPrime *from, const mtpPrime *end);
// return true if need to clean request data
bool rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err);
bool isKeysDestroyer() const;
void scheduleKeyDestroy(ShiftedDcId shiftedDcId);
void checkIfKeyWasDestroyed(ShiftedDcId shiftedDcId);
void keyDestroyedOnServer(DcId dcId, uint64 keyId);
void requestConfig();
void requestConfigIfOld();
void requestCDNConfig();
void setUserPhone(const QString &phone);
void badConfigurationError();
void syncHttpUnixtime();
void connectionFinished(not_null<internal::Connection*> connection);
~Instance();
signals: signals:
void configLoaded(); void configLoaded();
void cdnConfigLoaded(); void cdnConfigLoaded();

View File

@ -261,13 +261,13 @@ private:
}; };
struct RPCCallbackClear { struct RPCCallbackClear {
RPCCallbackClear(mtpRequestId id , int32 code = RPCError::NoError) RPCCallbackClear(mtpRequestId id, int32 code = RPCError::NoError)
: requestId(id) : requestId(id)
, errorCode(code) { , errorCode(code) {
} }
mtpRequestId requestId; mtpRequestId requestId = 0;
int32 errorCode; int32 errorCode = 0;
}; };

View File

@ -12,6 +12,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "mtproto/dcenter.h" #include "mtproto/dcenter.h"
#include "mtproto/auth_key.h" #include "mtproto/auth_key.h"
#include "base/unixtime.h" #include "base/unixtime.h"
#include "base/openssl_help.h"
#include "core/crash_reports.h" #include "core/crash_reports.h"
#include "facades.h" #include "facades.h"
@ -63,19 +64,15 @@ ConnectionOptions::ConnectionOptions(
, useTcp(useTcp) { , useTcp(useTcp) {
} }
void SessionData::setKey(const AuthKeyPtr &key) { void SessionData::setCurrentKeyId(uint64 keyId) {
if (_authKey != key) { QWriteLocker locker(&_lock);
const auto sessionId = rand_value<uint64>(); if (_keyId == keyId) {
_authKey = key; return;
DEBUG_LOG(("MTP Info: new auth key set in SessionData, id %1, setting random server_session %2").arg(key ? key->keyId() : 0).arg(sessionId));
QWriteLocker locker(&_lock);
if (_sessionId != sessionId) {
_sessionId = sessionId;
_messagesSent = 0;
}
_layerInited = false;
} }
_keyId = keyId;
_sessionId = openssl::RandomValue<uint64>();
_messagesSent = 0;
DEBUG_LOG(("MTP Info: new auth key set in SessionData, id %1, setting random server_session %2").arg(_keyId).arg(_sessionId));
} }
void SessionData::setKeyForCheck(const AuthKeyPtr &key) { void SessionData::setKeyForCheck(const AuthKeyPtr &key) {
@ -83,25 +80,24 @@ void SessionData::setKeyForCheck(const AuthKeyPtr &key) {
} }
void SessionData::notifyConnectionInited(const ConnectionOptions &options) { void SessionData::notifyConnectionInited(const ConnectionOptions &options) {
QWriteLocker locker(&_lock); // #TODO race
if (options.cloudLangCode == _options.cloudLangCode const auto current = connectionOptions();
&& options.systemLangCode == _options.systemLangCode if (current.cloudLangCode == _options.cloudLangCode
&& options.langPackName == _options.langPackName && current.systemLangCode == _options.systemLangCode
&& options.proxy == _options.proxy && current.langPackName == _options.langPackName
&& !_options.inited) { && current.proxy == _options.proxy) {
_options.inited = true;
locker.unlock();
owner()->notifyDcConnectionInited(); owner()->notifyDcConnectionInited();
} }
} }
void SessionData::clear(Instance *instance) { void SessionData::clearForNewKey(not_null<Instance*> instance) {
auto clearCallbacks = std::vector<RPCCallbackClear>(); auto clearCallbacks = std::vector<RPCCallbackClear>();
{ {
QReadLocker locker1(haveSentMutex()), locker2(toResendMutex()), locker3(haveReceivedMutex()), locker4(wereAckedMutex()); QReadLocker locker1(haveSentMutex());
auto receivedResponsesEnd = _receivedResponses.cend(); QReadLocker locker2(toResendMutex());
clearCallbacks.reserve(_haveSent.size() + _wereAcked.size()); QReadLocker locker3(haveReceivedMutex());
QReadLocker locker4(wereAckedMutex());
clearCallbacks.reserve(_haveSent.size() + _toResend.size() + _wereAcked.size());
for (auto i = _haveSent.cbegin(), e = _haveSent.cend(); i != e; ++i) { for (auto i = _haveSent.cbegin(), e = _haveSent.cend(); i != e; ++i) {
auto requestId = i.value()->requestId; auto requestId = i.value()->requestId;
if (!_receivedResponses.contains(requestId)) { if (!_receivedResponses.contains(requestId)) {
@ -147,21 +143,15 @@ Session::Session(
: QObject() : QObject()
, _instance(instance) , _instance(instance)
, _shiftedDcId(shiftedDcId) , _shiftedDcId(shiftedDcId)
, _dc(dc) , _ownedDc(dc ? nullptr : std::make_unique<Dcenter>(shiftedDcId, nullptr))
, _dc(dc ? dc : _ownedDc.get())
, _data(this) , _data(this)
, _timeouter([=] { checkRequestsByTimer(); }) , _timeouter([=] { checkRequestsByTimer(); })
, _sender([=] { needToResumeAndSend(); }) { , _sender([=] { needToResumeAndSend(); }) {
_timeouter.callEach(1000); _timeouter.callEach(1000);
refreshOptions(); refreshOptions();
if (_dc) { if (sharedDc()) {
if (const auto lock = ReadLockerAttempt(keyMutex())) { connect(_dc, SIGNAL(authKeyChanged()), this, SLOT(authKeyChangedForDC()), Qt::QueuedConnection);
_data.setKey(_dc->getKey());
if (_dc->connectionInited()) {
_data.setConnectionInited();
}
}
connect(_dc, SIGNAL(authKeyCreated()), this, SLOT(authKeyCreatedForDC()), Qt::QueuedConnection);
connect(_dc, SIGNAL(connectionWasInited()), this, SLOT(connectionWasInitedForDC()), Qt::QueuedConnection);
} }
} }
@ -199,7 +189,7 @@ void Session::refreshOptions() {
const auto useHttp = (proxyType != ProxyData::Type::Mtproto); const auto useHttp = (proxyType != ProxyData::Type::Mtproto);
const auto useIPv4 = true; const auto useIPv4 = true;
const auto useIPv6 = Global::TryIPv6(); const auto useIPv6 = Global::TryIPv6();
_data.applyConnectionOptions(ConnectionOptions( _data.setConnectionOptions(ConnectionOptions(
_instance->systemLangCode(), _instance->systemLangCode(),
_instance->cloudLangCode(), _instance->cloudLangCode(),
_instance->langPackName(), _instance->langPackName(),
@ -213,10 +203,7 @@ void Session::refreshOptions() {
} }
void Session::reInitConnection() { void Session::reInitConnection() {
if (_dc) { _dc->setConnectionInited(false);
_dc->setConnectionInited(false);
}
_data.setConnectionInited(false);
restart(); restart();
} }
@ -315,6 +302,10 @@ void Session::sendMsgsStateInfo(quint64 msgId, QByteArray data) {
MTP_msgs_state_info(MTP_long(msgId), MTP_bytes(data)))); MTP_msgs_state_info(MTP_long(msgId), MTP_bytes(data))));
} }
bool Session::sharedDc() const {
return (_ownedDc == nullptr);
}
void Session::checkRequestsByTimer() { void Session::checkRequestsByTimer() {
QVector<mtpMsgId> resendingIds; QVector<mtpMsgId> resendingIds;
QVector<mtpMsgId> removingIds; // remove very old (10 minutes) containers and resend requests QVector<mtpMsgId> removingIds; // remove very old (10 minutes) containers and resend requests
@ -555,51 +546,44 @@ void Session::sendPrepared(
sendAnything(msCanWait); sendAnything(msCanWait);
} }
QReadWriteLock *Session::keyMutex() const { void Session::authKeyChangedForDC() {
return _dc ? _dc->keyMutex() : nullptr; DEBUG_LOG(("AuthKey Info: Session::authKeyCreatedForDC slot, emitting authKeyChanged(), dcWithShift %1").arg(_shiftedDcId));
emit authKeyChanged();
} }
void Session::authKeyCreatedForDC() { bool Session::acquireKeyCreation() {
Expects(_dc != nullptr); return _dc->acquireKeyCreation();
DEBUG_LOG(("AuthKey Info: Session::authKeyCreatedForDC slot, emitting authKeyCreated(), dcWithShift %1").arg(_shiftedDcId));
_data.setKey(_dc->getKey());
emit authKeyCreated();
} }
void Session::notifyKeyCreated(AuthKeyPtr &&key) { void Session::releaseKeyCreationOnFail() {
DEBUG_LOG(("AuthKey Info: Session::keyCreated(), setting, dcWithShift %1").arg(_shiftedDcId)); _dc->releaseKeyCreationOnFail();
if (_dc) { }
_dc->setKey(std::move(key));
} else { void Session::releaseKeyCreationOnDone(AuthKeyPtr &&key) {
_data.setKey(std::move(key)); DEBUG_LOG(("AuthKey Info: Session key created, setting, dcWithShift %1").arg(_shiftedDcId));
emit authKeyCreated(); if (sharedDc()) {
const auto dcId = _dc->id();
const auto instance = _instance;
InvokeQueued(instance, [=] {
instance->setKeyForWrite(dcId, key);
});
} }
} _dc->releaseKeyCreationOnDone(std::move(key));
void Session::connectionWasInitedForDC() {
Expects(_dc != nullptr);
DEBUG_LOG(("MTP Info: Session::connectionWasInitedForDC slot, dcWithShift %1").arg(_shiftedDcId));
_data.setConnectionInited();
} }
void Session::notifyDcConnectionInited() { void Session::notifyDcConnectionInited() {
DEBUG_LOG(("MTP Info: emitting MTProtoDC::connectionWasInited(), dcWithShift %1").arg(_shiftedDcId)); DEBUG_LOG(("MTP Info: emitting MTProtoDC::connectionWasInited(), dcWithShift %1").arg(_shiftedDcId));
if (_dc) { _dc->setConnectionInited();
_dc->setConnectionInited();
} else {
_data.setConnectionInited();
}
} }
void Session::destroyKey() { void Session::destroyCdnKey(uint64 keyId) {
if (const auto key = _data.getKey()) { _dc->destroyCdnKey(keyId);
DEBUG_LOG(("MTP Info: destroying auth_key for dcWithShift %1").arg(_shiftedDcId)); if (sharedDc()) {
if (_dc && _dc->getKey() == key) { const auto dcId = _dc->id();
_dc->destroyKey(); const auto instance = _instance;
} InvokeQueued(instance, [=] {
_data.setKey(nullptr); instance->setKeyForWrite(dcId, nullptr);
});
} }
} }
@ -607,6 +591,14 @@ int32 Session::getDcWithShift() const {
return _shiftedDcId; return _shiftedDcId;
} }
AuthKeyPtr Session::getKey() const {
return _dc->getKey();
}
bool Session::connectionInited() const {
return _dc->connectionInited();
}
void Session::tryToReceive() { void Session::tryToReceive() {
if (_killed) { if (_killed) {
DEBUG_LOG(("Session Error: can't receive in a killed session")); DEBUG_LOG(("Session Error: can't receive in a killed session"));

View File

@ -131,7 +131,6 @@ struct ConnectionOptions {
bool useIPv6 = true; bool useIPv6 = true;
bool useHttp = true; bool useHttp = true;
bool useTcp = true; bool useTcp = true;
bool inited = false;
}; };
@ -141,6 +140,7 @@ public:
SessionData(not_null<Session*> creator) : _owner(creator) { SessionData(not_null<Session*> creator) : _owner(creator) {
} }
void setCurrentKeyId(uint64 keyId);
void setSessionId(uint64 sessionId) { void setSessionId(uint64 sessionId) {
DEBUG_LOG(("MTP Info: setting server_session: %1").arg(sessionId)); DEBUG_LOG(("MTP Info: setting server_session: %1").arg(sessionId));
@ -150,22 +150,16 @@ public:
_messagesSent = 0; _messagesSent = 0;
} }
} }
uint64 getSessionId() const { [[nodiscard]] uint64 getSessionId() const {
QReadLocker locker(&_lock); QReadLocker locker(&_lock);
return _sessionId; return _sessionId;
} }
void setConnectionInited(bool inited = true) {
QWriteLocker locker(&_lock);
_options.inited = inited;
}
void notifyConnectionInited(const ConnectionOptions &options); void notifyConnectionInited(const ConnectionOptions &options);
void applyConnectionOptions(ConnectionOptions options) { void setConnectionOptions(ConnectionOptions options) {
QWriteLocker locker(&_lock); QWriteLocker locker(&_lock);
const auto inited = _options.inited;
_options = options; _options = options;
_options.inited = inited;
} }
ConnectionOptions connectionOptions() const { [[nodiscard]] ConnectionOptions connectionOptions() const {
QReadLocker locker(&_lock); QReadLocker locker(&_lock);
return _options; return _options;
} }
@ -174,23 +168,16 @@ public:
QWriteLocker locker(&_lock); QWriteLocker locker(&_lock);
_salt = salt; _salt = salt;
} }
uint64 getSalt() const { [[nodiscard]] uint64 getSalt() const {
QReadLocker locker(&_lock); QReadLocker locker(&_lock);
return _salt; return _salt;
} }
const AuthKeyPtr &getKey() const { [[nodiscard]] const AuthKeyPtr &getKeyForCheck() const {
return _authKey;
}
void setKey(const AuthKeyPtr &key);
const AuthKeyPtr &getKeyForCheck() const {
return _dcKeyForCheck; return _dcKeyForCheck;
} }
void setKeyForCheck(const AuthKeyPtr &key); void setKeyForCheck(const AuthKeyPtr &key);
QReadWriteLock *keyMutex() const;
not_null<QReadWriteLock*> toSendMutex() const { not_null<QReadWriteLock*> toSendMutex() const {
return &_toSendLock; return &_toSendLock;
} }
@ -276,19 +263,17 @@ public:
return result * 2 + (needAck ? 1 : 0); return result * 2 + (needAck ? 1 : 0);
} }
void clear(Instance *instance); void clearForNewKey(not_null<Instance*> instance);
private: private:
uint64 _keyId = 0;
uint64 _sessionId = 0; uint64 _sessionId = 0;
uint64 _salt = 0; uint64 _salt = 0;
uint32 _messagesSent = 0; uint32 _messagesSent = 0;
not_null<Session*> _owner; not_null<Session*> _owner;
AuthKeyPtr _authKey;
AuthKeyPtr _dcKeyForCheck; AuthKeyPtr _dcKeyForCheck;
bool _layerInited = false;
ConnectionOptions _options; ConnectionOptions _options;
PreRequestMap _toSend; // map of request_id -> request, that is waiting to be sent PreRequestMap _toSend; // map of request_id -> request, that is waiting to be sent
@ -317,25 +302,34 @@ class Session : public QObject {
Q_OBJECT Q_OBJECT
public: public:
// Main thread.
Session( Session(
not_null<Instance*> instance, not_null<Instance*> instance,
ShiftedDcId shiftedDcId, ShiftedDcId shiftedDcId,
Dcenter *dc); Dcenter *dc);
~Session();
void start(); void start();
void reInitConnection();
void restart(); void restart();
void refreshOptions(); void refreshOptions();
void reInitConnection();
void stop(); void stop();
void kill(); void kill();
void unpaused(); void unpaused();
ShiftedDcId getDcWithShift() const; // Thread-safe.
[[nodiscard]] ShiftedDcId getDcWithShift() const;
[[nodiscard]] AuthKeyPtr getKey() const;
[[nodiscard]] bool connectionInited() const;
// Connection thread.
[[nodiscard]] bool acquireKeyCreation();
void releaseKeyCreationOnFail();
void releaseKeyCreationOnDone(AuthKeyPtr &&key);
void destroyCdnKey(uint64 keyId);
QReadWriteLock *keyMutex() const;
void notifyKeyCreated(AuthKeyPtr &&key);
void destroyKey();
void notifyDcConnectionInited(); void notifyDcConnectionInited();
void ping(); void ping();
@ -352,10 +346,8 @@ public:
crl::time msCanWait = 0, crl::time msCanWait = 0,
bool newRequest = true); bool newRequest = true);
~Session();
signals: signals:
void authKeyCreated(); void authKeyChanged();
void needToSend(); void needToSend();
void needToPing(); void needToPing();
void needToRestart(); void needToRestart();
@ -367,8 +359,7 @@ public slots:
void resendMany(QVector<quint64> msgIds, qint64 msCanWait, bool forceContainer, bool sendMsgStateInfo); void resendMany(QVector<quint64> msgIds, qint64 msCanWait, bool forceContainer, bool sendMsgStateInfo);
void resendAll(); // after connection restart void resendAll(); // after connection restart
void authKeyCreatedForDC(); void authKeyChangedForDC();
void connectionWasInitedForDC();
void tryToReceive(); void tryToReceive();
void onConnectionStateChange(qint32 newState); void onConnectionStateChange(qint32 newState);
@ -379,13 +370,15 @@ public slots:
void sendMsgsStateInfo(quint64 msgId, QByteArray data); void sendMsgsStateInfo(quint64 msgId, QByteArray data);
private: private:
[[nodiscard]] bool sharedDc() const;
void checkRequestsByTimer(); void checkRequestsByTimer();
bool rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err); bool rpcErrorOccured(mtpRequestId requestId, const RPCFailHandlerPtr &onFail, const RPCError &err);
const not_null<Instance*> _instance; const not_null<Instance*> _instance;
const ShiftedDcId _shiftedDcId = 0; const ShiftedDcId _shiftedDcId = 0;
Dcenter *_dc = nullptr; const std::unique_ptr<Dcenter> _ownedDc;
const not_null<Dcenter*> _dc;
std::unique_ptr<Connection> _connection; std::unique_ptr<Connection> _connection;
@ -406,38 +399,5 @@ private:
}; };
inline QReadWriteLock *SessionData::keyMutex() const {
return _owner->keyMutex();
}
class ReadLockerAttempt {
public:
ReadLockerAttempt(QReadWriteLock *lock) : _lock(lock), _locked(_lock ? _lock->tryLockForRead() : true) {
}
ReadLockerAttempt(const ReadLockerAttempt &other) = delete;
ReadLockerAttempt &operator=(const ReadLockerAttempt &other) = delete;
ReadLockerAttempt(ReadLockerAttempt &&other) : _lock(other._lock), _locked(base::take(other._locked)) {
}
ReadLockerAttempt &operator=(ReadLockerAttempt &&other) {
_lock = other._lock;
_locked = base::take(other._locked);
return *this;
}
~ReadLockerAttempt() {
if (_lock && _locked) {
_lock->unlock();
}
}
operator bool() const {
return _locked;
}
private:
QReadWriteLock *_lock = nullptr;
bool _locked = false;
};
} // namespace internal } // namespace internal
} // namespace MTP } // namespace MTP