Skip to content

Commit

Permalink
Fix continuation memory leak in Ares.query (#31)
Browse files Browse the repository at this point in the history
* Fix continuation memory leak in Ares.query

* Use class for QueryReplyHandler

* Deallocate QueryReplyHandler for DNSSD

* Move defer block after allocate/initialize, use class

Move defer deallocation block to after initialization.

Use class instead of struct for DNSSD.QueryReplyHandler.
  • Loading branch information
dieb authored Feb 21, 2024
1 parent d9afa74 commit b7079b7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
15 changes: 7 additions & 8 deletions Sources/AsyncDNSResolver/c-ares/DNSResolver_c-ares.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,12 @@ class Ares {
preconditionFailure("'arg' is nil. This is a bug.")
}

let handler = QueryReplyHandler(pointer: handlerPointer)
defer { handlerPointer.deallocate() }
let pointer = handlerPointer.assumingMemoryBound(to: QueryReplyHandler.self)
let handler = pointer.pointee
defer {
pointer.deinitialize(count: 1)
pointer.deallocate()
}

handler.handle(status: status, buffer: buf, length: len)
}
Expand Down Expand Up @@ -258,7 +262,7 @@ extension Ares {

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension Ares {
struct QueryReplyHandler {
class QueryReplyHandler {
private let _handler: (CInt, UnsafeMutablePointer<CUnsignedChar>?, CInt) -> Void

init<Parser: AresQueryReplyParser>(parser: Parser, _ continuation: CheckedContinuation<Parser.Reply, Error>) {
Expand All @@ -276,11 +280,6 @@ extension Ares {
}
}

init(pointer: UnsafeMutableRawPointer) {
let handlerPointer = pointer.assumingMemoryBound(to: Self.self)
self = handlerPointer.pointee
}

func handle(status: CInt, buffer: UnsafeMutablePointer<CUnsignedChar>?, length: CInt) {
self._handler(status, buffer, length)
}
Expand Down
20 changes: 11 additions & 9 deletions Sources/AsyncDNSResolver/dnssd/DNSResolver_dnssd.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,25 @@ struct DNSSD {
byteCount: MemoryLayout<QueryReplyHandler>.stride,
alignment: MemoryLayout<QueryReplyHandler>.alignment
)
// The handler might be called multiple times so don't deallocate inside `callback`
defer { handlerPointer.deallocate() }

handlerPointer.initializeMemory(as: QueryReplyHandler.self, repeating: handler, count: 1)

// The handler might be called multiple times so don't deallocate inside `callback`
defer {
let pointer = handlerPointer.assumingMemoryBound(to: QueryReplyHandler.self)
pointer.deinitialize(count: 1)
pointer.deallocate()
}

// This is called once per record received
let callback: DNSServiceQueryRecordReply = { _, _, _, errorCode, _, _, _, rdlen, rdata, _, context in
guard let handlerPointer = context else {
preconditionFailure("'context' is nil. This is a bug.")
}

let handler = QueryReplyHandler(pointer: handlerPointer)
let pointer = handlerPointer.assumingMemoryBound(to: QueryReplyHandler.self)
let handler = pointer.pointee

// This parses a record then adds it to the stream
handler.handleRecord(errorCode: errorCode, data: rdata, length: rdlen)
}
Expand Down Expand Up @@ -171,7 +178,7 @@ struct DNSSD {

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension DNSSD {
struct QueryReplyHandler {
class QueryReplyHandler {
private let _handleRecord: (DNSServiceErrorType, UnsafeRawPointer?, UInt16) -> Void

init<Handler: DNSSDQueryReplyHandler>(handler: Handler, _ continuation: AsyncThrowingStream<Handler.Record, Error>.Continuation) {
Expand All @@ -189,11 +196,6 @@ extension DNSSD {
}
}

init(pointer: UnsafeMutableRawPointer) {
let handlerPointer = pointer.assumingMemoryBound(to: Self.self)
self = handlerPointer.pointee
}

func handleRecord(errorCode: DNSServiceErrorType, data: UnsafeRawPointer?, length: UInt16) {
self._handleRecord(errorCode, data, length)
}
Expand Down

0 comments on commit b7079b7

Please sign in to comment.