Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions Sources/HummingbirdCompression/Allocator.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2025 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIOConcurrencyHelpers

protocol ZlibAllocator<Value> {
associatedtype Value
func allocate() throws -> Value
func free(_ compressor: inout Value?)
}

/// Wrapper for value that uses allocator to manage its lifecycle
class AllocatedValue<Allocator: ZlibAllocator> {
let value: Allocator.Value
let allocator: Allocator

init(allocator: Allocator) throws {
self.allocator = allocator
self.value = try allocator.allocate()
}

deinit {
var optionalValue: Allocator.Value? = self.value
self.allocator.free(&optionalValue)
}
}

/// Type that can be used with the PoolAllocator
protocol PoolReusable {
func reset() throws
}

/// Allocator that keeps a pool of values around to be re-used.
///
/// It will use a value from the pool if it isnt empty. Otherwise it will
/// allocate a new value. When values are freed they are passed back to the pool
/// up until the point where the pool grows to its maximum size.
struct PoolAllocator<BaseAllocator: ZlibAllocator>: ZlibAllocator where BaseAllocator.Value: PoolReusable {
typealias Value = BaseAllocator.Value
@usableFromInline
let base: BaseAllocator
@usableFromInline
let poolSize: Int
@usableFromInline
let values: NIOLockedValueBox<[Value]>

@inlinable
init(size: Int, base: BaseAllocator) {
self.base = base
self.poolSize = size
self.values = .init([])
}

@inlinable
func allocate() throws -> Value {
let value = self.values.withLockedValue {
$0.popLast()
}
if let value {
try value.reset()
return value
}
return try base.allocate()
}

@inlinable
func free(_ value: inout Value?) {
guard let nonOptionalValue = value else { preconditionFailure("Cannot ball free twice on a compressor") }
self.values.withLockedValue {
if $0.count < poolSize {
$0.append(nonOptionalValue)
}
}
base.free(&value)
}
}
54 changes: 44 additions & 10 deletions Sources/HummingbirdCompression/CompressedBodyWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2024 the Hummingbird authors
// Copyright (c) 2021-2025 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand All @@ -15,24 +15,25 @@
import CompressNIO
import Hummingbird
import Logging
import NIOConcurrencyHelpers

// ResponseBodyWriter that writes a compressed version of the response to a parent writer
final class CompressedBodyWriter<ParentWriter: ResponseBodyWriter & Sendable>: ResponseBodyWriter {
final class CompressedBodyWriter<ParentWriter: ResponseBodyWriter & Sendable, Allocator: ZlibAllocator>: ResponseBodyWriter
where Allocator.Value == ZlibCompressor {
var parentWriter: ParentWriter
private let compressor: ZlibCompressor
private var compressor: AllocatedValue<Allocator>
private var window: ByteBuffer
var lastBuffer: ByteBuffer?
let logger: Logger

init(
parent: ParentWriter,
algorithm: ZlibAlgorithm,
configuration: ZlibConfiguration,
allocator: Allocator,
windowSize: Int,
logger: Logger
) throws {
self.parentWriter = parent
self.compressor = try ZlibCompressor(algorithm: algorithm, configuration: configuration)
self.compressor = try .init(allocator: allocator)
self.window = ByteBufferAllocator().buffer(capacity: windowSize)
self.lastBuffer = nil
self.logger = logger
Expand All @@ -41,7 +42,7 @@
/// Write response buffer
func write(_ buffer: ByteBuffer) async throws {
var buffer = buffer
try await buffer.compressStream(with: self.compressor, window: &self.window, flush: .sync) { buffer in
try await buffer.compressStream(with: compressor.value, window: &self.window, flush: .sync) { buffer in
try await self.parentWriter.write(buffer)
}
// need to store the last buffer so it can be finished once the writer is done
Expand All @@ -56,7 +57,7 @@
// keep finishing stream until we don't get a buffer overflow
while true {
do {
try lastBuffer.compressStream(to: &self.window, with: self.compressor, flush: .finish)
try lastBuffer.compressStream(to: &self.window, with: compressor.value, flush: .finish)
try await self.parentWriter.write(self.window)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this write fails, the allocation isn't freed. Can we use a non-copyable with a deinit to handle this instead? Or a class or something similar.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using AllocatedValue class to hold these.

self.window.clear()
break
Expand All @@ -67,7 +68,6 @@
}
}
self.lastBuffer = nil

try await self.parentWriter.finish(trailingHeaders)
}
}
Expand All @@ -79,12 +79,46 @@
/// - windowSize: Window size (in bytes) to use when compressing data
/// - logger: Logger used to output compression errors
/// - Returns: new ``HummingbirdCore/ResponseBodyWriter``
func compressed<Allocator: ZlibAllocator<ZlibCompressor>>(
compressorPool: Allocator,
windowSize: Int,
logger: Logger
) throws -> some ResponseBodyWriter {
try CompressedBodyWriter(parent: self, allocator: compressorPool, windowSize: windowSize, logger: logger)
}

/// Return ``HummingbirdCore/ResponseBodyWriter`` that compresses the contents of this ResponseBodyWriter
/// - Parameters:
/// - algorithm: Compression algorithm
/// - configuration: Zlib configuration
/// - windowSize: Window size (in bytes) to use when compressing data
/// - logger: Logger used to output compression errors
/// - Returns: new ``HummingbirdCore/ResponseBodyWriter``
public func compressed(
algorithm: ZlibAlgorithm,
configuration: ZlibConfiguration,
windowSize: Int,
logger: Logger
) throws -> some ResponseBodyWriter {
try CompressedBodyWriter(parent: self, algorithm: algorithm, configuration: configuration, windowSize: windowSize, logger: logger)
try compressed(
compressorPool: ZlibCompressorAllocator(algorithm: algorithm, configuration: configuration),
windowSize: windowSize,
logger: logger
)
}

Check warning on line 108 in Sources/HummingbirdCompression/CompressedBodyWriter.swift

View check run for this annotation

Codecov / codecov/patch

Sources/HummingbirdCompression/CompressedBodyWriter.swift#L103-L108

Added lines #L103 - L108 were not covered by tests
}

extension ZlibCompressor: PoolReusable {}
struct ZlibCompressorAllocator: ZlibAllocator, Sendable {
typealias Value = ZlibCompressor
let algorithm: ZlibAlgorithm
let configuration: ZlibConfiguration

func allocate() throws -> ZlibCompressor {
try ZlibCompressor(algorithm: algorithm, configuration: configuration)
}

func free(_ compressor: inout ZlibCompressor?) {
compressor = nil
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,41 @@ import Logging
/// if the content-encoding header is set to gzip or deflate then the middleware will attempt
/// to decompress the contents of the request body and pass that down the middleware chain.
public struct RequestDecompressionMiddleware<Context: RequestContext>: RouterMiddleware {
/// decompression window size
/// Decompression window size. This is not the internal zlib window
let windowSize: Int
/// Pool of gzip decompressors
let gzipDecompressorPool: PoolAllocator<ZlibDecompressorAllocator>
/// Pool of deflate compressors
let deflateDecompressorPool: PoolAllocator<ZlibDecompressorAllocator>

/// Initialize RequestDecompressionMiddleware
/// - Parameters
/// - windowSize: Decompression window size
/// - gzipDecompressorPoolSize: Maximum size of the gzip decompressor pool
/// - deflateDecompressorPoolSize: Maximum size of the deflate decompressor pool
public init(
windowSize: Int = 32768,
gzipDecompressorPoolSize: Int,
deflateDecompressorPoolSize: Int
) {
self.windowSize = windowSize
self.gzipDecompressorPool = .init(size: gzipDecompressorPoolSize, base: .init(algorithm: .gzip, windowSize: 15))
self.deflateDecompressorPool = .init(size: deflateDecompressorPoolSize, base: .init(algorithm: .zlib, windowSize: 15))
}

/// Initialize RequestDecompressionMiddleware
/// - Parameter windowSize: Decompression window size
public init(windowSize: Int = 32768) {
self.windowSize = windowSize
self.init(windowSize: windowSize, gzipDecompressorPoolSize: 16, deflateDecompressorPoolSize: 16)
}

public func handle(_ request: Request, context: Context, next: (Request, Context) async throws -> Response) async throws -> Response {
if let algorithm = algorithm(from: request.headers[values: .contentEncoding]) {
if let pool = algorithm(from: request.headers[values: .contentEncoding]) {
var request = request
request.body = .init(
asyncSequence: DecompressByteBufferSequence(
base: request.body,
algorithm: algorithm,
allocator: pool,
windowSize: self.windowSize,
logger: context.logger
)
Expand All @@ -49,13 +68,13 @@ public struct RequestDecompressionMiddleware<Context: RequestContext>: RouterMid
}

/// Determines the decompression algorithm based off content encoding header.
private func algorithm(from contentEncodingHeaders: [String]) -> ZlibAlgorithm? {
private func algorithm(from contentEncodingHeaders: [String]) -> PoolAllocator<ZlibDecompressorAllocator>? {
for encoding in contentEncodingHeaders {
switch encoding {
case "gzip":
return .gzip
return self.gzipDecompressorPool
case "deflate":
return .zlib
return self.deflateDecompressorPool
default:
break
}
Expand All @@ -65,44 +84,45 @@ public struct RequestDecompressionMiddleware<Context: RequestContext>: RouterMid
}

/// AsyncSequence of decompressed ByteBuffers
struct DecompressByteBufferSequence<Base: AsyncSequence & Sendable>: AsyncSequence, Sendable where Base.Element == ByteBuffer {
struct DecompressByteBufferSequence<Base: AsyncSequence & Sendable, Allocator: ZlibAllocator>: AsyncSequence, Sendable
where Base.Element == ByteBuffer, Allocator.Value == ZlibDecompressor, Allocator: Sendable {
typealias Element = ByteBuffer

let base: Base
let algorithm: ZlibAlgorithm
let allocator: Allocator
let windowSize: Int
let logger: Logger

init(base: Base, algorithm: ZlibAlgorithm, windowSize: Int, logger: Logger) {
init(base: Base, allocator: Allocator, windowSize: Int, logger: Logger) {
self.base = base
self.algorithm = algorithm
self.allocator = allocator
self.windowSize = windowSize
self.logger = logger
}

struct AsyncIterator: AsyncIteratorProtocol {
class AsyncIterator: AsyncIteratorProtocol {
enum State {
case uninitialized(ZlibAlgorithm, windowSize: Int)
case decompressing(ZlibDecompressor, buffer: ByteBuffer, window: ByteBuffer)
case uninitialized(Allocator, windowSize: Int)
case decompressing(AllocatedValue<Allocator>, buffer: ByteBuffer, window: ByteBuffer)
case done
}

var baseIterator: Base.AsyncIterator
var state: State

init(baseIterator: Base.AsyncIterator, algorithm: ZlibAlgorithm, windowSize: Int) {
init(baseIterator: Base.AsyncIterator, allocator: Allocator, windowSize: Int) {
self.baseIterator = baseIterator
self.state = .uninitialized(algorithm, windowSize: windowSize)
self.state = .uninitialized(allocator, windowSize: windowSize)
}

mutating func next() async throws -> ByteBuffer? {
func next() async throws -> ByteBuffer? {
switch self.state {
case .uninitialized(let algorithm, let windowSize):
case .uninitialized(let allocator, let windowSize):
guard let buffer = try await self.baseIterator.next() else {
self.state = .done
return nil
}
let decompressor = try ZlibDecompressor(algorithm: algorithm)
let decompressor = try AllocatedValue(allocator: allocator)
self.state = .decompressing(decompressor, buffer: buffer, window: ByteBufferAllocator().buffer(capacity: windowSize))
return try await self.next()

Expand All @@ -111,7 +131,7 @@ struct DecompressByteBufferSequence<Base: AsyncSequence & Sendable>: AsyncSequen
window.clear()
while true {
do {
try buffer.decompressStream(to: &window, with: decompressor)
try buffer.decompressStream(to: &window, with: decompressor.value)
} catch let error as CompressNIOError where error == .bufferOverflow {
self.state = .decompressing(decompressor, buffer: buffer, window: window)
return window
Expand All @@ -138,6 +158,21 @@ struct DecompressByteBufferSequence<Base: AsyncSequence & Sendable>: AsyncSequen
}

func makeAsyncIterator() -> AsyncIterator {
.init(baseIterator: self.base.makeAsyncIterator(), algorithm: self.algorithm, windowSize: self.windowSize)
.init(baseIterator: self.base.makeAsyncIterator(), allocator: self.allocator, windowSize: self.windowSize)
}
}

extension ZlibDecompressor: PoolReusable {}
struct ZlibDecompressorAllocator: ZlibAllocator, Sendable {
typealias Value = ZlibDecompressor
let algorithm: ZlibAlgorithm
let windowSize: Int32

func allocate() throws -> ZlibDecompressor {
try ZlibDecompressor(algorithm: algorithm, windowSize: self.windowSize)
}

func free(_ decompressor: inout ZlibDecompressor?) {
decompressor = nil
}
}
Loading