Skip to content

Commit

Permalink
Merge pull request #28 from orlandos-nl/feature/jo/official-niots
Browse files Browse the repository at this point in the history
Use the official NIOTS, and support NIOTS on macOS platforms
  • Loading branch information
Joannis authored Aug 25, 2023
2 parents 6fda9fd + 1e51c69 commit 3878acc
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 101 deletions.
10 changes: 5 additions & 5 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@
"repositoryURL": "https://github.com/apple/swift-nio.git",
"state": {
"branch": null,
"revision": "e0cc6dd6ffa8e6a6f565938acd858b24e47902d0",
"version": "2.50.0"
"revision": "cf281631ff10ec6111f2761052aa81896a83a007",
"version": "2.58.0"
}
},
{
"package": "swift-nio-transport-services",
"repositoryURL": "https://github.com/Joannis/swift-nio-transport-services.git",
"repositoryURL": "https://github.com/apple/swift-nio-transport-services.git",
"state": {
"branch": null,
"revision": "f3e37707974cab7d02785b66a24f1baccb4cdc8d",
"version": "1.17.0"
"revision": "e7403c35ca6bb539a7ca353b91cc2d8ec0362d58",
"version": "1.19.0"
}
}
]
Expand Down
6 changes: 5 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import PackageDescription

let package = Package(
name: "DNSClient",
platforms: [
.macOS(.v10_15),
.iOS(.v13),
],
products: [
// Products define the executables and libraries produced by a package, and make them visible to other packages.
.library(
Expand All @@ -16,7 +20,7 @@ let package = Package(
dependencies: [
// Dependencies declare other packages that this package depends on.
.package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"),
.package(url: "https://github.com/Joannis/swift-nio-transport-services.git", from: "1.17.0"),
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"),
],
targets: [
// Targets are the basic building blocks of a package. A target can define a module or a test suite.
Expand Down
85 changes: 77 additions & 8 deletions Sources/DNSClient/DNSClient+Connect.swift
Original file line number Diff line number Diff line change
Expand Up @@ -179,30 +179,41 @@ fileprivate extension Array where Element == SocketAddress {
}
}

#if canImport(NIOTransportServices) && os(iOS)
#if canImport(Network)
import NIOTransportServices

@available(iOS 12, *)
extension DNSClient {
public static func connectTS(on group: NIOTSEventLoopGroup, host: String) -> EventLoopFuture<DNSClient> {
do {
let address = try SocketAddress(ipAddress: host, port: 53)
return connectTS(on: group, config: [address])
} catch {
return group.next().makeFailedFuture(error)
}
}

/// Connect to the dns server using TCP using NIOTransportServices. This is only available on iOS 12 and above.
/// - parameters:
/// - group: EventLoops to use
/// - config: DNS servers to use
/// - returns: Future with the NioDNS client. Use
public static func connectTS(on group: NIOTSEventLoopGroup, config: [SocketAddress]) -> EventLoopFuture<DNSClient> {
guard let address = config.preferred else {
// Don't connect by UNIX domain socket. We currently don't intend to test & support that.
guard
let address = config.preferred,
let ipAddress = address.ipAddress,
let port = address.port
else {
return group.next().makeFailedFuture(MissingNameservers())
}

let dnsDecoder = DNSDecoder(group: group)

return NIOTSDatagramBootstrap(group: group)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: 1)
.channelInitializer { channel in
return channel.pipeline.addHandlers(dnsDecoder, DNSEncoder())
return NIOTSDatagramBootstrap(group: group).channelInitializer { channel in
return channel.pipeline.addHandlers(dnsDecoder, DNSEncoder())
}
.connect(to: address)
.connect(host: ipAddress, port: port)
.map { channel -> DNSClient in
let client = DNSClient(
channel: channel,
Expand All @@ -214,6 +225,7 @@ extension DNSClient {
return client
}
}

/// Connect to the dns server using TCP using NIOTransportServices. This is only available on iOS 12 and above.
/// The DNS Host is read from /etc/resolv.conf
/// - parameters:
Expand All @@ -228,5 +240,62 @@ extension DNSClient {
return group.next().makeFailedFuture(UnableToParseConfig())
}
}

public static func connectTSTCP(on group: NIOTSEventLoopGroup, host: String) -> EventLoopFuture<DNSClient> {
do {
let address = try SocketAddress(ipAddress: host, port: 53)
return connectTSTCP(on: group, config: [address])
} catch {
return group.next().makeFailedFuture(error)
}
}

/// Connect to the dns server using TCP using NIOTransportServices. This is only available on iOS 12 and above.
/// - parameters:
/// - group: EventLoops to use
/// - config: DNS servers to use
/// - returns: Future with the NioDNS client. Use
public static func connectTSTCP(on group: NIOTSEventLoopGroup, config: [SocketAddress]) -> EventLoopFuture<DNSClient> {
guard let address = config.preferred else {
return group.next().makeFailedFuture(MissingNameservers())
}

let dnsDecoder = DNSDecoder(group: group)

return NIOTSConnectionBootstrap(group: group).channelInitializer { channel in
return channel.pipeline.addHandlers(
ByteToMessageHandler(UInt16FrameDecoder()),
MessageToByteHandler(UInt16FrameEncoder()),
dnsDecoder,
DNSEncoder()
)
}
.connect(to: address)
.map { channel -> DNSClient in
let client = DNSClient(
channel: channel,
address: address,
decoder: dnsDecoder
)

dnsDecoder.mainClient = client
return client
}
}

/// Connect to the dns server using TCP using NIOTransportServices. This is only available on iOS 12 and above.
/// The DNS Host is read from /etc/resolv.conf
/// - parameters:
/// - group: EventLoops to use
public static func connectTSTCP(on group: NIOTSEventLoopGroup) -> EventLoopFuture<DNSClient> {
do {
let configString = try String(contentsOfFile: "/etc/resolv.conf")
let config = try ResolvConf(from: configString)

return connectTSTCP(on: group, config: config.nameservers)
} catch {
return group.next().makeFailedFuture(UnableToParseConfig())
}
}
}
#endif
117 changes: 70 additions & 47 deletions Tests/DNSClientTests/DNSTCPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,36 @@ import XCTest
import NIO
@testable import DNSClient

#if canImport(Network)
import NIOTransportServices
#endif

final class DNSTCPClientTests: XCTestCase {
var group: MultiThreadedEventLoopGroup!
var dnsClient: DNSClient!

#if canImport(Network)
var nwGroup: NIOTSEventLoopGroup!
var nwDnsClient: DNSClient!
#endif

override func setUp() {
super.setUp()
do {
group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
dnsClient = try DNSClient.connectTCP(on: group, host: "8.8.8.8").wait()
} catch let error {
XCTFail("\(error)")
}
override func setUpWithError() throws {
group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
dnsClient = try DNSClient.connectTCP(on: group, host: "8.8.8.8").wait()

#if canImport(Network)
nwGroup = NIOTSEventLoopGroup(loopCount: 1)
nwDnsClient = try DNSClient.connectTSTCP(on: nwGroup, host: "8.8.8.8").wait()
#endif
}

func testClient(_ perform: (DNSClient) throws -> Void) rethrows -> Void {
try perform(dnsClient)
#if canImport(Network)
try perform(nwDnsClient)
#endif
}

func testStringAddress() throws {
var buffer = ByteBuffer()
buffer.writeInteger(0x7F000001 as UInt32)
Expand All @@ -40,75 +56,82 @@ final class DNSTCPClientTests: XCTestCase {
}

func testAQuery() throws {
let results = try dnsClient.initiateAQuery(host: "google.com", port: 443).wait()
XCTAssertGreaterThanOrEqual(results.count, 1, "The returned result should be greater than or equal to 1")
try testClient { dnsClient in
let results = try dnsClient.initiateAQuery(host: "google.com", port: 443).wait()
XCTAssertGreaterThanOrEqual(results.count, 1, "The returned result should be greater than or equal to 1")
}
}

// Test that we can resolve a domain name to an IPv6 address
func testAAAAQuery() throws {
let results = try dnsClient.initiateAAAAQuery(host: "google.com", port: 443).wait()
XCTAssertGreaterThanOrEqual(results.count, 1, "The returned result should be greater than or equal to 1")
try testClient { dnsClient in
let results = try dnsClient.initiateAAAAQuery(host: "google.com", port: 443).wait()
XCTAssertGreaterThanOrEqual(results.count, 1, "The returned result should be greater than or equal to 1")
}
}

// Given a domain name, test that we can resolve it to an IPv4 address
func testSendQueryA() throws {
let result = try dnsClient.sendQuery(forHost: "google.com", type: .a).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
try testClient { dnsClient in
let result = try dnsClient.sendQuery(forHost: "google.com", type: .a).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
}
}

// Test that we can resolve example.com to an IPv6 address
func testResolveExampleCom() throws {
let result = try dnsClient.sendQuery(forHost: "example.com", type: .aaaa).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
try testClient { dnsClient in
let result = try dnsClient.sendQuery(forHost: "example.com", type: .aaaa).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
}
}

func testSendTxtQuery() throws {
let result = try dnsClient.sendQuery(forHost: "google.com", type: .txt).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
try testClient { dnsClient in
let result = try dnsClient.sendQuery(forHost: "google.com", type: .txt).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
}
}

func testSendQueryMX() throws {
let result = try dnsClient.sendQuery(forHost: "gmail.com", type: .mx).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
try testClient { dnsClient in
let result = try dnsClient.sendQuery(forHost: "gmail.com", type: .mx).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
}
}

func testSendQueryCNAME() throws {
let result = try dnsClient.sendQuery(forHost: "www.youtube.com", type: .cName).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
try testClient { dnsClient in
let result = try dnsClient.sendQuery(forHost: "www.youtube.com", type: .cName).wait()
XCTAssertGreaterThanOrEqual(result.header.answerCount, 1, "The returned answers should be greater than or equal to 1")
}
}

func testSRVRecords() throws {
let answers = try dnsClient.getSRVRecords(from: "_mongodb._tcp.ok0-xkvc1.mongodb.net").wait()
XCTAssertGreaterThanOrEqual(answers.count, 1, "The returned answers should be greater than or equal to 1")
try testClient { dnsClient in
let answers = try dnsClient.getSRVRecords(from: "_caldavs._tcp.google.com").wait()
XCTAssertGreaterThanOrEqual(answers.count, 1, "The returned answers should be greater than or equal to 1")
}
}

func testSRVRecordsAsyncRequest() throws {
let expectation = self.expectation(description: "getSRVRecords")

dnsClient.getSRVRecords(from: "_mongodb._tcp.ok0-xkvc1.mongodb.net")
.whenComplete { (result) in
switch result {
case .failure(let error):
XCTFail("\(error)")
case .success(let answers):
print(answers)
XCTAssertGreaterThanOrEqual(answers.count, 1, "The returned answers should be greater than or equal to 1")
testClient { dnsClient in
let expectation = self.expectation(description: "getSRVRecords")

dnsClient.getSRVRecords(from: "_caldavs._tcp.google.com")
.whenComplete { (result) in
switch result {
case .failure(let error):
XCTFail("\(error)")
case .success(let answers):
XCTAssertGreaterThanOrEqual(answers.count, 1, "The returned answers should be greater than or equal to 1")
}
expectation.fulfill()
}
expectation.fulfill()
}
self.waitForExpectations(timeout: 5, handler: nil)
self.waitForExpectations(timeout: 5, handler: nil)
}
}

// func testMulticastDNS() async throws {
// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
// let client = try await DNSClient.connectMulticast(on: eventLoopGroup).get()
// let addresses = try await client.sendQuery(
// forHost: "my-host.local",
// type: .any
// ).get()
// print(addresses)
// }

func testThreadSafety() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
let client = try await DNSClient.connectTCP(
Expand Down
Loading

0 comments on commit 3878acc

Please sign in to comment.