4 ** The author disclaims copyright to this source code. In place of
5 ** a legal notice, here is a blessing:
7 ** May you do good and not evil.
8 ** May you find forgiveness for yourself and forgive others.
9 ** May you share freely, never taking more than you give.
11 ******************************************************************************
13 ** Routines to implement arbitrary-precision decimal math.
15 ** The focus here is on simplicity and correctness, not performance.
17 #include "sqlite3ext.h"
18 SQLITE_EXTENSION_INIT1
24 /* Mark a function parameter as unused, to suppress nuisance compiler
26 #ifndef UNUSED_PARAMETER
27 # define UNUSED_PARAMETER(X) (void)(X)
31 /* A decimal object */
32 typedef struct Decimal Decimal
;
34 char sign
; /* 0 for positive, 1 for negative */
35 char oom
; /* True if an OOM is encountered */
36 char isNull
; /* True if holds a NULL rather than a number */
37 char isInit
; /* True upon initialization */
38 int nDigit
; /* Total number of digits */
39 int nFrac
; /* Number of digits to the right of the decimal point */
40 signed char *a
; /* Array of digits. Most significant first. */
44 ** Release memory held by a Decimal, but do not free the object itself.
46 static void decimal_clear(Decimal
*p
){
51 ** Destroy a Decimal object
53 static void decimal_free(Decimal
*p
){
61 ** Allocate a new Decimal object. Initialize it to the number given
62 ** by the input string.
64 static Decimal
*decimal_new(
65 sqlite3_context
*pCtx
,
68 const unsigned char *zAlt
72 const unsigned char *zIn
;
74 p
= sqlite3_malloc( sizeof(*p
) );
75 if( p
==0 ) goto new_no_mem
;
86 if( sqlite3_value_type(pIn
)==SQLITE_NULL
){
91 n
= sqlite3_value_bytes(pIn
);
92 zIn
= sqlite3_value_text(pIn
);
94 p
->a
= sqlite3_malloc64( n
+1 );
95 if( p
->a
==0 ) goto new_no_mem
;
96 for(i
=0; isspace(zIn
[i
]); i
++){}
100 }else if( zIn
[i
]=='+' ){
103 while( i
<n
&& zIn
[i
]=='0' ) i
++;
106 if( c
>='0' && c
<='9' ){
107 p
->a
[p
->nDigit
++] = c
- '0';
109 p
->nFrac
= p
->nDigit
+ 1;
110 }else if( c
=='e' || c
=='E' ){
117 }else if( zIn
[j
]=='+' ){
120 while( j
<n
&& iExp
<1000000 ){
121 if( zIn
[j
]>='0' && zIn
[j
]<='9' ){
122 iExp
= iExp
*10 + zIn
[j
] - '0';
126 if( neg
) iExp
= -iExp
;
132 p
->nFrac
= p
->nDigit
- (p
->nFrac
- 1);
136 if( iExp
<=p
->nFrac
){
145 p
->a
= sqlite3_realloc64(p
->a
, p
->nDigit
+ iExp
+ 1 );
146 if( p
->a
==0 ) goto new_no_mem
;
147 memset(p
->a
+p
->nDigit
, 0, iExp
);
153 nExtra
= p
->nDigit
- p
->nFrac
- 1;
160 p
->nFrac
= p
->nDigit
- 1;
164 p
->a
= sqlite3_realloc64(p
->a
, p
->nDigit
+ iExp
+ 1 );
165 if( p
->a
==0 ) goto new_no_mem
;
166 memmove(p
->a
+iExp
, p
->a
, p
->nDigit
);
167 memset(p
->a
, 0, iExp
);
175 if( pCtx
) sqlite3_result_error_nomem(pCtx
);
181 ** Make the given Decimal the result.
183 static void decimal_result(sqlite3_context
*pCtx
, Decimal
*p
){
187 if( p
==0 || p
->oom
){
188 sqlite3_result_error_nomem(pCtx
);
192 sqlite3_result_null(pCtx
);
195 z
= sqlite3_malloc( p
->nDigit
+4 );
197 sqlite3_result_error_nomem(pCtx
);
201 if( p
->nDigit
==0 || (p
->nDigit
==1 && p
->a
[0]==0) ){
208 n
= p
->nDigit
- p
->nFrac
;
213 while( n
>1 && p
->a
[j
]==0 ){
218 z
[i
++] = p
->a
[j
] + '0';
225 z
[i
++] = p
->a
[j
] + '0';
227 }while( j
<p
->nDigit
);
230 sqlite3_result_text(pCtx
, z
, i
, sqlite3_free
);
234 ** SQL Function: decimal(X)
236 ** Convert input X into decimal and then back into text
238 static void decimalFunc(
239 sqlite3_context
*context
,
243 Decimal
*p
= decimal_new(context
, argv
[0], 0, 0);
244 UNUSED_PARAMETER(argc
);
245 decimal_result(context
, p
);
250 ** Compare to Decimal objects. Return negative, 0, or positive if the
251 ** first object is less than, equal to, or greater than the second.
253 ** Preconditions for this routine:
260 static int decimal_cmp(const Decimal
*pA
, const Decimal
*pB
){
261 int nASig
, nBSig
, rc
, n
;
262 if( pA
->sign
!=pB
->sign
){
263 return pA
->sign
? -1 : +1;
266 const Decimal
*pTemp
= pA
;
270 nASig
= pA
->nDigit
- pA
->nFrac
;
271 nBSig
= pB
->nDigit
- pB
->nFrac
;
273 return nASig
- nBSig
;
276 if( n
>pB
->nDigit
) n
= pB
->nDigit
;
277 rc
= memcmp(pA
->a
, pB
->a
, n
);
279 rc
= pA
->nDigit
- pB
->nDigit
;
285 ** SQL Function: decimal_cmp(X, Y)
287 ** Return negative, zero, or positive if X is less then, equal to, or
290 static void decimalCmpFunc(
291 sqlite3_context
*context
,
295 Decimal
*pA
= 0, *pB
= 0;
298 UNUSED_PARAMETER(argc
);
299 pA
= decimal_new(context
, argv
[0], 0, 0);
300 if( pA
==0 || pA
->isNull
) goto cmp_done
;
301 pB
= decimal_new(context
, argv
[1], 0, 0);
302 if( pB
==0 || pB
->isNull
) goto cmp_done
;
303 rc
= decimal_cmp(pA
, pB
);
305 else if( rc
>0 ) rc
= +1;
306 sqlite3_result_int(context
, rc
);
313 ** Expand the Decimal so that it has a least nDigit digits and nFrac
314 ** digits to the right of the decimal point.
316 static void decimal_expand(Decimal
*p
, int nDigit
, int nFrac
){
320 nAddFrac
= nFrac
- p
->nFrac
;
321 nAddSig
= (nDigit
- p
->nDigit
) - nAddFrac
;
322 if( nAddFrac
==0 && nAddSig
==0 ) return;
323 p
->a
= sqlite3_realloc64(p
->a
, nDigit
+1);
329 memmove(p
->a
+nAddSig
, p
->a
, p
->nDigit
);
330 memset(p
->a
, 0, nAddSig
);
331 p
->nDigit
+= nAddSig
;
334 memset(p
->a
+p
->nDigit
, 0, nAddFrac
);
335 p
->nDigit
+= nAddFrac
;
336 p
->nFrac
+= nAddFrac
;
341 ** Add the value pB into pA.
343 ** Both pA and pB might become denormalized by this routine.
345 static void decimal_add(Decimal
*pA
, Decimal
*pB
){
346 int nSig
, nFrac
, nDigit
;
351 if( pA
->oom
|| pB
==0 || pB
->oom
){
355 if( pA
->isNull
|| pB
->isNull
){
359 nSig
= pA
->nDigit
- pA
->nFrac
;
360 if( nSig
&& pA
->a
[0]==0 ) nSig
--;
361 if( nSig
<pB
->nDigit
-pB
->nFrac
){
362 nSig
= pB
->nDigit
- pB
->nFrac
;
365 if( nFrac
<pB
->nFrac
) nFrac
= pB
->nFrac
;
366 nDigit
= nSig
+ nFrac
+ 1;
367 decimal_expand(pA
, nDigit
, nFrac
);
368 decimal_expand(pB
, nDigit
, nFrac
);
369 if( pA
->oom
|| pB
->oom
){
372 if( pA
->sign
==pB
->sign
){
374 for(i
=nDigit
-1; i
>=0; i
--){
375 int x
= pA
->a
[i
] + pB
->a
[i
] + carry
;
385 signed char *aA
, *aB
;
387 rc
= memcmp(pA
->a
, pB
->a
, nDigit
);
391 pA
->sign
= !pA
->sign
;
396 for(i
=nDigit
-1; i
>=0; i
--){
397 int x
= aA
[i
] - aB
[i
] - borrow
;
411 ** Compare text in decimal order.
413 static int decimalCollFunc(
415 int nKey1
, const void *pKey1
,
416 int nKey2
, const void *pKey2
418 const unsigned char *zA
= (const unsigned char*)pKey1
;
419 const unsigned char *zB
= (const unsigned char*)pKey2
;
420 Decimal
*pA
= decimal_new(0, 0, nKey1
, zA
);
421 Decimal
*pB
= decimal_new(0, 0, nKey2
, zB
);
423 UNUSED_PARAMETER(notUsed
);
424 if( pA
==0 || pB
==0 ){
427 rc
= decimal_cmp(pA
, pB
);
436 ** SQL Function: decimal_add(X, Y)
439 ** Return the sum or difference of X and Y.
441 static void decimalAddFunc(
442 sqlite3_context
*context
,
446 Decimal
*pA
= decimal_new(context
, argv
[0], 0, 0);
447 Decimal
*pB
= decimal_new(context
, argv
[1], 0, 0);
448 UNUSED_PARAMETER(argc
);
450 decimal_result(context
, pA
);
454 static void decimalSubFunc(
455 sqlite3_context
*context
,
459 Decimal
*pA
= decimal_new(context
, argv
[0], 0, 0);
460 Decimal
*pB
= decimal_new(context
, argv
[1], 0, 0);
461 UNUSED_PARAMETER(argc
);
463 pB
->sign
= !pB
->sign
;
465 decimal_result(context
, pA
);
471 /* Aggregate funcion: decimal_sum(X)
473 ** Works like sum() except that it uses decimal arithmetic for unlimited
476 static void decimalSumStep(
477 sqlite3_context
*context
,
483 UNUSED_PARAMETER(argc
);
484 p
= sqlite3_aggregate_context(context
, sizeof(*p
));
488 p
->a
= sqlite3_malloc(2);
497 if( sqlite3_value_type(argv
[0])==SQLITE_NULL
) return;
498 pArg
= decimal_new(context
, argv
[0], 0, 0);
499 decimal_add(p
, pArg
);
502 static void decimalSumInverse(
503 sqlite3_context
*context
,
509 UNUSED_PARAMETER(argc
);
510 p
= sqlite3_aggregate_context(context
, sizeof(*p
));
512 if( sqlite3_value_type(argv
[0])==SQLITE_NULL
) return;
513 pArg
= decimal_new(context
, argv
[0], 0, 0);
514 if( pArg
) pArg
->sign
= !pArg
->sign
;
515 decimal_add(p
, pArg
);
518 static void decimalSumValue(sqlite3_context
*context
){
519 Decimal
*p
= sqlite3_aggregate_context(context
, 0);
521 decimal_result(context
, p
);
523 static void decimalSumFinalize(sqlite3_context
*context
){
524 Decimal
*p
= sqlite3_aggregate_context(context
, 0);
526 decimal_result(context
, p
);
531 ** SQL Function: decimal_mul(X, Y)
533 ** Return the product of X and Y.
535 ** All significant digits after the decimal point are retained.
536 ** Trailing zeros after the decimal point are omitted as long as
537 ** the number of digits after the decimal point is no less than
538 ** either the number of digits in either input.
540 static void decimalMulFunc(
541 sqlite3_context
*context
,
545 Decimal
*pA
= decimal_new(context
, argv
[0], 0, 0);
546 Decimal
*pB
= decimal_new(context
, argv
[1], 0, 0);
547 signed char *acc
= 0;
550 UNUSED_PARAMETER(argc
);
551 if( pA
==0 || pA
->oom
|| pA
->isNull
552 || pB
==0 || pB
->oom
|| pB
->isNull
556 acc
= sqlite3_malloc64( pA
->nDigit
+ pB
->nDigit
+ 2 );
558 sqlite3_result_error_nomem(context
);
561 memset(acc
, 0, pA
->nDigit
+ pB
->nDigit
+ 2);
563 if( pB
->nFrac
<minFrac
) minFrac
= pB
->nFrac
;
564 for(i
=pA
->nDigit
-1; i
>=0; i
--){
565 signed char f
= pA
->a
[i
];
567 for(j
=pB
->nDigit
-1, k
=i
+j
+3; j
>=0; j
--, k
--){
568 x
= acc
[k
] + f
*pB
->a
[j
] + carry
;
579 pA
->nDigit
+= pB
->nDigit
+ 2;
580 pA
->nFrac
+= pB
->nFrac
;
581 pA
->sign
^= pB
->sign
;
582 while( pA
->nFrac
>minFrac
&& pA
->a
[pA
->nDigit
-1]==0 ){
586 decimal_result(context
, pA
);
595 __declspec(dllexport
)
597 int sqlite3_decimal_init(
600 const sqlite3_api_routines
*pApi
603 static const struct {
604 const char *zFuncName
;
606 void (*xFunc
)(sqlite3_context
*,int,sqlite3_value
**);
608 { "decimal", 1, decimalFunc
},
609 { "decimal_cmp", 2, decimalCmpFunc
},
610 { "decimal_add", 2, decimalAddFunc
},
611 { "decimal_sub", 2, decimalSubFunc
},
612 { "decimal_mul", 2, decimalMulFunc
},
615 (void)pzErrMsg
; /* Unused parameter */
617 SQLITE_EXTENSION_INIT2(pApi
);
619 for(i
=0; i
<sizeof(aFunc
)/sizeof(aFunc
[0]) && rc
==SQLITE_OK
; i
++){
620 rc
= sqlite3_create_function(db
, aFunc
[i
].zFuncName
, aFunc
[i
].nArg
,
621 SQLITE_UTF8
|SQLITE_INNOCUOUS
|SQLITE_DETERMINISTIC
,
622 0, aFunc
[i
].xFunc
, 0, 0);
625 rc
= sqlite3_create_window_function(db
, "decimal_sum", 1,
626 SQLITE_UTF8
|SQLITE_INNOCUOUS
|SQLITE_DETERMINISTIC
, 0,
627 decimalSumStep
, decimalSumFinalize
,
628 decimalSumValue
, decimalSumInverse
, 0);
631 rc
= sqlite3_create_collation(db
, "decimal", SQLITE_UTF8
,