- Add cancel() method to RttCommand to fail promise on connection close - Add dequeueAll() method to ConcurrentQueue for batch cleanup - Call clearPendingPings() in close(), suspend(), and disconnect() methods - Prevents 'leaking promise' crash when connection is closed while pings are pending
1116 lines
43 KiB
Swift
1116 lines
43 KiB
Swift
// Copyright 2024 The NATS Authors
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
import Atomics
|
|
import Dispatch
|
|
import Foundation
|
|
import NIO
|
|
import NIOConcurrencyHelpers
|
|
import NIOFoundationCompat
|
|
import NIOHTTP1
|
|
import NIOSSL
|
|
import NIOWebSocket
|
|
import NKeys
|
|
|
|
class ConnectionHandler: ChannelInboundHandler {
|
|
let lang = "Swift"
|
|
let version = "0.0.1"
|
|
|
|
internal var connectedUrl: URL?
|
|
internal let allocator = ByteBufferAllocator()
|
|
internal var inputBuffer: ByteBuffer
|
|
internal var channel: Channel?
|
|
|
|
private var eventHandlerStore: [NatsEventKind: [NatsEventHandler]] = [:]
|
|
|
|
// Connection options
|
|
internal var retryOnFailedConnect = false
|
|
private var urls: [URL]
|
|
// nanoseconds representation of TimeInterval
|
|
private let reconnectWait: UInt64
|
|
private let maxReconnects: Int?
|
|
private let retainServersOrder: Bool
|
|
private let pingInterval: TimeInterval
|
|
private let requireTls: Bool
|
|
private let tlsFirst: Bool
|
|
private var rootCertificate: URL?
|
|
private var clientCertificate: URL?
|
|
private var clientKey: URL?
|
|
|
|
typealias InboundIn = ByteBuffer
|
|
private let state = NIOLockedValueBox(NatsState.pending)
|
|
private let subscriptions = NIOLockedValueBox([UInt64: NatsSubscription]())
|
|
|
|
// Helper methods for state access
|
|
internal var currentState: NatsState {
|
|
state.withLockedValue { $0 }
|
|
}
|
|
|
|
internal func setState(_ newState: NatsState) {
|
|
state.withLockedValue { $0 = newState }
|
|
}
|
|
|
|
private var subscriptionCounter = ManagedAtomic<UInt64>(0)
|
|
private var serverInfo: ServerInfo?
|
|
private var auth: Auth?
|
|
private let parseRemainder = NIOLockedValueBox<Data?>(nil)
|
|
private var pingTask: RepeatedTask?
|
|
private var outstandingPings = ManagedAtomic<UInt8>(0)
|
|
private var reconnectAttempts = 0
|
|
private var reconnectTask: Task<(), Error>? = nil
|
|
private let capturedConnectionError = NIOLockedValueBox<Error?>(nil)
|
|
|
|
private var group: MultiThreadedEventLoopGroup
|
|
|
|
private let serverInfoContinuation = NIOLockedValueBox<CheckedContinuation<ServerInfo, Error>?>(
|
|
nil)
|
|
private let connectionEstablishedContinuation = NIOLockedValueBox<
|
|
CheckedContinuation<Void, Error>?
|
|
>(nil)
|
|
|
|
private let pingQueue = ConcurrentQueue<RttCommand>()
|
|
private(set) var batchBuffer: BatchBuffer?
|
|
|
|
init(
|
|
inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?,
|
|
retainServersOrder: Bool,
|
|
pingInterval: TimeInterval, auth: Auth?, requireTls: Bool, tlsFirst: Bool,
|
|
clientCertificate: URL?, clientKey: URL?,
|
|
rootCertificate: URL?, retryOnFailedConnect: Bool
|
|
) {
|
|
self.urls = urls
|
|
self.group = .singleton
|
|
self.inputBuffer = allocator.buffer(capacity: 1024)
|
|
self.reconnectWait = UInt64(reconnectWait * 1_000_000_000)
|
|
self.maxReconnects = maxReconnects
|
|
self.retainServersOrder = retainServersOrder
|
|
self.auth = auth
|
|
self.pingInterval = pingInterval
|
|
self.requireTls = requireTls
|
|
self.tlsFirst = tlsFirst
|
|
self.clientCertificate = clientCertificate
|
|
self.clientKey = clientKey
|
|
self.rootCertificate = rootCertificate
|
|
self.retryOnFailedConnect = retryOnFailedConnect
|
|
}
|
|
|
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
|
var byteBuffer = self.unwrapInboundIn(data)
|
|
inputBuffer.writeBuffer(&byteBuffer)
|
|
}
|
|
|
|
func channelReadComplete(context: ChannelHandlerContext) {
|
|
guard inputBuffer.readableBytes > 0 else {
|
|
return
|
|
}
|
|
|
|
var inputChunk = Data(buffer: inputBuffer)
|
|
|
|
let remainder = parseRemainder.withLockedValue { value in
|
|
let current = value
|
|
value = nil
|
|
return current
|
|
}
|
|
|
|
if let remainder = remainder, !remainder.isEmpty {
|
|
inputChunk.prepend(remainder)
|
|
}
|
|
|
|
let parseResult: (ops: [ServerOp], remainder: Data?)
|
|
do {
|
|
parseResult = try inputChunk.parseOutMessages()
|
|
} catch {
|
|
// if parsing throws an error, clear buffer and remainder, then reconnect
|
|
inputBuffer.clear()
|
|
parseRemainder.withLockedValue { $0 = nil }
|
|
context.fireErrorCaught(error)
|
|
return
|
|
}
|
|
if let remainder = parseResult.remainder {
|
|
parseRemainder.withLockedValue { $0 = remainder }
|
|
}
|
|
for op in parseResult.ops {
|
|
// Only resume the server info continuation when we actually receive
|
|
// an INFO or -ERR op. Do NOT clear it for unrelated ops.
|
|
switch op {
|
|
case .error(let err):
|
|
if let continuation = serverInfoContinuation.withLockedValue({ cont in
|
|
let toResume = cont
|
|
cont = nil
|
|
return toResume
|
|
}) {
|
|
logger.debug("server info error")
|
|
continuation.resume(throwing: err)
|
|
continue
|
|
}
|
|
case .info(let info):
|
|
if let continuation = serverInfoContinuation.withLockedValue({ cont in
|
|
let toResume = cont
|
|
cont = nil
|
|
return toResume
|
|
}) {
|
|
logger.debug("server info")
|
|
continuation.resume(returning: info)
|
|
continue
|
|
}
|
|
default:
|
|
break
|
|
}
|
|
|
|
let connEstablishedCont = connectionEstablishedContinuation.withLockedValue { cont in
|
|
let toResume = cont
|
|
cont = nil
|
|
return toResume
|
|
}
|
|
|
|
if let continuation = connEstablishedCont {
|
|
logger.debug("conn established")
|
|
switch op {
|
|
case .error(let err):
|
|
continuation.resume(throwing: err)
|
|
default:
|
|
continuation.resume()
|
|
}
|
|
continue
|
|
}
|
|
|
|
switch op {
|
|
case .ping:
|
|
logger.debug("ping")
|
|
Task {
|
|
do {
|
|
try await self.write(operation: .pong)
|
|
} catch let err as NatsError.ClientError {
|
|
logger.error("error sending pong: \(err)")
|
|
self.fire(
|
|
.error(err))
|
|
} catch {
|
|
logger.error("unexpected error sending pong: \(error)")
|
|
}
|
|
}
|
|
case .pong:
|
|
logger.debug("pong")
|
|
self.outstandingPings.store(0, ordering: AtomicStoreOrdering.relaxed)
|
|
self.pingQueue.dequeue()?.setRoundTripTime()
|
|
case .error(let err):
|
|
logger.debug("error \(err)")
|
|
|
|
switch err {
|
|
case .staleConnection, .maxConnectionsExceeded:
|
|
inputBuffer.clear()
|
|
parseRemainder.withLockedValue { $0 = nil }
|
|
context.fireErrorCaught(err)
|
|
case .permissionsViolation(let operation, let subject, _):
|
|
switch operation {
|
|
case .subscribe:
|
|
subscriptions.withLockedValue { subs in
|
|
for (_, s) in subs {
|
|
if s.subject == subject {
|
|
s.receiveError(NatsError.SubscriptionError.permissionDenied)
|
|
}
|
|
}
|
|
}
|
|
case .publish:
|
|
self.fire(.error(err))
|
|
}
|
|
default:
|
|
self.fire(.error(err))
|
|
}
|
|
|
|
let normalizedError = err.normalizedError
|
|
// on some errors, force reconnect
|
|
if normalizedError == "stale connection"
|
|
|| normalizedError == "maximum connections exceeded"
|
|
{
|
|
inputBuffer.clear()
|
|
parseRemainder.withLockedValue { $0 = nil }
|
|
context.fireErrorCaught(err)
|
|
} else {
|
|
self.fire(.error(err))
|
|
}
|
|
case .message(let msg):
|
|
self.handleIncomingMessage(msg)
|
|
case .hMessage(let msg):
|
|
self.handleIncomingMessage(msg)
|
|
case .info(let serverInfo):
|
|
logger.debug("info \(op)")
|
|
self.serverInfo = serverInfo
|
|
if serverInfo.lameDuckMode {
|
|
self.fire(.lameDuckMode)
|
|
}
|
|
self.serverInfo = serverInfo
|
|
updateServersList(info: serverInfo)
|
|
default:
|
|
logger.debug("unknown operation type: \(op)")
|
|
}
|
|
}
|
|
inputBuffer.clear()
|
|
}
|
|
|
|
private func handleIncomingMessage(_ message: MessageInbound) {
|
|
let natsMsg = NatsMessage(
|
|
payload: message.payload, subject: message.subject, replySubject: message.reply,
|
|
length: message.length, headers: nil, status: nil, description: nil)
|
|
subscriptions.withLockedValue { subs in
|
|
if let sub = subs[message.sid] {
|
|
sub.receiveMessage(natsMsg)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func handleIncomingMessage(_ message: HMessageInbound) {
|
|
let natsMsg = NatsMessage(
|
|
payload: message.payload, subject: message.subject, replySubject: message.reply,
|
|
length: message.length, headers: message.headers, status: message.status,
|
|
description: message.description)
|
|
subscriptions.withLockedValue { subs in
|
|
if let sub = subs[message.sid] {
|
|
sub.receiveMessage(natsMsg)
|
|
}
|
|
}
|
|
}
|
|
|
|
func connect() async throws {
|
|
self.setState(.connecting)
|
|
var servers = self.urls
|
|
if !self.retainServersOrder {
|
|
servers = self.urls.shuffled()
|
|
}
|
|
var lastErr: Error?
|
|
|
|
// if there are more reconnect attempts than the number of servers,
|
|
// we are after the initial connect, so sleep between servers
|
|
let shouldSleep = self.reconnectAttempts >= self.urls.count
|
|
for s in servers {
|
|
if let maxReconnects {
|
|
if reconnectAttempts > 0 && reconnectAttempts >= maxReconnects {
|
|
throw NatsError.ClientError.maxReconnects
|
|
}
|
|
}
|
|
self.reconnectAttempts += 1
|
|
if shouldSleep {
|
|
try await Task.sleep(nanoseconds: self.reconnectWait)
|
|
}
|
|
|
|
do {
|
|
try await connectToServer(s: s)
|
|
} catch let error as NatsError.ConnectError {
|
|
if case .invalidConfig(_) = error {
|
|
throw error
|
|
}
|
|
logger.debug("error connecting to server: \(error)")
|
|
lastErr = error
|
|
continue
|
|
} catch {
|
|
logger.debug("error connecting to server: \(error)")
|
|
lastErr = error
|
|
continue
|
|
}
|
|
lastErr = nil
|
|
break
|
|
}
|
|
if let lastErr {
|
|
self.state.withLockedValue { $0 = .disconnected }
|
|
switch lastErr {
|
|
case let error as ChannelError:
|
|
serverInfoContinuation.withLockedValue { $0 = nil }
|
|
var err: NatsError.ConnectError
|
|
switch error.self {
|
|
case .connectTimeout(_):
|
|
err = .timeout
|
|
default:
|
|
err = .io(error)
|
|
}
|
|
throw err
|
|
case let error as NIOConnectionError:
|
|
if let dnsAAAAError = error.dnsAAAAError {
|
|
throw NatsError.ConnectError.dns(dnsAAAAError)
|
|
} else if let dnsAError = error.dnsAError {
|
|
throw NatsError.ConnectError.dns(dnsAError)
|
|
} else {
|
|
throw NatsError.ConnectError.io(error)
|
|
}
|
|
case let err as NIOSSLError:
|
|
throw NatsError.ConnectError.tlsFailure(err)
|
|
case let err as BoringSSLError:
|
|
throw NatsError.ConnectError.tlsFailure(err)
|
|
case let err as NatsError.ServerError:
|
|
throw err
|
|
case let err as NatsError.ConnectError:
|
|
throw err
|
|
default:
|
|
throw NatsError.ConnectError.io(lastErr)
|
|
}
|
|
}
|
|
self.reconnectAttempts = 0
|
|
guard let channel = self.channel else {
|
|
throw NatsError.ClientError.internalError("empty channel")
|
|
}
|
|
// Schedule the task to send a PING periodically
|
|
let pingInterval = TimeAmount.nanoseconds(Int64(self.pingInterval * 1_000_000_000))
|
|
self.pingTask = channel.eventLoop.scheduleRepeatedTask(
|
|
initialDelay: pingInterval, delay: pingInterval
|
|
) { _ in
|
|
Task { await self.sendPing() }
|
|
}
|
|
logger.debug("connection established")
|
|
return
|
|
}
|
|
|
|
private func connectToServer(s: URL) async throws {
|
|
var infoTask: Task<(), Never>? = nil
|
|
// this continuation can throw NatsError.ServerError if server responds with
|
|
// -ERR to client connect (e.g. auth error)
|
|
let info: ServerInfo = try await withCheckedThrowingContinuation { continuation in
|
|
serverInfoContinuation.withLockedValue { $0 = continuation }
|
|
infoTask = Task {
|
|
await withTaskCancellationHandler {
|
|
do {
|
|
let (bootstrap, upgradePromise) = self.bootstrapConnection(to: s)
|
|
|
|
guard let host = s.host, let port = s.port else {
|
|
upgradePromise.succeed() // avoid promise leaks
|
|
throw NatsError.ConnectError.invalidConfig("no url")
|
|
}
|
|
|
|
let connect = bootstrap.connect(host: host, port: port)
|
|
connect.cascadeFailure(to: upgradePromise)
|
|
self.channel = try await connect.get()
|
|
|
|
guard let channel = self.channel else {
|
|
upgradePromise.succeed() // avoid promise leaks
|
|
throw NatsError.ClientError.internalError("empty channel")
|
|
}
|
|
|
|
try await upgradePromise.futureResult.get()
|
|
self.batchBuffer = BatchBuffer(channel: channel)
|
|
} catch {
|
|
let continuationToResume: CheckedContinuation<ServerInfo, Error>? = self
|
|
.serverInfoContinuation.withLockedValue { cont in
|
|
guard let c = cont else { return nil }
|
|
cont = nil
|
|
return c
|
|
}
|
|
if let continuation = continuationToResume {
|
|
continuation.resume(throwing: error)
|
|
}
|
|
}
|
|
} onCancel: {
|
|
logger.debug("Connection task cancelled")
|
|
// Clean up resources
|
|
if let channel = self.channel {
|
|
channel.close(mode: .all, promise: nil)
|
|
self.channel = nil
|
|
}
|
|
self.batchBuffer = nil
|
|
|
|
let continuationToResume: CheckedContinuation<ServerInfo, Error>? = self
|
|
.serverInfoContinuation.withLockedValue { cont in
|
|
guard let c = cont else { return nil }
|
|
cont = nil
|
|
return c
|
|
}
|
|
if let continuation = continuationToResume {
|
|
continuation.resume(throwing: NatsError.ClientError.cancelled)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
await infoTask?.value
|
|
self.serverInfo = info
|
|
if (info.tlsRequired ?? false || self.requireTls) && !self.tlsFirst && s.scheme != "wss" {
|
|
let tlsConfig = try makeTLSConfig()
|
|
let sslContext = try NIOSSLContext(configuration: tlsConfig)
|
|
let sslHandler = try NIOSSLClientHandler(
|
|
context: sslContext, serverHostname: s.host)
|
|
try await self.channel?.pipeline.addHandler(sslHandler, position: .first)
|
|
}
|
|
|
|
try await sendClientConnectInit()
|
|
self.connectedUrl = s
|
|
}
|
|
|
|
private func makeTLSConfig() throws -> TLSConfiguration {
|
|
var tlsConfiguration =
|
|
TLSConfiguration.makeClientConfiguration()
|
|
if let rootCertificate = self.rootCertificate {
|
|
tlsConfiguration.trustRoots = .file(
|
|
rootCertificate.path)
|
|
}
|
|
if let clientCertificate = self.clientCertificate,
|
|
let clientKey = self.clientKey
|
|
{
|
|
// Load the client certificate from the PEM file
|
|
let certificate = try NIOSSLCertificate.fromPEMFile(
|
|
clientCertificate.path
|
|
).map { NIOSSLCertificateSource.certificate($0) }
|
|
tlsConfiguration.certificateChain = certificate
|
|
|
|
// Load the private key from the file
|
|
let privateKey = try NIOSSLPrivateKey(
|
|
file: clientKey.path, format: .pem)
|
|
tlsConfiguration.privateKey = .privateKey(
|
|
privateKey)
|
|
}
|
|
return tlsConfiguration
|
|
}
|
|
|
|
private func sendClientConnectInit() async throws {
|
|
var initialConnect = ConnectInfo(
|
|
verbose: false, pedantic: false, userJwt: nil, nkey: "", name: "", echo: true,
|
|
lang: self.lang, version: self.version, natsProtocol: .dynamic, tlsRequired: false,
|
|
user: self.auth?.user ?? "", pass: self.auth?.password ?? "",
|
|
authToken: self.auth?.token ?? "", headers: true, noResponders: true)
|
|
|
|
if self.auth?.nkey != nil && self.auth?.nkeyPath != nil {
|
|
throw NatsError.ConnectError.invalidConfig("cannot use both nkey and nkeyPath")
|
|
}
|
|
if let auth = self.auth, let credentialsPath = auth.credentialsPath {
|
|
let credentials = try await URLSession.shared.data(from: credentialsPath).0
|
|
guard let jwt = JwtUtils.parseDecoratedJWT(contents: credentials) else {
|
|
throw NatsError.ConnectError.invalidConfig(
|
|
"failed to extract JWT from credentials file")
|
|
}
|
|
guard let nkey = JwtUtils.parseDecoratedNKey(contents: credentials) else {
|
|
throw NatsError.ConnectError.invalidConfig(
|
|
"failed to extract NKEY from credentials file")
|
|
}
|
|
guard let nonce = self.serverInfo?.nonce else {
|
|
throw NatsError.ConnectError.invalidConfig("missing nonce")
|
|
}
|
|
let keypair = try KeyPair(seed: String(data: nkey, encoding: .utf8)!)
|
|
let nonceData = nonce.data(using: .utf8)!
|
|
let sig = try keypair.sign(input: nonceData)
|
|
let base64sig = sig.base64EncodedURLSafeNotPadded()
|
|
initialConnect.signature = base64sig
|
|
initialConnect.userJwt = String(data: jwt, encoding: .utf8)!
|
|
}
|
|
if let nkey = self.auth?.nkeyPath {
|
|
let nkeyData = try await URLSession.shared.data(from: nkey).0
|
|
|
|
guard let nkeyContent = String(data: nkeyData, encoding: .utf8) else {
|
|
throw NatsError.ConnectError.invalidConfig("failed to read NKEY file")
|
|
}
|
|
let keypair = try KeyPair(
|
|
seed: nkeyContent.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
)
|
|
|
|
guard let nonce = self.serverInfo?.nonce else {
|
|
throw NatsError.ConnectError.invalidConfig("missing nonce")
|
|
}
|
|
let sig = try keypair.sign(input: nonce.data(using: .utf8)!)
|
|
let base64sig = sig.base64EncodedURLSafeNotPadded()
|
|
initialConnect.signature = base64sig
|
|
initialConnect.nkey = keypair.publicKeyEncoded
|
|
}
|
|
if let nkey = self.auth?.nkey {
|
|
let keypair = try KeyPair(seed: nkey)
|
|
guard let nonce = self.serverInfo?.nonce else {
|
|
throw NatsError.ConnectError.invalidConfig("missing nonce")
|
|
}
|
|
let nonceData = nonce.data(using: .utf8)!
|
|
let sig = try keypair.sign(input: nonceData)
|
|
let base64sig = sig.base64EncodedURLSafeNotPadded()
|
|
initialConnect.signature = base64sig
|
|
initialConnect.nkey = keypair.publicKeyEncoded
|
|
}
|
|
let connect = initialConnect
|
|
// this continuation can throw NatsError.ServerError if server responds with
|
|
// -ERR to client connect (e.g. auth error)
|
|
try await withTaskCancellationHandler {
|
|
try await withCheckedThrowingContinuation { continuation in
|
|
connectionEstablishedContinuation.withLockedValue { $0 = continuation }
|
|
Task.detached {
|
|
do {
|
|
try await self.write(operation: ClientOp.connect(connect))
|
|
try await self.write(operation: ClientOp.ping)
|
|
self.channel?.flush()
|
|
} catch {
|
|
let continuationToResume: CheckedContinuation<Void, Error>? = self
|
|
.connectionEstablishedContinuation.withLockedValue { cont in
|
|
guard let c = cont else { return nil }
|
|
cont = nil
|
|
return c
|
|
}
|
|
if let continuation = continuationToResume {
|
|
continuation.resume(throwing: error)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} onCancel: {
|
|
logger.debug("Client connect initialization cancelled")
|
|
// Clean up resources
|
|
if let channel = self.channel {
|
|
channel.close(mode: .all, promise: nil)
|
|
self.channel = nil
|
|
}
|
|
self.batchBuffer = nil
|
|
|
|
let continuationToResume: CheckedContinuation<Void, Error>? = self
|
|
.connectionEstablishedContinuation.withLockedValue { cont in
|
|
guard let c = cont else { return nil }
|
|
cont = nil
|
|
return c
|
|
}
|
|
if let continuation = continuationToResume {
|
|
continuation.resume(throwing: NatsError.ClientError.cancelled)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func bootstrapConnection(
|
|
to server: URL
|
|
) -> (ClientBootstrap, EventLoopPromise<Void>) {
|
|
let upgradePromise: EventLoopPromise<Void> = self.group.any().makePromise(of: Void.self)
|
|
let bootstrap = ClientBootstrap(group: self.group)
|
|
.channelOption(
|
|
ChannelOptions.socket(
|
|
SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR),
|
|
value: 1
|
|
)
|
|
.channelInitializer { channel in
|
|
if self.requireTls && self.tlsFirst {
|
|
upgradePromise.succeed(())
|
|
do {
|
|
let tlsConfig = try self.makeTLSConfig()
|
|
let sslContext = try NIOSSLContext(
|
|
configuration: tlsConfig)
|
|
let sslHandler = try NIOSSLClientHandler(
|
|
context: sslContext, serverHostname: server.host!)
|
|
//Fixme(jrm): do not ignore error from addHandler future.
|
|
channel.pipeline.addHandler(sslHandler).flatMap { _ in
|
|
channel.pipeline.addHandler(self)
|
|
}.whenComplete { result in
|
|
switch result {
|
|
case .success():
|
|
print("success")
|
|
case .failure(let error):
|
|
print("error: \(error)")
|
|
}
|
|
}
|
|
return channel.eventLoop.makeSucceededFuture(())
|
|
} catch {
|
|
let tlsError = NatsError.ConnectError.tlsFailure(error)
|
|
return channel.eventLoop.makeFailedFuture(tlsError)
|
|
}
|
|
} else {
|
|
if server.scheme == "ws" || server.scheme == "wss" {
|
|
let httpUpgradeRequestHandler = HTTPUpgradeRequestHandler(
|
|
host: server.host ?? "localhost",
|
|
path: server.path,
|
|
query: server.query,
|
|
headers: HTTPHeaders(), // TODO (mtmk): pass in from client options
|
|
upgradePromise: upgradePromise)
|
|
let httpUpgradeRequestHandlerBox = NIOLoopBound(
|
|
httpUpgradeRequestHandler, eventLoop: channel.eventLoop)
|
|
|
|
let websocketUpgrader = NIOWebSocketClientUpgrader(
|
|
maxFrameSize: 8 * 1024 * 1024,
|
|
automaticErrorHandling: true,
|
|
upgradePipelineHandler: { channel, _ in
|
|
let wsh = NIOWebSocketFrameAggregator(
|
|
minNonFinalFragmentSize: 0,
|
|
maxAccumulatedFrameCount: Int.max,
|
|
maxAccumulatedFrameSize: Int.max
|
|
)
|
|
return channel.pipeline.addHandler(wsh).flatMap {
|
|
channel.pipeline.addHandler(WebSocketByteBufferCodec()).flatMap
|
|
{
|
|
channel.pipeline.addHandler(self)
|
|
}
|
|
}
|
|
}
|
|
)
|
|
|
|
let config: NIOHTTPClientUpgradeConfiguration = (
|
|
upgraders: [websocketUpgrader],
|
|
completionHandler: { context in
|
|
upgradePromise.succeed(())
|
|
channel.pipeline.removeHandler(
|
|
httpUpgradeRequestHandlerBox.value, promise: nil)
|
|
}
|
|
)
|
|
|
|
if server.scheme == "wss" {
|
|
do {
|
|
let tlsConfig = try self.makeTLSConfig()
|
|
let sslContext = try NIOSSLContext(
|
|
configuration: tlsConfig)
|
|
let sslHandler = try NIOSSLClientHandler(
|
|
context: sslContext, serverHostname: server.host!)
|
|
// The sync methods here are safe because we're on the channel event loop
|
|
// due to the promise originating on the event loop of the channel.
|
|
try channel.pipeline.syncOperations.addHandler(sslHandler)
|
|
} catch {
|
|
let tlsError = NatsError.ConnectError.tlsFailure(error)
|
|
upgradePromise.fail(tlsError)
|
|
return channel.eventLoop.makeFailedFuture(tlsError)
|
|
}
|
|
}
|
|
|
|
//Fixme(jrm): do not ignore error from addHandler future.
|
|
channel.pipeline.addHTTPClientHandlers(
|
|
leftOverBytesStrategy: .forwardBytes,
|
|
withClientUpgrade: config
|
|
).flatMap {
|
|
channel.pipeline.addHandler(httpUpgradeRequestHandlerBox.value)
|
|
}.whenComplete { result in
|
|
switch result {
|
|
case .success():
|
|
logger.debug("success")
|
|
case .failure(let error):
|
|
logger.debug("error: \(error)")
|
|
}
|
|
}
|
|
} else {
|
|
upgradePromise.succeed(())
|
|
//Fixme(jrm): do not ignore error from addHandler future.
|
|
channel.pipeline.addHandler(self).whenComplete { result in
|
|
switch result {
|
|
case .success():
|
|
logger.debug("success")
|
|
case .failure(let error):
|
|
logger.debug("error: \(error)")
|
|
}
|
|
}
|
|
}
|
|
return channel.eventLoop.makeSucceededFuture(())
|
|
}
|
|
}.connectTimeout(.seconds(5))
|
|
return (bootstrap, upgradePromise)
|
|
}
|
|
|
|
private func updateServersList(info: ServerInfo) {
|
|
if let connectUrls = info.connectUrls {
|
|
for connectUrl in connectUrls {
|
|
guard let url = URL(string: connectUrl) else {
|
|
continue
|
|
}
|
|
if !self.urls.contains(url) {
|
|
urls.append(url)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func close() async throws {
|
|
self.reconnectTask?.cancel()
|
|
try await self.reconnectTask?.value
|
|
|
|
guard let eventLoop = self.channel?.eventLoop else {
|
|
self.state.withLockedValue { $0 = .closed }
|
|
self.pingTask?.cancel()
|
|
clearPendingPings() // Clear pending pings to avoid promise leaks
|
|
self.fire(.closed)
|
|
return
|
|
}
|
|
let promise = eventLoop.makePromise(of: Void.self)
|
|
|
|
eventLoop.execute {
|
|
self.state.withLockedValue { $0 = .closed }
|
|
self.pingTask?.cancel()
|
|
self.clearPendingPings() // Clear pending pings to avoid promise leaks
|
|
self.channel?.close(mode: .all, promise: promise)
|
|
}
|
|
|
|
do {
|
|
try await promise.futureResult.get()
|
|
} catch ChannelError.alreadyClosed {
|
|
// we don't want to throw an error if channel is already closed
|
|
// as that would mean we would get an error closing client during reconnect
|
|
}
|
|
|
|
self.fire(.closed)
|
|
}
|
|
|
|
private func disconnect() async throws {
|
|
self.pingTask?.cancel()
|
|
clearPendingPings() // Clear pending pings to avoid promise leaks
|
|
try await self.channel?.close().get()
|
|
}
|
|
|
|
/// Clear all pending ping requests to avoid promise leaks
|
|
private func clearPendingPings() {
|
|
let pendingPings = pingQueue.dequeueAll()
|
|
for ping in pendingPings {
|
|
ping.cancel()
|
|
}
|
|
if !pendingPings.isEmpty {
|
|
logger.debug("Cleared \(pendingPings.count) pending ping(s)")
|
|
}
|
|
}
|
|
|
|
func suspend() async throws {
|
|
self.reconnectTask?.cancel()
|
|
_ = try await self.reconnectTask?.value
|
|
|
|
// Handle case where channel is already nil (e.g., during rapid reconnections)
|
|
guard let eventLoop = self.channel?.eventLoop else {
|
|
// Set state to suspended even if channel is nil
|
|
self.state.withLockedValue { $0 = .suspended }
|
|
clearPendingPings() // Clear pending pings to avoid promise leaks
|
|
return
|
|
}
|
|
let promise = eventLoop.makePromise(of: Void.self)
|
|
|
|
eventLoop.execute { // This ensures the code block runs on the event loop
|
|
let shouldClose = self.state.withLockedValue { currentState in
|
|
let wasConnected = currentState == .connected
|
|
currentState = .suspended
|
|
return wasConnected
|
|
}
|
|
|
|
if shouldClose {
|
|
self.pingTask?.cancel()
|
|
self.clearPendingPings() // Clear pending pings to avoid promise leaks
|
|
self.channel?.close(mode: .all, promise: promise)
|
|
} else {
|
|
self.clearPendingPings() // Clear pending pings even if not closing
|
|
promise.succeed()
|
|
}
|
|
}
|
|
|
|
try await promise.futureResult.get()
|
|
self.fire(.suspended)
|
|
}
|
|
|
|
func resume() async throws {
|
|
guard let eventLoop = self.channel?.eventLoop else {
|
|
throw NatsError.ClientError.internalError("channel should not be nil")
|
|
}
|
|
try await eventLoop.submit {
|
|
let canResume = self.state.withLockedValue { $0 == .suspended }
|
|
guard canResume else {
|
|
throw NatsError.ClientError.invalidConnection(
|
|
"unable to resume connection - connection is not in suspended state")
|
|
}
|
|
self.handleReconnect()
|
|
}.get()
|
|
}
|
|
|
|
func reconnect() async throws {
|
|
try await suspend()
|
|
try await resume()
|
|
}
|
|
|
|
internal func sendPing(_ rttCommand: RttCommand? = nil) async {
|
|
let pingsOut = self.outstandingPings.wrappingIncrementThenLoad(
|
|
ordering: AtomicUpdateOrdering.relaxed)
|
|
if pingsOut > 2 {
|
|
handleDisconnect()
|
|
return
|
|
}
|
|
let ping = ClientOp.ping
|
|
do {
|
|
self.pingQueue.enqueue(rttCommand ?? RttCommand.makeFrom(channel: self.channel))
|
|
try await self.write(operation: ping)
|
|
logger.debug("sent ping: \(pingsOut)")
|
|
} catch {
|
|
logger.error("Unable to send ping: \(error)")
|
|
}
|
|
|
|
}
|
|
|
|
func channelActive(context: ChannelHandlerContext) {
|
|
logger.debug("TCP channel active")
|
|
|
|
parseRemainder.withLockedValue { $0 = nil }
|
|
|
|
inputBuffer = context.channel.allocator.buffer(capacity: 1024 * 1024 * 8)
|
|
}
|
|
|
|
func channelInactive(context: ChannelHandlerContext) {
|
|
logger.debug("TCP channel inactive")
|
|
|
|
// If we lost the channel before we delivered server INFO or connection
|
|
// establishment, make sure to fail any pending continuations to avoid leaks.
|
|
// Use captured error if available (e.g., TLS failure), otherwise use connectionClosed.
|
|
let errorToUse: Error = capturedConnectionError.withLockedValue({ err in
|
|
let captured = err
|
|
err = nil // Clear after using
|
|
if let capturedError = captured {
|
|
return NatsError.ConnectError.tlsFailure(capturedError)
|
|
} else {
|
|
return NatsError.ClientError.connectionClosed
|
|
}
|
|
})
|
|
|
|
if let continuation = serverInfoContinuation.withLockedValue({ cont in
|
|
let toResume = cont
|
|
cont = nil
|
|
return toResume
|
|
}) {
|
|
continuation.resume(throwing: errorToUse)
|
|
}
|
|
|
|
if let continuation = connectionEstablishedContinuation.withLockedValue({ cont in
|
|
let toResume = cont
|
|
cont = nil
|
|
return toResume
|
|
}) {
|
|
continuation.resume(throwing: errorToUse)
|
|
}
|
|
|
|
let shouldHandleDisconnect = state.withLockedValue { $0 == .connected }
|
|
if shouldHandleDisconnect {
|
|
handleDisconnect()
|
|
}
|
|
}
|
|
|
|
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
|
logger.debug("Encountered error on the channel: \(error)")
|
|
|
|
// Capture connection-stage errors (especially TLS) for proper error reporting BEFORE closing
|
|
let isConnecting = state.withLockedValue { $0 == .pending || $0 == .connecting }
|
|
if isConnecting {
|
|
capturedConnectionError.withLockedValue { $0 = error }
|
|
}
|
|
|
|
context.close(promise: nil)
|
|
|
|
if let natsErr = error as? NatsErrorProtocol {
|
|
self.fire(.error(natsErr))
|
|
} else {
|
|
logger.error("unexpected error: \(error)")
|
|
}
|
|
let currentState = state.withLockedValue { $0 }
|
|
if currentState == .pending || currentState == .connecting {
|
|
handleDisconnect()
|
|
} else if currentState == .disconnected {
|
|
handleReconnect()
|
|
}
|
|
}
|
|
|
|
func handleDisconnect() {
|
|
state.withLockedValue { $0 = .disconnected }
|
|
if let channel = self.channel {
|
|
let promise = channel.eventLoop.makePromise(of: Void.self)
|
|
Task {
|
|
do {
|
|
try await self.disconnect()
|
|
promise.succeed()
|
|
} catch ChannelError.alreadyClosed {
|
|
// if the channel was already closed, no need to return error
|
|
promise.succeed()
|
|
} catch {
|
|
promise.fail(error)
|
|
}
|
|
}
|
|
promise.futureResult.whenComplete { result in
|
|
do {
|
|
try result.get()
|
|
self.fire(.disconnected)
|
|
} catch {
|
|
logger.error("Error closing connection: \(error)")
|
|
}
|
|
}
|
|
}
|
|
|
|
handleReconnect()
|
|
}
|
|
|
|
func handleReconnect() {
|
|
reconnectTask = Task {
|
|
var connected = false
|
|
while !Task.isCancelled
|
|
&& (maxReconnects == nil || self.reconnectAttempts < maxReconnects!)
|
|
{
|
|
do {
|
|
try await self.connect()
|
|
connected = true
|
|
break // Successfully connected
|
|
} catch is CancellationError {
|
|
logger.debug("Reconnect task cancelled")
|
|
return
|
|
} catch {
|
|
logger.debug("Could not reconnect: \(error)")
|
|
if !Task.isCancelled {
|
|
try await Task.sleep(nanoseconds: self.reconnectWait)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Early return if cancelled
|
|
if Task.isCancelled {
|
|
logger.debug("Reconnect task cancelled after connection attempts")
|
|
return
|
|
}
|
|
|
|
// If we got here without connecting and weren't cancelled, we hit max reconnects
|
|
if !connected {
|
|
logger.error("Could not reconnect; maxReconnects exceeded")
|
|
try await self.close()
|
|
return
|
|
}
|
|
|
|
// Recreate subscriptions - safely copy first
|
|
let subsToRestore = subscriptions.withLockedValue { Array($0) }
|
|
for (sid, sub) in subsToRestore {
|
|
do {
|
|
try await write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
|
|
} catch {
|
|
logger.error("Error recreating subscription \(sid): \(error)")
|
|
}
|
|
}
|
|
|
|
self.channel?.eventLoop.execute {
|
|
self.state.withLockedValue { $0 = .connected }
|
|
self.fire(.connected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func write(operation: ClientOp) async throws {
|
|
guard let buffer = self.batchBuffer else {
|
|
throw NatsError.ClientError.invalidConnection("not connected")
|
|
}
|
|
do {
|
|
try await buffer.writeMessage(operation)
|
|
} catch {
|
|
throw NatsError.ClientError.io(error)
|
|
}
|
|
}
|
|
|
|
internal func subscribe(
|
|
_ subject: String, queue: String? = nil
|
|
) async throws -> NatsSubscription {
|
|
let sid = self.subscriptionCounter.wrappingIncrementThenLoad(
|
|
ordering: AtomicUpdateOrdering.relaxed)
|
|
let sub = try NatsSubscription(sid: sid, subject: subject, queue: queue, conn: self)
|
|
|
|
// Add subscription BEFORE sending command to avoid race condition
|
|
subscriptions.withLockedValue { $0[sid] = sub }
|
|
|
|
do {
|
|
try await write(operation: ClientOp.subscribe((sid, subject, queue)))
|
|
} catch {
|
|
// Remove subscription if subscribe command fails
|
|
subscriptions.withLockedValue { $0.removeValue(forKey: sid) }
|
|
throw error
|
|
}
|
|
|
|
return sub
|
|
}
|
|
|
|
internal func unsubscribe(sub: NatsSubscription, max: UInt64?) async throws {
|
|
if let max, sub.delivered < max {
|
|
// if max is set and the sub has not yet reached it, send unsub with max set
|
|
// and do not remove the sub from connection
|
|
try await write(operation: ClientOp.unsubscribe((sid: sub.sid, max: max)))
|
|
sub.max = max
|
|
} else {
|
|
// if max is not set or the subscription received at least as many
|
|
// messages as max, send unsub command without max and remove sub from connection
|
|
try await write(operation: ClientOp.unsubscribe((sid: sub.sid, max: nil)))
|
|
self.removeSub(sub: sub)
|
|
}
|
|
}
|
|
|
|
internal func removeSub(sub: NatsSubscription) {
|
|
subscriptions.withLockedValue { $0.removeValue(forKey: sub.sid) }
|
|
sub.complete()
|
|
}
|
|
}
|
|
|
|
extension ConnectionHandler {
|
|
|
|
internal func fire(_ event: NatsEvent) {
|
|
let eventKind = event.kind()
|
|
guard let handlerStore = self.eventHandlerStore[eventKind] else { return }
|
|
|
|
for handler in handlerStore {
|
|
handler.handler(event)
|
|
}
|
|
}
|
|
|
|
internal func addListeners(
|
|
for events: [NatsEventKind], using handler: @escaping (NatsEvent) -> Void
|
|
) -> String {
|
|
|
|
let id = String.hash()
|
|
|
|
for event in events {
|
|
if self.eventHandlerStore[event] == nil {
|
|
self.eventHandlerStore[event] = []
|
|
}
|
|
self.eventHandlerStore[event]?.append(
|
|
NatsEventHandler(lid: id, handler: handler))
|
|
}
|
|
|
|
return id
|
|
|
|
}
|
|
|
|
internal func removeListener(_ id: String) {
|
|
|
|
for event in NatsEventKind.all {
|
|
|
|
let handlerStore = self.eventHandlerStore[event]
|
|
if let store = handlerStore {
|
|
self.eventHandlerStore[event] = store.filter { $0.listenerId != id }
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
/// Nats events
|
|
public enum NatsEventKind: String {
|
|
case connected = "connected"
|
|
case disconnected = "disconnected"
|
|
case closed = "closed"
|
|
case suspended = "suspended"
|
|
case lameDuckMode = "lameDuckMode"
|
|
case error = "error"
|
|
static let all = [connected, disconnected, closed, lameDuckMode, error]
|
|
}
|
|
|
|
public enum NatsEvent {
|
|
case connected
|
|
case disconnected
|
|
case suspended
|
|
case closed
|
|
case lameDuckMode
|
|
case error(NatsErrorProtocol)
|
|
|
|
public func kind() -> NatsEventKind {
|
|
switch self {
|
|
case .connected:
|
|
return .connected
|
|
case .disconnected:
|
|
return .disconnected
|
|
case .suspended:
|
|
return .suspended
|
|
case .closed:
|
|
return .closed
|
|
case .lameDuckMode:
|
|
return .lameDuckMode
|
|
case .error(_):
|
|
return .error
|
|
}
|
|
}
|
|
}
|
|
|
|
internal struct NatsEventHandler {
|
|
let listenerId: String
|
|
let handler: (NatsEvent) -> Void
|
|
init(lid: String, handler: @escaping (NatsEvent) -> Void) {
|
|
self.listenerId = lid
|
|
self.handler = handler
|
|
}
|
|
}
|