1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/spdy/hpack_huffman_table.h"
10 #include "base/logging.h"
11 #include "base/numerics/safe_conversions.h"
12 #include "net/spdy/hpack_input_stream.h"
13 #include "net/spdy/hpack_output_stream.h"
17 using base::StringPiece
;
22 // How many bits to index in the root decode table.
23 const uint8 kDecodeTableRootBits
= 9;
24 // Maximum number of bits to index in successive decode tables.
25 const uint8 kDecodeTableBranchBits
= 6;
27 bool SymbolLengthAndIdCompare(const HpackHuffmanSymbol
& a
,
28 const HpackHuffmanSymbol
& b
) {
29 if (a
.length
== b
.length
) {
32 return a
.length
< b
.length
;
34 bool SymbolIdCompare(const HpackHuffmanSymbol
& a
,
35 const HpackHuffmanSymbol
& b
) {
41 HpackHuffmanTable::DecodeEntry::DecodeEntry()
42 : next_table_index(0), length(0), symbol_id(0) {
44 HpackHuffmanTable::DecodeEntry::DecodeEntry(uint8 next_table_index
,
47 : next_table_index(next_table_index
), length(length
), symbol_id(symbol_id
) {
49 size_t HpackHuffmanTable::DecodeTable::size() const {
50 return size_t(1) << indexed_length
;
53 HpackHuffmanTable::HpackHuffmanTable() {}
55 HpackHuffmanTable::~HpackHuffmanTable() {}
57 bool HpackHuffmanTable::Initialize(const HpackHuffmanSymbol
* input_symbols
,
58 size_t symbol_count
) {
59 CHECK(!IsInitialized());
60 DCHECK(base::IsValueInRangeForNumericType
<uint16
>(symbol_count
));
62 std::vector
<Symbol
> symbols(symbol_count
);
63 // Validate symbol id sequence, and copy into |symbols|.
64 for (uint16 i
= 0; i
< symbol_count
; i
++) {
65 if (i
!= input_symbols
[i
].id
) {
66 failed_symbol_id_
= i
;
69 symbols
[i
] = input_symbols
[i
];
71 // Order on length and ID ascending, to verify symbol codes are canonical.
72 std::sort(symbols
.begin(), symbols
.end(), SymbolLengthAndIdCompare
);
73 if (symbols
[0].code
!= 0) {
74 failed_symbol_id_
= 0;
77 for (size_t i
= 1; i
!= symbols
.size(); i
++) {
78 unsigned code_shift
= 32 - symbols
[i
-1].length
;
79 uint32 code
= symbols
[i
-1].code
+ (1 << code_shift
);
81 if (code
!= symbols
[i
].code
) {
82 failed_symbol_id_
= symbols
[i
].id
;
85 if (code
< symbols
[i
-1].code
) {
86 // An integer overflow occurred. This implies the input
87 // lengths do not represent a valid Huffman code.
88 failed_symbol_id_
= symbols
[i
].id
;
92 if (symbols
.back().length
< 8) {
93 // At least one code (such as an EOS symbol) must be 8 bits or longer.
94 // Without this, some inputs will not be encodable in a whole number
98 pad_bits_
= static_cast<uint8
>(symbols
.back().code
>> 24);
100 BuildDecodeTables(symbols
);
101 // Order on symbol ID ascending.
102 std::sort(symbols
.begin(), symbols
.end(), SymbolIdCompare
);
103 BuildEncodeTable(symbols
);
107 void HpackHuffmanTable::BuildEncodeTable(const std::vector
<Symbol
>& symbols
) {
108 for (size_t i
= 0; i
!= symbols
.size(); i
++) {
109 const Symbol
& symbol
= symbols
[i
];
110 CHECK_EQ(i
, symbol
.id
);
111 code_by_id_
.push_back(symbol
.code
);
112 length_by_id_
.push_back(symbol
.length
);
116 void HpackHuffmanTable::BuildDecodeTables(const std::vector
<Symbol
>& symbols
) {
117 AddDecodeTable(0, kDecodeTableRootBits
);
118 // We wish to maximize the flatness of the DecodeTable hierarchy (subject to
119 // the |kDecodeTableBranchBits| constraint), and to minimize the size of
120 // child tables. To achieve this, we iterate in order of descending code
121 // length. This ensures that child tables are visited with their longest
122 // entry first, and that the child can therefore be minimally sized to hold
123 // that entry without fear of introducing unneccesary branches later.
124 for (std::vector
<Symbol
>::const_reverse_iterator it
= symbols
.rbegin();
125 it
!= symbols
.rend(); ++it
) {
126 uint8 table_index
= 0;
128 const DecodeTable table
= decode_tables_
[table_index
];
130 // Mask and shift the portion of the code being indexed into low bits.
131 uint32 index
= (it
->code
<< table
.prefix_length
);
132 index
= index
>> (32 - table
.indexed_length
);
134 CHECK_LT(index
, table
.size());
135 DecodeEntry entry
= Entry(table
, index
);
137 uint8 total_indexed
= table
.prefix_length
+ table
.indexed_length
;
138 if (total_indexed
>= it
->length
) {
139 // We're writing a terminal entry.
140 entry
.length
= it
->length
;
141 entry
.symbol_id
= it
->id
;
142 entry
.next_table_index
= table_index
;
143 SetEntry(table
, index
, entry
);
147 if (entry
.length
== 0) {
148 // First visit to this placeholder. We need to create a new table.
149 CHECK_EQ(entry
.next_table_index
, 0);
150 entry
.length
= it
->length
;
151 entry
.next_table_index
= AddDecodeTable(
152 total_indexed
, // Becomes the new table prefix.
153 std::min
<uint8
>(kDecodeTableBranchBits
,
154 entry
.length
- total_indexed
));
155 SetEntry(table
, index
, entry
);
157 CHECK_NE(entry
.next_table_index
, table_index
);
158 table_index
= entry
.next_table_index
;
161 // Fill shorter table entries into the additional entry spots they map to.
162 for (size_t i
= 0; i
!= decode_tables_
.size(); i
++) {
163 const DecodeTable
& table
= decode_tables_
[i
];
164 uint8 total_indexed
= table
.prefix_length
+ table
.indexed_length
;
167 while (j
!= table
.size()) {
168 const DecodeEntry
& entry
= Entry(table
, j
);
169 if (entry
.length
!= 0 && entry
.length
< total_indexed
) {
170 // The difference between entry & table bit counts tells us how
171 // many additional entries map to this one.
172 size_t fill_count
= 1 << (total_indexed
- entry
.length
);
173 CHECK_LE(j
+ fill_count
, table
.size());
175 for (size_t k
= 1; k
!= fill_count
; k
++) {
176 CHECK_EQ(Entry(table
, j
+ k
).length
, 0);
177 SetEntry(table
, j
+ k
, entry
);
187 uint8
HpackHuffmanTable::AddDecodeTable(uint8 prefix
, uint8 indexed
) {
188 CHECK_LT(decode_tables_
.size(), 255u);
191 table
.prefix_length
= prefix
;
192 table
.indexed_length
= indexed
;
193 table
.entries_offset
= decode_entries_
.size();
194 decode_tables_
.push_back(table
);
196 decode_entries_
.resize(decode_entries_
.size() + (size_t(1) << indexed
));
197 return static_cast<uint8
>(decode_tables_
.size() - 1);
200 const HpackHuffmanTable::DecodeEntry
& HpackHuffmanTable::Entry(
201 const DecodeTable
& table
,
202 uint32 index
) const {
203 DCHECK_LT(index
, table
.size());
204 DCHECK_LT(table
.entries_offset
+ index
, decode_entries_
.size());
205 return decode_entries_
[table
.entries_offset
+ index
];
208 void HpackHuffmanTable::SetEntry(const DecodeTable
& table
,
210 const DecodeEntry
& entry
) {
211 CHECK_LT(index
, table
.size());
212 CHECK_LT(table
.entries_offset
+ index
, decode_entries_
.size());
213 decode_entries_
[table
.entries_offset
+ index
] = entry
;
216 bool HpackHuffmanTable::IsInitialized() const {
217 return !code_by_id_
.empty();
220 void HpackHuffmanTable::EncodeString(StringPiece in
,
221 HpackOutputStream
* out
) const {
222 size_t bit_remnant
= 0;
223 for (size_t i
= 0; i
!= in
.size(); i
++) {
224 uint16 symbol_id
= static_cast<uint8
>(in
[i
]);
225 CHECK_GT(code_by_id_
.size(), symbol_id
);
227 // Load, and shift code to low bits.
228 unsigned length
= length_by_id_
[symbol_id
];
229 uint32 code
= code_by_id_
[symbol_id
] >> (32 - length
);
231 bit_remnant
= (bit_remnant
+ length
) % 8;
234 out
->AppendBits(static_cast<uint8
>(code
>> 24), length
- 24);
238 out
->AppendBits(static_cast<uint8
>(code
>> 16), length
- 16);
242 out
->AppendBits(static_cast<uint8
>(code
>> 8), length
- 8);
245 out
->AppendBits(static_cast<uint8
>(code
), length
);
247 if (bit_remnant
!= 0) {
248 // Pad current byte as required.
249 out
->AppendBits(pad_bits_
>> bit_remnant
, 8 - bit_remnant
);
253 size_t HpackHuffmanTable::EncodedSize(StringPiece in
) const {
254 size_t bit_count
= 0;
255 for (size_t i
= 0; i
!= in
.size(); i
++) {
256 uint16 symbol_id
= static_cast<uint8
>(in
[i
]);
257 CHECK_GT(code_by_id_
.size(), symbol_id
);
259 bit_count
+= length_by_id_
[symbol_id
];
261 if (bit_count
% 8 != 0) {
262 bit_count
+= 8 - bit_count
% 8;
264 return bit_count
/ 8;
267 bool HpackHuffmanTable::DecodeString(HpackInputStream
* in
,
270 // Number of decode iterations required for a 32-bit code.
271 const int kDecodeIterations
= static_cast<int>(
272 std::ceil((32.f
- kDecodeTableRootBits
) / kDecodeTableBranchBits
));
276 // Current input, stored in the high |bits_available| bits of |bits|.
278 size_t bits_available
= 0;
279 bool peeked_success
= in
->PeekBits(&bits_available
, &bits
);
282 const DecodeTable
* table
= &decode_tables_
[0];
283 uint32 index
= bits
>> (32 - kDecodeTableRootBits
);
285 for (int i
= 0; i
!= kDecodeIterations
; i
++) {
286 DCHECK_LT(index
, table
->size());
287 DCHECK_LT(Entry(*table
, index
).next_table_index
, decode_tables_
.size());
289 table
= &decode_tables_
[Entry(*table
, index
).next_table_index
];
290 // Mask and shift the portion of the code being indexed into low bits.
291 index
= (bits
<< table
->prefix_length
) >> (32 - table
->indexed_length
);
293 const DecodeEntry
& entry
= Entry(*table
, index
);
295 if (entry
.length
> bits_available
) {
296 if (!peeked_success
) {
297 // Unable to read enough input for a match. If only a portion of
298 // the last byte remains, this is a successful EOF condition.
299 in
->ConsumeByteRemainder();
300 return !in
->HasMoreData();
302 } else if (entry
.length
== 0) {
303 // The input is an invalid prefix, larger than any prefix in the table.
306 if (out
->size() == out_capacity
) {
307 // This code would cause us to overflow |out_capacity|.
310 if (entry
.symbol_id
< 256) {
311 // Assume symbols >= 256 are used for padding.
312 out
->push_back(static_cast<char>(entry
.symbol_id
));
315 in
->ConsumeBits(entry
.length
);
316 bits
= bits
<< entry
.length
;
317 bits_available
-= entry
.length
;
319 peeked_success
= in
->PeekBits(&bits_available
, &bits
);