/* * Copyright 2018 The WebRTC Project Authors. All rights reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "p2p/base/mdns_message.h" #include "rtc_base/logging.h" #include "rtc_base/net_helpers.h" #include "rtc_base/string_encode.h" namespace webrtc { namespace { // RFC 1035, Section 4.1.1. // // QR bit. constexpr uint16_t kMdnsFlagMaskQueryOrResponse = 0x8000; // AA bit. constexpr uint16_t kMdnsFlagMaskAuthoritative = 0x0400; // RFC 1035, Section 4.1.2, QCLASS and RFC 6762, Section 18.12, repurposing of // top bit of QCLASS as the unicast response bit. constexpr uint16_t kMdnsQClassMaskUnicastResponse = 0x8000; constexpr size_t kMdnsHeaderSizeBytes = 12; bool ReadDomainName(MessageBufferReader* buf, std::string* name) { size_t name_start_pos = buf->CurrentOffset(); uint8_t label_length; if (!buf->ReadUInt8(&label_length)) { return false; } // RFC 1035, Section 4.1.4. // // If the first two bits of the length octet are ones, the name is compressed // and the rest six bits with the next octet denotes its position in the // message by the offset from the start of the message. auto is_pointer = [](uint8_t octet) { return (octet & 0x80) && (octet & 0x40); }; while (label_length && !is_pointer(label_length)) { // RFC 1035, Section 2.3.1, labels are restricted to 63 octets or less. if (label_length > 63) { return false; } std::string label; if (!buf->ReadString(&label, label_length)) { return false; } (*name) += label + "."; if (!buf->ReadUInt8(&label_length)) { return false; } } if (is_pointer(label_length)) { uint8_t next_octet; if (!buf->ReadUInt8(&next_octet)) { return false; } size_t pos_jump_to = ((label_length & 0x3f) << 8) | next_octet; // A legitimate pointer only refers to a prior occurrence of the same name, // and we should only move strictly backward to a prior name field after the // header. if (pos_jump_to >= name_start_pos || pos_jump_to < kMdnsHeaderSizeBytes) { return false; } MessageBufferReader new_buf(buf->MessageData(), buf->MessageLength()); if (!new_buf.Consume(pos_jump_to)) { return false; } return ReadDomainName(&new_buf, name); } return true; } void WriteDomainName(rtc::ByteBufferWriter* buf, const std::string& name) { std::vector labels; rtc::tokenize(name, '.', &labels); for (const auto& label : labels) { buf->WriteUInt8(label.length()); buf->WriteString(label); } buf->WriteUInt8(0); } } // namespace void MdnsHeader::SetQueryOrResponse(bool is_query) { if (is_query) { flags &= ~kMdnsFlagMaskQueryOrResponse; } else { flags |= kMdnsFlagMaskQueryOrResponse; } } void MdnsHeader::SetAuthoritative(bool is_authoritative) { if (is_authoritative) { flags |= kMdnsFlagMaskAuthoritative; } else { flags &= ~kMdnsFlagMaskAuthoritative; } } bool MdnsHeader::IsAuthoritative() const { return flags & kMdnsFlagMaskAuthoritative; } bool MdnsHeader::Read(MessageBufferReader* buf) { if (!buf->ReadUInt16(&id) || !buf->ReadUInt16(&flags) || !buf->ReadUInt16(&qdcount) || !buf->ReadUInt16(&ancount) || !buf->ReadUInt16(&nscount) || !buf->ReadUInt16(&arcount)) { RTC_LOG(LS_ERROR) << "Invalid mDNS header."; return false; } return true; } void MdnsHeader::Write(rtc::ByteBufferWriter* buf) const { buf->WriteUInt16(id); buf->WriteUInt16(flags); buf->WriteUInt16(qdcount); buf->WriteUInt16(ancount); buf->WriteUInt16(nscount); buf->WriteUInt16(arcount); } bool MdnsHeader::IsQuery() const { return !(flags & kMdnsFlagMaskQueryOrResponse); } MdnsSectionEntry::MdnsSectionEntry() = default; MdnsSectionEntry::~MdnsSectionEntry() = default; MdnsSectionEntry::MdnsSectionEntry(const MdnsSectionEntry& other) = default; void MdnsSectionEntry::SetType(SectionEntryType type) { switch (type) { case SectionEntryType::kA: type_ = 1; return; case SectionEntryType::kAAAA: type_ = 28; return; default: RTC_NOTREACHED(); } } SectionEntryType MdnsSectionEntry::GetType() const { switch (type_) { case 1: return SectionEntryType::kA; case 28: return SectionEntryType::kAAAA; default: return SectionEntryType::kUnsupported; } } void MdnsSectionEntry::SetClass(SectionEntryClass cls) { switch (cls) { case SectionEntryClass::kIN: class_ = 1; return; default: RTC_NOTREACHED(); } } SectionEntryClass MdnsSectionEntry::GetClass() const { switch (class_) { case 1: return SectionEntryClass::kIN; default: return SectionEntryClass::kUnsupported; } } MdnsQuestion::MdnsQuestion() = default; MdnsQuestion::MdnsQuestion(const MdnsQuestion& other) = default; MdnsQuestion::~MdnsQuestion() = default; bool MdnsQuestion::Read(MessageBufferReader* buf) { if (!ReadDomainName(buf, &name_)) { RTC_LOG(LS_ERROR) << "Invalid name."; return false; } if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_)) { RTC_LOG(LS_ERROR) << "Invalid type and class."; return false; } return true; } bool MdnsQuestion::Write(rtc::ByteBufferWriter* buf) const { WriteDomainName(buf, name_); buf->WriteUInt16(type_); buf->WriteUInt16(class_); return true; } void MdnsQuestion::SetUnicastResponse(bool should_unicast) { if (should_unicast) { class_ |= kMdnsQClassMaskUnicastResponse; } else { class_ &= ~kMdnsQClassMaskUnicastResponse; } } bool MdnsQuestion::ShouldUnicastResponse() const { return class_ & kMdnsQClassMaskUnicastResponse; } MdnsResourceRecord::MdnsResourceRecord() = default; MdnsResourceRecord::MdnsResourceRecord(const MdnsResourceRecord& other) = default; MdnsResourceRecord::~MdnsResourceRecord() = default; bool MdnsResourceRecord::Read(MessageBufferReader* buf) { if (!ReadDomainName(buf, &name_)) { return false; } if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_) || !buf->ReadUInt32(&ttl_seconds_) || !buf->ReadUInt16(&rdlength_)) { return false; } switch (GetType()) { case SectionEntryType::kA: return ReadARData(buf); case SectionEntryType::kAAAA: return ReadQuadARData(buf); case SectionEntryType::kUnsupported: return false; default: RTC_NOTREACHED(); } return false; } bool MdnsResourceRecord::ReadARData(MessageBufferReader* buf) { // A RDATA contains a 32-bit IPv4 address. return buf->ReadString(&rdata_, 4); } bool MdnsResourceRecord::ReadQuadARData(MessageBufferReader* buf) { // AAAA RDATA contains a 128-bit IPv6 address. return buf->ReadString(&rdata_, 16); } bool MdnsResourceRecord::Write(rtc::ByteBufferWriter* buf) const { WriteDomainName(buf, name_); buf->WriteUInt16(type_); buf->WriteUInt16(class_); buf->WriteUInt32(ttl_seconds_); buf->WriteUInt16(rdlength_); switch (GetType()) { case SectionEntryType::kA: WriteARData(buf); return true; case SectionEntryType::kAAAA: WriteQuadARData(buf); return true; case SectionEntryType::kUnsupported: return false; default: RTC_NOTREACHED(); } return true; } void MdnsResourceRecord::WriteARData(rtc::ByteBufferWriter* buf) const { buf->WriteString(rdata_); } void MdnsResourceRecord::WriteQuadARData(rtc::ByteBufferWriter* buf) const { buf->WriteString(rdata_); } bool MdnsResourceRecord::SetIPAddressInRecordData( const rtc::IPAddress& address) { int af = address.family(); if (af != AF_INET && af != AF_INET6) { return false; } char out[16] = {0}; if (!rtc::inet_pton(af, address.ToString().c_str(), out)) { return false; } rdlength_ = (af == AF_INET) ? 4 : 16; rdata_ = std::string(out, rdlength_); return true; } bool MdnsResourceRecord::GetIPAddressFromRecordData( rtc::IPAddress* address) const { if (GetType() != SectionEntryType::kA && GetType() != SectionEntryType::kAAAA) { return false; } if (rdata_.size() != 4 && rdata_.size() != 16) { return false; } char out[INET6_ADDRSTRLEN] = {0}; int af = (GetType() == SectionEntryType::kA) ? AF_INET : AF_INET6; if (!rtc::inet_ntop(af, rdata_.data(), out, sizeof(out))) { return false; } return rtc::IPFromString(std::string(out), address); } MdnsMessage::MdnsMessage() = default; MdnsMessage::~MdnsMessage() = default; bool MdnsMessage::Read(MessageBufferReader* buf) { RTC_DCHECK_EQ(0u, buf->CurrentOffset()); if (!header_.Read(buf)) { return false; } auto read_question = [&buf](std::vector* section, uint16_t count) { section->resize(count); for (auto& question : (*section)) { if (!question.Read(buf)) { return false; } } return true; }; auto read_rr = [&buf](std::vector* section, uint16_t count) { section->resize(count); for (auto& rr : (*section)) { if (!rr.Read(buf)) { return false; } } return true; }; if (!read_question(&question_section_, header_.qdcount) || !read_rr(&answer_section_, header_.ancount) || !read_rr(&authority_section_, header_.nscount) || !read_rr(&additional_section_, header_.arcount)) { return false; } return true; } bool MdnsMessage::Write(rtc::ByteBufferWriter* buf) const { header_.Write(buf); auto write_rr = [&buf](const std::vector& section) { for (const auto& rr : section) { if (!rr.Write(buf)) { return false; } } return true; }; for (const auto& question : question_section_) { if (!question.Write(buf)) { return false; } } if (!write_rr(answer_section_) || !write_rr(authority_section_) || !write_rr(additional_section_)) { return false; } return true; } bool MdnsMessage::ShouldUnicastResponse() const { bool should_unicast = false; for (const auto& question : question_section_) { should_unicast |= question.ShouldUnicastResponse(); } return should_unicast; } void MdnsMessage::AddQuestion(const MdnsQuestion& question) { question_section_.push_back(question); header_.qdcount = question_section_.size(); } void MdnsMessage::AddAnswerRecord(const MdnsResourceRecord& answer) { answer_section_.push_back(answer); header_.ancount = answer_section_.size(); } } // namespace webrtc