don't try to write data after error
[ghsmtp.git] / Base64.cpp
blob2d4d862c57f7414ef7cb4e813048707360a52201
1 #include "Base64.hpp"
3 #include <algorithm>
4 #include <cctype>
5 #include <stdexcept>
7 #include <glog/logging.h>
9 namespace Base64 {
11 constexpr char const CHARSET[]{
12 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"};
14 namespace {
15 auto CHARSET_find(unsigned char ch)
17 return static_cast<unsigned char>(
18 std::find(std::begin(CHARSET), std::end(CHARSET), ch)
19 - std::begin(CHARSET));
21 } // namespace
23 std::string enc(std::string_view text, std::string::size_type wrap)
25 unsigned char group_8bit[3];
26 unsigned char group_6bit[4];
27 int count_3_chars = 0;
29 auto const input_size = text.length();
30 auto const padding = ((input_size % 3) ? (3 - (input_size % 3)) : 0);
31 auto const code_padded_size = ((input_size + padding) / 3) * 4;
32 auto const newline_size = wrap ? ((code_padded_size) / wrap) * 2 : 0;
33 auto const total_size = code_padded_size + newline_size;
35 std::string enc_text;
36 enc_text.reserve(total_size);
37 std::string::size_type line_len = 0;
39 for (std::string::size_type ch = 0; ch < text.length(); ch++) {
40 group_8bit[count_3_chars++] = text[ch];
41 if (count_3_chars == 3) {
42 group_6bit[0] = (group_8bit[0] & 0xfc) >> 2;
43 group_6bit[1]
44 = ((group_8bit[0] & 0x03) << 4) + ((group_8bit[1] & 0xf0) >> 4);
45 group_6bit[2]
46 = ((group_8bit[1] & 0x0f) << 2) + ((group_8bit[2] & 0xc0) >> 6);
47 group_6bit[3] = group_8bit[2] & 0x3f;
49 for (int i = 0; i < 4; i++)
50 enc_text += CHARSET[group_6bit[i]];
51 count_3_chars = 0;
52 line_len += 4;
55 if (wrap && (line_len == wrap)) {
56 enc_text += "\r\n";
57 line_len = 0;
61 // encode remaining characters if any
63 if (count_3_chars > 0) {
64 for (int i = count_3_chars; i < 3; i++)
65 group_8bit[i] = '\0';
67 group_6bit[0] = (group_8bit[0] & 0xfc) >> 2;
68 group_6bit[1]
69 = ((group_8bit[0] & 0x03) << 4) + ((group_8bit[1] & 0xf0) >> 4);
70 group_6bit[2]
71 = ((group_8bit[1] & 0x0f) << 2) + ((group_8bit[2] & 0xc0) >> 6);
72 group_6bit[3] = group_8bit[2] & 0x3f;
74 for (int i = 0; i < count_3_chars + 1; i++) {
75 if (wrap && (line_len == wrap)) {
76 enc_text += "\r\n";
77 line_len = 0;
79 enc_text += CHARSET[group_6bit[i]];
80 line_len++;
83 while (count_3_chars++ < 3) {
84 if (wrap && (line_len == wrap)) {
85 enc_text += "\r\n";
86 line_len = 0;
88 enc_text += '=';
89 line_len++;
93 CHECK_EQ(enc_text.length(), total_size);
95 return enc_text;
98 bool is_base64char(char ch)
100 return std::isalnum(ch) || ch == '+' || ch == '/';
103 std::string dec(std::string_view text)
105 auto const input_size = text.length();
106 auto const max_size = (input_size / 4) * 3;
108 std::string dec_text;
109 dec_text.reserve(max_size);
110 unsigned char group_6bit[4];
111 unsigned char group_8bit[3];
112 int count_4_chars = 0;
114 for (std::string::size_type ch = 0; ch < text.length(); ch++) {
115 if (text[ch] == '=')
116 break;
118 if ((text[ch] == '\r') || (text[ch] == '\n'))
119 continue;
121 if (!is_base64char(text[ch]))
122 throw std::invalid_argument("bad character in decode");
124 group_6bit[count_4_chars++] = text[ch];
125 if (count_4_chars == 4) {
126 for (int i = 0; i < 4; i++)
127 group_6bit[i] = CHARSET_find(group_6bit[i]);
129 group_8bit[0] = (group_6bit[0] << 2) + ((group_6bit[1] & 0x30) >> 4);
130 group_8bit[1]
131 = ((group_6bit[1] & 0xf) << 4) + ((group_6bit[2] & 0x3c) >> 2);
132 group_8bit[2] = ((group_6bit[2] & 0x3) << 6) + group_6bit[3];
134 for (int i = 0; i < 3; i++)
135 dec_text += group_8bit[i];
136 count_4_chars = 0;
140 // decode remaining characters if any
142 if (count_4_chars > 0) {
143 for (int i = count_4_chars; i < 4; i++)
144 group_6bit[i] = '\0';
146 for (int i = 0; i < 4; i++)
147 group_6bit[i] = CHARSET_find(group_6bit[i]);
149 group_8bit[0] = (group_6bit[0] << 2) + ((group_6bit[1] & 0x30) >> 4);
150 group_8bit[1]
151 = ((group_6bit[1] & 0xf) << 4) + ((group_6bit[2] & 0x3c) >> 2);
152 group_8bit[2] = ((group_6bit[2] & 0x3) << 6) + group_6bit[3];
154 for (int i = 0; i < count_4_chars - 1; i++)
155 dec_text += group_8bit[i];
158 return dec_text;
160 } // namespace Base64