1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/spdy/hpack_huffman_table.h"
10 #include "base/logging.h"
11 #include "net/spdy/hpack_input_stream.h"
12 #include "net/spdy/hpack_output_stream.h"
16 using base::StringPiece
;
21 // How many bits to index in the root decode table.
22 const uint8 kDecodeTableRootBits
= 9;
23 // Maximum number of bits to index in successive decode tables.
24 const uint8 kDecodeTableBranchBits
= 6;
26 bool SymbolLengthAndIdCompare(const HpackHuffmanSymbol
& a
,
27 const HpackHuffmanSymbol
& b
) {
28 if (a
.length
== b
.length
) {
31 return a
.length
< b
.length
;
33 bool SymbolIdCompare(const HpackHuffmanSymbol
& a
,
34 const HpackHuffmanSymbol
& b
) {
40 HpackHuffmanTable::DecodeEntry::DecodeEntry()
41 : next_table_index(0), length(0), symbol_id(0) {
43 HpackHuffmanTable::DecodeEntry::DecodeEntry(uint8 next_table_index
,
46 : next_table_index(next_table_index
), length(length
), symbol_id(symbol_id
) {
48 size_t HpackHuffmanTable::DecodeTable::size() const {
49 return size_t(1) << indexed_length
;
52 HpackHuffmanTable::HpackHuffmanTable() {}
54 HpackHuffmanTable::~HpackHuffmanTable() {}
56 bool HpackHuffmanTable::Initialize(const HpackHuffmanSymbol
* input_symbols
,
57 size_t symbol_count
) {
58 CHECK(!IsInitialized());
60 std::vector
<Symbol
> symbols(symbol_count
);
61 // Validate symbol id sequence, and copy into |symbols|.
62 for (size_t i
= 0; i
!= symbol_count
; i
++) {
63 if (i
!= input_symbols
[i
].id
) {
64 failed_symbol_id_
= i
;
67 symbols
[i
] = input_symbols
[i
];
69 // Order on length and ID ascending, to verify symbol codes are canonical.
70 std::sort(symbols
.begin(), symbols
.end(), SymbolLengthAndIdCompare
);
71 if (symbols
[0].code
!= 0) {
72 failed_symbol_id_
= 0;
75 for (size_t i
= 1; i
!= symbols
.size(); i
++) {
76 unsigned code_shift
= 32 - symbols
[i
-1].length
;
77 uint32 code
= symbols
[i
-1].code
+ (1 << code_shift
);
79 if (code
!= symbols
[i
].code
) {
80 failed_symbol_id_
= symbols
[i
].id
;
83 if (code
< symbols
[i
-1].code
) {
84 // An integer overflow occurred. This implies the input
85 // lengths do not represent a valid Huffman code.
86 failed_symbol_id_
= symbols
[i
].id
;
90 if (symbols
.back().length
< 8) {
91 // At least one code (such as an EOS symbol) must be 8 bits or longer.
92 // Without this, some inputs will not be encodable in a whole number
96 pad_bits_
= static_cast<uint8
>(symbols
.back().code
>> 24);
98 BuildDecodeTables(symbols
);
99 // Order on symbol ID ascending.
100 std::sort(symbols
.begin(), symbols
.end(), SymbolIdCompare
);
101 BuildEncodeTable(symbols
);
105 void HpackHuffmanTable::BuildEncodeTable(const std::vector
<Symbol
>& symbols
) {
106 for (size_t i
= 0; i
!= symbols
.size(); i
++) {
107 const Symbol
& symbol
= symbols
[i
];
108 CHECK_EQ(i
, symbol
.id
);
109 code_by_id_
.push_back(symbol
.code
);
110 length_by_id_
.push_back(symbol
.length
);
114 void HpackHuffmanTable::BuildDecodeTables(const std::vector
<Symbol
>& symbols
) {
115 AddDecodeTable(0, kDecodeTableRootBits
);
116 // We wish to maximize the flatness of the DecodeTable hierarchy (subject to
117 // the |kDecodeTableBranchBits| constraint), and to minimize the size of
118 // child tables. To achieve this, we iterate in order of descending code
119 // length. This ensures that child tables are visited with their longest
120 // entry first, and that the child can therefore be minimally sized to hold
121 // that entry without fear of introducing unneccesary branches later.
122 for (std::vector
<Symbol
>::const_reverse_iterator it
= symbols
.rbegin();
123 it
!= symbols
.rend(); ++it
) {
124 uint8 table_index
= 0;
126 const DecodeTable table
= decode_tables_
[table_index
];
128 // Mask and shift the portion of the code being indexed into low bits.
129 uint32 index
= (it
->code
<< table
.prefix_length
);
130 index
= index
>> (32 - table
.indexed_length
);
132 CHECK_LT(index
, table
.size());
133 DecodeEntry entry
= Entry(table
, index
);
135 uint8 total_indexed
= table
.prefix_length
+ table
.indexed_length
;
136 if (total_indexed
>= it
->length
) {
137 // We're writing a terminal entry.
138 entry
.length
= it
->length
;
139 entry
.symbol_id
= it
->id
;
140 entry
.next_table_index
= table_index
;
141 SetEntry(table
, index
, entry
);
145 if (entry
.length
== 0) {
146 // First visit to this placeholder. We need to create a new table.
147 CHECK_EQ(entry
.next_table_index
, 0);
148 entry
.length
= it
->length
;
149 entry
.next_table_index
= AddDecodeTable(
150 total_indexed
, // Becomes the new table prefix.
151 std::min
<uint8
>(kDecodeTableBranchBits
,
152 entry
.length
- total_indexed
));
153 SetEntry(table
, index
, entry
);
155 CHECK_NE(entry
.next_table_index
, table_index
);
156 table_index
= entry
.next_table_index
;
159 // Fill shorter table entries into the additional entry spots they map to.
160 for (size_t i
= 0; i
!= decode_tables_
.size(); i
++) {
161 const DecodeTable
& table
= decode_tables_
[i
];
162 uint8 total_indexed
= table
.prefix_length
+ table
.indexed_length
;
165 while (j
!= table
.size()) {
166 const DecodeEntry
& entry
= Entry(table
, j
);
167 if (entry
.length
!= 0 && entry
.length
< total_indexed
) {
168 // The difference between entry & table bit counts tells us how
169 // many additional entries map to this one.
170 size_t fill_count
= 1 << (total_indexed
- entry
.length
);
171 CHECK_LE(j
+ fill_count
, table
.size());
173 for (size_t k
= 1; k
!= fill_count
; k
++) {
174 CHECK_EQ(Entry(table
, j
+ k
).length
, 0);
175 SetEntry(table
, j
+ k
, entry
);
185 uint8
HpackHuffmanTable::AddDecodeTable(uint8 prefix
, uint8 indexed
) {
186 CHECK_LT(decode_tables_
.size(), 255u);
189 table
.prefix_length
= prefix
;
190 table
.indexed_length
= indexed
;
191 table
.entries_offset
= decode_entries_
.size();
192 decode_tables_
.push_back(table
);
194 decode_entries_
.resize(decode_entries_
.size() + (size_t(1) << indexed
));
195 return static_cast<uint8
>(decode_tables_
.size() - 1);
198 const HpackHuffmanTable::DecodeEntry
& HpackHuffmanTable::Entry(
199 const DecodeTable
& table
,
200 uint32 index
) const {
201 DCHECK_LT(index
, table
.size());
202 DCHECK_LT(table
.entries_offset
+ index
, decode_entries_
.size());
203 return decode_entries_
[table
.entries_offset
+ index
];
206 void HpackHuffmanTable::SetEntry(const DecodeTable
& table
,
208 const DecodeEntry
& entry
) {
209 CHECK_LT(index
, table
.size());
210 CHECK_LT(table
.entries_offset
+ index
, decode_entries_
.size());
211 decode_entries_
[table
.entries_offset
+ index
] = entry
;
214 bool HpackHuffmanTable::IsInitialized() const {
215 return !code_by_id_
.empty();
218 void HpackHuffmanTable::EncodeString(StringPiece in
,
219 HpackOutputStream
* out
) const {
220 size_t bit_remnant
= 0;
221 for (size_t i
= 0; i
!= in
.size(); i
++) {
222 uint16 symbol_id
= static_cast<uint8
>(in
[i
]);
223 CHECK_GT(code_by_id_
.size(), symbol_id
);
225 // Load, and shift code to low bits.
226 unsigned length
= length_by_id_
[symbol_id
];
227 uint32 code
= code_by_id_
[symbol_id
] >> (32 - length
);
229 bit_remnant
= (bit_remnant
+ length
) % 8;
232 out
->AppendBits(static_cast<uint8
>(code
>> 24), length
- 24);
236 out
->AppendBits(static_cast<uint8
>(code
>> 16), length
- 16);
240 out
->AppendBits(static_cast<uint8
>(code
>> 8), length
- 8);
243 out
->AppendBits(static_cast<uint8
>(code
), length
);
245 if (bit_remnant
!= 0) {
246 // Pad current byte as required.
247 out
->AppendBits(pad_bits_
>> bit_remnant
, 8 - bit_remnant
);
251 size_t HpackHuffmanTable::EncodedSize(StringPiece in
) const {
252 size_t bit_count
= 0;
253 for (size_t i
= 0; i
!= in
.size(); i
++) {
254 uint16 symbol_id
= static_cast<uint8
>(in
[i
]);
255 CHECK_GT(code_by_id_
.size(), symbol_id
);
257 bit_count
+= length_by_id_
[symbol_id
];
259 if (bit_count
% 8 != 0) {
260 bit_count
+= 8 - bit_count
% 8;
262 return bit_count
/ 8;
265 bool HpackHuffmanTable::DecodeString(HpackInputStream
* in
,
268 // Number of decode iterations required for a 32-bit code.
269 const int kDecodeIterations
= static_cast<int>(
270 std::ceil((32.f
- kDecodeTableRootBits
) / kDecodeTableBranchBits
));
274 // Current input, stored in the high |bits_available| bits of |bits|.
276 size_t bits_available
= 0;
277 bool peeked_success
= in
->PeekBits(&bits_available
, &bits
);
280 const DecodeTable
* table
= &decode_tables_
[0];
281 uint32 index
= bits
>> (32 - kDecodeTableRootBits
);
283 for (int i
= 0; i
!= kDecodeIterations
; i
++) {
284 DCHECK_LT(index
, table
->size());
285 DCHECK_LT(Entry(*table
, index
).next_table_index
, decode_tables_
.size());
287 table
= &decode_tables_
[Entry(*table
, index
).next_table_index
];
288 // Mask and shift the portion of the code being indexed into low bits.
289 index
= (bits
<< table
->prefix_length
) >> (32 - table
->indexed_length
);
291 const DecodeEntry
& entry
= Entry(*table
, index
);
293 if (entry
.length
> bits_available
) {
294 if (!peeked_success
) {
295 // Unable to read enough input for a match. If only a portion of
296 // the last byte remains, this is a successful EOF condition.
297 in
->ConsumeByteRemainder();
298 return !in
->HasMoreData();
300 } else if (entry
.length
== 0) {
301 // The input is an invalid prefix, larger than any prefix in the table.
304 if (out
->size() == out_capacity
) {
305 // This code would cause us to overflow |out_capacity|.
308 if (entry
.symbol_id
< 256) {
309 // Assume symbols >= 256 are used for padding.
310 out
->push_back(static_cast<char>(entry
.symbol_id
));
313 in
->ConsumeBits(entry
.length
);
314 bits
= bits
<< entry
.length
;
315 bits_available
-= entry
.length
;
317 peeked_success
= in
->PeekBits(&bits_available
, &bits
);