1 #include "libpgcli/pgprov3.h"
4 #define HEADER_SIZE sizeof(char) + sizeof(int32_t)
6 static int _resize (pgmsg_t
**msg
, size_t len
) {
8 int nlen
= (*msg
)->len
+ len
;
10 bufsize
= sizeof(pgmsg_t
) + (nlen
/ CHUNKSIZE
) * CHUNKSIZE
+ CHUNKSIZE
;
11 if (bufsize
== (*msg
)->bufsize
)
13 pc_len
= (uintptr_t)(*msg
)->pc
- (uintptr_t)(*msg
)->body
.ptr
;
14 if (!(res
= realloc(*msg
, bufsize
)))
16 res
->bufsize
= bufsize
;
17 res
->pc
= res
->body
.ptr
+ pc_len
;
22 pgmsg_t
*pgmsg_create (char type
) {
23 size_t bufsize
= sizeof(pgmsg_t
) + CHUNKSIZE
;
24 pgmsg_t
*msg
= malloc(bufsize
);
25 if (!msg
) return NULL
;
26 msg
->bufsize
= bufsize
;
27 msg
->len
= sizeof(pgmsg_t
);
28 msg
->pc
= msg
->body
.ptr
;
29 msg
->body
.type
= type
;
30 msg
->body
.len
= htobe32(sizeof(int32_t));
34 static int _seti8 (pgmsg_t
**msg
, int8_t x
) {
35 if (-1 == _resize(msg
, sizeof(int8_t)))
37 *((int8_t*)(*msg
)->pc
) = x
;
38 (*msg
)->pc
+= sizeof(int8_t);
39 (*msg
)->body
.len
= htobe32(be32toh((*msg
)->body
.len
) + sizeof(int8_t));
40 (*msg
)->len
+= sizeof(int8_t);
44 static int _seti16 (pgmsg_t
**msg
, int16_t x
) {
45 if (-1 == _resize(msg
, sizeof(int16_t)))
47 *((int16_t*)(*msg
)->pc
) = htobe16(x
);
48 (*msg
)->pc
+= sizeof(int16_t);
49 (*msg
)->body
.len
= htobe32(be32toh((*msg
)->body
.len
) + sizeof(int16_t));
50 (*msg
)->len
+= sizeof(int16_t);
54 static int _seti32 (pgmsg_t
**msg
, int32_t x
) {
55 if (-1 == _resize(msg
, sizeof(int32_t)))
57 *((int32_t*)(*msg
)->pc
) = htobe32(x
);
58 (*msg
)->pc
+= sizeof(int32_t);
59 (*msg
)->body
.len
= htobe32(be32toh((*msg
)->body
.len
) + sizeof(int32_t));
60 (*msg
)->len
+= sizeof(int32_t);
64 static int _setstr (pgmsg_t
**msg
, const char *s
, size_t slen
) {
67 if (-1 == _resize(msg
, slen
))
69 memcpy((*msg
)->pc
, s
, slen
);
71 (*msg
)->body
.len
= htobe32(be32toh((*msg
)->body
.len
) + slen
);
76 static int16_t _geti16 (pgmsg_t
*msg
) {
77 int16_t r
= be16toh(*((int16_t*)msg
->pc
));
78 msg
->pc
+= sizeof(int16_t);
82 static int32_t _geti32 (pgmsg_t
*msg
) {
83 int32_t r
= be32toh(*((int32_t*)msg
->pc
));
84 msg
->pc
+= sizeof(int32_t);
88 static const char *_getstr (pgmsg_t
*msg
) {
89 char *p
= msg
->pc
, *e
= msg
->body
.ptr
+ be32toh(msg
->body
.len
);
90 while (*(msg
->pc
) && msg
->pc
< e
) ++msg
->pc
;
97 int pgmsg_set_param (pgmsg_t
**msg
, const char *name
, size_t name_len
, const char *value
, size_t value_len
) {
98 if (-1 == _setstr(msg
, name
, name_len
) ||
99 -1 == _seti8(msg
, 0) ||
100 -1 == _setstr(msg
, value
, value_len
) ||
101 -1 == _seti8(msg
, 0))
106 pgmsg_t
*pgmsg_create_startup (const char *user
, size_t user_len
, const char *database
, size_t database_len
) {
107 pgmsg_t
*msg
= pgmsg_create('\0');
108 _seti16(&msg
, PG_MAJOR_VER
);
109 _seti16(&msg
, PG_MINOR_VER
);
110 pgmsg_set_param(&msg
, CONST_STR_LEN("user"), user
, user_len
);
111 pgmsg_set_param(&msg
, CONST_STR_LEN("database"), database
, database_len
);
115 #define PG_USERDEFS 0x00000001
116 #define PG_DBDEFS 0x00000002
117 #define PG_ENCDEFS 0x00000004
119 static inline pgmsg_t
*_add_str_param (pgmsg_t
*msg
, const char *k
, size_t lk
, const char *v
, size_t lv
) {
120 _setstr(&msg
, k
, lk
);
122 _setstr(&msg
, v
, lv
);
127 static pgmsg_t
*_add_param (pgmsg_t
*msg
, const char *begin
, const char *end
, uint32_t *flags
) {
128 const char *p
= begin
, *p1
;
129 while (p
< end
&& '=' != *p
) ++p
;
133 while (p1
> begin
&& isspace(*(p1
- 1))) --p1
;
137 while (p
< end
&& isspace(*p
)) ++p
;
140 if (0 == strncmp(begin
, "dbname", (uintptr_t)p1
- (uintptr_t)begin
)) {
141 msg
= _add_str_param(msg
, CONST_STR_LEN("database"), p
, (uintptr_t)end
- (uintptr_t)p
);
145 if (0 == strncmp(begin
, "user", (uintptr_t)p1
- (uintptr_t)begin
)) {
146 msg
= _add_str_param(msg
, begin
, (uintptr_t)p1
- (uintptr_t)begin
, p
, (uintptr_t)end
- (uintptr_t)p
);
148 *flags
|= PG_USERDEFS
;
150 if (0 != strncmp(begin
, "host", (uintptr_t)p1
- (uintptr_t)begin
) &&
151 0 != strncmp(begin
, "port", (uintptr_t)p1
- (uintptr_t)begin
))
152 msg
= _add_str_param(msg
, begin
, (uintptr_t)p1
- (uintptr_t)begin
, p
, (uintptr_t)end
- (uintptr_t)p
);
156 void *parse_conninfo (void *data
, const char *conn_info
, parse_param_h fn
, uint32_t *flags
) {
158 const char *p
= conn_info
;
161 while (isspace(*p
)) ++p
;
164 while (*p
&& !isspace(*p
)) ++p
;
165 data
= fn(data
, q
, p
, flags
);
171 static pgmsg_t
*_startup_params (pgmsg_t
*msg
, const char *conn_info
) {
173 msg
= (pgmsg_t
*)parse_conninfo((void*)msg
, conn_info
, (parse_param_h
)_add_param
, &flags
);
174 if (!(flags
& PG_USERDEFS
)) {
175 char *s
= getenv("USER");
177 msg
= _add_str_param(msg
, CONST_STR_LEN("user"), s
, strlen(s
));
179 if (!(flags
& PG_DBDEFS
)) {
180 char *s
= getenv("USER");
182 msg
= _add_str_param(msg
, CONST_STR_LEN("database"), s
, strlen(s
));
184 if (!(flags
& PG_ENCDEFS
)) {
185 char *s
= getenv("LANG");
186 if (s
&& (s
= strchr(s
, '.')) && *++s
)
187 msg
= _add_str_param(msg
, CONST_STR_LEN("client_encoding"), s
, strlen(s
));
193 pgmsg_t
*pgmsg_create_startup_params (const char *conn_info
) {
194 pgmsg_t
*msg
= pgmsg_create('\0');
195 _seti16(&msg
, PG_MAJOR_VER
);
196 _seti16(&msg
, PG_MINOR_VER
);
197 return _startup_params(msg
, conn_info
);
200 pgmsg_t
*pgmsg_create_simple_query (const char *sql
, size_t sql_len
) {
201 pgmsg_t
*msg
= pgmsg_create(PG_SIMPLEQUERY
);
202 _setstr(&msg
, sql
, sql_len
);
207 pgmsg_t
*pgmsg_create_parse (const char *name
, size_t name_len
, const char *sql
, size_t sql_len
, int fld_len
, pgfld_t
**flds
) {
208 pgmsg_t
*msg
= pgmsg_create(PG_PARSE
);
209 if (name
&& 0 == name_len
)
210 name_len
= strlen(name
);
212 sql_len
= strlen(sql
);
213 _setstr(&msg
, name
, name_len
);
215 _setstr(&msg
, sql
, sql_len
);
217 _seti16(&msg
, fld_len
);
218 for (int i
= 0; i
< fld_len
; ++i
)
219 _seti32(&msg
, flds
[i
]->oid
);
223 pgmsg_t
*pgmsg_create_bind (const char *portal
, size_t portal_len
, const char *stmt
, size_t stmt_len
,
224 int fld_len
, pgfld_t
**flds
, int res_fmt_len
, int *res_fmt
) {
225 pgmsg_t
*msg
= pgmsg_create(PG_BIND
);
226 _setstr(&msg
, portal
, portal_len
);
228 _setstr(&msg
, stmt
, stmt_len
);
230 _seti16(&msg
, fld_len
);
231 for (int i
= 0; i
< fld_len
; ++i
)
232 _seti16(&msg
, flds
[i
]->fmt
);
233 _seti16(&msg
, fld_len
);
234 for (int i
= 0; i
< fld_len
; ++i
) {
235 if (flds
[i
]->is_null
)
238 _seti32(&msg
, flds
[i
]->len
);
239 if (0 == flds
[i
]->fmt
|| OID_UUID
== flds
[i
]->oid
)
240 _setstr(&msg
, flds
[i
]->data
.s
, flds
[i
]->len
);
242 _setstr(&msg
, (const char*)&flds
[i
]->data
, flds
[i
]->len
);
245 _seti16(&msg
, res_fmt_len
);
246 for (int i
= 0; i
< res_fmt_len
; ++i
)
247 _seti16(&msg
, res_fmt
[i
]);
251 pgmsg_t
*pgmsg_create_describe (uint8_t op
, const char *name
, size_t name_len
) {
252 pgmsg_t
*msg
= pgmsg_create(PG_DESCRIBE
);
254 _setstr(&msg
, name
, name_len
);
259 pgmsg_t
*pgmsg_create_execute (const char *portal
, size_t portal_len
, int32_t max_rows
) {
260 pgmsg_t
*msg
= pgmsg_create(PG_EXECUTE
);
261 _setstr(&msg
, portal
, portal_len
);
263 _seti32(&msg
, max_rows
);
267 pgmsg_t
*pgmsg_create_close(char what
, const char *str
, size_t slen
) {
268 pgmsg_t
*msg
= pgmsg_create(PG_CLOSE
);
269 _seti8(&msg
, (uint8_t)what
);
270 _setstr(&msg
, str
, slen
);
275 static void _pg_md5_hash (const void *buff
, size_t len
, char *out
) {
281 MD5_Update(&ctx
, buff
, 512);
283 MD5_Update(&ctx
, buff
, len
);
287 MD5_Final(digest
, &ctx
);
288 for (int i
= 0; i
< 16; ++i
)
289 snprintf(&(out
[i
*2]), 16*2, "%02x", (uint8_t)digest
[i
]);
292 static void _pg_md5_encrypt(const char *passwd
, const char *salt
, size_t salt_len
, char *buf
) {
293 size_t passwd_len
= strlen(passwd
);
294 char *crypt_buf
= malloc(passwd_len
+ salt_len
+ 1);
295 memcpy(crypt_buf
, passwd
, passwd_len
);
296 memcpy(crypt_buf
+ passwd_len
, salt
, salt_len
);
298 _pg_md5_hash(crypt_buf
, passwd_len
+ salt_len
, buf
+ 3);
302 pgmsg_t
*pgmsg_create_pass (int req
, const char *salt
, size_t salt_len
, const char *user
, const char *pass
) {
303 pgmsg_t
*msg
= pgmsg_create(PG_PASS
);
304 char *pwd
= malloc(2 * (PG_MD5PASS_LEN
+ 1)),
305 *pwd2
= pwd
+ PG_MD5PASS_LEN
+ 1,
307 _pg_md5_encrypt(pass
, user
, strlen(user
), pwd2
);
308 _pg_md5_encrypt(pwd2
+ sizeof("md5")-1, salt
, 4, pwd
);
315 pwd_to_send
= (char*)pass
;
317 _setstr(&msg
, pwd_to_send
, strlen(pwd_to_send
));
322 int pgmsg_send (int fd
, pgmsg_t
*msg
) {
325 ssize_t sent
= 0, wrote
= 0;
326 if ('\0' == msg
->body
.type
) {
327 buf
= &msg
->body
.len
;
328 size
= be32toh(msg
->body
.len
);
331 size
= be32toh(msg
->body
.len
) + sizeof(char);
333 while (sent
< size
) {
335 wrote
= send(fd
, buf
, size
- sent
, 0);
336 while (wrote
< 0 && EINTR
== errno
);
340 return sent
== size
? 0 : -1;
343 int pgmsg_recv (int fd
, pgmsg_t
**msg
) {
344 ssize_t readed
= 0, total
= 0;
350 } __attribute__ ((packed
)) header
;
351 while (total
< HEADER_SIZE
&& (readed
= recv(fd
, (void*)(&header
+ total
), HEADER_SIZE
- total
, 0)) > 0)
355 header
.len
= be32toh(header
.len
);
356 bufsize
= sizeof(pgmsg_t
) + header
.len
;
357 if (!(m
= calloc(bufsize
, sizeof(int8_t))))
359 m
->body
.type
= header
.type
;
360 m
->body
.len
= header
.len
;
361 total
= header
.len
- sizeof(int32_t);
363 char *p
= m
->body
.ptr
;
364 if (-1 == (readed
= recv(fd
, p
, total
, 0))) {
375 static int _parse_param_status (pgmsg_body_t
*body
, pgmsg_param_status_t
*pmsg
) {
376 char *p
= body
->ptr
, *e
= p
+ body
->len
, *q
= p
;
377 while (*q
&& q
< e
) ++q
;
385 static int _parse_error (pgmsg_body_t
*body
, pgmsg_error_t
*pmsg
) {
386 char *p
= body
->ptr
, *e
= p
+ body
->len
;
390 pmsg
->severity
= ++p
;
402 pmsg
->position
= ++p
;
418 while (p
< e
&& *p
) ++p
;
427 static int _parse_rowdesc (pgmsg_t
*msg
, pgmsg_rowdesc_t
*pmsg
) {
428 pmsg
->nflds
= _geti16(msg
);
429 for (int i
= 0; i
< pmsg
->nflds
; ++i
) {
430 const char *fname
= _getstr(msg
);
433 pmsg
->fields
[i
].fname
= fname
;
434 pmsg
->fields
[i
].oid_table
= _geti32(msg
);
435 pmsg
->fields
[i
].idx_field
= _geti16(msg
);
436 pmsg
->fields
[i
].oid_field
= _geti32(msg
);
437 pmsg
->fields
[i
].field_len
= _geti16(msg
);
438 pmsg
->fields
[i
].type_mod
= _geti32(msg
);
439 pmsg
->fields
[i
].field_fmt
= _geti16(msg
);
444 static int _parse_datarow (pgmsg_t
*msg
, pgmsg_datarow_t
*pmsg
) {
445 pmsg
->nflds
= _geti16(msg
);
446 for (int i
= 0; i
< pmsg
->nflds
; ++i
) {
447 int32_t len
= pmsg
->fields
[i
].len
= _geti32(msg
);
448 pmsg
->fields
[i
].data
= NULL
;
450 pmsg
->fields
[i
].data
= (uint8_t*)msg
->pc
;
457 pgmsg_resp_t
*pgmsg_parse (pgmsg_t
*msg
) {
458 pgmsg_resp_t
*resp
= NULL
;
459 msg
->pc
= msg
->body
.ptr
;
460 switch (msg
->body
.type
) {
462 resp
= malloc(sizeof(pgmsg_auth_t
));
463 resp
->type
= msg
->body
.type
;
464 switch (resp
->msg_auth
.success
= be32toh(*((int32_t*)msg
->body
.ptr
))) {
468 memcpy(resp
->msg_auth
.kind
.md5_auth
, msg
->body
.ptr
+ sizeof(int32_t), sizeof(uint8_t)*4);
477 resp
= malloc(sizeof(pgmsg_param_status_t
));
478 resp
->type
= msg
->body
.type
;
479 if (-1 == _parse_param_status(&msg
->body
, &resp
->msg_param_status
)) {
484 case PG_BACKENDKEYDATA
:
485 resp
= malloc(sizeof(pgmsg_backend_keydata_t
));
486 resp
->type
= msg
->body
.type
;
487 resp
->msg_backend_keydata
.pid
= be32toh(*((int32_t*)msg
->body
.ptr
));
488 resp
->msg_backend_keydata
.sk
= be32toh(*((int32_t*)msg
->body
.ptr
+ sizeof(int32_t)));
491 resp
= malloc(sizeof(pgmsg_ready_t
));
492 resp
->type
= msg
->body
.type
;
493 resp
->msg_ready
.tr
= msg
->body
.ptr
[0];
496 resp
=malloc(sizeof(pgmsg_auth_t
));
497 resp
->type
= msg
->body
.type
;
500 resp
= malloc(sizeof(pgmsg_error_t
));
501 resp
->type
= msg
->body
.type
;
502 if (-1 == _parse_error(&msg
->body
, &resp
->msg_error
)) {
508 resp
= malloc(sizeof(pgmsg_rowdesc_t
) + sizeof(pgmsg_field_t
) * be16toh(*((int16_t*)msg
->body
.ptr
)));
509 resp
->type
= msg
->body
.type
;
510 if (-1 == _parse_rowdesc(msg
, &resp
->msg_rowdesc
)) {
516 resp
= malloc(sizeof(pgmsg_datarow_t
) + sizeof(pgmsg_data_t
) * be16toh(*((int16_t*)msg
->body
.ptr
)));
517 resp
->type
= msg
->body
.type
;
518 if (-1 == _parse_datarow(msg
, &resp
->msg_datarow
)) {
524 resp
= malloc(sizeof(pgmsg_cmd_complete_t
));
525 resp
->type
= msg
->body
.type
;
526 resp
->msg_complete
.tag
= msg
->body
.ptr
;