diff --git a/Sources/Nats/NatsClient/NatsClient.swift b/Sources/Nats/NatsClient/NatsClient.swift index 61e0ab1..3d55f41 100755 --- a/Sources/Nats/NatsClient/NatsClient.swift +++ b/Sources/Nats/NatsClient/NatsClient.swift @@ -70,6 +70,17 @@ public class NatsClient { public var connectedUrl: URL? { connectionHandler?.connectedUrl } + + /// Returns the current connection state + public var connectionState: NatsState { + connectionHandler?.currentState ?? .closed + } + + /// Returns true if the client is currently connected to a server + public var isConnected: Bool { + connectionHandler?.currentState == .connected + } + internal let allocator = ByteBufferAllocator() internal var buffer: ByteBuffer internal var connectionHandler: ConnectionHandler? @@ -349,4 +360,90 @@ extension NatsClient { await connectionHandler.sendPing(ping) return try await ping.getRoundTripTime() } + + /// Checks if the connection is alive by sending a ping. + /// If the connection is not in connected state, returns false. + /// If the ping fails, it triggers a reconnect and returns false. + /// + /// - Parameter timeout: The maximum time to wait for a pong response (default: 5 seconds) + /// - Returns: true if the connection is alive, false otherwise + public func checkConnection(timeout: TimeInterval = 5) async -> Bool { + guard let connectionHandler = self.connectionHandler else { + return false + } + + // Quick state check first + guard connectionHandler.currentState == .connected else { + return false + } + + // Try to send a ping and wait for pong + do { + let ping = RttCommand.makeFrom(channel: connectionHandler.channel) + await connectionHandler.sendPing(ping) + + // Wait for pong with timeout + let rtt = try await withThrowingTaskGroup(of: TimeInterval.self) { group in + group.addTask { + return try await ping.getRoundTripTime() + } + group.addTask { + try await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + throw NatsError.RequestError.timeout + } + + // Return first successful result or timeout + if let result = try await group.next() { + group.cancelAll() + return result + } + throw NatsError.RequestError.timeout + } + + logger.debug("Connection check successful, RTT: \(rtt)s") + return true + } catch { + logger.warning("Connection check failed: \(error)") + // Connection is broken, the ping failure will trigger reconnect + return false + } + } + + /// Triggers a reconnect if the connection is not in a healthy state. + /// This is useful for proactively recovering from a broken connection. + /// + /// - Returns: true if a reconnect was triggered, false if already connected + @discardableResult + public func ensureConnected() async throws -> Bool { + guard let connectionHandler = self.connectionHandler else { + throw NatsError.ClientError.internalError("empty connection handler") + } + + if case .closed = connectionHandler.currentState { + throw NatsError.ClientError.connectionClosed + } + + // If already connected, verify with a ping + if connectionHandler.currentState == .connected { + let isAlive = await checkConnection(timeout: 3) + if isAlive { + return false // Already connected, no reconnect needed + } + // Connection check failed, reconnect will be triggered by ping failure + return true + } + + // If disconnected or suspended, trigger reconnect + if connectionHandler.currentState == .disconnected { + connectionHandler.handleReconnect() + return true + } + + if connectionHandler.currentState == .suspended { + try await connectionHandler.resume() + return true + } + + return false + } } diff --git a/Sources/Nats/NatsConnection.swift b/Sources/Nats/NatsConnection.swift index 8a3e625..fac0d0f 100644 --- a/Sources/Nats/NatsConnection.swift +++ b/Sources/Nats/NatsConnection.swift @@ -407,12 +407,25 @@ class ConnectionHandler: ChannelInboundHandler { } } onCancel: { logger.debug("Connection task cancelled") - // Clean up resources - if let channel = self.channel { - channel.close(mode: .all, promise: nil) - self.channel = nil - } + // Clean up resources safely to avoid race conditions + // Capture references first to avoid concurrent access issues + let channelToClose = self.channel + let bufferToRelease = self.batchBuffer + + // Clear references first to prevent other threads from using them + self.channel = nil self.batchBuffer = nil + + // Close channel asynchronously after clearing references + // This ensures BatchBuffer's deinit won't conflict with channel close + if let channel = channelToClose { + channel.eventLoop.execute { + channel.close(mode: .all, promise: nil) + } + } + + // bufferToRelease will be released here after channel close is scheduled + _ = bufferToRelease let continuationToResume: CheckedContinuation? = self .serverInfoContinuation.withLockedValue { cont in @@ -551,12 +564,25 @@ class ConnectionHandler: ChannelInboundHandler { } } onCancel: { logger.debug("Client connect initialization cancelled") - // Clean up resources - if let channel = self.channel { - channel.close(mode: .all, promise: nil) - self.channel = nil - } + // Clean up resources safely to avoid race conditions + // Capture references first to avoid concurrent access issues + let channelToClose = self.channel + let bufferToRelease = self.batchBuffer + + // Clear references first to prevent other threads from using them + self.channel = nil self.batchBuffer = nil + + // Close channel asynchronously after clearing references + // This ensures BatchBuffer's deinit won't conflict with channel close + if let channel = channelToClose { + channel.eventLoop.execute { + channel.close(mode: .all, promise: nil) + } + } + + // bufferToRelease will be released here after channel close is scheduled + _ = bufferToRelease let continuationToResume: CheckedContinuation? = self .connectionEstablishedContinuation.withLockedValue { cont in @@ -738,6 +764,13 @@ class ConnectionHandler: ChannelInboundHandler { private func disconnect() async throws { self.pingTask?.cancel() clearPendingPings() // Clear pending pings to avoid promise leaks + + // Safely clear batchBuffer before closing channel + // This prevents race conditions during deallocation + let bufferToRelease = self.batchBuffer + self.batchBuffer = nil + _ = bufferToRelease // Release after clearing reference + try await self.channel?.close().get() } @@ -819,6 +852,11 @@ class ConnectionHandler: ChannelInboundHandler { logger.debug("sent ping: \(pingsOut)") } catch { logger.error("Unable to send ping: \(error)") + // Trigger reconnect on ping failure - connection may be broken + let currentState = state.withLockedValue { $0 } + if currentState == .connected { + handleDisconnect() + } } } @@ -895,6 +933,12 @@ class ConnectionHandler: ChannelInboundHandler { func handleDisconnect() { state.withLockedValue { $0 = .disconnected } + + // Safely clear batchBuffer first to avoid race conditions + let bufferToRelease = self.batchBuffer + self.batchBuffer = nil + _ = bufferToRelease // Release after clearing reference + if let channel = self.channel { let promise = channel.eventLoop.makePromise(of: Void.self) Task { @@ -915,10 +959,13 @@ class ConnectionHandler: ChannelInboundHandler { } catch { logger.error("Error closing connection: \(error)") } + // Only start reconnect after disconnect is complete + self.handleReconnect() } + } else { + // No channel, start reconnect immediately + handleReconnect() } - - handleReconnect() } func handleReconnect() { @@ -979,6 +1026,12 @@ class ConnectionHandler: ChannelInboundHandler { do { try await buffer.writeMessage(operation) } catch { + // Trigger reconnect on write failure - connection may be broken + let currentState = state.withLockedValue { $0 } + if currentState == .connected { + logger.error("Write operation failed, triggering reconnect: \(error)") + handleDisconnect() + } throw NatsError.ClientError.io(error) } }