diff --git a/src/cow_ws.erl b/src/cow_ws.erl index 3bb46c5..bbcdd6d 100644 --- a/src/cow_ws.erl +++ b/src/cow_ws.erl @@ -47,7 +47,9 @@ %% LZ77 sliding window size limits. server_max_window_bits => 8..15, - client_max_window_bits => 8..15 + client_max_window_bits => 8..15, + + approx_max => nil() | integer() }. -export_type([deflate_opts/0]). @@ -148,7 +150,8 @@ negotiate_permessage_deflate1(Params, Extensions, Opts) -> deflate => Deflate, deflate_takeover => maps:get(server_context_takeover, Negotiated), inflate => Inflate, - inflate_takeover => maps:get(client_context_takeover, Negotiated)}} + inflate_takeover => maps:get(client_context_takeover, Negotiated), + approx_max => maps:get(approx_max, Opts, nil)}} end. negotiate_params([], Negotiated, RespParams) -> @@ -297,7 +300,8 @@ negotiate_x_webkit_deflate_frame(_Params, Extensions, Opts) -> deflate => Deflate, deflate_takeover => takeover, inflate => Inflate, - inflate_takeover => takeover}}. + inflate_takeover => takeover, + approx_max => maps:get(approx_max, Opts, nil)}}. %% @doc Validate the negotiated permessage-deflate extension. @@ -319,7 +323,8 @@ validate_permessage_deflate(Params, Extensions, Opts) -> deflate => Deflate, deflate_takeover => ClientTakeOver, inflate => Inflate, - inflate_takeover => ServerTakeOver}} + inflate_takeover => ServerTakeOver, + approx_max => maps:get(approx_max, Opts, nil)}} end end. @@ -443,7 +448,7 @@ frag_state(_, 1, _, FragState) -> FragState. | {ok, close_code(), binary(), utf8_state(), binary()} | {more, binary(), utf8_state()} | {more, close_code(), binary(), utf8_state()} - | {error, badframe | badencoding}. + | {error, badframe | badencoding | overflow}. %% Empty last frame of compressed message. parse_payload(Data, _, Utf8State, _, _, 0, {fin, _, << 1:1, 0:2 >>}, #{inflate := Inflate, inflate_takeover := TakeOver}, _) -> @@ -455,16 +460,26 @@ parse_payload(Data, _, Utf8State, _, _, 0, {fin, _, << 1:1, 0:2 >>}, {ok, <<>>, Utf8State, Data}; %% Compressed fragmented frame. parse_payload(Data, MaskKey, Utf8State, ParsedLen, Type, Len, FragState = {_, _, << 1:1, 0:2 >>}, - #{inflate := Inflate, inflate_takeover := TakeOver}, _) -> + #{inflate := Inflate, inflate_takeover := TakeOver, approx_max := ApproxMax}, _) -> {Data2, Rest, Eof} = split_payload(Data, Len), - Payload = inflate_frame(unmask(Data2, MaskKey, ParsedLen), Inflate, TakeOver, FragState, Eof), - validate_payload(Payload, Rest, Utf8State, ParsedLen, Type, FragState, Eof); + Result = inflate_frame(unmask(Data2, MaskKey, ParsedLen), Inflate, TakeOver, FragState, ApproxMax, Eof), + case Result of + {error, _} -> + Result; + {ok, Payload} -> + validate_payload(Payload, Rest, Utf8State, ParsedLen, Type, FragState, Eof) + end; %% Compressed frame. parse_payload(Data, MaskKey, Utf8State, ParsedLen, Type, Len, FragState, - #{inflate := Inflate, inflate_takeover := TakeOver}, << 1:1, 0:2 >>) when Type =:= text; Type =:= binary -> + #{inflate := Inflate, inflate_takeover := TakeOver, approx_max := ApproxMax}, << 1:1, 0:2 >>) when Type =:= text; Type =:= binary -> {Data2, Rest, Eof} = split_payload(Data, Len), - Payload = inflate_frame(unmask(Data2, MaskKey, ParsedLen), Inflate, TakeOver, FragState, Eof), - validate_payload(Payload, Rest, Utf8State, ParsedLen, Type, FragState, Eof); + Result = inflate_frame(unmask(Data2, MaskKey, ParsedLen), Inflate, TakeOver, FragState, ApproxMax, Eof), + case Result of + {error, _} -> + Result; + {ok, Payload} -> + validate_payload(Payload, Rest, Utf8State, ParsedLen, Type, FragState, Eof) + end; %% Empty frame. parse_payload(Data, _, Utf8State, 0, _, 0, _, _, _) when Utf8State =:= 0; Utf8State =:= undefined -> @@ -541,16 +556,34 @@ mask(<< O:8 >>, MaskKey, Acc) -> T = O bxor MaskKey2, << Acc/binary, T:8 >>. -inflate_frame(Data, Inflate, TakeOver, FragState, true) +safe_inflate(ApproxMax, Inflate, {continue, Data}, Acc) -> + Overflow = iolist_size(Acc) + iolist_size(Data) > ApproxMax, + if + Overflow -> + {error, overflow}; + true -> + safe_inflate(ApproxMax, Inflate, zlib:safeInflate(Inflate, []), [Acc, Data]) + end; +safe_inflate(_A, _I, {finished, Data}, Acc) -> + {ok, iolist_to_binary([Acc, Data])}. + +inflate_dispatch(Data, Inflate, ApproxMax) -> + Data2 = << Data/binary, 0, 0, 255, 255 >>, + case ApproxMax of + nil -> {ok, iolist_to_binary(zlib:inflate(Inflate, Data2))}; + _ -> safe_inflate(ApproxMax, Inflate, zlib:safeInflate(Inflate, Data2), []) + end. + +inflate_frame(Data, Inflate, TakeOver, FragState, ApproxMax, true) when FragState =:= undefined; element(1, FragState) =:= fin -> - Data2 = zlib:inflate(Inflate, << Data/binary, 0, 0, 255, 255 >>), + Result = inflate_dispatch(Data, Inflate, ApproxMax), case TakeOver of no_takeover -> zlib:inflateReset(Inflate); takeover -> ok end, - iolist_to_binary(Data2); -inflate_frame(Data, Inflate, _T, _F, _E) -> - iolist_to_binary(zlib:inflate(Inflate, Data)). + Result; +inflate_frame(Data, Inflate, _T, _F, ApproxMax, _E) -> + inflate_dispatch(Data, Inflate, ApproxMax). %% The Utf8State variable can be set to 'undefined' to disable the validation. validate_payload(Payload, _, undefined, _, _, _, false) ->