Files
nats.swift/Sources/Nats/NatsConnection.swift

1220 lines
48 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Atomics
import Dispatch
import Foundation
import NIO
import NIOConcurrencyHelpers
import NIOFoundationCompat
import NIOHTTP1
import NIOSSL
import NIOWebSocket
import NKeys
class ConnectionHandler: ChannelInboundHandler {
let lang = "Swift"
let version = "0.0.1"
internal var connectedUrl: URL?
internal let allocator = ByteBufferAllocator()
internal var inputBuffer: ByteBuffer
internal var channel: Channel?
private var eventHandlerStore: [NatsEventKind: [NatsEventHandler]] = [:]
// Connection options
internal var retryOnFailedConnect = false
private var urls: [URL]
// nanoseconds representation of TimeInterval
private let reconnectWait: UInt64
private let maxReconnects: Int?
private let retainServersOrder: Bool
private let pingInterval: TimeInterval
private let requireTls: Bool
private let tlsFirst: Bool
private var rootCertificate: URL?
private var clientCertificate: URL?
private var clientKey: URL?
typealias InboundIn = ByteBuffer
private let state = NIOLockedValueBox(NatsState.pending)
private let subscriptions = NIOLockedValueBox([UInt64: NatsSubscription]())
// Helper methods for state access
internal var currentState: NatsState {
state.withLockedValue { $0 }
}
internal func setState(_ newState: NatsState) {
state.withLockedValue { $0 = newState }
}
private var subscriptionCounter = ManagedAtomic<UInt64>(0)
private var serverInfo: ServerInfo?
private var auth: Auth?
private let parseRemainder = NIOLockedValueBox<Data?>(nil)
private var pingTask: RepeatedTask?
private var outstandingPings = ManagedAtomic<UInt8>(0)
private var reconnectAttempts = 0
private var reconnectTask: Task<(), Error>? = nil
private let capturedConnectionError = NIOLockedValueBox<Error?>(nil)
private var group: MultiThreadedEventLoopGroup
private let serverInfoContinuation = NIOLockedValueBox<CheckedContinuation<ServerInfo, Error>?>(
nil)
private let connectionEstablishedContinuation = NIOLockedValueBox<
CheckedContinuation<Void, Error>?
>(nil)
private let pingQueue = ConcurrentQueue<RttCommand>()
private(set) var batchBuffer: BatchBuffer?
init(
inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?,
retainServersOrder: Bool,
pingInterval: TimeInterval, auth: Auth?, requireTls: Bool, tlsFirst: Bool,
clientCertificate: URL?, clientKey: URL?,
rootCertificate: URL?, retryOnFailedConnect: Bool
) {
self.urls = urls
self.group = .singleton
self.inputBuffer = allocator.buffer(capacity: 1024)
self.reconnectWait = UInt64(reconnectWait * 1_000_000_000)
self.maxReconnects = maxReconnects
self.retainServersOrder = retainServersOrder
self.auth = auth
self.pingInterval = pingInterval
self.requireTls = requireTls
self.tlsFirst = tlsFirst
self.clientCertificate = clientCertificate
self.clientKey = clientKey
self.rootCertificate = rootCertificate
self.retryOnFailedConnect = retryOnFailedConnect
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
var byteBuffer = self.unwrapInboundIn(data)
inputBuffer.writeBuffer(&byteBuffer)
}
func channelReadComplete(context: ChannelHandlerContext) {
guard inputBuffer.readableBytes > 0 else {
return
}
var inputChunk = Data(buffer: inputBuffer)
let remainder = parseRemainder.withLockedValue { value in
let current = value
value = nil
return current
}
if let remainder = remainder, !remainder.isEmpty {
inputChunk.prepend(remainder)
}
let parseResult: (ops: [ServerOp], remainder: Data?)
do {
parseResult = try inputChunk.parseOutMessages()
} catch {
// if parsing throws an error, clear buffer and remainder, then reconnect
inputBuffer.clear()
parseRemainder.withLockedValue { $0 = nil }
context.fireErrorCaught(error)
return
}
if let remainder = parseResult.remainder {
parseRemainder.withLockedValue { $0 = remainder }
}
for op in parseResult.ops {
// Only resume the server info continuation when we actually receive
// an INFO or -ERR op. Do NOT clear it for unrelated ops.
switch op {
case .error(let err):
if let continuation = serverInfoContinuation.withLockedValue({ cont in
let toResume = cont
cont = nil
return toResume
}) {
logger.debug("server info error")
continuation.resume(throwing: err)
continue
}
case .info(let info):
if let continuation = serverInfoContinuation.withLockedValue({ cont in
let toResume = cont
cont = nil
return toResume
}) {
logger.debug("server info")
continuation.resume(returning: info)
continue
}
default:
break
}
let connEstablishedCont = connectionEstablishedContinuation.withLockedValue { cont in
let toResume = cont
cont = nil
return toResume
}
if let continuation = connEstablishedCont {
logger.debug("conn established")
switch op {
case .error(let err):
continuation.resume(throwing: err)
default:
continuation.resume()
}
continue
}
switch op {
case .ping:
logger.debug("ping")
Task {
do {
try await self.write(operation: .pong)
} catch let err as NatsError.ClientError {
logger.error("error sending pong: \(err)")
self.fire(
.error(err))
} catch {
logger.error("unexpected error sending pong: \(error)")
}
}
case .pong:
logger.debug("pong")
self.outstandingPings.store(0, ordering: AtomicStoreOrdering.relaxed)
self.pingQueue.dequeue()?.setRoundTripTime()
case .error(let err):
logger.debug("error \(err)")
switch err {
case .staleConnection, .maxConnectionsExceeded:
inputBuffer.clear()
parseRemainder.withLockedValue { $0 = nil }
context.fireErrorCaught(err)
case .permissionsViolation(let operation, let subject, _):
switch operation {
case .subscribe:
subscriptions.withLockedValue { subs in
for (_, s) in subs {
if s.subject == subject {
s.receiveError(NatsError.SubscriptionError.permissionDenied)
}
}
}
case .publish:
self.fire(.error(err))
}
default:
self.fire(.error(err))
}
let normalizedError = err.normalizedError
// on some errors, force reconnect
if normalizedError == "stale connection"
|| normalizedError == "maximum connections exceeded"
{
inputBuffer.clear()
parseRemainder.withLockedValue { $0 = nil }
context.fireErrorCaught(err)
} else {
self.fire(.error(err))
}
case .message(let msg):
self.handleIncomingMessage(msg)
case .hMessage(let msg):
self.handleIncomingMessage(msg)
case .info(let serverInfo):
logger.debug("info \(op)")
self.serverInfo = serverInfo
if serverInfo.lameDuckMode {
self.fire(.lameDuckMode)
}
self.serverInfo = serverInfo
updateServersList(info: serverInfo)
default:
logger.debug("unknown operation type: \(op)")
}
}
inputBuffer.clear()
}
private func handleIncomingMessage(_ message: MessageInbound) {
let natsMsg = NatsMessage(
payload: message.payload, subject: message.subject, replySubject: message.reply,
length: message.length, headers: nil, status: nil, description: nil)
subscriptions.withLockedValue { subs in
if let sub = subs[message.sid] {
sub.receiveMessage(natsMsg)
}
}
}
private func handleIncomingMessage(_ message: HMessageInbound) {
let natsMsg = NatsMessage(
payload: message.payload, subject: message.subject, replySubject: message.reply,
length: message.length, headers: message.headers, status: message.status,
description: message.description)
subscriptions.withLockedValue { subs in
if let sub = subs[message.sid] {
sub.receiveMessage(natsMsg)
}
}
}
func connect() async throws {
self.setState(.connecting)
var servers = self.urls
if !self.retainServersOrder {
servers = self.urls.shuffled()
}
var lastErr: Error?
// if there are more reconnect attempts than the number of servers,
// we are after the initial connect, so sleep between servers
let shouldSleep = self.reconnectAttempts >= self.urls.count
for s in servers {
if let maxReconnects {
if reconnectAttempts > 0 && reconnectAttempts >= maxReconnects {
throw NatsError.ClientError.maxReconnects
}
}
self.reconnectAttempts += 1
if shouldSleep {
try await Task.sleep(nanoseconds: self.reconnectWait)
}
do {
try await connectToServer(s: s)
} catch let error as NatsError.ConnectError {
if case .invalidConfig(_) = error {
throw error
}
logger.debug("error connecting to server: \(error)")
lastErr = error
continue
} catch {
logger.debug("error connecting to server: \(error)")
lastErr = error
continue
}
lastErr = nil
break
}
if let lastErr {
self.state.withLockedValue { $0 = .disconnected }
switch lastErr {
case let error as ChannelError:
serverInfoContinuation.withLockedValue { $0 = nil }
var err: NatsError.ConnectError
switch error.self {
case .connectTimeout(_):
err = .timeout
default:
err = .io(error)
}
throw err
case let error as NIOConnectionError:
if let dnsAAAAError = error.dnsAAAAError {
throw NatsError.ConnectError.dns(dnsAAAAError)
} else if let dnsAError = error.dnsAError {
throw NatsError.ConnectError.dns(dnsAError)
} else {
throw NatsError.ConnectError.io(error)
}
case let err as NIOSSLError:
throw NatsError.ConnectError.tlsFailure(err)
case let err as BoringSSLError:
throw NatsError.ConnectError.tlsFailure(err)
case let err as NatsError.ServerError:
throw err
case let err as NatsError.ConnectError:
throw err
default:
throw NatsError.ConnectError.io(lastErr)
}
}
self.reconnectAttempts = 0
guard let channel = self.channel else {
throw NatsError.ClientError.internalError("empty channel")
}
// ping
self.outstandingPings.store(0, ordering: .relaxed)
// Schedule the task to send a PING periodically
let pingInterval = TimeAmount.nanoseconds(Int64(self.pingInterval * 1_000_000_000))
self.pingTask = channel.eventLoop.scheduleRepeatedTask(
initialDelay: pingInterval, delay: pingInterval
) { _ in
Task { await self.sendPing() }
}
logger.debug("connection established")
return
}
private func connectToServer(s: URL) async throws {
var infoTask: Task<(), Never>? = nil
// this continuation can throw NatsError.ServerError if server responds with
// -ERR to client connect (e.g. auth error)
let info: ServerInfo = try await withCheckedThrowingContinuation { continuation in
serverInfoContinuation.withLockedValue { $0 = continuation }
infoTask = Task {
await withTaskCancellationHandler {
do {
let (bootstrap, upgradePromise) = self.bootstrapConnection(to: s)
guard let host = s.host, let port = s.port else {
upgradePromise.succeed() // avoid promise leaks
throw NatsError.ConnectError.invalidConfig("no url")
}
let connect = bootstrap.connect(host: host, port: port)
connect.cascadeFailure(to: upgradePromise)
self.channel = try await connect.get()
guard let channel = self.channel else {
upgradePromise.succeed() // avoid promise leaks
throw NatsError.ClientError.internalError("empty channel")
}
try await upgradePromise.futureResult.get()
self.batchBuffer = BatchBuffer(channel: channel)
} catch {
let continuationToResume: CheckedContinuation<ServerInfo, Error>? = self
.serverInfoContinuation.withLockedValue { cont in
guard let c = cont else { return nil }
cont = nil
return c
}
if let continuation = continuationToResume {
continuation.resume(throwing: error)
}
}
} onCancel: {
logger.debug("Connection task cancelled")
// Clean up resources safely to avoid race conditions
// Capture references first to avoid concurrent access issues
let channelToClose = self.channel
let bufferToRelease = self.batchBuffer
// Clear references first to prevent other threads from using them
self.channel = nil
self.batchBuffer = nil
// Close channel asynchronously after clearing references
// This ensures BatchBuffer's deinit won't conflict with channel close
if let channel = channelToClose {
channel.eventLoop.execute {
channel.close(mode: .all, promise: nil)
}
}
// bufferToRelease will be released here after channel close is scheduled
_ = bufferToRelease
let continuationToResume: CheckedContinuation<ServerInfo, Error>? = self
.serverInfoContinuation.withLockedValue { cont in
guard let c = cont else { return nil }
cont = nil
return c
}
if let continuation = continuationToResume {
continuation.resume(throwing: NatsError.ClientError.cancelled)
}
}
}
}
await infoTask?.value
self.serverInfo = info
if (info.tlsRequired ?? false || self.requireTls) && !self.tlsFirst && s.scheme != "wss" {
let tlsConfig = try makeTLSConfig()
let sslContext = try NIOSSLContext(configuration: tlsConfig)
let sslHandler = try NIOSSLClientHandler(
context: sslContext, serverHostname: s.host)
try await self.channel?.pipeline.addHandler(sslHandler, position: .first)
}
try await sendClientConnectInit()
self.connectedUrl = s
}
private func makeTLSConfig() throws -> TLSConfiguration {
var tlsConfiguration =
TLSConfiguration.makeClientConfiguration()
if let rootCertificate = self.rootCertificate {
tlsConfiguration.trustRoots = .file(
rootCertificate.path)
}
if let clientCertificate = self.clientCertificate,
let clientKey = self.clientKey
{
// Load the client certificate from the PEM file
let certificate = try NIOSSLCertificate.fromPEMFile(
clientCertificate.path
).map { NIOSSLCertificateSource.certificate($0) }
tlsConfiguration.certificateChain = certificate
// Load the private key from the file
let privateKey = try NIOSSLPrivateKey(
file: clientKey.path, format: .pem)
tlsConfiguration.privateKey = .privateKey(
privateKey)
}
return tlsConfiguration
}
private func sendClientConnectInit() async throws {
var initialConnect = ConnectInfo(
verbose: false, pedantic: false, userJwt: nil, nkey: "", name: "", echo: true,
lang: self.lang, version: self.version, natsProtocol: .dynamic, tlsRequired: false,
user: self.auth?.user ?? "", pass: self.auth?.password ?? "",
authToken: self.auth?.token ?? "", headers: true, noResponders: true)
if self.auth?.nkey != nil && self.auth?.nkeyPath != nil {
throw NatsError.ConnectError.invalidConfig("cannot use both nkey and nkeyPath")
}
if let auth = self.auth, let credentialsPath = auth.credentialsPath {
let credentials = try await URLSession.shared.data(from: credentialsPath).0
guard let jwt = JwtUtils.parseDecoratedJWT(contents: credentials) else {
throw NatsError.ConnectError.invalidConfig(
"failed to extract JWT from credentials file")
}
guard let nkey = JwtUtils.parseDecoratedNKey(contents: credentials) else {
throw NatsError.ConnectError.invalidConfig(
"failed to extract NKEY from credentials file")
}
guard let nonce = self.serverInfo?.nonce else {
throw NatsError.ConnectError.invalidConfig("missing nonce")
}
let keypair = try KeyPair(seed: String(data: nkey, encoding: .utf8)!)
let nonceData = nonce.data(using: .utf8)!
let sig = try keypair.sign(input: nonceData)
let base64sig = sig.base64EncodedURLSafeNotPadded()
initialConnect.signature = base64sig
initialConnect.userJwt = String(data: jwt, encoding: .utf8)!
}
if let nkey = self.auth?.nkeyPath {
let nkeyData = try await URLSession.shared.data(from: nkey).0
guard let nkeyContent = String(data: nkeyData, encoding: .utf8) else {
throw NatsError.ConnectError.invalidConfig("failed to read NKEY file")
}
let keypair = try KeyPair(
seed: nkeyContent.trimmingCharacters(in: .whitespacesAndNewlines)
)
guard let nonce = self.serverInfo?.nonce else {
throw NatsError.ConnectError.invalidConfig("missing nonce")
}
let sig = try keypair.sign(input: nonce.data(using: .utf8)!)
let base64sig = sig.base64EncodedURLSafeNotPadded()
initialConnect.signature = base64sig
initialConnect.nkey = keypair.publicKeyEncoded
}
if let nkey = self.auth?.nkey {
let keypair = try KeyPair(seed: nkey)
guard let nonce = self.serverInfo?.nonce else {
throw NatsError.ConnectError.invalidConfig("missing nonce")
}
let nonceData = nonce.data(using: .utf8)!
let sig = try keypair.sign(input: nonceData)
let base64sig = sig.base64EncodedURLSafeNotPadded()
initialConnect.signature = base64sig
initialConnect.nkey = keypair.publicKeyEncoded
}
let connect = initialConnect
// this continuation can throw NatsError.ServerError if server responds with
// -ERR to client connect (e.g. auth error)
try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { continuation in
connectionEstablishedContinuation.withLockedValue { $0 = continuation }
Task.detached {
do {
try await self.write(operation: ClientOp.connect(connect))
try await self.write(operation: ClientOp.ping)
self.channel?.flush()
} catch {
let continuationToResume: CheckedContinuation<Void, Error>? = self
.connectionEstablishedContinuation.withLockedValue { cont in
guard let c = cont else { return nil }
cont = nil
return c
}
if let continuation = continuationToResume {
continuation.resume(throwing: error)
}
}
}
}
} onCancel: {
logger.debug("Client connect initialization cancelled")
// Clean up resources safely to avoid race conditions
// Capture references first to avoid concurrent access issues
let channelToClose = self.channel
let bufferToRelease = self.batchBuffer
// Clear references first to prevent other threads from using them
self.channel = nil
self.batchBuffer = nil
// Close channel asynchronously after clearing references
// This ensures BatchBuffer's deinit won't conflict with channel close
if let channel = channelToClose {
channel.eventLoop.execute {
channel.close(mode: .all, promise: nil)
}
}
// bufferToRelease will be released here after channel close is scheduled
_ = bufferToRelease
let continuationToResume: CheckedContinuation<Void, Error>? = self
.connectionEstablishedContinuation.withLockedValue { cont in
guard let c = cont else { return nil }
cont = nil
return c
}
if let continuation = continuationToResume {
continuation.resume(throwing: NatsError.ClientError.cancelled)
}
}
}
private func bootstrapConnection(
to server: URL
) -> (ClientBootstrap, EventLoopPromise<Void>) {
let upgradePromise: EventLoopPromise<Void> = self.group.any().makePromise(of: Void.self)
let bootstrap = ClientBootstrap(group: self.group)
.channelOption(
ChannelOptions.socket(
SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR),
value: 1
)
.channelInitializer { channel in
if self.requireTls && self.tlsFirst {
upgradePromise.succeed(())
do {
let tlsConfig = try self.makeTLSConfig()
let sslContext = try NIOSSLContext(
configuration: tlsConfig)
let sslHandler = try NIOSSLClientHandler(
context: sslContext, serverHostname: server.host!)
//Fixme(jrm): do not ignore error from addHandler future.
channel.pipeline.addHandler(sslHandler).flatMap { _ in
channel.pipeline.addHandler(self)
}.whenComplete { result in
switch result {
case .success():
print("success")
case .failure(let error):
print("error: \(error)")
}
}
return channel.eventLoop.makeSucceededFuture(())
} catch {
let tlsError = NatsError.ConnectError.tlsFailure(error)
return channel.eventLoop.makeFailedFuture(tlsError)
}
} else {
if server.scheme == "ws" || server.scheme == "wss" {
let httpUpgradeRequestHandler = HTTPUpgradeRequestHandler(
host: server.host ?? "localhost",
path: server.path,
query: server.query,
headers: HTTPHeaders(), // TODO (mtmk): pass in from client options
upgradePromise: upgradePromise)
let httpUpgradeRequestHandlerBox = NIOLoopBound(
httpUpgradeRequestHandler, eventLoop: channel.eventLoop)
let websocketUpgrader = NIOWebSocketClientUpgrader(
maxFrameSize: 8 * 1024 * 1024,
automaticErrorHandling: true,
upgradePipelineHandler: { channel, _ in
let wsh = NIOWebSocketFrameAggregator(
minNonFinalFragmentSize: 0,
maxAccumulatedFrameCount: Int.max,
maxAccumulatedFrameSize: Int.max
)
return channel.pipeline.addHandler(wsh).flatMap {
channel.pipeline.addHandler(WebSocketByteBufferCodec()).flatMap
{
channel.pipeline.addHandler(self)
}
}
}
)
let config: NIOHTTPClientUpgradeConfiguration = (
upgraders: [websocketUpgrader],
completionHandler: { context in
upgradePromise.succeed(())
channel.pipeline.removeHandler(
httpUpgradeRequestHandlerBox.value, promise: nil)
}
)
if server.scheme == "wss" {
do {
let tlsConfig = try self.makeTLSConfig()
let sslContext = try NIOSSLContext(
configuration: tlsConfig)
let sslHandler = try NIOSSLClientHandler(
context: sslContext, serverHostname: server.host!)
// The sync methods here are safe because we're on the channel event loop
// due to the promise originating on the event loop of the channel.
try channel.pipeline.syncOperations.addHandler(sslHandler)
} catch {
let tlsError = NatsError.ConnectError.tlsFailure(error)
upgradePromise.fail(tlsError)
return channel.eventLoop.makeFailedFuture(tlsError)
}
}
//Fixme(jrm): do not ignore error from addHandler future.
channel.pipeline.addHTTPClientHandlers(
leftOverBytesStrategy: .forwardBytes,
withClientUpgrade: config
).flatMap {
channel.pipeline.addHandler(httpUpgradeRequestHandlerBox.value)
}.whenComplete { result in
switch result {
case .success():
logger.debug("success")
case .failure(let error):
logger.debug("error: \(error)")
}
}
} else {
upgradePromise.succeed(())
//Fixme(jrm): do not ignore error from addHandler future.
channel.pipeline.addHandler(self).whenComplete { result in
switch result {
case .success():
logger.debug("success")
case .failure(let error):
logger.debug("error: \(error)")
}
}
}
return channel.eventLoop.makeSucceededFuture(())
}
}.connectTimeout(.seconds(5))
return (bootstrap, upgradePromise)
}
private func updateServersList(info: ServerInfo) {
if let connectUrls = info.connectUrls {
for connectUrl in connectUrls {
guard let url = URL(string: connectUrl) else {
continue
}
if !self.urls.contains(url) {
urls.append(url)
}
}
}
}
func close() async throws {
self.reconnectTask?.cancel()
try await self.reconnectTask?.value
guard let eventLoop = self.channel?.eventLoop else {
self.state.withLockedValue { $0 = .closed }
self.pingTask?.cancel()
clearPendingPings() // Clear pending pings to avoid promise leaks
self.fire(.closed)
return
}
let promise = eventLoop.makePromise(of: Void.self)
eventLoop.execute {
self.state.withLockedValue { $0 = .closed }
self.pingTask?.cancel()
self.clearPendingPings() // Clear pending pings to avoid promise leaks
self.channel?.close(mode: .all, promise: promise)
}
do {
try await promise.futureResult.get()
} catch ChannelError.alreadyClosed {
// we don't want to throw an error if channel is already closed
// as that would mean we would get an error closing client during reconnect
}
self.fire(.closed)
}
private func disconnect() async throws {
self.pingTask?.cancel()
clearPendingPings() // Clear pending pings to avoid promise leaks
// Safely clear batchBuffer before closing channel
// This prevents race conditions during deallocation
let bufferToRelease = self.batchBuffer
self.batchBuffer = nil
_ = bufferToRelease // Release after clearing reference
try await self.channel?.close().get()
}
/// Clear all pending ping requests to avoid promise leaks
private func clearPendingPings() {
let pendingPings = pingQueue.dequeueAll()
for ping in pendingPings {
ping.cancel()
}
if !pendingPings.isEmpty {
logger.debug("Cleared \(pendingPings.count) pending ping(s)")
}
}
func suspend() async throws {
self.reconnectTask?.cancel()
_ = try await self.reconnectTask?.value
// Handle case where channel is already nil (e.g., during rapid reconnections)
guard let eventLoop = self.channel?.eventLoop else {
// Set state to suspended even if channel is nil
self.state.withLockedValue { $0 = .suspended }
clearPendingPings() // Clear pending pings to avoid promise leaks
return
}
let promise = eventLoop.makePromise(of: Void.self)
eventLoop.execute { // This ensures the code block runs on the event loop
let shouldClose = self.state.withLockedValue { currentState in
let wasConnected = currentState == .connected
currentState = .suspended
return wasConnected
}
if shouldClose {
self.pingTask?.cancel()
self.clearPendingPings() // Clear pending pings to avoid promise leaks
self.channel?.close(mode: .all, promise: promise)
} else {
self.clearPendingPings() // Clear pending pings even if not closing
promise.succeed()
}
}
try await promise.futureResult.get()
self.fire(.suspended)
}
func resume() async throws {
guard let eventLoop = self.channel?.eventLoop else {
throw NatsError.ClientError.internalError("channel should not be nil")
}
try await eventLoop.submit {
let canResume = self.state.withLockedValue { $0 == .suspended }
guard canResume else {
throw NatsError.ClientError.invalidConnection(
"unable to resume connection - connection is not in suspended state")
}
self.handleReconnect()
}.get()
}
func reconnect() async throws {
try await suspend()
try await resume()
}
internal func sendPing(_ rttCommand: RttCommand? = nil) async {
let pingsOut = self.outstandingPings.wrappingIncrementThenLoad(
ordering: AtomicUpdateOrdering.relaxed)
if pingsOut > 2 {
handleDisconnect()
return
}
let ping = ClientOp.ping
do {
self.pingQueue.enqueue(rttCommand ?? RttCommand.makeFrom(channel: self.channel))
try await self.write(operation: ping)
logger.debug("sent ping: \(pingsOut)")
} catch {
logger.error("Unable to send ping: \(error)")
// Trigger reconnect on ping failure - connection may be broken
let currentState = state.withLockedValue { $0 }
if currentState == .connected {
handleDisconnect()
}
}
}
func channelActive(context: ChannelHandlerContext) {
logger.debug("TCP channel active")
parseRemainder.withLockedValue { $0 = nil }
inputBuffer = context.channel.allocator.buffer(capacity: 1024 * 1024 * 8)
}
func channelInactive(context: ChannelHandlerContext) {
logger.debug("TCP channel inactive")
// If we lost the channel before we delivered server INFO or connection
// establishment, make sure to fail any pending continuations to avoid leaks.
// Use captured error if available (e.g., TLS failure), otherwise use connectionClosed.
let errorToUse: Error = capturedConnectionError.withLockedValue({ err in
let captured = err
err = nil // Clear after using
if let capturedError = captured {
return NatsError.ConnectError.tlsFailure(capturedError)
} else {
return NatsError.ClientError.connectionClosed
}
})
if let continuation = serverInfoContinuation.withLockedValue({ cont in
let toResume = cont
cont = nil
return toResume
}) {
continuation.resume(throwing: errorToUse)
}
if let continuation = connectionEstablishedContinuation.withLockedValue({ cont in
let toResume = cont
cont = nil
return toResume
}) {
continuation.resume(throwing: errorToUse)
}
let shouldHandleDisconnect = state.withLockedValue { $0 != .closed && $0 != .disconnected }
if shouldHandleDisconnect {
handleDisconnect()
}
}
func errorCaught(context: ChannelHandlerContext, error: Error) {
logger.debug("Encountered error on the channel: \(error)")
// Capture connection-stage errors (especially TLS) for proper error reporting BEFORE closing
let isConnecting = state.withLockedValue { $0 == .pending || $0 == .connecting }
if isConnecting {
capturedConnectionError.withLockedValue { $0 = error }
}
context.close(promise: nil)
if let natsErr = error as? NatsErrorProtocol {
self.fire(.error(natsErr))
} else {
// debug
// uncleanShutdown TLS
logger.debug("Channel error (will reconnect if needed): \(error)")
}
// handleDisconnect
// context.close() channelInactive
//
}
func handleDisconnect() {
// Prevent duplicate disconnect handling
let shouldProceed = state.withLockedValue { currentState -> Bool in
if currentState == .disconnected || currentState == .closed {
return false // Already in disconnected/closed state
}
return true
}
guard shouldProceed else {
return
}
// Set state to disconnected after check
state.withLockedValue { $0 = .disconnected }
// Clean up pending continuations to prevent leaks
if let continuation = serverInfoContinuation.withLockedValue({ cont in
let toResume = cont
cont = nil
return toResume
}) {
continuation.resume(throwing: NatsError.ClientError.connectionClosed)
}
if let continuation = connectionEstablishedContinuation.withLockedValue({ cont in
let toResume = cont
cont = nil
return toResume
}) {
continuation.resume(throwing: NatsError.ClientError.connectionClosed)
}
// Safely clear batchBuffer first to avoid race conditions
let bufferToRelease = self.batchBuffer
self.batchBuffer = nil
_ = bufferToRelease // Release after clearing reference
if let channel = self.channel {
let promise = channel.eventLoop.makePromise(of: Void.self)
Task {
do {
try await self.disconnect()
promise.succeed()
} catch ChannelError.alreadyClosed {
// if the channel was already closed, no need to return error
promise.succeed()
} catch {
promise.fail(error)
}
}
promise.futureResult.whenComplete { result in
do {
try result.get()
} catch {
// debug TLS
logger.debug("Connection closed with error (will reconnect): \(error)")
}
// .disconnected
//
self.fire(.disconnected)
// Only start reconnect after disconnect is complete
self.handleReconnect()
}
} else {
self.fire(.disconnected)
handleReconnect()
}
}
func handleReconnect() {
// Cancel any existing reconnect task to prevent multiple concurrent reconnections
if let oldTask = reconnectTask {
oldTask.cancel()
}
reconnectTask = Task {
var connected = false
while !Task.isCancelled
&& (maxReconnects == nil || self.reconnectAttempts < maxReconnects!)
{
do {
try await self.connect()
connected = true
break // Successfully connected
} catch is CancellationError {
logger.debug("Reconnect task cancelled")
return
} catch {
logger.debug("Could not reconnect: \(error)")
if !Task.isCancelled {
try? await Task.sleep(nanoseconds: self.reconnectWait)
}
}
}
// Early return if cancelled
if Task.isCancelled {
logger.debug("Reconnect task cancelled after connection attempts")
return
}
// If we got here without connecting and weren't cancelled, we hit max reconnects
if !connected {
logger.error("Could not reconnect; maxReconnects exceeded")
try? await self.close()
return
}
// Recreate subscriptions - safely copy first
let subsToRestore = subscriptions.withLockedValue { Array($0) }
for (sid, sub) in subsToRestore {
do {
try await write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
} catch {
logger.error("Error recreating subscription \(sid): \(error)")
}
}
self.channel?.eventLoop.execute {
self.state.withLockedValue { $0 = .connected }
self.fire(.connected)
}
}
}
func write(operation: ClientOp) async throws {
guard let buffer = self.batchBuffer else {
// If state is connected but batchBuffer is nil, this is a "fake connection" state
// Trigger reconnect to recover
let currentState = state.withLockedValue { $0 }
if currentState == .connected {
// debug
logger.debug("Write failed: batchBuffer is nil, triggering reconnect")
handleDisconnect()
}
throw NatsError.ClientError.invalidConnection("not connected")
}
do {
try await buffer.writeMessage(operation)
} catch {
// Trigger reconnect on write failure - connection may be broken
let currentState = state.withLockedValue { $0 }
if currentState == .connected {
// debug
logger.debug("Write operation failed, triggering reconnect: \(error)")
handleDisconnect()
}
throw NatsError.ClientError.io(error)
}
}
internal func subscribe(
_ subject: String, queue: String? = nil
) async throws -> NatsSubscription {
let sid = self.subscriptionCounter.wrappingIncrementThenLoad(
ordering: AtomicUpdateOrdering.relaxed)
let sub = try NatsSubscription(sid: sid, subject: subject, queue: queue, conn: self)
// Add subscription BEFORE sending command to avoid race condition
subscriptions.withLockedValue { $0[sid] = sub }
do {
try await write(operation: ClientOp.subscribe((sid, subject, queue)))
} catch {
// Remove subscription if subscribe command fails
subscriptions.withLockedValue { $0.removeValue(forKey: sid) }
throw error
}
return sub
}
internal func unsubscribe(sub: NatsSubscription, max: UInt64?) async throws {
if let max, sub.delivered < max {
// if max is set and the sub has not yet reached it, send unsub with max set
// and do not remove the sub from connection
try await write(operation: ClientOp.unsubscribe((sid: sub.sid, max: max)))
sub.max = max
} else {
// if max is not set or the subscription received at least as many
// messages as max, send unsub command without max and remove sub from connection
try await write(operation: ClientOp.unsubscribe((sid: sub.sid, max: nil)))
self.removeSub(sub: sub)
}
}
internal func removeSub(sub: NatsSubscription) {
subscriptions.withLockedValue { $0.removeValue(forKey: sub.sid) }
sub.complete()
}
}
extension ConnectionHandler {
internal func fire(_ event: NatsEvent) {
let eventKind = event.kind()
guard let handlerStore = self.eventHandlerStore[eventKind] else { return }
for handler in handlerStore {
handler.handler(event)
}
}
internal func addListeners(
for events: [NatsEventKind], using handler: @escaping (NatsEvent) -> Void
) -> String {
let id = String.hash()
for event in events {
if self.eventHandlerStore[event] == nil {
self.eventHandlerStore[event] = []
}
self.eventHandlerStore[event]?.append(
NatsEventHandler(lid: id, handler: handler))
}
return id
}
internal func removeListener(_ id: String) {
for event in NatsEventKind.all {
let handlerStore = self.eventHandlerStore[event]
if let store = handlerStore {
self.eventHandlerStore[event] = store.filter { $0.listenerId != id }
}
}
}
}
/// Nats events
public enum NatsEventKind: String {
case connected = "connected"
case disconnected = "disconnected"
case closed = "closed"
case suspended = "suspended"
case lameDuckMode = "lameDuckMode"
case error = "error"
static let all = [connected, disconnected, closed, lameDuckMode, error]
}
public enum NatsEvent {
case connected
case disconnected
case suspended
case closed
case lameDuckMode
case error(NatsErrorProtocol)
public func kind() -> NatsEventKind {
switch self {
case .connected:
return .connected
case .disconnected:
return .disconnected
case .suspended:
return .suspended
case .closed:
return .closed
case .lameDuckMode:
return .lameDuckMode
case .error(_):
return .error
}
}
}
internal struct NatsEventHandler {
let listenerId: String
let handler: (NatsEvent) -> Void
init(lid: String, handler: @escaping (NatsEvent) -> Void) {
self.listenerId = lid
self.handler = handler
}
}