2 -- Copyright (C) 2012 Florian Zeitz
3 -- Copyright (C) 2014 Daurnimator
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
9 local softreq
= require
"util.dependencies".softreq
;
10 local random_bytes
= require
"util.random".bytes
;
12 local bit
= require
"util.bitcompat";
13 local band
= bit
.band
;
15 local bxor
= bit
.bxor
;
16 local lshift
= bit
.lshift
;
17 local rshift
= bit
.rshift
;
18 local unpack
= table.unpack
or unpack
; -- luacheck: ignore 113
20 local t_concat
= table.concat
;
21 local s_byte
= string.byte
;
22 local s_char
= string.char
;
23 local s_sub
= string.sub
;
24 local s_pack
= string.pack
;
25 local s_unpack
= string.unpack
;
27 if not s_pack
and softreq
"struct" then
28 s_pack
= softreq
"struct".pack
;
29 s_unpack
= softreq
"struct".unpack
;
32 local function read_uint16be(str
, pos
)
33 local l1
, l2
= s_byte(str
, pos
, pos
+1);
36 -- FIXME: this may lose precision
37 local function read_uint64be(str
, pos
)
38 local l1
, l2
, l3
, l4
, l5
, l6
, l7
, l8
= s_byte(str
, pos
, pos
+7);
39 local h
= lshift(l1
, 24) + lshift(l2
, 16) + lshift(l3
, 8) + l4
;
40 local l
= lshift(l5
, 24) + lshift(l6
, 16) + lshift(l7
, 8) + l8
;
43 local function pack_uint16be(x
)
44 return s_char(rshift(x
, 8), band(x
, 0xFF));
46 local function get_byte(x
, n
)
47 return band(rshift(x
, n
), 0xFF);
49 local function pack_uint64be(x
)
50 local h
= band(x
/ 2^
32, 2^
32-1);
51 return s_char(get_byte(h
, 24), get_byte(h
, 16), get_byte(h
, 8), band(h
, 0xFF),
52 get_byte(x
, 24), get_byte(x
, 16), get_byte(x
, 8), band(x
, 0xFF));
56 function pack_uint16be(x
)
57 return s_pack(">I2", x
);
59 function pack_uint64be(x
)
60 return s_pack(">I8", x
);
65 function read_uint16be(str
, pos
)
66 return s_unpack(">I2", str
, pos
);
68 function read_uint64be(str
, pos
)
69 return s_unpack(">I8", str
, pos
);
73 local function parse_frame_header(frame
)
74 if #frame
< 2 then return; end
76 local byte1
, byte2
= s_byte(frame
, 1, 2);
78 FIN
= band(byte1
, 0x80) > 0;
79 RSV1
= band(byte1
, 0x40) > 0;
80 RSV2
= band(byte1
, 0x20) > 0;
81 RSV3
= band(byte1
, 0x10) > 0;
82 opcode
= band(byte1
, 0x0F);
84 MASK
= band(byte2
, 0x80) > 0;
85 length
= band(byte2
, 0x7F);
88 local length_bytes
= 0;
89 if result
.length
== 126 then
91 elseif result
.length
== 127 then
95 local header_length
= 2 + length_bytes
+ (result
.MASK
and 4 or 0);
96 if #frame
< header_length
then return; end
98 if length_bytes
== 2 then
99 result
.length
= read_uint16be(frame
, 3);
100 elseif length_bytes
== 8 then
101 result
.length
= read_uint64be(frame
, 3);
105 result
.key
= { s_byte(frame
, length_bytes
+3, length_bytes
+6) };
108 return result
, header_length
;
111 -- XORs the string `str` with the array of bytes `key`
113 local function apply_mask(str
, key
, from
, to
)
115 if from
< 0 then from
= #str
+ from
+ 1 end -- negative indices
117 if to
< 0 then to
= #str
+ to
+ 1 end -- negative indices
122 local key_index
= counter
%key_len
+ 1;
123 counter
= counter
+ 1;
124 data
[counter
] = s_char(bxor(key
[key_index
], s_byte(str
, i
)));
126 return t_concat(data
);
129 local function parse_frame_body(frame
, header
, pos
)
131 return apply_mask(frame
, header
.key
, pos
, pos
+ header
.length
- 1);
133 return frame
:sub(pos
, pos
+ header
.length
- 1);
137 local function parse_frame(frame
)
138 local result
, pos
= parse_frame_header(frame
);
139 if result
== nil or #frame
< (pos
+ result
.length
) then return; end
140 result
.data
= parse_frame_body(frame
, result
, pos
+1);
141 return result
, pos
+ result
.length
;
144 local function build_frame(desc
)
145 local data
= desc
.data
or "";
147 assert(desc
.opcode
and desc
.opcode
>= 0 and desc
.opcode
<= 0xF, "Invalid WebSocket opcode");
148 if desc
.opcode
>= 0x8 then
150 assert(#data
<= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less.");
153 local b1
= bor(desc
.opcode
,
154 desc
.FIN
and 0x80 or 0,
155 desc
.RSV1
and 0x40 or 0,
156 desc
.RSV2
and 0x20 or 0,
157 desc
.RSV3
and 0x10 or 0);
161 if b2
<= 125 then -- 7-bit length
163 elseif b2
<= 0xFFFF then -- 2-byte length
165 length_extra
= pack_uint16be(#data
);
166 else -- 8-byte length
168 length_extra
= pack_uint64be(#data
);
173 local key_a
= desc
.key
175 key
= s_char(unpack(key_a
, 1, 4));
177 key
= random_bytes(4);
178 key_a
= {key
:byte(1,4)};
181 data
= apply_mask(data
, key_a
);
184 return s_char(b1
, b2
) .. length_extra
.. key
.. data
187 local function parse_close(data
)
190 code
= read_uint16be(data
, 1);
192 message
= s_sub(data
, 3);
198 local function build_close(code
, message
, mask
)
199 local data
= pack_uint16be(code
);
201 assert(#message
<=123, "Close reason must be <=123 bytes");
202 data
= data
.. message
;
213 parse_header
= parse_frame_header
;
214 parse_body
= parse_frame_body
;
217 parse_close
= parse_close
;
218 build_close
= build_close
;