1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/spdy/hpack/hpack_huffman_table.h"
10 #include "base/logging.h"
11 #include "base/numerics/safe_conversions.h"
12 #include "net/spdy/hpack/hpack_input_stream.h"
13 #include "net/spdy/hpack/hpack_output_stream.h"
17 using base::StringPiece
;
22 // How many bits to index in the root decode table.
23 const uint8 kDecodeTableRootBits
= 9;
24 // Maximum number of bits to index in successive decode tables.
25 const uint8 kDecodeTableBranchBits
= 6;
27 bool SymbolLengthAndIdCompare(const HpackHuffmanSymbol
& a
,
28 const HpackHuffmanSymbol
& b
) {
29 if (a
.length
== b
.length
) {
32 return a
.length
< b
.length
;
34 bool SymbolIdCompare(const HpackHuffmanSymbol
& a
, const HpackHuffmanSymbol
& b
) {
40 HpackHuffmanTable::DecodeEntry::DecodeEntry()
41 : next_table_index(0), length(0), symbol_id(0) {}
42 HpackHuffmanTable::DecodeEntry::DecodeEntry(uint8 next_table_index
,
45 : next_table_index(next_table_index
),
47 symbol_id(symbol_id
) {}
48 size_t HpackHuffmanTable::DecodeTable::size() const {
49 return size_t(1) << indexed_length
;
52 HpackHuffmanTable::HpackHuffmanTable() {}
54 HpackHuffmanTable::~HpackHuffmanTable() {}
56 bool HpackHuffmanTable::Initialize(const HpackHuffmanSymbol
* input_symbols
,
57 size_t symbol_count
) {
58 CHECK(!IsInitialized());
59 DCHECK(base::IsValueInRangeForNumericType
<uint16
>(symbol_count
));
61 std::vector
<Symbol
> symbols(symbol_count
);
62 // Validate symbol id sequence, and copy into |symbols|.
63 for (uint16 i
= 0; i
< symbol_count
; i
++) {
64 if (i
!= input_symbols
[i
].id
) {
65 failed_symbol_id_
= i
;
68 symbols
[i
] = input_symbols
[i
];
70 // Order on length and ID ascending, to verify symbol codes are canonical.
71 std::sort(symbols
.begin(), symbols
.end(), SymbolLengthAndIdCompare
);
72 if (symbols
[0].code
!= 0) {
73 failed_symbol_id_
= 0;
76 for (size_t i
= 1; i
!= symbols
.size(); i
++) {
77 unsigned code_shift
= 32 - symbols
[i
- 1].length
;
78 uint32 code
= symbols
[i
- 1].code
+ (1 << code_shift
);
80 if (code
!= symbols
[i
].code
) {
81 failed_symbol_id_
= symbols
[i
].id
;
84 if (code
< symbols
[i
- 1].code
) {
85 // An integer overflow occurred. This implies the input
86 // lengths do not represent a valid Huffman code.
87 failed_symbol_id_
= symbols
[i
].id
;
91 if (symbols
.back().length
< 8) {
92 // At least one code (such as an EOS symbol) must be 8 bits or longer.
93 // Without this, some inputs will not be encodable in a whole number
97 pad_bits_
= static_cast<uint8
>(symbols
.back().code
>> 24);
99 BuildDecodeTables(symbols
);
100 // Order on symbol ID ascending.
101 std::sort(symbols
.begin(), symbols
.end(), SymbolIdCompare
);
102 BuildEncodeTable(symbols
);
106 void HpackHuffmanTable::BuildEncodeTable(const std::vector
<Symbol
>& symbols
) {
107 for (size_t i
= 0; i
!= symbols
.size(); i
++) {
108 const Symbol
& symbol
= symbols
[i
];
109 CHECK_EQ(i
, symbol
.id
);
110 code_by_id_
.push_back(symbol
.code
);
111 length_by_id_
.push_back(symbol
.length
);
115 void HpackHuffmanTable::BuildDecodeTables(const std::vector
<Symbol
>& symbols
) {
116 AddDecodeTable(0, kDecodeTableRootBits
);
117 // We wish to maximize the flatness of the DecodeTable hierarchy (subject to
118 // the |kDecodeTableBranchBits| constraint), and to minimize the size of
119 // child tables. To achieve this, we iterate in order of descending code
120 // length. This ensures that child tables are visited with their longest
121 // entry first, and that the child can therefore be minimally sized to hold
122 // that entry without fear of introducing unneccesary branches later.
123 for (std::vector
<Symbol
>::const_reverse_iterator it
= symbols
.rbegin();
124 it
!= symbols
.rend(); ++it
) {
125 uint8 table_index
= 0;
127 const DecodeTable table
= decode_tables_
[table_index
];
129 // Mask and shift the portion of the code being indexed into low bits.
130 uint32 index
= (it
->code
<< table
.prefix_length
);
131 index
= index
>> (32 - table
.indexed_length
);
133 CHECK_LT(index
, table
.size());
134 DecodeEntry entry
= Entry(table
, index
);
136 uint8 total_indexed
= table
.prefix_length
+ table
.indexed_length
;
137 if (total_indexed
>= it
->length
) {
138 // We're writing a terminal entry.
139 entry
.length
= it
->length
;
140 entry
.symbol_id
= it
->id
;
141 entry
.next_table_index
= table_index
;
142 SetEntry(table
, index
, entry
);
146 if (entry
.length
== 0) {
147 // First visit to this placeholder. We need to create a new table.
148 CHECK_EQ(entry
.next_table_index
, 0);
149 entry
.length
= it
->length
;
150 entry
.next_table_index
=
151 AddDecodeTable(total_indexed
, // Becomes the new table prefix.
152 std::min
<uint8
>(kDecodeTableBranchBits
,
153 entry
.length
- total_indexed
));
154 SetEntry(table
, index
, entry
);
156 CHECK_NE(entry
.next_table_index
, table_index
);
157 table_index
= entry
.next_table_index
;
160 // Fill shorter table entries into the additional entry spots they map to.
161 for (size_t i
= 0; i
!= decode_tables_
.size(); i
++) {
162 const DecodeTable
& table
= decode_tables_
[i
];
163 uint8 total_indexed
= table
.prefix_length
+ table
.indexed_length
;
166 while (j
!= table
.size()) {
167 const DecodeEntry
& entry
= Entry(table
, j
);
168 if (entry
.length
!= 0 && entry
.length
< total_indexed
) {
169 // The difference between entry & table bit counts tells us how
170 // many additional entries map to this one.
171 size_t fill_count
= 1 << (total_indexed
- entry
.length
);
172 CHECK_LE(j
+ fill_count
, table
.size());
174 for (size_t k
= 1; k
!= fill_count
; k
++) {
175 CHECK_EQ(Entry(table
, j
+ k
).length
, 0);
176 SetEntry(table
, j
+ k
, entry
);
186 uint8
HpackHuffmanTable::AddDecodeTable(uint8 prefix
, uint8 indexed
) {
187 CHECK_LT(decode_tables_
.size(), 255u);
190 table
.prefix_length
= prefix
;
191 table
.indexed_length
= indexed
;
192 table
.entries_offset
= decode_entries_
.size();
193 decode_tables_
.push_back(table
);
195 decode_entries_
.resize(decode_entries_
.size() + (size_t(1) << indexed
));
196 return static_cast<uint8
>(decode_tables_
.size() - 1);
199 const HpackHuffmanTable::DecodeEntry
& HpackHuffmanTable::Entry(
200 const DecodeTable
& table
,
201 uint32 index
) const {
202 DCHECK_LT(index
, table
.size());
203 DCHECK_LT(table
.entries_offset
+ index
, decode_entries_
.size());
204 return decode_entries_
[table
.entries_offset
+ index
];
207 void HpackHuffmanTable::SetEntry(const DecodeTable
& table
,
209 const DecodeEntry
& entry
) {
210 CHECK_LT(index
, table
.size());
211 CHECK_LT(table
.entries_offset
+ index
, decode_entries_
.size());
212 decode_entries_
[table
.entries_offset
+ index
] = entry
;
215 bool HpackHuffmanTable::IsInitialized() const {
216 return !code_by_id_
.empty();
219 void HpackHuffmanTable::EncodeString(StringPiece in
,
220 HpackOutputStream
* out
) const {
221 size_t bit_remnant
= 0;
222 for (size_t i
= 0; i
!= in
.size(); i
++) {
223 uint16 symbol_id
= static_cast<uint8
>(in
[i
]);
224 CHECK_GT(code_by_id_
.size(), symbol_id
);
226 // Load, and shift code to low bits.
227 unsigned length
= length_by_id_
[symbol_id
];
228 uint32 code
= code_by_id_
[symbol_id
] >> (32 - length
);
230 bit_remnant
= (bit_remnant
+ length
) % 8;
233 out
->AppendBits(static_cast<uint8
>(code
>> 24), length
- 24);
237 out
->AppendBits(static_cast<uint8
>(code
>> 16), length
- 16);
241 out
->AppendBits(static_cast<uint8
>(code
>> 8), length
- 8);
244 out
->AppendBits(static_cast<uint8
>(code
), length
);
246 if (bit_remnant
!= 0) {
247 // Pad current byte as required.
248 out
->AppendBits(pad_bits_
>> bit_remnant
, 8 - bit_remnant
);
252 size_t HpackHuffmanTable::EncodedSize(StringPiece in
) const {
253 size_t bit_count
= 0;
254 for (size_t i
= 0; i
!= in
.size(); i
++) {
255 uint16 symbol_id
= static_cast<uint8
>(in
[i
]);
256 CHECK_GT(code_by_id_
.size(), symbol_id
);
258 bit_count
+= length_by_id_
[symbol_id
];
260 if (bit_count
% 8 != 0) {
261 bit_count
+= 8 - bit_count
% 8;
263 return bit_count
/ 8;
266 bool HpackHuffmanTable::DecodeString(HpackInputStream
* in
,
269 // Number of decode iterations required for a 32-bit code.
270 const int kDecodeIterations
= static_cast<int>(
271 std::ceil((32.f
- kDecodeTableRootBits
) / kDecodeTableBranchBits
));
275 // Current input, stored in the high |bits_available| bits of |bits|.
277 size_t bits_available
= 0;
278 bool peeked_success
= in
->PeekBits(&bits_available
, &bits
);
281 const DecodeTable
* table
= &decode_tables_
[0];
282 uint32 index
= bits
>> (32 - kDecodeTableRootBits
);
284 for (int i
= 0; i
!= kDecodeIterations
; i
++) {
285 DCHECK_LT(index
, table
->size());
286 DCHECK_LT(Entry(*table
, index
).next_table_index
, decode_tables_
.size());
288 table
= &decode_tables_
[Entry(*table
, index
).next_table_index
];
289 // Mask and shift the portion of the code being indexed into low bits.
290 index
= (bits
<< table
->prefix_length
) >> (32 - table
->indexed_length
);
292 const DecodeEntry
& entry
= Entry(*table
, index
);
294 if (entry
.length
> bits_available
) {
295 if (!peeked_success
) {
296 // Unable to read enough input for a match. If only a portion of
297 // the last byte remains, this is a successful EOF condition.
298 in
->ConsumeByteRemainder();
299 return !in
->HasMoreData();
301 } else if (entry
.length
== 0) {
302 // The input is an invalid prefix, larger than any prefix in the table.
305 if (out
->size() == out_capacity
) {
306 // This code would cause us to overflow |out_capacity|.
309 if (entry
.symbol_id
< 256) {
310 // Assume symbols >= 256 are used for padding.
311 out
->push_back(static_cast<char>(entry
.symbol_id
));
314 in
->ConsumeBits(entry
.length
);
315 bits
= bits
<< entry
.length
;
316 bits_available
-= entry
.length
;
318 peeked_success
= in
->PeekBits(&bits_available
, &bits
);