Skip to content

Commit

Permalink
[dns-types] add template variants of reading DNS names or labels (ope…
Browse files Browse the repository at this point in the history
…nthread#9720)

This commit introduces template variants for `ReadName()` and other
related methods, allowing flexible reading of DNS names and labels
from messages into a given array buffer. This simplifies the code and
improves readability.

Additionally, this commit defines new types, `Dns::Name::Buffer` and
`Dns::Name::LabelBuffer`, as arrays of char with fixed sizes to hold
DNS names and labels, respectively.
  • Loading branch information
abtink authored Dec 19, 2023
1 parent eec54ef commit a492d04
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 118 deletions.
125 changes: 125 additions & 0 deletions src/core/net/dns_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,24 @@ class Name : public Clearable<Name>
*/
static constexpr uint8_t kMaxLabelLength = kMaxLabelSize - 1;

/**
* Dot character separating labels in a name.
*
*/
static constexpr char kLabelSeparatorChar = '.';

/**
* Represents a string buffer (with `kMaxNameSize`) intended to hold a DNS name.
*
*/
typedef char Buffer[kMaxNameSize];

/**
* Represents a string buffer (with `kMaxLabelSize`) intended to hold a DNS label.
*
*/
typedef char LabelBuffer[kMaxLabelSize];

/**
* Represents the name type.
*
Expand Down Expand Up @@ -821,6 +837,35 @@ class Name : public Clearable<Name>
*/
static Error ReadName(const Message &aMessage, uint16_t &aOffset, char *aNameBuffer, uint16_t aNameBufferSize);

/**
* Reads a full name from a message.
*
* On successful read, the read name follows "<label1>.<label2>.<label3>.", i.e., a sequence of labels separated by
* dot '.' character. The read name will ALWAYS end with a dot.
*
* Verifies that the labels after the first label in message do not contain any dot character. If they do,
* returns `kErrorParse`.
*
* @tparam kNameBufferSize Size of the string buffer array.
*
* @param[in] aMessage The message to read the name from. `aMessage.GetOffset()` MUST point to
* the start of DNS header (this is used to handle compressed names).
* @param[in,out] aOffset On input, the offset in @p aMessage pointing to the start of the name field.
* On exit (when parsed successfully), @p aOffset is updated to point to the byte
* after the end of name field.
* @param[out] aNameBuffer Reference to a name string buffer to output the read name.
*
* @retval kErrorNone Successfully read the name, @p aNameBuffer and @p Offset are updated.
* @retval kErrorParse Name could not be parsed (invalid format).
* @retval kErrorNoBufs Name could not fit in @p aNameBuffer.
*
*/
template <uint16_t kNameBufferSize>
static Error ReadName(const Message &aMessage, uint16_t &aOffset, char (&aNameBuffer)[kNameBufferSize])
{
return ReadName(aMessage, aOffset, aNameBuffer, kNameBufferSize);
}

/**
* Compares a single name label from a message with a given label string.
*
Expand Down Expand Up @@ -946,6 +991,30 @@ class Name : public Clearable<Name>
*/
static Error ExtractLabels(const char *aName, const char *aSuffixName, char *aLabels, uint16_t aLabelsSize);

/**
* Extracts label(s) from a full name by checking that it contains a given suffix name (e.g., suffix name can be
* a domain name) and removing it.
*
* Both @p aName and @p aSuffixName must be full DNS name and end with ('.'), otherwise the behavior of this method
* is undefined.
*
* @tparam kLabelsBufferSize Size of the buffer string.
*
* @param[in] aName The full name to extract labels from.
* @param[in] aSuffixName The suffix name (e.g. can be domain name).
* @param[out] aLabelsBuffer A buffer to copy the extracted labels.
*
* @retval kErrorNone Successfully extracted the labels, @p aLabels is updated.
* @retval kErrorParse @p aName does not contain @p aSuffixName.
* @retval kErrorNoBufs Could not fit the labels in @p aLabels.
*
*/
template <uint16_t kLabelsBufferSize>
static Error ExtractLabels(const char *aName, const char *aSuffixName, char (&aLabels)[kLabelsBufferSize])
{
return ExtractLabels(aName, aSuffixName, aLabels, kLabelsBufferSize);
}

/**
* Tests if a DNS name is a sub-domain of a given domain.
*
Expand Down Expand Up @@ -1660,6 +1729,36 @@ class PtrRecord : public ResourceRecord
char *aNameBuffer,
uint16_t aNameBufferSize) const;

/**
* Parses and reads the PTR name from a message.
*
* This is a template variation of the previous method with name and label buffer sizes as template parameters.
*
* @tparam kLabelBufferSize The size of label buffer.
* @tparam kNameBufferSize The size of name buffer.
*
* @param[in] aMessage The message to read from. `aMessage.GetOffset()` MUST point to the start of
* DNS header.
* @param[in,out] aOffset On input, the offset in @p aMessage to the start of PTR name field.
* On exit, when successfully read, @p aOffset is updated to point to the byte
* after the entire PTR record (skipping over the record).
* @param[out] aLabelBuffer A char array buffer to output the first label as a null-terminated C string.
* @param[out] aNameBuffer A char array to output the rest of name (after first label).
*
* @retval kErrorNone The PTR name was read successfully. @p aOffset, @aLabelBuffer and @aNameBuffer are updated.
* @retval kErrorParse The PTR record in @p aMessage could not be parsed (invalid format).
* @retval kErrorNoBufs Either label or name could not fit in the related given buffers.
*
*/
template <uint16_t kLabelBufferSize, uint16_t kNameBufferSize>
Error ReadPtrName(const Message &aMessage,
uint16_t &aOffset,
char (&aLabelBuffer)[kLabelBufferSize],
char (&aNameBuffer)[kNameBufferSize]) const
{
return ReadPtrName(aMessage, aOffset, aLabelBuffer, kLabelBufferSize, aNameBuffer, kNameBufferSize);
}

} OT_TOOL_PACKED_END;

/**
Expand Down Expand Up @@ -1868,6 +1967,32 @@ class SrvRecord : public ResourceRecord
/* aSkipRecord */ true);
}

/**
* Parses and reads the SRV target host name from a message.
*
* Also verifies that the SRV record is well-formed (e.g., the record data length `GetLength()` matches
* the SRV encoded name).
*
* @tparam kNameBufferSize Size of the name buffer.
*
* @param[in] aMessage The message to read from. `aMessage.GetOffset()` MUST point to the start of
* DNS header.
* @param[in,out] aOffset On input, the offset in @p aMessage to start of target host name field.
* On exit when successfully read, @p aOffset is updated to point to the byte
* after the entire SRV record (skipping over the record).
* @param[out] aNameBuffer A char array to output the read name as a null-terminated C string
*
* @retval kErrorNone The host name was read successfully. @p aOffset and @p aNameBuffer are updated.
* @retval kErrorParse The SRV record in @p aMessage could not be parsed (invalid format).
* @retval kErrorNoBufs Name could not fit in @p aNameBuffer.
*
*/
template <uint16_t kNameBufferSize>
Error ReadTargetHostName(const Message &aMessage, uint16_t &aOffset, char (&aNameBuffer)[kNameBufferSize]) const
{
return ReadTargetHostName(aMessage, aOffset, aNameBuffer, kNameBufferSize);
}

private:
uint16_t mPriority;
uint16_t mWeight;
Expand Down
58 changes: 29 additions & 29 deletions src/core/net/dnssd_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,12 @@ Error Server::Response::ParseQueryName(void)
// Parses and validates the query name and updates
// the name compression offsets.

Error error = kErrorNone;
DnsName name;
uint16_t offset;
Error error = kErrorNone;
Name::Buffer name;
uint16_t offset;

offset = sizeof(Header);
SuccessOrExit(error = Name::ReadName(*mMessage, offset, name, sizeof(name)));
SuccessOrExit(error = Name::ReadName(*mMessage, offset, name));

switch (mType)
{
Expand Down Expand Up @@ -446,9 +446,9 @@ Error Server::Response::ParseQueryName(void)

while (true)
{
DnsLabel label;
uint8_t labelLength = sizeof(label);
uint16_t comapreOffset;
Name::LabelBuffer label;
uint8_t labelLength = sizeof(label);
uint16_t comapreOffset;

SuccessOrExit(error = Name::ReadLabel(*mMessage, offset, label, labelLength));

Expand All @@ -472,7 +472,7 @@ Error Server::Response::ParseQueryName(void)
return error;
}

void Server::Response::ReadQueryName(DnsName &aName) const { Server::ReadQueryName(*mMessage, aName); }
void Server::Response::ReadQueryName(Name::Buffer &aName) const { Server::ReadQueryName(*mMessage, aName); }

bool Server::Response::QueryNameMatches(const char *aName) const { return Server::QueryNameMatches(*mMessage, aName); }

Expand Down Expand Up @@ -526,12 +526,12 @@ Error Server::Response::AppendSrvRecord(const char *aHostName,
uint16_t aWeight,
uint16_t aPort)
{
Error error = kErrorNone;
SrvRecord srvRecord;
uint16_t recordOffset;
DnsName hostLabels;
Error error = kErrorNone;
SrvRecord srvRecord;
uint16_t recordOffset;
Name::Buffer hostLabels;

SuccessOrExit(error = Name::ExtractLabels(aHostName, kDefaultDomainName, hostLabels, sizeof(hostLabels)));
SuccessOrExit(error = Name::ExtractLabels(aHostName, kDefaultDomainName, hostLabels));

srvRecord.Init();
srvRecord.SetTtl(aTtl);
Expand Down Expand Up @@ -674,7 +674,7 @@ uint8_t Server::GetNameLength(const char *aName)
#if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
void Server::Response::Log(void) const
{
DnsName name;
Name::Buffer name;

ReadQueryName(name);
LogInfo("%s query for '%s'", QueryTypeToString(mType), name);
Expand Down Expand Up @@ -830,16 +830,16 @@ bool Server::Response::QueryNameMatchesService(const Srp::Server::Service &aServ
#if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
bool Server::ShouldForwardToUpstream(const Request &aRequest)
{
bool shouldForward = false;
uint16_t readOffset;
DnsName name;
bool shouldForward = false;
uint16_t readOffset;
Name::Buffer name;

VerifyOrExit(aRequest.mHeader.IsRecursionDesiredFlagSet());
readOffset = sizeof(Header);

for (uint16_t i = 0; i < aRequest.mHeader.GetQuestionCount(); i++)
{
SuccessOrExit(Name::ReadName(*aRequest.mMessage, readOffset, name, sizeof(name)));
SuccessOrExit(Name::ReadName(*aRequest.mMessage, readOffset, name));
readOffset += sizeof(Question);

VerifyOrExit(!Name::IsSubDomainOf(name, kDefaultDomainName));
Expand Down Expand Up @@ -915,7 +915,7 @@ void Server::ResolveByProxy(Response &aResponse, const Ip6::MessageInfo &aMessag
{
ProxyQuery *query;
ProxyQueryInfo info;
DnsName name;
Name::Buffer name;

VerifyOrExit(mQuerySubscribe.IsSet());

Expand Down Expand Up @@ -950,11 +950,11 @@ void Server::ResolveByProxy(Response &aResponse, const Ip6::MessageInfo &aMessag
return;
}

void Server::ReadQueryName(const Message &aQuery, DnsName &aName)
void Server::ReadQueryName(const Message &aQuery, Name::Buffer &aName)
{
uint16_t offset = sizeof(Header);

IgnoreError(Name::ReadName(aQuery, offset, aName, sizeof(aName)));
IgnoreError(Name::ReadName(aQuery, offset, aName));
}

bool Server::QueryNameMatches(const Message &aQuery, const char *aName)
Expand All @@ -979,20 +979,20 @@ void Server::ProxyQueryInfo::UpdateIn(ProxyQuery &aQuery) const
aQuery.Write(aQuery.GetLength() - sizeof(ProxyQueryInfo), *this);
}

Error Server::Response::ExtractServiceInstanceLabel(const char *aInstanceName, DnsLabel &aLabel)
Error Server::Response::ExtractServiceInstanceLabel(const char *aInstanceName, Name::LabelBuffer &aLabel)
{
uint16_t offset;
DnsName serviceName;
uint16_t offset;
Name::Buffer serviceName;

offset = mOffsets.mServiceName;
IgnoreError(Name::ReadName(*mMessage, offset, serviceName, sizeof(serviceName)));
IgnoreError(Name::ReadName(*mMessage, offset, serviceName));

return Name::ExtractLabels(aInstanceName, serviceName, aLabel, sizeof(aLabel));
return Name::ExtractLabels(aInstanceName, serviceName, aLabel);
}

void Server::RemoveQueryAndPrepareResponse(ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, Response &aResponse)
{
DnsName name;
Name::Buffer name;

mProxyQueries.Dequeue(aQuery);
aInfo.RemoveFrom(aQuery);
Expand Down Expand Up @@ -1021,7 +1021,7 @@ void Server::Response::Answer(const ServiceInstanceInfo &aInstanceInfo, const Ip

if (mType == kPtrQuery)
{
DnsLabel instanceLabel;
Name::LabelBuffer instanceLabel;

SuccessOrExit(error = ExtractServiceInstanceLabel(aInstanceInfo.mFullName, instanceLabel));
mSection = kAnswerSection;
Expand Down Expand Up @@ -1148,7 +1148,7 @@ const otDnssdQuery *Server::GetNextQuery(const otDnssdQuery *aQuery) const
return (query == nullptr) ? mProxyQueries.GetHead() : query->GetNext();
}

Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, char (&aName)[Name::kMaxNameSize])
Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, Dns::Name::Buffer &aName)
{
const ProxyQuery *query = static_cast<const ProxyQuery *>(aQuery);
ProxyQueryInfo info;
Expand Down
10 changes: 4 additions & 6 deletions src/core/net/dnssd_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ class Server : public InstanceLocator, private NonCopyable
* @returns The DNS-SD query type.
*
*/
static DnsQueryType GetQueryTypeAndName(const otDnssdQuery *aQuery, char (&aName)[Name::kMaxNameSize]);
static DnsQueryType GetQueryTypeAndName(const otDnssdQuery *aQuery, Dns::Name::Buffer &aName);

/**
* Returns the counters of the DNS-SD server.
Expand Down Expand Up @@ -297,8 +297,6 @@ class Server : public InstanceLocator, private NonCopyable
static constexpr uint16_t kMaxConcurrentUpstreamQueries = 32;

typedef Header::Response ResponseCode;
typedef char DnsName[Name::kMaxNameSize];
typedef char DnsLabel[Name::kMaxLabelSize];

typedef Message ProxyQuery;
typedef MessageQueue ProxyQueryList;
Expand Down Expand Up @@ -347,7 +345,7 @@ class Server : public InstanceLocator, private NonCopyable
void SetResponseCode(ResponseCode aResponseCode) { mHeader.SetResponseCode(aResponseCode); }
ResponseCode AddQuestionsFrom(const Request &aRequest);
Error ParseQueryName(void);
void ReadQueryName(DnsName &aName) const;
void ReadQueryName(Name::Buffer &aName) const;
bool QueryNameMatches(const char *aName) const;
Error AppendQueryName(void);
Error AppendPtrRecord(const char *aInstanceLabel, uint32_t aTtl);
Expand All @@ -367,7 +365,7 @@ class Server : public InstanceLocator, private NonCopyable
void Send(const Ip6::MessageInfo &aMessageInfo);
void Answer(const HostInfo &aHostInfo, const Ip6::MessageInfo &aMessageInfo);
void Answer(const ServiceInstanceInfo &aInstanceInfo, const Ip6::MessageInfo &aMessageInfo);
Error ExtractServiceInstanceLabel(const char *aInstanceName, DnsLabel &aLabel);
Error ExtractServiceInstanceLabel(const char *aInstanceName, Name::LabelBuffer &aLabel);
#if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
Error ResolveBySrp(void);
bool QueryNameMatchesService(const Srp::Server::Service &aService) const;
Expand Down Expand Up @@ -408,7 +406,7 @@ class Server : public InstanceLocator, private NonCopyable
void ResolveByProxy(Response &aResponse, const Ip6::MessageInfo &aMessageInfo);
void RemoveQueryAndPrepareResponse(ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, Response &aResponse);
void Finalize(ProxyQuery &aQuery, ResponseCode aResponseCode);
static void ReadQueryName(const Message &aQuery, DnsName &aName);
static void ReadQueryName(const Message &aQuery, Name::Buffer &aName);
static bool QueryNameMatches(const Message &aQuery, const char *aName);

#if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
Expand Down
Loading

0 comments on commit a492d04

Please sign in to comment.