4 ** The author disclaims copyright to this source code. In place of
5 ** a legal notice, here is a blessing:
7 ** May you do good and not evil.
8 ** May you find forgiveness for yourself and forgive others.
9 ** May you share freely, never taking more than you give.
11 ******************************************************************************
13 ** Routines to implement arbitrary-precision decimal math.
15 ** The focus here is on simplicity and correctness, not performance.
17 #include "sqlite3ext.h"
18 SQLITE_EXTENSION_INIT1
24 /* Mark a function parameter as unused, to suppress nuisance compiler
26 #ifndef UNUSED_PARAMETER
27 # define UNUSED_PARAMETER(X) (void)(X)
31 /* A decimal object */
32 typedef struct Decimal Decimal
;
34 char sign
; /* 0 for positive, 1 for negative */
35 char oom
; /* True if an OOM is encountered */
36 char isNull
; /* True if holds a NULL rather than a number */
37 char isInit
; /* True upon initialization */
38 int nDigit
; /* Total number of digits */
39 int nFrac
; /* Number of digits to the right of the decimal point */
40 signed char *a
; /* Array of digits. Most significant first. */
44 ** Release memory held by a Decimal, but do not free the object itself.
46 static void decimal_clear(Decimal
*p
){
51 ** Destroy a Decimal object
53 static void decimal_free(Decimal
*p
){
61 ** Allocate a new Decimal object initialized to the text in zIn[].
62 ** Return NULL if any kind of error occurs.
64 static Decimal
*decimalNewFromText(const char *zIn
, int n
){
69 p
= sqlite3_malloc( sizeof(*p
) );
70 if( p
==0 ) goto new_from_text_failed
;
77 p
->a
= sqlite3_malloc64( n
+1 );
78 if( p
->a
==0 ) goto new_from_text_failed
;
79 for(i
=0; isspace(zIn
[i
]); i
++){}
83 }else if( zIn
[i
]=='+' ){
86 while( i
<n
&& zIn
[i
]=='0' ) i
++;
89 if( c
>='0' && c
<='9' ){
90 p
->a
[p
->nDigit
++] = c
- '0';
92 p
->nFrac
= p
->nDigit
+ 1;
93 }else if( c
=='e' || c
=='E' ){
100 }else if( zIn
[j
]=='+' ){
103 while( j
<n
&& iExp
<1000000 ){
104 if( zIn
[j
]>='0' && zIn
[j
]<='9' ){
105 iExp
= iExp
*10 + zIn
[j
] - '0';
109 if( neg
) iExp
= -iExp
;
115 p
->nFrac
= p
->nDigit
- (p
->nFrac
- 1);
119 if( iExp
<=p
->nFrac
){
128 p
->a
= sqlite3_realloc64(p
->a
, p
->nDigit
+ iExp
+ 1 );
129 if( p
->a
==0 ) goto new_from_text_failed
;
130 memset(p
->a
+p
->nDigit
, 0, iExp
);
136 nExtra
= p
->nDigit
- p
->nFrac
- 1;
143 p
->nFrac
= p
->nDigit
- 1;
147 p
->a
= sqlite3_realloc64(p
->a
, p
->nDigit
+ iExp
+ 1 );
148 if( p
->a
==0 ) goto new_from_text_failed
;
149 memmove(p
->a
+iExp
, p
->a
, p
->nDigit
);
150 memset(p
->a
, 0, iExp
);
157 new_from_text_failed
:
159 if( p
->a
) sqlite3_free(p
->a
);
165 /* Forward reference */
166 static Decimal
*decimalFromDouble(double);
169 ** Allocate a new Decimal object from an sqlite3_value. Return a pointer
170 ** to the new object, or NULL if there is an error. If the pCtx argument
171 ** is not NULL, then errors are reported on it as well.
173 ** If the pIn argument is SQLITE_TEXT or SQLITE_INTEGER, it is converted
174 ** directly into a Decimal. For SQLITE_FLOAT or for SQLITE_BLOB of length
175 ** 8 bytes, the resulting double value is expanded into its decimal equivalent.
176 ** If pIn is NULL or if it is a BLOB that is not exactly 8 bytes in length,
177 ** then NULL is returned.
179 static Decimal
*decimal_new(
180 sqlite3_context
*pCtx
, /* Report error here, if not null */
181 sqlite3_value
*pIn
, /* Construct the decimal object from this */
182 int bTextOnly
/* Always interpret pIn as text if true */
185 int eType
= sqlite3_value_type(pIn
);
186 if( bTextOnly
&& (eType
==SQLITE_FLOAT
|| eType
==SQLITE_BLOB
) ){
191 case SQLITE_INTEGER
: {
192 const char *zIn
= (const char*)sqlite3_value_text(pIn
);
193 int n
= sqlite3_value_bytes(pIn
);
194 p
= decimalNewFromText(zIn
, n
);
195 if( p
==0 ) goto new_failed
;
200 p
= decimalFromDouble(sqlite3_value_double(pIn
));
205 const unsigned char *x
;
207 sqlite3_uint64 v
= 0;
210 if( sqlite3_value_bytes(pIn
)!=sizeof(r
) ) break;
211 x
= sqlite3_value_blob(pIn
);
212 for(i
=0; i
<sizeof(r
); i
++){
215 memcpy(&r
, &v
, sizeof(r
));
216 p
= decimalFromDouble(r
);
227 if( pCtx
) sqlite3_result_error_nomem(pCtx
);
233 ** Make the given Decimal the result.
235 static void decimal_result(sqlite3_context
*pCtx
, Decimal
*p
){
239 if( p
==0 || p
->oom
){
240 sqlite3_result_error_nomem(pCtx
);
244 sqlite3_result_null(pCtx
);
247 z
= sqlite3_malloc( p
->nDigit
+4 );
249 sqlite3_result_error_nomem(pCtx
);
253 if( p
->nDigit
==0 || (p
->nDigit
==1 && p
->a
[0]==0) ){
260 n
= p
->nDigit
- p
->nFrac
;
265 while( n
>1 && p
->a
[j
]==0 ){
270 z
[i
++] = p
->a
[j
] + '0';
277 z
[i
++] = p
->a
[j
] + '0';
279 }while( j
<p
->nDigit
);
282 sqlite3_result_text(pCtx
, z
, i
, sqlite3_free
);
286 ** Make the given Decimal the result in an format similar to '%+#e'.
287 ** In other words, show exponential notation with leading and trailing
290 static void decimal_result_sci(sqlite3_context
*pCtx
, Decimal
*p
){
291 char *z
; /* The output buffer */
292 int i
; /* Loop counter */
293 int nZero
; /* Number of leading zeros */
294 int nDigit
; /* Number of digits not counting trailing zeros */
295 int nFrac
; /* Digits to the right of the decimal point */
296 int exp
; /* Exponent value */
297 signed char zero
; /* Zero value */
298 signed char *a
; /* Array of digits */
300 if( p
==0 || p
->oom
){
301 sqlite3_result_error_nomem(pCtx
);
305 sqlite3_result_null(pCtx
);
308 for(nDigit
=p
->nDigit
; nDigit
>0 && p
->a
[nDigit
-1]==0; nDigit
--){}
309 for(nZero
=0; nZero
<nDigit
&& p
->a
[nZero
]==0; nZero
++){}
310 nFrac
= p
->nFrac
+ (nDigit
- p
->nDigit
);
312 z
= sqlite3_malloc( nDigit
+20 );
314 sqlite3_result_error_nomem(pCtx
);
325 if( p
->sign
&& nDigit
>0 ){
336 for(i
=1; i
<nDigit
; i
++){
341 exp
= nDigit
- nFrac
- 1;
342 sqlite3_snprintf(nDigit
+20-i
, &z
[i
], "e%+03d", exp
);
343 sqlite3_result_text(pCtx
, z
, -1, sqlite3_free
);
347 ** Compare to Decimal objects. Return negative, 0, or positive if the
348 ** first object is less than, equal to, or greater than the second.
350 ** Preconditions for this routine:
357 static int decimal_cmp(const Decimal
*pA
, const Decimal
*pB
){
358 int nASig
, nBSig
, rc
, n
;
359 if( pA
->sign
!=pB
->sign
){
360 return pA
->sign
? -1 : +1;
363 const Decimal
*pTemp
= pA
;
367 nASig
= pA
->nDigit
- pA
->nFrac
;
368 nBSig
= pB
->nDigit
- pB
->nFrac
;
370 return nASig
- nBSig
;
373 if( n
>pB
->nDigit
) n
= pB
->nDigit
;
374 rc
= memcmp(pA
->a
, pB
->a
, n
);
376 rc
= pA
->nDigit
- pB
->nDigit
;
382 ** SQL Function: decimal_cmp(X, Y)
384 ** Return negative, zero, or positive if X is less then, equal to, or
387 static void decimalCmpFunc(
388 sqlite3_context
*context
,
392 Decimal
*pA
= 0, *pB
= 0;
395 UNUSED_PARAMETER(argc
);
396 pA
= decimal_new(context
, argv
[0], 1);
397 if( pA
==0 || pA
->isNull
) goto cmp_done
;
398 pB
= decimal_new(context
, argv
[1], 1);
399 if( pB
==0 || pB
->isNull
) goto cmp_done
;
400 rc
= decimal_cmp(pA
, pB
);
402 else if( rc
>0 ) rc
= +1;
403 sqlite3_result_int(context
, rc
);
410 ** Expand the Decimal so that it has a least nDigit digits and nFrac
411 ** digits to the right of the decimal point.
413 static void decimal_expand(Decimal
*p
, int nDigit
, int nFrac
){
417 nAddFrac
= nFrac
- p
->nFrac
;
418 nAddSig
= (nDigit
- p
->nDigit
) - nAddFrac
;
419 if( nAddFrac
==0 && nAddSig
==0 ) return;
420 p
->a
= sqlite3_realloc64(p
->a
, nDigit
+1);
426 memmove(p
->a
+nAddSig
, p
->a
, p
->nDigit
);
427 memset(p
->a
, 0, nAddSig
);
428 p
->nDigit
+= nAddSig
;
431 memset(p
->a
+p
->nDigit
, 0, nAddFrac
);
432 p
->nDigit
+= nAddFrac
;
433 p
->nFrac
+= nAddFrac
;
438 ** Add the value pB into pA. A := A + B.
440 ** Both pA and pB might become denormalized by this routine.
442 static void decimal_add(Decimal
*pA
, Decimal
*pB
){
443 int nSig
, nFrac
, nDigit
;
448 if( pA
->oom
|| pB
==0 || pB
->oom
){
452 if( pA
->isNull
|| pB
->isNull
){
456 nSig
= pA
->nDigit
- pA
->nFrac
;
457 if( nSig
&& pA
->a
[0]==0 ) nSig
--;
458 if( nSig
<pB
->nDigit
-pB
->nFrac
){
459 nSig
= pB
->nDigit
- pB
->nFrac
;
462 if( nFrac
<pB
->nFrac
) nFrac
= pB
->nFrac
;
463 nDigit
= nSig
+ nFrac
+ 1;
464 decimal_expand(pA
, nDigit
, nFrac
);
465 decimal_expand(pB
, nDigit
, nFrac
);
466 if( pA
->oom
|| pB
->oom
){
469 if( pA
->sign
==pB
->sign
){
471 for(i
=nDigit
-1; i
>=0; i
--){
472 int x
= pA
->a
[i
] + pB
->a
[i
] + carry
;
482 signed char *aA
, *aB
;
484 rc
= memcmp(pA
->a
, pB
->a
, nDigit
);
488 pA
->sign
= !pA
->sign
;
493 for(i
=nDigit
-1; i
>=0; i
--){
494 int x
= aA
[i
] - aB
[i
] - borrow
;
508 ** Multiply A by B. A := A * B
510 ** All significant digits after the decimal point are retained.
511 ** Trailing zeros after the decimal point are omitted as long as
512 ** the number of digits after the decimal point is no less than
513 ** either the number of digits in either input.
515 static void decimalMul(Decimal
*pA
, Decimal
*pB
){
516 signed char *acc
= 0;
520 if( pA
==0 || pA
->oom
|| pA
->isNull
521 || pB
==0 || pB
->oom
|| pB
->isNull
525 acc
= sqlite3_malloc64( pA
->nDigit
+ pB
->nDigit
+ 2 );
530 memset(acc
, 0, pA
->nDigit
+ pB
->nDigit
+ 2);
532 if( pB
->nFrac
<minFrac
) minFrac
= pB
->nFrac
;
533 for(i
=pA
->nDigit
-1; i
>=0; i
--){
534 signed char f
= pA
->a
[i
];
536 for(j
=pB
->nDigit
-1, k
=i
+j
+3; j
>=0; j
--, k
--){
537 x
= acc
[k
] + f
*pB
->a
[j
] + carry
;
548 pA
->nDigit
+= pB
->nDigit
+ 2;
549 pA
->nFrac
+= pB
->nFrac
;
550 pA
->sign
^= pB
->sign
;
551 while( pA
->nFrac
>minFrac
&& pA
->a
[pA
->nDigit
-1]==0 ){
561 ** Create a new Decimal object that contains an integer power of 2.
563 static Decimal
*decimalPow2(int N
){
564 Decimal
*pA
= 0; /* The result to be returned */
565 Decimal
*pX
= 0; /* Multiplier */
566 if( N
<-20000 || N
>20000 ) goto pow2_fault
;
567 pA
= decimalNewFromText("1.0", 3);
568 if( pA
==0 || pA
->oom
) goto pow2_fault
;
569 if( N
==0 ) return pA
;
571 pX
= decimalNewFromText("2.0", 3);
574 pX
= decimalNewFromText("0.5", 3);
576 if( pX
==0 || pX
->oom
) goto pow2_fault
;
577 while( 1 /* Exit by break */ ){
580 if( pA
->oom
) goto pow2_fault
;
596 ** Use an IEEE754 binary64 ("double") to generate a new Decimal object.
598 static Decimal
*decimalFromDouble(double r
){
611 memcpy(&a
,&r
,sizeof(a
));
617 m
= a
& ((((sqlite3_int64
)1)<<52)-1);
621 m
|= ((sqlite3_int64
)1)<<52;
623 while( e
<1075 && m
>0 && (m
&1)==0 ){
630 return 0; /* A NaN or an Infinity */
634 /* At this point m is the integer significand and e is the exponent */
635 sqlite3_snprintf(sizeof(zNum
), zNum
, "%lld", m
);
636 pA
= decimalNewFromText(zNum
, (int)strlen(zNum
));
644 ** SQL Function: decimal(X)
645 ** OR: decimal_exp(X)
647 ** Convert input X into decimal and then back into text.
649 ** If X is originally a float, then a full decimal expansion of that floating
650 ** point value is done. Or if X is an 8-byte blob, it is interpreted
651 ** as a float and similarly expanded.
653 ** The decimal_exp(X) function returns the result in exponential notation.
654 ** decimal(X) returns a complete decimal, without the e+NNN at the end.
656 static void decimalFunc(
657 sqlite3_context
*context
,
661 Decimal
*p
= decimal_new(context
, argv
[0], 0);
662 UNUSED_PARAMETER(argc
);
664 if( sqlite3_user_data(context
)!=0 ){
665 decimal_result_sci(context
, p
);
667 decimal_result(context
, p
);
674 ** Compare text in decimal order.
676 static int decimalCollFunc(
678 int nKey1
, const void *pKey1
,
679 int nKey2
, const void *pKey2
681 const unsigned char *zA
= (const unsigned char*)pKey1
;
682 const unsigned char *zB
= (const unsigned char*)pKey2
;
683 Decimal
*pA
= decimalNewFromText((const char*)zA
, nKey1
);
684 Decimal
*pB
= decimalNewFromText((const char*)zB
, nKey2
);
686 UNUSED_PARAMETER(notUsed
);
687 if( pA
==0 || pB
==0 ){
690 rc
= decimal_cmp(pA
, pB
);
699 ** SQL Function: decimal_add(X, Y)
702 ** Return the sum or difference of X and Y.
704 static void decimalAddFunc(
705 sqlite3_context
*context
,
709 Decimal
*pA
= decimal_new(context
, argv
[0], 1);
710 Decimal
*pB
= decimal_new(context
, argv
[1], 1);
711 UNUSED_PARAMETER(argc
);
713 decimal_result(context
, pA
);
717 static void decimalSubFunc(
718 sqlite3_context
*context
,
722 Decimal
*pA
= decimal_new(context
, argv
[0], 1);
723 Decimal
*pB
= decimal_new(context
, argv
[1], 1);
724 UNUSED_PARAMETER(argc
);
726 pB
->sign
= !pB
->sign
;
728 decimal_result(context
, pA
);
734 /* Aggregate funcion: decimal_sum(X)
736 ** Works like sum() except that it uses decimal arithmetic for unlimited
739 static void decimalSumStep(
740 sqlite3_context
*context
,
746 UNUSED_PARAMETER(argc
);
747 p
= sqlite3_aggregate_context(context
, sizeof(*p
));
751 p
->a
= sqlite3_malloc(2);
760 if( sqlite3_value_type(argv
[0])==SQLITE_NULL
) return;
761 pArg
= decimal_new(context
, argv
[0], 1);
762 decimal_add(p
, pArg
);
765 static void decimalSumInverse(
766 sqlite3_context
*context
,
772 UNUSED_PARAMETER(argc
);
773 p
= sqlite3_aggregate_context(context
, sizeof(*p
));
775 if( sqlite3_value_type(argv
[0])==SQLITE_NULL
) return;
776 pArg
= decimal_new(context
, argv
[0], 1);
777 if( pArg
) pArg
->sign
= !pArg
->sign
;
778 decimal_add(p
, pArg
);
781 static void decimalSumValue(sqlite3_context
*context
){
782 Decimal
*p
= sqlite3_aggregate_context(context
, 0);
784 decimal_result(context
, p
);
786 static void decimalSumFinalize(sqlite3_context
*context
){
787 Decimal
*p
= sqlite3_aggregate_context(context
, 0);
789 decimal_result(context
, p
);
794 ** SQL Function: decimal_mul(X, Y)
796 ** Return the product of X and Y.
798 static void decimalMulFunc(
799 sqlite3_context
*context
,
803 Decimal
*pA
= decimal_new(context
, argv
[0], 1);
804 Decimal
*pB
= decimal_new(context
, argv
[1], 1);
805 UNUSED_PARAMETER(argc
);
806 if( pA
==0 || pA
->oom
|| pA
->isNull
807 || pB
==0 || pB
->oom
|| pB
->isNull
815 decimal_result(context
, pA
);
823 ** SQL Function: decimal_pow2(N)
825 ** Return the N-th power of 2. N must be an integer.
827 static void decimalPow2Func(
828 sqlite3_context
*context
,
832 UNUSED_PARAMETER(argc
);
833 if( sqlite3_value_type(argv
[0])==SQLITE_INTEGER
){
834 Decimal
*pA
= decimalPow2(sqlite3_value_int(argv
[0]));
835 decimal_result_sci(context
, pA
);
841 __declspec(dllexport
)
843 int sqlite3_decimal_init(
846 const sqlite3_api_routines
*pApi
849 static const struct {
850 const char *zFuncName
;
853 void (*xFunc
)(sqlite3_context
*,int,sqlite3_value
**);
855 { "decimal", 1, 0, decimalFunc
},
856 { "decimal_exp", 1, 1, decimalFunc
},
857 { "decimal_cmp", 2, 0, decimalCmpFunc
},
858 { "decimal_add", 2, 0, decimalAddFunc
},
859 { "decimal_sub", 2, 0, decimalSubFunc
},
860 { "decimal_mul", 2, 0, decimalMulFunc
},
861 { "decimal_pow2", 1, 0, decimalPow2Func
},
864 (void)pzErrMsg
; /* Unused parameter */
866 SQLITE_EXTENSION_INIT2(pApi
);
868 for(i
=0; i
<(int)(sizeof(aFunc
)/sizeof(aFunc
[0])) && rc
==SQLITE_OK
; i
++){
869 rc
= sqlite3_create_function(db
, aFunc
[i
].zFuncName
, aFunc
[i
].nArg
,
870 SQLITE_UTF8
|SQLITE_INNOCUOUS
|SQLITE_DETERMINISTIC
,
871 aFunc
[i
].iArg
? db
: 0, aFunc
[i
].xFunc
, 0, 0);
874 rc
= sqlite3_create_window_function(db
, "decimal_sum", 1,
875 SQLITE_UTF8
|SQLITE_INNOCUOUS
|SQLITE_DETERMINISTIC
, 0,
876 decimalSumStep
, decimalSumFinalize
,
877 decimalSumValue
, decimalSumInverse
, 0);
880 rc
= sqlite3_create_collation(db
, "decimal", SQLITE_UTF8
,