Snapshot of upstream SQLite 3.38.5
[sqlcipher.git] / ext / misc / decimal.c
blob37c6c2f52cc26169c6fd55ecd5887425f182680d
1 /*
2 ** 2020-06-22
3 **
4 ** The author disclaims copyright to this source code. In place of
5 ** a legal notice, here is a blessing:
6 **
7 ** May you do good and not evil.
8 ** May you find forgiveness for yourself and forgive others.
9 ** May you share freely, never taking more than you give.
11 ******************************************************************************
13 ** Routines to implement arbitrary-precision decimal math.
15 ** The focus here is on simplicity and correctness, not performance.
17 #include "sqlite3ext.h"
18 SQLITE_EXTENSION_INIT1
19 #include <assert.h>
20 #include <string.h>
21 #include <ctype.h>
22 #include <stdlib.h>
24 /* Mark a function parameter as unused, to suppress nuisance compiler
25 ** warnings. */
26 #ifndef UNUSED_PARAMETER
27 # define UNUSED_PARAMETER(X) (void)(X)
28 #endif
31 /* A decimal object */
32 typedef struct Decimal Decimal;
33 struct Decimal {
34 char sign; /* 0 for positive, 1 for negative */
35 char oom; /* True if an OOM is encountered */
36 char isNull; /* True if holds a NULL rather than a number */
37 char isInit; /* True upon initialization */
38 int nDigit; /* Total number of digits */
39 int nFrac; /* Number of digits to the right of the decimal point */
40 signed char *a; /* Array of digits. Most significant first. */
44 ** Release memory held by a Decimal, but do not free the object itself.
46 static void decimal_clear(Decimal *p){
47 sqlite3_free(p->a);
51 ** Destroy a Decimal object
53 static void decimal_free(Decimal *p){
54 if( p ){
55 decimal_clear(p);
56 sqlite3_free(p);
61 ** Allocate a new Decimal object. Initialize it to the number given
62 ** by the input string.
64 static Decimal *decimal_new(
65 sqlite3_context *pCtx,
66 sqlite3_value *pIn,
67 int nAlt,
68 const unsigned char *zAlt
70 Decimal *p;
71 int n, i;
72 const unsigned char *zIn;
73 int iExp = 0;
74 p = sqlite3_malloc( sizeof(*p) );
75 if( p==0 ) goto new_no_mem;
76 p->sign = 0;
77 p->oom = 0;
78 p->isInit = 1;
79 p->isNull = 0;
80 p->nDigit = 0;
81 p->nFrac = 0;
82 if( zAlt ){
83 n = nAlt,
84 zIn = zAlt;
85 }else{
86 if( sqlite3_value_type(pIn)==SQLITE_NULL ){
87 p->a = 0;
88 p->isNull = 1;
89 return p;
91 n = sqlite3_value_bytes(pIn);
92 zIn = sqlite3_value_text(pIn);
94 p->a = sqlite3_malloc64( n+1 );
95 if( p->a==0 ) goto new_no_mem;
96 for(i=0; isspace(zIn[i]); i++){}
97 if( zIn[i]=='-' ){
98 p->sign = 1;
99 i++;
100 }else if( zIn[i]=='+' ){
101 i++;
103 while( i<n && zIn[i]=='0' ) i++;
104 while( i<n ){
105 char c = zIn[i];
106 if( c>='0' && c<='9' ){
107 p->a[p->nDigit++] = c - '0';
108 }else if( c=='.' ){
109 p->nFrac = p->nDigit + 1;
110 }else if( c=='e' || c=='E' ){
111 int j = i+1;
112 int neg = 0;
113 if( j>=n ) break;
114 if( zIn[j]=='-' ){
115 neg = 1;
116 j++;
117 }else if( zIn[j]=='+' ){
118 j++;
120 while( j<n && iExp<1000000 ){
121 if( zIn[j]>='0' && zIn[j]<='9' ){
122 iExp = iExp*10 + zIn[j] - '0';
124 j++;
126 if( neg ) iExp = -iExp;
127 break;
129 i++;
131 if( p->nFrac ){
132 p->nFrac = p->nDigit - (p->nFrac - 1);
134 if( iExp>0 ){
135 if( p->nFrac>0 ){
136 if( iExp<=p->nFrac ){
137 p->nFrac -= iExp;
138 iExp = 0;
139 }else{
140 iExp -= p->nFrac;
141 p->nFrac = 0;
144 if( iExp>0 ){
145 p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 );
146 if( p->a==0 ) goto new_no_mem;
147 memset(p->a+p->nDigit, 0, iExp);
148 p->nDigit += iExp;
150 }else if( iExp<0 ){
151 int nExtra;
152 iExp = -iExp;
153 nExtra = p->nDigit - p->nFrac - 1;
154 if( nExtra ){
155 if( nExtra>=iExp ){
156 p->nFrac += iExp;
157 iExp = 0;
158 }else{
159 iExp -= nExtra;
160 p->nFrac = p->nDigit - 1;
163 if( iExp>0 ){
164 p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 );
165 if( p->a==0 ) goto new_no_mem;
166 memmove(p->a+iExp, p->a, p->nDigit);
167 memset(p->a, 0, iExp);
168 p->nDigit += iExp;
169 p->nFrac += iExp;
172 return p;
174 new_no_mem:
175 if( pCtx ) sqlite3_result_error_nomem(pCtx);
176 sqlite3_free(p);
177 return 0;
181 ** Make the given Decimal the result.
183 static void decimal_result(sqlite3_context *pCtx, Decimal *p){
184 char *z;
185 int i, j;
186 int n;
187 if( p==0 || p->oom ){
188 sqlite3_result_error_nomem(pCtx);
189 return;
191 if( p->isNull ){
192 sqlite3_result_null(pCtx);
193 return;
195 z = sqlite3_malloc( p->nDigit+4 );
196 if( z==0 ){
197 sqlite3_result_error_nomem(pCtx);
198 return;
200 i = 0;
201 if( p->nDigit==0 || (p->nDigit==1 && p->a[0]==0) ){
202 p->sign = 0;
204 if( p->sign ){
205 z[0] = '-';
206 i = 1;
208 n = p->nDigit - p->nFrac;
209 if( n<=0 ){
210 z[i++] = '0';
212 j = 0;
213 while( n>1 && p->a[j]==0 ){
214 j++;
215 n--;
217 while( n>0 ){
218 z[i++] = p->a[j] + '0';
219 j++;
220 n--;
222 if( p->nFrac ){
223 z[i++] = '.';
225 z[i++] = p->a[j] + '0';
226 j++;
227 }while( j<p->nDigit );
229 z[i] = 0;
230 sqlite3_result_text(pCtx, z, i, sqlite3_free);
234 ** SQL Function: decimal(X)
236 ** Convert input X into decimal and then back into text
238 static void decimalFunc(
239 sqlite3_context *context,
240 int argc,
241 sqlite3_value **argv
243 Decimal *p = decimal_new(context, argv[0], 0, 0);
244 UNUSED_PARAMETER(argc);
245 decimal_result(context, p);
246 decimal_free(p);
250 ** Compare to Decimal objects. Return negative, 0, or positive if the
251 ** first object is less than, equal to, or greater than the second.
253 ** Preconditions for this routine:
255 ** pA!=0
256 ** pA->isNull==0
257 ** pB!=0
258 ** pB->isNull==0
260 static int decimal_cmp(const Decimal *pA, const Decimal *pB){
261 int nASig, nBSig, rc, n;
262 if( pA->sign!=pB->sign ){
263 return pA->sign ? -1 : +1;
265 if( pA->sign ){
266 const Decimal *pTemp = pA;
267 pA = pB;
268 pB = pTemp;
270 nASig = pA->nDigit - pA->nFrac;
271 nBSig = pB->nDigit - pB->nFrac;
272 if( nASig!=nBSig ){
273 return nASig - nBSig;
275 n = pA->nDigit;
276 if( n>pB->nDigit ) n = pB->nDigit;
277 rc = memcmp(pA->a, pB->a, n);
278 if( rc==0 ){
279 rc = pA->nDigit - pB->nDigit;
281 return rc;
285 ** SQL Function: decimal_cmp(X, Y)
287 ** Return negative, zero, or positive if X is less then, equal to, or
288 ** greater than Y.
290 static void decimalCmpFunc(
291 sqlite3_context *context,
292 int argc,
293 sqlite3_value **argv
295 Decimal *pA = 0, *pB = 0;
296 int rc;
298 UNUSED_PARAMETER(argc);
299 pA = decimal_new(context, argv[0], 0, 0);
300 if( pA==0 || pA->isNull ) goto cmp_done;
301 pB = decimal_new(context, argv[1], 0, 0);
302 if( pB==0 || pB->isNull ) goto cmp_done;
303 rc = decimal_cmp(pA, pB);
304 if( rc<0 ) rc = -1;
305 else if( rc>0 ) rc = +1;
306 sqlite3_result_int(context, rc);
307 cmp_done:
308 decimal_free(pA);
309 decimal_free(pB);
313 ** Expand the Decimal so that it has a least nDigit digits and nFrac
314 ** digits to the right of the decimal point.
316 static void decimal_expand(Decimal *p, int nDigit, int nFrac){
317 int nAddSig;
318 int nAddFrac;
319 if( p==0 ) return;
320 nAddFrac = nFrac - p->nFrac;
321 nAddSig = (nDigit - p->nDigit) - nAddFrac;
322 if( nAddFrac==0 && nAddSig==0 ) return;
323 p->a = sqlite3_realloc64(p->a, nDigit+1);
324 if( p->a==0 ){
325 p->oom = 1;
326 return;
328 if( nAddSig ){
329 memmove(p->a+nAddSig, p->a, p->nDigit);
330 memset(p->a, 0, nAddSig);
331 p->nDigit += nAddSig;
333 if( nAddFrac ){
334 memset(p->a+p->nDigit, 0, nAddFrac);
335 p->nDigit += nAddFrac;
336 p->nFrac += nAddFrac;
341 ** Add the value pB into pA.
343 ** Both pA and pB might become denormalized by this routine.
345 static void decimal_add(Decimal *pA, Decimal *pB){
346 int nSig, nFrac, nDigit;
347 int i, rc;
348 if( pA==0 ){
349 return;
351 if( pA->oom || pB==0 || pB->oom ){
352 pA->oom = 1;
353 return;
355 if( pA->isNull || pB->isNull ){
356 pA->isNull = 1;
357 return;
359 nSig = pA->nDigit - pA->nFrac;
360 if( nSig && pA->a[0]==0 ) nSig--;
361 if( nSig<pB->nDigit-pB->nFrac ){
362 nSig = pB->nDigit - pB->nFrac;
364 nFrac = pA->nFrac;
365 if( nFrac<pB->nFrac ) nFrac = pB->nFrac;
366 nDigit = nSig + nFrac + 1;
367 decimal_expand(pA, nDigit, nFrac);
368 decimal_expand(pB, nDigit, nFrac);
369 if( pA->oom || pB->oom ){
370 pA->oom = 1;
371 }else{
372 if( pA->sign==pB->sign ){
373 int carry = 0;
374 for(i=nDigit-1; i>=0; i--){
375 int x = pA->a[i] + pB->a[i] + carry;
376 if( x>=10 ){
377 carry = 1;
378 pA->a[i] = x - 10;
379 }else{
380 carry = 0;
381 pA->a[i] = x;
384 }else{
385 signed char *aA, *aB;
386 int borrow = 0;
387 rc = memcmp(pA->a, pB->a, nDigit);
388 if( rc<0 ){
389 aA = pB->a;
390 aB = pA->a;
391 pA->sign = !pA->sign;
392 }else{
393 aA = pA->a;
394 aB = pB->a;
396 for(i=nDigit-1; i>=0; i--){
397 int x = aA[i] - aB[i] - borrow;
398 if( x<0 ){
399 pA->a[i] = x+10;
400 borrow = 1;
401 }else{
402 pA->a[i] = x;
403 borrow = 0;
411 ** Compare text in decimal order.
413 static int decimalCollFunc(
414 void *notUsed,
415 int nKey1, const void *pKey1,
416 int nKey2, const void *pKey2
418 const unsigned char *zA = (const unsigned char*)pKey1;
419 const unsigned char *zB = (const unsigned char*)pKey2;
420 Decimal *pA = decimal_new(0, 0, nKey1, zA);
421 Decimal *pB = decimal_new(0, 0, nKey2, zB);
422 int rc;
423 UNUSED_PARAMETER(notUsed);
424 if( pA==0 || pB==0 ){
425 rc = 0;
426 }else{
427 rc = decimal_cmp(pA, pB);
429 decimal_free(pA);
430 decimal_free(pB);
431 return rc;
436 ** SQL Function: decimal_add(X, Y)
437 ** decimal_sub(X, Y)
439 ** Return the sum or difference of X and Y.
441 static void decimalAddFunc(
442 sqlite3_context *context,
443 int argc,
444 sqlite3_value **argv
446 Decimal *pA = decimal_new(context, argv[0], 0, 0);
447 Decimal *pB = decimal_new(context, argv[1], 0, 0);
448 UNUSED_PARAMETER(argc);
449 decimal_add(pA, pB);
450 decimal_result(context, pA);
451 decimal_free(pA);
452 decimal_free(pB);
454 static void decimalSubFunc(
455 sqlite3_context *context,
456 int argc,
457 sqlite3_value **argv
459 Decimal *pA = decimal_new(context, argv[0], 0, 0);
460 Decimal *pB = decimal_new(context, argv[1], 0, 0);
461 UNUSED_PARAMETER(argc);
462 if( pB ){
463 pB->sign = !pB->sign;
464 decimal_add(pA, pB);
465 decimal_result(context, pA);
467 decimal_free(pA);
468 decimal_free(pB);
471 /* Aggregate funcion: decimal_sum(X)
473 ** Works like sum() except that it uses decimal arithmetic for unlimited
474 ** precision.
476 static void decimalSumStep(
477 sqlite3_context *context,
478 int argc,
479 sqlite3_value **argv
481 Decimal *p;
482 Decimal *pArg;
483 UNUSED_PARAMETER(argc);
484 p = sqlite3_aggregate_context(context, sizeof(*p));
485 if( p==0 ) return;
486 if( !p->isInit ){
487 p->isInit = 1;
488 p->a = sqlite3_malloc(2);
489 if( p->a==0 ){
490 p->oom = 1;
491 }else{
492 p->a[0] = 0;
494 p->nDigit = 1;
495 p->nFrac = 0;
497 if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return;
498 pArg = decimal_new(context, argv[0], 0, 0);
499 decimal_add(p, pArg);
500 decimal_free(pArg);
502 static void decimalSumInverse(
503 sqlite3_context *context,
504 int argc,
505 sqlite3_value **argv
507 Decimal *p;
508 Decimal *pArg;
509 UNUSED_PARAMETER(argc);
510 p = sqlite3_aggregate_context(context, sizeof(*p));
511 if( p==0 ) return;
512 if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return;
513 pArg = decimal_new(context, argv[0], 0, 0);
514 if( pArg ) pArg->sign = !pArg->sign;
515 decimal_add(p, pArg);
516 decimal_free(pArg);
518 static void decimalSumValue(sqlite3_context *context){
519 Decimal *p = sqlite3_aggregate_context(context, 0);
520 if( p==0 ) return;
521 decimal_result(context, p);
523 static void decimalSumFinalize(sqlite3_context *context){
524 Decimal *p = sqlite3_aggregate_context(context, 0);
525 if( p==0 ) return;
526 decimal_result(context, p);
527 decimal_clear(p);
531 ** SQL Function: decimal_mul(X, Y)
533 ** Return the product of X and Y.
535 ** All significant digits after the decimal point are retained.
536 ** Trailing zeros after the decimal point are omitted as long as
537 ** the number of digits after the decimal point is no less than
538 ** either the number of digits in either input.
540 static void decimalMulFunc(
541 sqlite3_context *context,
542 int argc,
543 sqlite3_value **argv
545 Decimal *pA = decimal_new(context, argv[0], 0, 0);
546 Decimal *pB = decimal_new(context, argv[1], 0, 0);
547 signed char *acc = 0;
548 int i, j, k;
549 int minFrac;
550 UNUSED_PARAMETER(argc);
551 if( pA==0 || pA->oom || pA->isNull
552 || pB==0 || pB->oom || pB->isNull
554 goto mul_end;
556 acc = sqlite3_malloc64( pA->nDigit + pB->nDigit + 2 );
557 if( acc==0 ){
558 sqlite3_result_error_nomem(context);
559 goto mul_end;
561 memset(acc, 0, pA->nDigit + pB->nDigit + 2);
562 minFrac = pA->nFrac;
563 if( pB->nFrac<minFrac ) minFrac = pB->nFrac;
564 for(i=pA->nDigit-1; i>=0; i--){
565 signed char f = pA->a[i];
566 int carry = 0, x;
567 for(j=pB->nDigit-1, k=i+j+3; j>=0; j--, k--){
568 x = acc[k] + f*pB->a[j] + carry;
569 acc[k] = x%10;
570 carry = x/10;
572 x = acc[k] + carry;
573 acc[k] = x%10;
574 acc[k-1] += x/10;
576 sqlite3_free(pA->a);
577 pA->a = acc;
578 acc = 0;
579 pA->nDigit += pB->nDigit + 2;
580 pA->nFrac += pB->nFrac;
581 pA->sign ^= pB->sign;
582 while( pA->nFrac>minFrac && pA->a[pA->nDigit-1]==0 ){
583 pA->nFrac--;
584 pA->nDigit--;
586 decimal_result(context, pA);
588 mul_end:
589 sqlite3_free(acc);
590 decimal_free(pA);
591 decimal_free(pB);
594 #ifdef _WIN32
595 __declspec(dllexport)
596 #endif
597 int sqlite3_decimal_init(
598 sqlite3 *db,
599 char **pzErrMsg,
600 const sqlite3_api_routines *pApi
602 int rc = SQLITE_OK;
603 static const struct {
604 const char *zFuncName;
605 int nArg;
606 void (*xFunc)(sqlite3_context*,int,sqlite3_value**);
607 } aFunc[] = {
608 { "decimal", 1, decimalFunc },
609 { "decimal_cmp", 2, decimalCmpFunc },
610 { "decimal_add", 2, decimalAddFunc },
611 { "decimal_sub", 2, decimalSubFunc },
612 { "decimal_mul", 2, decimalMulFunc },
614 unsigned int i;
615 (void)pzErrMsg; /* Unused parameter */
617 SQLITE_EXTENSION_INIT2(pApi);
619 for(i=0; i<sizeof(aFunc)/sizeof(aFunc[0]) && rc==SQLITE_OK; i++){
620 rc = sqlite3_create_function(db, aFunc[i].zFuncName, aFunc[i].nArg,
621 SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC,
622 0, aFunc[i].xFunc, 0, 0);
624 if( rc==SQLITE_OK ){
625 rc = sqlite3_create_window_function(db, "decimal_sum", 1,
626 SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC, 0,
627 decimalSumStep, decimalSumFinalize,
628 decimalSumValue, decimalSumInverse, 0);
630 if( rc==SQLITE_OK ){
631 rc = sqlite3_create_collation(db, "decimal", SQLITE_UTF8,
632 0, decimalCollFunc);
634 return rc;