util.encodings: Spell out all IDNA 2008 options ICU has
[prosody.git] / net / websocket / frames.lua
blob5e17df075e7b9377f6947fe44aecbd6d8851450b
1 -- Prosody IM
2 -- Copyright (C) 2012 Florian Zeitz
3 -- Copyright (C) 2014 Daurnimator
4 --
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
9 local softreq = require "util.dependencies".softreq;
10 local random_bytes = require "util.random".bytes;
12 local bit = require "util.bitcompat";
13 local band = bit.band;
14 local bor = bit.bor;
15 local bxor = bit.bxor;
16 local lshift = bit.lshift;
17 local rshift = bit.rshift;
18 local unpack = table.unpack or unpack; -- luacheck: ignore 113
20 local t_concat = table.concat;
21 local s_byte = string.byte;
22 local s_char= string.char;
23 local s_sub = string.sub;
24 local s_pack = string.pack;
25 local s_unpack = string.unpack;
27 if not s_pack and softreq"struct" then
28 s_pack = softreq"struct".pack;
29 s_unpack = softreq"struct".unpack;
30 end
32 local function read_uint16be(str, pos)
33 local l1, l2 = s_byte(str, pos, pos+1);
34 return l1*256 + l2;
35 end
36 -- FIXME: this may lose precision
37 local function read_uint64be(str, pos)
38 local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7);
39 local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4;
40 local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8;
41 return h * 2^32 + l;
42 end
43 local function pack_uint16be(x)
44 return s_char(rshift(x, 8), band(x, 0xFF));
45 end
46 local function get_byte(x, n)
47 return band(rshift(x, n), 0xFF);
48 end
49 local function pack_uint64be(x)
50 local h = band(x / 2^32, 2^32-1);
51 return s_char(get_byte(h, 24), get_byte(h, 16), get_byte(h, 8), band(h, 0xFF),
52 get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF));
53 end
55 if s_pack then
56 function pack_uint16be(x)
57 return s_pack(">I2", x);
58 end
59 function pack_uint64be(x)
60 return s_pack(">I8", x);
61 end
62 end
64 if s_unpack then
65 function read_uint16be(str, pos)
66 return s_unpack(">I2", str, pos);
67 end
68 function read_uint64be(str, pos)
69 return s_unpack(">I8", str, pos);
70 end
71 end
73 local function parse_frame_header(frame)
74 if #frame < 2 then return; end
76 local byte1, byte2 = s_byte(frame, 1, 2);
77 local result = {
78 FIN = band(byte1, 0x80) > 0;
79 RSV1 = band(byte1, 0x40) > 0;
80 RSV2 = band(byte1, 0x20) > 0;
81 RSV3 = band(byte1, 0x10) > 0;
82 opcode = band(byte1, 0x0F);
84 MASK = band(byte2, 0x80) > 0;
85 length = band(byte2, 0x7F);
88 local length_bytes = 0;
89 if result.length == 126 then
90 length_bytes = 2;
91 elseif result.length == 127 then
92 length_bytes = 8;
93 end
95 local header_length = 2 + length_bytes + (result.MASK and 4 or 0);
96 if #frame < header_length then return; end
98 if length_bytes == 2 then
99 result.length = read_uint16be(frame, 3);
100 elseif length_bytes == 8 then
101 result.length = read_uint64be(frame, 3);
104 if result.MASK then
105 result.key = { s_byte(frame, length_bytes+3, length_bytes+6) };
108 return result, header_length;
111 -- XORs the string `str` with the array of bytes `key`
112 -- TODO: optimize
113 local function apply_mask(str, key, from, to)
114 from = from or 1
115 if from < 0 then from = #str + from + 1 end -- negative indices
116 to = to or #str
117 if to < 0 then to = #str + to + 1 end -- negative indices
118 local key_len = #key
119 local counter = 0;
120 local data = {};
121 for i = from, to do
122 local key_index = counter%key_len + 1;
123 counter = counter + 1;
124 data[counter] = s_char(bxor(key[key_index], s_byte(str, i)));
126 return t_concat(data);
129 local function parse_frame_body(frame, header, pos)
130 if header.MASK then
131 return apply_mask(frame, header.key, pos, pos + header.length - 1);
132 else
133 return frame:sub(pos, pos + header.length - 1);
137 local function parse_frame(frame)
138 local result, pos = parse_frame_header(frame);
139 if result == nil or #frame < (pos + result.length) then return; end
140 result.data = parse_frame_body(frame, result, pos+1);
141 return result, pos + result.length;
144 local function build_frame(desc)
145 local data = desc.data or "";
147 assert(desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF, "Invalid WebSocket opcode");
148 if desc.opcode >= 0x8 then
149 -- RFC 6455 5.5
150 assert(#data <= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less.");
153 local b1 = bor(desc.opcode,
154 desc.FIN and 0x80 or 0,
155 desc.RSV1 and 0x40 or 0,
156 desc.RSV2 and 0x20 or 0,
157 desc.RSV3 and 0x10 or 0);
159 local b2 = #data;
160 local length_extra;
161 if b2 <= 125 then -- 7-bit length
162 length_extra = "";
163 elseif b2 <= 0xFFFF then -- 2-byte length
164 b2 = 126;
165 length_extra = pack_uint16be(#data);
166 else -- 8-byte length
167 b2 = 127;
168 length_extra = pack_uint64be(#data);
171 local key = ""
172 if desc.MASK then
173 local key_a = desc.key
174 if key_a then
175 key = s_char(unpack(key_a, 1, 4));
176 else
177 key = random_bytes(4);
178 key_a = {key:byte(1,4)};
180 b2 = bor(b2, 0x80);
181 data = apply_mask(data, key_a);
184 return s_char(b1, b2) .. length_extra .. key .. data
187 local function parse_close(data)
188 local code, message
189 if #data >= 2 then
190 code = read_uint16be(data, 1);
191 if #data > 2 then
192 message = s_sub(data, 3);
195 return code, message
198 local function build_close(code, message, mask)
199 local data = pack_uint16be(code);
200 if message then
201 assert(#message<=123, "Close reason must be <=123 bytes");
202 data = data .. message;
204 return build_frame({
205 opcode = 0x8;
206 FIN = true;
207 MASK = mask;
208 data = data;
212 return {
213 parse_header = parse_frame_header;
214 parse_body = parse_frame_body;
215 parse = parse_frame;
216 build = build_frame;
217 parse_close = parse_close;
218 build_close = build_close;