#include "EncryptedConnection.h" #include "CryptoHelper.h" #include "rtc_base/logging.h" #include "rtc_base/byte_buffer.h" #include "rtc_base/time_utils.h" namespace tgcalls { namespace { constexpr auto kSingleMessagePacketSeqBit = (uint32_t(1) << 31); constexpr auto kMessageRequiresAckSeqBit = (uint32_t(1) << 30); constexpr auto kMaxAllowedCounter = std::numeric_limits::max() & ~kSingleMessagePacketSeqBit & ~kMessageRequiresAckSeqBit; static_assert(kMaxAllowedCounter < kSingleMessagePacketSeqBit, "bad"); static_assert(kMaxAllowedCounter < kMessageRequiresAckSeqBit, "bad"); constexpr auto kAckSerializedSize = sizeof(uint32_t) + sizeof(uint8_t); constexpr auto kNotAckedMessagesLimit = 64 * 1024; constexpr auto kMaxIncomingPacketSize = 128 * 1024; // don't try decrypting more constexpr auto kKeepIncomingCountersCount = 64; constexpr auto kMaxFullPacketSize = 1500; // IP_PACKET_SIZE from webrtc. // Max seen turn_overhead is around 36. constexpr auto kMaxOuterPacketSize = kMaxFullPacketSize - 48; constexpr auto kMaxSignalingPacketSize = 16 * 1024; constexpr auto kServiceCauseAcks = 1; constexpr auto kServiceCauseResend = 2; static constexpr uint8_t kAckId = uint8_t(-1); static constexpr uint8_t kEmptyId = uint8_t(-2); void AppendSeq(rtc::CopyOnWriteBuffer &buffer, uint32_t seq) { const auto bytes = rtc::HostToNetwork32(seq); buffer.AppendData(reinterpret_cast(&bytes), sizeof(bytes)); } void WriteSeq(void *bytes, uint32_t seq) { *reinterpret_cast(bytes) = rtc::HostToNetwork32(seq); } uint32_t ReadSeq(const void *bytes) { return rtc::NetworkToHost32(*reinterpret_cast(bytes)); } uint32_t CounterFromSeq(uint32_t seq) { return seq & ~kSingleMessagePacketSeqBit & ~kMessageRequiresAckSeqBit; } absl::nullopt_t LogError( const char *message, const std::string &additional = std::string()) { RTC_LOG(LS_ERROR) << "ERROR! " << message << additional; return absl::nullopt; } } // namespace EncryptedConnection::EncryptedConnection( Type type, const EncryptionKey &key, std::function requestSendService) : _type(type), _key(key), _delayIntervals(DelayIntervalsByType(type)), _requestSendService(std::move(requestSendService)) { assert(_key.value != nullptr); } auto EncryptedConnection::prepareForSending(const Message &message) -> absl::optional { const auto messageRequiresAck = absl::visit([](const auto &data) { return std::decay_t::kRequiresAck; }, message.data); // If message requires ack, then we can't serialize it as a single // message packet, because later it may be sent as a part of big packet. const auto singleMessagePacket = !haveAdditionalMessages() && !messageRequiresAck; const auto maybeSeq = computeNextSeq(messageRequiresAck, singleMessagePacket); if (!maybeSeq) { return absl::nullopt; } const auto seq = *maybeSeq; auto serialized = SerializeMessageWithSeq(message, seq, singleMessagePacket); if (!enoughSpaceInPacket(serialized, 0)) { return LogError("Too large packet: ", std::to_string(serialized.size())); } const auto notYetAckedCopy = messageRequiresAck ? serialized : rtc::CopyOnWriteBuffer(); if (!messageRequiresAck) { appendAdditionalMessages(serialized); return encryptPrepared(serialized); } const auto type = uint8_t(serialized.cdata()[4]); const auto sendEnqueued = !_myNotYetAckedMessages.empty(); if (sendEnqueued) { // All requiring ack messages should always be sent in order within // one packet, starting with the least not-yet-acked one. // So if we still have those, we send an empty message with all // requiring ack messages that will fit in correct order. RTC_LOG(LS_INFO) << logHeader() << "Enqueue SEND:type" << type << "#" << CounterFromSeq(seq); } else { RTC_LOG(LS_INFO) << logHeader() << "Add SEND:type" << type << "#" << CounterFromSeq(seq); appendAdditionalMessages(serialized); } _myNotYetAckedMessages.push_back({ notYetAckedCopy, rtc::TimeMillis() }); if (!sendEnqueued) { return encryptPrepared(serialized); } for (auto &queued : _myNotYetAckedMessages) { queued.lastSent = 0; } return prepareForSendingService(0); } auto EncryptedConnection::prepareForSendingService(int cause) -> absl::optional { if (cause == kServiceCauseAcks) { _sendAcksTimerActive = false; } else if (cause == kServiceCauseResend) { _resendTimerActive = false; } if (!haveAdditionalMessages()) { return absl::nullopt; } const auto messageRequiresAck = false; const auto singleMessagePacket = false; const auto seq = computeNextSeq(messageRequiresAck, singleMessagePacket); if (!seq) { return absl::nullopt; } auto serialized = SerializeEmptyMessageWithSeq(*seq); assert(enoughSpaceInPacket(serialized, 0)); RTC_LOG(LS_INFO) << logHeader() << "SEND:empty#" << CounterFromSeq(*seq); appendAdditionalMessages(serialized); return encryptPrepared(serialized); } bool EncryptedConnection::haveAdditionalMessages() const { return !_myNotYetAckedMessages.empty() || !_acksToSendSeqs.empty(); } absl::optional EncryptedConnection::computeNextSeq( bool messageRequiresAck, bool singleMessagePacket) { if (messageRequiresAck && _myNotYetAckedMessages.size() >= kNotAckedMessagesLimit) { return LogError("Too many not ACKed messages."); } else if (_counter == kMaxAllowedCounter) { return LogError("Outgoing packet limit reached."); } return (++_counter) | (singleMessagePacket ? kSingleMessagePacketSeqBit : 0) | (messageRequiresAck ? kMessageRequiresAckSeqBit : 0); } size_t EncryptedConnection::packetLimit() const { switch (_type) { case Type::Signaling: return kMaxSignalingPacketSize; default: return kMaxOuterPacketSize; } } bool EncryptedConnection::enoughSpaceInPacket(const rtc::CopyOnWriteBuffer &buffer, size_t amount) const { const auto limit = packetLimit(); return (amount < limit) && (16 + buffer.size() + amount <= limit); } void EncryptedConnection::appendAcksToSend(rtc::CopyOnWriteBuffer &buffer) { auto i = _acksToSendSeqs.begin(); while ((i != _acksToSendSeqs.end()) && enoughSpaceInPacket( buffer, kAckSerializedSize)) { RTC_LOG(LS_INFO) << logHeader() << "Add ACK#" << CounterFromSeq(*i); AppendSeq(buffer, *i); buffer.AppendData(&kAckId, 1); ++i; } _acksToSendSeqs.erase(_acksToSendSeqs.begin(), i); for (const auto seq : _acksToSendSeqs) { RTC_LOG(LS_INFO) << logHeader() << "Skip ACK#" << CounterFromSeq(seq) << " (no space, length: " << kAckSerializedSize << ", already: " << buffer.size() << ")"; } } size_t EncryptedConnection::fullNotAckedLength() const { assert(_myNotYetAckedMessages.size() < kNotAckedMessagesLimit); auto result = size_t(); for (const auto &message : _myNotYetAckedMessages) { result += message.data.size(); } return result; } void EncryptedConnection::appendAdditionalMessages(rtc::CopyOnWriteBuffer &buffer) { appendAcksToSend(buffer); if (_myNotYetAckedMessages.empty()) { return; } const auto now = rtc::TimeMillis(); auto someWereNotAdded = false; for (auto &resending : _myNotYetAckedMessages) { const auto sent = resending.lastSent; const auto when = sent ? (sent + _delayIntervals.minDelayBeforeMessageResend) : 0; assert(resending.data.size() >= 5); const auto counter = CounterFromSeq(ReadSeq(resending.data.data())); const auto type = uint8_t(resending.data.data()[4]); if (when > now) { RTC_LOG(LS_INFO) << logHeader() << "Skip RESEND:type" << type << "#" << counter << " (wait " << (when - now) << "ms)."; break; } else if (enoughSpaceInPacket(buffer, resending.data.size())) { RTC_LOG(LS_INFO) << logHeader() << "Add RESEND:type" << type << "#" << counter; buffer.AppendData(resending.data); resending.lastSent = now; } else { RTC_LOG(LS_INFO) << logHeader() << "Skip RESEND:type" << type << "#" << counter << " (no space, length: " << resending.data.size() << ", already: " << buffer.size() << ")"; break; } } if (!_resendTimerActive) { _resendTimerActive = true; _requestSendService( _delayIntervals.maxDelayBeforeMessageResend, kServiceCauseResend); } } auto EncryptedConnection::encryptPrepared(const rtc::CopyOnWriteBuffer &buffer) -> EncryptedPacket { auto result = EncryptedPacket(); result.counter = CounterFromSeq(ReadSeq(buffer.data())); result.bytes.resize(16 + buffer.size()); const auto x = (_key.isOutgoing ? 0 : 8) + (_type == Type::Signaling ? 128 : 0); const auto key = _key.value->data(); const auto msgKeyLarge = ConcatSHA256( MemorySpan{ key + 88 + x, 32 }, MemorySpan{ buffer.data(), buffer.size() }); const auto msgKey = result.bytes.data(); memcpy(msgKey, msgKeyLarge.data() + 8, 16); auto aesKeyIv = PrepareAesKeyIv(key, msgKey, x); AesProcessCtr( MemorySpan{ buffer.data(), buffer.size() }, result.bytes.data() + 16, std::move(aesKeyIv)); return result; } bool EncryptedConnection::registerIncomingCounter(uint32_t incomingCounter) { auto &list = _largestIncomingCounters; const auto position = std::lower_bound(list.begin(), list.end(), incomingCounter); const auto largest = list.empty() ? 0 : list.back(); if (position != list.end() && *position == incomingCounter) { // The packet is in the list already. return false; } else if (incomingCounter + kKeepIncomingCountersCount <= largest) { // The packet is too old. return false; } const auto eraseTill = std::find_if(list.begin(), list.end(), [&](uint32_t counter) { return (counter + kKeepIncomingCountersCount > incomingCounter); }); const auto eraseCount = eraseTill - list.begin(); const auto positionIndex = (position - list.begin()) - eraseCount; list.erase(list.begin(), eraseTill); assert(positionIndex >= 0 && positionIndex <= list.size()); list.insert(list.begin() + positionIndex, incomingCounter); return true; } auto EncryptedConnection::handleIncomingPacket(const char *bytes, size_t size) -> absl::optional { if (size < 21 || size > kMaxIncomingPacketSize) { return LogError("Bad incoming packet size: ", std::to_string(size)); } const auto x = (_key.isOutgoing ? 8 : 0) + (_type == Type::Signaling ? 128 : 0); const auto key = _key.value->data(); const auto msgKey = reinterpret_cast(bytes); const auto encryptedData = msgKey + 16; const auto dataSize = size - 16; auto aesKeyIv = PrepareAesKeyIv(key, msgKey, x); auto decryptionBuffer = rtc::Buffer(dataSize); AesProcessCtr( MemorySpan{ encryptedData, dataSize }, decryptionBuffer.data(), std::move(aesKeyIv)); const auto msgKeyLarge = ConcatSHA256( MemorySpan{ key + 88 + x, 32 }, MemorySpan{ decryptionBuffer.data(), decryptionBuffer.size() }); if (memcmp(msgKeyLarge.data() + 8, msgKey, 16)) { return LogError("Bad incoming data hash."); } const auto incomingSeq = ReadSeq(decryptionBuffer.data()); const auto incomingCounter = CounterFromSeq(incomingSeq); if (!registerIncomingCounter(incomingCounter)) { // We've received that packet already. return LogError("Already handled packet received.", std::to_string(incomingCounter)); } return processPacket(decryptionBuffer, incomingSeq); } auto EncryptedConnection::processPacket( const rtc::Buffer &fullBuffer, uint32_t packetSeq) -> absl::optional { assert(fullBuffer.size() >= 5); auto additionalMessage = false; auto firstMessageRequiringAck = true; auto newRequiringAckReceived = false; auto currentSeq = packetSeq; auto currentCounter = CounterFromSeq(currentSeq); rtc::ByteBufferReader reader( reinterpret_cast(fullBuffer.data() + 4), // Skip seq. fullBuffer.size() - 4); auto result = absl::optional(); while (true) { const auto type = uint8_t(*reader.Data()); const auto singleMessagePacket = ((currentSeq & kSingleMessagePacketSeqBit) != 0); if (singleMessagePacket && additionalMessage) { return LogError("Single message packet bit in not first message."); } if (type == kEmptyId) { RTC_LOG(LS_INFO) << logHeader() << "Got RECV:empty" << "#" << currentCounter; reader.Consume(1); } else if (type == kAckId) { ackMyMessage(currentSeq); reader.Consume(1); } else if (auto message = DeserializeMessage(reader, singleMessagePacket)) { const auto messageRequiresAck = ((currentSeq & kMessageRequiresAckSeqBit) != 0); const auto skipMessage = messageRequiresAck ? !registerSentAck(currentCounter, firstMessageRequiringAck) : (additionalMessage && !registerIncomingCounter(currentCounter)); if (messageRequiresAck) { firstMessageRequiringAck = false; if (!skipMessage) { newRequiringAckReceived = true; } sendAckPostponed(currentSeq); RTC_LOG(LS_INFO) << logHeader() << (skipMessage ? "Repeated RECV:type" : "Got RECV:type") << type << "#" << currentCounter; } if (!skipMessage) { appendReceivedMessage(result, std::move(*message), currentSeq); } } else { return LogError("Could not parse message from packet, type: ", std::to_string(type)); } if (!reader.Length()) { break; } else if (singleMessagePacket) { return LogError("Single message didn't fill the entire packet."); } else if (reader.Length() < 5) { return LogError("Bad remaining data size: ", std::to_string(reader.Length())); } const auto success = reader.ReadUInt32(¤tSeq); assert(success); currentCounter = CounterFromSeq(currentSeq); additionalMessage = true; } if (!_acksToSendSeqs.empty()) { if (newRequiringAckReceived) { _requestSendService(0, 0); } else if (!_sendAcksTimerActive) { _sendAcksTimerActive = true; _requestSendService( _delayIntervals.maxDelayBeforeAckResend, kServiceCauseAcks); } } return result; } void EncryptedConnection::appendReceivedMessage( absl::optional &to, Message &&message, uint32_t incomingSeq) { auto decrypted = DecryptedMessage{ std::move(message), CounterFromSeq(incomingSeq) }; if (to) { to->additional.push_back(std::move(decrypted)); } else { to = DecryptedPacket{ std::move(decrypted) }; } } const char *EncryptedConnection::logHeader() const { return (_type == Type::Signaling) ? "(signaling) " : "(transport) "; } bool EncryptedConnection::registerSentAck(uint32_t counter, bool firstInPacket) { auto &list = _acksSentCounters; const auto position = std::lower_bound(list.begin(), list.end(), counter); const auto already = (position != list.end()) && (*position == counter); const auto was = list; if (firstInPacket) { list.erase(list.begin(), position); if (!already) { list.insert(list.begin(), counter); } } else if (!already) { list.insert(position, counter); } return !already; } void EncryptedConnection::sendAckPostponed(uint32_t incomingSeq) { auto &list = _acksToSendSeqs; const auto already = std::find(list.begin(), list.end(), incomingSeq); if (already == list.end()) { list.push_back(incomingSeq); } } void EncryptedConnection::ackMyMessage(uint32_t seq) { auto type = uint8_t(0); auto &list = _myNotYetAckedMessages; for (auto i = list.begin(), e = list.end(); i != e; ++i) { assert(i->data.size() >= 5); if (ReadSeq(i->data.cdata()) == seq) { type = uint8_t(i->data.cdata()[4]); list.erase(i); break; } } RTC_LOG(LS_INFO) << logHeader() << (type ? "Got ACK:type" + std::to_string(type) + "#" : "Repeated ACK#") << CounterFromSeq(seq); } auto EncryptedConnection::DelayIntervalsByType(Type type) -> DelayIntervals { auto result = DelayIntervals(); const auto signaling = (type == Type::Signaling); // Don't resend faster than min delay even if we have a packet we can attach to. result.minDelayBeforeMessageResend = signaling ? 3000 : 300; // When max delay elapsed we resend anyway, in a dedicated packet. result.maxDelayBeforeMessageResend = signaling ? 5000 : 1000; result.maxDelayBeforeAckResend = signaling ? 5000 : 1000; return result; } rtc::CopyOnWriteBuffer EncryptedConnection::SerializeEmptyMessageWithSeq(uint32_t seq) { auto result = rtc::CopyOnWriteBuffer(5); const auto bytes = result.data(); WriteSeq(bytes, seq); bytes[4] = kEmptyId; return result; } } // namespace tgcalls