From e426f2895b70d0f2b1d9c18b0d7a329a15cd74f6 Mon Sep 17 00:00:00 2001
From: John Preston <johnprestonmail@gmail.com>
Date: Thu, 5 Dec 2019 14:46:28 +0300
Subject: [PATCH] Add download sessions if it works fast enough.

---
 Telegram/SourceFiles/mtproto/mtp_instance.cpp |  20 ++
 Telegram/SourceFiles/mtproto/mtp_instance.h   |   3 +
 .../SourceFiles/mtproto/session_private.cpp   |  21 +-
 .../storage/download_manager_mtproto.cpp      | 309 +++++++++++++-----
 .../storage/download_manager_mtproto.h        |  46 ++-
 .../SourceFiles/storage/file_download_web.cpp |   4 +-
 6 files changed, 295 insertions(+), 108 deletions(-)

diff --git a/Telegram/SourceFiles/mtproto/mtp_instance.cpp b/Telegram/SourceFiles/mtproto/mtp_instance.cpp
index b48ddf734..0af0dc90d 100644
--- a/Telegram/SourceFiles/mtproto/mtp_instance.cpp
+++ b/Telegram/SourceFiles/mtproto/mtp_instance.cpp
@@ -85,6 +85,9 @@ public:
 	void badConfigurationError();
 	void syncHttpUnixtime();
 
+	void restartedByTimeout(ShiftedDcId shiftedDcId);
+	[[nodiscard]] rpl::producer<ShiftedDcId> restartsByTimeout() const;
+
 	void restart();
 	void restart(ShiftedDcId shiftedDcId);
 	[[nodiscard]] int32 dcstate(ShiftedDcId shiftedDcId = 0);
@@ -208,6 +211,7 @@ private:
 	Session *_mainSession = nullptr;
 	base::flat_map<ShiftedDcId, std::unique_ptr<Session>> _sessions;
 	std::vector<std::unique_ptr<Session>> _sessionsToDestroy;
+	rpl::event_stream<ShiftedDcId> _restartsByTimeout;
 
 	std::unique_ptr<ConfigLoader> _configLoader;
 	std::unique_ptr<DomainResolver> _domainResolver;
@@ -446,6 +450,14 @@ void Instance::Private::syncHttpUnixtime() {
 	});
 }
 
+void Instance::Private::restartedByTimeout(ShiftedDcId shiftedDcId) {
+	_restartsByTimeout.fire_copy(shiftedDcId);
+}
+
+rpl::producer<ShiftedDcId> Instance::Private::restartsByTimeout() const {
+	return _restartsByTimeout.events();
+}
+
 void Instance::Private::requestConfigIfOld() {
 	const auto timeout = Global::BlockedMode()
 		? kConfigBecomesOldForBlockedIn
@@ -1664,6 +1676,14 @@ void Instance::syncHttpUnixtime() {
 	_private->syncHttpUnixtime();
 }
 
+void Instance::restartedByTimeout(ShiftedDcId shiftedDcId) {
+	_private->restartedByTimeout(shiftedDcId);
+}
+
+rpl::producer<ShiftedDcId> Instance::restartsByTimeout() const {
+	return _private->restartsByTimeout();
+}
+
 void Instance::requestConfigIfOld() {
 	_private->requestConfigIfOld();
 }
diff --git a/Telegram/SourceFiles/mtproto/mtp_instance.h b/Telegram/SourceFiles/mtproto/mtp_instance.h
index 834ef947d..ed3797307 100644
--- a/Telegram/SourceFiles/mtproto/mtp_instance.h
+++ b/Telegram/SourceFiles/mtproto/mtp_instance.h
@@ -120,6 +120,9 @@ public:
 	void setUserPhone(const QString &phone);
 	void badConfigurationError();
 
+	void restartedByTimeout(ShiftedDcId shiftedDcId);
+	[[nodiscard]] rpl::producer<ShiftedDcId> restartsByTimeout() const;
+
 	void syncHttpUnixtime();
 
 	void sendAnything(ShiftedDcId shiftedDcId = 0, crl::time msCanWait = 0);
diff --git a/Telegram/SourceFiles/mtproto/session_private.cpp b/Telegram/SourceFiles/mtproto/session_private.cpp
index 9ee0039a5..96a95999f 100644
--- a/Telegram/SourceFiles/mtproto/session_private.cpp
+++ b/Telegram/SourceFiles/mtproto/session_private.cpp
@@ -1036,7 +1036,9 @@ void SessionPrivate::onSentSome(uint64 size) {
 		}
 		_waitForReceivedTimer.callOnce(remain);
 	}
-	if (!_firstSentAt) _firstSentAt = crl::now();
+	if (!_firstSentAt) {
+		_firstSentAt = crl::now();
+	}
 }
 
 void SessionPrivate::onReceivedSome() {
@@ -1048,7 +1050,9 @@ void SessionPrivate::onReceivedSome() {
 	_waitForReceivedTimer.cancel();
 	if (_firstSentAt > 0) {
 		const auto ms = crl::now() - _firstSentAt;
-		DEBUG_LOG(("MTP Info: response in %1ms, _waitForReceived: %2ms").arg(ms).arg(_waitForReceived));
+		DEBUG_LOG(("MTP Info: response in %1ms, _waitForReceived: %2ms"
+			).arg(ms
+			).arg(_waitForReceived));
 
 		if (ms > 0 && ms * 2 < _waitForReceived) {
 			_waitForReceived = qMax(ms * 2, kMinReceiveTimeout);
@@ -1060,7 +1064,8 @@ void SessionPrivate::onReceivedSome() {
 void SessionPrivate::markConnectionOld() {
 	_oldConnection = true;
 	_waitForReceived = kMinReceiveTimeout;
-	DEBUG_LOG(("This connection marked as old! _waitForReceived now %1ms").arg(_waitForReceived));
+	DEBUG_LOG(("This connection marked as old! _waitForReceived now %1ms"
+		).arg(_waitForReceived));
 }
 
 void SessionPrivate::sendPingByTimer() {
@@ -1094,10 +1099,6 @@ void SessionPrivate::sendPingForce() {
 void SessionPrivate::waitReceivedFailed() {
 	Expects(_options != nullptr);
 
-	if (!_options->useTcp) {
-		return;
-	}
-
 	DEBUG_LOG(("MTP Info: bad connection, _waitForReceived: %1ms").arg(_waitForReceived));
 	if (_waitForReceived < kMaxReceiveTimeout) {
 		_waitForReceived *= 2;
@@ -1109,6 +1110,12 @@ void SessionPrivate::waitReceivedFailed() {
 
 	DEBUG_LOG(("MTP Info: immediate restart!"));
 	InvokeQueued(this, [=] { connectToServer(); });
+
+	const auto instance = _instance;
+	const auto shiftedDcId = _shiftedDcId;
+	InvokeQueued(instance, [=] {
+		instance->restartedByTimeout(shiftedDcId);
+	});
 }
 
 void SessionPrivate::waitConnectedFailed() {
diff --git a/Telegram/SourceFiles/storage/download_manager_mtproto.cpp b/Telegram/SourceFiles/storage/download_manager_mtproto.cpp
index 92c0d828f..f60787f1c 100644
--- a/Telegram/SourceFiles/storage/download_manager_mtproto.cpp
+++ b/Telegram/SourceFiles/storage/download_manager_mtproto.cpp
@@ -17,21 +17,23 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
 namespace Storage {
 namespace {
 
-// How much time without download causes additional session kill.
 constexpr auto kKillSessionTimeout = 15 * crl::time(1000);
-
-// Max 16 file parts downloaded at the same time, 128 KB each.
-constexpr auto kMaxFileQueries = 16;
-
-constexpr auto kMaxWaitedInConnection = 512 * 1024;
-
-// Max 8 http[s] files downloaded at the same time.
-constexpr auto kMaxWebFileQueries = 8;
-
+constexpr auto kMaxWaitedInSession = 2 * 1024 * 1024;
 constexpr auto kStartSessionsCount = 1;
 constexpr auto kMaxSessionsCount = 8;
+constexpr auto kMaxTrackedSessionRemoves = 64;
+constexpr auto kRetryAddSessionTimeout = 8 * crl::time(1000);
+constexpr auto kRetryAddSessionSuccesses = 3;
+constexpr auto kMaxTrackedSuccesses = kRetryAddSessionSuccesses
+	* kMaxTrackedSessionRemoves;
+constexpr auto kRemoveSessionAfterTimeouts = 2;
 constexpr auto kResetDownloadPrioritiesTimeout = crl::time(200);
 
+// Each session remove by timeouts we wait for time
+// kRetryAddSessionTimeout * max(removesCount, kMaxTrackedSessionRemoves)
+// and for successes in all remaining sessions
+// kRetryAddSessionSuccesses * max(removesCount, kMaxTrackedSessionRemoves)
+
 } // namespace
 
 void DownloadManagerMtproto::Queue::enqueue(not_null<Task*> task) {
@@ -74,14 +76,30 @@ auto DownloadManagerMtproto::Queue::nextTask() const -> Task* {
 	return (i != all.end()) ? i->get() : nullptr;
 }
 
+DownloadManagerMtproto::DcSessionBalanceData::DcSessionBalanceData()
+: maxWaitedAmount(kDownloadPartSize) {
+}
+
+DownloadManagerMtproto::DcBalanceData::DcBalanceData()
+: sessions(kStartSessionsCount) {
+}
+
 DownloadManagerMtproto::DownloadManagerMtproto(not_null<ApiWrap*> api)
 : _api(api)
 , _resetGenerationTimer([=] { resetGeneration(); })
-, _killDownloadSessionsTimer([=] { killDownloadSessions(); }) {
+, _killSessionsTimer([=] { killSessions(); }) {
+	_api->instance()->restartsByTimeout(
+	) | rpl::filter([](MTP::ShiftedDcId shiftedDcId) {
+		return MTP::isDownloadDcId(shiftedDcId);
+	}) | rpl::start_with_next([=](MTP::ShiftedDcId shiftedDcId) {
+		sessionTimedOut(
+			MTP::BareDcId(shiftedDcId),
+			MTP::GetDcIdShift(shiftedDcId));
+	}, _lifetime);
 }
 
 DownloadManagerMtproto::~DownloadManagerMtproto() {
-	killDownloadSessions();
+	killSessions();
 }
 
 void DownloadManagerMtproto::enqueue(not_null<Task*> task) {
@@ -98,7 +116,6 @@ void DownloadManagerMtproto::remove(not_null<Task*> task) {
 	const auto dcId = task->dcId();
 	auto &queue = _queues[dcId];
 	queue.remove(task);
-	checkSendNext(dcId, queue);
 }
 
 void DownloadManagerMtproto::resetGeneration() {
@@ -118,89 +135,158 @@ void DownloadManagerMtproto::checkSendNext() {
 }
 
 void DownloadManagerMtproto::checkSendNext(MTP::DcId dcId, Queue &queue) {
+	while (trySendNextPart(dcId, queue)) {
+	}
+}
+
+bool DownloadManagerMtproto::trySendNextPart(MTP::DcId dcId, Queue &queue) {
 	const auto bestIndex = [&] {
-		const auto i = _requestedBytesAmount.find(dcId);
-		if (i == end(_requestedBytesAmount)) {
-			_requestedBytesAmount[dcId].resize(kStartSessionsCount);
-			return 0;
-		}
-		const auto j = ranges::min_element(i->second);
-		const auto already = *j;
-		return (already + kDownloadPartSize <= kMaxWaitedInConnection)
-			? (j - begin(i->second))
+		const auto &sessions = _balanceData[dcId].sessions;
+		const auto proj = [](const DcSessionBalanceData &data) {
+			return (data.requested < data.maxWaitedAmount)
+				? data.requested
+				: kMaxWaitedInSession;
+		};
+		const auto j = ranges::min_element(sessions, ranges::less(), proj);
+		return (j->requested + kDownloadPartSize <= j->maxWaitedAmount)
+			? (j - begin(sessions))
 			: -1;
 	}();
-	if (bestIndex >= 0) {
-		if (const auto task = queue.nextTask()) {
-			task->loadPart(bestIndex);
-		}
+	if (bestIndex < 0) {
+		return false;
 	}
+	if (const auto task = queue.nextTask()) {
+		task->loadPart(bestIndex);
+		return true;
+	}
+	return false;
 }
 
-void DownloadManagerMtproto::requestedAmountIncrement(
+void DownloadManagerMtproto::changeRequestedAmount(
 		MTP::DcId dcId,
 		int index,
-		int amount) {
-	Expects(dcId != 0);
+		int delta) {
+	const auto i = _balanceData.find(dcId);
+	Assert(i != _balanceData.end());
+	Assert(index < i->second.sessions.size());
+	i->second.sessions[index].requested += delta;
+	const auto findNonEmptySession = [](const DcBalanceData &data) {
+		using namespace rpl::mappers;
+		return ranges::find_if(
+			data.sessions,
+			_1 > 0,
+			&DcSessionBalanceData::requested);
+	};
+	if (delta > 0) {
+		killSessionsCancel(dcId);
+	} else if (findNonEmptySession(i->second) == end(i->second.sessions)) {
+		killSessionsSchedule(dcId);
+	}
+}
 
+void DownloadManagerMtproto::requestSucceeded(MTP::DcId dcId, int index) {
 	using namespace rpl::mappers;
 
-	auto it = _requestedBytesAmount.find(dcId);
-	if (it == _requestedBytesAmount.end()) {
-		it = _requestedBytesAmount.emplace(
-			dcId,
-			std::vector<int>(dcId ? kStartSessionsCount : 1, 0)
-		).first;
+	DEBUG_LOG(("Download (%1,%2) request done.").arg(dcId).arg(index));
+	const auto i = _balanceData.find(dcId);
+	Assert(i != end(_balanceData));
+	auto &dc = i->second;
+	Assert(index < dc.sessions.size());
+	auto &data = dc.sessions[index];
+	data.successes = std::min(data.successes + 1, kMaxTrackedSuccesses);
+	data.maxWaitedAmount = std::min(
+		data.maxWaitedAmount + kDownloadPartSize,
+		kMaxWaitedInSession);
+	const auto notEnough = ranges::find_if(
+		dc.sessions,
+		_1 < (dc.sessionRemoveTimes + 1) * kRetryAddSessionSuccesses,
+		&DcSessionBalanceData::successes);
+	if (notEnough != end(dc.sessions)) {
+		return;
 	}
-	it->second[index] += amount;
-	if (amount > 0) {
-		killDownloadSessionsStop(dcId);
+	for (auto &session : dc.sessions) {
+		session.successes = 0;
+	}
+	if (dc.timeouts > 0) {
+		--dc.timeouts;
+		return;
+	} else if (dc.sessions.size() == kMaxSessionsCount) {
+		return;
+	}
+	const auto now = crl::now();
+	const auto delay = (dc.sessionRemoveTimes + 1) * kRetryAddSessionTimeout;
+	if (dc.lastSessionRemove && now < dc.lastSessionRemove + delay) {
+		return;
+	}
+	DEBUG_LOG(("Download (%1,%2) added session."
+		).arg(dcId
+		).arg(dc.sessions.size()));
+	dc.sessions.emplace_back();
+	checkSendNext(dcId, _queues[dcId]);
+}
+
+void DownloadManagerMtproto::sessionTimedOut(MTP::DcId dcId, int index) {
+	const auto i = _balanceData.find(dcId);
+	if (i == end(_balanceData)) {
+		return;
+	}
+	auto &dc = i->second;
+	if (index >= dc.sessions.size()) {
+		return;
+	}
+	DEBUG_LOG(("Download (%1,%2) session timed-out.").arg(dcId).arg(index));
+	for (auto &session : dc.sessions) {
+		session.successes = 0;
+	}
+	if (dc.sessions.size() == kStartSessionsCount
+		|| ++dc.timeouts < kRemoveSessionAfterTimeouts) {
+		return;
+	}
+	dc.timeouts = 0;
+	removeSession(dcId);
+}
+
+void DownloadManagerMtproto::removeSession(MTP::DcId dcId) {
+	auto &dc = _balanceData[dcId];
+	Assert(dc.sessions.size() > kStartSessionsCount);
+	const auto index = int(dc.sessions.size() - 1);
+	DEBUG_LOG(("Download (%1,%2) removing session.").arg(dcId).arg(index));
+	auto &queue = _queues[dcId];
+	if (dc.sessionRemoveIndex == index) {
+		dc.sessionRemoveTimes = std::min(
+			dc.sessionRemoveTimes + 1,
+			kMaxTrackedSessionRemoves);
 	} else {
-		crl::on_main(this, [=] { checkSendNext(); });
-		if (ranges::find_if(it->second, _1 > 0) == end(it->second)) {
-			killDownloadSessionsStart(dcId);
-		}
+		dc.sessionRemoveIndex = index;
+		dc.sessionRemoveTimes = 1;
+	}
+	dc.lastSessionRemove = crl::now();
+//	dc.sessions.pop_back();
+}
+
+void DownloadManagerMtproto::killSessionsSchedule(MTP::DcId dcId) {
+	if (!_killSessionsWhen.contains(dcId)) {
+		_killSessionsWhen.emplace(dcId, crl::now() + kKillSessionTimeout);
+	}
+	if (!_killSessionsTimer.isActive()) {
+		_killSessionsTimer.callOnce(kKillSessionTimeout + 5);
 	}
 }
 
-int DownloadManagerMtproto::chooseDcIndexForRequest(MTP::DcId dcId) {
-	const auto i = _requestedBytesAmount.find(dcId);
-	return (i != end(_requestedBytesAmount))
-		? (ranges::min_element(i->second) - begin(i->second))
-		: 0;
-}
-
-void DownloadManagerMtproto::killDownloadSessionsStart(MTP::DcId dcId) {
-	if (!_killDownloadSessionTimes.contains(dcId)) {
-		_killDownloadSessionTimes.emplace(
-			dcId,
-			crl::now() + kKillSessionTimeout);
-	}
-	if (!_killDownloadSessionsTimer.isActive()) {
-		_killDownloadSessionsTimer.callOnce(kKillSessionTimeout + 5);
+void DownloadManagerMtproto::killSessionsCancel(MTP::DcId dcId) {
+	_killSessionsWhen.erase(dcId);
+	if (_killSessionsWhen.empty()) {
+		_killSessionsTimer.cancel();
 	}
 }
 
-void DownloadManagerMtproto::killDownloadSessionsStop(MTP::DcId dcId) {
-	_killDownloadSessionTimes.erase(dcId);
-	if (_killDownloadSessionTimes.empty()
-		&& _killDownloadSessionsTimer.isActive()) {
-		_killDownloadSessionsTimer.cancel();
-	}
-}
-
-void DownloadManagerMtproto::killDownloadSessions() {
+void DownloadManagerMtproto::killSessions() {
 	const auto now = crl::now();
 	auto left = kKillSessionTimeout;
-	for (auto i = _killDownloadSessionTimes.begin(); i != _killDownloadSessionTimes.end(); ) {
+	for (auto i = begin(_killSessionsWhen); i != end(_killSessionsWhen); ) {
 		if (i->second <= now) {
-			const auto j = _requestedBytesAmount.find(i->first);
-			if (j != end(_requestedBytesAmount)) {
-				for (auto index = 0; index != int(j->second.size()); ++index) {
-					MTP::stopSession(MTP::downloadDcId(i->first, index));
-				}
-			}
-			i = _killDownloadSessionTimes.erase(i);
+			killSessions(i->first);
+			i = _killSessionsWhen.erase(i);
 		} else {
 			if (i->second - now < left) {
 				left = i->second - now;
@@ -208,8 +294,23 @@ void DownloadManagerMtproto::killDownloadSessions() {
 			++i;
 		}
 	}
-	if (!_killDownloadSessionTimes.empty()) {
-		_killDownloadSessionsTimer.callOnce(left);
+	if (!_killSessionsWhen.empty()) {
+		_killSessionsTimer.callOnce(left);
+	}
+}
+
+void DownloadManagerMtproto::killSessions(MTP::DcId dcId) {
+	const auto i = _balanceData.find(dcId);
+	if (i != end(_balanceData)) {
+		auto &dc = i->second;
+		auto sessions = base::take(dc.sessions);
+		dc = DcBalanceData();
+		for (auto j = 0; j != int(sessions.size()); ++j) {
+			Assert(sessions[j].requested == 0);
+			sessions[j] = DcSessionBalanceData();
+			MTP::stopSession(MTP::downloadDcId(dcId, j));
+		}
+		dc.sessions = base::take(sessions);
 	}
 }
 
@@ -270,7 +371,11 @@ void DownloadMtprotoTask::refreshFileReferenceFrom(
 		cancelOnFail();
 		return;
 	}
-	makeRequest(finishSentRequest(requestId));
+	if (_sentRequests.contains(requestId)) {
+		makeRequest(finishSentRequest(
+			requestId,
+			FinishRequestReason::Redirect));
+	}
 }
 
 void DownloadMtprotoTask::loadPart(int dcIndex) {
@@ -370,7 +475,9 @@ void DownloadMtprotoTask::requestMoreCdnFileHashes() {
 void DownloadMtprotoTask::normalPartLoaded(
 		const MTPupload_File &result,
 		mtpRequestId requestId) {
-	const auto requestData = finishSentRequest(requestId);
+	const auto requestData = finishSentRequest(
+		requestId,
+		FinishRequestReason::Success);
 	result.match([&](const MTPDupload_fileCdnRedirect &data) {
 		switchToCDN(requestData, data);
 	}, [&](const MTPDupload_file &data) {
@@ -382,7 +489,9 @@ void DownloadMtprotoTask::webPartLoaded(
 		const MTPupload_WebFile &result,
 		mtpRequestId requestId) {
 	result.match([&](const MTPDupload_webFile &data) {
-		const auto requestData = finishSentRequest(requestId);
+		const auto requestData = finishSentRequest(
+			requestId,
+			FinishRequestReason::Success);
 		if (setWebFileSizeHook(data.vsize().v)) {
 			partLoaded(requestData.offset, data.vbytes().v);
 		}
@@ -390,8 +499,10 @@ void DownloadMtprotoTask::webPartLoaded(
 }
 
 void DownloadMtprotoTask::cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId requestId) {
-	const auto requestData = finishSentRequest(requestId);
 	result.match([&](const MTPDupload_cdnFileReuploadNeeded &data) {
+		const auto requestData = finishSentRequest(
+			requestId,
+			FinishRequestReason::Redirect);
 		const auto shiftedDcId = MTP::downloadDcId(
 			dcId(),
 			requestData.dcIndex);
@@ -405,6 +516,9 @@ void DownloadMtprotoTask::cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequ
 		}).toDC(shiftedDcId).send();
 		placeSentRequest(requestId, requestData);
 	}, [&](const MTPDupload_cdnFile &data) {
+		const auto requestData = finishSentRequest(
+			requestId,
+			FinishRequestReason::Success);
 		auto key = bytes::make_span(_cdnEncryptionKey);
 		auto iv = bytes::make_span(_cdnEncryptionIV);
 		Expects(key.size() == MTP::CTRState::KeySize);
@@ -462,7 +576,9 @@ DownloadMtprotoTask::CheckCdnHashResult DownloadMtprotoTask::checkCdnFileHash(
 void DownloadMtprotoTask::reuploadDone(
 		const MTPVector<MTPFileHash> &result,
 		mtpRequestId requestId) {
-	const auto requestData = finishSentRequest(requestId);
+	const auto requestData = finishSentRequest(
+		requestId,
+		FinishRequestReason::Redirect);
 	addCdnHashes(result.v);
 	makeRequest(requestData);
 }
@@ -474,7 +590,9 @@ void DownloadMtprotoTask::getCdnFileHashesDone(
 
 	_cdnHashesRequestId = 0;
 
-	const auto requestData = finishSentRequest(requestId);
+	const auto requestData = finishSentRequest(
+		requestId,
+		FinishRequestReason::Redirect);
 	addCdnHashes(result.v);
 	auto someMoreChecked = false;
 	for (auto i = _cdnUncheckedParts.begin(); i != _cdnUncheckedParts.cend();) {
@@ -521,7 +639,7 @@ void DownloadMtprotoTask::getCdnFileHashesDone(
 void DownloadMtprotoTask::placeSentRequest(
 		mtpRequestId requestId,
 		const RequestData &requestData) {
-	_owner->requestedAmountIncrement(
+	_owner->changeRequestedAmount(
 		dcId(),
 		requestData.dcIndex,
 		Storage::kDownloadPartSize);
@@ -533,19 +651,26 @@ void DownloadMtprotoTask::placeSentRequest(
 	Ensures(ok1 && ok2);
 }
 
-auto DownloadMtprotoTask::finishSentRequest(mtpRequestId requestId)
+auto DownloadMtprotoTask::finishSentRequest(
+	mtpRequestId requestId,
+	FinishRequestReason reason)
 -> RequestData {
 	auto it = _sentRequests.find(requestId);
 	Assert(it != _sentRequests.cend());
 
 	const auto result = it->second;
-	_owner->requestedAmountIncrement(
+	_owner->changeRequestedAmount(
 		dcId(),
 		result.dcIndex,
 		-Storage::kDownloadPartSize);
 	_sentRequests.erase(it);
-	_requestByOffset.erase(result.offset);
+	const auto ok = _requestByOffset.remove(result.offset);
 
+	if (reason == FinishRequestReason::Success) {
+		_owner->requestSucceeded(dcId(), result.dcIndex);
+	}
+
+	Ensures(ok);
 	return result;
 }
 
@@ -575,7 +700,9 @@ void DownloadMtprotoTask::cancelRequestForOffset(int offset) {
 
 void DownloadMtprotoTask::cancelRequest(mtpRequestId requestId) {
 	api().request(requestId).cancel();
-	[[maybe_unused]] const auto data = finishSentRequest(requestId);
+	[[maybe_unused]] const auto data = finishSentRequest(
+		requestId,
+		FinishRequestReason::Cancel);
 }
 
 void DownloadMtprotoTask::addToQueue() {
@@ -633,7 +760,9 @@ bool DownloadMtprotoTask::cdnPartFailed(
 	}
 	if (error.type() == qstr("FILE_TOKEN_INVALID")
 		|| error.type() == qstr("REQUEST_TOKEN_INVALID")) {
-		const auto requestData = finishSentRequest(requestId);
+		const auto requestData = finishSentRequest(
+			requestId,
+			FinishRequestReason::Redirect);
 		changeCDNParams(
 			requestData,
 			0,
@@ -702,7 +831,9 @@ void DownloadMtprotoTask::changeCDNParams(
 		while (!_sentRequests.empty()) {
 			const auto requestId = _sentRequests.begin()->first;
 			api().request(requestId).cancel();
-			resendRequests.push_back(finishSentRequest(requestId));
+			resendRequests.push_back(finishSentRequest(
+				requestId,
+				FinishRequestReason::Redirect));
 		}
 		for (const auto &requestData : resendRequests) {
 			makeRequest(requestData);
diff --git a/Telegram/SourceFiles/storage/download_manager_mtproto.h b/Telegram/SourceFiles/storage/download_manager_mtproto.h
index 7b78ad35e..fb4dade18 100644
--- a/Telegram/SourceFiles/storage/download_manager_mtproto.h
+++ b/Telegram/SourceFiles/storage/download_manager_mtproto.h
@@ -42,8 +42,8 @@ public:
 		return _taskFinishedObservable;
 	}
 
-	void requestedAmountIncrement(MTP::DcId dcId, int index, int amount);
-	[[nodiscard]] int chooseDcIndexForRequest(MTP::DcId dcId);
+	void changeRequestedAmount(MTP::DcId dcId, int index, int delta);
+	void requestSucceeded(MTP::DcId dcId, int index);
 
 private:
 	class Queue final {
@@ -59,27 +59,48 @@ private:
 		std::vector<not_null<Task*>> _previousGeneration;
 
 	};
+	struct DcSessionBalanceData {
+		DcSessionBalanceData();
+
+		int requested = 0;
+		int successes = 0; // Since last timeout in this dc in any session.
+		int maxWaitedAmount = 0;
+	};
+	struct DcBalanceData {
+		DcBalanceData();
+
+		std::vector<DcSessionBalanceData> sessions;
+		crl::time lastSessionRemove = 0;
+		int sessionRemoveIndex = 0;
+		int sessionRemoveTimes = 0;
+		int timeouts = 0; // Since all sessions had successes >= required.
+	};
 
 	void checkSendNext();
 	void checkSendNext(MTP::DcId dcId, Queue &queue);
+	bool trySendNextPart(MTP::DcId dcId, Queue &queue);
 
-	void killDownloadSessionsStart(MTP::DcId dcId);
-	void killDownloadSessionsStop(MTP::DcId dcId);
-	void killDownloadSessions();
+	void killSessionsSchedule(MTP::DcId dcId);
+	void killSessionsCancel(MTP::DcId dcId);
+	void killSessions();
+	void killSessions(MTP::DcId dcId);
 
 	void resetGeneration();
+	void sessionTimedOut(MTP::DcId dcId, int index);
+	void removeSession(MTP::DcId dcId);
 
 	const not_null<ApiWrap*> _api;
 
 	base::Observable<void> _taskFinishedObservable;
 
-	base::flat_map<MTP::DcId, std::vector<int>> _requestedBytesAmount;
+	base::flat_map<MTP::DcId, DcBalanceData> _balanceData;
 	base::Timer _resetGenerationTimer;
 
-	base::flat_map<MTP::DcId, crl::time> _killDownloadSessionTimes;
-	base::Timer _killDownloadSessionsTimer;
+	base::flat_map<MTP::DcId, crl::time> _killSessionsWhen;
+	base::Timer _killSessionsTimer;
 
 	base::flat_map<MTP::DcId, Queue> _queues;
+	rpl::lifetime _lifetime;
 
 };
 
@@ -148,6 +169,11 @@ private:
 		Invalid,
 		Good,
 	};
+	enum class FinishRequestReason {
+		Success,
+		Redirect,
+		Cancel,
+	};
 
 	// Called only if readyToRequest() == true.
 	[[nodiscard]] virtual int takeNextRequestOffset() = 0;
@@ -187,7 +213,9 @@ private:
 	void placeSentRequest(
 		mtpRequestId requestId,
 		const RequestData &requestData);
-	[[nodiscard]] RequestData finishSentRequest(mtpRequestId requestId);
+	[[nodiscard]] RequestData finishSentRequest(
+		mtpRequestId requestId,
+		FinishRequestReason reason);
 	void switchToCDN(
 		const RequestData &requestData,
 		const MTPDupload_fileCdnRedirect &redirect);
diff --git a/Telegram/SourceFiles/storage/file_download_web.cpp b/Telegram/SourceFiles/storage/file_download_web.cpp
index 56764a474..5d1878382 100644
--- a/Telegram/SourceFiles/storage/file_download_web.cpp
+++ b/Telegram/SourceFiles/storage/file_download_web.cpp
@@ -222,7 +222,6 @@ void WebLoadManager::remove(int id) {
 		ranges::remove(_previousGeneration, id, &Enqueued::id),
 		end(_previousGeneration));
 	removeSent(id);
-	checkSendNext();
 }
 
 void WebLoadManager::resetGeneration() {
@@ -258,6 +257,7 @@ void WebLoadManager::removeSent(int id) {
 	if (const auto i = _sent.find(id); i != end(_sent)) {
 		deleteDeferred(i->second.reply);
 		_sent.erase(i);
+		checkSendNext();
 	}
 }
 
@@ -265,11 +265,9 @@ not_null<QNetworkReply*> WebLoadManager::send(int id, const QString &url) {
 	const auto result = _network.get(QNetworkRequest(url));
 	const auto handleProgress = [=](qint64 ready, qint64 total) {
 		progress(id, result, ready, total);
-		checkSendNext();
 	};
 	const auto handleError = [=](QNetworkReply::NetworkError error) {
 		failed(id, result, error);
-		checkSendNext();
 	};
 	connect(result, &QNetworkReply::downloadProgress, handleProgress);
 	connect(result, QNetworkReply_error, handleError);