fix
[libpgclient.git] / src / pgprov3.c
blobaf91071b0ffe575ea5b90eb3793558d6fd3d604d
1 /*Copyright (c) Brian B.
3 This library is free software; you can redistribute it and/or
4 modify it under the terms of the GNU Lesser General Public
5 License as published by the Free Software Foundation; either
6 version 3 of the License, or (at your option) any later version.
7 See the file LICENSE included with this distribution for more
8 information.
9 */
10 #include "md5.h"
11 #include "hmac.h"
12 #include "libpgcli/pgprov3.h"
14 #define CHUNKSIZE 32
15 #define HEADER_SIZE sizeof(char) + sizeof(int32_t)
17 #define ADD_LIST_ITEM(list,item) \
18 if (!list->head) { \
19 item->next = item->prev = item; \
20 list->head = item; \
21 } else { \
22 item->next = list->head; \
23 item->prev = list->head->prev; \
24 list->head->prev = item; \
25 item->prev->next = item; \
26 } \
27 ++list->len;
29 void pgmsg_list_add (pgmsg_list_t *list, pgmsg_t *msg) {
30 ADD_LIST_ITEM(list, msg)
33 void pgmsg_datarow_list_add (pgmsg_datarow_list_t *list, pgmsg_datarow_t *datarow) {
34 pgmsg_datarow_item_t *item = calloc(1, sizeof(pgmsg_datarow_item_t));
35 item->datarow = datarow;
36 ADD_LIST_ITEM(list, item)
39 #define DEL_LIST_ITEM(type,list,item) \
40 type *prev = item->prev, *next = item->next; \
41 if (prev) prev->next = next; \
42 if (next) next->prev = prev; \
43 if (item == list->head) list->head = next; \
44 if (0 == --list->len) list->head = NULL;
46 static void _pgmsg_list_del (pgmsg_list_t *list, pgmsg_t *msg) {
47 DEL_LIST_ITEM(pgmsg_t,list, msg)
50 static void _pgmsg_datarow_list_del (pgmsg_datarow_list_t *list, pgmsg_datarow_item_t *item) {
51 DEL_LIST_ITEM(pgmsg_datarow_item_t,list, item)
52 free(item);
55 void pgmsg_list_clear (pgmsg_list_t *list) {
56 while (list->head)
57 _pgmsg_list_del(list, list->head);
60 void pgmsg_datarow_list_clear (pgmsg_datarow_list_t *list) {
61 while (list->head)
62 _pgmsg_datarow_list_del(list, list->head);
65 static int _resize (pgmsg_t **msg, size_t len) {
66 pgmsg_t *res;
67 int nlen = (*msg)->len + len;
68 size_t pc_len,
69 bufsize = sizeof(pgmsg_t) + (nlen / CHUNKSIZE) * CHUNKSIZE + CHUNKSIZE;
70 if (bufsize == (*msg)->bufsize)
71 return 0;
72 pc_len = (uintptr_t)(*msg)->pc - (uintptr_t)(*msg)->body.ptr;
73 if (!(res = realloc(*msg, bufsize)))
74 return -1;
75 res->bufsize = bufsize;
76 res->pc = res->body.ptr + pc_len;
77 *msg = res;
78 return 0;
81 pgmsg_t *pgmsg_create (char type) {
82 size_t bufsize = sizeof(pgmsg_t) + CHUNKSIZE;
83 pgmsg_t *msg = malloc(bufsize);
84 if (!msg) return NULL;
85 msg->bufsize = bufsize;
86 msg->len = sizeof(pgmsg_t);
87 msg->pc = msg->body.ptr;
88 msg->body.type = type;
89 msg->body.len = htobe32(sizeof(int32_t));
90 return msg;
93 static int _seti8 (pgmsg_t **msg, int8_t x) {
94 if (-1 == _resize(msg, sizeof(int8_t)))
95 return -1;
96 *((int8_t*)(*msg)->pc) = x;
97 (*msg)->pc += sizeof(int8_t);
98 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int8_t));
99 (*msg)->len += sizeof(int8_t);
100 return 0;
103 static int _seti16 (pgmsg_t **msg, int16_t x) {
104 if (-1 == _resize(msg, sizeof(int16_t)))
105 return -1;
106 *((int16_t*)(*msg)->pc) = htobe16(x);
107 (*msg)->pc += sizeof(int16_t);
108 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int16_t));
109 (*msg)->len += sizeof(int16_t);
110 return 0;
113 static int _seti32 (pgmsg_t **msg, int32_t x) {
114 if (-1 == _resize(msg, sizeof(int32_t)))
115 return -1;
116 *((int32_t*)(*msg)->pc) = htobe32(x);
117 (*msg)->pc += sizeof(int32_t);
118 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int32_t));
119 (*msg)->len += sizeof(int32_t);
120 return 0;
123 static int _setstr (pgmsg_t **msg, const char *s, size_t slen) {
124 if (0 == slen)
125 return 0;
126 if (-1 == _resize(msg, slen))
127 return -1;
128 memcpy((*msg)->pc, s, slen);
129 (*msg)->pc += slen;
130 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + slen);
131 (*msg)->len += slen;
132 return 0;
135 static int8_t _geti8 (pgmsg_t *msg) {
136 int8_t r = *msg->pc;
137 msg->pc += sizeof(int8_t);
138 return r;
141 static int16_t _geti16 (pgmsg_t *msg) {
142 int16_t r = be16toh(*((int16_t*)msg->pc));
143 msg->pc += sizeof(int16_t);
144 return r;
147 static int32_t _geti32 (pgmsg_t *msg) {
148 int32_t r = be32toh(*((int32_t*)msg->pc));
149 msg->pc += sizeof(int32_t);
150 return r;
153 static const char *_getstr (pgmsg_t *msg) {
154 char *p = msg->pc, *e = msg->body.ptr + be32toh(msg->body.len);
155 while (*(msg->pc) && msg->pc < e) ++msg->pc;
156 if (e == msg->pc)
157 return NULL;
158 ++msg->pc;
159 return p;
162 int pgmsg_set_param (pgmsg_t **msg, const char *name, size_t name_len, const char *value, size_t value_len) {
163 if (-1 == _setstr(msg, name, name_len) ||
164 -1 == _seti8(msg, 0) ||
165 -1 == _setstr(msg, value, value_len) ||
166 -1 == _seti8(msg, 0))
167 return -1;
168 return 0;
171 pgmsg_t *pgmsg_create_startup (const char *user, size_t user_len, const char *database, size_t database_len) {
172 pgmsg_t *msg = pgmsg_create('\0');
173 _seti16(&msg, PG_MAJOR_VER);
174 _seti16(&msg, PG_MINOR_VER);
175 pgmsg_set_param(&msg, CONST_STR_LEN("user"), user, user_len);
176 pgmsg_set_param(&msg, CONST_STR_LEN("database"), database, database_len);
177 return msg;
180 #define PG_USERDEFS 0x00000001
181 #define PG_DBDEFS 0x00000002
182 #define PG_ENCDEFS 0x00000004
184 static inline pgmsg_t *_add_str_param (pgmsg_t *msg, const char *k, size_t lk, const char *v, size_t lv) {
185 _setstr(&msg, k, lk);
186 _seti8(&msg, 0);
187 _setstr(&msg, v, lv);
188 _seti8(&msg, 0);
189 return msg;
192 static pgmsg_t *_add_param (pgmsg_t *msg, const char *begin, const char *end, uint32_t *flags) {
193 const char *p = begin, *p1;
194 while (p < end && '=' != *p) ++p;
195 if (0 == strncmp(begin, "password", (uintptr_t)p - (uintptr_t)begin))
196 return msg;
197 if (p == end)
198 return msg;
199 p1 = p;
200 while (p1 > begin && isspace(*(p1 - 1))) --p1;
201 if (p == p1 - 1)
202 return msg;
203 ++p;
204 while (p < end && isspace(*p)) ++p;
205 if (p == end)
206 return msg;
207 if (0 == strncmp(begin, "dbname", (uintptr_t)p1 - (uintptr_t)begin)) {
208 msg = _add_str_param(msg, CONST_STR_LEN("database"), p, (uintptr_t)end - (uintptr_t)p);
209 if (flags)
210 *flags |= PG_DBDEFS;
211 } else
212 if (0 == strncmp(begin, "user", (uintptr_t)p1 - (uintptr_t)begin)) {
213 msg = _add_str_param(msg, begin, (uintptr_t)p1 - (uintptr_t)begin, p, (uintptr_t)end - (uintptr_t)p);
214 if (flags)
215 *flags |= PG_USERDEFS;
216 } else
217 if (0 != strncmp(begin, "host", (uintptr_t)p1 - (uintptr_t)begin) &&
218 0 != strncmp(begin, "port", (uintptr_t)p1 - (uintptr_t)begin))
219 msg = _add_str_param(msg, begin, (uintptr_t)p1 - (uintptr_t)begin, p, (uintptr_t)end - (uintptr_t)p);
220 return msg;
223 void *parse_conninfo (void *data, const char *conn_info, parse_param_h fn, uint32_t *flags) {
224 if (conn_info) {
225 const char *p = conn_info;
226 while (*p) {
227 const char *q;
228 while (isspace(*p)) ++p;
229 if (!(*p)) break;
230 q = p;
231 while (*p && !isspace(*p)) ++p;
232 data = fn(data, q, p, flags);
235 return data;
238 static pgmsg_t *_startup_params (pgmsg_t *msg, const char *conn_info) {
239 uint32_t flags = 0;
240 msg = (pgmsg_t*)parse_conninfo((void*)msg, conn_info, (parse_param_h)_add_param, &flags);
241 if (!(flags & PG_USERDEFS)) {
242 char *s = getenv("USER");
243 if (s)
244 msg = _add_str_param(msg, CONST_STR_LEN("user"), s, strlen(s));
246 if (!(flags & PG_DBDEFS)) {
247 char *s = getenv("USER");
248 if (s)
249 msg = _add_str_param(msg, CONST_STR_LEN("database"), s, strlen(s));
251 if (!(flags & PG_ENCDEFS)) {
252 char *s = getenv("LANG");
253 if (s && (s = strchr(s, '.')) && *++s)
254 msg = _add_str_param(msg, CONST_STR_LEN("client_encoding"), s, strlen(s));
256 _seti8(&msg, 0);
257 return msg;
260 pgmsg_t *pgmsg_create_startup_params (const char *conn_info) {
261 pgmsg_t *msg = pgmsg_create('\0');
262 _seti16(&msg, PG_MAJOR_VER);
263 _seti16(&msg, PG_MINOR_VER);
264 return _startup_params(msg, conn_info);
267 pgmsg_t *pgmsg_create_simple_query (const char *sql, size_t sql_len) {
268 pgmsg_t *msg = pgmsg_create(PG_SIMPLEQUERY);
269 _setstr(&msg, sql, sql_len);
270 _seti8(&msg, 0);
271 return msg;
274 pgmsg_t *pgmsg_create_parse (const char *name, size_t name_len, const char *sql, size_t sql_len, int fld_len, pgfld_t **flds) {
275 pgmsg_t *msg = pgmsg_create(PG_PARSE);
276 if (name && 0 == name_len)
277 name_len = strlen(name);
278 if (0 == sql_len)
279 sql_len = strlen(sql);
280 _setstr(&msg, name, name_len);
281 _seti8(&msg, 0);
282 _setstr(&msg, sql, sql_len);
283 _seti8(&msg, 0);
284 _seti16(&msg, fld_len);
285 for (int i = 0; i < fld_len; ++i)
286 _seti32(&msg, flds[i]->oid);
287 return msg;
290 pgmsg_t *pgmsg_create_bind (const char *portal, size_t portal_len, const char *stmt, size_t stmt_len,
291 int fld_len, pgfld_t **flds, int res_fmt_len, int *res_fmt) {
292 pgmsg_t *msg = pgmsg_create(PG_BIND);
293 _setstr(&msg, portal, portal_len);
294 _seti8(&msg, 0);
295 _setstr(&msg, stmt, stmt_len);
296 _seti8(&msg, 0);
297 _seti16(&msg, fld_len);
298 for (int i = 0; i < fld_len; ++i)
299 _seti16(&msg, flds[i]->fmt);
300 _seti16(&msg, fld_len);
301 for (int i = 0; i < fld_len; ++i) {
302 if (flds[i]->is_null)
303 _seti32(&msg, -1);
304 else {
305 _seti32(&msg, flds[i]->len);
306 if (0 == flds[i]->fmt || OID_UUID == flds[i]->oid)
307 _setstr(&msg, flds[i]->data.s, flds[i]->len);
308 else
309 _setstr(&msg, (const char*)&flds[i]->data, flds[i]->len);
312 _seti16(&msg, res_fmt_len);
313 for (int i = 0; i < res_fmt_len; ++i)
314 _seti16(&msg, res_fmt[i]);
315 return msg;
318 pgmsg_t *pgmsg_create_describe (uint8_t op, const char *name, size_t name_len) {
319 pgmsg_t *msg = pgmsg_create(PG_DESCRIBE);
320 _seti8(&msg, op);
321 _setstr(&msg, name, name_len);
322 _seti8(&msg, 0);
323 return msg;
326 pgmsg_t *pgmsg_create_execute (const char *portal, size_t portal_len, int32_t max_rows) {
327 pgmsg_t *msg = pgmsg_create(PG_EXECUTE);
328 _setstr(&msg, portal, portal_len);
329 _seti8(&msg, 0);
330 _seti32(&msg, max_rows);
331 return msg;
334 pgmsg_t *pgmsg_create_close(char what, const char *str, size_t slen) {
335 pgmsg_t *msg = pgmsg_create(PG_CLOSE);
336 _seti8(&msg, (uint8_t)what);
337 _setstr(&msg, str, slen);
338 _seti8(&msg, 0);
339 return msg;
342 static void _pg_md5_hash (const void *buff, size_t len, char *out) {
343 uint8_t digest [16];
344 md5_t ctx;
345 md5_init(&ctx);
346 while (len > 0) {
347 if (len > 512)
348 md5_update(&ctx, (uint8_t*)buff, 512);
349 else
350 md5_update(&ctx, (uint8_t*)buff, len);
351 buff += 512;
352 len -= 512;
354 md5_final(&ctx);
355 memcpy(digest, ctx.digest, sizeof digest);
356 for (int i = 0; i < 16; ++i)
357 snprintf(&(out[i*2]), 16*2, "%02x", (uint8_t)digest[i]);
360 static void _pg_md5_encrypt(const char *passwd, const char *salt, size_t salt_len, char *buf) {
361 size_t passwd_len = strlen(passwd);
362 char *crypt_buf = malloc(passwd_len + salt_len + 1);
363 memcpy(crypt_buf, passwd, passwd_len);
364 memcpy(crypt_buf + passwd_len, salt, salt_len);
365 strcpy(buf, "md5");
366 _pg_md5_hash(crypt_buf, passwd_len + salt_len, buf + 3);
367 free(crypt_buf);
370 pgmsg_t *pgmsg_create_pass (int req, const char *salt, size_t salt_len, const char *user, const char *pass) {
371 pgmsg_t *msg = pgmsg_create(PG_PASS);
372 char *pwd = malloc(2 * (PG_MD5PASS_LEN + 1)),
373 *pwd2 = pwd + PG_MD5PASS_LEN + 1,
374 *pwd_to_send;
375 _pg_md5_encrypt(pass, user, strlen(user), pwd2);
376 _pg_md5_encrypt(pwd2 + sizeof("md5")-1, salt, 4, pwd);
377 pwd_to_send = pwd;
378 switch (req) {
379 case PG_REQMD5:
380 pwd_to_send = pwd;
381 break;
382 case PG_REQPASS:
383 pwd_to_send = (char*)pass;
385 _setstr(&msg, pwd_to_send, strlen(pwd_to_send));
386 _seti8(&msg, 0);
387 free(pwd);
388 return msg;
391 pgmsg_t *pgmsg_create_sasl_init (conninfo_t *cinfo) {
392 char raw_nonce [SCRAM_RAW_NONCE_LEN+1];
393 strand(raw_nonce, SCRAM_RAW_NONCE_LEN, RAND_ALNUM);
394 raw_nonce[SCRAM_RAW_NONCE_LEN] = '\0';
395 cinfo->nonce = cstr_b64encode(raw_nonce, SCRAM_RAW_NONCE_LEN);
396 str_t *str = strfmt("n,,n=,r=%s", cinfo->nonce->ptr);
397 cinfo->fmsg_bare = mkcstr(str->ptr+3, str->len-3);
398 pgmsg_t *msg = pgmsg_create(PG_PASS);
399 _setstr(&msg, CONST_STR_LEN("SCRAM-SHA-256"));
400 _seti8(&msg, 0);
401 _seti32(&msg, str->len);
402 _setstr(&msg, str->ptr, str->len);
403 free(str);
404 return msg;
407 static int _parse_scram_final (pgmsg_resp_t *resp, conninfo_t *cinfo) {
408 strptr_t entry = { .ptr = NULL, .len = 0 };
409 char *data = (char*)resp->msg_auth.kind.sasl_auth.data;
410 int len = resp->msg_auth.kind.sasl_auth.len;
411 if ('e' == *resp->msg_auth.kind.sasl_auth.data)
412 return -1;
413 if (!(cinfo->srv_scram_msg = strsplit(data, len, ',')))
414 return -1;
415 cinfo->fmsg_srv = mkcstr(data, len);
416 while (0 == strnext(cinfo->srv_scram_msg, &entry)) {
417 if (entry.len < 3 && '=' != entry.ptr[1])
418 continue;
419 switch (*entry.ptr) {
420 case 'r':
421 cinfo->r_attr = entry.ptr+2;
422 break;
423 case 's':
424 cinfo->s_attr = entry.ptr+2;
425 break;
426 case 'i':
427 cinfo->i_attr = entry.ptr+2;
428 break;
431 return cinfo->r_attr && cinfo->s_attr && cinfo->i_attr ? 0 : -1;
434 static void _scram_create_key (uint8_t *salted_password, uint8_t *result) {
435 hmac_t ctx;
436 hmac_init(&ctx, salted_password, SCRAM_KEY_LEN);
437 hmac_update(&ctx, (uint8_t*)"Client Key", sizeof("Client Key")-1);
438 hmac_final(&ctx, result, SCRAM_KEY_LEN);
441 char *(*on_pgauth) (const char *prompt, int is_echo);
442 static struct termios term_settings;
443 static char *_pg_auth (const char *prompt, const char *def_auth, int is_echo) {
444 char *s = NULL;
445 size_t len = 0;
446 ssize_t ilen;
447 struct termios tc;
448 if (on_pgauth)
449 return on_pgauth(prompt, is_echo);
450 printf("%s", prompt);
451 if (!is_echo) {
452 tcgetattr(0, &term_settings);
453 tc = term_settings;
454 tc.c_lflag &= ~ECHO;
455 tcsetattr(0, TCSANOW, &tc);
457 if (-1 == (ilen = getline(&s, &len, stdin))) {
458 if (!is_echo)
459 tcsetattr(0, TCSANOW, &term_settings);
460 return strdup(def_auth);
462 if (!is_echo) {
463 tcsetattr(0, TCSANOW, &term_settings);
464 printf("\n");
466 s[ilen-1] = '\0';
467 return s;
470 static void _scram_salted_password (conninfo_t *cinfo, cstr_t *salt) {
471 hmac_t ctx;
472 if (!cinfo->user)
473 cinfo->user = _pg_auth("Enter username: ", getenv("USER"), 1);
474 if (!cinfo->pass)
475 cinfo->pass = _pg_auth("Enter password: ", "", 0);
476 int password_len = strlen(cinfo->pass);
477 uint32_t one = htobe32(1);
478 uint8_t ui_prev [SCRAM_KEY_LEN],
479 ui [SCRAM_KEY_LEN];
480 int iterations = strtol(cinfo->i_attr, NULL, 0);
481 hmac_init(&ctx, (uint8_t*)cinfo->pass, password_len);
482 hmac_update(&ctx, (uint8_t*)salt->ptr, strlen(salt->ptr));
483 hmac_update(&ctx, (uint8_t*)&one, sizeof(uint32_t));
484 hmac_final(&ctx, ui_prev, sizeof(ui_prev));
485 memcpy(cinfo->salted_password, ui_prev, SCRAM_KEY_LEN);
486 for (int i = 2; i <= iterations; ++i) {
487 hmac_init(&ctx, (uint8_t*)cinfo->pass, password_len);
488 hmac_update(&ctx, ui_prev, SCRAM_KEY_LEN);
489 hmac_final(&ctx, ui, SCRAM_KEY_LEN);
490 for (int j = 0; j < SCRAM_KEY_LEN; ++j)
491 cinfo->salted_password[j] ^= ui[j];
492 memcpy(ui_prev, ui, SCRAM_KEY_LEN);
496 static void _scram_h (uint8_t *in, int len, uint8_t *result) {
497 sha_t ctx;
498 memset(&ctx, 0, sizeof ctx);
499 sha_init(&ctx);
500 sha_update(&ctx, in, len);
501 sha_final(&ctx, result);
504 static void _calc_scram_proof (conninfo_t *cinfo, uint8_t *result) {
505 hmac_t ctx;
506 uint8_t client_key [SCRAM_KEY_LEN],
507 stored_key [SCRAM_KEY_LEN],
508 clsign_key [SCRAM_KEY_LEN];
509 cstr_t *salt = cstr_b64decode(cinfo->s_attr, strlen(cinfo->s_attr));
510 _scram_salted_password(cinfo, salt);
511 free(salt);
512 _scram_create_key((uint8_t*)cinfo->salted_password, client_key);
513 _scram_h(client_key, SCRAM_KEY_LEN, stored_key);
514 hmac_init(&ctx, stored_key, SCRAM_KEY_LEN);
515 hmac_update(&ctx, (uint8_t*)cinfo->fmsg_bare->ptr, cinfo->fmsg_bare->len);
516 hmac_update(&ctx, (const uint8_t*)",", 1);
517 hmac_update(&ctx, (uint8_t*)cinfo->fmsg_srv->ptr, cinfo->fmsg_srv->len);
518 hmac_update(&ctx, (const uint8_t*)",", 1);
519 hmac_update(&ctx, (const uint8_t*)cinfo->fmsg_wproof->ptr, cinfo->fmsg_wproof->len);
520 hmac_final(&ctx, clsign_key, SCRAM_KEY_LEN);
521 for (int i = 0; i < SCRAM_KEY_LEN; ++i)
522 result[i] = client_key[i] ^ clsign_key[i];
525 pgmsg_t *pgmsg_create_sasl_fin (pgmsg_resp_t *resp, conninfo_t *cinfo) {
526 uint8_t cln_proof_key [SCRAM_KEY_LEN];
527 _parse_scram_final(resp, cinfo);
528 str_t *str = strfmt("c=biws,r=%s", cinfo->r_attr);
529 cinfo->fmsg_wproof = mkcstr(str->ptr, str->len);
530 strnadd(&str, CONST_STR_LEN(",p="));
531 _calc_scram_proof(cinfo, cln_proof_key);
532 cstr_t *cln_proof = cstr_b64encode((char*)cln_proof_key, SCRAM_KEY_LEN);
533 strnadd(&str, cln_proof->ptr, cln_proof->len);
534 pgmsg_t *msg = pgmsg_create(PG_PASS);
535 _setstr(&msg, str->ptr, str->len);
536 free(cln_proof);
537 free(str);
538 return msg;
541 int pgmsg_send (int fd, pgmsg_t *msg) {
542 void *buf;
543 size_t size;
544 ssize_t sent = 0, wrote = 0;
545 if ('\0' == msg->body.type) {
546 buf = &msg->body.len;
547 size = be32toh(msg->body.len);
548 } else {
549 buf = &msg->body;
550 size = be32toh(msg->body.len) + sizeof(char);
552 while (sent < size) {
554 wrote = send(fd, buf, size - sent, 0);
555 while (wrote < 0 && EINTR == errno);
556 sent += wrote;
557 buf += wrote;
559 return sent == size ? 0 : -1;
562 int pgmsg_recv (int fd, pgmsg_t **msg) {
563 ssize_t readed = 0, total = 0;
564 pgmsg_t *m;
565 size_t bufsize;
566 struct {
567 char type;
568 int32_t len;
569 } __attribute__ ((packed)) header = { 0, 0 };
570 while (total < HEADER_SIZE && (readed = recv(fd, (void*)(&header) + total, HEADER_SIZE - total, 0)) > 0)
571 total += readed;
572 if (-1 == readed)
573 return -1;
574 header.len = be32toh(header.len);
575 bufsize = sizeof(pgmsg_t) + header.len;
576 if (!(m = calloc(bufsize, sizeof(int8_t))))
577 return -1;
578 m->body.type = header.type;
579 m->body.len = header.len;
580 total = (header.len - sizeof(int32_t));
581 char *ptr = m->body.ptr;
582 while (total > 0) {
583 if (-1 == (readed = recv(fd, ptr, total, 0))) {
584 free(m);
585 return -1;
587 total -= readed;
588 ptr += readed;
590 *msg = m;
591 return 0;
594 static int _parse_param_status (pgmsg_body_t *body, pgmsg_param_status_t *pmsg) {
595 char *p = body->ptr, *e = p + body->len, *q = p;
596 while (*q && q < e) ++q;
597 if (q == e)
598 return -1;
599 pmsg->name = p;
600 pmsg->value = q + 1;
601 return 0;
604 static int _parse_error (pgmsg_body_t *body, pgmsg_error_t *pmsg) {
605 char *p = body->ptr, *e = p + body->len;
606 while (p < e) {
607 switch (*p) {
608 case PG_SEVERITY:
609 pmsg->severity = ++p;
610 break;
611 case PG_FATAL:
612 pmsg->text = ++p;
613 break;
614 case PG_SQLSTATE:
615 pmsg->code = ++p;
616 break;
617 case PG_MESSAGE:
618 pmsg->message = ++p;
619 break;
620 case PG_POSITION:
621 pmsg->position = ++p;
622 break;
623 case PG_FILE:
624 pmsg->file = ++p;
625 break;
626 case PG_LINE:
627 pmsg->line = ++p;
628 break;
629 case PG_ROUTINE:
630 pmsg->routine = ++p;
631 break;
632 case PG_DETAIL:
633 pmsg->detail = ++p;
634 break;
635 case '\0':
636 return 0;
637 default:
638 break;
640 while (p < e && *p) ++p;
641 if (p == e)
642 return 0;
643 if (++p == e)
644 return 0;
646 return 0;
649 static int _parse_rowdesc (pgmsg_t *msg, pgmsg_rowdesc_t *pmsg) {
650 pmsg->nflds = _geti16(msg);
651 for (int i = 0; i < pmsg->nflds; ++i) {
652 const char *fname = _getstr(msg);
653 if (!fname)
654 return -1;
655 pmsg->fields[i].fname = fname;
656 pmsg->fields[i].oid_table = _geti32(msg);
657 pmsg->fields[i].idx_field = _geti16(msg);
658 pmsg->fields[i].oid_field = _geti32(msg);
659 pmsg->fields[i].field_len = _geti16(msg);
660 pmsg->fields[i].type_mod = _geti32(msg);
661 pmsg->fields[i].field_fmt = _geti16(msg);
663 return 0;
666 static int _parse_datarow (pgmsg_t *msg, pgmsg_datarow_t *pmsg) {
667 pmsg->nflds = _geti16(msg);
668 for (int i = 0; i < pmsg->nflds; ++i) {
669 int32_t len = pmsg->fields[i].len = _geti32(msg);
670 pmsg->fields[i].data = NULL;
671 if (len >= 0) {
672 pmsg->fields[i].data = (uint8_t*)msg->pc;
673 msg->pc += len;
676 return 0;
679 static void _parse_copyin (pgmsg_t *msg, pgmsg_copyin_t *pmsg) {
680 pmsg->fmt = _geti8(msg);
681 pmsg->cols = _geti16(msg);
682 pmsg->fmtcol = _geti16(msg);
685 pgmsg_resp_t *pgmsg_parse (pgmsg_t *msg) {
686 pgmsg_resp_t *resp = NULL;
687 msg->pc = msg->body.ptr;
688 switch (msg->body.type) {
689 case PG_AUTHOK:
690 switch (be32toh(*((int32_t*)msg->body.ptr))) {
691 case PG_OK:
692 resp = malloc(sizeof(pgmsg_auth_t));
693 resp->type = msg->body.type;
694 resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr));
695 break;
696 case PG_REQMD5:
697 resp = malloc(sizeof(pgmsg_auth_t));
698 resp->type = msg->body.type;
699 resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr));
700 memcpy(resp->msg_auth.kind.md5_auth, msg->body.ptr + sizeof(int32_t), sizeof(uint8_t)*4);
701 break;
702 case PG_REQSASL:
703 resp = malloc(sizeof(pgmsg_auth_t));
704 resp->type = msg->body.type;
705 resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr));
706 break;
707 case PG_SASLCON:
708 case PG_SASLCOMP:
709 resp = malloc(sizeof(pgmsg_auth_t) + sizeof(sasl_t) + msg->body.len - sizeof(int32_t) * 2);
710 resp->type = msg->body.type;
711 resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr));
712 resp->msg_auth.kind.sasl_auth.len = msg->body.len - sizeof(int32_t) * 2;
713 memcpy(resp->msg_auth.kind.sasl_auth.data, msg->body.ptr + sizeof(int32_t), resp->msg_auth.kind.sasl_auth.len);
714 break;
716 break;
717 case PG_PARAMSTATUS:
718 resp = malloc(sizeof(pgmsg_param_status_t));
719 resp->type = msg->body.type;
720 if (-1 == _parse_param_status(&msg->body, &resp->msg_param_status)) {
721 free(resp);
722 resp = NULL;
724 break;
725 case PG_BACKENDKEYDATA:
726 resp = malloc(sizeof(pgmsg_backend_keydata_t));
727 resp->type = msg->body.type;
728 resp->msg_backend_keydata.pid = be32toh(*((int32_t*)msg->body.ptr));
729 resp->msg_backend_keydata.sk = be32toh(*((int32_t*)msg->body.ptr + sizeof(int32_t)));
730 break;
731 case PG_READY:
732 resp = malloc(sizeof(pgmsg_ready_t));
733 resp->type = msg->body.type;
734 resp->msg_ready.tr = msg->body.ptr[0];
735 break;
736 case PG_TERM:
737 resp =malloc(sizeof(pgmsg_auth_t));
738 resp->type = msg->body.type;
739 break;
740 case PG_ERROR:
741 resp = calloc(1, sizeof(pgmsg_error_t));
742 resp->type = msg->body.type;
743 if (-1 == _parse_error(&msg->body, &resp->msg_error)) {
744 free(resp);
745 resp = NULL;
747 break;
748 case PG_ROWDESC:
749 resp = malloc(sizeof(pgmsg_rowdesc_t) + sizeof(pgmsg_field_t) * be16toh(*((int16_t*)msg->body.ptr)));
750 resp->type = msg->body.type;
751 if (-1 == _parse_rowdesc(msg, &resp->msg_rowdesc)) {
752 free(resp);
753 resp = NULL;
755 break;
756 case PG_DATAROW:
757 resp = malloc(sizeof(pgmsg_datarow_t) + sizeof(pgmsg_data_t) * be16toh(*((int16_t*)msg->body.ptr)));
758 resp->type = msg->body.type;
759 if (-1 == _parse_datarow(msg, &resp->msg_datarow)) {
760 free(resp);
761 resp = NULL;
763 break;
764 case PG_CMDCOMPLETE:
765 resp = malloc(sizeof(pgmsg_cmd_complete_t));
766 resp->type = msg->body.type;
767 resp->msg_complete.tag = msg->body.ptr;
768 break;
769 case PG_COPYIN:
770 resp = malloc(sizeof(pgmsg_copyin_t));
771 resp->type = msg->body.type;
772 _parse_copyin(msg, &resp->msg_copyin);
773 break;
775 return resp;
778 static str_t *_prepare_str (const char *s, size_t l) {
779 str_t *str = stralloc(l, 0);
780 for (size_t i = 0; i < l; ++i) {
781 switch (s[i]) {
782 case '\\': strnadd(&str, CONST_STR_LEN("\\\\")); break;
783 case '\b': strnadd(&str, CONST_STR_LEN("\\b")); break;
784 case '\f': strnadd(&str, CONST_STR_LEN("\\f")); break;
785 case '\n': strnadd(&str, CONST_STR_LEN("\\n")); break;
786 case '\r': strnadd(&str, CONST_STR_LEN("\\r")); break;
787 case '\t': strnadd(&str, CONST_STR_LEN("\\t")); break;
788 default: strnadd(&str, &s[i], sizeof(char)); break;
791 return str;
794 static int _copyin_field (pgmsg_t **msg, pgfld_t *fld) {
795 struct tm tm;
796 char buf [32];
797 int len;
798 if (fld->is_null)
799 _setstr(msg, CONST_STR_LEN("\\N"));
800 else if (PG_TEXT == fld->fmt) {
801 str_t *str = _prepare_str(fld->data.s, fld->len);
802 _setstr(msg, str->ptr, str->len);
803 free(str);
804 } else
805 switch (fld->oid) {
806 case OID_INT2:
807 len = snprintf(buf, sizeof(buf),"%hd", be16toh(fld->data.i2));
808 _setstr(msg, buf, len);
809 break;
810 case OID_INT4:
811 len = snprintf(buf, sizeof(buf), "%d", be32toh(fld->data.i4));
812 _setstr(msg, buf, len);
813 break;
814 case OID_INT8:
815 len = snprintf(buf, sizeof(buf), LONG_FMT, be64toh(fld->data.i8));
816 _setstr(msg, buf, len);
817 break;
818 case OID_FLOAT4:
819 len = snprintf(buf, sizeof(buf), "%f", pg_conv_float(fld->data.f4));
820 _setstr(msg, buf, len);
821 break;
822 case OID_FLOAT8:
823 len = snprintf(buf, sizeof(buf), "%f", pg_conv_double(fld->data.f8));
824 _setstr(msg, buf, len);
825 break;
826 case OID_VARCHAR:
827 case OID_CHAR:
828 case OID_TEXT: {
829 str_t *str = _prepare_str(fld->data.s, fld->len);
830 _setstr(msg, str->ptr, str->len);
831 free(str);
833 break;
834 case OID_BOOL:
835 if (fld->data.b)
836 _setstr(msg, CONST_STR_LEN("t"));
837 else
838 _setstr(msg, CONST_STR_LEN("f"));
839 break;
840 case OID_DATE:
841 tm_dec(be64toh(fld->data.tm), &tm);
842 len = snprintf(buf, sizeof(buf), "%d-%02d-%02d", tm.tm_year, tm.tm_mon, tm.tm_mday);
843 _setstr(msg, buf, len);
844 break;
845 case OID_TIMESTAMP:
846 tm_dec(be64toh(fld->data.tm), &tm);
847 len = snprintf(buf, sizeof(buf), "%d-%02d-%02d %02d:%02d:%02d", tm.tm_year, tm.tm_mon, tm.tm_mday, tm.tm_hour, tm.tm_min, tm.tm_sec);
848 _setstr(msg, buf, len);
849 break;
850 // OID_UUID
851 case OID_BYTEA:
852 _setstr(msg, CONST_STR_LEN("\\\\x"));
853 if (!fld->len)
854 _setstr(msg, CONST_STR_LEN("0"));
855 else for (int i = 0; i < fld->len; ++i) {
856 len = snprintf(buf, sizeof(buf), "%02x", fld->data.s[i]);
857 _setstr(msg, buf, len);
859 break;
860 case OID_BIT:
861 len = snprintf(buf, sizeof(buf), "X'%08x'", fld->data.bit32.bit);
862 _setstr(msg, buf, len);
863 break;
864 //OID_MONEY
865 default:
866 return -1;
868 return 0;
871 pgmsg_t *pgmsg_copyin_flds (int len, pgfld_t **flds) {
872 pgmsg_t *msg = pgmsg_create(PG_COPYDATA);
873 if (-1 == _copyin_field(&msg, flds[0]))
874 goto err;
875 for (int i = 1; i < len; ++i) {
876 _setstr(&msg, CONST_STR_LEN("\t"));
877 if (-1 == _copyin_field(&msg, flds[i]))
878 goto err;
880 _seti8(&msg, 0x0a);
881 return msg;
882 err:
883 free(msg);
884 return NULL;