summaryrefslogtreecommitdiffstats
path: root/ext/misc/decimal.c
diff options
context:
space:
mode:
Diffstat (limited to 'ext/misc/decimal.c')
-rw-r--r--ext/misc/decimal.c635
1 files changed, 635 insertions, 0 deletions
diff --git a/ext/misc/decimal.c b/ext/misc/decimal.c
new file mode 100644
index 0000000..37c6c2f
--- /dev/null
+++ b/ext/misc/decimal.c
@@ -0,0 +1,635 @@
+/*
+** 2020-06-22
+**
+** The author disclaims copyright to this source code. In place of
+** a legal notice, here is a blessing:
+**
+** May you do good and not evil.
+** May you find forgiveness for yourself and forgive others.
+** May you share freely, never taking more than you give.
+**
+******************************************************************************
+**
+** Routines to implement arbitrary-precision decimal math.
+**
+** The focus here is on simplicity and correctness, not performance.
+*/
+#include "sqlite3ext.h"
+SQLITE_EXTENSION_INIT1
+#include <assert.h>
+#include <string.h>
+#include <ctype.h>
+#include <stdlib.h>
+
+/* Mark a function parameter as unused, to suppress nuisance compiler
+** warnings. */
+#ifndef UNUSED_PARAMETER
+# define UNUSED_PARAMETER(X) (void)(X)
+#endif
+
+
+/* A decimal object */
+typedef struct Decimal Decimal;
+struct Decimal {
+ char sign; /* 0 for positive, 1 for negative */
+ char oom; /* True if an OOM is encountered */
+ char isNull; /* True if holds a NULL rather than a number */
+ char isInit; /* True upon initialization */
+ int nDigit; /* Total number of digits */
+ int nFrac; /* Number of digits to the right of the decimal point */
+ signed char *a; /* Array of digits. Most significant first. */
+};
+
+/*
+** Release memory held by a Decimal, but do not free the object itself.
+*/
+static void decimal_clear(Decimal *p){
+ sqlite3_free(p->a);
+}
+
+/*
+** Destroy a Decimal object
+*/
+static void decimal_free(Decimal *p){
+ if( p ){
+ decimal_clear(p);
+ sqlite3_free(p);
+ }
+}
+
+/*
+** Allocate a new Decimal object. Initialize it to the number given
+** by the input string.
+*/
+static Decimal *decimal_new(
+ sqlite3_context *pCtx,
+ sqlite3_value *pIn,
+ int nAlt,
+ const unsigned char *zAlt
+){
+ Decimal *p;
+ int n, i;
+ const unsigned char *zIn;
+ int iExp = 0;
+ p = sqlite3_malloc( sizeof(*p) );
+ if( p==0 ) goto new_no_mem;
+ p->sign = 0;
+ p->oom = 0;
+ p->isInit = 1;
+ p->isNull = 0;
+ p->nDigit = 0;
+ p->nFrac = 0;
+ if( zAlt ){
+ n = nAlt,
+ zIn = zAlt;
+ }else{
+ if( sqlite3_value_type(pIn)==SQLITE_NULL ){
+ p->a = 0;
+ p->isNull = 1;
+ return p;
+ }
+ n = sqlite3_value_bytes(pIn);
+ zIn = sqlite3_value_text(pIn);
+ }
+ p->a = sqlite3_malloc64( n+1 );
+ if( p->a==0 ) goto new_no_mem;
+ for(i=0; isspace(zIn[i]); i++){}
+ if( zIn[i]=='-' ){
+ p->sign = 1;
+ i++;
+ }else if( zIn[i]=='+' ){
+ i++;
+ }
+ while( i<n && zIn[i]=='0' ) i++;
+ while( i<n ){
+ char c = zIn[i];
+ if( c>='0' && c<='9' ){
+ p->a[p->nDigit++] = c - '0';
+ }else if( c=='.' ){
+ p->nFrac = p->nDigit + 1;
+ }else if( c=='e' || c=='E' ){
+ int j = i+1;
+ int neg = 0;
+ if( j>=n ) break;
+ if( zIn[j]=='-' ){
+ neg = 1;
+ j++;
+ }else if( zIn[j]=='+' ){
+ j++;
+ }
+ while( j<n && iExp<1000000 ){
+ if( zIn[j]>='0' && zIn[j]<='9' ){
+ iExp = iExp*10 + zIn[j] - '0';
+ }
+ j++;
+ }
+ if( neg ) iExp = -iExp;
+ break;
+ }
+ i++;
+ }
+ if( p->nFrac ){
+ p->nFrac = p->nDigit - (p->nFrac - 1);
+ }
+ if( iExp>0 ){
+ if( p->nFrac>0 ){
+ if( iExp<=p->nFrac ){
+ p->nFrac -= iExp;
+ iExp = 0;
+ }else{
+ iExp -= p->nFrac;
+ p->nFrac = 0;
+ }
+ }
+ if( iExp>0 ){
+ p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 );
+ if( p->a==0 ) goto new_no_mem;
+ memset(p->a+p->nDigit, 0, iExp);
+ p->nDigit += iExp;
+ }
+ }else if( iExp<0 ){
+ int nExtra;
+ iExp = -iExp;
+ nExtra = p->nDigit - p->nFrac - 1;
+ if( nExtra ){
+ if( nExtra>=iExp ){
+ p->nFrac += iExp;
+ iExp = 0;
+ }else{
+ iExp -= nExtra;
+ p->nFrac = p->nDigit - 1;
+ }
+ }
+ if( iExp>0 ){
+ p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 );
+ if( p->a==0 ) goto new_no_mem;
+ memmove(p->a+iExp, p->a, p->nDigit);
+ memset(p->a, 0, iExp);
+ p->nDigit += iExp;
+ p->nFrac += iExp;
+ }
+ }
+ return p;
+
+new_no_mem:
+ if( pCtx ) sqlite3_result_error_nomem(pCtx);
+ sqlite3_free(p);
+ return 0;
+}
+
+/*
+** Make the given Decimal the result.
+*/
+static void decimal_result(sqlite3_context *pCtx, Decimal *p){
+ char *z;
+ int i, j;
+ int n;
+ if( p==0 || p->oom ){
+ sqlite3_result_error_nomem(pCtx);
+ return;
+ }
+ if( p->isNull ){
+ sqlite3_result_null(pCtx);
+ return;
+ }
+ z = sqlite3_malloc( p->nDigit+4 );
+ if( z==0 ){
+ sqlite3_result_error_nomem(pCtx);
+ return;
+ }
+ i = 0;
+ if( p->nDigit==0 || (p->nDigit==1 && p->a[0]==0) ){
+ p->sign = 0;
+ }
+ if( p->sign ){
+ z[0] = '-';
+ i = 1;
+ }
+ n = p->nDigit - p->nFrac;
+ if( n<=0 ){
+ z[i++] = '0';
+ }
+ j = 0;
+ while( n>1 && p->a[j]==0 ){
+ j++;
+ n--;
+ }
+ while( n>0 ){
+ z[i++] = p->a[j] + '0';
+ j++;
+ n--;
+ }
+ if( p->nFrac ){
+ z[i++] = '.';
+ do{
+ z[i++] = p->a[j] + '0';
+ j++;
+ }while( j<p->nDigit );
+ }
+ z[i] = 0;
+ sqlite3_result_text(pCtx, z, i, sqlite3_free);
+}
+
+/*
+** SQL Function: decimal(X)
+**
+** Convert input X into decimal and then back into text
+*/
+static void decimalFunc(
+ sqlite3_context *context,
+ int argc,
+ sqlite3_value **argv
+){
+ Decimal *p = decimal_new(context, argv[0], 0, 0);
+ UNUSED_PARAMETER(argc);
+ decimal_result(context, p);
+ decimal_free(p);
+}
+
+/*
+** Compare to Decimal objects. Return negative, 0, or positive if the
+** first object is less than, equal to, or greater than the second.
+**
+** Preconditions for this routine:
+**
+** pA!=0
+** pA->isNull==0
+** pB!=0
+** pB->isNull==0
+*/
+static int decimal_cmp(const Decimal *pA, const Decimal *pB){
+ int nASig, nBSig, rc, n;
+ if( pA->sign!=pB->sign ){
+ return pA->sign ? -1 : +1;
+ }
+ if( pA->sign ){
+ const Decimal *pTemp = pA;
+ pA = pB;
+ pB = pTemp;
+ }
+ nASig = pA->nDigit - pA->nFrac;
+ nBSig = pB->nDigit - pB->nFrac;
+ if( nASig!=nBSig ){
+ return nASig - nBSig;
+ }
+ n = pA->nDigit;
+ if( n>pB->nDigit ) n = pB->nDigit;
+ rc = memcmp(pA->a, pB->a, n);
+ if( rc==0 ){
+ rc = pA->nDigit - pB->nDigit;
+ }
+ return rc;
+}
+
+/*
+** SQL Function: decimal_cmp(X, Y)
+**
+** Return negative, zero, or positive if X is less then, equal to, or
+** greater than Y.
+*/
+static void decimalCmpFunc(
+ sqlite3_context *context,
+ int argc,
+ sqlite3_value **argv
+){
+ Decimal *pA = 0, *pB = 0;
+ int rc;
+
+ UNUSED_PARAMETER(argc);
+ pA = decimal_new(context, argv[0], 0, 0);
+ if( pA==0 || pA->isNull ) goto cmp_done;
+ pB = decimal_new(context, argv[1], 0, 0);
+ if( pB==0 || pB->isNull ) goto cmp_done;
+ rc = decimal_cmp(pA, pB);
+ if( rc<0 ) rc = -1;
+ else if( rc>0 ) rc = +1;
+ sqlite3_result_int(context, rc);
+cmp_done:
+ decimal_free(pA);
+ decimal_free(pB);
+}
+
+/*
+** Expand the Decimal so that it has a least nDigit digits and nFrac
+** digits to the right of the decimal point.
+*/
+static void decimal_expand(Decimal *p, int nDigit, int nFrac){
+ int nAddSig;
+ int nAddFrac;
+ if( p==0 ) return;
+ nAddFrac = nFrac - p->nFrac;
+ nAddSig = (nDigit - p->nDigit) - nAddFrac;
+ if( nAddFrac==0 && nAddSig==0 ) return;
+ p->a = sqlite3_realloc64(p->a, nDigit+1);
+ if( p->a==0 ){
+ p->oom = 1;
+ return;
+ }
+ if( nAddSig ){
+ memmove(p->a+nAddSig, p->a, p->nDigit);
+ memset(p->a, 0, nAddSig);
+ p->nDigit += nAddSig;
+ }
+ if( nAddFrac ){
+ memset(p->a+p->nDigit, 0, nAddFrac);
+ p->nDigit += nAddFrac;
+ p->nFrac += nAddFrac;
+ }
+}
+
+/*
+** Add the value pB into pA.
+**
+** Both pA and pB might become denormalized by this routine.
+*/
+static void decimal_add(Decimal *pA, Decimal *pB){
+ int nSig, nFrac, nDigit;
+ int i, rc;
+ if( pA==0 ){
+ return;
+ }
+ if( pA->oom || pB==0 || pB->oom ){
+ pA->oom = 1;
+ return;
+ }
+ if( pA->isNull || pB->isNull ){
+ pA->isNull = 1;
+ return;
+ }
+ nSig = pA->nDigit - pA->nFrac;
+ if( nSig && pA->a[0]==0 ) nSig--;
+ if( nSig<pB->nDigit-pB->nFrac ){
+ nSig = pB->nDigit - pB->nFrac;
+ }
+ nFrac = pA->nFrac;
+ if( nFrac<pB->nFrac ) nFrac = pB->nFrac;
+ nDigit = nSig + nFrac + 1;
+ decimal_expand(pA, nDigit, nFrac);
+ decimal_expand(pB, nDigit, nFrac);
+ if( pA->oom || pB->oom ){
+ pA->oom = 1;
+ }else{
+ if( pA->sign==pB->sign ){
+ int carry = 0;
+ for(i=nDigit-1; i>=0; i--){
+ int x = pA->a[i] + pB->a[i] + carry;
+ if( x>=10 ){
+ carry = 1;
+ pA->a[i] = x - 10;
+ }else{
+ carry = 0;
+ pA->a[i] = x;
+ }
+ }
+ }else{
+ signed char *aA, *aB;
+ int borrow = 0;
+ rc = memcmp(pA->a, pB->a, nDigit);
+ if( rc<0 ){
+ aA = pB->a;
+ aB = pA->a;
+ pA->sign = !pA->sign;
+ }else{
+ aA = pA->a;
+ aB = pB->a;
+ }
+ for(i=nDigit-1; i>=0; i--){
+ int x = aA[i] - aB[i] - borrow;
+ if( x<0 ){
+ pA->a[i] = x+10;
+ borrow = 1;
+ }else{
+ pA->a[i] = x;
+ borrow = 0;
+ }
+ }
+ }
+ }
+}
+
+/*
+** Compare text in decimal order.
+*/
+static int decimalCollFunc(
+ void *notUsed,
+ int nKey1, const void *pKey1,
+ int nKey2, const void *pKey2
+){
+ const unsigned char *zA = (const unsigned char*)pKey1;
+ const unsigned char *zB = (const unsigned char*)pKey2;
+ Decimal *pA = decimal_new(0, 0, nKey1, zA);
+ Decimal *pB = decimal_new(0, 0, nKey2, zB);
+ int rc;
+ UNUSED_PARAMETER(notUsed);
+ if( pA==0 || pB==0 ){
+ rc = 0;
+ }else{
+ rc = decimal_cmp(pA, pB);
+ }
+ decimal_free(pA);
+ decimal_free(pB);
+ return rc;
+}
+
+
+/*
+** SQL Function: decimal_add(X, Y)
+** decimal_sub(X, Y)
+**
+** Return the sum or difference of X and Y.
+*/
+static void decimalAddFunc(
+ sqlite3_context *context,
+ int argc,
+ sqlite3_value **argv
+){
+ Decimal *pA = decimal_new(context, argv[0], 0, 0);
+ Decimal *pB = decimal_new(context, argv[1], 0, 0);
+ UNUSED_PARAMETER(argc);
+ decimal_add(pA, pB);
+ decimal_result(context, pA);
+ decimal_free(pA);
+ decimal_free(pB);
+}
+static void decimalSubFunc(
+ sqlite3_context *context,
+ int argc,
+ sqlite3_value **argv
+){
+ Decimal *pA = decimal_new(context, argv[0], 0, 0);
+ Decimal *pB = decimal_new(context, argv[1], 0, 0);
+ UNUSED_PARAMETER(argc);
+ if( pB ){
+ pB->sign = !pB->sign;
+ decimal_add(pA, pB);
+ decimal_result(context, pA);
+ }
+ decimal_free(pA);
+ decimal_free(pB);
+}
+
+/* Aggregate funcion: decimal_sum(X)
+**
+** Works like sum() except that it uses decimal arithmetic for unlimited
+** precision.
+*/
+static void decimalSumStep(
+ sqlite3_context *context,
+ int argc,
+ sqlite3_value **argv
+){
+ Decimal *p;
+ Decimal *pArg;
+ UNUSED_PARAMETER(argc);
+ p = sqlite3_aggregate_context(context, sizeof(*p));
+ if( p==0 ) return;
+ if( !p->isInit ){
+ p->isInit = 1;
+ p->a = sqlite3_malloc(2);
+ if( p->a==0 ){
+ p->oom = 1;
+ }else{
+ p->a[0] = 0;
+ }
+ p->nDigit = 1;
+ p->nFrac = 0;
+ }
+ if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return;
+ pArg = decimal_new(context, argv[0], 0, 0);
+ decimal_add(p, pArg);
+ decimal_free(pArg);
+}
+static void decimalSumInverse(
+ sqlite3_context *context,
+ int argc,
+ sqlite3_value **argv
+){
+ Decimal *p;
+ Decimal *pArg;
+ UNUSED_PARAMETER(argc);
+ p = sqlite3_aggregate_context(context, sizeof(*p));
+ if( p==0 ) return;
+ if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return;
+ pArg = decimal_new(context, argv[0], 0, 0);
+ if( pArg ) pArg->sign = !pArg->sign;
+ decimal_add(p, pArg);
+ decimal_free(pArg);
+}
+static void decimalSumValue(sqlite3_context *context){
+ Decimal *p = sqlite3_aggregate_context(context, 0);
+ if( p==0 ) return;
+ decimal_result(context, p);
+}
+static void decimalSumFinalize(sqlite3_context *context){
+ Decimal *p = sqlite3_aggregate_context(context, 0);
+ if( p==0 ) return;
+ decimal_result(context, p);
+ decimal_clear(p);
+}
+
+/*
+** SQL Function: decimal_mul(X, Y)
+**
+** Return the product of X and Y.
+**
+** All significant digits after the decimal point are retained.
+** Trailing zeros after the decimal point are omitted as long as
+** the number of digits after the decimal point is no less than
+** either the number of digits in either input.
+*/
+static void decimalMulFunc(
+ sqlite3_context *context,
+ int argc,
+ sqlite3_value **argv
+){
+ Decimal *pA = decimal_new(context, argv[0], 0, 0);
+ Decimal *pB = decimal_new(context, argv[1], 0, 0);
+ signed char *acc = 0;
+ int i, j, k;
+ int minFrac;
+ UNUSED_PARAMETER(argc);
+ if( pA==0 || pA->oom || pA->isNull
+ || pB==0 || pB->oom || pB->isNull
+ ){
+ goto mul_end;
+ }
+ acc = sqlite3_malloc64( pA->nDigit + pB->nDigit + 2 );
+ if( acc==0 ){
+ sqlite3_result_error_nomem(context);
+ goto mul_end;
+ }
+ memset(acc, 0, pA->nDigit + pB->nDigit + 2);
+ minFrac = pA->nFrac;
+ if( pB->nFrac<minFrac ) minFrac = pB->nFrac;
+ for(i=pA->nDigit-1; i>=0; i--){
+ signed char f = pA->a[i];
+ int carry = 0, x;
+ for(j=pB->nDigit-1, k=i+j+3; j>=0; j--, k--){
+ x = acc[k] + f*pB->a[j] + carry;
+ acc[k] = x%10;
+ carry = x/10;
+ }
+ x = acc[k] + carry;
+ acc[k] = x%10;
+ acc[k-1] += x/10;
+ }
+ sqlite3_free(pA->a);
+ pA->a = acc;
+ acc = 0;
+ pA->nDigit += pB->nDigit + 2;
+ pA->nFrac += pB->nFrac;
+ pA->sign ^= pB->sign;
+ while( pA->nFrac>minFrac && pA->a[pA->nDigit-1]==0 ){
+ pA->nFrac--;
+ pA->nDigit--;
+ }
+ decimal_result(context, pA);
+
+mul_end:
+ sqlite3_free(acc);
+ decimal_free(pA);
+ decimal_free(pB);
+}
+
+#ifdef _WIN32
+__declspec(dllexport)
+#endif
+int sqlite3_decimal_init(
+ sqlite3 *db,
+ char **pzErrMsg,
+ const sqlite3_api_routines *pApi
+){
+ int rc = SQLITE_OK;
+ static const struct {
+ const char *zFuncName;
+ int nArg;
+ void (*xFunc)(sqlite3_context*,int,sqlite3_value**);
+ } aFunc[] = {
+ { "decimal", 1, decimalFunc },
+ { "decimal_cmp", 2, decimalCmpFunc },
+ { "decimal_add", 2, decimalAddFunc },
+ { "decimal_sub", 2, decimalSubFunc },
+ { "decimal_mul", 2, decimalMulFunc },
+ };
+ unsigned int i;
+ (void)pzErrMsg; /* Unused parameter */
+
+ SQLITE_EXTENSION_INIT2(pApi);
+
+ for(i=0; i<sizeof(aFunc)/sizeof(aFunc[0]) && rc==SQLITE_OK; i++){
+ rc = sqlite3_create_function(db, aFunc[i].zFuncName, aFunc[i].nArg,
+ SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC,
+ 0, aFunc[i].xFunc, 0, 0);
+ }
+ if( rc==SQLITE_OK ){
+ rc = sqlite3_create_window_function(db, "decimal_sum", 1,
+ SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC, 0,
+ decimalSumStep, decimalSumFinalize,
+ decimalSumValue, decimalSumInverse, 0);
+ }
+ if( rc==SQLITE_OK ){
+ rc = sqlite3_create_collation(db, "decimal", SQLITE_UTF8,
+ 0, decimalCollFunc);
+ }
+ return rc;
+}