Skip to content

Commit 053e79f

Browse files
rewrite masking function to be tail-recursive and use binary matches (#17)
1 parent 653e4ed commit 053e79f

1 file changed

Lines changed: 23 additions & 12 deletions

File tree

lib/mint/web_socket/frame.ex

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ defmodule Mint.WebSocket.Frame do
88
alias Mint.WebSocket.{Utils, Extension}
99
alias Mint.WebSocketError
1010

11+
@compile {:inline, apply_mask: 2, apply_mask: 3}
12+
1113
shared = [{:reserved, <<0::size(3)>>}, :mask, :data, :fin?]
1214

1315
defrecord :continuation, shared
@@ -134,20 +136,29 @@ defmodule Mint.WebSocket.Frame do
134136
# bytes (where the mask bytes repeat).
135137
# This is an "involution" function: applying the mask will mask
136138
# the data and applying the mask again will unmask it.
137-
def apply_mask(payload, nil), do: payload
138-
139-
def apply_mask(payload, _mask = <<a, b, c, d>>) do
140-
[a, b, c, d]
141-
|> Stream.cycle()
142-
|> Enum.reduce_while({payload, _acc = <<>>}, fn
143-
_mask_key, {<<>>, acc} ->
144-
{:halt, acc}
145-
146-
mask_key, {<<part_key::integer, payload_rest::binary>>, acc} ->
147-
{:cont, {payload_rest, <<acc::binary, Bitwise.bxor(mask_key, part_key)::integer>>}}
148-
end)
139+
def apply_mask(payload, mask, acc \\ <<>>)
140+
141+
def apply_mask(payload, nil, _acc), do: payload
142+
143+
# n=4 is the happy path
144+
# n=3..1 catches cases where the remaining byte_size/1 of the payload is shorter
145+
# than the mask
146+
for n <- 4..1 do
147+
def apply_mask(
148+
<<part_key::integer-size(8)-unit(unquote(n)), payload_rest::binary>>,
149+
<<mask_key::integer-size(8)-unit(unquote(n)), _::binary>> = mask,
150+
acc
151+
) do
152+
apply_mask(
153+
payload_rest,
154+
mask,
155+
<<acc::binary, Bitwise.bxor(mask_key, part_key)::integer-size(8)-unit(unquote(n))>>
156+
)
157+
end
149158
end
150159

160+
def apply_mask(<<>>, _mask, acc), do: acc
161+
151162
@spec decode(Mint.WebSocket.t(), binary()) ::
152163
{:ok, Mint.WebSocket.t(), [Mint.WebSocket.frame() | {:error, term()}]}
153164
| {:error, Mint.WebSocket.t(), any()}

0 commit comments

Comments
 (0)