init
Some checks failed
ci / macos (push) Has been cancelled
ci / ios (push) Has been cancelled
ci / check-linter (push) Has been cancelled

This commit is contained in:
wenzuhuai
2026-01-12 18:29:52 +08:00
commit d7bdb4f378
87 changed files with 12664 additions and 0 deletions

View File

@@ -0,0 +1,128 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
import NIO
import NIOConcurrencyHelpers
extension BatchBuffer {
struct State {
private var buffer: ByteBuffer
private var allocator: ByteBufferAllocator
var waitingPromises: [(ClientOp, UnsafeContinuation<Void, Error>)] = []
var isWriteInProgress: Bool = false
internal init(allocator: ByteBufferAllocator, batchSize: Int = 16 * 1024) {
self.allocator = allocator
self.buffer = allocator.buffer(capacity: batchSize)
}
var readableBytes: Int {
return self.buffer.readableBytes
}
mutating func clear() {
buffer.clear()
}
mutating func getWriteBuffer() -> ByteBuffer {
var writeBuffer = allocator.buffer(capacity: buffer.readableBytes)
writeBuffer.writeBytes(buffer.readableBytesView)
buffer.clear()
return writeBuffer
}
mutating func writeMessage(_ message: ClientOp) {
self.buffer.writeClientOp(message)
}
}
}
internal class BatchBuffer {
private let batchSize: Int
private let channel: Channel
private let state: NIOLockedValueBox<State>
init(channel: Channel, batchSize: Int = 16 * 1024) {
self.batchSize = batchSize
self.channel = channel
self.state = .init(
State(allocator: channel.allocator)
)
}
func writeMessage(_ message: ClientOp) async throws {
#if SWIFT_NATS_BATCH_BUFFER_DISABLED
let b = channel.allocator.buffer(bytes: data)
try await channel.writeAndFlush(b)
#else
// Batch writes and if we have more than the batch size
// already in the buffer await until buffer is flushed
// to handle any back pressure
try await withUnsafeThrowingContinuation { continuation in
self.state.withLockedValue { state in
guard state.readableBytes < self.batchSize else {
state.waitingPromises.append((message, continuation))
return
}
state.writeMessage(message)
self.flushWhenIdle(state: &state)
continuation.resume()
}
}
#endif
}
private func flushWhenIdle(state: inout State) {
// The idea is to keep writing to the buffer while a writeAndFlush() is
// in progress, so we can batch as many messages as possible.
guard !state.isWriteInProgress else {
return
}
// We need a separate write buffer so we can free the message buffer for more
// messages to be collected.
let writeBuffer = state.getWriteBuffer()
state.isWriteInProgress = true
let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
writePromise.futureResult.whenComplete { result in
self.state.withLockedValue { state in
state.isWriteInProgress = false
switch result {
case .success:
for (message, continuation) in state.waitingPromises {
state.writeMessage(message)
continuation.resume()
}
state.waitingPromises.removeAll()
case .failure(let error):
for (_, continuation) in state.waitingPromises {
continuation.resume(throwing: error)
}
state.waitingPromises.removeAll()
state.clear()
}
// Check if there are any pending flushes
if state.readableBytes > 0 {
self.flushWhenIdle(state: &state)
}
}
}
self.channel.writeAndFlush(writeBuffer, promise: writePromise)
}
}

View File

@@ -0,0 +1,32 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import NIOConcurrencyHelpers
internal class ConcurrentQueue<T> {
private var elements: [T] = []
private let lock = NIOLock()
func enqueue(_ element: T) {
lock.lock()
defer { lock.unlock() }
elements.append(element)
}
func dequeue() -> T? {
lock.lock()
defer { lock.unlock() }
guard !elements.isEmpty else { return nil }
return elements.removeFirst()
}
}

View File

@@ -0,0 +1,87 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
import NIO
extension ByteBuffer {
mutating func writeClientOp(_ op: ClientOp) {
switch op {
case .publish((let subject, let reply, let payload, let headers)):
if let payload = payload {
self.reserveCapacity(
minimumWritableBytes: payload.count + subject.utf8.count
+ NatsOperation.publish.rawValue.count + 12)
if headers != nil {
self.writeBytes(NatsOperation.hpublish.rawBytes)
} else {
self.writeBytes(NatsOperation.publish.rawBytes)
}
self.writeString(" ")
self.writeString(subject)
self.writeString(" ")
if let reply = reply {
self.writeString("\(reply) ")
}
if let headers = headers {
let headers = headers.toBytes()
let totalLen = headers.count + payload.count
let headersLen = headers.count
self.writeString("\(headersLen) \(totalLen)\r\n")
self.writeData(headers)
} else {
self.writeString("\(payload.count)\r\n")
}
self.writeData(payload)
self.writeString("\r\n")
} else {
self.reserveCapacity(
minimumWritableBytes: subject.utf8.count + NatsOperation.publish.rawValue.count
+ 12)
self.writeBytes(NatsOperation.publish.rawBytes)
self.writeString(" ")
self.writeString(subject)
if let reply = reply {
self.writeString("\(reply) ")
}
self.writeString("\r\n")
}
case .subscribe((let sid, let subject, let queue)):
if let queue {
self.writeString(
"\(NatsOperation.subscribe.rawValue) \(subject) \(queue) \(sid)\r\n")
} else {
self.writeString("\(NatsOperation.subscribe.rawValue) \(subject) \(sid)\r\n")
}
case .unsubscribe((let sid, let max)):
if let max {
self.writeString("\(NatsOperation.unsubscribe.rawValue) \(sid) \(max)\r\n")
} else {
self.writeString("\(NatsOperation.unsubscribe.rawValue) \(sid)\r\n")
}
case .connect(let info):
// This encode can't actually fail
let json = try! JSONEncoder().encode(info)
self.reserveCapacity(minimumWritableBytes: json.count + 5)
self.writeString("\(NatsOperation.connect.rawValue) ")
self.writeData(json)
self.writeString("\r\n")
case .ping:
self.writeString("\(NatsOperation.ping.rawValue)\r\n")
case .pong:
self.writeString("\(NatsOperation.pong.rawValue)\r\n")
}
}
}

View File

@@ -0,0 +1,24 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
extension Data {
/// Swift does not provide a way to encode data to base64 without padding in URL safe way.
func base64EncodedURLSafeNotPadded() -> String {
return self.base64EncodedString()
.replacingOccurrences(of: "+", with: "-")
.replacingOccurrences(of: "/", with: "_")
.trimmingCharacters(in: CharacterSet(charactersIn: "="))
}
}

View File

@@ -0,0 +1,184 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
extension Data {
private static let cr = UInt8(ascii: "\r")
private static let lf = UInt8(ascii: "\n")
private static let crlf = Data([cr, lf])
private static var currentNum = 0
private static var errored = false
internal static let versionLinePrefix = "NATS/1.0"
func removePrefix(_ prefix: Data) -> Data {
guard self.starts(with: prefix) else { return self }
return self.dropFirst(prefix.count)
}
func split(
separator: Data, maxSplits: Int = .max, omittingEmptySubsequences: Bool = true
)
-> [Data]
{
var chunks: [Data] = []
var start = startIndex
var end = startIndex
var splitsCount = 0
while end < count {
if splitsCount >= maxSplits {
break
}
if self[start..<end].elementsEqual(separator) {
if !omittingEmptySubsequences || start != end {
chunks.append(self[start..<end])
}
start = index(end, offsetBy: separator.count)
end = start
splitsCount += 1
continue
}
end = index(after: end)
}
if start <= endIndex {
if !omittingEmptySubsequences || start != endIndex {
chunks.append(self[start..<endIndex])
}
}
return chunks
}
func getMessageType() -> NatsOperation? {
guard self.count > 2 else { return nil }
for operation in NatsOperation.allOperations() {
if self.starts(with: operation.rawBytes) {
return operation
}
}
return nil
}
func starts(with bytes: [UInt8]) -> Bool {
guard self.count >= bytes.count else { return false }
return self.prefix(bytes.count).elementsEqual(bytes)
}
internal mutating func prepend(_ other: Data) {
self = other + self
}
internal func parseOutMessages() throws -> (ops: [ServerOp], remainder: Data?) {
var serverOps = [ServerOp]()
var startIndex = self.startIndex
var remainder: Data?
while startIndex < self.endIndex {
var nextLineStartIndex: Int
var lineData: Data
if let range = self[startIndex...].range(of: Data.crlf) {
let lineEndIndex = range.lowerBound
nextLineStartIndex =
self.index(range.upperBound, offsetBy: 0, limitedBy: self.endIndex)
?? self.endIndex
lineData = self[startIndex..<lineEndIndex]
} else {
remainder = self[startIndex..<self.endIndex]
break
}
if lineData.count == 0 {
startIndex = nextLineStartIndex
continue
}
let serverOp = try ServerOp.parse(from: lineData)
// if it's a message, get the full payload and add to returned data
if case .message(var msg) = serverOp {
if msg.length == 0 {
serverOps.append(serverOp)
} else {
var payload = Data()
let payloadEndIndex = nextLineStartIndex + msg.length
let payloadStartIndex = nextLineStartIndex
// include crlf in the expected payload leangth
if payloadEndIndex + Data.crlf.count > endIndex {
remainder = self[startIndex..<self.endIndex]
break
}
payload.append(self[payloadStartIndex..<payloadEndIndex])
msg.payload = payload
startIndex =
self.index(
payloadEndIndex, offsetBy: Data.crlf.count, limitedBy: self.endIndex)
?? self.endIndex
serverOps.append(.message(msg))
continue
}
//TODO(jrm): Add HMSG handling here too.
} else if case .hMessage(var msg) = serverOp {
if msg.length == 0 {
serverOps.append(serverOp)
} else {
let headersStartIndex = nextLineStartIndex
let headersEndIndex = nextLineStartIndex + msg.headersLength
let payloadStartIndex = headersEndIndex
let payloadEndIndex = nextLineStartIndex + msg.length
var payload: Data?
if msg.length > msg.headersLength {
payload = Data()
}
var headers = NatsHeaderMap()
// if the whole msg length (including training crlf) is longer
// than the remaining chunk, break and return the remainder
if payloadEndIndex + Data.crlf.count > endIndex {
remainder = self[startIndex..<self.endIndex]
break
}
let headersData = self[headersStartIndex..<headersEndIndex]
if let headersString = String(data: headersData, encoding: .utf8) {
headers = try NatsHeaderMap(from: headersString)
}
msg.status = headers.status
msg.description = headers.description
msg.headers = headers
if var payload = payload {
payload.append(self[payloadStartIndex..<payloadEndIndex])
msg.payload = payload
}
startIndex =
self.index(
payloadEndIndex, offsetBy: Data.crlf.count, limitedBy: self.endIndex)
?? self.endIndex
serverOps.append(.hMessage(msg))
continue
}
} else {
// otherwise, just add this server op to the result
serverOps.append(serverOp)
}
startIndex = nextLineStartIndex
}
return (serverOps, remainder)
}
}

View File

@@ -0,0 +1,24 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
import NIOPosix
extension Data {
func toString() -> String? {
if let str = String(data: self, encoding: .utf8) {
return str
}
return nil
}
}

View File

@@ -0,0 +1,38 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
extension String {
private static let charactersToTrim: CharacterSet = .whitespacesAndNewlines.union(
CharacterSet(charactersIn: "'"))
static func hash() -> String {
let uuid = String.uuid()
return uuid[0...7]
}
func trimWhitespacesAndApostrophes() -> String {
return self.trimmingCharacters(in: String.charactersToTrim)
}
static func uuid() -> String {
return UUID().uuidString.trimmingCharacters(in: .punctuationCharacters)
}
subscript(bounds: CountableClosedRange<Int>) -> String {
let start = index(startIndex, offsetBy: bounds.lowerBound)
let end = index(startIndex, offsetBy: bounds.upperBound)
return String(self[start...end])
}
}

View File

@@ -0,0 +1,182 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import NIO
import NIOHTTP1
import NIOWebSocket
// Adapted from https://github.com/vapor/websocket-kit/blob/main/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift
internal final class HTTPUpgradeRequestHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPClientResponsePart
typealias OutboundOut = HTTPClientRequestPart
let host: String
let path: String
let query: String?
let headers: HTTPHeaders
let upgradePromise: EventLoopPromise<Void>
private var requestSent = false
init(
host: String, path: String, query: String?, headers: HTTPHeaders,
upgradePromise: EventLoopPromise<Void>
) {
self.host = host
self.path = path
self.query = query
self.headers = headers
self.upgradePromise = upgradePromise
}
func channelActive(context: ChannelHandlerContext) {
self.sendRequest(context: context)
context.fireChannelActive()
}
func handlerAdded(context: ChannelHandlerContext) {
if context.channel.isActive {
self.sendRequest(context: context)
}
}
private func sendRequest(context: ChannelHandlerContext) {
if self.requestSent {
// we might run into this handler twice, once in handlerAdded and once in channelActive.
return
}
self.requestSent = true
var headers = self.headers
headers.add(name: "Host", value: self.host)
var uri: String
if self.path.hasPrefix("/") || self.path.hasPrefix("ws://") || self.path.hasPrefix("wss://")
{
uri = self.path
} else {
uri = "/" + self.path
}
if let query = self.query {
uri += "?\(query)"
}
let requestHead = HTTPRequestHead(
version: HTTPVersion(major: 1, minor: 1),
method: .GET,
uri: uri,
headers: headers
)
context.write(self.wrapOutboundOut(.head(requestHead)), promise: nil)
let emptyBuffer = context.channel.allocator.buffer(capacity: 0)
let body = HTTPClientRequestPart.body(.byteBuffer(emptyBuffer))
context.write(self.wrapOutboundOut(body), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
// `NIOHTTPClientUpgradeHandler` should consume the first response in the success case,
// any response we see here indicates a failure. Report the failure and tidy up at the end of the response.
let clientResponse = self.unwrapInboundIn(data)
switch clientResponse {
case .head(let responseHead):
self.upgradePromise.fail(
NatsError.ClientError.invalidConnection("ws error \(responseHead)"))
case .body: break
case .end:
context.close(promise: nil)
}
}
func errorCaught(context: ChannelHandlerContext, error: Error) {
self.upgradePromise.fail(error)
context.close(promise: nil)
}
}
internal final class WebSocketByteBufferCodec: ChannelDuplexHandler {
typealias InboundIn = WebSocketFrame
typealias InboundOut = ByteBuffer
typealias OutboundIn = ByteBuffer
typealias OutboundOut = WebSocketFrame
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let frame = unwrapInboundIn(data)
switch frame.opcode {
case .binary:
context.fireChannelRead(wrapInboundOut(frame.data))
case .text:
preconditionFailure("We will never receive a text frame")
case .continuation:
preconditionFailure("We will never receive a continuation frame")
case .pong:
break
case .ping:
if frame.fin {
var frameData = frame.data
let maskingKey = frame.maskKey
if let maskingKey = maskingKey {
frameData.webSocketUnmask(maskingKey)
}
let bb = context.channel.allocator.buffer(bytes: frameData.readableBytesView)
self.send(
bb,
context: context,
opcode: .pong
)
} else {
context.close(promise: nil)
}
default:
// We ignore all other frames.
break
}
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let buffer = unwrapOutboundIn(data)
let frame = WebSocketFrame(
fin: true,
opcode: .binary,
maskKey: self.makeMaskKey(),
data: buffer
)
context.write(wrapOutboundOut(frame), promise: promise)
}
public func send(
_ data: ByteBuffer,
context: ChannelHandlerContext,
opcode: WebSocketOpcode = .binary,
fin: Bool = true,
promise: EventLoopPromise<Void>? = nil
) {
let frame = WebSocketFrame(
fin: fin,
opcode: opcode,
maskKey: self.makeMaskKey(),
data: data
)
context.writeAndFlush(wrapOutboundOut(frame), promise: promise)
}
func makeMaskKey() -> WebSocketMaskingKey? {
/// See https://github.com/apple/swift/issues/66099
var generator = SystemRandomNumberGenerator()
return WebSocketMaskingKey.random(using: &generator)
}
}

View File

@@ -0,0 +1,57 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
extension NatsClient {
/// Registers a callback for given event types.
///
/// - Parameters:
/// - events: an array of ``NatsEventKind`` for which the handler will be invoked.
/// - handler: a callback invoked upon triggering a specific event.
///
/// - Returns an ID of the registered listener which can be used to disable it.
@discardableResult
public func on(_ events: [NatsEventKind], _ handler: @escaping (NatsEvent) -> Void) -> String {
guard let connectionHandler = self.connectionHandler else {
return ""
}
return connectionHandler.addListeners(for: events, using: handler)
}
/// Registers a callback for given event type.
///
/// - Parameters:
/// - events: a ``NatsEventKind`` for which the handler will be invoked.
/// - handler: a callback invoked upon triggering a specific event.
///
/// - Returns an ID of the registered listener which can be used to disable it.
@discardableResult
public func on(_ event: NatsEventKind, _ handler: @escaping (NatsEvent) -> Void) -> String {
guard let connectionHandler = self.connectionHandler else {
return ""
}
return connectionHandler.addListeners(for: [event], using: handler)
}
/// Disables the event listener.
///
/// - Parameter id: an ID of a listener to be disabled (returned when creating it).
public func off(_ id: String) {
guard let connectionHandler = self.connectionHandler else {
return
}
connectionHandler.removeListener(id)
}
}

View File

@@ -0,0 +1,352 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Dispatch
import Foundation
import Logging
import NIO
import NIOFoundationCompat
import Nuid
public var logger = Logger(label: "Nats")
/// NatsClient connection states
public enum NatsState {
case pending
case connecting
case connected
case disconnected
case closed
case suspended
}
public struct Auth {
var user: String?
var password: String?
var token: String?
var credentialsPath: URL?
var nkeyPath: URL?
var nkey: String?
init() {
}
init(user: String, password: String) {
self.user = user
self.password = password
}
init(token: String) {
self.token = token
}
static func fromCredentials(_ credentials: URL) -> Auth {
var auth = Auth()
auth.credentialsPath = credentials
return auth
}
static func fromNkey(_ nkey: URL) -> Auth {
var auth = Auth()
auth.nkeyPath = nkey
return auth
}
static func fromNkey(_ nkey: String) -> Auth {
var auth = Auth()
auth.nkey = nkey
return auth
}
}
public class NatsClient {
public var connectedUrl: URL? {
connectionHandler?.connectedUrl
}
internal let allocator = ByteBufferAllocator()
internal var buffer: ByteBuffer
internal var connectionHandler: ConnectionHandler?
internal var inboxPrefix: String = "_INBOX."
internal init() {
self.buffer = allocator.buffer(capacity: 1024)
}
/// Returns a new inbox subject using the configured prefix and a generated NUID.
public func newInbox() -> String {
return inboxPrefix + nextNuid()
}
}
extension NatsClient {
/// Connects to a NATS server using configuration provided via ``NatsClientOptions``.
/// If ``NatsClientOptions/retryOnfailedConnect()`` is used, `connect()`
/// will not wait until the connection is established but rather return immediatelly.
///
/// > **Throws:**
/// > - ``NatsError/ConnectError/invalidConfig(_:)`` if the provided configuration is invalid
/// > - ``NatsError/ConnectError/tlsFailure(_:)`` if upgrading to TLS connection fails
/// > - ``NatsError/ConnectError/timeout`` if there was a timeout waiting to establish TCP connection
/// > - ``NatsError/ConnectError/dns(_:)`` if there was an error during dns lookup
/// > - ``NatsError/ConnectError/io`` if there was other error establishing connection
/// > - ``NatsError/ServerError/autorization(_:)`` if connection could not be established due to invalid/missing/expired auth
/// > - ``NatsError/ServerError/other(_:)`` if the server responds to client connection with a different error (e.g. max connections exceeded)
public func connect() async throws {
logger.debug("connect")
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
// Check if already connected or in invalid state for connect()
let currentState = connectionHandler.currentState
switch currentState {
case .connected, .connecting:
throw NatsError.ClientError.alreadyConnected
case .closed:
throw NatsError.ClientError.connectionClosed
case .suspended:
throw NatsError.ClientError.invalidConnection(
"connection is suspended, use resume() instead")
case .pending, .disconnected:
// These states allow connection/reconnection
break
}
// Set state to connecting immediately to prevent concurrent connect() calls
connectionHandler.setState(.connecting)
do {
if !connectionHandler.retryOnFailedConnect {
try await connectionHandler.connect()
connectionHandler.setState(.connected)
connectionHandler.fire(.connected)
} else {
connectionHandler.handleReconnect()
}
} catch {
// Reset state on connection failure
connectionHandler.setState(.disconnected)
throw error
}
}
/// Closes a connection to NATS server.
///
/// - Throws ``NatsError/ClientError/connectionClosed`` if the conneciton is already closed.
public func close() async throws {
logger.debug("close")
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
if case .closed = connectionHandler.currentState {
throw NatsError.ClientError.connectionClosed
}
try await connectionHandler.close()
}
/// Suspends a connection to NATS server.
/// A suspended connection does not receive messages on subscriptions.
/// It can be resumed using ``resume()`` which restores subscriptions on successful reconnect.
///
/// - Throws ``NatsError/ClientError/connectionClosed`` if the conneciton is closed.
public func suspend() async throws {
logger.debug("suspend")
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
if case .closed = connectionHandler.currentState {
throw NatsError.ClientError.connectionClosed
}
try await connectionHandler.suspend()
}
/// Resumes a suspended connection.
/// ``resume()`` will not wait for successful reconnection but rather trigger a reconnect process and return.
/// Register ``NatsEvent`` using ``NatsClient/on()`` to wait for successful reconnection.
///
/// - Throws ``NatsError/ClientError`` if the conneciton is not in suspended state.
public func resume() async throws {
logger.debug("resume")
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
if case .closed = connectionHandler.currentState {
throw NatsError.ClientError.connectionClosed
}
try await connectionHandler.resume()
}
/// Forces a reconnect attempt to the server.
/// This is a non-blocking operation and will start the process without waiting for it to complete.
///
/// - Throws ``NatsError/ClientError/connectionClosed`` if the conneciton is closed.
public func reconnect() async throws {
logger.debug("resume")
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
if case .closed = connectionHandler.currentState {
throw NatsError.ClientError.connectionClosed
}
try await connectionHandler.reconnect()
}
/// Publishes a message on a given subject.
///
/// - Parameters:
/// - payload: data to be published.
/// - subject: a NATS subject on which the message will be published.
/// - reply: optional reply subject when publishing a request.
/// - headers: optional message headers.
///
/// > **Throws:**
/// > - ``NatsError/ClientError/connectionClosed`` if the conneciton is closed.
/// > - ``NatsError/ClientError/io(_:)`` if there is an error writing message to a TCP socket (e.g. bloken pipe).
public func publish(
_ payload: Data, subject: String, reply: String? = nil, headers: NatsHeaderMap? = nil
) async throws {
logger.debug("publish")
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
if case .closed = connectionHandler.currentState {
throw NatsError.ClientError.connectionClosed
}
try await connectionHandler.write(
operation: ClientOp.publish((subject, reply, payload, headers)))
}
/// Sends a blocking request on a given subject.
///
/// - Parameters:
/// - payload: data to be published in the request.
/// - subject: a NATS subject on which the request will be published.
/// - headers: optional request headers.
/// - timeout: request timeout - defaults to 5 seconds.
///
/// - Returns a ``NatsMessage`` containing the response.
///
/// > **Throws:**
/// > - ``NatsError/ClientError/connectionClosed`` if the conneciton is closed.
/// > - ``NatsError/ClientError/io(_:)`` if there is an error writing message to a TCP socket (e.g. bloken pipe).
/// > - ``NatsError/RequestError/noResponders`` if there are no responders available for the request.
/// > - ``NatsError/RequestError/timeout`` if there was a timeout waiting for the response.
public func request(
_ payload: Data, subject: String, headers: NatsHeaderMap? = nil, timeout: TimeInterval = 5
) async throws -> NatsMessage {
logger.debug("request")
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
if case .closed = connectionHandler.currentState {
throw NatsError.ClientError.connectionClosed
}
let inbox = newInbox()
let sub = try await connectionHandler.subscribe(inbox)
try await sub.unsubscribe(after: 1)
try await connectionHandler.write(
operation: ClientOp.publish((subject, inbox, payload, headers)))
return try await withThrowingTaskGroup(
of: NatsMessage?.self
) { group in
group.addTask {
do {
return try await sub.makeAsyncIterator().next()
} catch NatsError.SubscriptionError.permissionDenied {
throw NatsError.RequestError.permissionDenied
}
}
// task for the timeout
group.addTask {
try await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000))
return nil
}
for try await result in group {
// if the result is not empty, return it (or throw status error)
if let msg = result {
group.cancelAll()
if let status = msg.status, status == StatusCode.noResponders {
throw NatsError.RequestError.noResponders
}
return msg
} else {
try await sub.unsubscribe()
group.cancelAll()
throw NatsError.RequestError.timeout
}
}
// this should not be reachable
throw NatsError.ClientError.internalError("error waiting for response")
}
}
/// Flushes the internal buffer ensuring that all messages are sent.
///
/// - Throws ``NatsError/ClientError/connectionClosed`` if the conneciton is closed.
public func flush() async throws {
logger.debug("flush")
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
if case .closed = connectionHandler.currentState {
throw NatsError.ClientError.connectionClosed
}
connectionHandler.channel?.flush()
}
/// Subscribes to a subject to receive messages.
///
/// - Parameters:
/// - subject:a subject the client want's to subscribe to.
/// - queue: optional queue group name.
///
/// - Returns a ``NatsSubscription`` allowing iteration over incoming messages.
///
/// > **Throws:**
/// > - ``NatsError/ClientError/connectionClosed`` if the conneciton is closed.
/// > - ``NatsError/ClientError/io(_:)`` if there is an error sending the SUB request to the server.
/// > - ``NatsError/SubscriptionError/invalidSubject`` if the provided subject is invalid.
/// > - ``NatsError/SubscriptionError/invalidQueue`` if the provided queue group is invalid.
public func subscribe(subject: String, queue: String? = nil) async throws -> NatsSubscription {
logger.info("subscribe to subject \(subject)")
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
if case .closed = connectionHandler.currentState {
throw NatsError.ClientError.connectionClosed
}
return try await connectionHandler.subscribe(subject, queue: queue)
}
/// Sends a PING to the server, returning the time it took for the server to respond.
///
/// - Returns rtt of the request.
///
/// > **Throws:**
/// > - ``NatsError/ClientError/connectionClosed`` if the conneciton is closed.
/// > - ``NatsError/ClientError/io(_:)`` if there is an error sending the SUB request to the server.
public func rtt() async throws -> TimeInterval {
guard let connectionHandler = self.connectionHandler else {
throw NatsError.ClientError.internalError("empty connection handler")
}
if case .closed = connectionHandler.currentState {
throw NatsError.ClientError.connectionClosed
}
let ping = RttCommand.makeFrom(channel: connectionHandler.channel)
await connectionHandler.sendPing(ping)
return try await ping.getRoundTripTime()
}
}

View File

@@ -0,0 +1,202 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Dispatch
import Foundation
import Logging
import NIO
import NIOFoundationCompat
public class NatsClientOptions {
private var urls: [URL] = []
private var pingInterval: TimeInterval = 60.0
private var reconnectWait: TimeInterval = 2.0
private var maxReconnects: Int?
private var initialReconnect = false
private var noRandomize = false
private var auth: Auth? = nil
private var withTls = false
private var tlsFirst = false
private var rootCertificate: URL? = nil
private var clientCertificate: URL? = nil
private var clientKey: URL? = nil
private var inboxPrefix: String = "_INBOX."
public init() {}
/// Sets the prefix for inbox subjects used for request/reply.
/// Defaults to "_INBOX."
public func inboxPrefix(_ prefix: String) -> NatsClientOptions {
if prefix.isEmpty {
self.inboxPrefix = "_INBOX."
return self
}
if prefix.last != "." {
self.inboxPrefix = prefix + "."
return self
}
self.inboxPrefix = prefix
return self
}
/// A list of server urls that a client can connect to.
public func urls(_ urls: [URL]) -> NatsClientOptions {
self.urls = urls
return self
}
/// A single url that the client can connect to.
public func url(_ url: URL) -> NatsClientOptions {
self.urls = [url]
return self
}
/// The interval with which the client will send pings to NATS server.
/// Defaults to 60s.
public func pingInterval(_ pingInterval: TimeInterval) -> NatsClientOptions {
self.pingInterval = pingInterval
return self
}
/// Wait time between reconnect attempts.
/// Defaults to 2s.
public func reconnectWait(_ reconnectWait: TimeInterval) -> NatsClientOptions {
self.reconnectWait = reconnectWait
return self
}
/// Maximum number of reconnect attempts after each disconnect.
/// Defaults to unlimited.
public func maxReconnects(_ maxReconnects: Int) -> NatsClientOptions {
self.maxReconnects = maxReconnects
return self
}
/// Username and password used to connect to the server.
public func usernameAndPassword(_ username: String, _ password: String) -> NatsClientOptions {
if self.auth == nil {
self.auth = Auth(user: username, password: password)
} else {
self.auth?.user = username
self.auth?.password = password
}
return self
}
/// Token used for token auth to NATS server.
public func token(_ token: String) -> NatsClientOptions {
if self.auth == nil {
self.auth = Auth(token: token)
} else {
self.auth?.token = token
}
return self
}
/// The location of a credentials file containing user JWT and Nkey seed.
public func credentialsFile(_ credentials: URL) -> NatsClientOptions {
if self.auth == nil {
self.auth = Auth.fromCredentials(credentials)
} else {
self.auth?.credentialsPath = credentials
}
return self
}
/// The location of a public nkey file.
/// This and ``NatsClientOptions/nkey(_:)`` are mutually exclusive.
public func nkeyFile(_ nkey: URL) -> NatsClientOptions {
if self.auth == nil {
self.auth = Auth.fromNkey(nkey)
} else {
self.auth?.nkeyPath = nkey
}
return self
}
/// Public nkey.
/// This and ``NatsClientOptions/nkeyFile(_:)`` are mutually exclusive.
public func nkey(_ nkey: String) -> NatsClientOptions {
if self.auth == nil {
self.auth = Auth.fromNkey(nkey)
} else {
self.auth?.nkey = nkey
}
return self
}
/// Indicates whether the client requires an SSL connection.
public func requireTls() -> NatsClientOptions {
self.withTls = true
return self
}
/// Indicates whether the client will attempt to perform a TLS handshake first, that is
/// before receiving the INFO protocol. This requires the server to also be
/// configured with such option, otherwise the connection will fail.
public func withTlsFirst() -> NatsClientOptions {
self.tlsFirst = true
return self
}
/// The location of a root CAs file.
public func rootCertificates(_ rootCertificate: URL) -> NatsClientOptions {
self.rootCertificate = rootCertificate
return self
}
/// The location of a client cert file.
public func clientCertificate(_ clientCertificate: URL, _ clientKey: URL) -> NatsClientOptions {
self.clientCertificate = clientCertificate
self.clientKey = clientKey
return self
}
/// Indicates whether the client will retain the order of URLs to connect to provided in ``NatsClientOptions/urls(_:)``
/// If not set, the client will randomize the server pool.
public func retainServersOrder() -> NatsClientOptions {
self.noRandomize = true
return self
}
/// By default, ``NatsClient/connect()`` will return an error if
/// the connection to the server cannot be established.
///
/// Setting `retryOnfailedConnect()` makes the client
/// establish the connection in the background even if the initial connect fails.
public func retryOnfailedConnect() -> NatsClientOptions {
self.initialReconnect = true
return self
}
public func build() -> NatsClient {
let client = NatsClient()
client.inboxPrefix = inboxPrefix
client.connectionHandler = ConnectionHandler(
inputBuffer: client.buffer,
urls: urls,
reconnectWait: reconnectWait,
maxReconnects: maxReconnects,
retainServersOrder: noRandomize,
pingInterval: pingInterval,
auth: auth,
requireTls: withTls,
tlsFirst: tlsFirst,
clientCertificate: clientCertificate,
clientKey: clientKey,
rootCertificate: rootCertificate,
retryOnFailedConnect: initialReconnect
)
return client
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,244 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
public protocol NatsErrorProtocol: Error, CustomStringConvertible {}
public enum NatsError {
public enum ServerError: NatsErrorProtocol, Equatable {
case staleConnection
case maxConnectionsExceeded
case authorizationViolation
case authenticationExpired
case authenticationRevoked
case authenticationTimeout
case permissionsViolation(Operation, String, String?)
case proto(String)
public var description: String {
switch self {
case .staleConnection:
return "nats: stale connection"
case .maxConnectionsExceeded:
return "nats: maximum connections exceeded"
case .authorizationViolation:
return "nats: authorization violation"
case .authenticationExpired:
return "nats: authentication expired"
case .authenticationRevoked:
return "nats: authentication revoked"
case .authenticationTimeout:
return "nats: authentication timeout"
case .permissionsViolation(let operation, let subject, let queue):
if let queue {
return
"nats: permissions violation for operation \"\(operation)\" on subject \"\(subject)\" using queue \"\(queue)\""
} else {
return
"nats: permissions violation for operation \"\(operation)\" on subject \"\(subject)\""
}
case .proto(let error):
return "nats: \(error)"
}
}
var normalizedError: String {
return description.trimWhitespacesAndApostrophes().lowercased()
}
init(_ error: String) {
let normalizedError = error.trimWhitespacesAndApostrophes().lowercased()
if normalizedError.contains("stale connection") {
self = .staleConnection
} else if normalizedError.contains("maximum connections exceeded") {
self = .maxConnectionsExceeded
} else if normalizedError.contains("authorization violation") {
self = .authorizationViolation
} else if normalizedError.contains("authentication expired") {
self = .authenticationExpired
} else if normalizedError.contains("authentication revoked") {
self = .authenticationRevoked
} else if normalizedError.contains("authentication timeout") {
self = .authenticationTimeout
} else if normalizedError.contains("permissions violation") {
if let (operation, subject, queue) = NatsError.ServerError.parsePermissions(
error: error)
{
self = .permissionsViolation(operation, subject, queue)
} else {
self = .proto(error)
}
} else {
self = .proto(error)
}
}
public enum Operation: String, Equatable {
case publish = "Publish"
case subscribe = "Subscription"
}
internal static func parsePermissions(error: String) -> (Operation, String, String?)? {
let pattern = "(Publish|Subscription) to \"(\\S+)\""
let regex = try! NSRegularExpression(pattern: pattern)
let matches = regex.matches(
in: error, options: [], range: NSRange(location: 0, length: error.utf16.count))
guard let match = matches.first else {
return nil
}
var operation: Operation?
if let operationRange = Range(match.range(at: 1), in: error) {
let operationString = String(error[operationRange])
operation = Operation(rawValue: operationString)
}
var subject: String?
if let subjectRange = Range(match.range(at: 2), in: error) {
subject = String(error[subjectRange])
}
let queuePattern = "using queue \"(\\S+)\""
let queueRegex = try! NSRegularExpression(pattern: queuePattern)
let queueMatches = queueRegex.matches(
in: error, options: [], range: NSRange(location: 0, length: error.utf16.count))
var queue: String?
if let match = queueMatches.first, let queueRange = Range(match.range(at: 1), in: error)
{
queue = String(error[queueRange])
}
if let operation, let subject {
return (operation, subject, queue)
} else {
return nil
}
}
}
public enum ProtocolError: NatsErrorProtocol, Equatable {
case invalidOperation(String)
case parserFailure(String)
public var description: String {
switch self {
case .invalidOperation(let op):
return "nats: unknown server operation: \(op)"
case .parserFailure(let cause):
return "nats: parser failure: \(cause)"
}
}
}
public enum ClientError: NatsErrorProtocol {
case internalError(String)
case maxReconnects
case connectionClosed
case io(Error)
case invalidConnection(String)
case cancelled
case alreadyConnected
public var description: String {
switch self {
case .internalError(let error):
return "nats: internal error: \(error)"
case .maxReconnects:
return "nats: max reconnects exceeded"
case .connectionClosed:
return "nats: connection is closed"
case .io(let error):
return "nats: IO error: \(error)"
case .invalidConnection(let error):
return "nats: \(error)"
case .cancelled:
return "nats: operation cancelled"
case .alreadyConnected:
return "nats: client is already connected or connecting"
}
}
}
public enum ConnectError: NatsErrorProtocol {
case invalidConfig(String)
case tlsFailure(Error)
case timeout
case dns(Error)
case io(Error)
public var description: String {
switch self {
case .invalidConfig(let error):
return "nats: invalid client configuration: \(error)"
case .tlsFailure(let error):
return "nats: TLS error: \(error)"
case .timeout:
return "nats: timed out waiting for connection"
case .dns(let error):
return "nats: DNS lookup error: \(error)"
case .io(let error):
return "nats: error establishing connection: \(error)"
}
}
}
public enum RequestError: NatsErrorProtocol, Equatable {
case noResponders
case timeout
case permissionDenied
public var description: String {
switch self {
case .noResponders:
return "nats: no responders available for request"
case .timeout:
return "nats: request timed out"
case .permissionDenied:
return "nats: permission denied"
}
}
}
public enum SubscriptionError: NatsErrorProtocol, Equatable {
case invalidSubject
case invalidQueue
case permissionDenied
case subscriptionClosed
public var description: String {
switch self {
case .invalidSubject:
return "nats: invalid subject name"
case .invalidQueue:
return "nats: invalid queue group name"
case .permissionDenied:
return "nats: permission denied"
case .subscriptionClosed:
return "nats: subscription closed"
}
}
}
public enum ParseHeaderError: NatsErrorProtocol, Equatable {
case invalidCharacter
public var description: String {
switch self {
case .invalidCharacter:
return
"nats: invalid header name (name cannot contain non-ascii alphanumeric characters other than '-')"
}
}
}
}

View File

@@ -0,0 +1,159 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
// Represents NATS header field value in Swift.
public struct NatsHeaderValue: Equatable, CustomStringConvertible {
private var inner: String
public init(_ value: String) {
self.inner = value
}
public var description: String {
return inner
}
}
// Custom header representation in Swift
public struct NatsHeaderName: Equatable, Hashable, CustomStringConvertible {
private var inner: String
public init(_ value: String) throws {
if value.contains(where: { $0 == ":" || $0.asciiValue! < 33 || $0.asciiValue! > 126 }) {
throw NatsError.ParseHeaderError.invalidCharacter
}
self.inner = value
}
public var description: String {
return inner
}
// Example of standard headers
public static let natsStream = try! NatsHeaderName("Nats-Stream")
public static let natsSequence = try! NatsHeaderName("Nats-Sequence")
public static let natsTimestamp = try! NatsHeaderName("Nats-Time-Stamp")
public static let natsSubject = try! NatsHeaderName("Nats-Subject")
// Add other standard headers as needed...
}
// Represents a NATS header map in Swift.
public struct NatsHeaderMap: Equatable {
private var inner: [NatsHeaderName: [NatsHeaderValue]]
internal var status: StatusCode? = nil
internal var description: String? = nil
public init() {
self.inner = [:]
}
public init(from headersString: String) throws {
self.inner = [:]
let headersArray = headersString.split(separator: "\r\n")
let versionLine = headersArray[0]
guard versionLine.hasPrefix(Data.versionLinePrefix) else {
throw NatsError.ProtocolError.parserFailure(
"header version line does not begin with `NATS/1.0`")
}
let versionLineSuffix =
versionLine
.dropFirst(Data.versionLinePrefix.count)
.trimmingCharacters(in: .whitespacesAndNewlines)
// handle inlines status and description
if versionLineSuffix.count > 0 {
let statusAndDesc = versionLineSuffix.split(
separator: " ", maxSplits: 1)
guard let status = StatusCode(statusAndDesc[0]) else {
throw NatsError.ProtocolError.parserFailure("could not parse status parameter")
}
self.status = status
if statusAndDesc.count > 1 {
self.description = String(statusAndDesc[1])
}
}
for header in headersArray.dropFirst() {
let headerParts = header.split(separator: ":", maxSplits: 1)
if headerParts.count == 2 {
self.append(
try NatsHeaderName(String(headerParts[0])),
NatsHeaderValue(String(headerParts[1]).trimmingCharacters(in: .whitespaces)))
} else {
logger.error("Error parsing header: \(header)")
}
}
}
var isEmpty: Bool {
return inner.isEmpty
}
public mutating func insert(_ name: NatsHeaderName, _ value: NatsHeaderValue) {
self.inner[name] = [value]
}
public mutating func append(_ name: NatsHeaderName, _ value: NatsHeaderValue) {
if inner[name] != nil {
inner[name]?.append(value)
} else {
insert(name, value)
}
}
public func get(_ name: NatsHeaderName) -> NatsHeaderValue? {
return inner[name]?.first
}
public func getAll(_ name: NatsHeaderName) -> [NatsHeaderValue] {
return inner[name] ?? []
}
//TODO(jrm): can we use unsafe methods here? Probably yes.
func toBytes() -> [UInt8] {
var bytes: [UInt8] = []
bytes.append(contentsOf: "NATS/1.0\r\n".utf8)
for (name, values) in inner {
for value in values {
bytes.append(contentsOf: name.description.utf8)
bytes.append(contentsOf: ":".utf8)
bytes.append(contentsOf: value.description.utf8)
bytes.append(contentsOf: "\r\n".utf8)
}
}
bytes.append(contentsOf: "\r\n".utf8)
return bytes
}
// Implementing the == operator to exclude status and desc internal properties
public static func == (lhs: NatsHeaderMap, rhs: NatsHeaderMap) -> Bool {
return lhs.inner == rhs.inner
}
}
extension NatsHeaderMap {
public subscript(name: NatsHeaderName) -> NatsHeaderValue? {
get {
return get(name)
}
set {
if let value = newValue {
insert(name, value)
} else {
inner[name] = nil
}
}
}
}

View File

@@ -0,0 +1,69 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
class JwtUtils {
// This regular expression is equivalent to the one used in Rust.
static let userConfigRE: NSRegularExpression = {
do {
return try NSRegularExpression(
pattern:
"\\s*(?:(?:-{3,}.*-{3,}\\r?\\n)([\\w\\-.=]+)(?:\\r?\\n-{3,}.*-{3,}\\r?\\n))",
options: [])
} catch {
fatalError("Invalid regular expression: \(error)")
}
}()
/// Parses a credentials file and returns its user JWT.
static func parseDecoratedJWT(contents: String) -> String? {
let matches = userConfigRE.matches(
in: contents, options: [], range: NSRange(contents.startIndex..., in: contents))
if let match = matches.first, let range = Range(match.range(at: 1), in: contents) {
return String(contents[range])
}
return nil
}
/// Parses a credentials file and returns its user JWT.
static func parseDecoratedJWT(contents: Data) -> Data? {
guard let contentsString = String(data: contents, encoding: .utf8) else {
return nil
}
if let match = parseDecoratedJWT(contents: contentsString) {
return match.data(using: .utf8)
}
return nil
}
/// Parses a credentials file and returns its nkey.
static func parseDecoratedNKey(contents: String) -> String? {
let matches = userConfigRE.matches(
in: contents, options: [], range: NSRange(contents.startIndex..., in: contents))
if matches.count > 1, let range = Range(matches[1].range(at: 1), in: contents) {
return String(contents[range])
}
return nil
}
/// Parses a credentials file and returns its nkey.
static func parseDecoratedNKey(contents: Data) -> Data? {
guard let contentsString = String(data: contents, encoding: .utf8) else {
return nil
}
if let match = parseDecoratedNKey(contents: contentsString) {
return match.data(using: .utf8)
}
return nil
}
}

View File

@@ -0,0 +1,60 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
public struct NatsMessage {
public let payload: Data?
public let subject: String
public let replySubject: String?
public let length: Int
public let headers: NatsHeaderMap?
public let status: StatusCode?
public let description: String?
}
public struct StatusCode: Equatable {
public static let idleHeartbeat = StatusCode(value: 100)
public static let ok = StatusCode(value: 200)
public static let badRequest = StatusCode(value: 400)
public static let notFound = StatusCode(value: 404)
public static let timeout = StatusCode(value: 408)
public static let noResponders = StatusCode(value: 503)
public static let requestTerminated = StatusCode(value: 409)
let value: UInt16
// non-optional initializer for static status codes
private init(value: UInt16) {
self.value = value
}
init?(_ value: UInt16) {
if !(100..<1000 ~= value) {
return nil
}
self.value = value
}
init?(_ value: any StringProtocol) {
guard let status = UInt16(value) else {
return nil
}
if !(100..<1000 ~= status) {
return nil
}
self.value = status
}
}

View File

@@ -0,0 +1,341 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
import NIO
internal struct NatsOperation: RawRepresentable, Hashable {
let rawValue: String
static let connect = NatsOperation(rawValue: "CONNECT")
static let subscribe = NatsOperation(rawValue: "SUB")
static let unsubscribe = NatsOperation(rawValue: "UNSUB")
static let publish = NatsOperation(rawValue: "PUB")
static let hpublish = NatsOperation(rawValue: "HPUB")
static let message = NatsOperation(rawValue: "MSG")
static let hmessage = NatsOperation(rawValue: "HMSG")
static let info = NatsOperation(rawValue: "INFO")
static let ok = NatsOperation(rawValue: "+OK")
static let error = NatsOperation(rawValue: "-ERR")
static let ping = NatsOperation(rawValue: "PING")
static let pong = NatsOperation(rawValue: "PONG")
var rawBytes: String.UTF8View {
self.rawValue.utf8
}
static func allOperations() -> [NatsOperation] {
return [
.connect, .subscribe, .unsubscribe, .publish, .message, .hmessage, .info, .ok, .error,
.ping, .pong,
]
}
}
enum ServerOp {
case ok
case info(ServerInfo)
case ping
case pong
case error(NatsError.ServerError)
case message(MessageInbound)
case hMessage(HMessageInbound)
static func parse(from msg: Data) throws -> ServerOp {
guard msg.count > 2 else {
throw NatsError.ProtocolError.parserFailure(
"unable to parse inbound message: \(String(data: msg, encoding: .utf8)!)")
}
guard let msgType = msg.getMessageType() else {
throw NatsError.ProtocolError.invalidOperation(String(data: msg, encoding: .utf8)!)
}
switch msgType {
case .message:
return try message(MessageInbound.parse(data: msg))
case .hmessage:
return try hMessage(HMessageInbound.parse(data: msg))
case .info:
return try info(ServerInfo.parse(data: msg))
case .ok:
return ok
case .error:
if let errMsg = msg.removePrefix(Data(NatsOperation.error.rawBytes)).toString() {
return error(NatsError.ServerError(errMsg))
}
return error(NatsError.ServerError("unexpected error"))
case .ping:
return ping
case .pong:
return pong
default:
throw NatsError.ProtocolError.invalidOperation(
"unknown server op: \(String(data: msg, encoding: .utf8)!)")
}
}
}
internal struct HMessageInbound: Equatable {
private static let newline = UInt8(ascii: "\n")
private static let space = UInt8(ascii: " ")
var subject: String
var sid: UInt64
var reply: String?
var payload: Data?
var headers: NatsHeaderMap
var headersLength: Int
var length: Int
var status: StatusCode?
var description: String?
// Parse the operation syntax: HMSG <subject> <sid> [reply-to]
internal static func parse(data: Data) throws -> HMessageInbound {
let protoComponents =
data
.dropFirst(NatsOperation.hmessage.rawValue.count) // Assuming msg starts with "HMSG "
.split(separator: space)
.filter { !$0.isEmpty }
let parseArgs: ((Data, Data, Data?, Data, Data) throws -> HMessageInbound) = {
subjectData, sidData, replyData, lengthHeaders, lengthData in
let subject = String(decoding: subjectData, as: UTF8.self)
guard let sid = UInt64(String(decoding: sidData, as: UTF8.self)) else {
throw NatsError.ProtocolError.parserFailure(
"unable to parse subscription ID as number")
}
var replySubject: String? = nil
if let replyData = replyData {
replySubject = String(decoding: replyData, as: UTF8.self)
}
let headersLength = Int(String(decoding: lengthHeaders, as: UTF8.self)) ?? 0
let length = Int(String(decoding: lengthData, as: UTF8.self)) ?? 0
return HMessageInbound(
subject: subject, sid: sid, reply: replySubject, payload: nil,
headers: NatsHeaderMap(),
headersLength: headersLength, length: length)
}
var msg: HMessageInbound
switch protoComponents.count {
case 4:
msg = try parseArgs(
protoComponents[0], protoComponents[1], nil, protoComponents[2],
protoComponents[3])
case 5:
msg = try parseArgs(
protoComponents[0], protoComponents[1], protoComponents[2], protoComponents[3],
protoComponents[4])
default:
throw NatsError.ProtocolError.parserFailure("unable to parse inbound message header")
}
return msg
}
}
// TODO(pp): add headers and HMSG parsing
internal struct MessageInbound: Equatable {
private static let newline = UInt8(ascii: "\n")
private static let space = UInt8(ascii: " ")
var subject: String
var sid: UInt64
var reply: String?
var payload: Data?
var length: Int
// Parse the operation syntax: MSG <subject> <sid> [reply-to]
internal static func parse(data: Data) throws -> MessageInbound {
let protoComponents =
data
.dropFirst(NatsOperation.message.rawValue.count) // Assuming msg starts with "MSG "
.split(separator: space)
.filter { !$0.isEmpty }
let parseArgs: ((Data, Data, Data?, Data) throws -> MessageInbound) = {
subjectData, sidData, replyData, lengthData in
let subject = String(decoding: subjectData, as: UTF8.self)
guard let sid = UInt64(String(decoding: sidData, as: UTF8.self)) else {
throw NatsError.ProtocolError.parserFailure(
"unable to parse subscription ID as number")
}
var replySubject: String? = nil
if let replyData = replyData {
replySubject = String(decoding: replyData, as: UTF8.self)
}
let length = Int(String(decoding: lengthData, as: UTF8.self)) ?? 0
return MessageInbound(
subject: subject, sid: sid, reply: replySubject, payload: nil, length: length)
}
var msg: MessageInbound
switch protoComponents.count {
case 3:
msg = try parseArgs(protoComponents[0], protoComponents[1], nil, protoComponents[2])
case 4:
msg = try parseArgs(
protoComponents[0], protoComponents[1], protoComponents[2], protoComponents[3])
default:
throw NatsError.ProtocolError.parserFailure("unable to parse inbound message header")
}
return msg
}
}
/// Struct representing server information in NATS.
struct ServerInfo: Codable, Equatable {
/// The unique identifier of the NATS server.
let serverId: String
/// Generated Server Name.
let serverName: String
/// The host specified in the cluster parameter/options.
let host: String
/// The port number specified in the cluster parameter/options.
let port: UInt16
/// The version of the NATS server.
let version: String
/// If this is set, then the server should try to authenticate upon connect.
let authRequired: Bool?
/// If this is set, then the server must authenticate using TLS.
let tlsRequired: Bool?
/// Maximum payload size that the server will accept.
let maxPayload: UInt
/// The protocol version in use.
let proto: Int8
/// The server-assigned client ID. This may change during reconnection.
let clientId: UInt64?
/// The version of golang the NATS server was built with.
let go: String
/// The nonce used for nkeys.
let nonce: String?
/// A list of server urls that a client can connect to.
let connectUrls: [String]?
/// The client IP as known by the server.
let clientIp: String
/// Whether the server supports headers.
let headers: Bool
/// Whether server goes into lame duck
private let _lameDuckMode: Bool?
var lameDuckMode: Bool {
return _lameDuckMode ?? false
}
private static let prefix = NatsOperation.info.rawValue.data(using: .utf8)!
private enum CodingKeys: String, CodingKey {
case serverId = "server_id"
case serverName = "server_name"
case host
case port
case version
case authRequired = "auth_required"
case tlsRequired = "tls_required"
case maxPayload = "max_payload"
case proto
case clientId = "client_id"
case go
case nonce
case connectUrls = "connect_urls"
case clientIp = "client_ip"
case headers
case _lameDuckMode = "ldm"
}
internal static func parse(data: Data) throws -> ServerInfo {
let info = data.removePrefix(prefix)
return try JSONDecoder().decode(self, from: info)
}
}
enum ClientOp {
case publish((subject: String, reply: String?, payload: Data?, headers: NatsHeaderMap?))
case subscribe((sid: UInt64, subject: String, queue: String?))
case unsubscribe((sid: UInt64, max: UInt64?))
case connect(ConnectInfo)
case ping
case pong
}
/// Info to construct a CONNECT message.
struct ConnectInfo: Encodable {
/// Turns on +OK protocol acknowledgments.
var verbose: Bool
/// Turns on additional strict format checking, e.g. for properly formed
/// subjects.
var pedantic: Bool
/// User's JWT.
var userJwt: String?
/// Public nkey.
var nkey: String
/// Signed nonce, encoded to Base64URL.
var signature: String?
/// Optional client name.
var name: String
/// If set to `true`, the server (version 1.2.0+) will not send originating
/// messages from this connection to its own subscriptions. Clients should
/// set this to `true` only for server supporting this feature, which is
/// when proto in the INFO protocol is set to at least 1.
var echo: Bool
/// The implementation language of the client.
var lang: String
/// The version of the client.
var version: String
/// Sending 0 (or absent) indicates client supports original protocol.
/// Sending 1 indicates that the client supports dynamic reconfiguration
/// of cluster topology changes by asynchronously receiving INFO messages
/// with known servers it can reconnect to.
var natsProtocol: NatsProtocol
/// Indicates whether the client requires an SSL connection.
var tlsRequired: Bool
/// Connection username (if `auth_required` is set)
var user: String
/// Connection password (if auth_required is set)
var pass: String
/// Client authorization token (if auth_required is set)
var authToken: String
/// Whether the client supports the usage of headers.
var headers: Bool
/// Whether the client supports no_responders.
var noResponders: Bool
enum CodingKeys: String, CodingKey {
case verbose
case pedantic
case userJwt = "jwt"
case nkey
case signature = "sig" // Custom key name for JSON
case name
case echo
case lang
case version
case natsProtocol = "protocol"
case tlsRequired = "tls_required"
case user
case pass
case authToken = "auth_token"
case headers
case noResponders = "no_responders"
}
}
enum NatsProtocol: Encodable {
case original
case dynamic
func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .original:
try container.encode(0)
case .dynamic:
try container.encode(1)
}
}
}

View File

@@ -0,0 +1,185 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
// TODO(pp): Implement slow consumer
public class NatsSubscription: AsyncSequence {
public typealias Element = NatsMessage
public typealias AsyncIterator = SubscriptionIterator
public let subject: String
public let queue: String?
internal var max: UInt64?
internal var delivered: UInt64 = 0
internal let sid: UInt64
private var buffer: [Result<Element, NatsError.SubscriptionError>]
private let capacity: UInt64
private var closed = false
private var continuation:
CheckedContinuation<Result<Element, NatsError.SubscriptionError>?, Never>?
private let lock = NSLock()
private let conn: ConnectionHandler
private static let defaultSubCapacity: UInt64 = 512 * 1024
convenience init(sid: UInt64, subject: String, queue: String?, conn: ConnectionHandler) throws {
try self.init(
sid: sid, subject: subject, queue: queue, capacity: NatsSubscription.defaultSubCapacity,
conn: conn)
}
init(
sid: UInt64, subject: String, queue: String?, capacity: UInt64, conn: ConnectionHandler
) throws {
if !NatsSubscription.validSubject(subject) {
throw NatsError.SubscriptionError.invalidSubject
}
if let queue, !NatsSubscription.validQueue(queue) {
throw NatsError.SubscriptionError.invalidQueue
}
self.sid = sid
self.subject = subject
self.queue = queue
self.capacity = capacity
self.buffer = []
self.conn = conn
}
public func makeAsyncIterator() -> SubscriptionIterator {
return SubscriptionIterator(subscription: self)
}
func receiveMessage(_ message: NatsMessage) {
lock.withLock {
if let continuation = self.continuation {
// Immediately use the continuation if it exists
self.continuation = nil
continuation.resume(returning: .success(message))
} else if buffer.count < capacity {
// Only append to buffer if no continuation is available
// TODO(pp): Hadndle SlowConsumer as subscription event
buffer.append(.success(message))
}
}
}
func receiveError(_ error: NatsError.SubscriptionError) {
lock.withLock {
if let continuation = self.continuation {
// Immediately use the continuation if it exists
self.continuation = nil
continuation.resume(returning: .failure(error))
} else {
buffer.append(.failure(error))
}
}
}
internal func complete() {
lock.withLock {
closed = true
if let continuation {
self.continuation = nil
continuation.resume(returning: nil)
}
}
}
// AsyncIterator implementation
public class SubscriptionIterator: AsyncIteratorProtocol {
private var subscription: NatsSubscription
init(subscription: NatsSubscription) {
self.subscription = subscription
}
public func next() async throws -> Element? {
try await subscription.nextMessage()
}
}
private func nextMessage() async throws -> Element? {
let result: Result<Element, NatsError.SubscriptionError>? = await withCheckedContinuation {
continuation in
lock.withLock {
if closed {
continuation.resume(returning: nil)
return
}
delivered += 1
if let message = buffer.first {
buffer.removeFirst()
continuation.resume(returning: message)
} else {
self.continuation = continuation
}
}
}
if let max, delivered >= max {
conn.removeSub(sub: self)
}
switch result {
case .success(let msg):
return msg
case .failure(let error):
throw error
default:
return nil
}
}
/// Unsubscribes from subscription.
///
/// - Parameter after: If set, unsubscribe will be performed after reaching given number of messages.
/// If it already reached or surpassed the passed value, it will immediately stop.
///
/// > **Throws:**
/// > - ``NatsError/ClientError/connectionClosed`` if the conneciton is closed.
/// > - ``NatsError/SubscriptionError/subscriptionClosed`` if the subscription is already closed
public func unsubscribe(after: UInt64? = nil) async throws {
logger.info("unsubscribe from subject \(subject)")
if case .closed = self.conn.currentState {
throw NatsError.ClientError.connectionClosed
}
if self.closed {
throw NatsError.SubscriptionError.subscriptionClosed
}
return try await self.conn.unsubscribe(sub: self, max: after)
}
// validateSubject will do a basic subject validation.
// Spaces are not allowed and all tokens should be > 0 in length.
private static func validSubject(_ subj: String) -> Bool {
let whitespaceCharacterSet = CharacterSet.whitespacesAndNewlines
if subj.rangeOfCharacter(from: whitespaceCharacterSet) != nil {
return false
}
let tokens = subj.split(separator: ".")
for token in tokens {
if token.isEmpty {
return false
}
}
return true
}
// validQueue will check a queue name for whitespaces.
private static func validQueue(_ queue: String) -> Bool {
let whitespaceCharacterSet = CharacterSet.whitespacesAndNewlines
return queue.rangeOfCharacter(from: whitespaceCharacterSet) == nil
}
}

View File

@@ -0,0 +1,39 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
import NIOCore
internal class RttCommand {
let startTime = DispatchTime.now()
let promise: EventLoopPromise<TimeInterval>?
static func makeFrom(channel: Channel?) -> RttCommand {
RttCommand(promise: channel?.eventLoop.makePromise(of: TimeInterval.self))
}
private init(promise: EventLoopPromise<TimeInterval>?) {
self.promise = promise
}
func setRoundTripTime() {
let now = DispatchTime.now()
let nanoTime = now.uptimeNanoseconds - startTime.uptimeNanoseconds
let rtt = TimeInterval(nanoTime) / 1_000_000_000 // Convert nanos to seconds
promise?.succeed(rtt)
}
func getRoundTripTime() async throws -> TimeInterval {
try await promise?.futureResult.get() ?? 0
}
}