- handleReconnect 中的日志从 debug 改为 info/warning - 添加重连尝试次数和 reconnectAttempts 计数器的显示 - 添加等待时间日志 - 便于诊断重连失败的原因
1227 lines
48 KiB
Swift
1227 lines
48 KiB
Swift
// Copyright 2024 The NATS Authors
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
import Atomics
|
||
import Dispatch
|
||
import Foundation
|
||
import NIO
|
||
import NIOConcurrencyHelpers
|
||
import NIOFoundationCompat
|
||
import NIOHTTP1
|
||
import NIOSSL
|
||
import NIOWebSocket
|
||
import NKeys
|
||
|
||
class ConnectionHandler: ChannelInboundHandler {
|
||
let lang = "Swift"
|
||
let version = "0.0.1"
|
||
|
||
internal var connectedUrl: URL?
|
||
internal let allocator = ByteBufferAllocator()
|
||
internal var inputBuffer: ByteBuffer
|
||
internal var channel: Channel?
|
||
|
||
private var eventHandlerStore: [NatsEventKind: [NatsEventHandler]] = [:]
|
||
|
||
// Connection options
|
||
internal var retryOnFailedConnect = false
|
||
private var urls: [URL]
|
||
// nanoseconds representation of TimeInterval
|
||
private let reconnectWait: UInt64
|
||
private let maxReconnects: Int?
|
||
private let retainServersOrder: Bool
|
||
private let pingInterval: TimeInterval
|
||
private let requireTls: Bool
|
||
private let tlsFirst: Bool
|
||
private var rootCertificate: URL?
|
||
private var clientCertificate: URL?
|
||
private var clientKey: URL?
|
||
|
||
typealias InboundIn = ByteBuffer
|
||
private let state = NIOLockedValueBox(NatsState.pending)
|
||
private let subscriptions = NIOLockedValueBox([UInt64: NatsSubscription]())
|
||
|
||
// Helper methods for state access
|
||
internal var currentState: NatsState {
|
||
state.withLockedValue { $0 }
|
||
}
|
||
|
||
internal func setState(_ newState: NatsState) {
|
||
state.withLockedValue { $0 = newState }
|
||
}
|
||
|
||
private var subscriptionCounter = ManagedAtomic<UInt64>(0)
|
||
private var serverInfo: ServerInfo?
|
||
private var auth: Auth?
|
||
private let parseRemainder = NIOLockedValueBox<Data?>(nil)
|
||
private var pingTask: RepeatedTask?
|
||
private var outstandingPings = ManagedAtomic<UInt8>(0)
|
||
private var reconnectAttempts = 0
|
||
private var reconnectTask: Task<(), Error>? = nil
|
||
private let capturedConnectionError = NIOLockedValueBox<Error?>(nil)
|
||
|
||
private var group: MultiThreadedEventLoopGroup
|
||
|
||
private let serverInfoContinuation = NIOLockedValueBox<CheckedContinuation<ServerInfo, Error>?>(
|
||
nil)
|
||
private let connectionEstablishedContinuation = NIOLockedValueBox<
|
||
CheckedContinuation<Void, Error>?
|
||
>(nil)
|
||
|
||
private let pingQueue = ConcurrentQueue<RttCommand>()
|
||
private(set) var batchBuffer: BatchBuffer?
|
||
|
||
init(
|
||
inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?,
|
||
retainServersOrder: Bool,
|
||
pingInterval: TimeInterval, auth: Auth?, requireTls: Bool, tlsFirst: Bool,
|
||
clientCertificate: URL?, clientKey: URL?,
|
||
rootCertificate: URL?, retryOnFailedConnect: Bool
|
||
) {
|
||
self.urls = urls
|
||
self.group = .singleton
|
||
self.inputBuffer = allocator.buffer(capacity: 1024)
|
||
self.reconnectWait = UInt64(reconnectWait * 1_000_000_000)
|
||
self.maxReconnects = maxReconnects
|
||
self.retainServersOrder = retainServersOrder
|
||
self.auth = auth
|
||
self.pingInterval = pingInterval
|
||
self.requireTls = requireTls
|
||
self.tlsFirst = tlsFirst
|
||
self.clientCertificate = clientCertificate
|
||
self.clientKey = clientKey
|
||
self.rootCertificate = rootCertificate
|
||
self.retryOnFailedConnect = retryOnFailedConnect
|
||
}
|
||
|
||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||
var byteBuffer = self.unwrapInboundIn(data)
|
||
inputBuffer.writeBuffer(&byteBuffer)
|
||
}
|
||
|
||
func channelReadComplete(context: ChannelHandlerContext) {
|
||
guard inputBuffer.readableBytes > 0 else {
|
||
return
|
||
}
|
||
|
||
var inputChunk = Data(buffer: inputBuffer)
|
||
|
||
let remainder = parseRemainder.withLockedValue { value in
|
||
let current = value
|
||
value = nil
|
||
return current
|
||
}
|
||
|
||
if let remainder = remainder, !remainder.isEmpty {
|
||
inputChunk.prepend(remainder)
|
||
}
|
||
|
||
let parseResult: (ops: [ServerOp], remainder: Data?)
|
||
do {
|
||
parseResult = try inputChunk.parseOutMessages()
|
||
} catch {
|
||
// if parsing throws an error, clear buffer and remainder, then reconnect
|
||
inputBuffer.clear()
|
||
parseRemainder.withLockedValue { $0 = nil }
|
||
context.fireErrorCaught(error)
|
||
return
|
||
}
|
||
if let remainder = parseResult.remainder {
|
||
parseRemainder.withLockedValue { $0 = remainder }
|
||
}
|
||
for op in parseResult.ops {
|
||
// Only resume the server info continuation when we actually receive
|
||
// an INFO or -ERR op. Do NOT clear it for unrelated ops.
|
||
switch op {
|
||
case .error(let err):
|
||
if let continuation = serverInfoContinuation.withLockedValue({ cont in
|
||
let toResume = cont
|
||
cont = nil
|
||
return toResume
|
||
}) {
|
||
logger.debug("server info error")
|
||
continuation.resume(throwing: err)
|
||
continue
|
||
}
|
||
case .info(let info):
|
||
if let continuation = serverInfoContinuation.withLockedValue({ cont in
|
||
let toResume = cont
|
||
cont = nil
|
||
return toResume
|
||
}) {
|
||
logger.debug("server info")
|
||
continuation.resume(returning: info)
|
||
continue
|
||
}
|
||
default:
|
||
break
|
||
}
|
||
|
||
let connEstablishedCont = connectionEstablishedContinuation.withLockedValue { cont in
|
||
let toResume = cont
|
||
cont = nil
|
||
return toResume
|
||
}
|
||
|
||
if let continuation = connEstablishedCont {
|
||
logger.debug("conn established")
|
||
switch op {
|
||
case .error(let err):
|
||
continuation.resume(throwing: err)
|
||
default:
|
||
continuation.resume()
|
||
}
|
||
continue
|
||
}
|
||
|
||
switch op {
|
||
case .ping:
|
||
logger.debug("ping")
|
||
Task {
|
||
do {
|
||
try await self.write(operation: .pong)
|
||
} catch let err as NatsError.ClientError {
|
||
logger.error("error sending pong: \(err)")
|
||
self.fire(
|
||
.error(err))
|
||
} catch {
|
||
logger.error("unexpected error sending pong: \(error)")
|
||
}
|
||
}
|
||
case .pong:
|
||
logger.debug("pong")
|
||
self.outstandingPings.store(0, ordering: AtomicStoreOrdering.relaxed)
|
||
self.pingQueue.dequeue()?.setRoundTripTime()
|
||
case .error(let err):
|
||
logger.debug("error \(err)")
|
||
|
||
switch err {
|
||
case .staleConnection, .maxConnectionsExceeded:
|
||
inputBuffer.clear()
|
||
parseRemainder.withLockedValue { $0 = nil }
|
||
context.fireErrorCaught(err)
|
||
case .permissionsViolation(let operation, let subject, _):
|
||
switch operation {
|
||
case .subscribe:
|
||
subscriptions.withLockedValue { subs in
|
||
for (_, s) in subs {
|
||
if s.subject == subject {
|
||
s.receiveError(NatsError.SubscriptionError.permissionDenied)
|
||
}
|
||
}
|
||
}
|
||
case .publish:
|
||
self.fire(.error(err))
|
||
}
|
||
default:
|
||
self.fire(.error(err))
|
||
}
|
||
|
||
let normalizedError = err.normalizedError
|
||
// on some errors, force reconnect
|
||
if normalizedError == "stale connection"
|
||
|| normalizedError == "maximum connections exceeded"
|
||
{
|
||
inputBuffer.clear()
|
||
parseRemainder.withLockedValue { $0 = nil }
|
||
context.fireErrorCaught(err)
|
||
} else {
|
||
self.fire(.error(err))
|
||
}
|
||
case .message(let msg):
|
||
self.handleIncomingMessage(msg)
|
||
case .hMessage(let msg):
|
||
self.handleIncomingMessage(msg)
|
||
case .info(let serverInfo):
|
||
logger.debug("info \(op)")
|
||
self.serverInfo = serverInfo
|
||
if serverInfo.lameDuckMode {
|
||
self.fire(.lameDuckMode)
|
||
}
|
||
self.serverInfo = serverInfo
|
||
updateServersList(info: serverInfo)
|
||
default:
|
||
logger.debug("unknown operation type: \(op)")
|
||
}
|
||
}
|
||
inputBuffer.clear()
|
||
}
|
||
|
||
private func handleIncomingMessage(_ message: MessageInbound) {
|
||
let natsMsg = NatsMessage(
|
||
payload: message.payload, subject: message.subject, replySubject: message.reply,
|
||
length: message.length, headers: nil, status: nil, description: nil)
|
||
subscriptions.withLockedValue { subs in
|
||
if let sub = subs[message.sid] {
|
||
sub.receiveMessage(natsMsg)
|
||
}
|
||
}
|
||
}
|
||
|
||
private func handleIncomingMessage(_ message: HMessageInbound) {
|
||
let natsMsg = NatsMessage(
|
||
payload: message.payload, subject: message.subject, replySubject: message.reply,
|
||
length: message.length, headers: message.headers, status: message.status,
|
||
description: message.description)
|
||
subscriptions.withLockedValue { subs in
|
||
if let sub = subs[message.sid] {
|
||
sub.receiveMessage(natsMsg)
|
||
}
|
||
}
|
||
}
|
||
|
||
func connect() async throws {
|
||
self.setState(.connecting)
|
||
var servers = self.urls
|
||
if !self.retainServersOrder {
|
||
servers = self.urls.shuffled()
|
||
}
|
||
var lastErr: Error?
|
||
|
||
// if there are more reconnect attempts than the number of servers,
|
||
// we are after the initial connect, so sleep between servers
|
||
let shouldSleep = self.reconnectAttempts >= self.urls.count
|
||
for s in servers {
|
||
if let maxReconnects {
|
||
if reconnectAttempts > 0 && reconnectAttempts >= maxReconnects {
|
||
throw NatsError.ClientError.maxReconnects
|
||
}
|
||
}
|
||
self.reconnectAttempts += 1
|
||
if shouldSleep {
|
||
try await Task.sleep(nanoseconds: self.reconnectWait)
|
||
}
|
||
|
||
do {
|
||
try await connectToServer(s: s)
|
||
} catch let error as NatsError.ConnectError {
|
||
if case .invalidConfig(_) = error {
|
||
throw error
|
||
}
|
||
logger.debug("error connecting to server: \(error)")
|
||
lastErr = error
|
||
continue
|
||
} catch {
|
||
logger.debug("error connecting to server: \(error)")
|
||
lastErr = error
|
||
continue
|
||
}
|
||
lastErr = nil
|
||
break
|
||
}
|
||
if let lastErr {
|
||
self.state.withLockedValue { $0 = .disconnected }
|
||
switch lastErr {
|
||
case let error as ChannelError:
|
||
serverInfoContinuation.withLockedValue { $0 = nil }
|
||
var err: NatsError.ConnectError
|
||
switch error.self {
|
||
case .connectTimeout(_):
|
||
err = .timeout
|
||
default:
|
||
err = .io(error)
|
||
}
|
||
throw err
|
||
case let error as NIOConnectionError:
|
||
if let dnsAAAAError = error.dnsAAAAError {
|
||
throw NatsError.ConnectError.dns(dnsAAAAError)
|
||
} else if let dnsAError = error.dnsAError {
|
||
throw NatsError.ConnectError.dns(dnsAError)
|
||
} else {
|
||
throw NatsError.ConnectError.io(error)
|
||
}
|
||
case let err as NIOSSLError:
|
||
throw NatsError.ConnectError.tlsFailure(err)
|
||
case let err as BoringSSLError:
|
||
throw NatsError.ConnectError.tlsFailure(err)
|
||
case let err as NatsError.ServerError:
|
||
throw err
|
||
case let err as NatsError.ConnectError:
|
||
throw err
|
||
default:
|
||
throw NatsError.ConnectError.io(lastErr)
|
||
}
|
||
}
|
||
self.reconnectAttempts = 0
|
||
guard let channel = self.channel else {
|
||
throw NatsError.ClientError.internalError("empty channel")
|
||
}
|
||
|
||
// 重连成功后重置 ping 计数器,避免累积的失败计数导致立即断开
|
||
self.outstandingPings.store(0, ordering: .relaxed)
|
||
|
||
// Schedule the task to send a PING periodically
|
||
let pingInterval = TimeAmount.nanoseconds(Int64(self.pingInterval * 1_000_000_000))
|
||
self.pingTask = channel.eventLoop.scheduleRepeatedTask(
|
||
initialDelay: pingInterval, delay: pingInterval
|
||
) { _ in
|
||
Task { await self.sendPing() }
|
||
}
|
||
logger.debug("connection established")
|
||
return
|
||
}
|
||
|
||
private func connectToServer(s: URL) async throws {
|
||
var infoTask: Task<(), Never>? = nil
|
||
// this continuation can throw NatsError.ServerError if server responds with
|
||
// -ERR to client connect (e.g. auth error)
|
||
let info: ServerInfo = try await withCheckedThrowingContinuation { continuation in
|
||
serverInfoContinuation.withLockedValue { $0 = continuation }
|
||
infoTask = Task {
|
||
await withTaskCancellationHandler {
|
||
do {
|
||
let (bootstrap, upgradePromise) = self.bootstrapConnection(to: s)
|
||
|
||
guard let host = s.host, let port = s.port else {
|
||
upgradePromise.succeed() // avoid promise leaks
|
||
throw NatsError.ConnectError.invalidConfig("no url")
|
||
}
|
||
|
||
let connect = bootstrap.connect(host: host, port: port)
|
||
connect.cascadeFailure(to: upgradePromise)
|
||
self.channel = try await connect.get()
|
||
|
||
guard let channel = self.channel else {
|
||
upgradePromise.succeed() // avoid promise leaks
|
||
throw NatsError.ClientError.internalError("empty channel")
|
||
}
|
||
|
||
try await upgradePromise.futureResult.get()
|
||
self.batchBuffer = BatchBuffer(channel: channel)
|
||
} catch {
|
||
let continuationToResume: CheckedContinuation<ServerInfo, Error>? = self
|
||
.serverInfoContinuation.withLockedValue { cont in
|
||
guard let c = cont else { return nil }
|
||
cont = nil
|
||
return c
|
||
}
|
||
if let continuation = continuationToResume {
|
||
continuation.resume(throwing: error)
|
||
}
|
||
}
|
||
} onCancel: {
|
||
logger.debug("Connection task cancelled")
|
||
// Clean up resources safely to avoid race conditions
|
||
// Capture references first to avoid concurrent access issues
|
||
let channelToClose = self.channel
|
||
let bufferToRelease = self.batchBuffer
|
||
|
||
// Clear references first to prevent other threads from using them
|
||
self.channel = nil
|
||
self.batchBuffer = nil
|
||
|
||
// Close channel asynchronously after clearing references
|
||
// This ensures BatchBuffer's deinit won't conflict with channel close
|
||
if let channel = channelToClose {
|
||
channel.eventLoop.execute {
|
||
channel.close(mode: .all, promise: nil)
|
||
}
|
||
}
|
||
|
||
// bufferToRelease will be released here after channel close is scheduled
|
||
_ = bufferToRelease
|
||
|
||
let continuationToResume: CheckedContinuation<ServerInfo, Error>? = self
|
||
.serverInfoContinuation.withLockedValue { cont in
|
||
guard let c = cont else { return nil }
|
||
cont = nil
|
||
return c
|
||
}
|
||
if let continuation = continuationToResume {
|
||
continuation.resume(throwing: NatsError.ClientError.cancelled)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
await infoTask?.value
|
||
self.serverInfo = info
|
||
if (info.tlsRequired ?? false || self.requireTls) && !self.tlsFirst && s.scheme != "wss" {
|
||
let tlsConfig = try makeTLSConfig()
|
||
let sslContext = try NIOSSLContext(configuration: tlsConfig)
|
||
let sslHandler = try NIOSSLClientHandler(
|
||
context: sslContext, serverHostname: s.host)
|
||
try await self.channel?.pipeline.addHandler(sslHandler, position: .first)
|
||
}
|
||
|
||
try await sendClientConnectInit()
|
||
self.connectedUrl = s
|
||
}
|
||
|
||
private func makeTLSConfig() throws -> TLSConfiguration {
|
||
var tlsConfiguration =
|
||
TLSConfiguration.makeClientConfiguration()
|
||
if let rootCertificate = self.rootCertificate {
|
||
tlsConfiguration.trustRoots = .file(
|
||
rootCertificate.path)
|
||
}
|
||
if let clientCertificate = self.clientCertificate,
|
||
let clientKey = self.clientKey
|
||
{
|
||
// Load the client certificate from the PEM file
|
||
let certificate = try NIOSSLCertificate.fromPEMFile(
|
||
clientCertificate.path
|
||
).map { NIOSSLCertificateSource.certificate($0) }
|
||
tlsConfiguration.certificateChain = certificate
|
||
|
||
// Load the private key from the file
|
||
let privateKey = try NIOSSLPrivateKey(
|
||
file: clientKey.path, format: .pem)
|
||
tlsConfiguration.privateKey = .privateKey(
|
||
privateKey)
|
||
}
|
||
return tlsConfiguration
|
||
}
|
||
|
||
private func sendClientConnectInit() async throws {
|
||
var initialConnect = ConnectInfo(
|
||
verbose: false, pedantic: false, userJwt: nil, nkey: "", name: "", echo: true,
|
||
lang: self.lang, version: self.version, natsProtocol: .dynamic, tlsRequired: false,
|
||
user: self.auth?.user ?? "", pass: self.auth?.password ?? "",
|
||
authToken: self.auth?.token ?? "", headers: true, noResponders: true)
|
||
|
||
if self.auth?.nkey != nil && self.auth?.nkeyPath != nil {
|
||
throw NatsError.ConnectError.invalidConfig("cannot use both nkey and nkeyPath")
|
||
}
|
||
if let auth = self.auth, let credentialsPath = auth.credentialsPath {
|
||
let credentials = try await URLSession.shared.data(from: credentialsPath).0
|
||
guard let jwt = JwtUtils.parseDecoratedJWT(contents: credentials) else {
|
||
throw NatsError.ConnectError.invalidConfig(
|
||
"failed to extract JWT from credentials file")
|
||
}
|
||
guard let nkey = JwtUtils.parseDecoratedNKey(contents: credentials) else {
|
||
throw NatsError.ConnectError.invalidConfig(
|
||
"failed to extract NKEY from credentials file")
|
||
}
|
||
guard let nonce = self.serverInfo?.nonce else {
|
||
throw NatsError.ConnectError.invalidConfig("missing nonce")
|
||
}
|
||
let keypair = try KeyPair(seed: String(data: nkey, encoding: .utf8)!)
|
||
let nonceData = nonce.data(using: .utf8)!
|
||
let sig = try keypair.sign(input: nonceData)
|
||
let base64sig = sig.base64EncodedURLSafeNotPadded()
|
||
initialConnect.signature = base64sig
|
||
initialConnect.userJwt = String(data: jwt, encoding: .utf8)!
|
||
}
|
||
if let nkey = self.auth?.nkeyPath {
|
||
let nkeyData = try await URLSession.shared.data(from: nkey).0
|
||
|
||
guard let nkeyContent = String(data: nkeyData, encoding: .utf8) else {
|
||
throw NatsError.ConnectError.invalidConfig("failed to read NKEY file")
|
||
}
|
||
let keypair = try KeyPair(
|
||
seed: nkeyContent.trimmingCharacters(in: .whitespacesAndNewlines)
|
||
)
|
||
|
||
guard let nonce = self.serverInfo?.nonce else {
|
||
throw NatsError.ConnectError.invalidConfig("missing nonce")
|
||
}
|
||
let sig = try keypair.sign(input: nonce.data(using: .utf8)!)
|
||
let base64sig = sig.base64EncodedURLSafeNotPadded()
|
||
initialConnect.signature = base64sig
|
||
initialConnect.nkey = keypair.publicKeyEncoded
|
||
}
|
||
if let nkey = self.auth?.nkey {
|
||
let keypair = try KeyPair(seed: nkey)
|
||
guard let nonce = self.serverInfo?.nonce else {
|
||
throw NatsError.ConnectError.invalidConfig("missing nonce")
|
||
}
|
||
let nonceData = nonce.data(using: .utf8)!
|
||
let sig = try keypair.sign(input: nonceData)
|
||
let base64sig = sig.base64EncodedURLSafeNotPadded()
|
||
initialConnect.signature = base64sig
|
||
initialConnect.nkey = keypair.publicKeyEncoded
|
||
}
|
||
let connect = initialConnect
|
||
// this continuation can throw NatsError.ServerError if server responds with
|
||
// -ERR to client connect (e.g. auth error)
|
||
try await withTaskCancellationHandler {
|
||
try await withCheckedThrowingContinuation { continuation in
|
||
connectionEstablishedContinuation.withLockedValue { $0 = continuation }
|
||
Task.detached {
|
||
do {
|
||
try await self.write(operation: ClientOp.connect(connect))
|
||
try await self.write(operation: ClientOp.ping)
|
||
self.channel?.flush()
|
||
} catch {
|
||
let continuationToResume: CheckedContinuation<Void, Error>? = self
|
||
.connectionEstablishedContinuation.withLockedValue { cont in
|
||
guard let c = cont else { return nil }
|
||
cont = nil
|
||
return c
|
||
}
|
||
if let continuation = continuationToResume {
|
||
continuation.resume(throwing: error)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
} onCancel: {
|
||
logger.debug("Client connect initialization cancelled")
|
||
// Clean up resources safely to avoid race conditions
|
||
// Capture references first to avoid concurrent access issues
|
||
let channelToClose = self.channel
|
||
let bufferToRelease = self.batchBuffer
|
||
|
||
// Clear references first to prevent other threads from using them
|
||
self.channel = nil
|
||
self.batchBuffer = nil
|
||
|
||
// Close channel asynchronously after clearing references
|
||
// This ensures BatchBuffer's deinit won't conflict with channel close
|
||
if let channel = channelToClose {
|
||
channel.eventLoop.execute {
|
||
channel.close(mode: .all, promise: nil)
|
||
}
|
||
}
|
||
|
||
// bufferToRelease will be released here after channel close is scheduled
|
||
_ = bufferToRelease
|
||
|
||
let continuationToResume: CheckedContinuation<Void, Error>? = self
|
||
.connectionEstablishedContinuation.withLockedValue { cont in
|
||
guard let c = cont else { return nil }
|
||
cont = nil
|
||
return c
|
||
}
|
||
if let continuation = continuationToResume {
|
||
continuation.resume(throwing: NatsError.ClientError.cancelled)
|
||
}
|
||
}
|
||
}
|
||
|
||
private func bootstrapConnection(
|
||
to server: URL
|
||
) -> (ClientBootstrap, EventLoopPromise<Void>) {
|
||
let upgradePromise: EventLoopPromise<Void> = self.group.any().makePromise(of: Void.self)
|
||
let bootstrap = ClientBootstrap(group: self.group)
|
||
.channelOption(
|
||
ChannelOptions.socket(
|
||
SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR),
|
||
value: 1
|
||
)
|
||
.channelInitializer { channel in
|
||
if self.requireTls && self.tlsFirst {
|
||
upgradePromise.succeed(())
|
||
do {
|
||
let tlsConfig = try self.makeTLSConfig()
|
||
let sslContext = try NIOSSLContext(
|
||
configuration: tlsConfig)
|
||
let sslHandler = try NIOSSLClientHandler(
|
||
context: sslContext, serverHostname: server.host!)
|
||
//Fixme(jrm): do not ignore error from addHandler future.
|
||
channel.pipeline.addHandler(sslHandler).flatMap { _ in
|
||
channel.pipeline.addHandler(self)
|
||
}.whenComplete { result in
|
||
switch result {
|
||
case .success():
|
||
print("success")
|
||
case .failure(let error):
|
||
print("error: \(error)")
|
||
}
|
||
}
|
||
return channel.eventLoop.makeSucceededFuture(())
|
||
} catch {
|
||
let tlsError = NatsError.ConnectError.tlsFailure(error)
|
||
return channel.eventLoop.makeFailedFuture(tlsError)
|
||
}
|
||
} else {
|
||
if server.scheme == "ws" || server.scheme == "wss" {
|
||
let httpUpgradeRequestHandler = HTTPUpgradeRequestHandler(
|
||
host: server.host ?? "localhost",
|
||
path: server.path,
|
||
query: server.query,
|
||
headers: HTTPHeaders(), // TODO (mtmk): pass in from client options
|
||
upgradePromise: upgradePromise)
|
||
let httpUpgradeRequestHandlerBox = NIOLoopBound(
|
||
httpUpgradeRequestHandler, eventLoop: channel.eventLoop)
|
||
|
||
let websocketUpgrader = NIOWebSocketClientUpgrader(
|
||
maxFrameSize: 8 * 1024 * 1024,
|
||
automaticErrorHandling: true,
|
||
upgradePipelineHandler: { channel, _ in
|
||
let wsh = NIOWebSocketFrameAggregator(
|
||
minNonFinalFragmentSize: 0,
|
||
maxAccumulatedFrameCount: Int.max,
|
||
maxAccumulatedFrameSize: Int.max
|
||
)
|
||
return channel.pipeline.addHandler(wsh).flatMap {
|
||
channel.pipeline.addHandler(WebSocketByteBufferCodec()).flatMap
|
||
{
|
||
channel.pipeline.addHandler(self)
|
||
}
|
||
}
|
||
}
|
||
)
|
||
|
||
let config: NIOHTTPClientUpgradeConfiguration = (
|
||
upgraders: [websocketUpgrader],
|
||
completionHandler: { context in
|
||
upgradePromise.succeed(())
|
||
channel.pipeline.removeHandler(
|
||
httpUpgradeRequestHandlerBox.value, promise: nil)
|
||
}
|
||
)
|
||
|
||
if server.scheme == "wss" {
|
||
do {
|
||
let tlsConfig = try self.makeTLSConfig()
|
||
let sslContext = try NIOSSLContext(
|
||
configuration: tlsConfig)
|
||
let sslHandler = try NIOSSLClientHandler(
|
||
context: sslContext, serverHostname: server.host!)
|
||
// The sync methods here are safe because we're on the channel event loop
|
||
// due to the promise originating on the event loop of the channel.
|
||
try channel.pipeline.syncOperations.addHandler(sslHandler)
|
||
} catch {
|
||
let tlsError = NatsError.ConnectError.tlsFailure(error)
|
||
upgradePromise.fail(tlsError)
|
||
return channel.eventLoop.makeFailedFuture(tlsError)
|
||
}
|
||
}
|
||
|
||
//Fixme(jrm): do not ignore error from addHandler future.
|
||
channel.pipeline.addHTTPClientHandlers(
|
||
leftOverBytesStrategy: .forwardBytes,
|
||
withClientUpgrade: config
|
||
).flatMap {
|
||
channel.pipeline.addHandler(httpUpgradeRequestHandlerBox.value)
|
||
}.whenComplete { result in
|
||
switch result {
|
||
case .success():
|
||
logger.debug("success")
|
||
case .failure(let error):
|
||
logger.debug("error: \(error)")
|
||
}
|
||
}
|
||
} else {
|
||
upgradePromise.succeed(())
|
||
//Fixme(jrm): do not ignore error from addHandler future.
|
||
channel.pipeline.addHandler(self).whenComplete { result in
|
||
switch result {
|
||
case .success():
|
||
logger.debug("success")
|
||
case .failure(let error):
|
||
logger.debug("error: \(error)")
|
||
}
|
||
}
|
||
}
|
||
return channel.eventLoop.makeSucceededFuture(())
|
||
}
|
||
}.connectTimeout(.seconds(5))
|
||
return (bootstrap, upgradePromise)
|
||
}
|
||
|
||
private func updateServersList(info: ServerInfo) {
|
||
if let connectUrls = info.connectUrls {
|
||
for connectUrl in connectUrls {
|
||
guard let url = URL(string: connectUrl) else {
|
||
continue
|
||
}
|
||
if !self.urls.contains(url) {
|
||
urls.append(url)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func close() async throws {
|
||
self.reconnectTask?.cancel()
|
||
try await self.reconnectTask?.value
|
||
|
||
guard let eventLoop = self.channel?.eventLoop else {
|
||
self.state.withLockedValue { $0 = .closed }
|
||
self.pingTask?.cancel()
|
||
clearPendingPings() // Clear pending pings to avoid promise leaks
|
||
self.fire(.closed)
|
||
return
|
||
}
|
||
let promise = eventLoop.makePromise(of: Void.self)
|
||
|
||
eventLoop.execute {
|
||
self.state.withLockedValue { $0 = .closed }
|
||
self.pingTask?.cancel()
|
||
self.clearPendingPings() // Clear pending pings to avoid promise leaks
|
||
self.channel?.close(mode: .all, promise: promise)
|
||
}
|
||
|
||
do {
|
||
try await promise.futureResult.get()
|
||
} catch ChannelError.alreadyClosed {
|
||
// we don't want to throw an error if channel is already closed
|
||
// as that would mean we would get an error closing client during reconnect
|
||
}
|
||
|
||
self.fire(.closed)
|
||
}
|
||
|
||
private func disconnect() async throws {
|
||
self.pingTask?.cancel()
|
||
clearPendingPings() // Clear pending pings to avoid promise leaks
|
||
|
||
// Safely clear batchBuffer before closing channel
|
||
// This prevents race conditions during deallocation
|
||
let bufferToRelease = self.batchBuffer
|
||
self.batchBuffer = nil
|
||
_ = bufferToRelease // Release after clearing reference
|
||
|
||
try await self.channel?.close().get()
|
||
}
|
||
|
||
/// Clear all pending ping requests to avoid promise leaks
|
||
private func clearPendingPings() {
|
||
let pendingPings = pingQueue.dequeueAll()
|
||
for ping in pendingPings {
|
||
ping.cancel()
|
||
}
|
||
if !pendingPings.isEmpty {
|
||
logger.debug("Cleared \(pendingPings.count) pending ping(s)")
|
||
}
|
||
}
|
||
|
||
func suspend() async throws {
|
||
self.reconnectTask?.cancel()
|
||
_ = try await self.reconnectTask?.value
|
||
|
||
// Handle case where channel is already nil (e.g., during rapid reconnections)
|
||
guard let eventLoop = self.channel?.eventLoop else {
|
||
// Set state to suspended even if channel is nil
|
||
self.state.withLockedValue { $0 = .suspended }
|
||
clearPendingPings() // Clear pending pings to avoid promise leaks
|
||
return
|
||
}
|
||
let promise = eventLoop.makePromise(of: Void.self)
|
||
|
||
eventLoop.execute { // This ensures the code block runs on the event loop
|
||
let shouldClose = self.state.withLockedValue { currentState in
|
||
let wasConnected = currentState == .connected
|
||
currentState = .suspended
|
||
return wasConnected
|
||
}
|
||
|
||
if shouldClose {
|
||
self.pingTask?.cancel()
|
||
self.clearPendingPings() // Clear pending pings to avoid promise leaks
|
||
self.channel?.close(mode: .all, promise: promise)
|
||
} else {
|
||
self.clearPendingPings() // Clear pending pings even if not closing
|
||
promise.succeed()
|
||
}
|
||
}
|
||
|
||
try await promise.futureResult.get()
|
||
self.fire(.suspended)
|
||
}
|
||
|
||
func resume() async throws {
|
||
guard let eventLoop = self.channel?.eventLoop else {
|
||
throw NatsError.ClientError.internalError("channel should not be nil")
|
||
}
|
||
try await eventLoop.submit {
|
||
let canResume = self.state.withLockedValue { $0 == .suspended }
|
||
guard canResume else {
|
||
throw NatsError.ClientError.invalidConnection(
|
||
"unable to resume connection - connection is not in suspended state")
|
||
}
|
||
self.handleReconnect()
|
||
}.get()
|
||
}
|
||
|
||
func reconnect() async throws {
|
||
try await suspend()
|
||
try await resume()
|
||
}
|
||
|
||
internal func sendPing(_ rttCommand: RttCommand? = nil) async {
|
||
let pingsOut = self.outstandingPings.wrappingIncrementThenLoad(
|
||
ordering: AtomicUpdateOrdering.relaxed)
|
||
if pingsOut > 2 {
|
||
handleDisconnect()
|
||
return
|
||
}
|
||
let ping = ClientOp.ping
|
||
do {
|
||
self.pingQueue.enqueue(rttCommand ?? RttCommand.makeFrom(channel: self.channel))
|
||
try await self.write(operation: ping)
|
||
logger.debug("sent ping: \(pingsOut)")
|
||
} catch {
|
||
logger.error("Unable to send ping: \(error)")
|
||
// Trigger reconnect on ping failure - connection may be broken
|
||
let currentState = state.withLockedValue { $0 }
|
||
if currentState == .connected {
|
||
handleDisconnect()
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
func channelActive(context: ChannelHandlerContext) {
|
||
logger.debug("TCP channel active")
|
||
|
||
parseRemainder.withLockedValue { $0 = nil }
|
||
|
||
inputBuffer = context.channel.allocator.buffer(capacity: 1024 * 1024 * 8)
|
||
}
|
||
|
||
func channelInactive(context: ChannelHandlerContext) {
|
||
logger.debug("TCP channel inactive")
|
||
|
||
// If we lost the channel before we delivered server INFO or connection
|
||
// establishment, make sure to fail any pending continuations to avoid leaks.
|
||
// Use captured error if available (e.g., TLS failure), otherwise use connectionClosed.
|
||
let errorToUse: Error = capturedConnectionError.withLockedValue({ err in
|
||
let captured = err
|
||
err = nil // Clear after using
|
||
if let capturedError = captured {
|
||
return NatsError.ConnectError.tlsFailure(capturedError)
|
||
} else {
|
||
return NatsError.ClientError.connectionClosed
|
||
}
|
||
})
|
||
|
||
if let continuation = serverInfoContinuation.withLockedValue({ cont in
|
||
let toResume = cont
|
||
cont = nil
|
||
return toResume
|
||
}) {
|
||
continuation.resume(throwing: errorToUse)
|
||
}
|
||
|
||
if let continuation = connectionEstablishedContinuation.withLockedValue({ cont in
|
||
let toResume = cont
|
||
cont = nil
|
||
return toResume
|
||
}) {
|
||
continuation.resume(throwing: errorToUse)
|
||
}
|
||
|
||
let shouldHandleDisconnect = state.withLockedValue { $0 != .closed && $0 != .disconnected }
|
||
if shouldHandleDisconnect {
|
||
handleDisconnect()
|
||
}
|
||
}
|
||
|
||
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||
logger.debug("Encountered error on the channel: \(error)")
|
||
|
||
// Capture connection-stage errors (especially TLS) for proper error reporting BEFORE closing
|
||
let isConnecting = state.withLockedValue { $0 == .pending || $0 == .connecting }
|
||
if isConnecting {
|
||
capturedConnectionError.withLockedValue { $0 = error }
|
||
}
|
||
|
||
context.close(promise: nil)
|
||
|
||
if let natsErr = error as? NatsErrorProtocol {
|
||
self.fire(.error(natsErr))
|
||
} else {
|
||
// 降级为 debug 级别,避免频繁输出错误日志
|
||
// uncleanShutdown 是常见的 TLS 关闭情况,不需要作为错误处理
|
||
logger.debug("Channel error (will reconnect if needed): \(error)")
|
||
}
|
||
|
||
// 注意:不在这里调用 handleDisconnect
|
||
// context.close() 会触发 channelInactive,由它负责处理断开逻辑
|
||
// 这样可以避免重复处理和过度重连
|
||
}
|
||
|
||
func handleDisconnect() {
|
||
// Prevent duplicate disconnect handling
|
||
let shouldProceed = state.withLockedValue { currentState -> Bool in
|
||
if currentState == .disconnected || currentState == .closed {
|
||
return false // Already in disconnected/closed state
|
||
}
|
||
return true
|
||
}
|
||
|
||
guard shouldProceed else {
|
||
return
|
||
}
|
||
|
||
// Set state to disconnected after check
|
||
state.withLockedValue { $0 = .disconnected }
|
||
|
||
// Clean up pending continuations to prevent leaks
|
||
if let continuation = serverInfoContinuation.withLockedValue({ cont in
|
||
let toResume = cont
|
||
cont = nil
|
||
return toResume
|
||
}) {
|
||
continuation.resume(throwing: NatsError.ClientError.connectionClosed)
|
||
}
|
||
|
||
if let continuation = connectionEstablishedContinuation.withLockedValue({ cont in
|
||
let toResume = cont
|
||
cont = nil
|
||
return toResume
|
||
}) {
|
||
continuation.resume(throwing: NatsError.ClientError.connectionClosed)
|
||
}
|
||
|
||
// Safely clear batchBuffer first to avoid race conditions
|
||
let bufferToRelease = self.batchBuffer
|
||
self.batchBuffer = nil
|
||
_ = bufferToRelease // Release after clearing reference
|
||
|
||
if let channel = self.channel {
|
||
let promise = channel.eventLoop.makePromise(of: Void.self)
|
||
Task {
|
||
do {
|
||
try await self.disconnect()
|
||
promise.succeed()
|
||
} catch ChannelError.alreadyClosed {
|
||
// if the channel was already closed, no need to return error
|
||
promise.succeed()
|
||
} catch {
|
||
promise.fail(error)
|
||
}
|
||
}
|
||
promise.futureResult.whenComplete { result in
|
||
do {
|
||
try result.get()
|
||
} catch {
|
||
// 降级为 debug:网络断开时 TLS 无法完成正常关闭握手是预期行为
|
||
logger.debug("Connection closed with error (will reconnect): \(error)")
|
||
}
|
||
// 无论成功还是失败,都要触发 .disconnected 事件
|
||
// 这样上层才能感知到连接已断开
|
||
self.fire(.disconnected)
|
||
// Only start reconnect after disconnect is complete
|
||
self.handleReconnect()
|
||
}
|
||
} else {
|
||
self.fire(.disconnected)
|
||
handleReconnect()
|
||
}
|
||
}
|
||
|
||
func handleReconnect() {
|
||
// Cancel any existing reconnect task to prevent multiple concurrent reconnections
|
||
if let oldTask = reconnectTask {
|
||
oldTask.cancel()
|
||
}
|
||
|
||
logger.info("🔄 Starting reconnection process...")
|
||
|
||
reconnectTask = Task {
|
||
var connected = false
|
||
var attempt = 0
|
||
while !Task.isCancelled
|
||
&& (maxReconnects == nil || self.reconnectAttempts < maxReconnects!)
|
||
{
|
||
attempt += 1
|
||
logger.info("🔄 Reconnect attempt \(attempt), total reconnectAttempts: \(self.reconnectAttempts)")
|
||
do {
|
||
try await self.connect()
|
||
connected = true
|
||
logger.info("✅ Reconnection successful after \(attempt) attempts")
|
||
break // Successfully connected
|
||
} catch is CancellationError {
|
||
logger.info("⚠️ Reconnect task cancelled")
|
||
return
|
||
} catch {
|
||
logger.warning("⚠️ Reconnect attempt \(attempt) failed: \(error)")
|
||
if !Task.isCancelled {
|
||
logger.info("⏳ Waiting \(Double(self.reconnectWait) / 1_000_000_000)s before next attempt...")
|
||
try? await Task.sleep(nanoseconds: self.reconnectWait)
|
||
}
|
||
}
|
||
}
|
||
|
||
// Early return if cancelled
|
||
if Task.isCancelled {
|
||
logger.info("⚠️ Reconnect task cancelled after \(attempt) connection attempts")
|
||
return
|
||
}
|
||
|
||
// If we got here without connecting and weren't cancelled, we hit max reconnects
|
||
if !connected {
|
||
logger.error("❌ Could not reconnect; maxReconnects exceeded (\(self.maxReconnects ?? -1))")
|
||
try? await self.close()
|
||
return
|
||
}
|
||
|
||
// Recreate subscriptions - safely copy first
|
||
let subsToRestore = subscriptions.withLockedValue { Array($0) }
|
||
for (sid, sub) in subsToRestore {
|
||
do {
|
||
try await write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
|
||
} catch {
|
||
logger.error("Error recreating subscription \(sid): \(error)")
|
||
}
|
||
}
|
||
|
||
self.channel?.eventLoop.execute {
|
||
self.state.withLockedValue { $0 = .connected }
|
||
self.fire(.connected)
|
||
}
|
||
}
|
||
}
|
||
|
||
func write(operation: ClientOp) async throws {
|
||
guard let buffer = self.batchBuffer else {
|
||
// If state is connected but batchBuffer is nil, this is a "fake connection" state
|
||
// Trigger reconnect to recover
|
||
let currentState = state.withLockedValue { $0 }
|
||
if currentState == .connected {
|
||
// 降级为 debug:这是断网恢复场景的正常状态
|
||
logger.debug("Write failed: batchBuffer is nil, triggering reconnect")
|
||
handleDisconnect()
|
||
}
|
||
throw NatsError.ClientError.invalidConnection("not connected")
|
||
}
|
||
do {
|
||
try await buffer.writeMessage(operation)
|
||
} catch {
|
||
// Trigger reconnect on write failure - connection may be broken
|
||
let currentState = state.withLockedValue { $0 }
|
||
if currentState == .connected {
|
||
// 降级为 debug:网络断开时写入失败是预期行为
|
||
logger.debug("Write operation failed, triggering reconnect: \(error)")
|
||
handleDisconnect()
|
||
}
|
||
throw NatsError.ClientError.io(error)
|
||
}
|
||
}
|
||
|
||
internal func subscribe(
|
||
_ subject: String, queue: String? = nil
|
||
) async throws -> NatsSubscription {
|
||
let sid = self.subscriptionCounter.wrappingIncrementThenLoad(
|
||
ordering: AtomicUpdateOrdering.relaxed)
|
||
let sub = try NatsSubscription(sid: sid, subject: subject, queue: queue, conn: self)
|
||
|
||
// Add subscription BEFORE sending command to avoid race condition
|
||
subscriptions.withLockedValue { $0[sid] = sub }
|
||
|
||
do {
|
||
try await write(operation: ClientOp.subscribe((sid, subject, queue)))
|
||
} catch {
|
||
// Remove subscription if subscribe command fails
|
||
subscriptions.withLockedValue { $0.removeValue(forKey: sid) }
|
||
throw error
|
||
}
|
||
|
||
return sub
|
||
}
|
||
|
||
internal func unsubscribe(sub: NatsSubscription, max: UInt64?) async throws {
|
||
if let max, sub.delivered < max {
|
||
// if max is set and the sub has not yet reached it, send unsub with max set
|
||
// and do not remove the sub from connection
|
||
try await write(operation: ClientOp.unsubscribe((sid: sub.sid, max: max)))
|
||
sub.max = max
|
||
} else {
|
||
// if max is not set or the subscription received at least as many
|
||
// messages as max, send unsub command without max and remove sub from connection
|
||
try await write(operation: ClientOp.unsubscribe((sid: sub.sid, max: nil)))
|
||
self.removeSub(sub: sub)
|
||
}
|
||
}
|
||
|
||
internal func removeSub(sub: NatsSubscription) {
|
||
subscriptions.withLockedValue { $0.removeValue(forKey: sub.sid) }
|
||
sub.complete()
|
||
}
|
||
}
|
||
|
||
extension ConnectionHandler {
|
||
|
||
internal func fire(_ event: NatsEvent) {
|
||
let eventKind = event.kind()
|
||
guard let handlerStore = self.eventHandlerStore[eventKind] else { return }
|
||
|
||
for handler in handlerStore {
|
||
handler.handler(event)
|
||
}
|
||
}
|
||
|
||
internal func addListeners(
|
||
for events: [NatsEventKind], using handler: @escaping (NatsEvent) -> Void
|
||
) -> String {
|
||
|
||
let id = String.hash()
|
||
|
||
for event in events {
|
||
if self.eventHandlerStore[event] == nil {
|
||
self.eventHandlerStore[event] = []
|
||
}
|
||
self.eventHandlerStore[event]?.append(
|
||
NatsEventHandler(lid: id, handler: handler))
|
||
}
|
||
|
||
return id
|
||
|
||
}
|
||
|
||
internal func removeListener(_ id: String) {
|
||
|
||
for event in NatsEventKind.all {
|
||
|
||
let handlerStore = self.eventHandlerStore[event]
|
||
if let store = handlerStore {
|
||
self.eventHandlerStore[event] = store.filter { $0.listenerId != id }
|
||
}
|
||
|
||
}
|
||
|
||
}
|
||
|
||
}
|
||
|
||
/// Nats events
|
||
public enum NatsEventKind: String {
|
||
case connected = "connected"
|
||
case disconnected = "disconnected"
|
||
case closed = "closed"
|
||
case suspended = "suspended"
|
||
case lameDuckMode = "lameDuckMode"
|
||
case error = "error"
|
||
static let all = [connected, disconnected, closed, lameDuckMode, error]
|
||
}
|
||
|
||
public enum NatsEvent {
|
||
case connected
|
||
case disconnected
|
||
case suspended
|
||
case closed
|
||
case lameDuckMode
|
||
case error(NatsErrorProtocol)
|
||
|
||
public func kind() -> NatsEventKind {
|
||
switch self {
|
||
case .connected:
|
||
return .connected
|
||
case .disconnected:
|
||
return .disconnected
|
||
case .suspended:
|
||
return .suspended
|
||
case .closed:
|
||
return .closed
|
||
case .lameDuckMode:
|
||
return .lameDuckMode
|
||
case .error(_):
|
||
return .error
|
||
}
|
||
}
|
||
}
|
||
|
||
internal struct NatsEventHandler {
|
||
let listenerId: String
|
||
let handler: (NatsEvent) -> Void
|
||
init(lid: String, handler: @escaping (NatsEvent) -> Void) {
|
||
self.listenerId = lid
|
||
self.handler = handler
|
||
}
|
||
}
|