auth, not tested
[libpgclient.git] / src / pgprov3.c
blobff1d961c15c20c8cb6b90244d34207d069cf393d
1 #include "libpgcli/pgprov3.h"
3 #define CHUNKSIZE 32
4 #define HEADER_SIZE sizeof(char) + sizeof(int32_t)
6 static int _resize (pgmsg_t **msg, size_t len) {
7 pgmsg_t *res;
8 int nlen = (*msg)->len + len;
9 size_t pc_len,
10 bufsize = sizeof(pgmsg_t) + (nlen / CHUNKSIZE) * CHUNKSIZE + CHUNKSIZE;
11 if (bufsize == (*msg)->bufsize)
12 return 0;
13 pc_len = (uintptr_t)(*msg)->pc - (uintptr_t)(*msg)->body.ptr;
14 if (!(res = realloc(*msg, bufsize)))
15 return -1;
16 res->bufsize = bufsize;
17 res->pc = res->body.ptr + pc_len;
18 *msg = res;
19 return 0;
22 pgmsg_t *pgmsg_create (char type) {
23 size_t bufsize = sizeof(pgmsg_t) + CHUNKSIZE;
24 pgmsg_t *msg = malloc(bufsize);
25 if (!msg) return NULL;
26 msg->bufsize = bufsize;
27 msg->len = sizeof(pgmsg_t);
28 msg->pc = msg->body.ptr;
29 msg->body.type = type;
30 msg->body.len = htobe32(sizeof(int32_t));
31 return msg;
34 static int _seti8 (pgmsg_t **msg, int8_t x) {
35 if (-1 == _resize(msg, sizeof(int8_t)))
36 return -1;
37 *((int8_t*)(*msg)->pc) = x;
38 (*msg)->pc += sizeof(int8_t);
39 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int8_t));
40 (*msg)->len += sizeof(int8_t);
41 return 0;
44 static int _seti16 (pgmsg_t **msg, int16_t x) {
45 if (-1 == _resize(msg, sizeof(int16_t)))
46 return -1;
47 *((int16_t*)(*msg)->pc) = htobe16(x);
48 (*msg)->pc += sizeof(int16_t);
49 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int16_t));
50 (*msg)->len += sizeof(int16_t);
51 return 0;
54 static int _seti32 (pgmsg_t **msg, int32_t x) {
55 if (-1 == _resize(msg, sizeof(int32_t)))
56 return -1;
57 *((int32_t*)(*msg)->pc) = htobe32(x);
58 (*msg)->pc += sizeof(int32_t);
59 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + sizeof(int32_t));
60 (*msg)->len += sizeof(int32_t);
61 return 0;
64 static int _setstr (pgmsg_t **msg, const char *s, size_t slen) {
65 if (0 == slen)
66 return 0;
67 if (-1 == _resize(msg, slen))
68 return -1;
69 memcpy((*msg)->pc, s, slen);
70 (*msg)->pc += slen;
71 (*msg)->body.len = htobe32(be32toh((*msg)->body.len) + slen);
72 (*msg)->len += slen;
73 return 0;
76 static int16_t _geti16 (pgmsg_t *msg) {
77 int16_t r = be16toh(*((int16_t*)msg->pc));
78 msg->pc += sizeof(int16_t);
79 return r;
82 static int32_t _geti32 (pgmsg_t *msg) {
83 int32_t r = be32toh(*((int32_t*)msg->pc));
84 msg->pc += sizeof(int32_t);
85 return r;
88 static const char *_getstr (pgmsg_t *msg) {
89 char *p = msg->pc, *e = msg->body.ptr + be32toh(msg->body.len);
90 while (*(msg->pc) && msg->pc < e) ++msg->pc;
91 if (e == msg->pc)
92 return NULL;
93 ++msg->pc;
94 return p;
97 int pgmsg_set_param (pgmsg_t **msg, const char *name, size_t name_len, const char *value, size_t value_len) {
98 if (-1 == _setstr(msg, name, name_len) ||
99 -1 == _seti8(msg, 0) ||
100 -1 == _setstr(msg, value, value_len) ||
101 -1 == _seti8(msg, 0))
102 return -1;
103 return 0;
106 pgmsg_t *pgmsg_create_startup (const char *user, size_t user_len, const char *database, size_t database_len) {
107 pgmsg_t *msg = pgmsg_create('\0');
108 _seti16(&msg, PG_MAJOR_VER);
109 _seti16(&msg, PG_MINOR_VER);
110 pgmsg_set_param(&msg, CONST_STR_LEN("user"), user, user_len);
111 pgmsg_set_param(&msg, CONST_STR_LEN("database"), database, database_len);
112 return msg;
115 #define PG_USERDEFS 0x00000001
116 #define PG_DBDEFS 0x00000002
117 #define PG_ENCDEFS 0x00000004
119 static inline pgmsg_t *_add_str_param (pgmsg_t *msg, const char *k, size_t lk, const char *v, size_t lv) {
120 _setstr(&msg, k, lk);
121 _seti8(&msg, 0);
122 _setstr(&msg, v, lv);
123 _seti8(&msg, 0);
124 return msg;
127 static pgmsg_t *_add_param (pgmsg_t *msg, const char *begin, const char *end, uint32_t *flags) {
128 const char *p = begin, *p1;
129 while (p < end && '=' != *p) ++p;
130 if (p == end)
131 return msg;
132 p1 = p;
133 while (p1 > begin && isspace(*(p1 - 1))) --p1;
134 if (p == p1 - 1)
135 return msg;
136 ++p;
137 while (p < end && isspace(*p)) ++p;
138 if (p == end)
139 return msg;
140 if (0 == strncmp(begin, "dbname", (uintptr_t)p1 - (uintptr_t)begin)) {
141 msg = _add_str_param(msg, CONST_STR_LEN("database"), p, (uintptr_t)end - (uintptr_t)p);
142 if (flags)
143 *flags |= PG_DBDEFS;
144 } else
145 if (0 == strncmp(begin, "user", (uintptr_t)p1 - (uintptr_t)begin)) {
146 msg = _add_str_param(msg, begin, (uintptr_t)p1 - (uintptr_t)begin, p, (uintptr_t)end - (uintptr_t)p);
147 if (flags)
148 *flags |= PG_USERDEFS;
149 } else
150 if (0 != strncmp(begin, "host", (uintptr_t)p1 - (uintptr_t)begin) &&
151 0 != strncmp(begin, "port", (uintptr_t)p1 - (uintptr_t)begin))
152 msg = _add_str_param(msg, begin, (uintptr_t)p1 - (uintptr_t)begin, p, (uintptr_t)end - (uintptr_t)p);
153 return msg;
156 void *parse_conninfo (void *data, const char *conn_info, parse_param_h fn, uint32_t *flags) {
157 if (conn_info) {
158 const char *p = conn_info;
159 while (*p) {
160 const char *q;
161 while (isspace(*p)) ++p;
162 if (!(*p)) break;
163 q = p;
164 while (*p && !isspace(*p)) ++p;
165 data = fn(data, q, p, flags);
168 return data;
171 static pgmsg_t *_startup_params (pgmsg_t *msg, const char *conn_info) {
172 uint32_t flags = 0;
173 msg = (pgmsg_t*)parse_conninfo((void*)msg, conn_info, (parse_param_h)_add_param, &flags);
174 if (!(flags & PG_USERDEFS)) {
175 char *s = getenv("USER");
176 if (s)
177 msg = _add_str_param(msg, CONST_STR_LEN("user"), s, strlen(s));
179 if (!(flags & PG_DBDEFS)) {
180 char *s = getenv("USER");
181 if (s)
182 msg = _add_str_param(msg, CONST_STR_LEN("database"), s, strlen(s));
184 if (!(flags & PG_ENCDEFS)) {
185 char *s = getenv("LANG");
186 if (s && (s = strchr(s, '.')) && *++s)
187 msg = _add_str_param(msg, CONST_STR_LEN("client_encoding"), s, strlen(s));
189 _seti8(&msg, 0);
190 return msg;
193 pgmsg_t *pgmsg_create_startup_params (const char *conn_info) {
194 pgmsg_t *msg = pgmsg_create('\0');
195 _seti16(&msg, PG_MAJOR_VER);
196 _seti16(&msg, PG_MINOR_VER);
197 return _startup_params(msg, conn_info);
200 pgmsg_t *pgmsg_create_simple_query (const char *sql, size_t sql_len) {
201 pgmsg_t *msg = pgmsg_create(PG_SIMPLEQUERY);
202 _setstr(&msg, sql, sql_len);
203 _seti8(&msg, 0);
204 return msg;
207 pgmsg_t *pgmsg_create_parse (const char *name, size_t name_len, const char *sql, size_t sql_len, int fld_len, pgfld_t **flds) {
208 pgmsg_t *msg = pgmsg_create(PG_PARSE);
209 if (name && 0 == name_len)
210 name_len = strlen(name);
211 if (0 == sql_len)
212 sql_len = strlen(sql);
213 _setstr(&msg, name, name_len);
214 _seti8(&msg, 0);
215 _setstr(&msg, sql, sql_len);
216 _seti8(&msg, 0);
217 _seti16(&msg, fld_len);
218 for (int i = 0; i < fld_len; ++i)
219 _seti32(&msg, flds[i]->oid);
220 return msg;
223 pgmsg_t *pgmsg_create_bind (const char *portal, size_t portal_len, const char *stmt, size_t stmt_len,
224 int fld_len, pgfld_t **flds, int res_fmt_len, int *res_fmt) {
225 pgmsg_t *msg = pgmsg_create(PG_BIND);
226 _setstr(&msg, portal, portal_len);
227 _seti8(&msg, 0);
228 _setstr(&msg, stmt, stmt_len);
229 _seti8(&msg, 0);
230 _seti16(&msg, fld_len);
231 for (int i = 0; i < fld_len; ++i)
232 _seti16(&msg, flds[i]->fmt);
233 _seti16(&msg, fld_len);
234 for (int i = 0; i < fld_len; ++i) {
235 if (flds[i]->is_null)
236 _seti32(&msg, -1);
237 else {
238 _seti32(&msg, flds[i]->len);
239 if (0 == flds[i]->fmt || OID_UUID == flds[i]->oid)
240 _setstr(&msg, flds[i]->data.s, flds[i]->len);
241 else
242 _setstr(&msg, (const char*)&flds[i]->data, flds[i]->len);
245 _seti16(&msg, res_fmt_len);
246 for (int i = 0; i < res_fmt_len; ++i)
247 _seti16(&msg, res_fmt[i]);
248 return msg;
251 pgmsg_t *pgmsg_create_describe (uint8_t op, const char *name, size_t name_len) {
252 pgmsg_t *msg = pgmsg_create(PG_DESCRIBE);
253 _seti8(&msg, op);
254 _setstr(&msg, name, name_len);
255 _seti8(&msg, 0);
256 return msg;
259 pgmsg_t *pgmsg_create_execute (const char *portal, size_t portal_len, int32_t max_rows) {
260 pgmsg_t *msg = pgmsg_create(PG_EXECUTE);
261 _setstr(&msg, portal, portal_len);
262 _seti8(&msg, 0);
263 _seti32(&msg, max_rows);
264 return msg;
267 pgmsg_t *pgmsg_create_close(char what, const char *str, size_t slen) {
268 pgmsg_t *msg = pgmsg_create(PG_CLOSE);
269 _seti8(&msg, (uint8_t)what);
270 _setstr(&msg, str, slen);
271 _seti8(&msg, 0);
272 return msg;
275 static void _pg_md5_hash (const void *buff, size_t len, char *out) {
276 uint8_t digest [16];
277 MD5_CTX ctx;
278 MD5_Init(&ctx);
279 while (len > 0) {
280 if (len > 512)
281 MD5_Update(&ctx, buff, 512);
282 else
283 MD5_Update(&ctx, buff, len);
284 buff += 512;
285 len -= 512;
287 MD5_Final(digest, &ctx);
288 for (int i = 0; i < 16; ++i)
289 snprintf(&(out[i*2]), 16*2, "%02x", (uint8_t)digest[i]);
292 static void _pg_md5_encrypt(const char *passwd, const char *salt, size_t salt_len, char *buf) {
293 size_t passwd_len = strlen(passwd);
294 char *crypt_buf = malloc(passwd_len + salt_len + 1);
295 memcpy(crypt_buf, passwd, passwd_len);
296 memcpy(crypt_buf + passwd_len, salt, salt_len);
297 strcpy(buf, "md5");
298 _pg_md5_hash(crypt_buf, passwd_len + salt_len, buf + 3);
299 free(crypt_buf);
302 pgmsg_t *pgmsg_create_pass (int req, const char *salt, size_t salt_len, const char *user, const char *pass) {
303 pgmsg_t *msg = pgmsg_create(PG_PASS);
304 char *pwd = malloc(2 * (PG_MD5PASS_LEN + 1)),
305 *pwd2 = pwd + PG_MD5PASS_LEN + 1,
306 *pwd_to_send;
307 _pg_md5_encrypt(pass, user, strlen(user), pwd2);
308 _pg_md5_encrypt(pwd2 + sizeof("md5")-1, salt, 4, pwd);
309 pwd_to_send = pwd;
310 switch (req) {
311 case PG_REQMD5:
312 pwd_to_send = pwd;
313 break;
314 case PG_REQPASS:
315 pwd_to_send = (char*)pass;
317 _setstr(&msg, pwd_to_send, strlen(pwd_to_send));
318 _seti8(&msg, 0);
319 return msg;
322 int pgmsg_send (int fd, pgmsg_t *msg) {
323 void *buf;
324 size_t size;
325 ssize_t sent = 0, wrote = 0;
326 if ('\0' == msg->body.type) {
327 buf = &msg->body.len;
328 size = be32toh(msg->body.len);
329 } else {
330 buf = &msg->body;
331 size = be32toh(msg->body.len) + sizeof(char);
333 while (sent < size) {
335 wrote = send(fd, buf, size - sent, 0);
336 while (wrote < 0 && EINTR == errno);
337 sent += wrote;
338 buf += wrote;
340 return sent == size ? 0 : -1;
343 int pgmsg_recv (int fd, pgmsg_t **msg) {
344 ssize_t readed = 0, total = 0;
345 pgmsg_t *m;
346 size_t bufsize;
347 struct {
348 char type;
349 int32_t len;
350 } __attribute__ ((packed)) header;
351 while (total < HEADER_SIZE && (readed = recv(fd, (void*)(&header + total), HEADER_SIZE - total, 0)) > 0)
352 total += readed;
353 if (-1 == readed)
354 return -1;
355 header.len = be32toh(header.len);
356 bufsize = sizeof(pgmsg_t) + header.len;
357 if (!(m = calloc(bufsize, sizeof(int8_t))))
358 return -1;
359 m->body.type = header.type;
360 m->body.len = header.len;
361 total = header.len - sizeof(int32_t);
362 while (total > 0) {
363 char *p = m->body.ptr;
364 if (-1 == (readed = recv(fd, p, total, 0))) {
365 free(m);
366 return -1;
368 total -= readed;
369 p += readed;
371 *msg = m;
372 return 0;
375 static int _parse_param_status (pgmsg_body_t *body, pgmsg_param_status_t *pmsg) {
376 char *p = body->ptr, *e = p + body->len, *q = p;
377 while (*q && q < e) ++q;
378 if (q == e)
379 return -1;
380 pmsg->name = p;
381 pmsg->value = q + 1;
382 return 0;
385 static int _parse_error (pgmsg_body_t *body, pgmsg_error_t *pmsg) {
386 char *p = body->ptr, *e = p + body->len;
387 while (p < e) {
388 switch (*p) {
389 case PG_SEVERITY:
390 pmsg->severity = ++p;
391 break;
392 case PG_FATAL:
393 pmsg->text = ++p;
394 break;
395 case PG_SQLSTATE:
396 pmsg->code = ++p;
397 break;
398 case PG_MESSAGE:
399 pmsg->message = ++p;
400 break;
401 case PG_POSITION:
402 pmsg->position = ++p;
403 break;
404 case PG_FILE:
405 pmsg->file = ++p;
406 break;
407 case PG_LINE:
408 pmsg->line = ++p;
409 break;
410 case PG_ROUTINE:
411 pmsg->routine = ++p;
412 break;
413 case '\0':
414 return 0;
415 default:
416 break;
418 while (p < e && *p) ++p;
419 if (p == e)
420 return 0;
421 if (++p == e)
422 return 0;
424 return 0;
427 static int _parse_rowdesc (pgmsg_t *msg, pgmsg_rowdesc_t *pmsg) {
428 pmsg->nflds = _geti16(msg);
429 for (int i = 0; i < pmsg->nflds; ++i) {
430 const char *fname = _getstr(msg);
431 if (!fname)
432 return -1;
433 pmsg->fields[i].fname = fname;
434 pmsg->fields[i].oid_table = _geti32(msg);
435 pmsg->fields[i].idx_field = _geti16(msg);
436 pmsg->fields[i].oid_field = _geti32(msg);
437 pmsg->fields[i].field_len = _geti16(msg);
438 pmsg->fields[i].type_mod = _geti32(msg);
439 pmsg->fields[i].field_fmt = _geti16(msg);
441 return 0;
444 static int _parse_datarow (pgmsg_t *msg, pgmsg_datarow_t *pmsg) {
445 pmsg->nflds = _geti16(msg);
446 for (int i = 0; i < pmsg->nflds; ++i) {
447 int32_t len = pmsg->fields[i].len = _geti32(msg);
448 pmsg->fields[i].data = NULL;
449 if (len >= 0) {
450 pmsg->fields[i].data = (uint8_t*)msg->pc;
451 msg->pc += len;
454 return 0;
457 pgmsg_resp_t *pgmsg_parse (pgmsg_t *msg) {
458 pgmsg_resp_t *resp = NULL;
459 msg->pc = msg->body.ptr;
460 switch (msg->body.type) {
461 case PG_AUTHOK:
462 resp = malloc(sizeof(pgmsg_auth_t));
463 resp->type = msg->body.type;
464 switch (resp->msg_auth.success = be32toh(*((int32_t*)msg->body.ptr))) {
465 case PG_OK:
466 break;
467 case PG_REQMD5:
468 memcpy(resp->msg_auth.kind.md5_auth, msg->body.ptr + sizeof(int32_t), sizeof(uint8_t)*4);
469 break;
470 default:
471 free(resp);
472 resp = NULL;
473 break;
475 break;
476 case PG_PARAMSTATUS:
477 resp = malloc(sizeof(pgmsg_param_status_t));
478 resp->type = msg->body.type;
479 if (-1 == _parse_param_status(&msg->body, &resp->msg_param_status)) {
480 free(resp);
481 resp = NULL;
483 break;
484 case PG_BACKENDKEYDATA:
485 resp = malloc(sizeof(pgmsg_backend_keydata_t));
486 resp->type = msg->body.type;
487 resp->msg_backend_keydata.pid = be32toh(*((int32_t*)msg->body.ptr));
488 resp->msg_backend_keydata.sk = be32toh(*((int32_t*)msg->body.ptr + sizeof(int32_t)));
489 break;
490 case PG_READY:
491 resp = malloc(sizeof(pgmsg_ready_t));
492 resp->type = msg->body.type;
493 resp->msg_ready.tr = msg->body.ptr[0];
494 break;
495 case PG_TERM:
496 resp =malloc(sizeof(pgmsg_auth_t));
497 resp->type = msg->body.type;
498 break;
499 case PG_ERROR:
500 resp = malloc(sizeof(pgmsg_error_t));
501 resp->type = msg->body.type;
502 if (-1 == _parse_error(&msg->body, &resp->msg_error)) {
503 free(resp);
504 resp = NULL;
506 break;
507 case PG_ROWDESC:
508 resp = malloc(sizeof(pgmsg_rowdesc_t) + sizeof(pgmsg_field_t) * be16toh(*((int16_t*)msg->body.ptr)));
509 resp->type = msg->body.type;
510 if (-1 == _parse_rowdesc(msg, &resp->msg_rowdesc)) {
511 free(resp);
512 resp = NULL;
514 break;
515 case PG_DATAROW:
516 resp = malloc(sizeof(pgmsg_datarow_t) + sizeof(pgmsg_data_t) * be16toh(*((int16_t*)msg->body.ptr)));
517 resp->type = msg->body.type;
518 if (-1 == _parse_datarow(msg, &resp->msg_datarow)) {
519 free(resp);
520 resp = NULL;
522 break;
523 case PG_CMDCOMPLETE:
524 resp = malloc(sizeof(pgmsg_cmd_complete_t));
525 resp->type = msg->body.type;
526 resp->msg_complete.tag = msg->body.ptr;
527 break;
529 return resp;