copyin
[libpgclient.git] / src / pgprov3.c
blob4d5bc97e4dc417bfbafde9da8e6771550d0c99ac
1 /*Copyright (c) Brian B.
3 This library is free software; you can redistribute it and/or
4 modify it under the terms of the GNU Lesser General Public
5 License as published by the Free Software Foundation; either
6 version 3 of the License, or (at your option) any later version.
7 See the file LICENSE included with this distribution for more
8 information.
9 */
10 #include "md5.h"
11 #include "hmac.h"
12 #include "libpgcli/pgprov3.h"
14 #define CHUNKSIZE 32
15 #define HEADER_SIZE sizeof(char) + sizeof(int32_t)
17 static int _resize (pgmsg_t **msg, size_t len) {
18 pgmsg_t *res;
19 int nlen = (*msg)->len + len;
20 size_t pc_len,
21 bufsize = sizeof(pgmsg_t) + (nlen / CHUNKSIZE) * CHUNKSIZE + CHUNKSIZE;
22 if (bufsize == (*msg)->bufsize)
23 return 0;
24 pc_len = (uintptr_t)(*msg)->pc - (uintptr_t)(*msg)->body.ptr;
25 if (!(res = realloc(*msg, bufsize)))
26 return -1;
27 res->bufsize = bufsize;
28 res->pc = res->body.ptr + pc_len;
29 *msg = res;
30 return 0;
33 pgmsg_t *pgmsg_create (char type) {
34 size_t bufsize = sizeof(pgmsg_t) + CHUNKSIZE;
35 pgmsg_t *msg = malloc(bufsize);
36 if (!msg) return NULL;
37 msg->bufsize = bufsize;
38 msg->len = sizeof(pgmsg_t);
39 msg->pc = msg->body.ptr;
40 msg->body.type = type;
41 msg->body.len = htobe32(sizeof(int32_t));
42 return msg;
45 static int _seti8 (pgmsg_t **msg, int8_t x) {
46 if (-1 == _resize(msg, sizeof(int8_t)))
47 return -1;
48 *((int8_t*)(*msg)->pc) = x;
49 (*msg)->pc += sizeof(int8_t);
50 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int8_t));
51 (*msg)->len += sizeof(int8_t);
52 return 0;
55 static int _seti16 (pgmsg_t **msg, int16_t x) {
56 if (-1 == _resize(msg, sizeof(int16_t)))
57 return -1;
58 *((int16_t*)(*msg)->pc) = htobe16(x);
59 (*msg)->pc += sizeof(int16_t);
60 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int16_t));
61 (*msg)->len += sizeof(int16_t);
62 return 0;
65 static int _seti32 (pgmsg_t **msg, int32_t x) {
66 if (-1 == _resize(msg, sizeof(int32_t)))
67 return -1;
68 *((int32_t*)(*msg)->pc) = htobe32(x);
69 (*msg)->pc += sizeof(int32_t);
70 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int32_t));
71 (*msg)->len += sizeof(int32_t);
72 return 0;
75 static int _setstr (pgmsg_t **msg, const char *s, size_t slen) {
76 if (0 == slen)
77 return 0;
78 if (-1 == _resize(msg, slen))
79 return -1;
80 memcpy((*msg)->pc, s, slen);
81 (*msg)->pc += slen;
82 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + slen);
83 (*msg)->len += slen;
84 return 0;
87 static int8_t _geti8 (pgmsg_t *msg) {
88 int8_t r = *msg->pc;
89 msg->pc += sizeof(int8_t);
90 return r;
93 static int16_t _geti16 (pgmsg_t *msg) {
94 int16_t r = be16toh(*((int16_t*)msg->pc));
95 msg->pc += sizeof(int16_t);
96 return r;
99 static int32_t _geti32 (pgmsg_t *msg) {
100 int32_t r = be32toh(*((int32_t*)msg->pc));
101 msg->pc += sizeof(int32_t);
102 return r;
105 static const char *_getstr (pgmsg_t *msg) {
106 char *p = msg->pc, *e = msg->body.ptr + be32toh(msg->body.len);
107 while (*(msg->pc) && msg->pc < e) ++msg->pc;
108 if (e == msg->pc)
109 return NULL;
110 ++msg->pc;
111 return p;
114 int pgmsg_set_param (pgmsg_t **msg, const char *name, size_t name_len, const char *value, size_t value_len) {
115 if (-1 == _setstr(msg, name, name_len) ||
116 -1 == _seti8(msg, 0) ||
117 -1 == _setstr(msg, value, value_len) ||
118 -1 == _seti8(msg, 0))
119 return -1;
120 return 0;
123 pgmsg_t *pgmsg_create_startup (const char *user, size_t user_len, const char *database, size_t database_len) {
124 pgmsg_t *msg = pgmsg_create('\0');
125 _seti16(&msg, PG_MAJOR_VER);
126 _seti16(&msg, PG_MINOR_VER);
127 pgmsg_set_param(&msg, CONST_STR_LEN("user"), user, user_len);
128 pgmsg_set_param(&msg, CONST_STR_LEN("database"), database, database_len);
129 return msg;
132 #define PG_USERDEFS 0x00000001
133 #define PG_DBDEFS 0x00000002
134 #define PG_ENCDEFS 0x00000004
136 static inline pgmsg_t *_add_str_param (pgmsg_t *msg, const char *k, size_t lk, const char *v, size_t lv) {
137 _setstr(&msg, k, lk);
138 _seti8(&msg, 0);
139 _setstr(&msg, v, lv);
140 _seti8(&msg, 0);
141 return msg;
144 static pgmsg_t *_add_param (pgmsg_t *msg, const char *begin, const char *end, uint32_t *flags) {
145 const char *p = begin, *p1;
146 while (p < end && '=' != *p) ++p;
147 if (0 == strncmp(begin, "password", (uintptr_t)p - (uintptr_t)begin))
148 return msg;
149 if (p == end)
150 return msg;
151 p1 = p;
152 while (p1 > begin && isspace(*(p1 - 1))) --p1;
153 if (p == p1 - 1)
154 return msg;
155 ++p;
156 while (p < end && isspace(*p)) ++p;
157 if (p == end)
158 return msg;
159 if (0 == strncmp(begin, "dbname", (uintptr_t)p1 - (uintptr_t)begin)) {
160 msg = _add_str_param(msg, CONST_STR_LEN("database"), p, (uintptr_t)end - (uintptr_t)p);
161 if (flags)
162 *flags |= PG_DBDEFS;
163 } else
164 if (0 == strncmp(begin, "user", (uintptr_t)p1 - (uintptr_t)begin)) {
165 msg = _add_str_param(msg, begin, (uintptr_t)p1 - (uintptr_t)begin, p, (uintptr_t)end - (uintptr_t)p);
166 if (flags)
167 *flags |= PG_USERDEFS;
168 } else
169 if (0 != strncmp(begin, "host", (uintptr_t)p1 - (uintptr_t)begin) &&
170 0 != strncmp(begin, "port", (uintptr_t)p1 - (uintptr_t)begin))
171 msg = _add_str_param(msg, begin, (uintptr_t)p1 - (uintptr_t)begin, p, (uintptr_t)end - (uintptr_t)p);
172 return msg;
175 void *parse_conninfo (void *data, const char *conn_info, parse_param_h fn, uint32_t *flags) {
176 if (conn_info) {
177 const char *p = conn_info;
178 while (*p) {
179 const char *q;
180 while (isspace(*p)) ++p;
181 if (!(*p)) break;
182 q = p;
183 while (*p && !isspace(*p)) ++p;
184 data = fn(data, q, p, flags);
187 return data;
190 static pgmsg_t *_startup_params (pgmsg_t *msg, const char *conn_info) {
191 uint32_t flags = 0;
192 msg = (pgmsg_t*)parse_conninfo((void*)msg, conn_info, (parse_param_h)_add_param, &flags);
193 if (!(flags & PG_USERDEFS)) {
194 char *s = getenv("USER");
195 if (s)
196 msg = _add_str_param(msg, CONST_STR_LEN("user"), s, strlen(s));
198 if (!(flags & PG_DBDEFS)) {
199 char *s = getenv("USER");
200 if (s)
201 msg = _add_str_param(msg, CONST_STR_LEN("database"), s, strlen(s));
203 if (!(flags & PG_ENCDEFS)) {
204 char *s = getenv("LANG");
205 if (s && (s = strchr(s, '.')) && *++s)
206 msg = _add_str_param(msg, CONST_STR_LEN("client_encoding"), s, strlen(s));
208 _seti8(&msg, 0);
209 return msg;
212 pgmsg_t *pgmsg_create_startup_params (const char *conn_info) {
213 pgmsg_t *msg = pgmsg_create('\0');
214 _seti16(&msg, PG_MAJOR_VER);
215 _seti16(&msg, PG_MINOR_VER);
216 return _startup_params(msg, conn_info);
219 pgmsg_t *pgmsg_create_simple_query (const char *sql, size_t sql_len) {
220 pgmsg_t *msg = pgmsg_create(PG_SIMPLEQUERY);
221 _setstr(&msg, sql, sql_len);
222 _seti8(&msg, 0);
223 return msg;
226 pgmsg_t *pgmsg_create_parse (const char *name, size_t name_len, const char *sql, size_t sql_len, int fld_len, pgfld_t **flds) {
227 pgmsg_t *msg = pgmsg_create(PG_PARSE);
228 if (name && 0 == name_len)
229 name_len = strlen(name);
230 if (0 == sql_len)
231 sql_len = strlen(sql);
232 _setstr(&msg, name, name_len);
233 _seti8(&msg, 0);
234 _setstr(&msg, sql, sql_len);
235 _seti8(&msg, 0);
236 _seti16(&msg, fld_len);
237 for (int i = 0; i < fld_len; ++i)
238 _seti32(&msg, flds[i]->oid);
239 return msg;
242 pgmsg_t *pgmsg_create_bind (const char *portal, size_t portal_len, const char *stmt, size_t stmt_len,
243 int fld_len, pgfld_t **flds, int res_fmt_len, int *res_fmt) {
244 pgmsg_t *msg = pgmsg_create(PG_BIND);
245 _setstr(&msg, portal, portal_len);
246 _seti8(&msg, 0);
247 _setstr(&msg, stmt, stmt_len);
248 _seti8(&msg, 0);
249 _seti16(&msg, fld_len);
250 for (int i = 0; i < fld_len; ++i)
251 _seti16(&msg, flds[i]->fmt);
252 _seti16(&msg, fld_len);
253 for (int i = 0; i < fld_len; ++i) {
254 if (flds[i]->is_null)
255 _seti32(&msg, -1);
256 else {
257 _seti32(&msg, flds[i]->len);
258 if (0 == flds[i]->fmt || OID_UUID == flds[i]->oid)
259 _setstr(&msg, flds[i]->data.s, flds[i]->len);
260 else
261 _setstr(&msg, (const char*)&flds[i]->data, flds[i]->len);
264 _seti16(&msg, res_fmt_len);
265 for (int i = 0; i < res_fmt_len; ++i)
266 _seti16(&msg, res_fmt[i]);
267 return msg;
270 pgmsg_t *pgmsg_create_describe (uint8_t op, const char *name, size_t name_len) {
271 pgmsg_t *msg = pgmsg_create(PG_DESCRIBE);
272 _seti8(&msg, op);
273 _setstr(&msg, name, name_len);
274 _seti8(&msg, 0);
275 return msg;
278 pgmsg_t *pgmsg_create_execute (const char *portal, size_t portal_len, int32_t max_rows) {
279 pgmsg_t *msg = pgmsg_create(PG_EXECUTE);
280 _setstr(&msg, portal, portal_len);
281 _seti8(&msg, 0);
282 _seti32(&msg, max_rows);
283 return msg;
286 pgmsg_t *pgmsg_create_close(char what, const char *str, size_t slen) {
287 pgmsg_t *msg = pgmsg_create(PG_CLOSE);
288 _seti8(&msg, (uint8_t)what);
289 _setstr(&msg, str, slen);
290 _seti8(&msg, 0);
291 return msg;
294 static void _pg_md5_hash (const void *buff, size_t len, char *out) {
295 uint8_t digest [16];
296 md5_t ctx;
297 md5_init(&ctx);
298 while (len > 0) {
299 if (len > 512)
300 md5_update(&ctx, (uint8_t*)buff, 512);
301 else
302 md5_update(&ctx, (uint8_t*)buff, len);
303 buff += 512;
304 len -= 512;
306 md5_final(&ctx);
307 memcpy(digest, ctx.digest, sizeof digest);
308 for (int i = 0; i < 16; ++i)
309 snprintf(&(out[i*2]), 16*2, "%02x", (uint8_t)digest[i]);
312 static void _pg_md5_encrypt(const char *passwd, const char *salt, size_t salt_len, char *buf) {
313 size_t passwd_len = strlen(passwd);
314 char *crypt_buf = malloc(passwd_len + salt_len + 1);
315 memcpy(crypt_buf, passwd, passwd_len);
316 memcpy(crypt_buf + passwd_len, salt, salt_len);
317 strcpy(buf, "md5");
318 _pg_md5_hash(crypt_buf, passwd_len + salt_len, buf + 3);
319 free(crypt_buf);
322 pgmsg_t *pgmsg_create_pass (int req, const char *salt, size_t salt_len, const char *user, const char *pass) {
323 pgmsg_t *msg = pgmsg_create(PG_PASS);
324 char *pwd = malloc(2 * (PG_MD5PASS_LEN + 1)),
325 *pwd2 = pwd + PG_MD5PASS_LEN + 1,
326 *pwd_to_send;
327 _pg_md5_encrypt(pass, user, strlen(user), pwd2);
328 _pg_md5_encrypt(pwd2 + sizeof("md5")-1, salt, 4, pwd);
329 pwd_to_send = pwd;
330 switch (req) {
331 case PG_REQMD5:
332 pwd_to_send = pwd;
333 break;
334 case PG_REQPASS:
335 pwd_to_send = (char*)pass;
337 _setstr(&msg, pwd_to_send, strlen(pwd_to_send));
338 _seti8(&msg, 0);
339 return msg;
342 pgmsg_t *pgmsg_create_sasl_init (conninfo_t *cinfo) {
343 char raw_nonce [SCRAM_RAW_NONCE_LEN+1];
344 strand(raw_nonce, SCRAM_RAW_NONCE_LEN, RAND_ALNUM);
345 raw_nonce[SCRAM_RAW_NONCE_LEN] = '\0';
346 cinfo->nonce = cstr_b64encode(raw_nonce, SCRAM_RAW_NONCE_LEN);
347 str_t *str = strprintf("n,,n=,r=%s", cinfo->nonce->ptr);
348 cinfo->fmsg_bare = mkcstr(str->ptr+3, str->len-3);
349 pgmsg_t *msg = pgmsg_create(PG_PASS);
350 _setstr(&msg, CONST_STR_LEN("SCRAM-SHA-256"));
351 _seti8(&msg, 0);
352 _seti32(&msg, str->len);
353 _setstr(&msg, str->ptr, str->len);
354 free(str);
355 return msg;
358 static int _parse_scram_final (pgmsg_resp_t *resp, conninfo_t *cinfo) {
359 strptr_t entry = { .ptr = NULL, .len = 0 };
360 char *data = (char*)resp->msg_auth.kind.sasl_auth.data;
361 int len = resp->msg_auth.kind.sasl_auth.len;
362 if ('e' == *resp->msg_auth.kind.sasl_auth.data)
363 return -1;
364 if (!(cinfo->srv_scram_msg = strsplit(data, len, ',')))
365 return -1;
366 cinfo->fmsg_srv = mkcstr(data, len);
367 while (0 == strnext(cinfo->srv_scram_msg, &entry)) {
368 if (entry.len < 3 && '=' != entry.ptr[1])
369 continue;
370 switch (*entry.ptr) {
371 case 'r':
372 cinfo->r_attr = entry.ptr+2;
373 break;
374 case 's':
375 cinfo->s_attr = entry.ptr+2;
376 break;
377 case 'i':
378 cinfo->i_attr = entry.ptr+2;
379 break;
382 return cinfo->r_attr && cinfo->s_attr && cinfo->i_attr ? 0 : -1;
385 static void _scram_create_key (uint8_t *salted_password, uint8_t *result) {
386 hmac_t ctx;
387 hmac_init(&ctx, salted_password, SCRAM_KEY_LEN);
388 hmac_update(&ctx, (uint8_t*)"Client Key", sizeof("Client Key")-1);
389 hmac_final(&ctx, result, SCRAM_KEY_LEN);
392 char *(*on_pgauth) (const char *prompt, int is_echo);
393 static struct termios term_settings;
394 static char *_pg_auth (const char *prompt, const char *def_auth, int is_echo) {
395 char *s = NULL;
396 size_t len = 0;
397 ssize_t ilen;
398 struct termios tc;
399 if (on_pgauth)
400 return on_pgauth(prompt, is_echo);
401 printf("%s", prompt);
402 if (!is_echo) {
403 tcgetattr(0, &term_settings);
404 tc = term_settings;
405 tc.c_lflag &= ~ECHO;
406 tcsetattr(0, TCSANOW, &tc);
408 if (-1 == (ilen = getline(&s, &len, stdin))) {
409 if (!is_echo)
410 tcsetattr(0, TCSANOW, &term_settings);
411 return strdup(def_auth);
413 if (!is_echo) {
414 tcsetattr(0, TCSANOW, &term_settings);
415 printf("\n");
417 s[ilen-1] = '\0';
418 return s;
421 static void _scram_salted_password (conninfo_t *cinfo, cstr_t *salt) {
422 hmac_t ctx;
423 if (!cinfo->user)
424 cinfo->user = _pg_auth("Enter username: ", getenv("USER"), 1);
425 if (!cinfo->pass)
426 cinfo->pass = _pg_auth("Enter password: ", "", 0);
427 int password_len = strlen(cinfo->pass);
428 uint32_t one = htobe32(1);
429 uint8_t ui_prev [SCRAM_KEY_LEN],
430 ui [SCRAM_KEY_LEN];
431 int iterations = strtol(cinfo->i_attr, NULL, 0);
432 hmac_init(&ctx, (uint8_t*)cinfo->pass, password_len);
433 hmac_update(&ctx, (uint8_t*)salt->ptr, strlen(salt->ptr));
434 hmac_update(&ctx, (uint8_t*)&one, sizeof(uint32_t));
435 hmac_final(&ctx, ui_prev, sizeof(ui_prev));
436 memcpy(cinfo->salted_password, ui_prev, SCRAM_KEY_LEN);
437 for (int i = 2; i <= iterations; ++i) {
438 hmac_init(&ctx, (uint8_t*)cinfo->pass, password_len);
439 hmac_update(&ctx, ui_prev, SCRAM_KEY_LEN);
440 hmac_final(&ctx, ui, SCRAM_KEY_LEN);
441 for (int j = 0; j < SCRAM_KEY_LEN; ++j)
442 cinfo->salted_password[j] ^= ui[j];
443 memcpy(ui_prev, ui, SCRAM_KEY_LEN);
447 static void _scram_h (uint8_t *in, int len, uint8_t *result) {
448 sha_t ctx;
449 memset(&ctx, 0, sizeof ctx);
450 sha_init(&ctx);
451 sha_update(&ctx, in, len);
452 sha_final(&ctx, result);
455 static void _calc_scram_proof (conninfo_t *cinfo, uint8_t *result) {
456 hmac_t ctx;
457 uint8_t client_key [SCRAM_KEY_LEN],
458 stored_key [SCRAM_KEY_LEN],
459 clsign_key [SCRAM_KEY_LEN];
460 cstr_t *salt = cstr_b64decode(cinfo->s_attr, strlen(cinfo->s_attr));
461 _scram_salted_password(cinfo, salt);
462 free(salt);
463 _scram_create_key((uint8_t*)cinfo->salted_password, client_key);
464 _scram_h(client_key, SCRAM_KEY_LEN, stored_key);
465 hmac_init(&ctx, stored_key, SCRAM_KEY_LEN);
466 hmac_update(&ctx, (uint8_t*)cinfo->fmsg_bare->ptr, cinfo->fmsg_bare->len);
467 hmac_update(&ctx, (const uint8_t*)",", 1);
468 hmac_update(&ctx, (uint8_t*)cinfo->fmsg_srv->ptr, cinfo->fmsg_srv->len);
469 hmac_update(&ctx, (const uint8_t*)",", 1);
470 hmac_update(&ctx, (const uint8_t*)cinfo->fmsg_wproof->ptr, cinfo->fmsg_wproof->len);
471 hmac_final(&ctx, clsign_key, SCRAM_KEY_LEN);
472 for (int i = 0; i < SCRAM_KEY_LEN; ++i)
473 result[i] = client_key[i] ^ clsign_key[i];
476 pgmsg_t *pgmsg_create_sasl_fin (pgmsg_resp_t *resp, conninfo_t *cinfo) {
477 uint8_t cln_proof_key [SCRAM_KEY_LEN];
478 _parse_scram_final(resp, cinfo);
479 str_t *str = strprintf("c=biws,r=%s", cinfo->r_attr);
480 cinfo->fmsg_wproof = mkcstr(str->ptr, str->len);
481 strnadd(&str, CONST_STR_LEN(",p="));
482 _calc_scram_proof(cinfo, cln_proof_key);
483 cstr_t *cln_proof = cstr_b64encode((char*)cln_proof_key, SCRAM_KEY_LEN);
484 strnadd(&str, cln_proof->ptr, cln_proof->len);
485 pgmsg_t *msg = pgmsg_create(PG_PASS);
486 _setstr(&msg, str->ptr, str->len);
487 free(cln_proof);
488 free(str);
489 return msg;
492 int pgmsg_send (int fd, pgmsg_t *msg) {
493 void *buf;
494 size_t size;
495 ssize_t sent = 0, wrote = 0;
496 if ('\0' == msg->body.type) {
497 buf = &msg->body.len;
498 size = be32toh(msg->body.len);
499 } else {
500 buf = &msg->body;
501 size = be32toh(msg->body.len) + sizeof(char);
503 while (sent < size) {
505 wrote = send(fd, buf, size - sent, 0);
506 while (wrote < 0 && EINTR == errno);
507 sent += wrote;
508 buf += wrote;
510 return sent == size ? 0 : -1;
513 int pgmsg_recv (int fd, pgmsg_t **msg) {
514 ssize_t readed = 0, total = 0;
515 pgmsg_t *m;
516 size_t bufsize;
517 struct {
518 char type;
519 int32_t len;
520 } __attribute__ ((packed)) header;
521 while (total < HEADER_SIZE && (readed = recv(fd, (void*)(&header + total), HEADER_SIZE - total, 0)) > 0)
522 total += readed;
523 if (-1 == readed)
524 return -1;
525 header.len = be32toh(header.len);
526 bufsize = sizeof(pgmsg_t) + header.len;
527 if (!(m = calloc(bufsize, sizeof(int8_t))))
528 return -1;
529 m->body.type = header.type;
530 m->body.len = header.len;
531 total = (header.len - sizeof(int32_t));
532 char *ptr = m->body.ptr;
533 while (total > 0) {
534 if (-1 == (readed = recv(fd, ptr, total, 0))) {
535 free(m);
536 return -1;
538 total -= readed;
539 ptr += readed;
541 *msg = m;
542 return 0;
545 static int _parse_param_status (pgmsg_body_t *body, pgmsg_param_status_t *pmsg) {
546 char *p = body->ptr, *e = p + body->len, *q = p;
547 while (*q && q < e) ++q;
548 if (q == e)
549 return -1;
550 pmsg->name = p;
551 pmsg->value = q + 1;
552 return 0;
555 static int _parse_error (pgmsg_body_t *body, pgmsg_error_t *pmsg) {
556 char *p = body->ptr, *e = p + body->len;
557 while (p < e) {
558 switch (*p) {
559 case PG_SEVERITY:
560 pmsg->severity = ++p;
561 break;
562 case PG_FATAL:
563 pmsg->text = ++p;
564 break;
565 case PG_SQLSTATE:
566 pmsg->code = ++p;
567 break;
568 case PG_MESSAGE:
569 pmsg->message = ++p;
570 break;
571 case PG_POSITION:
572 pmsg->position = ++p;
573 break;
574 case PG_FILE:
575 pmsg->file = ++p;
576 break;
577 case PG_LINE:
578 pmsg->line = ++p;
579 break;
580 case PG_ROUTINE:
581 pmsg->routine = ++p;
582 break;
583 case '\0':
584 return 0;
585 default:
586 break;
588 while (p < e && *p) ++p;
589 if (p == e)
590 return 0;
591 if (++p == e)
592 return 0;
594 return 0;
597 static int _parse_rowdesc (pgmsg_t *msg, pgmsg_rowdesc_t *pmsg) {
598 pmsg->nflds = _geti16(msg);
599 for (int i = 0; i < pmsg->nflds; ++i) {
600 const char *fname = _getstr(msg);
601 if (!fname)
602 return -1;
603 pmsg->fields[i].fname = fname;
604 pmsg->fields[i].oid_table = _geti32(msg);
605 pmsg->fields[i].idx_field = _geti16(msg);
606 pmsg->fields[i].oid_field = _geti32(msg);
607 pmsg->fields[i].field_len = _geti16(msg);
608 pmsg->fields[i].type_mod = _geti32(msg);
609 pmsg->fields[i].field_fmt = _geti16(msg);
611 return 0;
614 static int _parse_datarow (pgmsg_t *msg, pgmsg_datarow_t *pmsg) {
615 pmsg->nflds = _geti16(msg);
616 for (int i = 0; i < pmsg->nflds; ++i) {
617 int32_t len = pmsg->fields[i].len = _geti32(msg);
618 pmsg->fields[i].data = NULL;
619 if (len >= 0) {
620 pmsg->fields[i].data = (uint8_t*)msg->pc;
621 msg->pc += len;
624 return 0;
627 static void _parse_copyin (pgmsg_t *msg, pgmsg_copyin_t *pmsg) {
628 pmsg->fmt = _geti8(msg);
629 pmsg->cols = _geti16(msg);
630 pmsg->fmtcol = _geti16(msg);
633 pgmsg_resp_t *pgmsg_parse (pgmsg_t *msg) {
634 pgmsg_resp_t *resp = NULL;
635 msg->pc = msg->body.ptr;
636 switch (msg->body.type) {
637 case PG_AUTHOK:
638 switch (be32toh(*((int32_t*)msg->body.ptr))) {
639 case PG_OK:
640 resp = malloc(sizeof(pgmsg_auth_t));
641 resp->type = msg->body.type;
642 resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr));
643 break;
644 case PG_REQMD5:
645 resp = malloc(sizeof(pgmsg_auth_t));
646 resp->type = msg->body.type;
647 resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr));
648 memcpy(resp->msg_auth.kind.md5_auth, msg->body.ptr + sizeof(int32_t), sizeof(uint8_t)*4);
649 break;
650 case PG_REQSASL:
651 resp = malloc(sizeof(pgmsg_auth_t));
652 resp->type = msg->body.type;
653 resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr));
654 break;
655 case PG_SASLCON:
656 case PG_SASLCOMP:
657 resp = malloc(sizeof(pgmsg_auth_t) + sizeof(sasl_t) + msg->body.len - sizeof(int32_t) * 2);
658 resp->type = msg->body.type;
659 resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr));
660 resp->msg_auth.kind.sasl_auth.len = msg->body.len - sizeof(int32_t) * 2;
661 memcpy(resp->msg_auth.kind.sasl_auth.data, msg->body.ptr + sizeof(int32_t), resp->msg_auth.kind.sasl_auth.len);
662 break;
664 break;
665 case PG_PARAMSTATUS:
666 resp = malloc(sizeof(pgmsg_param_status_t));
667 resp->type = msg->body.type;
668 if (-1 == _parse_param_status(&msg->body, &resp->msg_param_status)) {
669 free(resp);
670 resp = NULL;
672 break;
673 case PG_BACKENDKEYDATA:
674 resp = malloc(sizeof(pgmsg_backend_keydata_t));
675 resp->type = msg->body.type;
676 resp->msg_backend_keydata.pid = be32toh(*((int32_t*)msg->body.ptr));
677 resp->msg_backend_keydata.sk = be32toh(*((int32_t*)msg->body.ptr + sizeof(int32_t)));
678 break;
679 case PG_READY:
680 resp = malloc(sizeof(pgmsg_ready_t));
681 resp->type = msg->body.type;
682 resp->msg_ready.tr = msg->body.ptr[0];
683 break;
684 case PG_TERM:
685 resp =malloc(sizeof(pgmsg_auth_t));
686 resp->type = msg->body.type;
687 break;
688 case PG_ERROR:
689 resp = malloc(sizeof(pgmsg_error_t));
690 resp->type = msg->body.type;
691 if (-1 == _parse_error(&msg->body, &resp->msg_error)) {
692 free(resp);
693 resp = NULL;
695 break;
696 case PG_ROWDESC:
697 resp = malloc(sizeof(pgmsg_rowdesc_t) + sizeof(pgmsg_field_t) * be16toh(*((int16_t*)msg->body.ptr)));
698 resp->type = msg->body.type;
699 if (-1 == _parse_rowdesc(msg, &resp->msg_rowdesc)) {
700 free(resp);
701 resp = NULL;
703 break;
704 case PG_DATAROW:
705 resp = malloc(sizeof(pgmsg_datarow_t) + sizeof(pgmsg_data_t) * be16toh(*((int16_t*)msg->body.ptr)));
706 resp->type = msg->body.type;
707 if (-1 == _parse_datarow(msg, &resp->msg_datarow)) {
708 free(resp);
709 resp = NULL;
711 break;
712 case PG_CMDCOMPLETE:
713 resp = malloc(sizeof(pgmsg_cmd_complete_t));
714 resp->type = msg->body.type;
715 resp->msg_complete.tag = msg->body.ptr;
716 break;
717 case PG_COPYIN:
718 resp = malloc(sizeof(pgmsg_copyin_t));
719 resp->type = msg->body.type;
720 _parse_copyin(msg, &resp->msg_copyin);
721 break;
723 return resp;
726 static int _copyin_field (pgmsg_t **msg, pgfld_t *fld) {
727 struct tm tm;
728 char buf [32];
729 int len;
730 if (fld->is_null)
731 _setstr(msg, CONST_STR_LEN("\\N"));
732 else if (PG_TEXT == fld->fmt)
733 _setstr(msg, fld->data.s, fld->len);
734 else
735 switch (fld->oid) {
736 case OID_INT2:
737 len = snprintf(buf, sizeof(buf),"%hd", be16toh(fld->data.i2));
738 _setstr(msg, buf, len);
739 break;
740 case OID_INT4:
741 len = snprintf(buf, sizeof(buf), "%d", be32toh(fld->data.i4));
742 _setstr(msg, buf, len);
743 break;
744 case OID_INT8:
745 len = snprintf(buf, sizeof(buf), LONG_FMT, be64toh(fld->data.i8));
746 _setstr(msg, buf, len);
747 break;
748 case OID_FLOAT4:
749 len = snprintf(buf, sizeof(buf), "%f", pg_conv_float(fld->data.f4));
750 _setstr(msg, buf, len);
751 break;
752 case OID_FLOAT8:
753 len = snprintf(buf, sizeof(buf), "%f", pg_conv_double(fld->data.f8));
754 _setstr(msg, buf, len);
755 break;
756 case OID_VARCHAR:
757 case OID_CHAR:
758 case OID_TEXT:
759 _setstr(msg, fld->data.s, fld->len);
760 break;
761 case OID_BOOL:
762 if (fld->data.b)
763 _setstr(msg, CONST_STR_LEN("t"));
764 else
765 _setstr(msg, CONST_STR_LEN("f"));
766 break;
767 case OID_DATE:
768 tm_dec(be64toh(fld->data.tm), &tm);
769 len = snprintf(buf, sizeof(buf), "%d-%02d-%02d", tm.tm_year, tm.tm_mon, tm.tm_mday);
770 _setstr(msg, buf, len);
771 break;
772 case OID_TIMESTAMP:
773 tm_dec(be64toh(fld->data.tm), &tm);
774 len = snprintf(buf, sizeof(buf), "%d-%02d-%02d %02d:%02d:%02d", tm.tm_year, tm.tm_mon, tm.tm_mday, tm.tm_hour, tm.tm_min, tm.tm_sec);
775 _setstr(msg, buf, len);
776 break;
777 // OID_UUID
778 case OID_BYTEA:
779 _setstr(msg, CONST_STR_LEN("0x"));
780 for (int i = 0; i < fld->len; ++i) {
781 len = snprintf(buf, sizeof(buf), "%02x", fld->data.s[i]);
782 _setstr(msg, buf, len);
784 break;
785 case OID_BIT:
786 len = snprintf(buf, sizeof(buf), "X'%08x'", fld->data.bit32.bit);
787 _setstr(msg, buf, len);
788 break;
789 //OID_MONEY
790 default:
791 return -1;
793 return 0;
796 pgmsg_t *pgmsg_copyin_flds (int len, pgfld_t **flds) {
797 pgmsg_t *msg = pgmsg_create(PG_COPYDATA);
798 if (-1 == _copyin_field(&msg, flds[0]))
799 goto err;
800 for (int i = 1; i < len; ++i) {
801 _setstr(&msg, CONST_STR_LEN("\t"));
802 if (-1 == _copyin_field(&msg, flds[i]))
803 goto err;
805 _seti8(&msg, 0x0a);
806 return msg;
807 err:
808 free(msg);
809 return NULL;