From a492d04a31d675d38c115cf248a6b4416a386811 Mon Sep 17 00:00:00 2001 From: Abtin Keshavarzian Date: Tue, 19 Dec 2023 10:45:46 -0800 Subject: [PATCH] [dns-types] add template variants of reading DNS names or labels (#9720) This commit introduces template variants for `ReadName()` and other related methods, allowing flexible reading of DNS names and labels from messages into a given array buffer. This simplifies the code and improves readability. Additionally, this commit defines new types, `Dns::Name::Buffer` and `Dns::Name::LabelBuffer`, as arrays of char with fixed sizes to hold DNS names and labels, respectively. --- src/core/net/dns_types.hpp | 125 +++++++++++++++++++++++++++++++++ src/core/net/dnssd_server.cpp | 58 +++++++-------- src/core/net/dnssd_server.hpp | 10 ++- src/core/net/srp_server.cpp | 68 +++++++++--------- src/core/net/srp_server.hpp | 6 ++ tests/unit/test_dns.cpp | 74 +++++++++---------- tests/unit/test_dns_client.cpp | 22 +++--- 7 files changed, 245 insertions(+), 118 deletions(-) diff --git a/src/core/net/dns_types.hpp b/src/core/net/dns_types.hpp index 9f8a93093a1..a8fb88a158e 100644 --- a/src/core/net/dns_types.hpp +++ b/src/core/net/dns_types.hpp @@ -511,8 +511,24 @@ class Name : public Clearable */ static constexpr uint8_t kMaxLabelLength = kMaxLabelSize - 1; + /** + * Dot character separating labels in a name. + * + */ static constexpr char kLabelSeparatorChar = '.'; + /** + * Represents a string buffer (with `kMaxNameSize`) intended to hold a DNS name. + * + */ + typedef char Buffer[kMaxNameSize]; + + /** + * Represents a string buffer (with `kMaxLabelSize`) intended to hold a DNS label. + * + */ + typedef char LabelBuffer[kMaxLabelSize]; + /** * Represents the name type. * @@ -821,6 +837,35 @@ class Name : public Clearable */ static Error ReadName(const Message &aMessage, uint16_t &aOffset, char *aNameBuffer, uint16_t aNameBufferSize); + /** + * Reads a full name from a message. + * + * On successful read, the read name follows "...", i.e., a sequence of labels separated by + * dot '.' character. The read name will ALWAYS end with a dot. + * + * Verifies that the labels after the first label in message do not contain any dot character. If they do, + * returns `kErrorParse`. + * + * @tparam kNameBufferSize Size of the string buffer array. + * + * @param[in] aMessage The message to read the name from. `aMessage.GetOffset()` MUST point to + * the start of DNS header (this is used to handle compressed names). + * @param[in,out] aOffset On input, the offset in @p aMessage pointing to the start of the name field. + * On exit (when parsed successfully), @p aOffset is updated to point to the byte + * after the end of name field. + * @param[out] aNameBuffer Reference to a name string buffer to output the read name. + * + * @retval kErrorNone Successfully read the name, @p aNameBuffer and @p Offset are updated. + * @retval kErrorParse Name could not be parsed (invalid format). + * @retval kErrorNoBufs Name could not fit in @p aNameBuffer. + * + */ + template + static Error ReadName(const Message &aMessage, uint16_t &aOffset, char (&aNameBuffer)[kNameBufferSize]) + { + return ReadName(aMessage, aOffset, aNameBuffer, kNameBufferSize); + } + /** * Compares a single name label from a message with a given label string. * @@ -946,6 +991,30 @@ class Name : public Clearable */ static Error ExtractLabels(const char *aName, const char *aSuffixName, char *aLabels, uint16_t aLabelsSize); + /** + * Extracts label(s) from a full name by checking that it contains a given suffix name (e.g., suffix name can be + * a domain name) and removing it. + * + * Both @p aName and @p aSuffixName must be full DNS name and end with ('.'), otherwise the behavior of this method + * is undefined. + * + * @tparam kLabelsBufferSize Size of the buffer string. + * + * @param[in] aName The full name to extract labels from. + * @param[in] aSuffixName The suffix name (e.g. can be domain name). + * @param[out] aLabelsBuffer A buffer to copy the extracted labels. + * + * @retval kErrorNone Successfully extracted the labels, @p aLabels is updated. + * @retval kErrorParse @p aName does not contain @p aSuffixName. + * @retval kErrorNoBufs Could not fit the labels in @p aLabels. + * + */ + template + static Error ExtractLabels(const char *aName, const char *aSuffixName, char (&aLabels)[kLabelsBufferSize]) + { + return ExtractLabels(aName, aSuffixName, aLabels, kLabelsBufferSize); + } + /** * Tests if a DNS name is a sub-domain of a given domain. * @@ -1660,6 +1729,36 @@ class PtrRecord : public ResourceRecord char *aNameBuffer, uint16_t aNameBufferSize) const; + /** + * Parses and reads the PTR name from a message. + * + * This is a template variation of the previous method with name and label buffer sizes as template parameters. + * + * @tparam kLabelBufferSize The size of label buffer. + * @tparam kNameBufferSize The size of name buffer. + * + * @param[in] aMessage The message to read from. `aMessage.GetOffset()` MUST point to the start of + * DNS header. + * @param[in,out] aOffset On input, the offset in @p aMessage to the start of PTR name field. + * On exit, when successfully read, @p aOffset is updated to point to the byte + * after the entire PTR record (skipping over the record). + * @param[out] aLabelBuffer A char array buffer to output the first label as a null-terminated C string. + * @param[out] aNameBuffer A char array to output the rest of name (after first label). + * + * @retval kErrorNone The PTR name was read successfully. @p aOffset, @aLabelBuffer and @aNameBuffer are updated. + * @retval kErrorParse The PTR record in @p aMessage could not be parsed (invalid format). + * @retval kErrorNoBufs Either label or name could not fit in the related given buffers. + * + */ + template + Error ReadPtrName(const Message &aMessage, + uint16_t &aOffset, + char (&aLabelBuffer)[kLabelBufferSize], + char (&aNameBuffer)[kNameBufferSize]) const + { + return ReadPtrName(aMessage, aOffset, aLabelBuffer, kLabelBufferSize, aNameBuffer, kNameBufferSize); + } + } OT_TOOL_PACKED_END; /** @@ -1868,6 +1967,32 @@ class SrvRecord : public ResourceRecord /* aSkipRecord */ true); } + /** + * Parses and reads the SRV target host name from a message. + * + * Also verifies that the SRV record is well-formed (e.g., the record data length `GetLength()` matches + * the SRV encoded name). + * + * @tparam kNameBufferSize Size of the name buffer. + * + * @param[in] aMessage The message to read from. `aMessage.GetOffset()` MUST point to the start of + * DNS header. + * @param[in,out] aOffset On input, the offset in @p aMessage to start of target host name field. + * On exit when successfully read, @p aOffset is updated to point to the byte + * after the entire SRV record (skipping over the record). + * @param[out] aNameBuffer A char array to output the read name as a null-terminated C string + * + * @retval kErrorNone The host name was read successfully. @p aOffset and @p aNameBuffer are updated. + * @retval kErrorParse The SRV record in @p aMessage could not be parsed (invalid format). + * @retval kErrorNoBufs Name could not fit in @p aNameBuffer. + * + */ + template + Error ReadTargetHostName(const Message &aMessage, uint16_t &aOffset, char (&aNameBuffer)[kNameBufferSize]) const + { + return ReadTargetHostName(aMessage, aOffset, aNameBuffer, kNameBufferSize); + } + private: uint16_t mPriority; uint16_t mWeight; diff --git a/src/core/net/dnssd_server.cpp b/src/core/net/dnssd_server.cpp index ea59c08d515..33aabbd66b0 100644 --- a/src/core/net/dnssd_server.cpp +++ b/src/core/net/dnssd_server.cpp @@ -412,12 +412,12 @@ Error Server::Response::ParseQueryName(void) // Parses and validates the query name and updates // the name compression offsets. - Error error = kErrorNone; - DnsName name; - uint16_t offset; + Error error = kErrorNone; + Name::Buffer name; + uint16_t offset; offset = sizeof(Header); - SuccessOrExit(error = Name::ReadName(*mMessage, offset, name, sizeof(name))); + SuccessOrExit(error = Name::ReadName(*mMessage, offset, name)); switch (mType) { @@ -446,9 +446,9 @@ Error Server::Response::ParseQueryName(void) while (true) { - DnsLabel label; - uint8_t labelLength = sizeof(label); - uint16_t comapreOffset; + Name::LabelBuffer label; + uint8_t labelLength = sizeof(label); + uint16_t comapreOffset; SuccessOrExit(error = Name::ReadLabel(*mMessage, offset, label, labelLength)); @@ -472,7 +472,7 @@ Error Server::Response::ParseQueryName(void) return error; } -void Server::Response::ReadQueryName(DnsName &aName) const { Server::ReadQueryName(*mMessage, aName); } +void Server::Response::ReadQueryName(Name::Buffer &aName) const { Server::ReadQueryName(*mMessage, aName); } bool Server::Response::QueryNameMatches(const char *aName) const { return Server::QueryNameMatches(*mMessage, aName); } @@ -526,12 +526,12 @@ Error Server::Response::AppendSrvRecord(const char *aHostName, uint16_t aWeight, uint16_t aPort) { - Error error = kErrorNone; - SrvRecord srvRecord; - uint16_t recordOffset; - DnsName hostLabels; + Error error = kErrorNone; + SrvRecord srvRecord; + uint16_t recordOffset; + Name::Buffer hostLabels; - SuccessOrExit(error = Name::ExtractLabels(aHostName, kDefaultDomainName, hostLabels, sizeof(hostLabels))); + SuccessOrExit(error = Name::ExtractLabels(aHostName, kDefaultDomainName, hostLabels)); srvRecord.Init(); srvRecord.SetTtl(aTtl); @@ -674,7 +674,7 @@ uint8_t Server::GetNameLength(const char *aName) #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO) void Server::Response::Log(void) const { - DnsName name; + Name::Buffer name; ReadQueryName(name); LogInfo("%s query for '%s'", QueryTypeToString(mType), name); @@ -830,16 +830,16 @@ bool Server::Response::QueryNameMatchesService(const Srp::Server::Service &aServ #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE bool Server::ShouldForwardToUpstream(const Request &aRequest) { - bool shouldForward = false; - uint16_t readOffset; - DnsName name; + bool shouldForward = false; + uint16_t readOffset; + Name::Buffer name; VerifyOrExit(aRequest.mHeader.IsRecursionDesiredFlagSet()); readOffset = sizeof(Header); for (uint16_t i = 0; i < aRequest.mHeader.GetQuestionCount(); i++) { - SuccessOrExit(Name::ReadName(*aRequest.mMessage, readOffset, name, sizeof(name))); + SuccessOrExit(Name::ReadName(*aRequest.mMessage, readOffset, name)); readOffset += sizeof(Question); VerifyOrExit(!Name::IsSubDomainOf(name, kDefaultDomainName)); @@ -915,7 +915,7 @@ void Server::ResolveByProxy(Response &aResponse, const Ip6::MessageInfo &aMessag { ProxyQuery *query; ProxyQueryInfo info; - DnsName name; + Name::Buffer name; VerifyOrExit(mQuerySubscribe.IsSet()); @@ -950,11 +950,11 @@ void Server::ResolveByProxy(Response &aResponse, const Ip6::MessageInfo &aMessag return; } -void Server::ReadQueryName(const Message &aQuery, DnsName &aName) +void Server::ReadQueryName(const Message &aQuery, Name::Buffer &aName) { uint16_t offset = sizeof(Header); - IgnoreError(Name::ReadName(aQuery, offset, aName, sizeof(aName))); + IgnoreError(Name::ReadName(aQuery, offset, aName)); } bool Server::QueryNameMatches(const Message &aQuery, const char *aName) @@ -979,20 +979,20 @@ void Server::ProxyQueryInfo::UpdateIn(ProxyQuery &aQuery) const aQuery.Write(aQuery.GetLength() - sizeof(ProxyQueryInfo), *this); } -Error Server::Response::ExtractServiceInstanceLabel(const char *aInstanceName, DnsLabel &aLabel) +Error Server::Response::ExtractServiceInstanceLabel(const char *aInstanceName, Name::LabelBuffer &aLabel) { - uint16_t offset; - DnsName serviceName; + uint16_t offset; + Name::Buffer serviceName; offset = mOffsets.mServiceName; - IgnoreError(Name::ReadName(*mMessage, offset, serviceName, sizeof(serviceName))); + IgnoreError(Name::ReadName(*mMessage, offset, serviceName)); - return Name::ExtractLabels(aInstanceName, serviceName, aLabel, sizeof(aLabel)); + return Name::ExtractLabels(aInstanceName, serviceName, aLabel); } void Server::RemoveQueryAndPrepareResponse(ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, Response &aResponse) { - DnsName name; + Name::Buffer name; mProxyQueries.Dequeue(aQuery); aInfo.RemoveFrom(aQuery); @@ -1021,7 +1021,7 @@ void Server::Response::Answer(const ServiceInstanceInfo &aInstanceInfo, const Ip if (mType == kPtrQuery) { - DnsLabel instanceLabel; + Name::LabelBuffer instanceLabel; SuccessOrExit(error = ExtractServiceInstanceLabel(aInstanceInfo.mFullName, instanceLabel)); mSection = kAnswerSection; @@ -1148,7 +1148,7 @@ const otDnssdQuery *Server::GetNextQuery(const otDnssdQuery *aQuery) const return (query == nullptr) ? mProxyQueries.GetHead() : query->GetNext(); } -Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, char (&aName)[Name::kMaxNameSize]) +Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, Dns::Name::Buffer &aName) { const ProxyQuery *query = static_cast(aQuery); ProxyQueryInfo info; diff --git a/src/core/net/dnssd_server.hpp b/src/core/net/dnssd_server.hpp index de01d5534ea..3b6b67993b9 100644 --- a/src/core/net/dnssd_server.hpp +++ b/src/core/net/dnssd_server.hpp @@ -258,7 +258,7 @@ class Server : public InstanceLocator, private NonCopyable * @returns The DNS-SD query type. * */ - static DnsQueryType GetQueryTypeAndName(const otDnssdQuery *aQuery, char (&aName)[Name::kMaxNameSize]); + static DnsQueryType GetQueryTypeAndName(const otDnssdQuery *aQuery, Dns::Name::Buffer &aName); /** * Returns the counters of the DNS-SD server. @@ -297,8 +297,6 @@ class Server : public InstanceLocator, private NonCopyable static constexpr uint16_t kMaxConcurrentUpstreamQueries = 32; typedef Header::Response ResponseCode; - typedef char DnsName[Name::kMaxNameSize]; - typedef char DnsLabel[Name::kMaxLabelSize]; typedef Message ProxyQuery; typedef MessageQueue ProxyQueryList; @@ -347,7 +345,7 @@ class Server : public InstanceLocator, private NonCopyable void SetResponseCode(ResponseCode aResponseCode) { mHeader.SetResponseCode(aResponseCode); } ResponseCode AddQuestionsFrom(const Request &aRequest); Error ParseQueryName(void); - void ReadQueryName(DnsName &aName) const; + void ReadQueryName(Name::Buffer &aName) const; bool QueryNameMatches(const char *aName) const; Error AppendQueryName(void); Error AppendPtrRecord(const char *aInstanceLabel, uint32_t aTtl); @@ -367,7 +365,7 @@ class Server : public InstanceLocator, private NonCopyable void Send(const Ip6::MessageInfo &aMessageInfo); void Answer(const HostInfo &aHostInfo, const Ip6::MessageInfo &aMessageInfo); void Answer(const ServiceInstanceInfo &aInstanceInfo, const Ip6::MessageInfo &aMessageInfo); - Error ExtractServiceInstanceLabel(const char *aInstanceName, DnsLabel &aLabel); + Error ExtractServiceInstanceLabel(const char *aInstanceName, Name::LabelBuffer &aLabel); #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE Error ResolveBySrp(void); bool QueryNameMatchesService(const Srp::Server::Service &aService) const; @@ -408,7 +406,7 @@ class Server : public InstanceLocator, private NonCopyable void ResolveByProxy(Response &aResponse, const Ip6::MessageInfo &aMessageInfo); void RemoveQueryAndPrepareResponse(ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, Response &aResponse); void Finalize(ProxyQuery &aQuery, ResponseCode aResponseCode); - static void ReadQueryName(const Message &aQuery, DnsName &aName); + static void ReadQueryName(const Message &aQuery, Name::Buffer &aName); static bool QueryNameMatches(const Message &aQuery, const char *aName); #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE diff --git a/src/core/net/srp_server.cpp b/src/core/net/srp_server.cpp index aaaaaa71f85..946b84163e6 100644 --- a/src/core/net/srp_server.cpp +++ b/src/core/net/srp_server.cpp @@ -796,13 +796,13 @@ void Server::ProcessDnsUpdate(Message &aMessage, MessageMetadata &aMetadata) Error Server::ProcessZoneSection(const Message &aMessage, MessageMetadata &aMetadata) const { - Error error = kErrorNone; - char name[Dns::Name::kMaxNameSize]; - uint16_t offset = aMetadata.mOffset; + Error error = kErrorNone; + Dns::Name::Buffer name; + uint16_t offset = aMetadata.mOffset; VerifyOrExit(aMetadata.mDnsHeader.GetZoneRecordCount() == 1, error = kErrorParse); - SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name))); + SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name)); // TODO: return `Dns::kResponseNotAuth` for not authorized zone names. VerifyOrExit(StringMatch(name, GetDomain(), kStringCaseInsensitiveMatch), error = kErrorSecurity); SuccessOrExit(error = aMessage.Read(offset, aMetadata.mDnsZone)); @@ -861,10 +861,10 @@ Error Server::ProcessHostDescriptionInstruction(Host &aHost, for (uint16_t numRecords = aMetadata.mDnsHeader.GetUpdateRecordCount(); numRecords > 0; numRecords--) { - char name[Dns::Name::kMaxNameSize]; + Dns::Name::Buffer name; Dns::ResourceRecord record; - SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name))); + SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name)); SuccessOrExit(error = aMessage.Read(offset, record)); @@ -952,9 +952,9 @@ Error Server::ProcessServiceDiscoveryInstructions(Host &aHost, for (uint16_t numRecords = aMetadata.mDnsHeader.GetUpdateRecordCount(); numRecords > 0; numRecords--) { - char serviceName[Dns::Name::kMaxNameSize]; - char instanceLabel[Dns::Name::kMaxLabelSize]; - char instanceServiceName[Dns::Name::kMaxNameSize]; + Dns::Name::Buffer serviceName; + Dns::Name::LabelBuffer instanceLabel; + Dns::Name::Buffer instanceServiceName; String instanceName; Dns::PtrRecord ptrRecord; const char *subServiceName; @@ -962,7 +962,7 @@ Error Server::ProcessServiceDiscoveryInstructions(Host &aHost, bool isSubType; bool isDelete; - SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, serviceName, sizeof(serviceName))); + SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, serviceName)); VerifyOrExit(Dns::Name::IsSubDomainOf(serviceName, GetDomain()), error = kErrorSecurity); error = Dns::ResourceRecord::ReadRecord(aMessage, offset, ptrRecord); @@ -977,8 +977,7 @@ Error Server::ProcessServiceDiscoveryInstructions(Host &aHost, SuccessOrExit(error); - SuccessOrExit(error = ptrRecord.ReadPtrName(aMessage, offset, instanceLabel, sizeof(instanceLabel), - instanceServiceName, sizeof(instanceServiceName))); + SuccessOrExit(error = ptrRecord.ReadPtrName(aMessage, offset, instanceLabel, instanceServiceName)); instanceName.Append("%s.%s", instanceLabel, instanceServiceName); // Class None indicates "Delete an RR from an RRset". @@ -1076,11 +1075,11 @@ Error Server::ProcessServiceDescriptionInstructions(Host &aHost, for (uint16_t numRecords = aMetadata.mDnsHeader.GetUpdateRecordCount(); numRecords > 0; numRecords--) { - char name[Dns::Name::kMaxNameSize]; + Dns::Name::Buffer name; Dns::ResourceRecord record; Service *service; - SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name))); + SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name)); SuccessOrExit(error = aMessage.Read(offset, record)); if (record.GetClass() == Dns::ResourceRecord::kClassAny) @@ -1102,9 +1101,8 @@ Error Server::ProcessServiceDescriptionInstructions(Host &aHost, if (record.GetType() == Dns::ResourceRecord::kTypeSrv) { - Dns::SrvRecord srvRecord; - char hostName[Dns::Name::kMaxNameSize]; - uint16_t hostNameLength = sizeof(hostName); + Dns::SrvRecord srvRecord; + Dns::Name::Buffer hostName; VerifyOrExit(record.GetClass() == aMetadata.mDnsZone.GetClass(), error = kErrorFailed); @@ -1113,7 +1111,7 @@ Error Server::ProcessServiceDescriptionInstructions(Host &aHost, SuccessOrExit(error = aMessage.Read(offset, srvRecord)); offset += sizeof(srvRecord); - SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, hostName, hostNameLength)); + SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, hostName)); VerifyOrExit(Dns::Name::IsSubDomainOf(name, GetDomain()), error = kErrorSecurity); VerifyOrExit(aHost.Matches(hostName), error = kErrorFailed); @@ -1179,22 +1177,22 @@ bool Server::IsValidDeleteAllRecord(const Dns::ResourceRecord &aRecord) Error Server::ProcessAdditionalSection(Host *aHost, const Message &aMessage, MessageMetadata &aMetadata) const { - Error error = kErrorNone; - Dns::OptRecord optRecord; - Dns::LeaseOption leaseOption; - Dns::SigRecord sigRecord; - char name[2]; // The root domain name (".") is expected. - uint16_t offset = aMetadata.mOffset; - uint16_t sigOffset; - uint16_t sigRdataOffset; - char signerName[Dns::Name::kMaxNameSize]; - uint16_t signatureLength; + Error error = kErrorNone; + Dns::OptRecord optRecord; + Dns::LeaseOption leaseOption; + Dns::SigRecord sigRecord; + char name[2]; // The root domain name (".") is expected. + uint16_t offset = aMetadata.mOffset; + uint16_t sigOffset; + uint16_t sigRdataOffset; + Dns::Name::Buffer signerName; + uint16_t signatureLength; VerifyOrExit(aMetadata.mDnsHeader.GetAdditionalRecordCount() == 2, error = kErrorFailed); // EDNS(0) Update Lease Option. - SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name))); + SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name)); SuccessOrExit(error = aMessage.Read(offset, optRecord)); SuccessOrExit(error = leaseOption.ReadFrom(aMessage, offset + sizeof(optRecord), optRecord.GetLength())); @@ -1221,7 +1219,7 @@ Error Server::ProcessAdditionalSection(Host *aHost, const Message &aMessage, Mes // SIG(0). sigOffset = offset; - SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name))); + SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name)); SuccessOrExit(error = aMessage.Read(offset, sigRecord)); VerifyOrExit(sigRecord.IsValid(), error = kErrorParse); @@ -1232,7 +1230,7 @@ Error Server::ProcessAdditionalSection(Host *aHost, const Message &aMessage, Mes // implemented because the end device may not be able to get // the synchronized date/time. - SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, signerName, sizeof(signerName))); + SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, signerName)); signatureLength = sigRecord.GetLength() - (offset - sigRdataOffset); offset += signatureLength; @@ -1383,9 +1381,9 @@ void Server::InformUpdateHandlerOrCommit(Error aError, Host &aHost, const Messag for (const Heap::String &subType : service.mSubTypes) { - char label[Dns::Name::kMaxLabelSize]; + Dns::Name::LabelBuffer label; - IgnoreError(Service::ParseSubTypeServiceName(subType.AsCString(), label, sizeof(label))); + IgnoreError(Service::ParseSubTypeServiceName(subType.AsCString(), label)); LogInfo(" sub-type: %s", label); } } @@ -1892,9 +1890,9 @@ void Server::Service::Log(Action aAction) const for (const Heap::String &subType : mSubTypes) { - char label[Dns::Name::kMaxLabelSize]; + Dns::Name::LabelBuffer label; - IgnoreError(ParseSubTypeServiceName(subType.AsCString(), label, sizeof(label))); + IgnoreError(ParseSubTypeServiceName(subType.AsCString(), label)); LogInfo(" sub-type: %s", subType.AsCString()); } } diff --git a/src/core/net/srp_server.hpp b/src/core/net/srp_server.hpp index bb73c154a97..c506416eca4 100644 --- a/src/core/net/srp_server.hpp +++ b/src/core/net/srp_server.hpp @@ -407,6 +407,12 @@ class Server : public InstanceLocator, private NonCopyable bool Matches(const char *aInstanceName) const; void Log(Action aAction) const; + template + static Error ParseSubTypeServiceName(const char *aSubTypeServiceName, char (&aLabel)[kLabelSize]) + { + return ParseSubTypeServiceName(aSubTypeServiceName, aLabel, kLabelSize); + } + Service *mNext; Heap::String mInstanceName; Heap::String mInstanceLabel; diff --git a/tests/unit/test_dns.cpp b/tests/unit/test_dns.cpp index a8408aad487..8e6ef0dd8a4 100644 --- a/tests/unit/test_dns.cpp +++ b/tests/unit/test_dns.cpp @@ -56,20 +56,20 @@ void TestDnsName(void) const char *mExpectedReadName; }; - Instance *instance; - MessagePool *messagePool; - Message *message; - uint8_t buffer[kMaxSize]; - uint16_t len; - uint16_t offset; - char label[Dns::Name::kMaxLabelSize]; - uint8_t labelLength; - char name[Dns::Name::kMaxNameSize]; - const char *subDomain; - const char *domain; - const char *domain2; - const char *fullName; - const char *suffixName; + Instance *instance; + MessagePool *messagePool; + Message *message; + uint8_t buffer[kMaxSize]; + uint16_t len; + uint16_t offset; + Dns::Name::LabelBuffer label; + uint8_t labelLength; + Dns::Name::Buffer name; + const char *subDomain; + const char *domain; + const char *domain2; + const char *fullName; + const char *suffixName; static const uint8_t kEncodedName1[] = {7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0}; static const uint8_t kEncodedName2[] = {3, 'f', 'o', 'o', 1, 'a', 2, 'b', 'b', 3, 'e', 'd', 'u', 0}; @@ -254,42 +254,42 @@ void TestDnsName(void) fullName = "my-service._ipps._tcp.default.service.arpa."; suffixName = "default.service.arpa."; - SuccessOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name)); VerifyOrQuit(strcmp(name, "my-service._ipps._tcp") == 0); fullName = "my.service._ipps._tcp.default.service.arpa."; suffixName = "_ipps._tcp.default.service.arpa."; - SuccessOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name)); VerifyOrQuit(strcmp(name, "my.service") == 0); fullName = "my-service._ipps._tcp.default.service.arpa."; suffixName = "DeFault.SerVice.ARPA."; - SuccessOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name)); VerifyOrQuit(strcmp(name, "my-service._ipps._tcp") == 0); fullName = "my-service._ipps._tcp.default.service.arpa."; suffixName = "efault.service.arpa."; - VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name, sizeof(name)) == kErrorParse); + VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name) == kErrorParse); fullName = "my-service._ipps._tcp.default.service.arpa."; suffixName = "xdefault.service.arpa."; - VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name, sizeof(name)) == kErrorParse); + VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name) == kErrorParse); fullName = "my-service._ipps._tcp.default.service.arpa."; suffixName = ".default.service.arpa."; - VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name, sizeof(name)) == kErrorParse); + VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name) == kErrorParse); fullName = "my-service._ipps._tcp.default.service.arpa."; suffixName = "default.service.arp."; - VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name, sizeof(name)) == kErrorParse); + VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name) == kErrorParse); fullName = "default.service.arpa."; suffixName = "default.service.arpa."; - VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name, sizeof(name)) == kErrorParse); + VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name) == kErrorParse); fullName = "efault.service.arpa."; suffixName = "default.service.arpa."; - VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name, sizeof(name)) == kErrorParse); + VerifyOrQuit(Dns::Name::ExtractLabels(fullName, suffixName, name) == kErrorParse); fullName = "my-service._ipps._tcp.default.service.arpa."; suffixName = "default.service.arpa."; @@ -342,7 +342,7 @@ void TestDnsName(void) // Read entire name offset = 0; - SuccessOrQuit(Dns::Name::ReadName(*message, offset, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ReadName(*message, offset, name)); printf("Read name =\"%s\"\n", name); @@ -633,7 +633,7 @@ void TestDnsCompressedName(void) "Name::ReadLabel() failed at end of the name"); offset = name1Offset; - SuccessOrQuit(Dns::Name::ReadName(*message, offset, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ReadName(*message, offset, name)); printf("Read name =\"%s\"\n", name); VerifyOrQuit(strcmp(name, kExpectedReadName1) == 0, "Name::ReadName() did not return expected name"); VerifyOrQuit(offset == name1Offset + sizeof(kEncodedName), "Name::ReadName() returned incorrect offset"); @@ -690,7 +690,7 @@ void TestDnsCompressedName(void) "Name::ReadLabel() failed at end of the name"); offset = name2Offset; - SuccessOrQuit(Dns::Name::ReadName(*message, offset, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ReadName(*message, offset, name)); printf("Read name =\"%s\"\n", name); VerifyOrQuit(strcmp(name, kExpectedReadName2) == 0, "Name::ReadName() did not return expected name"); VerifyOrQuit(offset == name2Offset + kName2EncodedSize, "Name::ReadName() returned incorrect offset"); @@ -747,7 +747,7 @@ void TestDnsCompressedName(void) "Name::ReadLabel() failed at end of the name"); offset = name3Offset; - SuccessOrQuit(Dns::Name::ReadName(*message, offset, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ReadName(*message, offset, name)); printf("Read name =\"%s\"\n", name); VerifyOrQuit(strcmp(name, kExpectedReadName3) == 0, "Name::ReadName() did not return expected name"); VerifyOrQuit(offset == name3Offset + kName3EncodedSize, "Name::ReadName() returned incorrect offset"); @@ -801,7 +801,7 @@ void TestDnsCompressedName(void) // `ReadName()` for name-4 should still succeed since only the first label contains dot char offset = name4Offset; - SuccessOrQuit(Dns::Name::ReadName(*message, offset, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ReadName(*message, offset, name)); printf("Read name =\"%s\"\n", name); VerifyOrQuit(strcmp(name, kExpectedReadName4) == 0, "Name::ReadName() did not return expected name"); VerifyOrQuit(offset == name4Offset + kName4EncodedSize, "Name::ParseName() returned incorrect offset"); @@ -848,11 +848,11 @@ void TestDnsCompressedName(void) SuccessOrQuit(Dns::Name::CompareName(*message2, offset, dnsName4)); offset = 0; - SuccessOrQuit(Dns::Name::ReadName(*message2, offset, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ReadName(*message2, offset, name)); printf("- Name1 after `AppendTo()`: \"%s\"\n", name); - SuccessOrQuit(Dns::Name::ReadName(*message2, offset, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ReadName(*message2, offset, name)); printf("- Name2 after `AppendTo()`: \"%s\"\n", name); - SuccessOrQuit(Dns::Name::ReadName(*message2, offset, name, sizeof(name))); + SuccessOrQuit(Dns::Name::ReadName(*message2, offset, name)); printf("- Name3 after `AppendTo()`: \"%s\"\n", name); // `ReadName()` for name-4 will fail due to first label containing dot char. @@ -913,9 +913,9 @@ void TestHeaderAndResourceRecords(void) Dns::ResourceRecord record; Ip6::Address hostAddress; - char label[Dns::Name::kMaxLabelSize]; - char name[Dns::Name::kMaxNameSize]; - uint8_t buffer[kMaxSize]; + Dns::Name::LabelBuffer label; + Dns::Name::Buffer name; + uint8_t buffer[kMaxSize]; printf("================================================================\n"); printf("TestHeaderAndResourceRecords()\n"); @@ -1050,7 +1050,7 @@ void TestHeaderAndResourceRecords(void) SuccessOrQuit(Dns::ResourceRecord::ReadRecord(*message, offset, ptrRecord)); VerifyOrQuit(ptrRecord.GetTtl() == kTtl, "Read PTR is incorrect"); - SuccessOrQuit(ptrRecord.ReadPtrName(*message, offset, label, sizeof(label), name, sizeof(name))); + SuccessOrQuit(ptrRecord.ReadPtrName(*message, offset, label, name)); VerifyOrQuit(strcmp(label, instanceLabel) == 0, "Inst label is incorrect"); VerifyOrQuit(strcmp(name, kServiceName) == 0); @@ -1077,7 +1077,7 @@ void TestHeaderAndResourceRecords(void) VerifyOrQuit(numRecords == prevNumRecords - 1, "Incorrect num records"); SuccessOrQuit(Dns::ResourceRecord::ReadRecord(*message, offset, ptrRecord)); VerifyOrQuit(ptrRecord.GetTtl() == kTtl, "Read PTR is incorrect"); - SuccessOrQuit(ptrRecord.ReadPtrName(*message, offset, label, sizeof(label), name, sizeof(name))); + SuccessOrQuit(ptrRecord.ReadPtrName(*message, offset, label, name)); printf(" \"%s\" PTR %u %d inst:\"%s\" at \"%s\"\n", kServiceName, ptrRecord.GetTtl(), ptrRecord.GetLength(), label, name); } @@ -1125,7 +1125,7 @@ void TestHeaderAndResourceRecords(void) VerifyOrQuit(srvRecord.GetPort() == kSrvPort); VerifyOrQuit(srvRecord.GetWeight() == kSrvWeight); VerifyOrQuit(srvRecord.GetPriority() == kSrvPriority); - SuccessOrQuit(srvRecord.ReadTargetHostName(*message, offset, name, sizeof(name))); + SuccessOrQuit(srvRecord.ReadTargetHostName(*message, offset, name)); VerifyOrQuit(strcmp(name, kHostName) == 0); printf(" \"%s\" SRV %u %d %d %d %d \"%s\"\n", instanceName, srvRecord.GetTtl(), srvRecord.GetLength(), srvRecord.GetPort(), srvRecord.GetWeight(), srvRecord.GetPriority(), name); diff --git a/tests/unit/test_dns_client.cpp b/tests/unit/test_dns_client.cpp index 59efab72df2..74c5598c036 100644 --- a/tests/unit/test_dns_client.cpp +++ b/tests/unit/test_dns_client.cpp @@ -357,10 +357,10 @@ struct BrowseInfo { void Reset(void) { mCallbackCount = 0; } - uint16_t mCallbackCount; - Error mError; - char mServiceName[Dns::Name::kMaxNameSize]; - uint16_t mNumInstances; + uint16_t mCallbackCount; + Error mError; + Dns::Name::Buffer mServiceName; + uint16_t mNumInstances; }; static BrowseInfo sBrowseInfo; @@ -384,8 +384,8 @@ void BrowseCallback(otError aError, const otDnsBrowseResponse *aResponse, void * for (uint16_t index = 0;; index++) { - char instLabel[Dns::Name::kMaxLabelSize]; - Error error; + Dns::Name::LabelBuffer instLabel; + Error error; error = response.GetServiceInstance(index, instLabel, sizeof(instLabel)); @@ -423,7 +423,7 @@ struct ResolveServiceInfo uint16_t mCallbackCount; Error mError; Dns::Client::ServiceInfo mInfo; - char mNameBuffer[Dns::Name::kMaxNameSize]; + Dns::Name::Buffer mNameBuffer; uint8_t mTxtBuffer[kMaxTxtBuffer]; Ip6::Address mHostAddresses[kMaxHostAddresses]; uint8_t mNumHostAddresses; @@ -434,8 +434,8 @@ static ResolveServiceInfo sResolveServiceInfo; void ServiceCallback(otError aError, const otDnsServiceResponse *aResponse, void *aContext) { const Dns::Client::ServiceResponse &response = AsCoreType(aResponse); - char instLabel[Dns::Name::kMaxLabelSize]; - char serviceName[Dns::Name::kMaxNameSize]; + Dns::Name::LabelBuffer instLabel; + Dns::Name::Buffer serviceName; Log("ServiceCallback"); Log(" Error: %s", ErrorToString(aError)); @@ -912,8 +912,8 @@ void TestDnsClient(void) //---------------------------------------------------------------------------------------------------------------------- -char sLastSubscribeName[Dns::Name::kMaxNameSize]; -char sLastUnsubscribeName[Dns::Name::kMaxNameSize]; +Dns::Name::Buffer sLastSubscribeName; +Dns::Name::Buffer sLastUnsubscribeName; void QuerySubscribe(void *aContext, const char *aFullName) {