jam
[libpgclient.git] / src / pgprov3.c
blob3ff581109ac0926e75ed499329315641ed393c3b
1 #include "md5.h"
2 #include "libpgcli/pgprov3.h"
4 #define CHUNKSIZE 32
5 #define HEADER_SIZE sizeof(char) + sizeof(int32_t)
7 static int _resize (pgmsg_t **msg, size_t len) {
8 pgmsg_t *res;
9 int nlen = (*msg)->len + len;
10 size_t pc_len,
11 bufsize = sizeof(pgmsg_t) + (nlen / CHUNKSIZE) * CHUNKSIZE + CHUNKSIZE;
12 if (bufsize == (*msg)->bufsize)
13 return 0;
14 pc_len = (uintptr_t)(*msg)->pc - (uintptr_t)(*msg)->body.ptr;
15 if (!(res = realloc(*msg, bufsize)))
16 return -1;
17 res->bufsize = bufsize;
18 res->pc = res->body.ptr + pc_len;
19 *msg = res;
20 return 0;
23 pgmsg_t *pgmsg_create (char type) {
24 size_t bufsize = sizeof(pgmsg_t) + CHUNKSIZE;
25 pgmsg_t *msg = malloc(bufsize);
26 if (!msg) return NULL;
27 msg->bufsize = bufsize;
28 msg->len = sizeof(pgmsg_t);
29 msg->pc = msg->body.ptr;
30 msg->body.type = type;
31 msg->body.len = htobe32(sizeof(int32_t));
32 return msg;
35 static int _seti8 (pgmsg_t **msg, int8_t x) {
36 if (-1 == _resize(msg, sizeof(int8_t)))
37 return -1;
38 *((int8_t*)(*msg)->pc) = x;
39 (*msg)->pc += sizeof(int8_t);
40 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int8_t));
41 (*msg)->len += sizeof(int8_t);
42 return 0;
45 static int _seti16 (pgmsg_t **msg, int16_t x) {
46 if (-1 == _resize(msg, sizeof(int16_t)))
47 return -1;
48 *((int16_t*)(*msg)->pc) = htobe16(x);
49 (*msg)->pc += sizeof(int16_t);
50 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int16_t));
51 (*msg)->len += sizeof(int16_t);
52 return 0;
55 static int _seti32 (pgmsg_t **msg, int32_t x) {
56 if (-1 == _resize(msg, sizeof(int32_t)))
57 return -1;
58 *((int32_t*)(*msg)->pc) = htobe32(x);
59 (*msg)->pc += sizeof(int32_t);
60 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int32_t));
61 (*msg)->len += sizeof(int32_t);
62 return 0;
65 static int _setstr (pgmsg_t **msg, const char *s, size_t slen) {
66 if (0 == slen)
67 return 0;
68 if (-1 == _resize(msg, slen))
69 return -1;
70 memcpy((*msg)->pc, s, slen);
71 (*msg)->pc += slen;
72 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + slen);
73 (*msg)->len += slen;
74 return 0;
77 static int16_t _geti16 (pgmsg_t *msg) {
78 int16_t r = be16toh(*((int16_t*)msg->pc));
79 msg->pc += sizeof(int16_t);
80 return r;
83 static int32_t _geti32 (pgmsg_t *msg) {
84 int32_t r = be32toh(*((int32_t*)msg->pc));
85 msg->pc += sizeof(int32_t);
86 return r;
89 static const char *_getstr (pgmsg_t *msg) {
90 char *p = msg->pc, *e = msg->body.ptr + be32toh(msg->body.len);
91 while (*(msg->pc) && msg->pc < e) ++msg->pc;
92 if (e == msg->pc)
93 return NULL;
94 ++msg->pc;
95 return p;
98 int pgmsg_set_param (pgmsg_t **msg, const char *name, size_t name_len, const char *value, size_t value_len) {
99 if (-1 == _setstr(msg, name, name_len) ||
100 -1 == _seti8(msg, 0) ||
101 -1 == _setstr(msg, value, value_len) ||
102 -1 == _seti8(msg, 0))
103 return -1;
104 return 0;
107 pgmsg_t *pgmsg_create_startup (const char *user, size_t user_len, const char *database, size_t database_len) {
108 pgmsg_t *msg = pgmsg_create('\0');
109 _seti16(&msg, PG_MAJOR_VER);
110 _seti16(&msg, PG_MINOR_VER);
111 pgmsg_set_param(&msg, CONST_STR_LEN("user"), user, user_len);
112 pgmsg_set_param(&msg, CONST_STR_LEN("database"), database, database_len);
113 return msg;
116 #define PG_USERDEFS 0x00000001
117 #define PG_DBDEFS 0x00000002
118 #define PG_ENCDEFS 0x00000004
120 static inline pgmsg_t *_add_str_param (pgmsg_t *msg, const char *k, size_t lk, const char *v, size_t lv) {
121 _setstr(&msg, k, lk);
122 _seti8(&msg, 0);
123 _setstr(&msg, v, lv);
124 _seti8(&msg, 0);
125 return msg;
128 static pgmsg_t *_add_param (pgmsg_t *msg, const char *begin, const char *end, uint32_t *flags) {
129 const char *p = begin, *p1;
130 while (p < end && '=' != *p) ++p;
131 if (p == end)
132 return msg;
133 p1 = p;
134 while (p1 > begin && isspace(*(p1 - 1))) --p1;
135 if (p == p1 - 1)
136 return msg;
137 ++p;
138 while (p < end && isspace(*p)) ++p;
139 if (p == end)
140 return msg;
141 if (0 == strncmp(begin, "dbname", (uintptr_t)p1 - (uintptr_t)begin)) {
142 msg = _add_str_param(msg, CONST_STR_LEN("database"), p, (uintptr_t)end - (uintptr_t)p);
143 if (flags)
144 *flags |= PG_DBDEFS;
145 } else
146 if (0 == strncmp(begin, "user", (uintptr_t)p1 - (uintptr_t)begin)) {
147 msg = _add_str_param(msg, begin, (uintptr_t)p1 - (uintptr_t)begin, p, (uintptr_t)end - (uintptr_t)p);
148 if (flags)
149 *flags |= PG_USERDEFS;
150 } else
151 if (0 != strncmp(begin, "host", (uintptr_t)p1 - (uintptr_t)begin) &&
152 0 != strncmp(begin, "port", (uintptr_t)p1 - (uintptr_t)begin))
153 msg = _add_str_param(msg, begin, (uintptr_t)p1 - (uintptr_t)begin, p, (uintptr_t)end - (uintptr_t)p);
154 return msg;
157 void *parse_conninfo (void *data, const char *conn_info, parse_param_h fn, uint32_t *flags) {
158 if (conn_info) {
159 const char *p = conn_info;
160 while (*p) {
161 const char *q;
162 while (isspace(*p)) ++p;
163 if (!(*p)) break;
164 q = p;
165 while (*p && !isspace(*p)) ++p;
166 data = fn(data, q, p, flags);
169 return data;
172 static pgmsg_t *_startup_params (pgmsg_t *msg, const char *conn_info) {
173 uint32_t flags = 0;
174 msg = (pgmsg_t*)parse_conninfo((void*)msg, conn_info, (parse_param_h)_add_param, &flags);
175 if (!(flags & PG_USERDEFS)) {
176 char *s = getenv("USER");
177 if (s)
178 msg = _add_str_param(msg, CONST_STR_LEN("user"), s, strlen(s));
180 if (!(flags & PG_DBDEFS)) {
181 char *s = getenv("USER");
182 if (s)
183 msg = _add_str_param(msg, CONST_STR_LEN("database"), s, strlen(s));
185 if (!(flags & PG_ENCDEFS)) {
186 char *s = getenv("LANG");
187 if (s && (s = strchr(s, '.')) && *++s)
188 msg = _add_str_param(msg, CONST_STR_LEN("client_encoding"), s, strlen(s));
190 _seti8(&msg, 0);
191 return msg;
194 pgmsg_t *pgmsg_create_startup_params (const char *conn_info) {
195 pgmsg_t *msg = pgmsg_create('\0');
196 _seti16(&msg, PG_MAJOR_VER);
197 _seti16(&msg, PG_MINOR_VER);
198 return _startup_params(msg, conn_info);
201 pgmsg_t *pgmsg_create_simple_query (const char *sql, size_t sql_len) {
202 pgmsg_t *msg = pgmsg_create(PG_SIMPLEQUERY);
203 _setstr(&msg, sql, sql_len);
204 _seti8(&msg, 0);
205 return msg;
208 pgmsg_t *pgmsg_create_parse (const char *name, size_t name_len, const char *sql, size_t sql_len, int fld_len, pgfld_t **flds) {
209 pgmsg_t *msg = pgmsg_create(PG_PARSE);
210 if (name && 0 == name_len)
211 name_len = strlen(name);
212 if (0 == sql_len)
213 sql_len = strlen(sql);
214 _setstr(&msg, name, name_len);
215 _seti8(&msg, 0);
216 _setstr(&msg, sql, sql_len);
217 _seti8(&msg, 0);
218 _seti16(&msg, fld_len);
219 for (int i = 0; i < fld_len; ++i)
220 _seti32(&msg, flds[i]->oid);
221 return msg;
224 pgmsg_t *pgmsg_create_bind (const char *portal, size_t portal_len, const char *stmt, size_t stmt_len,
225 int fld_len, pgfld_t **flds, int res_fmt_len, int *res_fmt) {
226 pgmsg_t *msg = pgmsg_create(PG_BIND);
227 _setstr(&msg, portal, portal_len);
228 _seti8(&msg, 0);
229 _setstr(&msg, stmt, stmt_len);
230 _seti8(&msg, 0);
231 _seti16(&msg, fld_len);
232 for (int i = 0; i < fld_len; ++i)
233 _seti16(&msg, flds[i]->fmt);
234 _seti16(&msg, fld_len);
235 for (int i = 0; i < fld_len; ++i) {
236 if (flds[i]->is_null)
237 _seti32(&msg, -1);
238 else {
239 _seti32(&msg, flds[i]->len);
240 if (0 == flds[i]->fmt || OID_UUID == flds[i]->oid)
241 _setstr(&msg, flds[i]->data.s, flds[i]->len);
242 else
243 _setstr(&msg, (const char*)&flds[i]->data, flds[i]->len);
246 _seti16(&msg, res_fmt_len);
247 for (int i = 0; i < res_fmt_len; ++i)
248 _seti16(&msg, res_fmt[i]);
249 return msg;
252 pgmsg_t *pgmsg_create_describe (uint8_t op, const char *name, size_t name_len) {
253 pgmsg_t *msg = pgmsg_create(PG_DESCRIBE);
254 _seti8(&msg, op);
255 _setstr(&msg, name, name_len);
256 _seti8(&msg, 0);
257 return msg;
260 pgmsg_t *pgmsg_create_execute (const char *portal, size_t portal_len, int32_t max_rows) {
261 pgmsg_t *msg = pgmsg_create(PG_EXECUTE);
262 _setstr(&msg, portal, portal_len);
263 _seti8(&msg, 0);
264 _seti32(&msg, max_rows);
265 return msg;
268 pgmsg_t *pgmsg_create_close(char what, const char *str, size_t slen) {
269 pgmsg_t *msg = pgmsg_create(PG_CLOSE);
270 _seti8(&msg, (uint8_t)what);
271 _setstr(&msg, str, slen);
272 _seti8(&msg, 0);
273 return msg;
276 static void _pg_md5_hash (const void *buff, size_t len, char *out) {
277 uint8_t digest [16];
278 MD5Context ctx;
279 md5Init(&ctx);
280 while (len > 0) {
281 if (len > 512)
282 md5Update(&ctx, (uint8_t*)buff, 512);
283 else
284 md5Update(&ctx, (uint8_t*)buff, len);
285 buff += 512;
286 len -= 512;
288 md5Finalize(&ctx);
289 memcpy(digest, ctx.digest, 16);
290 for (int i = 0; i < 16; ++i)
291 snprintf(&(out[i*2]), 16*2, "%02x", (uint8_t)digest[i]);
294 static void _pg_md5_encrypt(const char *passwd, const char *salt, size_t salt_len, char *buf) {
295 size_t passwd_len = strlen(passwd);
296 char *crypt_buf = malloc(passwd_len + salt_len + 1);
297 memcpy(crypt_buf, passwd, passwd_len);
298 memcpy(crypt_buf + passwd_len, salt, salt_len);
299 strcpy(buf, "md5");
300 _pg_md5_hash(crypt_buf, passwd_len + salt_len, buf + 3);
301 free(crypt_buf);
304 pgmsg_t *pgmsg_create_pass (int req, const char *salt, size_t salt_len, const char *user, const char *pass) {
305 pgmsg_t *msg = pgmsg_create(PG_PASS);
306 char *pwd = malloc(2 * (PG_MD5PASS_LEN + 1)),
307 *pwd2 = pwd + PG_MD5PASS_LEN + 1,
308 *pwd_to_send;
309 _pg_md5_encrypt(pass, user, strlen(user), pwd2);
310 _pg_md5_encrypt(pwd2 + sizeof("md5")-1, salt, 4, pwd);
311 pwd_to_send = pwd;
312 switch (req) {
313 case PG_REQMD5:
314 pwd_to_send = pwd;
315 break;
316 case PG_REQPASS:
317 pwd_to_send = (char*)pass;
319 _setstr(&msg, pwd_to_send, strlen(pwd_to_send));
320 _seti8(&msg, 0);
321 return msg;
324 int pgmsg_send (int fd, pgmsg_t *msg) {
325 void *buf;
326 size_t size;
327 ssize_t sent = 0, wrote = 0;
328 if ('\0' == msg->body.type) {
329 buf = &msg->body.len;
330 size = be32toh(msg->body.len);
331 } else {
332 buf = &msg->body;
333 size = be32toh(msg->body.len) + sizeof(char);
335 while (sent < size) {
337 wrote = send(fd, buf, size - sent, 0);
338 while (wrote < 0 && EINTR == errno);
339 sent += wrote;
340 buf += wrote;
342 return sent == size ? 0 : -1;
345 int pgmsg_recv (int fd, pgmsg_t **msg) {
346 ssize_t readed = 0, total = 0;
347 pgmsg_t *m;
348 size_t bufsize;
349 struct {
350 char type;
351 int32_t len;
352 } __attribute__ ((packed)) header;
353 while (total < HEADER_SIZE && (readed = recv(fd, (void*)(&header + total), HEADER_SIZE - total, 0)) > 0)
354 total += readed;
355 if (-1 == readed)
356 return -1;
357 header.len = be32toh(header.len);
358 bufsize = sizeof(pgmsg_t) + header.len;
359 if (!(m = calloc(bufsize, sizeof(int8_t))))
360 return -1;
361 m->body.type = header.type;
362 m->body.len = header.len;
363 total = header.len - sizeof(int32_t);
364 while (total > 0) {
365 char *p = m->body.ptr;
366 if (-1 == (readed = recv(fd, p, total, 0))) {
367 free(m);
368 return -1;
370 total -= readed;
371 p += readed;
373 *msg = m;
374 return 0;
377 static int _parse_param_status (pgmsg_body_t *body, pgmsg_param_status_t *pmsg) {
378 char *p = body->ptr, *e = p + body->len, *q = p;
379 while (*q && q < e) ++q;
380 if (q == e)
381 return -1;
382 pmsg->name = p;
383 pmsg->value = q + 1;
384 return 0;
387 static int _parse_error (pgmsg_body_t *body, pgmsg_error_t *pmsg) {
388 char *p = body->ptr, *e = p + body->len;
389 while (p < e) {
390 switch (*p) {
391 case PG_SEVERITY:
392 pmsg->severity = ++p;
393 break;
394 case PG_FATAL:
395 pmsg->text = ++p;
396 break;
397 case PG_SQLSTATE:
398 pmsg->code = ++p;
399 break;
400 case PG_MESSAGE:
401 pmsg->message = ++p;
402 break;
403 case PG_POSITION:
404 pmsg->position = ++p;
405 break;
406 case PG_FILE:
407 pmsg->file = ++p;
408 break;
409 case PG_LINE:
410 pmsg->line = ++p;
411 break;
412 case PG_ROUTINE:
413 pmsg->routine = ++p;
414 break;
415 case '\0':
416 return 0;
417 default:
418 break;
420 while (p < e && *p) ++p;
421 if (p == e)
422 return 0;
423 if (++p == e)
424 return 0;
426 return 0;
429 static int _parse_rowdesc (pgmsg_t *msg, pgmsg_rowdesc_t *pmsg) {
430 pmsg->nflds = _geti16(msg);
431 for (int i = 0; i < pmsg->nflds; ++i) {
432 const char *fname = _getstr(msg);
433 if (!fname)
434 return -1;
435 pmsg->fields[i].fname = fname;
436 pmsg->fields[i].oid_table = _geti32(msg);
437 pmsg->fields[i].idx_field = _geti16(msg);
438 pmsg->fields[i].oid_field = _geti32(msg);
439 pmsg->fields[i].field_len = _geti16(msg);
440 pmsg->fields[i].type_mod = _geti32(msg);
441 pmsg->fields[i].field_fmt = _geti16(msg);
443 return 0;
446 static int _parse_datarow (pgmsg_t *msg, pgmsg_datarow_t *pmsg) {
447 pmsg->nflds = _geti16(msg);
448 for (int i = 0; i < pmsg->nflds; ++i) {
449 int32_t len = pmsg->fields[i].len = _geti32(msg);
450 pmsg->fields[i].data = NULL;
451 if (len >= 0) {
452 pmsg->fields[i].data = (uint8_t*)msg->pc;
453 msg->pc += len;
456 return 0;
459 pgmsg_resp_t *pgmsg_parse (pgmsg_t *msg) {
460 pgmsg_resp_t *resp = NULL;
461 msg->pc = msg->body.ptr;
462 switch (msg->body.type) {
463 case PG_AUTHOK:
464 resp = malloc(sizeof(pgmsg_auth_t));
465 resp->type = msg->body.type;
466 switch (resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr))) {
467 case PG_OK:
468 break;
469 case PG_REQMD5:
470 memcpy(resp->msg_auth.kind.md5_auth, msg->body.ptr + sizeof(int32_t), sizeof(uint8_t)*4);
471 break;
472 default:
473 free(resp);
474 resp = NULL;
475 break;
477 break;
478 case PG_PARAMSTATUS:
479 resp = malloc(sizeof(pgmsg_param_status_t));
480 resp->type = msg->body.type;
481 if (-1 == _parse_param_status(&msg->body, &resp->msg_param_status)) {
482 free(resp);
483 resp = NULL;
485 break;
486 case PG_BACKENDKEYDATA:
487 resp = malloc(sizeof(pgmsg_backend_keydata_t));
488 resp->type = msg->body.type;
489 resp->msg_backend_keydata.pid = be32toh(*((int32_t*)msg->body.ptr));
490 resp->msg_backend_keydata.sk = be32toh(*((int32_t*)msg->body.ptr + sizeof(int32_t)));
491 break;
492 case PG_READY:
493 resp = malloc(sizeof(pgmsg_ready_t));
494 resp->type = msg->body.type;
495 resp->msg_ready.tr = msg->body.ptr[0];
496 break;
497 case PG_TERM:
498 resp =malloc(sizeof(pgmsg_auth_t));
499 resp->type = msg->body.type;
500 break;
501 case PG_ERROR:
502 resp = malloc(sizeof(pgmsg_error_t));
503 resp->type = msg->body.type;
504 if (-1 == _parse_error(&msg->body, &resp->msg_error)) {
505 free(resp);
506 resp = NULL;
508 break;
509 case PG_ROWDESC:
510 resp = malloc(sizeof(pgmsg_rowdesc_t) + sizeof(pgmsg_field_t) * be16toh(*((int16_t*)msg->body.ptr)));
511 resp->type = msg->body.type;
512 if (-1 == _parse_rowdesc(msg, &resp->msg_rowdesc)) {
513 free(resp);
514 resp = NULL;
516 break;
517 case PG_DATAROW:
518 resp = malloc(sizeof(pgmsg_datarow_t) + sizeof(pgmsg_data_t) * be16toh(*((int16_t*)msg->body.ptr)));
519 resp->type = msg->body.type;
520 if (-1 == _parse_datarow(msg, &resp->msg_datarow)) {
521 free(resp);
522 resp = NULL;
524 break;
525 case PG_CMDCOMPLETE:
526 resp = malloc(sizeof(pgmsg_cmd_complete_t));
527 resp->type = msg->body.type;
528 resp->msg_complete.tag = msg->body.ptr;
529 break;
531 return resp;