From 6aa930d510d26daebeeffc218c1fd57b34c11af8 Mon Sep 17 00:00:00 2001 From: John Preston <johnprestonmail@gmail.com> Date: Thu, 18 Jul 2019 16:06:38 +0200 Subject: [PATCH] Remove exceptions in TL parsing. --- .../codegen/scheme/codegen_scheme.py | 101 +++++++---- Telegram/SourceFiles/core/application.cpp | 3 +- Telegram/SourceFiles/mainwidget.cpp | 42 ++--- Telegram/SourceFiles/mainwidget.h | 2 +- .../SourceFiles/mtproto/concurrent_sender.cpp | 11 +- .../SourceFiles/mtproto/concurrent_sender.h | 7 +- Telegram/SourceFiles/mtproto/connection.cpp | 140 +++++++++------ Telegram/SourceFiles/mtproto/connection.h | 5 +- .../mtproto/connection_abstract.cpp | 9 +- .../SourceFiles/mtproto/connection_abstract.h | 2 +- .../SourceFiles/mtproto/connection_http.cpp | 34 ++-- .../SourceFiles/mtproto/connection_tcp.cpp | 10 +- Telegram/SourceFiles/mtproto/core_types.cpp | 153 ++++++++-------- Telegram/SourceFiles/mtproto/core_types.h | 115 ++++++------ Telegram/SourceFiles/mtproto/mtp_instance.cpp | 30 ++-- Telegram/SourceFiles/mtproto/rpc_sender.h | 170 +++++++++++------- Telegram/SourceFiles/mtproto/sender.h | 9 +- .../mtproto/special_config_request.cpp | 4 +- 18 files changed, 463 insertions(+), 384 deletions(-) diff --git a/Telegram/SourceFiles/codegen/scheme/codegen_scheme.py b/Telegram/SourceFiles/codegen/scheme/codegen_scheme.py index bd4b8f3e9..319c2181a 100644 --- a/Telegram/SourceFiles/codegen/scheme/codegen_scheme.py +++ b/Telegram/SourceFiles/codegen/scheme/codegen_scheme.py @@ -391,19 +391,24 @@ for line in lines: funcsText += '\tmtpTypeId type() const {\n\t\treturn mtpc_' + name + ';\n\t}\n'; # type id - funcsText += '\tvoid read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_' + name + ');\n'; # read method + funcsText += '\t[[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_' + name + ');\n'; # read method if (isTemplate != ''): methodBodies += 'template <typename TQueryType>\n' - methodBodies += 'void MTP' + name + '<TQueryType>::read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons) {\n'; + methodBodies += 'bool MTP' + name + '<TQueryType>::read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons) {\n'; else: - methodBodies += 'void MTP' + name + '::read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons) {\n'; + methodBodies += 'bool MTP' + name + '::read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons) {\n'; + readFunc = '' for k in prmsList: v = prms[k]; if (k in conditionsList): if (not k in trivialConditions): - methodBodies += '\tif (_' + hasFlags + '.v & Flag::f_' + k + ') { _' + k + '.read(from, end); } else { _' + k + ' = MTP' + v + '(); }\n'; + readFunc += '\t\t&& ((_' + hasFlags + '.v & Flag::f_' + k + ') ? _' + k + '.read(from, end) : ((_' + k + ' = MTP' + v + '()), true))\n'; else: - methodBodies += '\t_' + k + '.read(from, end);\n'; + readFunc += '\t\t&& _' + k + '.read(from, end)\n'; + if readFunc != '': + methodBodies += '\treturn' + readFunc[4:len(readFunc)-1] + ';\n'; + else: + methodBodies += '\treturn true;\n'; methodBodies += '}\n'; funcsText += '\tvoid write(mtpBuffer &to) const;\n'; # write method @@ -485,7 +490,7 @@ def addTextSerialize(lst, dct, dataLetter): if (isTemplate != ''): templateArgument = '<MTP::SecureRequest>' - result += 'void Serialize_' + name + '(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag) {\n'; + result += 'bool Serialize_' + name + '(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag) {\n'; if (len(conditions)): result += '\tauto flag = MTP' + dataLetter + name + templateArgument + '::Flags::from_raw(iflag);\n\n'; if (len(prms)): @@ -501,7 +506,7 @@ def addTextSerialize(lst, dct, dataLetter): v = prms[k]; result += '\tcase ' + str(stage) + ': to.add(" ' + k + ': "); ++stages.back(); '; if (k == hasFlags): - result += 'if (start >= end) throw Exception("start >= end in flags"); else flags.back() = *start; '; + result += 'if (start >= end) return false; else flags.back() = *start; '; if (k in trivialConditions): result += 'if (flag & MTP' + dataLetter + name + templateArgument + '::Flag::f_' + k + ') { '; result += 'to.add("YES [ BY BIT ' + conditions[k] + ' IN FIELD ' + hasFlags + ' ]"); '; @@ -564,6 +569,7 @@ def addTextSerialize(lst, dct, dataLetter): result += '\t}\n'; else: result += '\tto.add("{ ' + name + ' }"); types.pop_back(); vtypes.pop_back(); stages.pop_back(); flags.pop_back();\n'; + result += '\treturn true;\n'; result += '}\n\n'; return result; @@ -693,29 +699,30 @@ for restype in typesList: creatorParams.append('const MTP' + paramType + ' &' + paramName + '_'); creatorParamsList.append(paramName + '_'); prmsInit.append('_' + paramName + '(' + paramName + '_)'); - if (withType): - readText += '\t'; - writeText += '\t'; if (paramName in conditions): - readText += '\tif (v' + paramName + '()) { _' + paramName + '.read(from, end); } else { _' + paramName + ' = MTP' + paramType + '(); }\n'; + readText += '\t\t&& (v' + paramName + '() ? _' + paramName + '.read(from, end) : ((_' + paramName + ' = MTP' + paramType + '()), true))\n'; writeText += '\tif (const auto v' + paramName + ' = v.v' + paramName + '()) v' + paramName + '->write(to);\n'; sizeList.append('(v.v' + paramName + '() ? v.v' + paramName + '()->innerLength() : 0)'); else: - readText += '\t_' + paramName + '.read(from, end);\n'; - writeText += '\tv.v' + paramName + '().write(to);\n'; + readText += '\t\t&& _' + paramName + '.read(from, end)\n'; + writeText += '\t\tv.v' + paramName + '().write(to);\n'; sizeList.append('v.v' + paramName + '().innerLength()'); dataText += ', '.join(prmsStr) + ');\n'; constructsBodies += 'MTPD' + name + '::MTPD' + name + '(' + ', '.join(prmsStr) + ') : ' + ', '.join(prmsInit) + ' {\n}\n'; - dataText += '\tMTPD' + name + '(const mtpPrime *&from, const mtpPrime *end);\n'; # reading constructor + dataText += '\n'; + dataText += '\t[[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end);\n'; + dataText += '\n'; - constructsBodies += 'MTPD' + name + '::MTPD' + name + '(const mtpPrime *&from, const mtpPrime *end) {\n'; - constructsBodies += readText; + constructsBodies += 'bool MTPD' + name + '::read(const mtpPrime *&from, const mtpPrime *end) {\n'; + if readText != '': + constructsBodies += '\treturn' + readText[4:len(readText)-1] + ';\n'; + else: + constructsBodies += '\treturn true;\n'; constructsBodies += '}\n'; - dataText += '\n'; if len(prmsList) > 0: for paramName in prmsList: # getters if (paramName in trivialConditions): @@ -780,7 +787,12 @@ for restype in typesList: reader += '\tcase mtpc_' + name + ': _type = cons; '; # read switch line if (len(prms) > len(trivialConditions)): reader += '{\n'; - reader += '\t\tsetData(new MTPD' + name + '(from, end));\n'; + reader += '\t\tif (const auto data = new MTPD' + name + '(); data->read(from, end)) {\n'; + reader += '\t\t\tsetData(data);\n'; + reader += '\t\t} else {\n'; + reader += '\t\t\tdelete data;\n'; + reader += '\t\t\treturn false;\n'; + reader += '\t\t}\n'; reader += '\t} break;\n'; writer += '\tcase mtpc_' + name + ': {\n'; # write switch line @@ -791,7 +803,12 @@ for restype in typesList: reader += 'break;\n'; else: if (len(prms) > len(trivialConditions)): - reader += '\tsetData(new MTPD' + name + '(from, end));\n'; + reader += '\tif (const auto data = new MTPD' + name + '(); data->read(from, end)) {\n'; + reader += '\t\tsetData(data);\n'; + reader += '\t} else {\n'; + reader += '\t\tdelete data;\n'; + reader += '\t\treturn false;\n'; + reader += '\t}\n'; writer += '\tconst MTPD' + name + ' &v = c_' + name + '();\n'; writer += writeText; @@ -844,21 +861,22 @@ for restype in typesList: methods += '\treturn mtpc_' + v[0][0] + ';\n'; methods += '}\n'; - typesText += '\tvoid read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons'; # read method + typesText += '\t[[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons'; # read method if (not withType): typesText += ' = mtpc_' + name; typesText += ');\n'; - methods += 'void MTP' + restype + '::read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons) {\n'; + methods += 'bool MTP' + restype + '::read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons) {\n'; if (withData): if not (withType): - methods += '\tif (cons != mtpc_' + v[0][0] + ') throw mtpErrorUnexpected(cons, "MTP' + restype + '");\n'; + methods += '\tif (cons != mtpc_' + v[0][0] + ') return false;\n'; if (withType): methods += '\tswitch (cons) {\n' methods += reader; - methods += '\tdefault: throw mtpErrorUnexpected(cons, "MTP' + restype + '");\n'; + methods += '\tdefault: return false;\n'; methods += '\t}\n'; else: methods += reader; + methods += '\treturn true;\n'; methods += '}\n'; typesText += '\tvoid write(mtpBuffer &to) const;\n'; # write method @@ -881,7 +899,8 @@ for restype in typesList: methods += ' {\n'; methods += '\tswitch (type) {\n'; # type id check methods += switchLines; - methods += '\tdefault: throw mtpErrorBadTypeId(type, "MTP' + restype + '");\n\t}\n'; + methods += '\tdefault: Unexpected("Type in MTP' + restype + '::MTP' + restype + '.");\n'; + methods += '\t}\n'; methods += '}\n'; # by-type-id constructor end if (withData): @@ -925,7 +944,7 @@ for childName in parentFlagsList: # manual types added here textSerializeMethods += '\ -void _serialize_rpc_result(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag) {\n\ +bool _serialize_rpc_result(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag) {\n\ if (stage) {\n\ to.add(",\\n").addSpaces(lev);\n\ } else {\n\ @@ -937,9 +956,10 @@ void _serialize_rpc_result(MTPStringLogger &to, int32 stage, int32 lev, Types &t case 1: to.add(" result: "); ++stages.back(); types.push_back(0); vtypes.push_back(0); stages.push_back(0); flags.push_back(0); break;\n\ default: to.add("}"); types.pop_back(); vtypes.pop_back(); stages.pop_back(); flags.pop_back(); break;\n\ }\n\ + return true;\n\ }\n\ \n\ -void _serialize_msg_container(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag) {\n\ +bool _serialize_msg_container(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag) {\n\ if (stage) {\n\ to.add(",\\n").addSpaces(lev);\n\ } else {\n\ @@ -950,9 +970,10 @@ void _serialize_msg_container(MTPStringLogger &to, int32 stage, int32 lev, Types case 0: to.add(" messages: "); ++stages.back(); types.push_back(mtpc_vector); vtypes.push_back(mtpc_core_message); stages.push_back(0); flags.push_back(0); break;\n\ default: to.add("}"); types.pop_back(); vtypes.pop_back(); stages.pop_back(); flags.pop_back(); break;\n\ }\n\ + return true;\n\ }\n\ \n\ -void _serialize_core_message(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag) {\n\ +bool _serialize_core_message(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag) {\n\ if (stage) {\n\ to.add(",\\n").addSpaces(lev);\n\ } else {\n\ @@ -966,6 +987,7 @@ void _serialize_core_message(MTPStringLogger &to, int32 stage, int32 lev, Types case 3: to.add(" body: "); ++stages.back(); types.push_back(0); vtypes.push_back(0); stages.push_back(0); flags.push_back(0); break;\n\ default: to.add("}"); types.pop_back(); vtypes.pop_back(); stages.pop_back(); flags.pop_back(); break;\n\ }\n\ + return true;\n\ }\n\ \n'; @@ -1027,7 +1049,7 @@ enum {\n\ // Factory methods declaration\n\ ' + factories + '\n\ // Human-readable text serialization\n\ -void mtpTextSerializeType(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpPrime cons, uint32 level, mtpPrime vcons);\n' +[[nodiscard]] bool mtpTextSerializeType(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpPrime cons, uint32 level, mtpPrime vcons);\n' source = '\ /*\n\ @@ -1063,20 +1085,19 @@ using StagesFlags = QVector<int32>;\n\ ' + textSerializeMethods + '\n\ namespace {\n\ \n\ -using TextSerializer = void (*)(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag);\n\ +using TextSerializer = bool (*)(MTPStringLogger &to, int32 stage, int32 lev, Types &types, Types &vtypes, StagesFlags &stages, StagesFlags &flags, const mtpPrime *start, const mtpPrime *end, uint32 iflag);\n\ using TextSerializers = QMap<mtpTypeId, TextSerializer>;\n\ \n\ QMap<mtpTypeId, TextSerializer> createTextSerializers() {\n\ auto result = QMap<mtpTypeId, TextSerializer>();\n\ \n\ ' + textSerializeInit + '\n\ -\n\ return result;\n\ }\n\ \n\ } // namespace\n\ \n\ -void mtpTextSerializeType(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpPrime cons, uint32 level, mtpPrime vcons) {\n\ +bool mtpTextSerializeType(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpPrime cons, uint32 level, mtpPrime vcons) {\n\ static auto serializers = createTextSerializers();\n\ \n\ QVector<mtpTypeId> types, vtypes;\n\ @@ -1094,9 +1115,11 @@ void mtpTextSerializeType(MTPStringLogger &to, const mtpPrime *&from, const mtpP flag = flags.back();\n\ if (!type) {\n\ if (from >= end) {\n\ - throw Exception("from >= end");\n\ + to.error("insufficient data");\n\ + return false;\n\ } else if (stage) {\n\ - throw Exception("unknown type on stage > 0");\n\ + to.error("unknown type on stage > 0");\n\ + return false;\n\ }\n\ types.back() = type = *from;\n\ ++from;\n\ @@ -1105,12 +1128,18 @@ void mtpTextSerializeType(MTPStringLogger &to, const mtpPrime *&from, const mtpP int32 lev = level + types.size() - 1;\n\ auto it = serializers.constFind(type);\n\ if (it != serializers.cend()) {\n\ - (*it.value())(to, stage, lev, types, vtypes, stages, flags, from, end, flag);\n\ - } else {\n\ - mtpTextSerializeCore(to, from, end, type, lev, vtype);\n\ + if (!(*it.value())(to, stage, lev, types, vtypes, stages, flags, from, end, flag)) {\n\ + to.error();\n\ + return false;\n\ + }\n\ + } else if (mtpTextSerializeCore(to, from, end, type, lev, vtype)) {\n\ types.pop_back(); vtypes.pop_back(); stages.pop_back(); flags.pop_back();\n\ + } else {\n\ + to.error();\n\ + return false;\n\ }\n\ }\n\ + return true;\n\ }\n'; already_header = '' diff --git a/Telegram/SourceFiles/core/application.cpp b/Telegram/SourceFiles/core/application.cpp index a351e1f7a..702531ae6 100644 --- a/Telegram/SourceFiles/core/application.cpp +++ b/Telegram/SourceFiles/core/application.cpp @@ -780,8 +780,9 @@ void Application::authSessionCreate(const MTPUser &user) { const mtpPrime *from, const mtpPrime *end) { if (const auto main = App::main()) { - main->updateReceived(from, end); + return main->updateReceived(from, end); } + return true; })); _mtproto->setGlobalFailHandler(::rpcFail([=](const RPCError &error) { if (activeAccount().sessionExists()) { diff --git a/Telegram/SourceFiles/mainwidget.cpp b/Telegram/SourceFiles/mainwidget.cpp index 0ded1e98c..de451843c 100644 --- a/Telegram/SourceFiles/mainwidget.cpp +++ b/Telegram/SourceFiles/mainwidget.cpp @@ -3666,35 +3666,35 @@ void MainWidget::checkIdleFinish() { } } -void MainWidget::updateReceived(const mtpPrime *from, const mtpPrime *end) { - if (end <= from) return; +bool MainWidget::updateReceived(const mtpPrime *from, const mtpPrime *end) { + if (end <= from) { + return false; + } session().checkAutoLock(); if (mtpTypeId(*from) == mtpc_new_session_created) { - try { - MTPNewSession newSession; - newSession.read(from, end); - } catch (mtpErrorUnexpected &) { + MTPNewSession newSession; + if (!newSession.read(from, end)) { + return false; } updSeq = 0; MTP_LOG(0, ("getDifference { after new_session_created }%1").arg(cTestMode() ? " TESTMODE" : "")); - return getDifference(); - } else { - try { - MTPUpdates updates; - updates.read(from, end); - - _lastUpdateTime = crl::now(); - _noUpdatesTimer.callOnce(kNoUpdatesTimeout); - if (!requestingDifference() - || HasForceLogoutNotification(updates)) { - feedUpdates(updates); - } - } catch (mtpErrorUnexpected &) { // just some other type - } + getDifference(); + return true; } - update(); + MTPUpdates updates; + if (!updates.read(from, end)) { + return false; + } + + _lastUpdateTime = crl::now(); + _noUpdatesTimer.callOnce(kNoUpdatesTimeout); + if (!requestingDifference() + || HasForceLogoutNotification(updates)) { + feedUpdates(updates); + } + return true; } void MainWidget::feedUpdates(const MTPUpdates &updates, uint64 randomId) { diff --git a/Telegram/SourceFiles/mainwidget.h b/Telegram/SourceFiles/mainwidget.h index 43b1e0e3b..4c5124e82 100644 --- a/Telegram/SourceFiles/mainwidget.h +++ b/Telegram/SourceFiles/mainwidget.h @@ -121,7 +121,7 @@ public: void incrementSticker(DocumentData *sticker); void activate(); - void updateReceived(const mtpPrime *from, const mtpPrime *end); + [[nodiscard]] bool updateReceived(const mtpPrime *from, const mtpPrime *end); void refreshDialog(Dialogs::Key key); void removeDialog(Dialogs::Key key); diff --git a/Telegram/SourceFiles/mtproto/concurrent_sender.cpp b/Telegram/SourceFiles/mtproto/concurrent_sender.cpp index a8877da60..a17e7031f 100644 --- a/Telegram/SourceFiles/mtproto/concurrent_sender.cpp +++ b/Telegram/SourceFiles/mtproto/concurrent_sender.cpp @@ -18,7 +18,7 @@ public: not_null<ConcurrentSender*> sender, Fn<void(FnMut<void()>)> runner); - void operator()( + bool operator()( mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override; @@ -54,7 +54,7 @@ ConcurrentSender::RPCDoneHandler::RPCDoneHandler( , _runner(std::move(runner)) { } -void ConcurrentSender::RPCDoneHandler::operator()( +bool ConcurrentSender::RPCDoneHandler::operator()( mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) { @@ -66,6 +66,7 @@ void ConcurrentSender::RPCDoneHandler::operator()( strong->senderRequestDone(requestId, std::move(moved)); } }); + return true; } ConcurrentSender::RPCFailHandler::RPCFailHandler( @@ -178,14 +179,12 @@ void ConcurrentSender::senderRequestDone( mtpRequestId requestId, bytes::const_span result) { if (auto handlers = _requests.take(requestId)) { - try { - handlers->done(requestId, result); - } catch (Exception &e) { + if (!handlers->done(requestId, result)) { handlers->fail( requestId, RPCError::Local( "RESPONSE_PARSE_FAILED", - QString("exception text: ") + e.what())); + "ConcurrentSender::senderRequestDone")); } } } diff --git a/Telegram/SourceFiles/mtproto/concurrent_sender.h b/Telegram/SourceFiles/mtproto/concurrent_sender.h index 64088efa2..5929f800c 100644 --- a/Telegram/SourceFiles/mtproto/concurrent_sender.h +++ b/Telegram/SourceFiles/mtproto/concurrent_sender.h @@ -32,7 +32,7 @@ class ConcurrentSender : public base::has_weak_ptr { auto with_instance(Method &&method) -> std::enable_if_t<is_callable_v<Method, not_null<Instance*>>>; - using DoneHandler = FnMut<void( + using DoneHandler = FnMut<bool( mtpRequestId requestId, bytes::const_span result)>; using FailHandler = FnMut<void( @@ -206,8 +206,11 @@ void ConcurrentSender::RequestBuilder::setDoneHandler( auto from = reinterpret_cast<const mtpPrime*>(result.data()); const auto end = from + result.size() / sizeof(mtpPrime); Response data; - data.read(from, end); + if (!data.read(from, end)) { + return false; + } std::move(handler)(requestId, std::move(data)); + return true; }; } diff --git a/Telegram/SourceFiles/mtproto/connection.cpp b/Telegram/SourceFiles/mtproto/connection.cpp index 8cabe5ca8..d4541a824 100644 --- a/Telegram/SourceFiles/mtproto/connection.cpp +++ b/Telegram/SourceFiles/mtproto/connection.cpp @@ -1637,8 +1637,7 @@ void ConnectionPrivate::handleReceived() { } ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPrime *from, const mtpPrime *end, uint64 msgId, int32 serverTime, uint64 serverSalt, bool badTime) { - mtpTypeId cons = *from; - try { + const auto cons = *from; switch (cons) { @@ -1652,17 +1651,23 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr } case mtpc_msg_container: { - if (++from >= end) throw mtpErrorInsufficient(); + if (++from >= end) { + return HandleResult::ParseError; + } const mtpPrime *otherEnd; - uint32 msgsCount = (uint32)*(from++); + const auto msgsCount = (uint32)*(from++); DEBUG_LOG(("Message Info: container received, count: %1").arg(msgsCount)); for (uint32 i = 0; i < msgsCount; ++i) { - if (from + 4 >= end) throw mtpErrorInsufficient(); + if (from + 4 >= end) { + return HandleResult::ParseError; + } otherEnd = from + 4; MTPlong inMsgId; - inMsgId.read(from, otherEnd); + if (!inMsgId.read(from, otherEnd)) { + return HandleResult::ParseError; + } bool isReply = ((inMsgId.v & 0x03) == 1); if (!isReply && ((inMsgId.v & 0x03) != 3)) { LOG(("Message Error: bad msg_id %1 in contained message received").arg(inMsgId.v)); @@ -1670,9 +1675,13 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr } MTPint inSeqNo; - inSeqNo.read(from, otherEnd); + if (!inSeqNo.read(from, otherEnd)) { + return HandleResult::ParseError; + } MTPint bytes; - bytes.read(from, otherEnd); + if (!bytes.read(from, otherEnd)) { + return HandleResult::ParseError; + } if ((bytes.v & 0x03) || bytes.v < 4) { LOG(("Message Error: bad length %1 of contained message received").arg(bytes.v)); return HandleResult::RestartConnection; @@ -1684,7 +1693,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr DEBUG_LOG(("Message Info: message from container, msg_id: %1, needAck: %2").arg(inMsgId.v).arg(Logs::b(needAck))); otherEnd = from + (bytes.v >> 2); - if (otherEnd > end) throw mtpErrorInsufficient(); + if (otherEnd > end) { + return HandleResult::ParseError; + } bool needToHandle = false; { @@ -1706,7 +1717,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr case mtpc_msgs_ack: { MTPMsgsAck msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } auto &ids = msg.c_msgs_ack().vmsg_ids().v; uint32 idsCount = ids.size(); @@ -1725,7 +1738,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr case mtpc_bad_msg_notification: { MTPBadMsgNotification msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } const auto &data(msg.c_bad_msg_notification()); LOG(("Message Info: bad message notification received (error_code %3) for msg_id = %1, seq_no = %2").arg(data.vbad_msg_id().v).arg(data.vbad_msg_seqno().v).arg(data.verror_code().v)); @@ -1822,7 +1837,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr case mtpc_bad_server_salt: { MTPBadMsgNotification msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } const auto &data(msg.c_bad_server_salt()); DEBUG_LOG(("Message Info: bad server salt received (error_code %4) for msg_id = %1, seq_no = %2, new salt: %3").arg(data.vbad_msg_id().v).arg(data.vbad_msg_seqno().v).arg(data.vnew_server_salt().v).arg(data.verror_code().v)); @@ -1857,7 +1874,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr return HandleResult::Ignored; } MTPMsgsStateReq msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } auto &ids = msg.c_msgs_state_req().vmsg_ids().v; auto idsCount = ids.size(); DEBUG_LOG(("Message Info: msgs_state_req received, ids: %1").arg(LogIdsVector(ids))); @@ -1905,7 +1924,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr case mtpc_msgs_state_info: { MTPMsgsStateInfo msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } auto &data = msg.c_msgs_state_info(); auto reqMsgId = data.vreq_msg_id().v; @@ -1938,20 +1959,21 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr LOG(("Message Error: bad request %1 found in requestMap, size: %2").arg(reqMsgId).arg(requestBuffer->size())); return HandleResult::RestartConnection; } - try { - const mtpPrime *rFrom = requestBuffer->constData() + 8, *rEnd = requestBuffer->constData() + requestBuffer->size(); - if (mtpTypeId(*rFrom) == mtpc_msgs_state_req) { - MTPMsgsStateReq request; - request.read(rFrom, rEnd); - handleMsgsStates(request.c_msgs_state_req().vmsg_ids().v, states, toAck); - } else { - MTPMsgResendReq request; - request.read(rFrom, rEnd); - handleMsgsStates(request.c_msg_resend_req().vmsg_ids().v, states, toAck); + const mtpPrime *rFrom = requestBuffer->constData() + 8, *rEnd = requestBuffer->constData() + requestBuffer->size(); + if (mtpTypeId(*rFrom) == mtpc_msgs_state_req) { + MTPMsgsStateReq request; + if (!request.read(rFrom, rEnd)) { + LOG(("Message Error: could not parse sent msgs_state_req")); + return HandleResult::ParseError; } - } catch(Exception &) { - LOG(("Message Error: could not parse sent msgs_state_req")); - throw; + handleMsgsStates(request.c_msgs_state_req().vmsg_ids().v, states, toAck); + } else { + MTPMsgResendReq request; + if (!request.read(rFrom, rEnd)) { + LOG(("Message Error: could not parse sent msgs_state_req")); + return HandleResult::ParseError; + } + handleMsgsStates(request.c_msg_resend_req().vmsg_ids().v, states, toAck); } requestsAcked(toAck); @@ -1964,7 +1986,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr } MTPMsgsAllInfo msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } auto &data = msg.c_msgs_all_info(); auto &ids = data.vmsg_ids().v; auto &states = data.vinfo().v; @@ -1979,7 +2003,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr case mtpc_msg_detailed_info: { MTPMsgDetailedInfo msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } const auto &data(msg.c_msg_detailed_info()); DEBUG_LOG(("Message Info: msg detailed info, sent msgId %1, answerId %2, status %3, bytes %4").arg(data.vmsg_id().v).arg(data.vanswer_msg_id().v).arg(data.vstatus().v).arg(data.vbytes().v)); @@ -2015,7 +2041,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr return HandleResult::Ignored; } MTPMsgDetailedInfo msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } const auto &data(msg.c_msg_new_detailed_info()); DEBUG_LOG(("Message Info: msg new detailed info, answerId %2, status %3, bytes %4").arg(data.vanswer_msg_id().v).arg(data.vstatus().v).arg(data.vbytes().v)); @@ -2036,7 +2064,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr case mtpc_msg_resend_req: { MTPMsgResendReq msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } auto &ids = msg.c_msg_resend_req().vmsg_ids().v; auto idsCount = ids.size(); @@ -2051,11 +2081,15 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr } return HandleResult::Success; case mtpc_rpc_result: { - if (from + 3 > end) throw mtpErrorInsufficient(); + if (from + 3 > end) { + return HandleResult::ParseError; + } auto response = SerializedMessage(); MTPlong reqMsgId; - reqMsgId.read(++from, end); + if (!reqMsgId.read(++from, end)) { + return HandleResult::ParseError; + } mtpTypeId typeId = from[0]; DEBUG_LOG(("RPC Info: response received for %1, queueing...").arg(reqMsgId.v)); @@ -2074,7 +2108,7 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr if (typeId == mtpc_gzip_packed) { DEBUG_LOG(("RPC Info: gzip container")); response = ungzip(++from, end); - if (!response.size()) { + if (response.empty()) { return HandleResult::RestartConnection; } typeId = response[0]; @@ -2105,7 +2139,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr case mtpc_new_session_created: { const mtpPrime *start = from; MTPNewSession msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } const auto &data(msg.c_new_session_created()); if (badTime) { @@ -2143,7 +2179,9 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr case mtpc_pong: { MTPPong msg; - msg.read(from, end); + if (!msg.read(from, end)) { + return HandleResult::ParseError; + } const auto &data(msg.c_pong()); DEBUG_LOG(("Message Info: pong received, msg_id: %1, ping_id: %2").arg(data.vmsg_id().v).arg(data.vping_id().v)); @@ -2170,10 +2208,6 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr } - } catch (Exception &) { - return HandleResult::RestartConnection; - } - if (badTime) { DEBUG_LOG(("Message Error: bad time in updates cons, must create new session")); return HandleResult::ResetSession; @@ -2205,12 +2239,16 @@ ConnectionPrivate::HandleResult ConnectionPrivate::handleOneReceived(const mtpPr } mtpBuffer ConnectionPrivate::ungzip(const mtpPrime *from, const mtpPrime *end) const { - MTPstring packed; - packed.read(from, end); // read packed string as serialized mtp string type - uint32 packedLen = packed.v.size(), unpackedChunk = packedLen, unpackedLen = 0; - mtpBuffer result; // * 4 because of mtpPrime type result.resize(0); + + MTPstring packed; + if (!packed.read(from, end)) { // read packed string as serialized mtp string type + LOG(("RPC Error: could not read gziped bytes.")); + return result; + } + uint32 packedLen = packed.v.size(), unpackedChunk = packedLen, unpackedLen = 0; + z_stream stream; stream.zalloc = 0; stream.zfree = 0; @@ -2730,7 +2768,10 @@ void ConnectionPrivate::dhParamsAnswered() { const mtpPrime *from(&decBuffer[5]), *to(from), *end(from + (encDHBufLen - 5)); MTPServer_DH_inner_data dh_inner; - dh_inner.read(to, end); + if (!dh_inner.read(to, end)) { + LOG(("AuthKey Error: could not decrypt server_DH_inner_data!")); + return restart(); + } const auto &dh_inner_data(dh_inner.c_server_DH_inner_data()); if (dh_inner_data.vnonce() != _authKeyData->nonce) { LOG(("AuthKey Error: received nonce <> sent nonce (in server_DH_inner_data)!")); @@ -3100,13 +3141,8 @@ bool ConnectionPrivate::readNotSecureResponse(Response &response) { if (answer.empty()) { return false; } - try { - auto from = answer.data(); - response.read(from, from + answer.size()); - } catch (Exception &) { - return false; - } - return true; + auto from = answer.data(); + return response.read(from, from + answer.size()); } bool ConnectionPrivate::sendSecureRequest( diff --git a/Telegram/SourceFiles/mtproto/connection.h b/Telegram/SourceFiles/mtproto/connection.h index 079864db9..ee1eed458 100644 --- a/Telegram/SourceFiles/mtproto/connection.h +++ b/Telegram/SourceFiles/mtproto/connection.h @@ -195,8 +195,9 @@ private: Ignored, RestartConnection, ResetSession, + ParseError, }; - HandleResult handleOneReceived(const mtpPrime *from, const mtpPrime *end, uint64 msgId, int32 serverTime, uint64 serverSalt, bool badTime); + [[nodiscard]] HandleResult handleOneReceived(const mtpPrime *from, const mtpPrime *end, uint64 msgId, int32 serverTime, uint64 serverSalt, bool badTime); mtpBuffer ungzip(const mtpPrime *from, const mtpPrime *end) const; void handleMsgsStates(const QVector<MTPlong> &ids, const QByteArray &states, QVector<MTPlong> &acked); @@ -225,7 +226,7 @@ private: void sendNotSecureRequest(const Request &request); template <typename Response> - bool readNotSecureResponse(Response &response); + [[nodiscard]] bool readNotSecureResponse(Response &response); not_null<Instance*> _instance; DcType _dcType = DcType::Regular; diff --git a/Telegram/SourceFiles/mtproto/connection_abstract.cpp b/Telegram/SourceFiles/mtproto/connection_abstract.cpp index 507247830..48974856d 100644 --- a/Telegram/SourceFiles/mtproto/connection_abstract.cpp +++ b/Telegram/SourceFiles/mtproto/connection_abstract.cpp @@ -141,16 +141,17 @@ mtpBuffer AbstractConnection::preparePQFake(const MTPint128 &nonce) const { base::unixtime::mtproto_msg_id()); } -MTPResPQ AbstractConnection::readPQFakeReply( +std::optional<MTPResPQ> AbstractConnection::readPQFakeReply( const mtpBuffer &buffer) const { const auto answer = parseNotSecureResponse(buffer); if (answer.empty()) { - throw Exception("bad pq reply"); + return std::nullopt; } auto from = answer.data(); MTPResPQ response; - response.read(from, from + answer.size()); - return response; + return response.read(from, from + answer.size()) + ? std::make_optional(response) + : std::nullopt; } AbstractConnection::AbstractConnection( diff --git a/Telegram/SourceFiles/mtproto/connection_abstract.h b/Telegram/SourceFiles/mtproto/connection_abstract.h index 3f92cb4d1..e6e40129c 100644 --- a/Telegram/SourceFiles/mtproto/connection_abstract.h +++ b/Telegram/SourceFiles/mtproto/connection_abstract.h @@ -137,7 +137,7 @@ protected: // first we always send fake MTPReq_pq to see if connection works at all // we send them simultaneously through TCP/HTTP/IPv4/IPv6 to choose the working one mtpBuffer preparePQFake(const MTPint128 &nonce) const; - MTPResPQ readPQFakeReply(const mtpBuffer &buffer) const; + std::optional<MTPResPQ> readPQFakeReply(const mtpBuffer &buffer) const; }; diff --git a/Telegram/SourceFiles/mtproto/connection_http.cpp b/Telegram/SourceFiles/mtproto/connection_http.cpp index 082dc9b30..f7d6a81f7 100644 --- a/Telegram/SourceFiles/mtproto/connection_http.cpp +++ b/Telegram/SourceFiles/mtproto/connection_http.cpp @@ -167,28 +167,24 @@ void HttpConnection::requestFinished(QNetworkReply *reply) { if (_status == Status::Ready) { _receivedQueue.push_back(data); emit receivedData(); - } else { - try { - const auto res_pq = readPQFakeReply(data); - const auto &data = res_pq.c_resPQ(); - if (data.vnonce() == _checkNonce) { - DEBUG_LOG(("Connection Info: " - "HTTP-transport to %1 connected by pq-response" - ).arg(_address)); - _status = Status::Ready; - _pingTime = crl::now() - _pingTime; - emit connected(); - } else { - DEBUG_LOG(("Connection Error: " - "Wrong nonce received in HTTP fake pq-responce")); - emit error(kErrorCodeOther); - } - } catch (Exception &e) { + } else if (const auto res_pq = readPQFakeReply(data)) { + const auto &data = res_pq->c_resPQ(); + if (data.vnonce() == _checkNonce) { + DEBUG_LOG(("Connection Info: " + "HTTP-transport to %1 connected by pq-response" + ).arg(_address)); + _status = Status::Ready; + _pingTime = crl::now() - _pingTime; + emit connected(); + } else { DEBUG_LOG(("Connection Error: " - "Exception in parsing HTTP fake pq-responce, %1" - ).arg(e.what())); + "Wrong nonce received in HTTP fake pq-responce")); emit error(kErrorCodeOther); } + } else { + DEBUG_LOG(("Connection Error: " + "Could not parse HTTP fake pq-responce")); + emit error(kErrorCodeOther); } } } else { diff --git a/Telegram/SourceFiles/mtproto/connection_tcp.cpp b/Telegram/SourceFiles/mtproto/connection_tcp.cpp index bf48effbd..d48eed980 100644 --- a/Telegram/SourceFiles/mtproto/connection_tcp.cpp +++ b/Telegram/SourceFiles/mtproto/connection_tcp.cpp @@ -615,9 +615,8 @@ void TcpConnection::socketPacket(bytes::const_span bytes) { _receivedQueue.push_back(data); emit receivedData(); } else if (_status == Status::Waiting) { - try { - const auto res_pq = readPQFakeReply(data); - const auto &data = res_pq.c_resPQ(); + if (const auto res_pq = readPQFakeReply(data)) { + const auto &data = res_pq->c_resPQ(); if (data.vnonce() == _checkNonce) { DEBUG_LOG(("Connection Info: Valid pq response by TCP.")); _status = Status::Ready; @@ -629,10 +628,9 @@ void TcpConnection::socketPacket(bytes::const_span bytes) { "Wrong nonce received in TCP fake pq-responce")); emit error(kErrorCodeOther); } - } catch (Exception &e) { + } else { DEBUG_LOG(("Connection Error: " - "Exception in parsing TCP fake pq-responce, %1" - ).arg(e.what())); + "Could not parse TCP fake pq-responce")); emit error(kErrorCodeOther); } } diff --git a/Telegram/SourceFiles/mtproto/core_types.cpp b/Telegram/SourceFiles/mtproto/core_types.cpp index 84a57235d..0f1d7e2e9 100644 --- a/Telegram/SourceFiles/mtproto/core_types.cpp +++ b/Telegram/SourceFiles/mtproto/core_types.cpp @@ -147,36 +147,6 @@ bool SecureRequest::needAck() const { } // namespace MTP -Exception::Exception(const QString &msg) noexcept : _msg(msg.toUtf8()) { - LOG(("Exception: %1").arg(msg)); -} - -mtpErrorUnexpected::mtpErrorUnexpected( - mtpTypeId typeId, - const QString &type) noexcept -: Exception( - QString("MTP Unexpected type id #%1 read in %2" - ).arg(uint32(typeId), 0, 16 - ).arg(type)) { -} - -mtpErrorInsufficient::mtpErrorInsufficient() noexcept -: Exception("MTP Insufficient bytes in input buffer") { -} - -mtpErrorBadTypeId::mtpErrorBadTypeId( - mtpTypeId typeId, - const QString &type) noexcept -: Exception( - QString("MTP Bad type id #%1 passed to constructor of %2" - ).arg(uint32(typeId), 0, 16 - ).arg(type)) { -} - -const char *Exception::what() const noexcept { - return _msg.constData(); -} - uint32 MTPstring::innerLength() const { uint32 l = v.length(); if (l < 254) { @@ -189,9 +159,10 @@ uint32 MTPstring::innerLength() const { return l; } -void MTPstring::read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons) { - if (from + 1 > end) throw mtpErrorInsufficient(); - if (cons != mtpc_string) throw mtpErrorUnexpected(cons, "MTPstring"); +bool MTPstring::read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons) { + if (from + 1 > end || cons != mtpc_string) { + return false; + } uint32 l; const uchar *buf = (const uchar*)from; @@ -204,9 +175,12 @@ void MTPstring::read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons) ++buf; from += ((l + 1) >> 2) + (((l + 1) & 0x03) ? 1 : 0); } - if (from > end) throw mtpErrorInsufficient(); + if (from > end) { + return false; + } v = QByteArray(reinterpret_cast<const char*>(buf), l); + return true; } void MTPstring::write(mtpBuffer &to) const { @@ -229,74 +203,91 @@ void MTPstring::write(mtpBuffer &to) const { memcpy(buf, v.constData(), l); } -void mtpTextSerializeCore(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons, uint32 level, mtpPrime vcons) { +bool mtpTextSerializeCore(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons, uint32 level, mtpPrime vcons) { switch (mtpTypeId(cons)) { case mtpc_int: { MTPint value; - value.read(from, end, cons); - to.add(QString::number(value.v)).add(" [INT]"); + if (value.read(from, end, cons)) { + to.add(QString::number(value.v)).add(" [INT]"); + return true; + } } break; case mtpc_long: { MTPlong value; - value.read(from, end, cons); - to.add(QString::number(value.v)).add(" [LONG]"); + if (value.read(from, end, cons)) { + to.add(QString::number(value.v)).add(" [LONG]"); + return true; + } } break; case mtpc_int128: { MTPint128 value; - value.read(from, end, cons); - to.add(QString::number(value.h)).add(" * 2^64 + ").add(QString::number(value.l)).add(" [INT128]"); + if (value.read(from, end, cons)) { + to.add(QString::number(value.h)).add(" * 2^64 + ").add(QString::number(value.l)).add(" [INT128]"); + return true; + } } break; case mtpc_int256: { MTPint256 value; - value.read(from, end, cons); - to.add(QString::number(value.h.h)).add(" * 2^192 + ").add(QString::number(value.h.l)).add(" * 2^128 + ").add(QString::number(value.l.h)).add(" * 2 ^ 64 + ").add(QString::number(value.l.l)).add(" [INT256]"); + if (value.read(from, end, cons)) { + to.add(QString::number(value.h.h)).add(" * 2^192 + ").add(QString::number(value.h.l)).add(" * 2^128 + ").add(QString::number(value.l.h)).add(" * 2 ^ 64 + ").add(QString::number(value.l.l)).add(" [INT256]"); + return true; + } } break; case mtpc_double: { MTPdouble value; - value.read(from, end, cons); - to.add(QString::number(value.v)).add(" [DOUBLE]"); + if (value.read(from, end, cons)) { + to.add(QString::number(value.v)).add(" [DOUBLE]"); + return true; + } } break; case mtpc_string: { MTPstring value; - value.read(from, end, cons); - auto strUtf8 = value.v; - auto str = QString::fromUtf8(strUtf8); - if (str.toUtf8() == strUtf8) { - to.add("\"").add(str.replace('\\', "\\\\").replace('"', "\\\"").replace('\n', "\\n")).add("\" [STRING]"); - } else if (strUtf8.size() < 64) { - to.add(Logs::mb(strUtf8.constData(), strUtf8.size()).str()).add(" [").add(QString::number(strUtf8.size())).add(" BYTES]"); - } else { - to.add(Logs::mb(strUtf8.constData(), 16).str()).add("... [").add(QString::number(strUtf8.size())).add(" BYTES]"); + if (value.read(from, end, cons)) { + auto strUtf8 = value.v; + auto str = QString::fromUtf8(strUtf8); + if (str.toUtf8() == strUtf8) { + to.add("\"").add(str.replace('\\', "\\\\").replace('"', "\\\"").replace('\n', "\\n")).add("\" [STRING]"); + } else if (strUtf8.size() < 64) { + to.add(Logs::mb(strUtf8.constData(), strUtf8.size()).str()).add(" [").add(QString::number(strUtf8.size())).add(" BYTES]"); + } else { + to.add(Logs::mb(strUtf8.constData(), 16).str()).add("... [").add(QString::number(strUtf8.size())).add(" BYTES]"); + } + return true; } } break; case mtpc_vector: { - if (from >= end) { - throw Exception("from >= end in vector"); - } - int32 cnt = *(from++); - to.add("[ vector<0x").add(QString::number(vcons, 16)).add(">"); - if (cnt) { - to.add("\n").addSpaces(level); - for (int32 i = 0; i < cnt; ++i) { - to.add(" "); - mtpTextSerializeType(to, from, end, vcons, level + 1); - to.add(",\n").addSpaces(level); + if (from < end) { + int32 cnt = *(from++); + to.add("[ vector<0x").add(QString::number(vcons, 16)).add(">"); + if (cnt) { + to.add("\n").addSpaces(level); + for (int32 i = 0; i < cnt; ++i) { + to.add(" "); + if (!mtpTextSerializeType(to, from, end, vcons, level + 1)) { + return false; + } + to.add(",\n").addSpaces(level); + } + } else { + to.add(" "); } - } else { - to.add(" "); + to.add("]"); + return true; } - to.add("]"); } break; case mtpc_gzip_packed: { MTPstring packed; - packed.read(from, end); // read packed string as serialized mtp string type + // read packed string as serialized mtp string type + if (!packed.read(from, end)) { + return false; + } uint32 packedLen = packed.v.size(), unpackedChunk = packedLen; mtpBuffer result; // * 4 because of mtpPrime type result.resize(0); @@ -309,7 +300,7 @@ void mtpTextSerializeCore(MTPStringLogger &to, const mtpPrime *&from, const mtpP stream.next_in = nullptr; int res = inflateInit2(&stream, 16 + MAX_WBITS); if (res != Z_OK) { - throw Exception(QString("ungzip init, code: %1").arg(res)); + return false; } stream.avail_in = packedLen; stream.next_in = reinterpret_cast<Bytef*>(packed.v.data()); @@ -321,39 +312,39 @@ void mtpTextSerializeCore(MTPStringLogger &to, const mtpPrime *&from, const mtpP int res = inflate(&stream, Z_NO_FLUSH); if (res != Z_OK && res != Z_STREAM_END) { inflateEnd(&stream); - throw Exception(QString("ungzip unpack, code: %1").arg(res)); + return false; } } if (stream.avail_out & 0x03) { - uint32 badSize = result.size() * sizeof(mtpPrime) - stream.avail_out; - throw Exception(QString("ungzip bad length, size: %1").arg(badSize)); + return false; } result.resize(result.size() - (stream.avail_out >> 2)); inflateEnd(&stream); if (result.empty()) { - throw Exception("ungzip void data"); + return false; } const mtpPrime *newFrom = result.constData(), *newEnd = result.constData() + result.size(); - to.add("[GZIPPED] "); mtpTextSerializeType(to, newFrom, newEnd, 0, level); + to.add("[GZIPPED] "); + return mtpTextSerializeType(to, newFrom, newEnd, 0, level); } break; default: { for (uint32 i = 1; i < mtpLayerMaxSingle; ++i) { if (cons == mtpLayers[i]) { - to.add("[LAYER").add(QString::number(i + 1)).add("] "); mtpTextSerializeType(to, from, end, 0, level); - return; + to.add("[LAYER").add(QString::number(i + 1)).add("] "); + return mtpTextSerializeType(to, from, end, 0, level); } } if (cons == mtpc_invokeWithLayer) { if (from >= end) { - throw Exception("from >= end in invokeWithLayer"); + return false; } int32 layer = *(from++); - to.add("[LAYER").add(QString::number(layer)).add("] "); mtpTextSerializeType(to, from, end, 0, level); - return; + to.add("[LAYER").add(QString::number(layer)).add("] "); + return mtpTextSerializeType(to, from, end, 0, level); } - throw Exception(QString("unknown cons 0x%1").arg(cons, 0, 16)); } break; } + return false; } diff --git a/Telegram/SourceFiles/mtproto/core_types.h b/Telegram/SourceFiles/mtproto/core_types.h index ffedf9cd7..a2f474fd0 100644 --- a/Telegram/SourceFiles/mtproto/core_types.h +++ b/Telegram/SourceFiles/mtproto/core_types.h @@ -59,35 +59,6 @@ constexpr int GetDcIdShift(ShiftedDcId shiftedDcId) { } // namespace MTP -class Exception : public std::exception { -public: - explicit Exception(const QString &msg) noexcept; - - const char *what() const noexcept override; - -private: - QByteArray _msg; - -}; - -class mtpErrorUnexpected : public Exception { -public: - mtpErrorUnexpected(mtpTypeId typeId, const QString &type) noexcept; - -}; - -class mtpErrorInsufficient : public Exception { -public: - mtpErrorInsufficient() noexcept; - -}; - -class mtpErrorBadTypeId : public Exception { -public: - mtpErrorBadTypeId(mtpTypeId typeId, const QString &type) noexcept; - -}; - namespace MTP { namespace internal { @@ -259,10 +230,12 @@ public: uint32 innerLength() const { return sizeof(mtpTypeId) + bareT::innerLength(); } - void read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = 0) { - if (from + 1 > end) throw mtpErrorInsufficient(); + [[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = 0) { + if (from + 1 > end) { + return false; + } cons = (mtpTypeId)*(from++); - bareT::read(from, end, cons); + return bareT::read(from, end, cons); } void write(mtpBuffer &to) const { to.push_back(bareT::type()); @@ -439,10 +412,12 @@ public: mtpTypeId type() const { return mtpc_int; } - void read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_int) { - if (from + 1 > end) throw mtpErrorInsufficient(); - if (cons != mtpc_int) throw mtpErrorUnexpected(cons, "MTPint"); + [[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_int) { + if (from + 1 > end || cons != mtpc_int) { + return false; + } v = (int32)*(from++); + return true; } void write(mtpBuffer &to) const { to.push_back((mtpPrime)v); @@ -477,10 +452,12 @@ public: mtpTypeId type() const { return mtpc_flags; } - void read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_flags) { - if (from + 1 > end) throw mtpErrorInsufficient(); - if (cons != mtpc_flags) throw mtpErrorUnexpected(cons, "MTPflags"); + [[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_flags) { + if (from + 1 > end || cons != mtpc_flags) { + return false; + } v = Flags::from_raw(static_cast<typename Flags::Type>(*(from++))); + return true; } void write(mtpBuffer &to) const { to.push_back(static_cast<mtpPrime>(v.value())); @@ -534,11 +511,13 @@ public: mtpTypeId type() const { return mtpc_long; } - void read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_long) { - if (from + 2 > end) throw mtpErrorInsufficient(); - if (cons != mtpc_long) throw mtpErrorUnexpected(cons, "MTPlong"); + [[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_long) { + if (from + 2 > end || cons != mtpc_long) { + return false; + } v = (uint64)(((uint32*)from)[0]) | ((uint64)(((uint32*)from)[1]) << 32); from += 2; + return true; } void write(mtpBuffer &to) const { to.push_back((mtpPrime)(v & 0xFFFFFFFFL)); @@ -576,12 +555,14 @@ public: mtpTypeId type() const { return mtpc_int128; } - void read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_int128) { - if (from + 4 > end) throw mtpErrorInsufficient(); - if (cons != mtpc_int128) throw mtpErrorUnexpected(cons, "MTPint128"); + [[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_int128) { + if (from + 4 > end || cons != mtpc_int128) { + return false; + } l = (uint64)(((uint32*)from)[0]) | ((uint64)(((uint32*)from)[1]) << 32); h = (uint64)(((uint32*)from)[2]) | ((uint64)(((uint32*)from)[3]) << 32); from += 4; + return true; } void write(mtpBuffer &to) const { to.push_back((mtpPrime)(l & 0xFFFFFFFFL)); @@ -621,10 +602,12 @@ public: mtpTypeId type() const { return mtpc_int256; } - void read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_int256) { - if (cons != mtpc_int256) throw mtpErrorUnexpected(cons, "MTPint256"); - l.read(from, end); - h.read(from, end); + [[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_int256) { + if (cons != mtpc_int256) { + return false; + } + return l.read(from, end) + && h.read(from, end); } void write(mtpBuffer &to) const { l.write(to); @@ -661,12 +644,14 @@ public: mtpTypeId type() const { return mtpc_double; } - void read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_double) { - if (from + 2 > end) throw mtpErrorInsufficient(); - if (cons != mtpc_double) throw mtpErrorUnexpected(cons, "MTPdouble"); + [[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_double) { + if (from + 2 > end || cons != mtpc_double) { + return false; + } auto nv = (uint64)(((uint32*)from)[0]) | ((uint64)(((uint32*)from)[1]) << 32); std::memcpy(&v, &nv, sizeof(v)); from += 2; + return true; } void write(mtpBuffer &to) const { uint64 iv; @@ -704,7 +689,7 @@ public: mtpTypeId type() const { return mtpc_string; } - void read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_string); + [[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_string); void write(mtpBuffer &to) const; QByteArray v; @@ -792,16 +777,20 @@ public: mtpTypeId type() const { return mtpc_vector; } - void read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_vector) { - if (from + 1 > end) throw mtpErrorInsufficient(); - if (cons != mtpc_vector) throw mtpErrorUnexpected(cons, "MTPvector"); + [[nodiscard]] bool read(const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons = mtpc_vector) { + if (from + 1 > end || cons != mtpc_vector) { + return false; + } auto count = static_cast<uint32>(*(from++)); auto vector = QVector<T>(count, T()); for (auto &item : vector) { - item.read(from, end); + if (!item.read(from, end)) { + return false; + } } v = std::move(vector); + return true; } void write(mtpBuffer &to) const { to.push_back(v.size()); @@ -900,6 +889,10 @@ struct MTPStringLogger { return (*this); } + MTPStringLogger &error(const char *problem = "could not decode type") { + return add("[ERROR] (").add(problem).add(")"); + } + void ensureLength(int32 add) { if (size + add <= alloced) return; @@ -920,16 +913,12 @@ struct MTPStringLogger { }; -void mtpTextSerializeType(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpPrime cons = 0, uint32 level = 0, mtpPrime vcons = 0); +[[nodiscard]] bool mtpTextSerializeType(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpPrime cons = 0, uint32 level = 0, mtpPrime vcons = 0); -void mtpTextSerializeCore(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons, uint32 level, mtpPrime vcons = 0); +[[nodiscard]] bool mtpTextSerializeCore(MTPStringLogger &to, const mtpPrime *&from, const mtpPrime *end, mtpTypeId cons, uint32 level, mtpPrime vcons = 0); inline QString mtpTextSerialize(const mtpPrime *&from, const mtpPrime *end) { MTPStringLogger to; - try { - mtpTextSerializeType(to, from, end, mtpc_core_message); - } catch (Exception &e) { - to.add("[ERROR] (").add(e.what()).add(")"); - } + [[maybe_unused]] bool result = mtpTextSerializeType(to, from, end, mtpc_core_message); return QString::fromUtf8(to.p, to.size); } diff --git a/Telegram/SourceFiles/mtproto/mtp_instance.cpp b/Telegram/SourceFiles/mtproto/mtp_instance.cpp index e5c89ea55..6564205eb 100644 --- a/Telegram/SourceFiles/mtproto/mtp_instance.cpp +++ b/Telegram/SourceFiles/mtproto/mtp_instance.cpp @@ -1064,22 +1064,24 @@ void Instance::Private::execCallback( } }; - try { - if (from >= end) throw mtpErrorInsufficient(); - if (*from == mtpc_rpc_error) { - auto error = MTPRpcError(); - error.read(from, end); - handleError(error); - } else { - if (h.onDone) { - (*h.onDone)(requestId, from, end); - } - unregisterRequest(requestId); - } - } catch (Exception &e) { + if (from >= end) { handleError(RPCError::Local( "RESPONSE_PARSE_FAILED", - QString("exception text: ") + e.what())); + "Empty response.")); + } else if (*from == mtpc_rpc_error) { + auto error = MTPRpcError(); + handleError(error.read(from, end) ? error : RPCError::Local( + "RESPONSE_PARSE_FAILED", + "Error parse failed.")); + } else { + if (h.onDone) { + if (!(*h.onDone)(requestId, from, end)) { + handleError(RPCError::Local( + "RESPONSE_PARSE_FAILED", + "Response parse failed.")); + } + } + unregisterRequest(requestId); } } else { DEBUG_LOG(("RPC Info: parser not found for %1").arg(requestId)); diff --git a/Telegram/SourceFiles/mtproto/rpc_sender.h b/Telegram/SourceFiles/mtproto/rpc_sender.h index 84c47449e..0e739eb92 100644 --- a/Telegram/SourceFiles/mtproto/rpc_sender.h +++ b/Telegram/SourceFiles/mtproto/rpc_sender.h @@ -67,7 +67,7 @@ inline bool isDefaultHandledError(const RPCError &error) { class RPCAbstractDoneHandler { // abstract done public: - virtual void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) = 0; + virtual [[nodiscard]] bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) = 0; virtual ~RPCAbstractDoneHandler() { } @@ -94,15 +94,14 @@ struct RPCResponseHandler { }; -template <typename TReturn> class RPCDoneHandlerBare : public RPCAbstractDoneHandler { // done(from, end) - using CallbackType = TReturn (*)(const mtpPrime *, const mtpPrime *); + using CallbackType = bool (*)(const mtpPrime *, const mtpPrime *); public: RPCDoneHandlerBare(CallbackType onDone) : _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - (*_onDone)(from, end); + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + return (*_onDone)(from, end); } private: @@ -110,15 +109,14 @@ private: }; -template <typename TReturn> class RPCDoneHandlerBareReq : public RPCAbstractDoneHandler { // done(from, end, req_id) - using CallbackType = TReturn (*)(const mtpPrime *, const mtpPrime *, mtpRequestId); + using CallbackType = bool (*)(const mtpPrime *, const mtpPrime *, mtpRequestId); public: RPCDoneHandlerBareReq(CallbackType onDone) : _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - (*_onDone)(from, end, requestId); + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + return (*_onDone)(from, end, requestId); } private: @@ -133,10 +131,13 @@ class RPCDoneHandlerPlain : public RPCAbstractDoneHandler { // done(result) public: RPCDoneHandlerPlain(CallbackType onDone) : _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { auto response = TResponse(); - response.read(from, end); + if (!response.read(from, end)) { + return false; + } (*_onDone)(std::move(response)); + return true; } private: @@ -151,10 +152,13 @@ class RPCDoneHandlerReq : public RPCAbstractDoneHandler { // done(result, req_id public: RPCDoneHandlerReq(CallbackType onDone) : _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { auto response = TResponse(); - response.read(from, end); + if (!response.read(from, end)) { + return false; + } (*_onDone)(std::move(response), requestId); + return true; } private: @@ -169,8 +173,9 @@ class RPCDoneHandlerNo : public RPCAbstractDoneHandler { // done() public: RPCDoneHandlerNo(CallbackType onDone) : _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { (*_onDone)(); + return true; } private: @@ -185,8 +190,9 @@ class RPCDoneHandlerNoReq : public RPCAbstractDoneHandler { // done(req_id) public: RPCDoneHandlerNoReq(CallbackType onDone) : _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { (*_onDone)(requestId); + return true; } private: @@ -265,14 +271,12 @@ struct RPCCallbackClear { }; -template <typename TReturn> -inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)(const mtpPrime *, const mtpPrime *)) { // done(from, end) - return RPCDoneHandlerPtr(new RPCDoneHandlerBare<TReturn>(onDone)); +inline RPCDoneHandlerPtr rpcDone(bool (*onDone)(const mtpPrime *, const mtpPrime *)) { // done(from, end) + return RPCDoneHandlerPtr(new RPCDoneHandlerBare(onDone)); } -template <typename TReturn> -inline RPCDoneHandlerPtr rpcDone(TReturn (*onDone)(const mtpPrime *, const mtpPrime *, mtpRequestId)) { // done(from, end, req_id) - return RPCDoneHandlerPtr(new RPCDoneHandlerBareReq<TReturn>(onDone)); +inline RPCDoneHandlerPtr rpcDone(bool (*onDone)(const mtpPrime *, const mtpPrime *, mtpRequestId)) { // done(from, end, req_id) + return RPCDoneHandlerPtr(new RPCDoneHandlerBareReq(onDone)); } template <typename TReturn, typename TResponse> @@ -346,7 +350,7 @@ class RPCDoneHandlerBareOwned : public RPCOwnedDoneHandler { // done(from, end) public: RPCDoneHandlerBareOwned(TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { if (_owner) (static_cast<TReceiver*>(_owner)->*_onDone)(from, end); } @@ -355,15 +359,17 @@ private: }; -template <typename TReturn, typename TReceiver> +template <typename TReceiver> class RPCDoneHandlerBareOwnedReq : public RPCOwnedDoneHandler { // done(from, end, req_id) - using CallbackType = TReturn (TReceiver::*)(const mtpPrime *, const mtpPrime *, mtpRequestId); + using CallbackType = bool (TReceiver::*)(const mtpPrime *, const mtpPrime *, mtpRequestId); public: RPCDoneHandlerBareOwnedReq(TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - if (_owner) (static_cast<TReceiver*>(_owner)->*_onDone)(from, end, requestId); + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + return _owner + ? (static_cast<TReceiver*>(_owner)->*_onDone)(from, end, requestId) + : true; } private: @@ -378,12 +384,15 @@ class RPCDoneHandlerOwned : public RPCOwnedDoneHandler { // done(result) public: RPCDoneHandlerOwned(TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + auto response = TResponse(); + if (!response.read(from, end)) { + return false; + } if (_owner) { - auto response = TResponse(); - response.read(from, end); (static_cast<TReceiver*>(_owner)->*_onDone)(std::move(response)); } + return true; } private: @@ -398,12 +407,15 @@ class RPCDoneHandlerOwnedReq : public RPCOwnedDoneHandler { // done(result, req_ public: RPCDoneHandlerOwnedReq(TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + auto response = TResponse(); + if (!response.read(from, end)) { + return false; + } if (_owner) { - auto response = TResponse(); - response.read(from, end); (static_cast<TReceiver*>(_owner)->*_onDone)(std::move(response), requestId); } + return true; } private: @@ -418,8 +430,9 @@ class RPCDoneHandlerOwnedNo : public RPCOwnedDoneHandler { // done() public: RPCDoneHandlerOwnedNo(TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { if (_owner) (static_cast<TReceiver*>(_owner)->*_onDone)(); + return true; } private: @@ -434,8 +447,9 @@ class RPCDoneHandlerOwnedNoReq : public RPCOwnedDoneHandler { // done(req_id) public: RPCDoneHandlerOwnedNoReq(TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { if (_owner) (static_cast<TReceiver*>(_owner)->*_onDone)(requestId); + return true; } private: @@ -443,15 +457,17 @@ private: }; -template <typename T, typename TReturn, typename TReceiver> +template <typename T, typename TReceiver> class RPCBindedDoneHandlerBareOwned : public RPCOwnedDoneHandler { // done(b, from, end) - using CallbackType = TReturn (TReceiver::*)(T, const mtpPrime *, const mtpPrime *); + using CallbackType = bool (TReceiver::*)(T, const mtpPrime *, const mtpPrime *); public: RPCBindedDoneHandlerBareOwned(T b, TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _b(b), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - if (_owner) (static_cast<TReceiver*>(_owner)->*_onDone)(_b, from, end); + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + return _owner + ? (static_cast<TReceiver*>(_owner)->*_onDone)(_b, from, end) + : true; } private: @@ -462,13 +478,15 @@ private: template <typename T, typename TReturn, typename TReceiver> class RPCBindedDoneHandlerBareOwnedReq : public RPCOwnedDoneHandler { // done(b, from, end, req_id) - using CallbackType = TReturn (TReceiver::*)(T, const mtpPrime *, const mtpPrime *, mtpRequestId); + using CallbackType = bool (TReceiver::*)(T, const mtpPrime *, const mtpPrime *, mtpRequestId); public: RPCBindedDoneHandlerBareOwnedReq(T b, TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _b(b), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - if (_owner) (static_cast<TReceiver*>(_owner)->*_onDone)(_b, from, end, requestId); + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + return _owner + ? (static_cast<TReceiver*>(_owner)->*_onDone)(_b, from, end, requestId) + : true; } private: @@ -484,12 +502,15 @@ class RPCBindedDoneHandlerOwned : public RPCOwnedDoneHandler { // done(b, result public: RPCBindedDoneHandlerOwned(T b, TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _onDone(onDone), _b(b) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + auto response = TResponse(); + if (!response.read(from, end)) { + return false; + } if (_owner) { - auto response = TResponse(); - response.read(from, end); (static_cast<TReceiver*>(_owner)->*_onDone)(_b, std::move(response)); } + return true; } private: @@ -505,12 +526,15 @@ class RPCBindedDoneHandlerOwnedReq : public RPCOwnedDoneHandler { // done(b, res public: RPCBindedDoneHandlerOwnedReq(T b, TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _onDone(onDone), _b(b) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + auto response = TResponse(); + if (!response.read(from, end)) { + return false; + } if (_owner) { - auto response = TResponse(); - response.read(from, end); (static_cast<TReceiver*>(_owner)->*_onDone)(_b, std::move(response), requestId); } + return true; } private: @@ -526,8 +550,9 @@ class RPCBindedDoneHandlerOwnedNo : public RPCOwnedDoneHandler { // done(b) public: RPCBindedDoneHandlerOwnedNo(T b, TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _b(b), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { if (_owner) (static_cast<TReceiver*>(_owner)->*_onDone)(_b); + return true; } private: @@ -543,8 +568,9 @@ class RPCBindedDoneHandlerOwnedNoReq : public RPCOwnedDoneHandler { // done(b, r public: RPCBindedDoneHandlerOwnedNoReq(T b, TReceiver *receiver, CallbackType onDone) : RPCOwnedDoneHandler(receiver), _b(b), _onDone(onDone) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { if (_owner) (static_cast<TReceiver*>(_owner)->*_onDone)(_b, requestId); + return true; } private: @@ -851,22 +877,20 @@ protected: template <typename FunctionType> using RPCDoneHandlerImplementation = RPCHandlerImplementation<RPCAbstractDoneHandler, FunctionType>; -template <typename R> -class RPCDoneHandlerImplementationBare : public RPCDoneHandlerImplementation<R(const mtpPrime*, const mtpPrime*)> { // done(from, end) +class RPCDoneHandlerImplementationBare : public RPCDoneHandlerImplementation<bool(const mtpPrime*, const mtpPrime*)> { // done(from, end) public: - using RPCDoneHandlerImplementation<R(const mtpPrime*, const mtpPrime*)>::Parent::Parent; - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - return this->_handler ? this->_handler(from, end) : void(0); + using RPCDoneHandlerImplementation<bool(const mtpPrime*, const mtpPrime*)>::Parent::Parent; + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + return this->_handler ? this->_handler(from, end) : true; } }; -template <typename R> -class RPCDoneHandlerImplementationBareReq : public RPCDoneHandlerImplementation<R(const mtpPrime*, const mtpPrime*, mtpRequestId)> { // done(from, end, req_id) +class RPCDoneHandlerImplementationBareReq : public RPCDoneHandlerImplementation<bool(const mtpPrime*, const mtpPrime*, mtpRequestId)> { // done(from, end, req_id) public: - using RPCDoneHandlerImplementation<R(const mtpPrime*, const mtpPrime*, mtpRequestId)>::Parent::Parent; - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { - return this->_handler ? this->_handler(from, end, requestId) : void(0); + using RPCDoneHandlerImplementation<bool(const mtpPrime*, const mtpPrime*, mtpRequestId)>::Parent::Parent; + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + return this->_handler ? this->_handler(from, end, requestId) : true; } }; @@ -875,12 +899,15 @@ template <typename R, typename TResponse> class RPCDoneHandlerImplementationPlain : public RPCDoneHandlerImplementation<R(const TResponse&)> { // done(result) public: using RPCDoneHandlerImplementation<R(const TResponse&)>::Parent::Parent; - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + auto response = TResponse(); + if (!response.read(from, end)) { + return false; + } if (this->_handler) { - auto response = TResponse(); - response.read(from, end); this->_handler(std::move(response)); } + return true; } }; @@ -889,12 +916,15 @@ template <typename R, typename TResponse> class RPCDoneHandlerImplementationReq : public RPCDoneHandlerImplementation<R(const TResponse&, mtpRequestId)> { // done(result, req_id) public: using RPCDoneHandlerImplementation<R(const TResponse&, mtpRequestId)>::Parent::Parent; - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + auto response = TResponse(); + if (!response.read(from, end)) { + return false; + } if (this->_handler) { - auto response = TResponse(); - response.read(from, end); this->_handler(std::move(response), requestId); } + return true; } }; @@ -903,10 +933,11 @@ template <typename R> class RPCDoneHandlerImplementationNo : public RPCDoneHandlerImplementation<R()> { // done() public: using RPCDoneHandlerImplementation<R()>::Parent::Parent; - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { if (this->_handler) { this->_handler(); } + return true; } }; @@ -915,10 +946,11 @@ template <typename R> class RPCDoneHandlerImplementationNoReq : public RPCDoneHandlerImplementation<R(mtpRequestId)> { // done(req_id) public: using RPCDoneHandlerImplementation<R(mtpRequestId)>::Parent::Parent; - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { if (this->_handler) { this->_handler(requestId); } + return true; } }; @@ -995,9 +1027,9 @@ template < RPCDoneHandlerPtr rpcDone(Lambda lambda) { using R = rpcDone_returnType_t<Function>; if constexpr (rpcDone_canCallBare_v<Lambda>) { - return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationBare<R>(std::move(lambda))); + return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationBare(std::move(lambda))); } else if constexpr (rpcDone_canCallBareReq_v<Lambda>) { - return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationBareReq<R>(std::move(lambda))); + return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationBareReq(std::move(lambda))); } else if constexpr (rpcDone_canCallNo_v<Lambda>) { return RPCDoneHandlerPtr(new RPCDoneHandlerImplementationNo<R>(std::move(lambda))); } else if constexpr (rpcDone_canCallNoReq_v<Lambda>) { diff --git a/Telegram/SourceFiles/mtproto/sender.h b/Telegram/SourceFiles/mtproto/sender.h index 71eead99a..37210d3fa 100644 --- a/Telegram/SourceFiles/mtproto/sender.h +++ b/Telegram/SourceFiles/mtproto/sender.h @@ -54,15 +54,18 @@ class Sender { DoneHandler(not_null<Sender*> sender, Callback handler) : _sender(sender), _handler(std::move(handler)) { } - void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { + bool operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override { auto handler = std::move(_handler); _sender->senderRequestHandled(requestId); + auto result = Response(); + if (!result.read(from, end)) { + return false; + } if (handler) { - auto result = Response(); - result.read(from, end); Policy::handle(std::move(handler), requestId, std::move(result)); } + return true; } private: diff --git a/Telegram/SourceFiles/mtproto/special_config_request.cpp b/Telegram/SourceFiles/mtproto/special_config_request.cpp index 4e011843a..9dd09fb23 100644 --- a/Telegram/SourceFiles/mtproto/special_config_request.cpp +++ b/Telegram/SourceFiles/mtproto/special_config_request.cpp @@ -458,9 +458,7 @@ bool SpecialConfigRequest::decryptSimpleConfig(const QByteArray &bytes) { return false; } - try { - _simpleConfig.read(from, end); - } catch (...) { + if (!_simpleConfig.read(from, end)) { LOG(("Config Error: Could not read configSimple.")); return false; }