diff --git a/Sources/Nats/NatsConnection.swift b/Sources/Nats/NatsConnection.swift index fac0d0f..bbb6457 100644 --- a/Sources/Nats/NatsConnection.swift +++ b/Sources/Nats/NatsConnection.swift @@ -901,7 +901,7 @@ class ConnectionHandler: ChannelInboundHandler { continuation.resume(throwing: errorToUse) } - let shouldHandleDisconnect = state.withLockedValue { $0 == .connected } + let shouldHandleDisconnect = state.withLockedValue { $0 != .closed && $0 != .disconnected } if shouldHandleDisconnect { handleDisconnect() } @@ -923,16 +923,44 @@ class ConnectionHandler: ChannelInboundHandler { } else { logger.error("unexpected error: \(error)") } + + // Unified handling: use handleDisconnect for all non-closed/non-disconnected states let currentState = state.withLockedValue { $0 } - if currentState == .pending || currentState == .connecting { + if currentState != .closed && currentState != .disconnected { handleDisconnect() - } else if currentState == .disconnected { - handleReconnect() } } func handleDisconnect() { - state.withLockedValue { $0 = .disconnected } + // Prevent duplicate disconnect handling + let shouldProceed = state.withLockedValue { currentState -> Bool in + if currentState == .disconnected || currentState == .closed { + return false // Already in disconnected/closed state + } + $0 = .disconnected + return true + } + + guard shouldProceed else { + return + } + + // Clean up pending continuations to prevent leaks + if let continuation = serverInfoContinuation.withLockedValue({ cont in + let toResume = cont + cont = nil + return toResume + }) { + continuation.resume(throwing: NatsError.ClientError.connectionClosed) + } + + if let continuation = connectionEstablishedContinuation.withLockedValue({ cont in + let toResume = cont + cont = nil + return toResume + }) { + continuation.resume(throwing: NatsError.ClientError.connectionClosed) + } // Safely clear batchBuffer first to avoid race conditions let bufferToRelease = self.batchBuffer @@ -963,12 +991,17 @@ class ConnectionHandler: ChannelInboundHandler { self.handleReconnect() } } else { - // No channel, start reconnect immediately + self.fire(.disconnected) handleReconnect() } } func handleReconnect() { + // Cancel any existing reconnect task to prevent multiple concurrent reconnections + if let oldTask = reconnectTask { + oldTask.cancel() + } + reconnectTask = Task { var connected = false while !Task.isCancelled @@ -984,7 +1017,7 @@ class ConnectionHandler: ChannelInboundHandler { } catch { logger.debug("Could not reconnect: \(error)") if !Task.isCancelled { - try await Task.sleep(nanoseconds: self.reconnectWait) + try? await Task.sleep(nanoseconds: self.reconnectWait) } } } @@ -998,7 +1031,7 @@ class ConnectionHandler: ChannelInboundHandler { // If we got here without connecting and weren't cancelled, we hit max reconnects if !connected { logger.error("Could not reconnect; maxReconnects exceeded") - try await self.close() + try? await self.close() return }