diff options
Diffstat (limited to '')
60 files changed, 23389 insertions, 0 deletions
diff --git a/src/math/big/accuracy_string.go b/src/math/big/accuracy_string.go new file mode 100644 index 0000000..1501ace --- /dev/null +++ b/src/math/big/accuracy_string.go @@ -0,0 +1,17 @@ +// Code generated by "stringer -type=Accuracy"; DO NOT EDIT. + +package big + +import "strconv" + +const _Accuracy_name = "BelowExactAbove" + +var _Accuracy_index = [...]uint8{0, 5, 10, 15} + +func (i Accuracy) String() string { + i -= -1 + if i < 0 || i >= Accuracy(len(_Accuracy_index)-1) { + return "Accuracy(" + strconv.FormatInt(int64(i+-1), 10) + ")" + } + return _Accuracy_name[_Accuracy_index[i]:_Accuracy_index[i+1]] +} diff --git a/src/math/big/alias_test.go b/src/math/big/alias_test.go new file mode 100644 index 0000000..36c37fb --- /dev/null +++ b/src/math/big/alias_test.go @@ -0,0 +1,312 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big_test + +import ( + cryptorand "crypto/rand" + "math/big" + "math/rand" + "reflect" + "testing" + "testing/quick" +) + +func equal(z, x *big.Int) bool { + return z.Cmp(x) == 0 +} + +type bigInt struct { + *big.Int +} + +func generatePositiveInt(rand *rand.Rand, size int) *big.Int { + n := big.NewInt(1) + n.Lsh(n, uint(rand.Intn(size*8))) + n.Rand(rand, n) + return n +} + +func (bigInt) Generate(rand *rand.Rand, size int) reflect.Value { + n := generatePositiveInt(rand, size) + if rand.Intn(4) == 0 { + n.Neg(n) + } + return reflect.ValueOf(bigInt{n}) +} + +type notZeroInt struct { + *big.Int +} + +func (notZeroInt) Generate(rand *rand.Rand, size int) reflect.Value { + n := generatePositiveInt(rand, size) + if rand.Intn(4) == 0 { + n.Neg(n) + } + if n.Sign() == 0 { + n.SetInt64(1) + } + return reflect.ValueOf(notZeroInt{n}) +} + +type positiveInt struct { + *big.Int +} + +func (positiveInt) Generate(rand *rand.Rand, size int) reflect.Value { + n := generatePositiveInt(rand, size) + return reflect.ValueOf(positiveInt{n}) +} + +type prime struct { + *big.Int +} + +func (prime) Generate(r *rand.Rand, size int) reflect.Value { + n, err := cryptorand.Prime(r, r.Intn(size*8-2)+2) + if err != nil { + panic(err) + } + return reflect.ValueOf(prime{n}) +} + +type zeroOrOne struct { + uint +} + +func (zeroOrOne) Generate(rand *rand.Rand, size int) reflect.Value { + return reflect.ValueOf(zeroOrOne{uint(rand.Intn(2))}) +} + +type smallUint struct { + uint +} + +func (smallUint) Generate(rand *rand.Rand, size int) reflect.Value { + return reflect.ValueOf(smallUint{uint(rand.Intn(1024))}) +} + +// checkAliasingOneArg checks if f returns a correct result when v and x alias. +// +// f is a function that takes x as an argument, doesn't modify it, sets v to the +// result, and returns v. It is the function signature of unbound methods like +// +// func (v *big.Int) m(x *big.Int) *big.Int +// +// v and x are two random Int values. v is randomized even if it will be +// overwritten to test for improper buffer reuse. +func checkAliasingOneArg(t *testing.T, f func(v, x *big.Int) *big.Int, v, x *big.Int) bool { + x1, v1 := new(big.Int).Set(x), new(big.Int).Set(x) + + // Calculate a reference f(x) without aliasing. + if out := f(v, x); out != v { + return false + } + + // Test aliasing the argument and the receiver. + if out := f(v1, v1); out != v1 || !equal(v1, v) { + t.Logf("f(v, x) != f(x, x)") + return false + } + + // Ensure the arguments was not modified. + return equal(x, x1) +} + +// checkAliasingTwoArgs checks if f returns a correct result when any +// combination of v, x and y alias. +// +// f is a function that takes x and y as arguments, doesn't modify them, sets v +// to the result, and returns v. It is the function signature of unbound methods +// like +// +// func (v *big.Int) m(x, y *big.Int) *big.Int +// +// v, x and y are random Int values. v is randomized even if it will be +// overwritten to test for improper buffer reuse. +func checkAliasingTwoArgs(t *testing.T, f func(v, x, y *big.Int) *big.Int, v, x, y *big.Int) bool { + x1, y1, v1 := new(big.Int).Set(x), new(big.Int).Set(y), new(big.Int).Set(v) + + // Calculate a reference f(x, y) without aliasing. + if out := f(v, x, y); out == nil { + // Certain functions like ModInverse return nil for certain inputs. + // Check that receiver and arguments were unchanged and move on. + return equal(x, x1) && equal(y, y1) && equal(v, v1) + } else if out != v { + return false + } + + // Test aliasing the first argument and the receiver. + v1.Set(x) + if out := f(v1, v1, y); out != v1 || !equal(v1, v) { + t.Logf("f(v, x, y) != f(x, x, y)") + return false + } + // Test aliasing the second argument and the receiver. + v1.Set(y) + if out := f(v1, x, v1); out != v1 || !equal(v1, v) { + t.Logf("f(v, x, y) != f(y, x, y)") + return false + } + + // Calculate a reference f(y, y) without aliasing. + // We use y because it's the one that commonly has restrictions + // like being prime or non-zero. + v1.Set(v) + y2 := new(big.Int).Set(y) + if out := f(v, y, y2); out == nil { + return equal(y, y1) && equal(y2, y1) && equal(v, v1) + } else if out != v { + return false + } + + // Test aliasing the two arguments. + if out := f(v1, y, y); out != v1 || !equal(v1, v) { + t.Logf("f(v, y1, y2) != f(v, y, y)") + return false + } + // Test aliasing the two arguments and the receiver. + v1.Set(y) + if out := f(v1, v1, v1); out != v1 || !equal(v1, v) { + t.Logf("f(v, y1, y2) != f(y, y, y)") + return false + } + + // Ensure the arguments were not modified. + return equal(x, x1) && equal(y, y1) +} + +func TestAliasing(t *testing.T) { + for name, f := range map[string]interface{}{ + "Abs": func(v, x bigInt) bool { + return checkAliasingOneArg(t, (*big.Int).Abs, v.Int, x.Int) + }, + "Add": func(v, x, y bigInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).Add, v.Int, x.Int, y.Int) + }, + "And": func(v, x, y bigInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).And, v.Int, x.Int, y.Int) + }, + "AndNot": func(v, x, y bigInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).AndNot, v.Int, x.Int, y.Int) + }, + "Div": func(v, x bigInt, y notZeroInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).Div, v.Int, x.Int, y.Int) + }, + "Exp-XY": func(v, x, y bigInt, z notZeroInt) bool { + return checkAliasingTwoArgs(t, func(v, x, y *big.Int) *big.Int { + return v.Exp(x, y, z.Int) + }, v.Int, x.Int, y.Int) + }, + "Exp-XZ": func(v, x, y bigInt, z notZeroInt) bool { + return checkAliasingTwoArgs(t, func(v, x, z *big.Int) *big.Int { + return v.Exp(x, y.Int, z) + }, v.Int, x.Int, z.Int) + }, + "Exp-YZ": func(v, x, y bigInt, z notZeroInt) bool { + return checkAliasingTwoArgs(t, func(v, y, z *big.Int) *big.Int { + return v.Exp(x.Int, y, z) + }, v.Int, y.Int, z.Int) + }, + "GCD": func(v, x, y bigInt) bool { + return checkAliasingTwoArgs(t, func(v, x, y *big.Int) *big.Int { + return v.GCD(nil, nil, x, y) + }, v.Int, x.Int, y.Int) + }, + "GCD-X": func(v, x, y bigInt) bool { + a, b := new(big.Int), new(big.Int) + return checkAliasingTwoArgs(t, func(v, x, y *big.Int) *big.Int { + a.GCD(v, b, x, y) + return v + }, v.Int, x.Int, y.Int) + }, + "GCD-Y": func(v, x, y bigInt) bool { + a, b := new(big.Int), new(big.Int) + return checkAliasingTwoArgs(t, func(v, x, y *big.Int) *big.Int { + a.GCD(b, v, x, y) + return v + }, v.Int, x.Int, y.Int) + }, + "Lsh": func(v, x bigInt, n smallUint) bool { + return checkAliasingOneArg(t, func(v, x *big.Int) *big.Int { + return v.Lsh(x, n.uint) + }, v.Int, x.Int) + }, + "Mod": func(v, x bigInt, y notZeroInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).Mod, v.Int, x.Int, y.Int) + }, + "ModInverse": func(v, x bigInt, y notZeroInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).ModInverse, v.Int, x.Int, y.Int) + }, + "ModSqrt": func(v, x bigInt, p prime) bool { + return checkAliasingTwoArgs(t, (*big.Int).ModSqrt, v.Int, x.Int, p.Int) + }, + "Mul": func(v, x, y bigInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).Mul, v.Int, x.Int, y.Int) + }, + "Neg": func(v, x bigInt) bool { + return checkAliasingOneArg(t, (*big.Int).Neg, v.Int, x.Int) + }, + "Not": func(v, x bigInt) bool { + return checkAliasingOneArg(t, (*big.Int).Not, v.Int, x.Int) + }, + "Or": func(v, x, y bigInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).Or, v.Int, x.Int, y.Int) + }, + "Quo": func(v, x bigInt, y notZeroInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).Quo, v.Int, x.Int, y.Int) + }, + "Rand": func(v, x bigInt, seed int64) bool { + return checkAliasingOneArg(t, func(v, x *big.Int) *big.Int { + rnd := rand.New(rand.NewSource(seed)) + return v.Rand(rnd, x) + }, v.Int, x.Int) + }, + "Rem": func(v, x bigInt, y notZeroInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).Rem, v.Int, x.Int, y.Int) + }, + "Rsh": func(v, x bigInt, n smallUint) bool { + return checkAliasingOneArg(t, func(v, x *big.Int) *big.Int { + return v.Rsh(x, n.uint) + }, v.Int, x.Int) + }, + "Set": func(v, x bigInt) bool { + return checkAliasingOneArg(t, (*big.Int).Set, v.Int, x.Int) + }, + "SetBit": func(v, x bigInt, i smallUint, b zeroOrOne) bool { + return checkAliasingOneArg(t, func(v, x *big.Int) *big.Int { + return v.SetBit(x, int(i.uint), b.uint) + }, v.Int, x.Int) + }, + "Sqrt": func(v bigInt, x positiveInt) bool { + return checkAliasingOneArg(t, (*big.Int).Sqrt, v.Int, x.Int) + }, + "Sub": func(v, x, y bigInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).Sub, v.Int, x.Int, y.Int) + }, + "Xor": func(v, x, y bigInt) bool { + return checkAliasingTwoArgs(t, (*big.Int).Xor, v.Int, x.Int, y.Int) + }, + } { + t.Run(name, func(t *testing.T) { + scale := 1.0 + switch name { + case "ModInverse", "GCD-Y", "GCD-X": + scale /= 5 + case "Rand": + scale /= 10 + case "Exp-XZ", "Exp-XY", "Exp-YZ": + scale /= 50 + case "ModSqrt": + scale /= 500 + } + if err := quick.Check(f, &quick.Config{ + MaxCountScale: scale, + }); err != nil { + t.Error(err) + } + }) + } +} diff --git a/src/math/big/arith.go b/src/math/big/arith.go new file mode 100644 index 0000000..06e63e2 --- /dev/null +++ b/src/math/big/arith.go @@ -0,0 +1,277 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file provides Go implementations of elementary multi-precision +// arithmetic operations on word vectors. These have the suffix _g. +// These are needed for platforms without assembly implementations of these routines. +// This file also contains elementary operations that can be implemented +// sufficiently efficiently in Go. + +package big + +import "math/bits" + +// A Word represents a single digit of a multi-precision unsigned integer. +type Word uint + +const ( + _S = _W / 8 // word size in bytes + + _W = bits.UintSize // word size in bits + _B = 1 << _W // digit base + _M = _B - 1 // digit mask +) + +// Many of the loops in this file are of the form +// for i := 0; i < len(z) && i < len(x) && i < len(y); i++ +// i < len(z) is the real condition. +// However, checking i < len(x) && i < len(y) as well is faster than +// having the compiler do a bounds check in the body of the loop; +// remarkably it is even faster than hoisting the bounds check +// out of the loop, by doing something like +// _, _ = x[len(z)-1], y[len(z)-1] +// There are other ways to hoist the bounds check out of the loop, +// but the compiler's BCE isn't powerful enough for them (yet?). +// See the discussion in CL 164966. + +// ---------------------------------------------------------------------------- +// Elementary operations on words +// +// These operations are used by the vector operations below. + +// z1<<_W + z0 = x*y +func mulWW(x, y Word) (z1, z0 Word) { + hi, lo := bits.Mul(uint(x), uint(y)) + return Word(hi), Word(lo) +} + +// z1<<_W + z0 = x*y + c +func mulAddWWW_g(x, y, c Word) (z1, z0 Word) { + hi, lo := bits.Mul(uint(x), uint(y)) + var cc uint + lo, cc = bits.Add(lo, uint(c), 0) + return Word(hi + cc), Word(lo) +} + +// nlz returns the number of leading zeros in x. +// Wraps bits.LeadingZeros call for convenience. +func nlz(x Word) uint { + return uint(bits.LeadingZeros(uint(x))) +} + +// The resulting carry c is either 0 or 1. +func addVV_g(z, x, y []Word) (c Word) { + // The comment near the top of this file discusses this for loop condition. + for i := 0; i < len(z) && i < len(x) && i < len(y); i++ { + zi, cc := bits.Add(uint(x[i]), uint(y[i]), uint(c)) + z[i] = Word(zi) + c = Word(cc) + } + return +} + +// The resulting carry c is either 0 or 1. +func subVV_g(z, x, y []Word) (c Word) { + // The comment near the top of this file discusses this for loop condition. + for i := 0; i < len(z) && i < len(x) && i < len(y); i++ { + zi, cc := bits.Sub(uint(x[i]), uint(y[i]), uint(c)) + z[i] = Word(zi) + c = Word(cc) + } + return +} + +// The resulting carry c is either 0 or 1. +func addVW_g(z, x []Word, y Word) (c Word) { + c = y + // The comment near the top of this file discusses this for loop condition. + for i := 0; i < len(z) && i < len(x); i++ { + zi, cc := bits.Add(uint(x[i]), uint(c), 0) + z[i] = Word(zi) + c = Word(cc) + } + return +} + +// addVWlarge is addVW, but intended for large z. +// The only difference is that we check on every iteration +// whether we are done with carries, +// and if so, switch to a much faster copy instead. +// This is only a good idea for large z, +// because the overhead of the check and the function call +// outweigh the benefits when z is small. +func addVWlarge(z, x []Word, y Word) (c Word) { + c = y + // The comment near the top of this file discusses this for loop condition. + for i := 0; i < len(z) && i < len(x); i++ { + if c == 0 { + copy(z[i:], x[i:]) + return + } + zi, cc := bits.Add(uint(x[i]), uint(c), 0) + z[i] = Word(zi) + c = Word(cc) + } + return +} + +func subVW_g(z, x []Word, y Word) (c Word) { + c = y + // The comment near the top of this file discusses this for loop condition. + for i := 0; i < len(z) && i < len(x); i++ { + zi, cc := bits.Sub(uint(x[i]), uint(c), 0) + z[i] = Word(zi) + c = Word(cc) + } + return +} + +// subVWlarge is to subVW as addVWlarge is to addVW. +func subVWlarge(z, x []Word, y Word) (c Word) { + c = y + // The comment near the top of this file discusses this for loop condition. + for i := 0; i < len(z) && i < len(x); i++ { + if c == 0 { + copy(z[i:], x[i:]) + return + } + zi, cc := bits.Sub(uint(x[i]), uint(c), 0) + z[i] = Word(zi) + c = Word(cc) + } + return +} + +func shlVU_g(z, x []Word, s uint) (c Word) { + if s == 0 { + copy(z, x) + return + } + if len(z) == 0 { + return + } + s &= _W - 1 // hint to the compiler that shifts by s don't need guard code + ŝ := _W - s + ŝ &= _W - 1 // ditto + c = x[len(z)-1] >> ŝ + for i := len(z) - 1; i > 0; i-- { + z[i] = x[i]<<s | x[i-1]>>ŝ + } + z[0] = x[0] << s + return +} + +func shrVU_g(z, x []Word, s uint) (c Word) { + if s == 0 { + copy(z, x) + return + } + if len(z) == 0 { + return + } + if len(x) != len(z) { + // This is an invariant guaranteed by the caller. + panic("len(x) != len(z)") + } + s &= _W - 1 // hint to the compiler that shifts by s don't need guard code + ŝ := _W - s + ŝ &= _W - 1 // ditto + c = x[0] << ŝ + for i := 1; i < len(z); i++ { + z[i-1] = x[i-1]>>s | x[i]<<ŝ + } + z[len(z)-1] = x[len(z)-1] >> s + return +} + +func mulAddVWW_g(z, x []Word, y, r Word) (c Word) { + c = r + // The comment near the top of this file discusses this for loop condition. + for i := 0; i < len(z) && i < len(x); i++ { + c, z[i] = mulAddWWW_g(x[i], y, c) + } + return +} + +func addMulVVW_g(z, x []Word, y Word) (c Word) { + // The comment near the top of this file discusses this for loop condition. + for i := 0; i < len(z) && i < len(x); i++ { + z1, z0 := mulAddWWW_g(x[i], y, z[i]) + lo, cc := bits.Add(uint(z0), uint(c), 0) + c, z[i] = Word(cc), Word(lo) + c += z1 + } + return +} + +// q = ( x1 << _W + x0 - r)/y. m = floor(( _B^2 - 1 ) / d - _B). Requiring x1<y. +// An approximate reciprocal with a reference to "Improved Division by Invariant Integers +// (IEEE Transactions on Computers, 11 Jun. 2010)" +func divWW(x1, x0, y, m Word) (q, r Word) { + s := nlz(y) + if s != 0 { + x1 = x1<<s | x0>>(_W-s) + x0 <<= s + y <<= s + } + d := uint(y) + // We know that + // m = ⎣(B^2-1)/d⎦-B + // ⎣(B^2-1)/d⎦ = m+B + // (B^2-1)/d = m+B+delta1 0 <= delta1 <= (d-1)/d + // B^2/d = m+B+delta2 0 <= delta2 <= 1 + // The quotient we're trying to compute is + // quotient = ⎣(x1*B+x0)/d⎦ + // = ⎣(x1*B*(B^2/d)+x0*(B^2/d))/B^2⎦ + // = ⎣(x1*B*(m+B+delta2)+x0*(m+B+delta2))/B^2⎦ + // = ⎣(x1*m+x1*B+x0)/B + x0*m/B^2 + delta2*(x1*B+x0)/B^2⎦ + // The latter two terms of this three-term sum are between 0 and 1. + // So we can compute just the first term, and we will be low by at most 2. + t1, t0 := bits.Mul(uint(m), uint(x1)) + _, c := bits.Add(t0, uint(x0), 0) + t1, _ = bits.Add(t1, uint(x1), c) + // The quotient is either t1, t1+1, or t1+2. + // We'll try t1 and adjust if needed. + qq := t1 + // compute remainder r=x-d*q. + dq1, dq0 := bits.Mul(d, qq) + r0, b := bits.Sub(uint(x0), dq0, 0) + r1, _ := bits.Sub(uint(x1), dq1, b) + // The remainder we just computed is bounded above by B+d: + // r = x1*B + x0 - d*q. + // = x1*B + x0 - d*⎣(x1*m+x1*B+x0)/B⎦ + // = x1*B + x0 - d*((x1*m+x1*B+x0)/B-alpha) 0 <= alpha < 1 + // = x1*B + x0 - x1*d/B*m - x1*d - x0*d/B + d*alpha + // = x1*B + x0 - x1*d/B*⎣(B^2-1)/d-B⎦ - x1*d - x0*d/B + d*alpha + // = x1*B + x0 - x1*d/B*⎣(B^2-1)/d-B⎦ - x1*d - x0*d/B + d*alpha + // = x1*B + x0 - x1*d/B*((B^2-1)/d-B-beta) - x1*d - x0*d/B + d*alpha 0 <= beta < 1 + // = x1*B + x0 - x1*B + x1/B + x1*d + x1*d/B*beta - x1*d - x0*d/B + d*alpha + // = x0 + x1/B + x1*d/B*beta - x0*d/B + d*alpha + // = x0*(1-d/B) + x1*(1+d*beta)/B + d*alpha + // < B*(1-d/B) + d*B/B + d because x0<B (and 1-d/B>0), x1<d, 1+d*beta<=B, alpha<1 + // = B - d + d + d + // = B+d + // So r1 can only be 0 or 1. If r1 is 1, then we know q was too small. + // Add 1 to q and subtract d from r. That guarantees that r is <B, so + // we no longer need to keep track of r1. + if r1 != 0 { + qq++ + r0 -= d + } + // If the remainder is still too large, increment q one more time. + if r0 >= d { + qq++ + r0 -= d + } + return Word(qq), Word(r0 >> s) +} + +// reciprocalWord return the reciprocal of the divisor. rec = floor(( _B^2 - 1 ) / u - _B). u = d1 << nlz(d1). +func reciprocalWord(d1 Word) Word { + u := uint(d1 << nlz(d1)) + x1 := ^u + x0 := uint(_M) + rec, _ := bits.Div(x1, x0, u) // (_B^2-1)/U-_B = (_B*(_M-C)+_M)/U + return Word(rec) +} diff --git a/src/math/big/arith_386.s b/src/math/big/arith_386.s new file mode 100644 index 0000000..8cf4665 --- /dev/null +++ b/src/math/big/arith_386.s @@ -0,0 +1,236 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go +// +build !math_big_pure_go + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// func addVV(z, x, y []Word) (c Word) +TEXT ·addVV(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), CX + MOVL z_len+4(FP), BP + MOVL $0, BX // i = 0 + MOVL $0, DX // c = 0 + JMP E1 + +L1: MOVL (SI)(BX*4), AX + ADDL DX, DX // restore CF + ADCL (CX)(BX*4), AX + SBBL DX, DX // save CF + MOVL AX, (DI)(BX*4) + ADDL $1, BX // i++ + +E1: CMPL BX, BP // i < n + JL L1 + + NEGL DX + MOVL DX, c+36(FP) + RET + + +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SBBL instead of ADCL and label names) +TEXT ·subVV(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), CX + MOVL z_len+4(FP), BP + MOVL $0, BX // i = 0 + MOVL $0, DX // c = 0 + JMP E2 + +L2: MOVL (SI)(BX*4), AX + ADDL DX, DX // restore CF + SBBL (CX)(BX*4), AX + SBBL DX, DX // save CF + MOVL AX, (DI)(BX*4) + ADDL $1, BX // i++ + +E2: CMPL BX, BP // i < n + JL L2 + + NEGL DX + MOVL DX, c+36(FP) + RET + + +// func addVW(z, x []Word, y Word) (c Word) +TEXT ·addVW(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), AX // c = y + MOVL z_len+4(FP), BP + MOVL $0, BX // i = 0 + JMP E3 + +L3: ADDL (SI)(BX*4), AX + MOVL AX, (DI)(BX*4) + SBBL AX, AX // save CF + NEGL AX + ADDL $1, BX // i++ + +E3: CMPL BX, BP // i < n + JL L3 + + MOVL AX, c+28(FP) + RET + + +// func subVW(z, x []Word, y Word) (c Word) +TEXT ·subVW(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), AX // c = y + MOVL z_len+4(FP), BP + MOVL $0, BX // i = 0 + JMP E4 + +L4: MOVL (SI)(BX*4), DX + SUBL AX, DX + MOVL DX, (DI)(BX*4) + SBBL AX, AX // save CF + NEGL AX + ADDL $1, BX // i++ + +E4: CMPL BX, BP // i < n + JL L4 + + MOVL AX, c+28(FP) + RET + + +// func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB),NOSPLIT,$0 + MOVL z_len+4(FP), BX // i = z + SUBL $1, BX // i-- + JL X8b // i < 0 (n <= 0) + + // n > 0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL s+24(FP), CX + MOVL (SI)(BX*4), AX // w1 = x[n-1] + MOVL $0, DX + SHLL CX, AX, DX // w1>>ŝ + MOVL DX, c+28(FP) + + CMPL BX, $0 + JLE X8a // i <= 0 + + // i > 0 +L8: MOVL AX, DX // w = w1 + MOVL -4(SI)(BX*4), AX // w1 = x[i-1] + SHLL CX, AX, DX // w<<s | w1>>ŝ + MOVL DX, (DI)(BX*4) // z[i] = w<<s | w1>>ŝ + SUBL $1, BX // i-- + JG L8 // i > 0 + + // i <= 0 +X8a: SHLL CX, AX // w1<<s + MOVL AX, (DI) // z[0] = w1<<s + RET + +X8b: MOVL $0, c+28(FP) + RET + + +// func shrVU(z, x []Word, s uint) (c Word) +TEXT ·shrVU(SB),NOSPLIT,$0 + MOVL z_len+4(FP), BP + SUBL $1, BP // n-- + JL X9b // n < 0 (n <= 0) + + // n > 0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL s+24(FP), CX + MOVL (SI), AX // w1 = x[0] + MOVL $0, DX + SHRL CX, AX, DX // w1<<ŝ + MOVL DX, c+28(FP) + + MOVL $0, BX // i = 0 + JMP E9 + + // i < n-1 +L9: MOVL AX, DX // w = w1 + MOVL 4(SI)(BX*4), AX // w1 = x[i+1] + SHRL CX, AX, DX // w>>s | w1<<ŝ + MOVL DX, (DI)(BX*4) // z[i] = w>>s | w1<<ŝ + ADDL $1, BX // i++ + +E9: CMPL BX, BP + JL L9 // i < n-1 + + // i >= n-1 +X9a: SHRL CX, AX // w1>>s + MOVL AX, (DI)(BP*4) // z[n-1] = w1>>s + RET + +X9b: MOVL $0, c+28(FP) + RET + + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), BP + MOVL r+28(FP), CX // c = r + MOVL z_len+4(FP), BX + LEAL (DI)(BX*4), DI + LEAL (SI)(BX*4), SI + NEGL BX // i = -n + JMP E5 + +L5: MOVL (SI)(BX*4), AX + MULL BP + ADDL CX, AX + ADCL $0, DX + MOVL AX, (DI)(BX*4) + MOVL DX, CX + ADDL $1, BX // i++ + +E5: CMPL BX, $0 // i < 0 + JL L5 + + MOVL CX, c+32(FP) + RET + + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), BP + MOVL z_len+4(FP), BX + LEAL (DI)(BX*4), DI + LEAL (SI)(BX*4), SI + NEGL BX // i = -n + MOVL $0, CX // c = 0 + JMP E6 + +L6: MOVL (SI)(BX*4), AX + MULL BP + ADDL CX, AX + ADCL $0, DX + ADDL AX, (DI)(BX*4) + ADCL $0, DX + MOVL DX, CX + ADDL $1, BX // i++ + +E6: CMPL BX, $0 // i < 0 + JL L6 + + MOVL CX, c+28(FP) + RET + + + diff --git a/src/math/big/arith_amd64.go b/src/math/big/arith_amd64.go new file mode 100644 index 0000000..89108fe --- /dev/null +++ b/src/math/big/arith_amd64.go @@ -0,0 +1,12 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go +// +build !math_big_pure_go + +package big + +import "internal/cpu" + +var support_adx = cpu.X86.HasADX && cpu.X86.HasBMI2 diff --git a/src/math/big/arith_amd64.s b/src/math/big/arith_amd64.s new file mode 100644 index 0000000..b1e914c --- /dev/null +++ b/src/math/big/arith_amd64.s @@ -0,0 +1,516 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go +// +build !math_big_pure_go + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// The carry bit is saved with SBBQ Rx, Rx: if the carry was set, Rx is -1, otherwise it is 0. +// It is restored with ADDQ Rx, Rx: if Rx was -1 the carry is set, otherwise it is cleared. +// This is faster than using rotate instructions. + +// func addVV(z, x, y []Word) (c Word) +TEXT ·addVV(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ z+0(FP), R10 + + MOVQ $0, CX // c = 0 + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V1 // if n < 0 goto V1 + +U1: // n >= 0 + // regular loop body unrolled 4x + ADDQ CX, CX // restore CF + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + ADCQ 0(R9)(SI*8), R11 + ADCQ 8(R9)(SI*8), R12 + ADCQ 16(R9)(SI*8), R13 + ADCQ 24(R9)(SI*8), R14 + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + SBBQ CX, CX // save CF + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U1 // if n >= 0 goto U1 + +V1: ADDQ $4, DI // n += 4 + JLE E1 // if n <= 0 goto E1 + +L1: // n > 0 + ADDQ CX, CX // restore CF + MOVQ 0(R8)(SI*8), R11 + ADCQ 0(R9)(SI*8), R11 + MOVQ R11, 0(R10)(SI*8) + SBBQ CX, CX // save CF + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L1 // if n > 0 goto L1 + +E1: NEGQ CX + MOVQ CX, c+72(FP) // return c + RET + + +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SBBQ instead of ADCQ and label names) +TEXT ·subVV(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ z+0(FP), R10 + + MOVQ $0, CX // c = 0 + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V2 // if n < 0 goto V2 + +U2: // n >= 0 + // regular loop body unrolled 4x + ADDQ CX, CX // restore CF + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + SBBQ 0(R9)(SI*8), R11 + SBBQ 8(R9)(SI*8), R12 + SBBQ 16(R9)(SI*8), R13 + SBBQ 24(R9)(SI*8), R14 + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + SBBQ CX, CX // save CF + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U2 // if n >= 0 goto U2 + +V2: ADDQ $4, DI // n += 4 + JLE E2 // if n <= 0 goto E2 + +L2: // n > 0 + ADDQ CX, CX // restore CF + MOVQ 0(R8)(SI*8), R11 + SBBQ 0(R9)(SI*8), R11 + MOVQ R11, 0(R10)(SI*8) + SBBQ CX, CX // save CF + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L2 // if n > 0 goto L2 + +E2: NEGQ CX + MOVQ CX, c+72(FP) // return c + RET + + +// func addVW(z, x []Word, y Word) (c Word) +TEXT ·addVW(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), DI + CMPQ DI, $32 + JG large + MOVQ x+24(FP), R8 + MOVQ y+48(FP), CX // c = y + MOVQ z+0(FP), R10 + + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V3 // if n < 4 goto V3 + +U3: // n >= 0 + // regular loop body unrolled 4x + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + ADDQ CX, R11 + ADCQ $0, R12 + ADCQ $0, R13 + ADCQ $0, R14 + SBBQ CX, CX // save CF + NEGQ CX + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U3 // if n >= 0 goto U3 + +V3: ADDQ $4, DI // n += 4 + JLE E3 // if n <= 0 goto E3 + +L3: // n > 0 + ADDQ 0(R8)(SI*8), CX + MOVQ CX, 0(R10)(SI*8) + SBBQ CX, CX // save CF + NEGQ CX + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L3 // if n > 0 goto L3 + +E3: MOVQ CX, c+56(FP) // return c + RET +large: + JMP ·addVWlarge(SB) + + +// func subVW(z, x []Word, y Word) (c Word) +// (same as addVW except for SUBQ/SBBQ instead of ADDQ/ADCQ and label names) +TEXT ·subVW(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), DI + CMPQ DI, $32 + JG large + MOVQ x+24(FP), R8 + MOVQ y+48(FP), CX // c = y + MOVQ z+0(FP), R10 + + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V4 // if n < 4 goto V4 + +U4: // n >= 0 + // regular loop body unrolled 4x + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + SUBQ CX, R11 + SBBQ $0, R12 + SBBQ $0, R13 + SBBQ $0, R14 + SBBQ CX, CX // save CF + NEGQ CX + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U4 // if n >= 0 goto U4 + +V4: ADDQ $4, DI // n += 4 + JLE E4 // if n <= 0 goto E4 + +L4: // n > 0 + MOVQ 0(R8)(SI*8), R11 + SUBQ CX, R11 + MOVQ R11, 0(R10)(SI*8) + SBBQ CX, CX // save CF + NEGQ CX + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L4 // if n > 0 goto L4 + +E4: MOVQ CX, c+56(FP) // return c + RET +large: + JMP ·subVWlarge(SB) + + +// func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), BX // i = z + SUBQ $1, BX // i-- + JL X8b // i < 0 (n <= 0) + + // n > 0 + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + MOVQ s+48(FP), CX + MOVQ (R8)(BX*8), AX // w1 = x[n-1] + MOVQ $0, DX + SHLQ CX, AX, DX // w1>>ŝ + MOVQ DX, c+56(FP) + + CMPQ BX, $0 + JLE X8a // i <= 0 + + // i > 0 +L8: MOVQ AX, DX // w = w1 + MOVQ -8(R8)(BX*8), AX // w1 = x[i-1] + SHLQ CX, AX, DX // w<<s | w1>>ŝ + MOVQ DX, (R10)(BX*8) // z[i] = w<<s | w1>>ŝ + SUBQ $1, BX // i-- + JG L8 // i > 0 + + // i <= 0 +X8a: SHLQ CX, AX // w1<<s + MOVQ AX, (R10) // z[0] = w1<<s + RET + +X8b: MOVQ $0, c+56(FP) + RET + + +// func shrVU(z, x []Word, s uint) (c Word) +TEXT ·shrVU(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), R11 + SUBQ $1, R11 // n-- + JL X9b // n < 0 (n <= 0) + + // n > 0 + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + MOVQ s+48(FP), CX + MOVQ (R8), AX // w1 = x[0] + MOVQ $0, DX + SHRQ CX, AX, DX // w1<<ŝ + MOVQ DX, c+56(FP) + + MOVQ $0, BX // i = 0 + JMP E9 + + // i < n-1 +L9: MOVQ AX, DX // w = w1 + MOVQ 8(R8)(BX*8), AX // w1 = x[i+1] + SHRQ CX, AX, DX // w>>s | w1<<ŝ + MOVQ DX, (R10)(BX*8) // z[i] = w>>s | w1<<ŝ + ADDQ $1, BX // i++ + +E9: CMPQ BX, R11 + JL L9 // i < n-1 + + // i >= n-1 +X9a: SHRQ CX, AX // w1>>s + MOVQ AX, (R10)(R11*8) // z[n-1] = w1>>s + RET + +X9b: MOVQ $0, c+56(FP) + RET + + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ r+56(FP), CX // c = r + MOVQ z_len+8(FP), R11 + MOVQ $0, BX // i = 0 + + CMPQ R11, $4 + JL E5 + +U5: // i+4 <= n + // regular loop body unrolled 4x + MOVQ (0*8)(R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (0*8)(R10)(BX*8) + MOVQ DX, CX + MOVQ (1*8)(R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (1*8)(R10)(BX*8) + MOVQ DX, CX + MOVQ (2*8)(R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (2*8)(R10)(BX*8) + MOVQ DX, CX + MOVQ (3*8)(R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (3*8)(R10)(BX*8) + MOVQ DX, CX + ADDQ $4, BX // i += 4 + + LEAQ 4(BX), DX + CMPQ DX, R11 + JLE U5 + JMP E5 + +L5: MOVQ (R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (R10)(BX*8) + MOVQ DX, CX + ADDQ $1, BX // i++ + +E5: CMPQ BX, R11 // i < n + JL L5 + + MOVQ CX, c+64(FP) + RET + + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB),NOSPLIT,$0 + CMPB ·support_adx(SB), $1 + JEQ adx + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ z_len+8(FP), R11 + MOVQ $0, BX // i = 0 + MOVQ $0, CX // c = 0 + MOVQ R11, R12 + ANDQ $-2, R12 + CMPQ R11, $2 + JAE A6 + JMP E6 + +A6: + MOVQ (R8)(BX*8), AX + MULQ R9 + ADDQ (R10)(BX*8), AX + ADCQ $0, DX + ADDQ CX, AX + ADCQ $0, DX + MOVQ DX, CX + MOVQ AX, (R10)(BX*8) + + MOVQ (8)(R8)(BX*8), AX + MULQ R9 + ADDQ (8)(R10)(BX*8), AX + ADCQ $0, DX + ADDQ CX, AX + ADCQ $0, DX + MOVQ DX, CX + MOVQ AX, (8)(R10)(BX*8) + + ADDQ $2, BX + CMPQ BX, R12 + JL A6 + JMP E6 + +L6: MOVQ (R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + ADDQ AX, (R10)(BX*8) + ADCQ $0, DX + MOVQ DX, CX + ADDQ $1, BX // i++ + +E6: CMPQ BX, R11 // i < n + JL L6 + + MOVQ CX, c+56(FP) + RET + +adx: + MOVQ z_len+8(FP), R11 + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + MOVQ y+48(FP), DX + MOVQ $0, BX // i = 0 + MOVQ $0, CX // carry + CMPQ R11, $8 + JAE adx_loop_header + CMPQ BX, R11 + JL adx_short + MOVQ CX, c+56(FP) + RET + +adx_loop_header: + MOVQ R11, R13 + ANDQ $-8, R13 +adx_loop: + XORQ R9, R9 // unset flags + MULXQ (R8), SI, DI + ADCXQ CX,SI + ADOXQ (R10), SI + MOVQ SI,(R10) + + MULXQ 8(R8), AX, CX + ADCXQ DI, AX + ADOXQ 8(R10), AX + MOVQ AX, 8(R10) + + MULXQ 16(R8), SI, DI + ADCXQ CX, SI + ADOXQ 16(R10), SI + MOVQ SI, 16(R10) + + MULXQ 24(R8), AX, CX + ADCXQ DI, AX + ADOXQ 24(R10), AX + MOVQ AX, 24(R10) + + MULXQ 32(R8), SI, DI + ADCXQ CX, SI + ADOXQ 32(R10), SI + MOVQ SI, 32(R10) + + MULXQ 40(R8), AX, CX + ADCXQ DI, AX + ADOXQ 40(R10), AX + MOVQ AX, 40(R10) + + MULXQ 48(R8), SI, DI + ADCXQ CX, SI + ADOXQ 48(R10), SI + MOVQ SI, 48(R10) + + MULXQ 56(R8), AX, CX + ADCXQ DI, AX + ADOXQ 56(R10), AX + MOVQ AX, 56(R10) + + ADCXQ R9, CX + ADOXQ R9, CX + + ADDQ $64, R8 + ADDQ $64, R10 + ADDQ $8, BX + + CMPQ BX, R13 + JL adx_loop + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + CMPQ BX, R11 + JL adx_short + MOVQ CX, c+56(FP) + RET + +adx_short: + MULXQ (R8)(BX*8), SI, DI + ADDQ CX, SI + ADCQ $0, DI + ADDQ SI, (R10)(BX*8) + ADCQ $0, DI + MOVQ DI, CX + ADDQ $1, BX // i++ + + CMPQ BX, R11 + JL adx_short + + MOVQ CX, c+56(FP) + RET + + + diff --git a/src/math/big/arith_arm.s b/src/math/big/arith_arm.s new file mode 100644 index 0000000..10054bd --- /dev/null +++ b/src/math/big/arith_arm.s @@ -0,0 +1,273 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go +// +build !math_big_pure_go + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// func addVV(z, x, y []Word) (c Word) +TEXT ·addVV(SB),NOSPLIT,$0 + ADD.S $0, R0 // clear carry flag + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R4 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R4<<2, R1, R4 + B E1 +L1: + MOVW.P 4(R2), R5 + MOVW.P 4(R3), R6 + ADC.S R6, R5 + MOVW.P R5, 4(R1) +E1: + TEQ R1, R4 + BNE L1 + + MOVW $0, R0 + MOVW.CS $1, R0 + MOVW R0, c+36(FP) + RET + + +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SBC instead of ADC and label names) +TEXT ·subVV(SB),NOSPLIT,$0 + SUB.S $0, R0 // clear borrow flag + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R4 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R4<<2, R1, R4 + B E2 +L2: + MOVW.P 4(R2), R5 + MOVW.P 4(R3), R6 + SBC.S R6, R5 + MOVW.P R5, 4(R1) +E2: + TEQ R1, R4 + BNE L2 + + MOVW $0, R0 + MOVW.CC $1, R0 + MOVW R0, c+36(FP) + RET + + +// func addVW(z, x []Word, y Word) (c Word) +TEXT ·addVW(SB),NOSPLIT,$0 + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R4 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R4<<2, R1, R4 + TEQ R1, R4 + BNE L3a + MOVW R3, c+28(FP) + RET +L3a: + MOVW.P 4(R2), R5 + ADD.S R3, R5 + MOVW.P R5, 4(R1) + B E3 +L3: + MOVW.P 4(R2), R5 + ADC.S $0, R5 + MOVW.P R5, 4(R1) +E3: + TEQ R1, R4 + BNE L3 + + MOVW $0, R0 + MOVW.CS $1, R0 + MOVW R0, c+28(FP) + RET + + +// func subVW(z, x []Word, y Word) (c Word) +TEXT ·subVW(SB),NOSPLIT,$0 + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R4 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R4<<2, R1, R4 + TEQ R1, R4 + BNE L4a + MOVW R3, c+28(FP) + RET +L4a: + MOVW.P 4(R2), R5 + SUB.S R3, R5 + MOVW.P R5, 4(R1) + B E4 +L4: + MOVW.P 4(R2), R5 + SBC.S $0, R5 + MOVW.P R5, 4(R1) +E4: + TEQ R1, R4 + BNE L4 + + MOVW $0, R0 + MOVW.CC $1, R0 + MOVW R0, c+28(FP) + RET + + +// func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB),NOSPLIT,$0 + MOVW z_len+4(FP), R5 + TEQ $0, R5 + BEQ X7 + + MOVW z+0(FP), R1 + MOVW x+12(FP), R2 + ADD R5<<2, R2, R2 + ADD R5<<2, R1, R5 + MOVW s+24(FP), R3 + TEQ $0, R3 // shift 0 is special + BEQ Y7 + ADD $4, R1 // stop one word early + MOVW $32, R4 + SUB R3, R4 + MOVW $0, R7 + + MOVW.W -4(R2), R6 + MOVW R6<<R3, R7 + MOVW R6>>R4, R6 + MOVW R6, c+28(FP) + B E7 + +L7: + MOVW.W -4(R2), R6 + ORR R6>>R4, R7 + MOVW.W R7, -4(R5) + MOVW R6<<R3, R7 +E7: + TEQ R1, R5 + BNE L7 + + MOVW R7, -4(R5) + RET + +Y7: // copy loop, because shift 0 == shift 32 + MOVW.W -4(R2), R6 + MOVW.W R6, -4(R5) + TEQ R1, R5 + BNE Y7 + +X7: + MOVW $0, R1 + MOVW R1, c+28(FP) + RET + + +// func shrVU(z, x []Word, s uint) (c Word) +TEXT ·shrVU(SB),NOSPLIT,$0 + MOVW z_len+4(FP), R5 + TEQ $0, R5 + BEQ X6 + + MOVW z+0(FP), R1 + MOVW x+12(FP), R2 + ADD R5<<2, R1, R5 + MOVW s+24(FP), R3 + TEQ $0, R3 // shift 0 is special + BEQ Y6 + SUB $4, R5 // stop one word early + MOVW $32, R4 + SUB R3, R4 + MOVW $0, R7 + + // first word + MOVW.P 4(R2), R6 + MOVW R6>>R3, R7 + MOVW R6<<R4, R6 + MOVW R6, c+28(FP) + B E6 + + // word loop +L6: + MOVW.P 4(R2), R6 + ORR R6<<R4, R7 + MOVW.P R7, 4(R1) + MOVW R6>>R3, R7 +E6: + TEQ R1, R5 + BNE L6 + + MOVW R7, 0(R1) + RET + +Y6: // copy loop, because shift 0 == shift 32 + MOVW.P 4(R2), R6 + MOVW.P R6, 4(R1) + TEQ R1, R5 + BNE Y6 + +X6: + MOVW $0, R1 + MOVW R1, c+28(FP) + RET + + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + MOVW $0, R0 + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R5 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + MOVW r+28(FP), R4 + ADD R5<<2, R1, R5 + B E8 + + // word loop +L8: + MOVW.P 4(R2), R6 + MULLU R6, R3, (R7, R6) + ADD.S R4, R6 + ADC R0, R7 + MOVW.P R6, 4(R1) + MOVW R7, R4 +E8: + TEQ R1, R5 + BNE L8 + + MOVW R4, c+32(FP) + RET + + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB),NOSPLIT,$0 + MOVW $0, R0 + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R5 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R5<<2, R1, R5 + MOVW $0, R4 + B E9 + + // word loop +L9: + MOVW.P 4(R2), R6 + MULLU R6, R3, (R7, R6) + ADD.S R4, R6 + ADC R0, R7 + MOVW 0(R1), R4 + ADD.S R4, R6 + ADC R0, R7 + MOVW.P R6, 4(R1) + MOVW R7, R4 +E9: + TEQ R1, R5 + BNE L9 + + MOVW R4, c+28(FP) + RET diff --git a/src/math/big/arith_arm64.s b/src/math/big/arith_arm64.s new file mode 100644 index 0000000..addf2d6 --- /dev/null +++ b/src/math/big/arith_arm64.s @@ -0,0 +1,573 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go +// +build !math_big_pure_go + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// TODO: Consider re-implementing using Advanced SIMD +// once the assembler supports those instructions. + +// func addVV(z, x, y []Word) (c Word) +TEXT ·addVV(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R10 + ADDS $0, R0 // clear carry flag + TBZ $0, R0, two + MOVD.P 8(R8), R11 + MOVD.P 8(R9), R15 + ADCS R15, R11 + MOVD.P R11, 8(R10) + SUB $1, R0 +two: + TBZ $1, R0, loop + LDP.P 16(R8), (R11, R12) + LDP.P 16(R9), (R15, R16) + ADCS R15, R11 + ADCS R16, R12 + STP.P (R11, R12), 16(R10) + SUB $2, R0 +loop: + CBZ R0, done // careful not to touch the carry flag + LDP.P 32(R8), (R11, R12) + LDP -16(R8), (R13, R14) + LDP.P 32(R9), (R15, R16) + LDP -16(R9), (R17, R19) + ADCS R15, R11 + ADCS R16, R12 + ADCS R17, R13 + ADCS R19, R14 + STP.P (R11, R12), 32(R10) + STP (R13, R14), -16(R10) + SUB $4, R0 + B loop +done: + CSET HS, R0 // extract carry flag + MOVD R0, c+72(FP) + RET + + +// func subVV(z, x, y []Word) (c Word) +TEXT ·subVV(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R10 + CMP R0, R0 // set carry flag + TBZ $0, R0, two + MOVD.P 8(R8), R11 + MOVD.P 8(R9), R15 + SBCS R15, R11 + MOVD.P R11, 8(R10) + SUB $1, R0 +two: + TBZ $1, R0, loop + LDP.P 16(R8), (R11, R12) + LDP.P 16(R9), (R15, R16) + SBCS R15, R11 + SBCS R16, R12 + STP.P (R11, R12), 16(R10) + SUB $2, R0 +loop: + CBZ R0, done // careful not to touch the carry flag + LDP.P 32(R8), (R11, R12) + LDP -16(R8), (R13, R14) + LDP.P 32(R9), (R15, R16) + LDP -16(R9), (R17, R19) + SBCS R15, R11 + SBCS R16, R12 + SBCS R17, R13 + SBCS R19, R14 + STP.P (R11, R12), 32(R10) + STP (R13, R14), -16(R10) + SUB $4, R0 + B loop +done: + CSET LO, R0 // extract carry flag + MOVD R0, c+72(FP) + RET + +#define vwOneOp(instr, op1) \ + MOVD.P 8(R1), R4; \ + instr op1, R4; \ + MOVD.P R4, 8(R3); + +// handle the first 1~4 elements before starting iteration in addVW/subVW +#define vwPreIter(instr1, instr2, counter, target) \ + vwOneOp(instr1, R2); \ + SUB $1, counter; \ + CBZ counter, target; \ + vwOneOp(instr2, $0); \ + SUB $1, counter; \ + CBZ counter, target; \ + vwOneOp(instr2, $0); \ + SUB $1, counter; \ + CBZ counter, target; \ + vwOneOp(instr2, $0); + +// do one iteration of add or sub in addVW/subVW +#define vwOneIter(instr, counter, exit) \ + CBZ counter, exit; \ // careful not to touch the carry flag + LDP.P 32(R1), (R4, R5); \ + LDP -16(R1), (R6, R7); \ + instr $0, R4, R8; \ + instr $0, R5, R9; \ + instr $0, R6, R10; \ + instr $0, R7, R11; \ + STP.P (R8, R9), 32(R3); \ + STP (R10, R11), -16(R3); \ + SUB $4, counter; + +// do one iteration of copy in addVW/subVW +#define vwOneIterCopy(counter, exit) \ + CBZ counter, exit; \ + LDP.P 32(R1), (R4, R5); \ + LDP -16(R1), (R6, R7); \ + STP.P (R4, R5), 32(R3); \ + STP (R6, R7), -16(R3); \ + SUB $4, counter; + +// func addVW(z, x []Word, y Word) (c Word) +// The 'large' branch handles large 'z'. It checks the carry flag on every iteration +// and switches to copy if we are done with carries. The copying is skipped as well +// if 'x' and 'z' happen to share the same underlying storage. +// The overhead of the checking and branching is visible when 'z' are small (~5%), +// so set a threshold of 32, and remain the small-sized part entirely untouched. +TEXT ·addVW(SB),NOSPLIT,$0 + MOVD z+0(FP), R3 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R1 + MOVD y+48(FP), R2 + CMP $32, R0 + BGE large // large-sized 'z' and 'x' + CBZ R0, len0 // the length of z is 0 + MOVD.P 8(R1), R4 + ADDS R2, R4 // z[0] = x[0] + y, set carry + MOVD.P R4, 8(R3) + SUB $1, R0 + CBZ R0, len1 // the length of z is 1 + TBZ $0, R0, two + MOVD.P 8(R1), R4 // do it once + ADCS $0, R4 + MOVD.P R4, 8(R3) + SUB $1, R0 +two: // do it twice + TBZ $1, R0, loop + LDP.P 16(R1), (R4, R5) + ADCS $0, R4, R8 // c, z[i] = x[i] + c + ADCS $0, R5, R9 + STP.P (R8, R9), 16(R3) + SUB $2, R0 +loop: // do four times per round + vwOneIter(ADCS, R0, len1) + B loop +len1: + CSET HS, R2 // extract carry flag +len0: + MOVD R2, c+56(FP) +done: + RET +large: + AND $0x3, R0, R10 + AND $~0x3, R0 + // unrolling for the first 1~4 elements to avoid saving the carry + // flag in each step, adjust $R0 if we unrolled 4 elements + vwPreIter(ADDS, ADCS, R10, add4) + SUB $4, R0 +add4: + BCC copy + vwOneIter(ADCS, R0, len1) + B add4 +copy: + MOVD ZR, c+56(FP) + CMP R1, R3 + BEQ done +copy_4: // no carry flag, copy the rest + vwOneIterCopy(R0, done) + B copy_4 + +// func subVW(z, x []Word, y Word) (c Word) +// The 'large' branch handles large 'z'. It checks the carry flag on every iteration +// and switches to copy if we are done with carries. The copying is skipped as well +// if 'x' and 'z' happen to share the same underlying storage. +// The overhead of the checking and branching is visible when 'z' are small (~5%), +// so set a threshold of 32, and remain the small-sized part entirely untouched. +TEXT ·subVW(SB),NOSPLIT,$0 + MOVD z+0(FP), R3 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R1 + MOVD y+48(FP), R2 + CMP $32, R0 + BGE large // large-sized 'z' and 'x' + CBZ R0, len0 // the length of z is 0 + MOVD.P 8(R1), R4 + SUBS R2, R4 // z[0] = x[0] - y, set carry + MOVD.P R4, 8(R3) + SUB $1, R0 + CBZ R0, len1 // the length of z is 1 + TBZ $0, R0, two // do it once + MOVD.P 8(R1), R4 + SBCS $0, R4 + MOVD.P R4, 8(R3) + SUB $1, R0 +two: // do it twice + TBZ $1, R0, loop + LDP.P 16(R1), (R4, R5) + SBCS $0, R4, R8 // c, z[i] = x[i] + c + SBCS $0, R5, R9 + STP.P (R8, R9), 16(R3) + SUB $2, R0 +loop: // do four times per round + vwOneIter(SBCS, R0, len1) + B loop +len1: + CSET LO, R2 // extract carry flag +len0: + MOVD R2, c+56(FP) +done: + RET +large: + AND $0x3, R0, R10 + AND $~0x3, R0 + // unrolling for the first 1~4 elements to avoid saving the carry + // flag in each step, adjust $R0 if we unrolled 4 elements + vwPreIter(SUBS, SBCS, R10, sub4) + SUB $4, R0 +sub4: + BCS copy + vwOneIter(SBCS, R0, len1) + B sub4 +copy: + MOVD ZR, c+56(FP) + CMP R1, R3 + BEQ done +copy_4: // no carry flag, copy the rest + vwOneIterCopy(R0, done) + B copy_4 + +// func shlVU(z, x []Word, s uint) (c Word) +// This implementation handles the shift operation from the high word to the low word, +// which may be an error for the case where the low word of x overlaps with the high +// word of z. When calling this function directly, you need to pay attention to this +// situation. +TEXT ·shlVU(SB),NOSPLIT,$0 + LDP z+0(FP), (R0, R1) // R0 = z.ptr, R1 = len(z) + MOVD x+24(FP), R2 + MOVD s+48(FP), R3 + ADD R1<<3, R0 // R0 = &z[n] + ADD R1<<3, R2 // R2 = &x[n] + CBZ R1, len0 + CBZ R3, copy // if the number of shift is 0, just copy x to z + MOVD $64, R4 + SUB R3, R4 + // handling the most significant element x[n-1] + MOVD.W -8(R2), R6 + LSR R4, R6, R5 // return value + LSL R3, R6, R8 // x[i] << s + SUB $1, R1 +one: TBZ $0, R1, two + MOVD.W -8(R2), R6 + LSR R4, R6, R7 + ORR R8, R7 + LSL R3, R6, R8 + SUB $1, R1 + MOVD.W R7, -8(R0) +two: + TBZ $1, R1, loop + LDP.W -16(R2), (R6, R7) + LSR R4, R7, R10 + ORR R8, R10 + LSL R3, R7 + LSR R4, R6, R9 + ORR R7, R9 + LSL R3, R6, R8 + SUB $2, R1 + STP.W (R9, R10), -16(R0) +loop: + CBZ R1, done + LDP.W -32(R2), (R10, R11) + LDP 16(R2), (R12, R13) + LSR R4, R13, R23 + ORR R8, R23 // z[i] = (x[i] << s) | (x[i-1] >> (64 - s)) + LSL R3, R13 + LSR R4, R12, R22 + ORR R13, R22 + LSL R3, R12 + LSR R4, R11, R21 + ORR R12, R21 + LSL R3, R11 + LSR R4, R10, R20 + ORR R11, R20 + LSL R3, R10, R8 + STP.W (R20, R21), -32(R0) + STP (R22, R23), 16(R0) + SUB $4, R1 + B loop +done: + MOVD.W R8, -8(R0) // the first element x[0] + MOVD R5, c+56(FP) // the part moved out from x[n-1] + RET +copy: + CMP R0, R2 + BEQ len0 + TBZ $0, R1, ctwo + MOVD.W -8(R2), R4 + MOVD.W R4, -8(R0) + SUB $1, R1 +ctwo: + TBZ $1, R1, cloop + LDP.W -16(R2), (R4, R5) + STP.W (R4, R5), -16(R0) + SUB $2, R1 +cloop: + CBZ R1, len0 + LDP.W -32(R2), (R4, R5) + LDP 16(R2), (R6, R7) + STP.W (R4, R5), -32(R0) + STP (R6, R7), 16(R0) + SUB $4, R1 + B cloop +len0: + MOVD $0, c+56(FP) + RET + +// func shrVU(z, x []Word, s uint) (c Word) +// This implementation handles the shift operation from the low word to the high word, +// which may be an error for the case where the high word of x overlaps with the low +// word of z. When calling this function directly, you need to pay attention to this +// situation. +TEXT ·shrVU(SB),NOSPLIT,$0 + MOVD z+0(FP), R0 + MOVD z_len+8(FP), R1 + MOVD x+24(FP), R2 + MOVD s+48(FP), R3 + MOVD $0, R8 + MOVD $64, R4 + SUB R3, R4 + CBZ R1, len0 + CBZ R3, copy // if the number of shift is 0, just copy x to z + + MOVD.P 8(R2), R20 + LSR R3, R20, R8 + LSL R4, R20 + MOVD R20, c+56(FP) // deal with the first element + SUB $1, R1 + + TBZ $0, R1, two + MOVD.P 8(R2), R6 + LSL R4, R6, R20 + ORR R8, R20 + LSR R3, R6, R8 + MOVD.P R20, 8(R0) + SUB $1, R1 +two: + TBZ $1, R1, loop + LDP.P 16(R2), (R6, R7) + LSL R4, R6, R20 + LSR R3, R6 + ORR R8, R20 + LSL R4, R7, R21 + LSR R3, R7, R8 + ORR R6, R21 + STP.P (R20, R21), 16(R0) + SUB $2, R1 +loop: + CBZ R1, done + LDP.P 32(R2), (R10, R11) + LDP -16(R2), (R12, R13) + LSL R4, R10, R20 + LSR R3, R10 + ORR R8, R20 // z[i] = (x[i] >> s) | (x[i+1] << (64 - s)) + LSL R4, R11, R21 + LSR R3, R11 + ORR R10, R21 + LSL R4, R12, R22 + LSR R3, R12 + ORR R11, R22 + LSL R4, R13, R23 + LSR R3, R13, R8 + ORR R12, R23 + STP.P (R20, R21), 32(R0) + STP (R22, R23), -16(R0) + SUB $4, R1 + B loop +done: + MOVD R8, (R0) // deal with the last element + RET +copy: + CMP R0, R2 + BEQ len0 + TBZ $0, R1, ctwo + MOVD.P 8(R2), R3 + MOVD.P R3, 8(R0) + SUB $1, R1 +ctwo: + TBZ $1, R1, cloop + LDP.P 16(R2), (R4, R5) + STP.P (R4, R5), 16(R0) + SUB $2, R1 +cloop: + CBZ R1, len0 + LDP.P 32(R2), (R4, R5) + LDP -16(R2), (R6, R7) + STP.P (R4, R5), 32(R0) + STP (R6, R7), -16(R0) + SUB $4, R1 + B cloop +len0: + MOVD $0, c+56(FP) + RET + + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + MOVD z+0(FP), R1 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R2 + MOVD y+48(FP), R3 + MOVD r+56(FP), R4 + // c, z = x * y + r + TBZ $0, R0, two + MOVD.P 8(R2), R5 + MUL R3, R5, R7 + UMULH R3, R5, R8 + ADDS R4, R7 + ADC $0, R8, R4 // c, z[i] = x[i] * y + r + MOVD.P R7, 8(R1) + SUB $1, R0 +two: + TBZ $1, R0, loop + LDP.P 16(R2), (R5, R6) + MUL R3, R5, R10 + UMULH R3, R5, R11 + ADDS R4, R10 + MUL R3, R6, R12 + UMULH R3, R6, R13 + ADCS R12, R11 + ADC $0, R13, R4 + + STP.P (R10, R11), 16(R1) + SUB $2, R0 +loop: + CBZ R0, done + LDP.P 32(R2), (R5, R6) + LDP -16(R2), (R7, R8) + + MUL R3, R5, R10 + UMULH R3, R5, R11 + ADDS R4, R10 + MUL R3, R6, R12 + UMULH R3, R6, R13 + ADCS R11, R12 + + MUL R3, R7, R14 + UMULH R3, R7, R15 + ADCS R13, R14 + MUL R3, R8, R16 + UMULH R3, R8, R17 + ADCS R15, R16 + ADC $0, R17, R4 + + STP.P (R10, R12), 32(R1) + STP (R14, R16), -16(R1) + SUB $4, R0 + B loop +done: + MOVD R4, c+64(FP) + RET + + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB),NOSPLIT,$0 + MOVD z+0(FP), R1 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R2 + MOVD y+48(FP), R3 + MOVD $0, R4 + + TBZ $0, R0, two + + MOVD.P 8(R2), R5 + MOVD (R1), R6 + + MUL R5, R3, R7 + UMULH R5, R3, R8 + + ADDS R7, R6 + ADC $0, R8, R4 + + MOVD.P R6, 8(R1) + SUB $1, R0 + +two: + TBZ $1, R0, loop + + LDP.P 16(R2), (R5, R10) + LDP (R1), (R6, R11) + + MUL R10, R3, R13 + UMULH R10, R3, R12 + + MUL R5, R3, R7 + UMULH R5, R3, R8 + + ADDS R4, R6 + ADCS R13, R11 + ADC $0, R12 + + ADDS R7, R6 + ADCS R8, R11 + ADC $0, R12, R4 + + STP.P (R6, R11), 16(R1) + SUB $2, R0 + +// The main loop of this code operates on a block of 4 words every iteration +// performing [R4:R12:R11:R10:R9] = R4 + R3 * [R8:R7:R6:R5] + [R12:R11:R10:R9] +// where R4 is carried from the previous iteration, R8:R7:R6:R5 hold the next +// 4 words of x, R3 is y and R12:R11:R10:R9 are part of the result z. +loop: + CBZ R0, done + + LDP.P 16(R2), (R5, R6) + LDP.P 16(R2), (R7, R8) + + LDP (R1), (R9, R10) + ADDS R4, R9 + MUL R6, R3, R14 + ADCS R14, R10 + MUL R7, R3, R15 + LDP 16(R1), (R11, R12) + ADCS R15, R11 + MUL R8, R3, R16 + ADCS R16, R12 + UMULH R8, R3, R20 + ADC $0, R20 + + MUL R5, R3, R13 + ADDS R13, R9 + UMULH R5, R3, R17 + ADCS R17, R10 + UMULH R6, R3, R21 + STP.P (R9, R10), 16(R1) + ADCS R21, R11 + UMULH R7, R3, R19 + ADCS R19, R12 + STP.P (R11, R12), 16(R1) + ADC $0, R20, R4 + + SUB $4, R0 + B loop + +done: + MOVD R4, c+56(FP) + RET + + diff --git a/src/math/big/arith_decl.go b/src/math/big/arith_decl.go new file mode 100644 index 0000000..9b254f2 --- /dev/null +++ b/src/math/big/arith_decl.go @@ -0,0 +1,34 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go +// +build !math_big_pure_go + +package big + +// implemented in arith_$GOARCH.s + +//go:noescape +func addVV(z, x, y []Word) (c Word) + +//go:noescape +func subVV(z, x, y []Word) (c Word) + +//go:noescape +func addVW(z, x []Word, y Word) (c Word) + +//go:noescape +func subVW(z, x []Word, y Word) (c Word) + +//go:noescape +func shlVU(z, x []Word, s uint) (c Word) + +//go:noescape +func shrVU(z, x []Word, s uint) (c Word) + +//go:noescape +func mulAddVWW(z, x []Word, y, r Word) (c Word) + +//go:noescape +func addMulVVW(z, x []Word, y Word) (c Word) diff --git a/src/math/big/arith_decl_pure.go b/src/math/big/arith_decl_pure.go new file mode 100644 index 0000000..75f3ed2 --- /dev/null +++ b/src/math/big/arith_decl_pure.go @@ -0,0 +1,50 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build math_big_pure_go +// +build math_big_pure_go + +package big + +func addVV(z, x, y []Word) (c Word) { + return addVV_g(z, x, y) +} + +func subVV(z, x, y []Word) (c Word) { + return subVV_g(z, x, y) +} + +func addVW(z, x []Word, y Word) (c Word) { + // TODO: remove indirect function call when golang.org/issue/30548 is fixed + fn := addVW_g + if len(z) > 32 { + fn = addVWlarge + } + return fn(z, x, y) +} + +func subVW(z, x []Word, y Word) (c Word) { + // TODO: remove indirect function call when golang.org/issue/30548 is fixed + fn := subVW_g + if len(z) > 32 { + fn = subVWlarge + } + return fn(z, x, y) +} + +func shlVU(z, x []Word, s uint) (c Word) { + return shlVU_g(z, x, s) +} + +func shrVU(z, x []Word, s uint) (c Word) { + return shrVU_g(z, x, s) +} + +func mulAddVWW(z, x []Word, y, r Word) (c Word) { + return mulAddVWW_g(z, x, y, r) +} + +func addMulVVW(z, x []Word, y Word) (c Word) { + return addMulVVW_g(z, x, y) +} diff --git a/src/math/big/arith_decl_s390x.go b/src/math/big/arith_decl_s390x.go new file mode 100644 index 0000000..4193f32 --- /dev/null +++ b/src/math/big/arith_decl_s390x.go @@ -0,0 +1,19 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go +// +build !math_big_pure_go + +package big + +import "internal/cpu" + +func addVV_check(z, x, y []Word) (c Word) +func addVV_vec(z, x, y []Word) (c Word) +func addVV_novec(z, x, y []Word) (c Word) +func subVV_check(z, x, y []Word) (c Word) +func subVV_vec(z, x, y []Word) (c Word) +func subVV_novec(z, x, y []Word) (c Word) + +var hasVX = cpu.S390X.HasVX diff --git a/src/math/big/arith_loong64.s b/src/math/big/arith_loong64.s new file mode 100644 index 0000000..0ae3031 --- /dev/null +++ b/src/math/big/arith_loong64.s @@ -0,0 +1,34 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go,loong64 + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +TEXT ·addVV(SB),NOSPLIT,$0 + JMP ·addVV_g(SB) + +TEXT ·subVV(SB),NOSPLIT,$0 + JMP ·subVV_g(SB) + +TEXT ·addVW(SB),NOSPLIT,$0 + JMP ·addVW_g(SB) + +TEXT ·subVW(SB),NOSPLIT,$0 + JMP ·subVW_g(SB) + +TEXT ·shlVU(SB),NOSPLIT,$0 + JMP ·shlVU_g(SB) + +TEXT ·shrVU(SB),NOSPLIT,$0 + JMP ·shrVU_g(SB) + +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + JMP ·mulAddVWW_g(SB) + +TEXT ·addMulVVW(SB),NOSPLIT,$0 + JMP ·addMulVVW_g(SB) diff --git a/src/math/big/arith_mips64x.s b/src/math/big/arith_mips64x.s new file mode 100644 index 0000000..3ee6e27 --- /dev/null +++ b/src/math/big/arith_mips64x.s @@ -0,0 +1,37 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go && (mips64 || mips64le) +// +build !math_big_pure_go +// +build mips64 mips64le + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +TEXT ·addVV(SB),NOSPLIT,$0 + JMP ·addVV_g(SB) + +TEXT ·subVV(SB),NOSPLIT,$0 + JMP ·subVV_g(SB) + +TEXT ·addVW(SB),NOSPLIT,$0 + JMP ·addVW_g(SB) + +TEXT ·subVW(SB),NOSPLIT,$0 + JMP ·subVW_g(SB) + +TEXT ·shlVU(SB),NOSPLIT,$0 + JMP ·shlVU_g(SB) + +TEXT ·shrVU(SB),NOSPLIT,$0 + JMP ·shrVU_g(SB) + +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + JMP ·mulAddVWW_g(SB) + +TEXT ·addMulVVW(SB),NOSPLIT,$0 + JMP ·addMulVVW_g(SB) + diff --git a/src/math/big/arith_mipsx.s b/src/math/big/arith_mipsx.s new file mode 100644 index 0000000..b1d3282 --- /dev/null +++ b/src/math/big/arith_mipsx.s @@ -0,0 +1,37 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go && (mips || mipsle) +// +build !math_big_pure_go +// +build mips mipsle + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +TEXT ·addVV(SB),NOSPLIT,$0 + JMP ·addVV_g(SB) + +TEXT ·subVV(SB),NOSPLIT,$0 + JMP ·subVV_g(SB) + +TEXT ·addVW(SB),NOSPLIT,$0 + JMP ·addVW_g(SB) + +TEXT ·subVW(SB),NOSPLIT,$0 + JMP ·subVW_g(SB) + +TEXT ·shlVU(SB),NOSPLIT,$0 + JMP ·shlVU_g(SB) + +TEXT ·shrVU(SB),NOSPLIT,$0 + JMP ·shrVU_g(SB) + +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + JMP ·mulAddVWW_g(SB) + +TEXT ·addMulVVW(SB),NOSPLIT,$0 + JMP ·addMulVVW_g(SB) + diff --git a/src/math/big/arith_ppc64x.s b/src/math/big/arith_ppc64x.s new file mode 100644 index 0000000..5fdbf40 --- /dev/null +++ b/src/math/big/arith_ppc64x.s @@ -0,0 +1,633 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go && (ppc64 || ppc64le) +// +build !math_big_pure_go +// +build ppc64 ppc64le + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// func addVV(z, y, y []Word) (c Word) +// z[i] = x[i] + y[i] for all i, carrying +TEXT ·addVV(SB), NOSPLIT, $0 + MOVD z_len+8(FP), R7 // R7 = z_len + MOVD x+24(FP), R8 // R8 = x[] + MOVD y+48(FP), R9 // R9 = y[] + MOVD z+0(FP), R10 // R10 = z[] + + // If z_len = 0, we are done + CMP R0, R7 + MOVD R0, R4 + BEQ done + + // Process the first iteration out of the loop so we can + // use MOVDU and avoid 3 index registers updates. + MOVD 0(R8), R11 // R11 = x[i] + MOVD 0(R9), R12 // R12 = y[i] + ADD $-1, R7 // R7 = z_len - 1 + ADDC R12, R11, R15 // R15 = x[i] + y[i], set CA + CMP R0, R7 + MOVD R15, 0(R10) // z[i] + BEQ final // If z_len was 1, we are done + + SRD $2, R7, R5 // R5 = z_len/4 + CMP R0, R5 + MOVD R5, CTR // Set up loop counter + BEQ tail // If R5 = 0, we can't use the loop + + // Process 4 elements per iteration. Unrolling this loop + // means a performance trade-off: we will lose performance + // for small values of z_len (0.90x in the worst case), but + // gain significant performance as z_len increases (up to + // 1.45x). + + PCALIGN $32 +loop: + MOVD 8(R8), R11 // R11 = x[i] + MOVD 16(R8), R12 // R12 = x[i+1] + MOVD 24(R8), R14 // R14 = x[i+2] + MOVDU 32(R8), R15 // R15 = x[i+3] + MOVD 8(R9), R16 // R16 = y[i] + MOVD 16(R9), R17 // R17 = y[i+1] + MOVD 24(R9), R18 // R18 = y[i+2] + MOVDU 32(R9), R19 // R19 = y[i+3] + ADDE R11, R16, R20 // R20 = x[i] + y[i] + CA + ADDE R12, R17, R21 // R21 = x[i+1] + y[i+1] + CA + ADDE R14, R18, R22 // R22 = x[i+2] + y[i+2] + CA + ADDE R15, R19, R23 // R23 = x[i+3] + y[i+3] + CA + MOVD R20, 8(R10) // z[i] + MOVD R21, 16(R10) // z[i+1] + MOVD R22, 24(R10) // z[i+2] + MOVDU R23, 32(R10) // z[i+3] + ADD $-4, R7 // R7 = z_len - 4 + BC 16, 0, loop // bdnz + + // We may have more elements to read + CMP R0, R7 + BEQ final + + // Process the remaining elements, one at a time +tail: + MOVDU 8(R8), R11 // R11 = x[i] + MOVDU 8(R9), R16 // R16 = y[i] + ADD $-1, R7 // R7 = z_len - 1 + ADDE R11, R16, R20 // R20 = x[i] + y[i] + CA + CMP R0, R7 + MOVDU R20, 8(R10) // z[i] + BEQ final // If R7 = 0, we are done + + MOVDU 8(R8), R11 + MOVDU 8(R9), R16 + ADD $-1, R7 + ADDE R11, R16, R20 + CMP R0, R7 + MOVDU R20, 8(R10) + BEQ final + + MOVD 8(R8), R11 + MOVD 8(R9), R16 + ADDE R11, R16, R20 + MOVD R20, 8(R10) + +final: + ADDZE R4 // Capture CA + +done: + MOVD R4, c+72(FP) + RET + +// func subVV(z, x, y []Word) (c Word) +// z[i] = x[i] - y[i] for all i, carrying +TEXT ·subVV(SB), NOSPLIT, $0 + MOVD z_len+8(FP), R7 // R7 = z_len + MOVD x+24(FP), R8 // R8 = x[] + MOVD y+48(FP), R9 // R9 = y[] + MOVD z+0(FP), R10 // R10 = z[] + + // If z_len = 0, we are done + CMP R0, R7 + MOVD R0, R4 + BEQ done + + // Process the first iteration out of the loop so we can + // use MOVDU and avoid 3 index registers updates. + MOVD 0(R8), R11 // R11 = x[i] + MOVD 0(R9), R12 // R12 = y[i] + ADD $-1, R7 // R7 = z_len - 1 + SUBC R12, R11, R15 // R15 = x[i] - y[i], set CA + CMP R0, R7 + MOVD R15, 0(R10) // z[i] + BEQ final // If z_len was 1, we are done + + SRD $2, R7, R5 // R5 = z_len/4 + CMP R0, R5 + MOVD R5, CTR // Set up loop counter + BEQ tail // If R5 = 0, we can't use the loop + + // Process 4 elements per iteration. Unrolling this loop + // means a performance trade-off: we will lose performance + // for small values of z_len (0.92x in the worst case), but + // gain significant performance as z_len increases (up to + // 1.45x). + + PCALIGN $32 +loop: + MOVD 8(R8), R11 // R11 = x[i] + MOVD 16(R8), R12 // R12 = x[i+1] + MOVD 24(R8), R14 // R14 = x[i+2] + MOVDU 32(R8), R15 // R15 = x[i+3] + MOVD 8(R9), R16 // R16 = y[i] + MOVD 16(R9), R17 // R17 = y[i+1] + MOVD 24(R9), R18 // R18 = y[i+2] + MOVDU 32(R9), R19 // R19 = y[i+3] + SUBE R16, R11, R20 // R20 = x[i] - y[i] + CA + SUBE R17, R12, R21 // R21 = x[i+1] - y[i+1] + CA + SUBE R18, R14, R22 // R22 = x[i+2] - y[i+2] + CA + SUBE R19, R15, R23 // R23 = x[i+3] - y[i+3] + CA + MOVD R20, 8(R10) // z[i] + MOVD R21, 16(R10) // z[i+1] + MOVD R22, 24(R10) // z[i+2] + MOVDU R23, 32(R10) // z[i+3] + ADD $-4, R7 // R7 = z_len - 4 + BC 16, 0, loop // bdnz + + // We may have more elements to read + CMP R0, R7 + BEQ final + + // Process the remaining elements, one at a time +tail: + MOVDU 8(R8), R11 // R11 = x[i] + MOVDU 8(R9), R16 // R16 = y[i] + ADD $-1, R7 // R7 = z_len - 1 + SUBE R16, R11, R20 // R20 = x[i] - y[i] + CA + CMP R0, R7 + MOVDU R20, 8(R10) // z[i] + BEQ final // If R7 = 0, we are done + + MOVDU 8(R8), R11 + MOVDU 8(R9), R16 + ADD $-1, R7 + SUBE R16, R11, R20 + CMP R0, R7 + MOVDU R20, 8(R10) + BEQ final + + MOVD 8(R8), R11 + MOVD 8(R9), R16 + SUBE R16, R11, R20 + MOVD R20, 8(R10) + +final: + ADDZE R4 + XOR $1, R4 + +done: + MOVD R4, c+72(FP) + RET + +// func addVW(z, x []Word, y Word) (c Word) +TEXT ·addVW(SB), NOSPLIT, $0 + MOVD z+0(FP), R10 // R10 = z[] + MOVD x+24(FP), R8 // R8 = x[] + MOVD y+48(FP), R4 // R4 = y = c + MOVD z_len+8(FP), R11 // R11 = z_len + + CMP R0, R11 // If z_len is zero, return + BEQ done + + // We will process the first iteration out of the loop so we capture + // the value of c. In the subsequent iterations, we will rely on the + // value of CA set here. + MOVD 0(R8), R20 // R20 = x[i] + ADD $-1, R11 // R11 = z_len - 1 + ADDC R20, R4, R6 // R6 = x[i] + c + CMP R0, R11 // If z_len was 1, we are done + MOVD R6, 0(R10) // z[i] + BEQ final + + // We will read 4 elements per iteration + SRD $2, R11, R9 // R9 = z_len/4 + DCBT (R8) + CMP R0, R9 + MOVD R9, CTR // Set up the loop counter + BEQ tail // If R9 = 0, we can't use the loop + PCALIGN $32 + +loop: + MOVD 8(R8), R20 // R20 = x[i] + MOVD 16(R8), R21 // R21 = x[i+1] + MOVD 24(R8), R22 // R22 = x[i+2] + MOVDU 32(R8), R23 // R23 = x[i+3] + ADDZE R20, R24 // R24 = x[i] + CA + ADDZE R21, R25 // R25 = x[i+1] + CA + ADDZE R22, R26 // R26 = x[i+2] + CA + ADDZE R23, R27 // R27 = x[i+3] + CA + MOVD R24, 8(R10) // z[i] + MOVD R25, 16(R10) // z[i+1] + MOVD R26, 24(R10) // z[i+2] + MOVDU R27, 32(R10) // z[i+3] + ADD $-4, R11 // R11 = z_len - 4 + BC 16, 0, loop // bdnz + + // We may have some elements to read + CMP R0, R11 + BEQ final + +tail: + MOVDU 8(R8), R20 + ADDZE R20, R24 + ADD $-1, R11 + MOVDU R24, 8(R10) + CMP R0, R11 + BEQ final + + MOVDU 8(R8), R20 + ADDZE R20, R24 + ADD $-1, R11 + MOVDU R24, 8(R10) + CMP R0, R11 + BEQ final + + MOVD 8(R8), R20 + ADDZE R20, R24 + MOVD R24, 8(R10) + +final: + ADDZE R0, R4 // c = CA +done: + MOVD R4, c+56(FP) + RET + +// func subVW(z, x []Word, y Word) (c Word) +TEXT ·subVW(SB), NOSPLIT, $0 + MOVD z+0(FP), R10 // R10 = z[] + MOVD x+24(FP), R8 // R8 = x[] + MOVD y+48(FP), R4 // R4 = y = c + MOVD z_len+8(FP), R11 // R11 = z_len + + CMP R0, R11 // If z_len is zero, return + BEQ done + + // We will process the first iteration out of the loop so we capture + // the value of c. In the subsequent iterations, we will rely on the + // value of CA set here. + MOVD 0(R8), R20 // R20 = x[i] + ADD $-1, R11 // R11 = z_len - 1 + SUBC R4, R20, R6 // R6 = x[i] - c + CMP R0, R11 // If z_len was 1, we are done + MOVD R6, 0(R10) // z[i] + BEQ final + + // We will read 4 elements per iteration + SRD $2, R11, R9 // R9 = z_len/4 + DCBT (R8) + CMP R0, R9 + MOVD R9, CTR // Set up the loop counter + BEQ tail // If R9 = 0, we can't use the loop + + // The loop here is almost the same as the one used in s390x, but + // we don't need to capture CA every iteration because we've already + // done that above. + + PCALIGN $32 +loop: + MOVD 8(R8), R20 + MOVD 16(R8), R21 + MOVD 24(R8), R22 + MOVDU 32(R8), R23 + SUBE R0, R20 + SUBE R0, R21 + SUBE R0, R22 + SUBE R0, R23 + MOVD R20, 8(R10) + MOVD R21, 16(R10) + MOVD R22, 24(R10) + MOVDU R23, 32(R10) + ADD $-4, R11 + BC 16, 0, loop // bdnz + + // We may have some elements to read + CMP R0, R11 + BEQ final + +tail: + MOVDU 8(R8), R20 + SUBE R0, R20 + ADD $-1, R11 + MOVDU R20, 8(R10) + CMP R0, R11 + BEQ final + + MOVDU 8(R8), R20 + SUBE R0, R20 + ADD $-1, R11 + MOVDU R20, 8(R10) + CMP R0, R11 + BEQ final + + MOVD 8(R8), R20 + SUBE R0, R20 + MOVD R20, 8(R10) + +final: + // Capture CA + SUBE R4, R4 + NEG R4, R4 + +done: + MOVD R4, c+56(FP) + RET + +//func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB), NOSPLIT, $0 + MOVD z+0(FP), R3 + MOVD x+24(FP), R6 + MOVD s+48(FP), R9 + MOVD z_len+8(FP), R4 + MOVD x_len+32(FP), R7 + CMP R9, R0 // s==0 copy(z,x) + BEQ zeroshift + CMP R4, R0 // len(z)==0 return + BEQ done + + ADD $-1, R4, R5 // len(z)-1 + SUBC R9, $64, R4 // ŝ=_W-s, we skip & by _W-1 as the caller ensures s < _W(64) + SLD $3, R5, R7 + ADD R6, R7, R15 // save starting address &x[len(z)-1] + ADD R3, R7, R16 // save starting address &z[len(z)-1] + MOVD (R6)(R7), R14 + SRD R4, R14, R7 // compute x[len(z)-1]>>ŝ into R7 + CMP R5, R0 // iterate from i=len(z)-1 to 0 + BEQ loopexit // Already at end? + MOVD 0(R15),R10 // x[i] + PCALIGN $32 +shloop: + SLD R9, R10, R10 // x[i]<<s + MOVDU -8(R15), R14 + SRD R4, R14, R11 // x[i-1]>>ŝ + OR R11, R10, R10 + MOVD R10, 0(R16) // z[i-1]=x[i]<<s | x[i-1]>>ŝ + MOVD R14, R10 // reuse x[i-1] for next iteration + ADD $-8, R16 // i-- + CMP R15, R6 // &x[i-1]>&x[0]? + BGT shloop +loopexit: + MOVD 0(R6), R4 + SLD R9, R4, R4 + MOVD R4, 0(R3) // z[0]=x[0]<<s + MOVD R7, c+56(FP) // store pre-computed x[len(z)-1]>>ŝ into c + RET + +zeroshift: + CMP R6, R0 // x is null, nothing to copy + BEQ done + CMP R6, R3 // if x is same as z, nothing to copy + BEQ done + CMP R7, R4 + ISEL $0, R7, R4, R7 // Take the lower bound of lengths of x,z + SLD $3, R7, R7 + SUB R6, R3, R11 // dest - src + CMPU R11, R7, CR2 // < len? + BLT CR2, backward // there is overlap, copy backwards + MOVD $0, R14 + // shlVU processes backwards, but added a forward copy option + // since its faster on POWER +repeat: + MOVD (R6)(R14), R15 // Copy 8 bytes at a time + MOVD R15, (R3)(R14) + ADD $8, R14 + CMP R14, R7 // More 8 bytes left? + BLT repeat + BR done +backward: + ADD $-8,R7, R14 +repeatback: + MOVD (R6)(R14), R15 // copy x into z backwards + MOVD R15, (R3)(R14) // copy 8 bytes at a time + SUB $8, R14 + CMP R14, $-8 // More 8 bytes left? + BGT repeatback + +done: + MOVD R0, c+56(FP) // c=0 + RET + +//func shrVU(z, x []Word, s uint) (c Word) +TEXT ·shrVU(SB), NOSPLIT, $0 + MOVD z+0(FP), R3 + MOVD x+24(FP), R6 + MOVD s+48(FP), R9 + MOVD z_len+8(FP), R4 + MOVD x_len+32(FP), R7 + + CMP R9, R0 // s==0, copy(z,x) + BEQ zeroshift + CMP R4, R0 // len(z)==0 return + BEQ done + SUBC R9, $64, R5 // ŝ=_W-s, we skip & by _W-1 as the caller ensures s < _W(64) + + MOVD 0(R6), R7 + SLD R5, R7, R7 // compute x[0]<<ŝ + MOVD $1, R8 // iterate from i=1 to i<len(z) + CMP R8, R4 + BGE loopexit // Already at end? + + // vectorize if len(z) is >=3, else jump to scalar loop + CMP R4, $3 + BLT scalar + MTVSRD R9, VS38 // s + VSPLTB $7, V6, V4 + MTVSRD R5, VS39 // ŝ + VSPLTB $7, V7, V2 + ADD $-2, R4, R16 + PCALIGN $16 +loopback: + ADD $-1, R8, R10 + SLD $3, R10 + LXVD2X (R6)(R10), VS32 // load x[i-1], x[i] + SLD $3, R8, R12 + LXVD2X (R6)(R12), VS33 // load x[i], x[i+1] + + VSRD V0, V4, V3 // x[i-1]>>s, x[i]>>s + VSLD V1, V2, V5 // x[i]<<ŝ, x[i+1]<<ŝ + VOR V3, V5, V5 // Or(|) the two registers together + STXVD2X VS37, (R3)(R10) // store into z[i-1] and z[i] + ADD $2, R8 // Done processing 2 entries, i and i+1 + CMP R8, R16 // Are there at least a couple of more entries left? + BLE loopback + CMP R8, R4 // Are we at the last element? + BEQ loopexit +scalar: + ADD $-1, R8, R10 + SLD $3, R10 + MOVD (R6)(R10),R11 + SRD R9, R11, R11 // x[len(z)-2] >> s + SLD $3, R8, R12 + MOVD (R6)(R12), R12 + SLD R5, R12, R12 // x[len(z)-1]<<ŝ + OR R12, R11, R11 // x[len(z)-2]>>s | x[len(z)-1]<<ŝ + MOVD R11, (R3)(R10) // z[len(z)-2]=x[len(z)-2]>>s | x[len(z)-1]<<ŝ +loopexit: + ADD $-1, R4 + SLD $3, R4 + MOVD (R6)(R4), R5 + SRD R9, R5, R5 // x[len(z)-1]>>s + MOVD R5, (R3)(R4) // z[len(z)-1]=x[len(z)-1]>>s + MOVD R7, c+56(FP) // store pre-computed x[0]<<ŝ into c + RET + +zeroshift: + CMP R6, R0 // x is null, nothing to copy + BEQ done + CMP R6, R3 // if x is same as z, nothing to copy + BEQ done + CMP R7, R4 + ISEL $0, R7, R4, R7 // Take the lower bounds of lengths of x, z + SLD $3, R7, R7 + MOVD $0, R14 +repeat: + MOVD (R6)(R14), R15 // copy 8 bytes at a time + MOVD R15, (R3)(R14) // shrVU processes bytes only forwards + ADD $8, R14 + CMP R14, R7 // More 8 bytes left? + BLT repeat +done: + MOVD R0, c+56(FP) + RET + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB), NOSPLIT, $0 + MOVD z+0(FP), R10 // R10 = z[] + MOVD x+24(FP), R8 // R8 = x[] + MOVD y+48(FP), R9 // R9 = y + MOVD r+56(FP), R4 // R4 = r = c + MOVD z_len+8(FP), R11 // R11 = z_len + + CMP R0, R11 + BEQ done + + MOVD 0(R8), R20 + ADD $-1, R11 + MULLD R9, R20, R6 // R6 = z0 = Low-order(x[i]*y) + MULHDU R9, R20, R7 // R7 = z1 = High-order(x[i]*y) + ADDC R4, R6 // R6 = z0 + r + ADDZE R7 // R7 = z1 + CA + CMP R0, R11 + MOVD R7, R4 // R4 = c + MOVD R6, 0(R10) // z[i] + BEQ done + + // We will read 4 elements per iteration + SRD $2, R11, R14 // R14 = z_len/4 + DCBT (R8) + CMP R0, R14 + MOVD R14, CTR // Set up the loop counter + BEQ tail // If R9 = 0, we can't use the loop + PCALIGN $32 + +loop: + MOVD 8(R8), R20 // R20 = x[i] + MOVD 16(R8), R21 // R21 = x[i+1] + MOVD 24(R8), R22 // R22 = x[i+2] + MOVDU 32(R8), R23 // R23 = x[i+3] + MULLD R9, R20, R24 // R24 = z0[i] + MULHDU R9, R20, R20 // R20 = z1[i] + ADDC R4, R24 // R24 = z0[i] + c + ADDZE R20 // R7 = z1[i] + CA + MULLD R9, R21, R25 + MULHDU R9, R21, R21 + ADDC R20, R25 + ADDZE R21 + MULLD R9, R22, R26 + MULHDU R9, R22, R22 + MULLD R9, R23, R27 + MULHDU R9, R23, R23 + ADDC R21, R26 + ADDZE R22 + MOVD R24, 8(R10) // z[i] + MOVD R25, 16(R10) // z[i+1] + ADDC R22, R27 + ADDZE R23,R4 // update carry + MOVD R26, 24(R10) // z[i+2] + MOVDU R27, 32(R10) // z[i+3] + ADD $-4, R11 // R11 = z_len - 4 + BC 16, 0, loop // bdnz + + // We may have some elements to read + CMP R0, R11 + BEQ done + + // Process the remaining elements, one at a time +tail: + MOVDU 8(R8), R20 // R20 = x[i] + MULLD R9, R20, R24 // R24 = z0[i] + MULHDU R9, R20, R25 // R25 = z1[i] + ADD $-1, R11 // R11 = z_len - 1 + ADDC R4, R24 + ADDZE R25 + MOVDU R24, 8(R10) // z[i] + CMP R0, R11 + MOVD R25, R4 // R4 = c + BEQ done // If R11 = 0, we are done + + MOVDU 8(R8), R20 + MULLD R9, R20, R24 + MULHDU R9, R20, R25 + ADD $-1, R11 + ADDC R4, R24 + ADDZE R25 + MOVDU R24, 8(R10) + CMP R0, R11 + MOVD R25, R4 + BEQ done + + MOVD 8(R8), R20 + MULLD R9, R20, R24 + MULHDU R9, R20, R25 + ADD $-1, R11 + ADDC R4, R24 + ADDZE R25 + MOVD R24, 8(R10) + MOVD R25, R4 + +done: + MOVD R4, c+64(FP) + RET + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB), NOSPLIT, $0 + MOVD z+0(FP), R10 // R10 = z[] + MOVD x+24(FP), R8 // R8 = x[] + MOVD y+48(FP), R9 // R9 = y + MOVD z_len+8(FP), R22 // R22 = z_len + + MOVD R0, R3 // R3 will be the index register + CMP R0, R22 + MOVD R0, R4 // R4 = c = 0 + MOVD R22, CTR // Initialize loop counter + BEQ done + PCALIGN $32 + +loop: + MOVD (R8)(R3), R20 // Load x[i] + MOVD (R10)(R3), R21 // Load z[i] + MULLD R9, R20, R6 // R6 = Low-order(x[i]*y) + MULHDU R9, R20, R7 // R7 = High-order(x[i]*y) + ADDC R21, R6 // R6 = z0 + ADDZE R7 // R7 = z1 + ADDC R4, R6 // R6 = z0 + c + 0 + ADDZE R7, R4 // c += z1 + MOVD R6, (R10)(R3) // Store z[i] + ADD $8, R3 + BC 16, 0, loop // bdnz + +done: + MOVD R4, c+56(FP) + RET + + diff --git a/src/math/big/arith_riscv64.s b/src/math/big/arith_riscv64.s new file mode 100644 index 0000000..cb9ac18 --- /dev/null +++ b/src/math/big/arith_riscv64.s @@ -0,0 +1,36 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go && riscv64 +// +build !math_big_pure_go,riscv64 + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +TEXT ·addVV(SB),NOSPLIT,$0 + JMP ·addVV_g(SB) + +TEXT ·subVV(SB),NOSPLIT,$0 + JMP ·subVV_g(SB) + +TEXT ·addVW(SB),NOSPLIT,$0 + JMP ·addVW_g(SB) + +TEXT ·subVW(SB),NOSPLIT,$0 + JMP ·subVW_g(SB) + +TEXT ·shlVU(SB),NOSPLIT,$0 + JMP ·shlVU_g(SB) + +TEXT ·shrVU(SB),NOSPLIT,$0 + JMP ·shrVU_g(SB) + +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + JMP ·mulAddVWW_g(SB) + +TEXT ·addMulVVW(SB),NOSPLIT,$0 + JMP ·addMulVVW_g(SB) + diff --git a/src/math/big/arith_s390x.s b/src/math/big/arith_s390x.s new file mode 100644 index 0000000..aa6590e --- /dev/null +++ b/src/math/big/arith_s390x.s @@ -0,0 +1,787 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go +// +build !math_big_pure_go + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// DI = R3, CX = R4, SI = r10, r8 = r8, r9=r9, r10 = r2, r11 = r5, r12 = r6, r13 = r7, r14 = r1 (R0 set to 0) + use R11 +// func addVV(z, x, y []Word) (c Word) + +TEXT ·addVV(SB), NOSPLIT, $0 + MOVD addvectorfacility+0x00(SB), R1 + BR (R1) + +TEXT ·addVV_check(SB), NOSPLIT, $0 + MOVB ·hasVX(SB), R1 + CMPBEQ R1, $1, vectorimpl // vectorfacility = 1, vector supported + MOVD $addvectorfacility+0x00(SB), R1 + MOVD $·addVV_novec(SB), R2 + MOVD R2, 0(R1) + + // MOVD $·addVV_novec(SB), 0(R1) + BR ·addVV_novec(SB) + +vectorimpl: + MOVD $addvectorfacility+0x00(SB), R1 + MOVD $·addVV_vec(SB), R2 + MOVD R2, 0(R1) + + // MOVD $·addVV_vec(SB), 0(R1) + BR ·addVV_vec(SB) + +GLOBL addvectorfacility+0x00(SB), NOPTR, $8 +DATA addvectorfacility+0x00(SB)/8, $·addVV_check(SB) + +TEXT ·addVV_vec(SB), NOSPLIT, $0 + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R2 + + MOVD $0, R4 // c = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 + BLT v1 + SUB $12, R3 // n -= 16 + BLT A1 // if n < 0 goto A1 + + MOVD R8, R5 + MOVD R9, R6 + MOVD R2, R7 + + // n >= 0 + // regular loop body unrolled 16x + VZERO V0 // c = 0 + +UU1: + VLM 0(R5), V1, V4 // 64-bytes into V1..V8 + ADD $64, R5 + VPDI $0x4, V1, V1, V1 // flip the doublewords to big-endian order + VPDI $0x4, V2, V2, V2 // flip the doublewords to big-endian order + + VLM 0(R6), V9, V12 // 64-bytes into V9..V16 + ADD $64, R6 + VPDI $0x4, V9, V9, V9 // flip the doublewords to big-endian order + VPDI $0x4, V10, V10, V10 // flip the doublewords to big-endian order + + VACCCQ V1, V9, V0, V25 + VACQ V1, V9, V0, V17 + VACCCQ V2, V10, V25, V26 + VACQ V2, V10, V25, V18 + + VLM 0(R5), V5, V6 // 32-bytes into V1..V8 + VLM 0(R6), V13, V14 // 32-bytes into V9..V16 + ADD $32, R5 + ADD $32, R6 + + VPDI $0x4, V3, V3, V3 // flip the doublewords to big-endian order + VPDI $0x4, V4, V4, V4 // flip the doublewords to big-endian order + VPDI $0x4, V11, V11, V11 // flip the doublewords to big-endian order + VPDI $0x4, V12, V12, V12 // flip the doublewords to big-endian order + + VACCCQ V3, V11, V26, V27 + VACQ V3, V11, V26, V19 + VACCCQ V4, V12, V27, V28 + VACQ V4, V12, V27, V20 + + VLM 0(R5), V7, V8 // 32-bytes into V1..V8 + VLM 0(R6), V15, V16 // 32-bytes into V9..V16 + ADD $32, R5 + ADD $32, R6 + + VPDI $0x4, V5, V5, V5 // flip the doublewords to big-endian order + VPDI $0x4, V6, V6, V6 // flip the doublewords to big-endian order + VPDI $0x4, V13, V13, V13 // flip the doublewords to big-endian order + VPDI $0x4, V14, V14, V14 // flip the doublewords to big-endian order + + VACCCQ V5, V13, V28, V29 + VACQ V5, V13, V28, V21 + VACCCQ V6, V14, V29, V30 + VACQ V6, V14, V29, V22 + + VPDI $0x4, V7, V7, V7 // flip the doublewords to big-endian order + VPDI $0x4, V8, V8, V8 // flip the doublewords to big-endian order + VPDI $0x4, V15, V15, V15 // flip the doublewords to big-endian order + VPDI $0x4, V16, V16, V16 // flip the doublewords to big-endian order + + VACCCQ V7, V15, V30, V31 + VACQ V7, V15, V30, V23 + VACCCQ V8, V16, V31, V0 // V0 has carry-over + VACQ V8, V16, V31, V24 + + VPDI $0x4, V17, V17, V17 // flip the doublewords to big-endian order + VPDI $0x4, V18, V18, V18 // flip the doublewords to big-endian order + VPDI $0x4, V19, V19, V19 // flip the doublewords to big-endian order + VPDI $0x4, V20, V20, V20 // flip the doublewords to big-endian order + VPDI $0x4, V21, V21, V21 // flip the doublewords to big-endian order + VPDI $0x4, V22, V22, V22 // flip the doublewords to big-endian order + VPDI $0x4, V23, V23, V23 // flip the doublewords to big-endian order + VPDI $0x4, V24, V24, V24 // flip the doublewords to big-endian order + VSTM V17, V24, 0(R7) // 128-bytes into z + ADD $128, R7 + ADD $128, R10 // i += 16 + SUB $16, R3 // n -= 16 + BGE UU1 // if n >= 0 goto U1 + VLGVG $1, V0, R4 // put cf into R4 + NEG R4, R4 // save cf + +A1: + ADD $12, R3 // n += 16 + + // s/JL/JMP/ below to disable the unrolled loop + BLT v1 // if n < 0 goto v1 + +U1: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + ADDC R4, R4 // restore CF + MOVD 0(R9)(R10*1), R11 + ADDE R11, R5 + MOVD 8(R9)(R10*1), R11 + ADDE R11, R6 + MOVD 16(R9)(R10*1), R11 + ADDE R11, R7 + MOVD 24(R9)(R10*1), R11 + ADDE R11, R1 + MOVD R0, R4 + ADDE R4, R4 // save CF + NEG R4, R4 + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + ADD $32, R10 // i += 4 + SUB $4, R3 // n -= 4 + BGE U1 // if n >= 0 goto U1 + +v1: + ADD $4, R3 // n += 4 + BLE E1 // if n <= 0 goto E1 + +L1: // n > 0 + ADDC R4, R4 // restore CF + MOVD 0(R8)(R10*1), R5 + MOVD 0(R9)(R10*1), R11 + ADDE R11, R5 + MOVD R5, 0(R2)(R10*1) + MOVD R0, R4 + ADDE R4, R4 // save CF + NEG R4, R4 + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L1 // if n > 0 goto L1 + +E1: + NEG R4, R4 + MOVD R4, c+72(FP) // return c + RET + +TEXT ·addVV_novec(SB), NOSPLIT, $0 +novec: + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R2 + + MOVD $0, R4 // c = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v1n // if n < 0 goto v1n + +U1n: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + ADDC R4, R4 // restore CF + MOVD 0(R9)(R10*1), R11 + ADDE R11, R5 + MOVD 8(R9)(R10*1), R11 + ADDE R11, R6 + MOVD 16(R9)(R10*1), R11 + ADDE R11, R7 + MOVD 24(R9)(R10*1), R11 + ADDE R11, R1 + MOVD R0, R4 + ADDE R4, R4 // save CF + NEG R4, R4 + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + ADD $32, R10 // i += 4 + SUB $4, R3 // n -= 4 + BGE U1n // if n >= 0 goto U1n + +v1n: + ADD $4, R3 // n += 4 + BLE E1n // if n <= 0 goto E1n + +L1n: // n > 0 + ADDC R4, R4 // restore CF + MOVD 0(R8)(R10*1), R5 + MOVD 0(R9)(R10*1), R11 + ADDE R11, R5 + MOVD R5, 0(R2)(R10*1) + MOVD R0, R4 + ADDE R4, R4 // save CF + NEG R4, R4 + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L1n // if n > 0 goto L1n + +E1n: + NEG R4, R4 + MOVD R4, c+72(FP) // return c + RET + +TEXT ·subVV(SB), NOSPLIT, $0 + MOVD subvectorfacility+0x00(SB), R1 + BR (R1) + +TEXT ·subVV_check(SB), NOSPLIT, $0 + MOVB ·hasVX(SB), R1 + CMPBEQ R1, $1, vectorimpl // vectorfacility = 1, vector supported + MOVD $subvectorfacility+0x00(SB), R1 + MOVD $·subVV_novec(SB), R2 + MOVD R2, 0(R1) + + // MOVD $·subVV_novec(SB), 0(R1) + BR ·subVV_novec(SB) + +vectorimpl: + MOVD $subvectorfacility+0x00(SB), R1 + MOVD $·subVV_vec(SB), R2 + MOVD R2, 0(R1) + + // MOVD $·subVV_vec(SB), 0(R1) + BR ·subVV_vec(SB) + +GLOBL subvectorfacility+0x00(SB), NOPTR, $8 +DATA subvectorfacility+0x00(SB)/8, $·subVV_check(SB) + +// DI = R3, CX = R4, SI = r10, r8 = r8, r9=r9, r10 = r2, r11 = r5, r12 = r6, r13 = r7, r14 = r1 (R0 set to 0) + use R11 +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SUBC/SUBE instead of ADDC/ADDE and label names) +TEXT ·subVV_vec(SB), NOSPLIT, $0 + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R2 + MOVD $0, R4 // c = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v1 // if n < 0 goto v1 + SUB $12, R3 // n -= 16 + BLT A1 // if n < 0 goto A1 + + MOVD R8, R5 + MOVD R9, R6 + MOVD R2, R7 + + // n >= 0 + // regular loop body unrolled 16x + VZERO V0 // cf = 0 + MOVD $1, R4 // for 390 subtraction cf starts as 1 (no borrow) + VLVGG $1, R4, V0 // put carry into V0 + +UU1: + VLM 0(R5), V1, V4 // 64-bytes into V1..V8 + ADD $64, R5 + VPDI $0x4, V1, V1, V1 // flip the doublewords to big-endian order + VPDI $0x4, V2, V2, V2 // flip the doublewords to big-endian order + + VLM 0(R6), V9, V12 // 64-bytes into V9..V16 + ADD $64, R6 + VPDI $0x4, V9, V9, V9 // flip the doublewords to big-endian order + VPDI $0x4, V10, V10, V10 // flip the doublewords to big-endian order + + VSBCBIQ V1, V9, V0, V25 + VSBIQ V1, V9, V0, V17 + VSBCBIQ V2, V10, V25, V26 + VSBIQ V2, V10, V25, V18 + + VLM 0(R5), V5, V6 // 32-bytes into V1..V8 + VLM 0(R6), V13, V14 // 32-bytes into V9..V16 + ADD $32, R5 + ADD $32, R6 + + VPDI $0x4, V3, V3, V3 // flip the doublewords to big-endian order + VPDI $0x4, V4, V4, V4 // flip the doublewords to big-endian order + VPDI $0x4, V11, V11, V11 // flip the doublewords to big-endian order + VPDI $0x4, V12, V12, V12 // flip the doublewords to big-endian order + + VSBCBIQ V3, V11, V26, V27 + VSBIQ V3, V11, V26, V19 + VSBCBIQ V4, V12, V27, V28 + VSBIQ V4, V12, V27, V20 + + VLM 0(R5), V7, V8 // 32-bytes into V1..V8 + VLM 0(R6), V15, V16 // 32-bytes into V9..V16 + ADD $32, R5 + ADD $32, R6 + + VPDI $0x4, V5, V5, V5 // flip the doublewords to big-endian order + VPDI $0x4, V6, V6, V6 // flip the doublewords to big-endian order + VPDI $0x4, V13, V13, V13 // flip the doublewords to big-endian order + VPDI $0x4, V14, V14, V14 // flip the doublewords to big-endian order + + VSBCBIQ V5, V13, V28, V29 + VSBIQ V5, V13, V28, V21 + VSBCBIQ V6, V14, V29, V30 + VSBIQ V6, V14, V29, V22 + + VPDI $0x4, V7, V7, V7 // flip the doublewords to big-endian order + VPDI $0x4, V8, V8, V8 // flip the doublewords to big-endian order + VPDI $0x4, V15, V15, V15 // flip the doublewords to big-endian order + VPDI $0x4, V16, V16, V16 // flip the doublewords to big-endian order + + VSBCBIQ V7, V15, V30, V31 + VSBIQ V7, V15, V30, V23 + VSBCBIQ V8, V16, V31, V0 // V0 has carry-over + VSBIQ V8, V16, V31, V24 + + VPDI $0x4, V17, V17, V17 // flip the doublewords to big-endian order + VPDI $0x4, V18, V18, V18 // flip the doublewords to big-endian order + VPDI $0x4, V19, V19, V19 // flip the doublewords to big-endian order + VPDI $0x4, V20, V20, V20 // flip the doublewords to big-endian order + VPDI $0x4, V21, V21, V21 // flip the doublewords to big-endian order + VPDI $0x4, V22, V22, V22 // flip the doublewords to big-endian order + VPDI $0x4, V23, V23, V23 // flip the doublewords to big-endian order + VPDI $0x4, V24, V24, V24 // flip the doublewords to big-endian order + VSTM V17, V24, 0(R7) // 128-bytes into z + ADD $128, R7 + ADD $128, R10 // i += 16 + SUB $16, R3 // n -= 16 + BGE UU1 // if n >= 0 goto U1 + VLGVG $1, V0, R4 // put cf into R4 + SUB $1, R4 // save cf + +A1: + ADD $12, R3 // n += 16 + BLT v1 // if n < 0 goto v1 + +U1: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + MOVD R0, R11 + SUBC R4, R11 // restore CF + MOVD 0(R9)(R10*1), R11 + SUBE R11, R5 + MOVD 8(R9)(R10*1), R11 + SUBE R11, R6 + MOVD 16(R9)(R10*1), R11 + SUBE R11, R7 + MOVD 24(R9)(R10*1), R11 + SUBE R11, R1 + MOVD R0, R4 + SUBE R4, R4 // save CF + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + ADD $32, R10 // i += 4 + SUB $4, R3 // n -= 4 + BGE U1 // if n >= 0 goto U1n + +v1: + ADD $4, R3 // n += 4 + BLE E1 // if n <= 0 goto E1 + +L1: // n > 0 + MOVD R0, R11 + SUBC R4, R11 // restore CF + MOVD 0(R8)(R10*1), R5 + MOVD 0(R9)(R10*1), R11 + SUBE R11, R5 + MOVD R5, 0(R2)(R10*1) + MOVD R0, R4 + SUBE R4, R4 // save CF + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L1 // if n > 0 goto L1n + +E1: + NEG R4, R4 + MOVD R4, c+72(FP) // return c + RET + +// DI = R3, CX = R4, SI = r10, r8 = r8, r9=r9, r10 = r2, r11 = r5, r12 = r6, r13 = r7, r14 = r1 (R0 set to 0) + use R11 +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SUBC/SUBE instead of ADDC/ADDE and label names) +TEXT ·subVV_novec(SB), NOSPLIT, $0 + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R2 + + MOVD $0, R4 // c = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v1 // if n < 0 goto v1 + +U1: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + MOVD R0, R11 + SUBC R4, R11 // restore CF + MOVD 0(R9)(R10*1), R11 + SUBE R11, R5 + MOVD 8(R9)(R10*1), R11 + SUBE R11, R6 + MOVD 16(R9)(R10*1), R11 + SUBE R11, R7 + MOVD 24(R9)(R10*1), R11 + SUBE R11, R1 + MOVD R0, R4 + SUBE R4, R4 // save CF + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + ADD $32, R10 // i += 4 + SUB $4, R3 // n -= 4 + BGE U1 // if n >= 0 goto U1 + +v1: + ADD $4, R3 // n += 4 + BLE E1 // if n <= 0 goto E1 + +L1: // n > 0 + MOVD R0, R11 + SUBC R4, R11 // restore CF + MOVD 0(R8)(R10*1), R5 + MOVD 0(R9)(R10*1), R11 + SUBE R11, R5 + MOVD R5, 0(R2)(R10*1) + MOVD R0, R4 + SUBE R4, R4 // save CF + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L1 // if n > 0 goto L1 + +E1: + NEG R4, R4 + MOVD R4, c+72(FP) // return c + RET + +TEXT ·addVW(SB), NOSPLIT, $0 + MOVD z_len+8(FP), R5 // length of z + MOVD x+24(FP), R6 + MOVD y+48(FP), R7 // c = y + MOVD z+0(FP), R8 + + CMPBEQ R5, $0, returnC // if len(z) == 0, we can have an early return + + // Add the first two words, and determine which path (copy path or loop path) to take based on the carry flag. + ADDC 0(R6), R7 + MOVD R7, 0(R8) + CMPBEQ R5, $1, returnResult // len(z) == 1 + MOVD $0, R9 + ADDE 8(R6), R9 + MOVD R9, 8(R8) + CMPBEQ R5, $2, returnResult // len(z) == 2 + + // Update the counters + MOVD $16, R12 // i = 2 + MOVD $-2(R5), R5 // n = n - 2 + +loopOverEachWord: + BRC $12, copySetup // carry = 0, copy the rest + MOVD $1, R9 + + // Originally we used the carry flag generated in the previous iteration + // (i.e: ADDE could be used here to do the addition). However, since we + // already know carry is 1 (otherwise we will go to copy section), we can use + // ADDC here so the current iteration does not depend on the carry flag + // generated in the previous iteration. This could be useful when branch prediction happens. + ADDC 0(R6)(R12*1), R9 + MOVD R9, 0(R8)(R12*1) // z[i] = x[i] + c + + MOVD $8(R12), R12 // i++ + BRCTG R5, loopOverEachWord // n-- + +// Return the current carry value +returnResult: + MOVD $0, R0 + ADDE R0, R0 + MOVD R0, c+56(FP) + RET + +// Update position of x(R6) and z(R8) based on the current counter value and perform copying. +// With the assumption that x and z will not overlap with each other or x and z will +// point to same memory region, we can use a faster version of copy using only MVC here. +// In the following implementation, we have three copy loops, each copying a word, 4 words, and +// 32 words at a time. Via benchmarking, this implementation is faster than calling runtime·memmove. +copySetup: + ADD R12, R6 + ADD R12, R8 + + CMPBGE R5, $4, mediumLoop + +smallLoop: // does a loop unrolling to copy word when n < 4 + CMPBEQ R5, $0, returnZero + MVC $8, 0(R6), 0(R8) + CMPBEQ R5, $1, returnZero + MVC $8, 8(R6), 8(R8) + CMPBEQ R5, $2, returnZero + MVC $8, 16(R6), 16(R8) + +returnZero: + MOVD $0, c+56(FP) // return 0 as carry + RET + +mediumLoop: + CMPBLT R5, $4, smallLoop + CMPBLT R5, $32, mediumLoopBody + +largeLoop: // Copying 256 bytes at a time. + MVC $256, 0(R6), 0(R8) + MOVD $256(R6), R6 + MOVD $256(R8), R8 + MOVD $-32(R5), R5 + CMPBGE R5, $32, largeLoop + BR mediumLoop + +mediumLoopBody: // Copying 32 bytes at a time + MVC $32, 0(R6), 0(R8) + MOVD $32(R6), R6 + MOVD $32(R8), R8 + MOVD $-4(R5), R5 + CMPBGE R5, $4, mediumLoopBody + BR smallLoop + +returnC: + MOVD R7, c+56(FP) + RET + +TEXT ·subVW(SB), NOSPLIT, $0 + MOVD z_len+8(FP), R5 + MOVD x+24(FP), R6 + MOVD y+48(FP), R7 // The borrow bit passed in + MOVD z+0(FP), R8 + MOVD $0, R0 // R0 is a temporary variable used during computation. Ensure it has zero in it. + + CMPBEQ R5, $0, returnC // len(z) == 0, have an early return + + // Subtract the first two words, and determine which path (copy path or loop path) to take based on the borrow flag + MOVD 0(R6), R9 + SUBC R7, R9 + MOVD R9, 0(R8) + CMPBEQ R5, $1, returnResult + MOVD 8(R6), R9 + SUBE R0, R9 + MOVD R9, 8(R8) + CMPBEQ R5, $2, returnResult + + // Update the counters + MOVD $16, R12 // i = 2 + MOVD $-2(R5), R5 // n = n - 2 + +loopOverEachWord: + BRC $3, copySetup // no borrow, copy the rest + MOVD 0(R6)(R12*1), R9 + + // Originally we used the borrow flag generated in the previous iteration + // (i.e: SUBE could be used here to do the subtraction). However, since we + // already know borrow is 1 (otherwise we will go to copy section), we can + // use SUBC here so the current iteration does not depend on the borrow flag + // generated in the previous iteration. This could be useful when branch prediction happens. + SUBC $1, R9 + MOVD R9, 0(R8)(R12*1) // z[i] = x[i] - 1 + + MOVD $8(R12), R12 // i++ + BRCTG R5, loopOverEachWord // n-- + +// return the current borrow value +returnResult: + SUBE R0, R0 + NEG R0, R0 + MOVD R0, c+56(FP) + RET + +// Update position of x(R6) and z(R8) based on the current counter value and perform copying. +// With the assumption that x and z will not overlap with each other or x and z will +// point to same memory region, we can use a faster version of copy using only MVC here. +// In the following implementation, we have three copy loops, each copying a word, 4 words, and +// 32 words at a time. Via benchmarking, this implementation is faster than calling runtime·memmove. +copySetup: + ADD R12, R6 + ADD R12, R8 + + CMPBGE R5, $4, mediumLoop + +smallLoop: // does a loop unrolling to copy word when n < 4 + CMPBEQ R5, $0, returnZero + MVC $8, 0(R6), 0(R8) + CMPBEQ R5, $1, returnZero + MVC $8, 8(R6), 8(R8) + CMPBEQ R5, $2, returnZero + MVC $8, 16(R6), 16(R8) + +returnZero: + MOVD $0, c+56(FP) // return 0 as borrow + RET + +mediumLoop: + CMPBLT R5, $4, smallLoop + CMPBLT R5, $32, mediumLoopBody + +largeLoop: // Copying 256 bytes at a time + MVC $256, 0(R6), 0(R8) + MOVD $256(R6), R6 + MOVD $256(R8), R8 + MOVD $-32(R5), R5 + CMPBGE R5, $32, largeLoop + BR mediumLoop + +mediumLoopBody: // Copying 32 bytes at a time + MVC $32, 0(R6), 0(R8) + MOVD $32(R6), R6 + MOVD $32(R8), R8 + MOVD $-4(R5), R5 + CMPBGE R5, $4, mediumLoopBody + BR smallLoop + +returnC: + MOVD R7, c+56(FP) + RET + +// func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB), NOSPLIT, $0 + BR ·shlVU_g(SB) + +// func shrVU(z, x []Word, s uint) (c Word) +TEXT ·shrVU(SB), NOSPLIT, $0 + BR ·shrVU_g(SB) + +// CX = R4, r8 = r8, r9=r9, r10 = r2, r11 = r5, DX = r3, AX = r6, BX = R1, (R0 set to 0) + use R11 + use R7 for i +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB), NOSPLIT, $0 + MOVD z+0(FP), R2 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD r+56(FP), R4 // c = r + MOVD z_len+8(FP), R5 + MOVD $0, R1 // i = 0 + MOVD $0, R7 // i*8 = 0 + MOVD $0, R0 // make sure it's zero + BR E5 + +L5: + MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + ADDC R4, R11 // add to low order bits + ADDE R0, R6 + MOVD R11, (R2)(R1*1) + MOVD R6, R4 + ADD $8, R1 // i*8 + 8 + ADD $1, R7 // i++ + +E5: + CMPBLT R7, R5, L5 // i < n + + MOVD R4, c+64(FP) + RET + +// func addMulVVW(z, x []Word, y Word) (c Word) +// CX = R4, r8 = r8, r9=r9, r10 = r2, r11 = r5, AX = r11, DX = R6, r12=r12, BX = R1, (R0 set to 0) + use R11 + use R7 for i +TEXT ·addMulVVW(SB), NOSPLIT, $0 + MOVD z+0(FP), R2 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z_len+8(FP), R5 + + MOVD $0, R1 // i*8 = 0 + MOVD $0, R7 // i = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R4 // c = 0 + + MOVD R5, R12 + AND $-2, R12 + CMPBGE R5, $2, A6 + BR E6 + +A6: + MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (R2)(R1*1), R10 + ADDC R10, R11 // add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (R2)(R1*1) + + MOVD (8)(R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (8)(R2)(R1*1), R10 + ADDC R10, R11 // add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (8)(R2)(R1*1) + + ADD $16, R1 // i*8 + 8 + ADD $2, R7 // i++ + + CMPBLT R7, R12, A6 + BR E6 + +L6: + MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (R2)(R1*1), R10 + ADDC R10, R11 // add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (R2)(R1*1) + + ADD $8, R1 // i*8 + 8 + ADD $1, R7 // i++ + +E6: + CMPBLT R7, R5, L6 // i < n + + MOVD R4, c+56(FP) + RET + diff --git a/src/math/big/arith_s390x_test.go b/src/math/big/arith_s390x_test.go new file mode 100644 index 0000000..8375ddb --- /dev/null +++ b/src/math/big/arith_s390x_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build s390x && !math_big_pure_go +// +build s390x,!math_big_pure_go + +package big + +import ( + "testing" +) + +// Tests whether the non vector routines are working, even when the tests are run on a +// vector-capable machine + +func TestFunVVnovec(t *testing.T) { + if hasVX == true { + for _, a := range sumVV { + arg := a + testFunVV(t, "addVV_novec", addVV_novec, arg) + + arg = argVV{a.z, a.y, a.x, a.c} + testFunVV(t, "addVV_novec symmetric", addVV_novec, arg) + + arg = argVV{a.x, a.z, a.y, a.c} + testFunVV(t, "subVV_novec", subVV_novec, arg) + + arg = argVV{a.y, a.z, a.x, a.c} + testFunVV(t, "subVV_novec symmetric", subVV_novec, arg) + } + } +} diff --git a/src/math/big/arith_test.go b/src/math/big/arith_test.go new file mode 100644 index 0000000..64225bb --- /dev/null +++ b/src/math/big/arith_test.go @@ -0,0 +1,697 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "fmt" + "internal/testenv" + "math/bits" + "math/rand" + "strings" + "testing" +) + +var isRaceBuilder = strings.HasSuffix(testenv.Builder(), "-race") + +type funVV func(z, x, y []Word) (c Word) +type argVV struct { + z, x, y nat + c Word +} + +var sumVV = []argVV{ + {}, + {nat{0}, nat{0}, nat{0}, 0}, + {nat{1}, nat{1}, nat{0}, 0}, + {nat{0}, nat{_M}, nat{1}, 1}, + {nat{80235}, nat{12345}, nat{67890}, 0}, + {nat{_M - 1}, nat{_M}, nat{_M}, 1}, + {nat{0, 0, 0, 0}, nat{_M, _M, _M, _M}, nat{1, 0, 0, 0}, 1}, + {nat{0, 0, 0, _M}, nat{_M, _M, _M, _M - 1}, nat{1, 0, 0, 0}, 0}, + {nat{0, 0, 0, 0}, nat{_M, 0, _M, 0}, nat{1, _M, 0, _M}, 1}, +} + +func testFunVV(t *testing.T, msg string, f funVV, a argVV) { + z := make(nat, len(a.z)) + c := f(z, a.x, a.y) + for i, zi := range z { + if zi != a.z[i] { + t.Errorf("%s%+v\n\tgot z[%d] = %#x; want %#x", msg, a, i, zi, a.z[i]) + break + } + } + if c != a.c { + t.Errorf("%s%+v\n\tgot c = %#x; want %#x", msg, a, c, a.c) + } +} + +func TestFunVV(t *testing.T) { + for _, a := range sumVV { + arg := a + testFunVV(t, "addVV_g", addVV_g, arg) + testFunVV(t, "addVV", addVV, arg) + + arg = argVV{a.z, a.y, a.x, a.c} + testFunVV(t, "addVV_g symmetric", addVV_g, arg) + testFunVV(t, "addVV symmetric", addVV, arg) + + arg = argVV{a.x, a.z, a.y, a.c} + testFunVV(t, "subVV_g", subVV_g, arg) + testFunVV(t, "subVV", subVV, arg) + + arg = argVV{a.y, a.z, a.x, a.c} + testFunVV(t, "subVV_g symmetric", subVV_g, arg) + testFunVV(t, "subVV symmetric", subVV, arg) + } +} + +// Always the same seed for reproducible results. +var rnd = rand.New(rand.NewSource(0)) + +func rndW() Word { + return Word(rnd.Int63()<<1 | rnd.Int63n(2)) +} + +func rndV(n int) []Word { + v := make([]Word, n) + for i := range v { + v[i] = rndW() + } + return v +} + +var benchSizes = []int{1, 2, 3, 4, 5, 1e1, 1e2, 1e3, 1e4, 1e5} + +func BenchmarkAddVV(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + x := rndV(n) + y := rndV(n) + z := make([]Word, n) + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _W)) + for i := 0; i < b.N; i++ { + addVV(z, x, y) + } + }) + } +} + +func BenchmarkSubVV(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + x := rndV(n) + y := rndV(n) + z := make([]Word, n) + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _W)) + for i := 0; i < b.N; i++ { + subVV(z, x, y) + } + }) + } +} + +type funVW func(z, x []Word, y Word) (c Word) +type argVW struct { + z, x nat + y Word + c Word +} + +var sumVW = []argVW{ + {}, + {nil, nil, 2, 2}, + {nat{0}, nat{0}, 0, 0}, + {nat{1}, nat{0}, 1, 0}, + {nat{1}, nat{1}, 0, 0}, + {nat{0}, nat{_M}, 1, 1}, + {nat{0, 0, 0, 0}, nat{_M, _M, _M, _M}, 1, 1}, + {nat{585}, nat{314}, 271, 0}, +} + +var lshVW = []argVW{ + {}, + {nat{0}, nat{0}, 0, 0}, + {nat{0}, nat{0}, 1, 0}, + {nat{0}, nat{0}, 20, 0}, + + {nat{_M}, nat{_M}, 0, 0}, + {nat{_M << 1 & _M}, nat{_M}, 1, 1}, + {nat{_M << 20 & _M}, nat{_M}, 20, _M >> (_W - 20)}, + + {nat{_M, _M, _M}, nat{_M, _M, _M}, 0, 0}, + {nat{_M << 1 & _M, _M, _M}, nat{_M, _M, _M}, 1, 1}, + {nat{_M << 20 & _M, _M, _M}, nat{_M, _M, _M}, 20, _M >> (_W - 20)}, +} + +var rshVW = []argVW{ + {}, + {nat{0}, nat{0}, 0, 0}, + {nat{0}, nat{0}, 1, 0}, + {nat{0}, nat{0}, 20, 0}, + + {nat{_M}, nat{_M}, 0, 0}, + {nat{_M >> 1}, nat{_M}, 1, _M << (_W - 1) & _M}, + {nat{_M >> 20}, nat{_M}, 20, _M << (_W - 20) & _M}, + + {nat{_M, _M, _M}, nat{_M, _M, _M}, 0, 0}, + {nat{_M, _M, _M >> 1}, nat{_M, _M, _M}, 1, _M << (_W - 1) & _M}, + {nat{_M, _M, _M >> 20}, nat{_M, _M, _M}, 20, _M << (_W - 20) & _M}, +} + +func testFunVW(t *testing.T, msg string, f funVW, a argVW) { + z := make(nat, len(a.z)) + c := f(z, a.x, a.y) + for i, zi := range z { + if zi != a.z[i] { + t.Errorf("%s%+v\n\tgot z[%d] = %#x; want %#x", msg, a, i, zi, a.z[i]) + break + } + } + if c != a.c { + t.Errorf("%s%+v\n\tgot c = %#x; want %#x", msg, a, c, a.c) + } +} + +func testFunVWext(t *testing.T, msg string, f funVW, f_g funVW, a argVW) { + // using the result of addVW_g/subVW_g as golden + z_g := make(nat, len(a.z)) + c_g := f_g(z_g, a.x, a.y) + c := f(a.z, a.x, a.y) + + for i, zi := range a.z { + if zi != z_g[i] { + t.Errorf("%s\n\tgot z[%d] = %#x; want %#x", msg, i, zi, z_g[i]) + break + } + } + if c != c_g { + t.Errorf("%s\n\tgot c = %#x; want %#x", msg, c, c_g) + } +} + +func makeFunVW(f func(z, x []Word, s uint) (c Word)) funVW { + return func(z, x []Word, s Word) (c Word) { + return f(z, x, uint(s)) + } +} + +func TestFunVW(t *testing.T) { + for _, a := range sumVW { + arg := a + testFunVW(t, "addVW_g", addVW_g, arg) + testFunVW(t, "addVW", addVW, arg) + + arg = argVW{a.x, a.z, a.y, a.c} + testFunVW(t, "subVW_g", subVW_g, arg) + testFunVW(t, "subVW", subVW, arg) + } + + shlVW_g := makeFunVW(shlVU_g) + shlVW := makeFunVW(shlVU) + for _, a := range lshVW { + arg := a + testFunVW(t, "shlVU_g", shlVW_g, arg) + testFunVW(t, "shlVU", shlVW, arg) + } + + shrVW_g := makeFunVW(shrVU_g) + shrVW := makeFunVW(shrVU) + for _, a := range rshVW { + arg := a + testFunVW(t, "shrVU_g", shrVW_g, arg) + testFunVW(t, "shrVU", shrVW, arg) + } +} + +// Construct a vector comprising the same word, usually '0' or 'maximum uint' +func makeWordVec(e Word, n int) []Word { + v := make([]Word, n) + for i := range v { + v[i] = e + } + return v +} + +// Extended testing to addVW and subVW using various kinds of input data. +// We utilize the results of addVW_g and subVW_g as golden reference to check +// correctness. +func TestFunVWExt(t *testing.T) { + // 32 is the current threshold that triggers an optimized version of + // calculation for large-sized vector, ensure we have sizes around it tested. + var vwSizes = []int{0, 1, 3, 4, 5, 8, 9, 23, 31, 32, 33, 34, 35, 36, 50, 120} + for _, n := range vwSizes { + // vector of random numbers, using the result of addVW_g/subVW_g as golden + x := rndV(n) + y := rndW() + z := make(nat, n) + arg := argVW{z, x, y, 0} + testFunVWext(t, "addVW, random inputs", addVW, addVW_g, arg) + testFunVWext(t, "subVW, random inputs", subVW, subVW_g, arg) + + // vector of random numbers, but make 'x' and 'z' share storage + arg = argVW{x, x, y, 0} + testFunVWext(t, "addVW, random inputs, sharing storage", addVW, addVW_g, arg) + testFunVWext(t, "subVW, random inputs, sharing storage", subVW, subVW_g, arg) + + // vector of maximum uint, to force carry flag set in each 'add' + y = ^Word(0) + x = makeWordVec(y, n) + arg = argVW{z, x, y, 0} + testFunVWext(t, "addVW, vector of max uint", addVW, addVW_g, arg) + + // vector of '0', to force carry flag set in each 'sub' + x = makeWordVec(0, n) + arg = argVW{z, x, 1, 0} + testFunVWext(t, "subVW, vector of zero", subVW, subVW_g, arg) + } +} + +type argVU struct { + d []Word // d is a Word slice, the input parameters x and z come from this array. + l uint // l is the length of the input parameters x and z. + xp uint // xp is the starting position of the input parameter x, x := d[xp:xp+l]. + zp uint // zp is the starting position of the input parameter z, z := d[zp:zp+l]. + s uint // s is the shift number. + r []Word // r is the expected output result z. + c Word // c is the expected return value. + m string // message. +} + +var argshlVUIn = []Word{1, 2, 4, 8, 16, 32, 64, 0, 0, 0} +var argshlVUr0 = []Word{1, 2, 4, 8, 16, 32, 64} +var argshlVUr1 = []Word{2, 4, 8, 16, 32, 64, 128} +var argshlVUrWm1 = []Word{1 << (_W - 1), 0, 1, 2, 4, 8, 16} + +var argshlVU = []argVU{ + // test cases for shlVU + {[]Word{1, _M, _M, _M, _M, _M, 3 << (_W - 2), 0}, 7, 0, 0, 1, []Word{2, _M - 1, _M, _M, _M, _M, 1<<(_W-1) + 1}, 1, "complete overlap of shlVU"}, + {[]Word{1, _M, _M, _M, _M, _M, 3 << (_W - 2), 0, 0, 0, 0}, 7, 0, 3, 1, []Word{2, _M - 1, _M, _M, _M, _M, 1<<(_W-1) + 1}, 1, "partial overlap by half of shlVU"}, + {[]Word{1, _M, _M, _M, _M, _M, 3 << (_W - 2), 0, 0, 0, 0, 0, 0, 0}, 7, 0, 6, 1, []Word{2, _M - 1, _M, _M, _M, _M, 1<<(_W-1) + 1}, 1, "partial overlap by 1 Word of shlVU"}, + {[]Word{1, _M, _M, _M, _M, _M, 3 << (_W - 2), 0, 0, 0, 0, 0, 0, 0, 0}, 7, 0, 7, 1, []Word{2, _M - 1, _M, _M, _M, _M, 1<<(_W-1) + 1}, 1, "no overlap of shlVU"}, + // additional test cases with shift values of 0, 1 and (_W-1) + {argshlVUIn, 7, 0, 0, 0, argshlVUr0, 0, "complete overlap of shlVU and shift of 0"}, + {argshlVUIn, 7, 0, 0, 1, argshlVUr1, 0, "complete overlap of shlVU and shift of 1"}, + {argshlVUIn, 7, 0, 0, _W - 1, argshlVUrWm1, 32, "complete overlap of shlVU and shift of _W - 1"}, + {argshlVUIn, 7, 0, 1, 0, argshlVUr0, 0, "partial overlap by 6 Words of shlVU and shift of 0"}, + {argshlVUIn, 7, 0, 1, 1, argshlVUr1, 0, "partial overlap by 6 Words of shlVU and shift of 1"}, + {argshlVUIn, 7, 0, 1, _W - 1, argshlVUrWm1, 32, "partial overlap by 6 Words of shlVU and shift of _W - 1"}, + {argshlVUIn, 7, 0, 2, 0, argshlVUr0, 0, "partial overlap by 5 Words of shlVU and shift of 0"}, + {argshlVUIn, 7, 0, 2, 1, argshlVUr1, 0, "partial overlap by 5 Words of shlVU and shift of 1"}, + {argshlVUIn, 7, 0, 2, _W - 1, argshlVUrWm1, 32, "partial overlap by 5 Words of shlVU abd shift of _W - 1"}, + {argshlVUIn, 7, 0, 3, 0, argshlVUr0, 0, "partial overlap by 4 Words of shlVU and shift of 0"}, + {argshlVUIn, 7, 0, 3, 1, argshlVUr1, 0, "partial overlap by 4 Words of shlVU and shift of 1"}, + {argshlVUIn, 7, 0, 3, _W - 1, argshlVUrWm1, 32, "partial overlap by 4 Words of shlVU and shift of _W - 1"}, +} + +var argshrVUIn = []Word{0, 0, 0, 1, 2, 4, 8, 16, 32, 64} +var argshrVUr0 = []Word{1, 2, 4, 8, 16, 32, 64} +var argshrVUr1 = []Word{0, 1, 2, 4, 8, 16, 32} +var argshrVUrWm1 = []Word{4, 8, 16, 32, 64, 128, 0} + +var argshrVU = []argVU{ + // test cases for shrVU + {[]Word{0, 3, _M, _M, _M, _M, _M, 1 << (_W - 1)}, 7, 1, 1, 1, []Word{1<<(_W-1) + 1, _M, _M, _M, _M, _M >> 1, 1 << (_W - 2)}, 1 << (_W - 1), "complete overlap of shrVU"}, + {[]Word{0, 0, 0, 0, 3, _M, _M, _M, _M, _M, 1 << (_W - 1)}, 7, 4, 1, 1, []Word{1<<(_W-1) + 1, _M, _M, _M, _M, _M >> 1, 1 << (_W - 2)}, 1 << (_W - 1), "partial overlap by half of shrVU"}, + {[]Word{0, 0, 0, 0, 0, 0, 0, 3, _M, _M, _M, _M, _M, 1 << (_W - 1)}, 7, 7, 1, 1, []Word{1<<(_W-1) + 1, _M, _M, _M, _M, _M >> 1, 1 << (_W - 2)}, 1 << (_W - 1), "partial overlap by 1 Word of shrVU"}, + {[]Word{0, 0, 0, 0, 0, 0, 0, 0, 3, _M, _M, _M, _M, _M, 1 << (_W - 1)}, 7, 8, 1, 1, []Word{1<<(_W-1) + 1, _M, _M, _M, _M, _M >> 1, 1 << (_W - 2)}, 1 << (_W - 1), "no overlap of shrVU"}, + // additional test cases with shift values of 0, 1 and (_W-1) + {argshrVUIn, 7, 3, 3, 0, argshrVUr0, 0, "complete overlap of shrVU and shift of 0"}, + {argshrVUIn, 7, 3, 3, 1, argshrVUr1, 1 << (_W - 1), "complete overlap of shrVU and shift of 1"}, + {argshrVUIn, 7, 3, 3, _W - 1, argshrVUrWm1, 2, "complete overlap of shrVU and shift of _W - 1"}, + {argshrVUIn, 7, 3, 2, 0, argshrVUr0, 0, "partial overlap by 6 Words of shrVU and shift of 0"}, + {argshrVUIn, 7, 3, 2, 1, argshrVUr1, 1 << (_W - 1), "partial overlap by 6 Words of shrVU and shift of 1"}, + {argshrVUIn, 7, 3, 2, _W - 1, argshrVUrWm1, 2, "partial overlap by 6 Words of shrVU and shift of _W - 1"}, + {argshrVUIn, 7, 3, 1, 0, argshrVUr0, 0, "partial overlap by 5 Words of shrVU and shift of 0"}, + {argshrVUIn, 7, 3, 1, 1, argshrVUr1, 1 << (_W - 1), "partial overlap by 5 Words of shrVU and shift of 1"}, + {argshrVUIn, 7, 3, 1, _W - 1, argshrVUrWm1, 2, "partial overlap by 5 Words of shrVU and shift of _W - 1"}, + {argshrVUIn, 7, 3, 0, 0, argshrVUr0, 0, "partial overlap by 4 Words of shrVU and shift of 0"}, + {argshrVUIn, 7, 3, 0, 1, argshrVUr1, 1 << (_W - 1), "partial overlap by 4 Words of shrVU and shift of 1"}, + {argshrVUIn, 7, 3, 0, _W - 1, argshrVUrWm1, 2, "partial overlap by 4 Words of shrVU and shift of _W - 1"}, +} + +func testShiftFunc(t *testing.T, f func(z, x []Word, s uint) Word, a argVU) { + // work on copy of a.d to preserve the original data. + b := make([]Word, len(a.d)) + copy(b, a.d) + z := b[a.zp : a.zp+a.l] + x := b[a.xp : a.xp+a.l] + c := f(z, x, a.s) + for i, zi := range z { + if zi != a.r[i] { + t.Errorf("d := %v, %s(d[%d:%d], d[%d:%d], %d)\n\tgot z[%d] = %#x; want %#x", a.d, a.m, a.zp, a.zp+a.l, a.xp, a.xp+a.l, a.s, i, zi, a.r[i]) + break + } + } + if c != a.c { + t.Errorf("d := %v, %s(d[%d:%d], d[%d:%d], %d)\n\tgot c = %#x; want %#x", a.d, a.m, a.zp, a.zp+a.l, a.xp, a.xp+a.l, a.s, c, a.c) + } +} + +func TestShiftOverlap(t *testing.T) { + for _, a := range argshlVU { + arg := a + testShiftFunc(t, shlVU, arg) + } + + for _, a := range argshrVU { + arg := a + testShiftFunc(t, shrVU, arg) + } +} + +func TestIssue31084(t *testing.T) { + // compute 10^n via 5^n << n. + const n = 165 + p := nat(nil).expNN(nat{5}, nat{n}, nil, false) + p = p.shl(p, n) + got := string(p.utoa(10)) + want := "1" + strings.Repeat("0", n) + if got != want { + t.Errorf("shl(%v, %v)\n\tgot %s\n\twant %s", p, n, got, want) + } +} + +const issue42838Value = "159309191113245227702888039776771180559110455519261878607388585338616290151305816094308987472018268594098344692611135542392730712890625" + +func TestIssue42838(t *testing.T) { + const s = 192 + z, _, _, _ := nat(nil).scan(strings.NewReader(issue42838Value), 0, false) + z = z.shl(z, s) + got := string(z.utoa(10)) + want := "1" + strings.Repeat("0", s) + if got != want { + t.Errorf("shl(%v, %v)\n\tgot %s\n\twant %s", z, s, got, want) + } +} + +func BenchmarkAddVW(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + x := rndV(n) + y := rndW() + z := make([]Word, n) + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _S)) + for i := 0; i < b.N; i++ { + addVW(z, x, y) + } + }) + } +} + +// Benchmarking addVW using vector of maximum uint to force carry flag set +func BenchmarkAddVWext(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + y := ^Word(0) + x := makeWordVec(y, n) + z := make([]Word, n) + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _S)) + for i := 0; i < b.N; i++ { + addVW(z, x, y) + } + }) + } +} + +func BenchmarkSubVW(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + x := rndV(n) + y := rndW() + z := make([]Word, n) + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _S)) + for i := 0; i < b.N; i++ { + subVW(z, x, y) + } + }) + } +} + +// Benchmarking subVW using vector of zero to force carry flag set +func BenchmarkSubVWext(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + x := makeWordVec(0, n) + y := Word(1) + z := make([]Word, n) + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _S)) + for i := 0; i < b.N; i++ { + subVW(z, x, y) + } + }) + } +} + +type funVWW func(z, x []Word, y, r Word) (c Word) +type argVWW struct { + z, x nat + y, r Word + c Word +} + +var prodVWW = []argVWW{ + {}, + {nat{0}, nat{0}, 0, 0, 0}, + {nat{991}, nat{0}, 0, 991, 0}, + {nat{0}, nat{_M}, 0, 0, 0}, + {nat{991}, nat{_M}, 0, 991, 0}, + {nat{0}, nat{0}, _M, 0, 0}, + {nat{991}, nat{0}, _M, 991, 0}, + {nat{1}, nat{1}, 1, 0, 0}, + {nat{992}, nat{1}, 1, 991, 0}, + {nat{22793}, nat{991}, 23, 0, 0}, + {nat{22800}, nat{991}, 23, 7, 0}, + {nat{0, 0, 0, 22793}, nat{0, 0, 0, 991}, 23, 0, 0}, + {nat{7, 0, 0, 22793}, nat{0, 0, 0, 991}, 23, 7, 0}, + {nat{0, 0, 0, 0}, nat{7893475, 7395495, 798547395, 68943}, 0, 0, 0}, + {nat{991, 0, 0, 0}, nat{7893475, 7395495, 798547395, 68943}, 0, 991, 0}, + {nat{0, 0, 0, 0}, nat{0, 0, 0, 0}, 894375984, 0, 0}, + {nat{991, 0, 0, 0}, nat{0, 0, 0, 0}, 894375984, 991, 0}, + {nat{_M << 1 & _M}, nat{_M}, 1 << 1, 0, _M >> (_W - 1)}, + {nat{_M<<1&_M + 1}, nat{_M}, 1 << 1, 1, _M >> (_W - 1)}, + {nat{_M << 7 & _M}, nat{_M}, 1 << 7, 0, _M >> (_W - 7)}, + {nat{_M<<7&_M + 1<<6}, nat{_M}, 1 << 7, 1 << 6, _M >> (_W - 7)}, + {nat{_M << 7 & _M, _M, _M, _M}, nat{_M, _M, _M, _M}, 1 << 7, 0, _M >> (_W - 7)}, + {nat{_M<<7&_M + 1<<6, _M, _M, _M}, nat{_M, _M, _M, _M}, 1 << 7, 1 << 6, _M >> (_W - 7)}, +} + +func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) { + z := make(nat, len(a.z)) + c := f(z, a.x, a.y, a.r) + for i, zi := range z { + if zi != a.z[i] { + t.Errorf("%s%+v\n\tgot z[%d] = %#x; want %#x", msg, a, i, zi, a.z[i]) + break + } + } + if c != a.c { + t.Errorf("%s%+v\n\tgot c = %#x; want %#x", msg, a, c, a.c) + } +} + +// TODO(gri) mulAddVWW and divWVW are symmetric operations but +// their signature is not symmetric. Try to unify. + +type funWVW func(z []Word, xn Word, x []Word, y Word) (r Word) +type argWVW struct { + z nat + xn Word + x nat + y Word + r Word +} + +func testFunWVW(t *testing.T, msg string, f funWVW, a argWVW) { + z := make(nat, len(a.z)) + r := f(z, a.xn, a.x, a.y) + for i, zi := range z { + if zi != a.z[i] { + t.Errorf("%s%+v\n\tgot z[%d] = %#x; want %#x", msg, a, i, zi, a.z[i]) + break + } + } + if r != a.r { + t.Errorf("%s%+v\n\tgot r = %#x; want %#x", msg, a, r, a.r) + } +} + +func TestFunVWW(t *testing.T) { + for _, a := range prodVWW { + arg := a + testFunVWW(t, "mulAddVWW_g", mulAddVWW_g, arg) + testFunVWW(t, "mulAddVWW", mulAddVWW, arg) + + if a.y != 0 && a.r < a.y { + arg := argWVW{a.x, a.c, a.z, a.y, a.r} + testFunWVW(t, "divWVW", divWVW, arg) + } + } +} + +var mulWWTests = []struct { + x, y Word + q, r Word +}{ + {_M, _M, _M - 1, 1}, + // 32 bit only: {0xc47dfa8c, 50911, 0x98a4, 0x998587f4}, +} + +func TestMulWW(t *testing.T) { + for i, test := range mulWWTests { + q, r := mulWW(test.x, test.y) + if q != test.q || r != test.r { + t.Errorf("#%d got (%x, %x) want (%x, %x)", i, q, r, test.q, test.r) + } + } +} + +var mulAddWWWTests = []struct { + x, y, c Word + q, r Word +}{ + // TODO(agl): These will only work on 64-bit platforms. + // {15064310297182388543, 0xe7df04d2d35d5d80, 13537600649892366549, 13644450054494335067, 10832252001440893781}, + // {15064310297182388543, 0xdab2f18048baa68d, 13644450054494335067, 12869334219691522700, 14233854684711418382}, + {_M, _M, 0, _M - 1, 1}, + {_M, _M, _M, _M, 0}, +} + +func TestMulAddWWW(t *testing.T) { + for i, test := range mulAddWWWTests { + q, r := mulAddWWW_g(test.x, test.y, test.c) + if q != test.q || r != test.r { + t.Errorf("#%d got (%x, %x) want (%x, %x)", i, q, r, test.q, test.r) + } + } +} + +var divWWTests = []struct { + x1, x0, y Word + q, r Word +}{ + {_M >> 1, 0, _M, _M >> 1, _M >> 1}, + {_M - (1 << (_W - 2)), _M, 3 << (_W - 2), _M, _M - (1 << (_W - 2))}, +} + +const testsNumber = 1 << 16 + +func TestDivWW(t *testing.T) { + i := 0 + for i, test := range divWWTests { + rec := reciprocalWord(test.y) + q, r := divWW(test.x1, test.x0, test.y, rec) + if q != test.q || r != test.r { + t.Errorf("#%d got (%x, %x) want (%x, %x)", i, q, r, test.q, test.r) + } + } + //random tests + for ; i < testsNumber; i++ { + x1 := rndW() + x0 := rndW() + y := rndW() + if x1 >= y { + continue + } + rec := reciprocalWord(y) + qGot, rGot := divWW(x1, x0, y, rec) + qWant, rWant := bits.Div(uint(x1), uint(x0), uint(y)) + if uint(qGot) != qWant || uint(rGot) != rWant { + t.Errorf("#%d got (%x, %x) want (%x, %x)", i, qGot, rGot, qWant, rWant) + } + } +} + +func BenchmarkMulAddVWW(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + z := make([]Word, n+1) + x := rndV(n) + y := rndW() + r := rndW() + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _W)) + for i := 0; i < b.N; i++ { + mulAddVWW(z, x, y, r) + } + }) + } +} + +func BenchmarkAddMulVVW(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + x := rndV(n) + y := rndW() + z := make([]Word, n) + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _W)) + for i := 0; i < b.N; i++ { + addMulVVW(z, x, y) + } + }) + } +} +func BenchmarkDivWVW(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + x := rndV(n) + y := rndW() + z := make([]Word, n) + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _W)) + for i := 0; i < b.N; i++ { + divWVW(z, 0, x, y) + } + }) + } +} + +func BenchmarkNonZeroShifts(b *testing.B) { + for _, n := range benchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + x := rndV(n) + s := uint(rand.Int63n(_W-2)) + 1 // avoid 0 and over-large shifts + z := make([]Word, n) + b.Run(fmt.Sprint(n), func(b *testing.B) { + b.SetBytes(int64(n * _W)) + b.Run("shrVU", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = shrVU(z, x, s) + } + }) + b.Run("shlVU", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = shlVU(z, x, s) + } + }) + }) + } +} diff --git a/src/math/big/arith_wasm.s b/src/math/big/arith_wasm.s new file mode 100644 index 0000000..93eb16d --- /dev/null +++ b/src/math/big/arith_wasm.s @@ -0,0 +1,33 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !math_big_pure_go +// +build !math_big_pure_go + +#include "textflag.h" + +TEXT ·addVV(SB),NOSPLIT,$0 + JMP ·addVV_g(SB) + +TEXT ·subVV(SB),NOSPLIT,$0 + JMP ·subVV_g(SB) + +TEXT ·addVW(SB),NOSPLIT,$0 + JMP ·addVW_g(SB) + +TEXT ·subVW(SB),NOSPLIT,$0 + JMP ·subVW_g(SB) + +TEXT ·shlVU(SB),NOSPLIT,$0 + JMP ·shlVU_g(SB) + +TEXT ·shrVU(SB),NOSPLIT,$0 + JMP ·shrVU_g(SB) + +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + JMP ·mulAddVWW_g(SB) + +TEXT ·addMulVVW(SB),NOSPLIT,$0 + JMP ·addMulVVW_g(SB) + diff --git a/src/math/big/bits_test.go b/src/math/big/bits_test.go new file mode 100644 index 0000000..985b60b --- /dev/null +++ b/src/math/big/bits_test.go @@ -0,0 +1,224 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements the Bits type used for testing Float operations +// via an independent (albeit slower) representations for floating-point +// numbers. + +package big + +import ( + "fmt" + "sort" + "testing" +) + +// A Bits value b represents a finite floating-point number x of the form +// +// x = 2**b[0] + 2**b[1] + ... 2**b[len(b)-1] +// +// The order of slice elements is not significant. Negative elements may be +// used to form fractions. A Bits value is normalized if each b[i] occurs at +// most once. For instance Bits{0, 0, 1} is not normalized but represents the +// same floating-point number as Bits{2}, which is normalized. The zero (nil) +// value of Bits is a ready to use Bits value and represents the value 0. +type Bits []int + +func (x Bits) add(y Bits) Bits { + return append(x, y...) +} + +func (x Bits) mul(y Bits) Bits { + var p Bits + for _, x := range x { + for _, y := range y { + p = append(p, x+y) + } + } + return p +} + +func TestMulBits(t *testing.T) { + for _, test := range []struct { + x, y, want Bits + }{ + {nil, nil, nil}, + {Bits{}, Bits{}, nil}, + {Bits{0}, Bits{0}, Bits{0}}, + {Bits{0}, Bits{1}, Bits{1}}, + {Bits{1}, Bits{1, 2, 3}, Bits{2, 3, 4}}, + {Bits{-1}, Bits{1}, Bits{0}}, + {Bits{-10, -1, 0, 1, 10}, Bits{1, 2, 3}, Bits{-9, -8, -7, 0, 1, 2, 1, 2, 3, 2, 3, 4, 11, 12, 13}}, + } { + got := fmt.Sprintf("%v", test.x.mul(test.y)) + want := fmt.Sprintf("%v", test.want) + if got != want { + t.Errorf("%v * %v = %s; want %s", test.x, test.y, got, want) + } + + } +} + +// norm returns the normalized bits for x: It removes multiple equal entries +// by treating them as an addition (e.g., Bits{5, 5} => Bits{6}), and it sorts +// the result list for reproducible results. +func (x Bits) norm() Bits { + m := make(map[int]bool) + for _, b := range x { + for m[b] { + m[b] = false + b++ + } + m[b] = true + } + var z Bits + for b, set := range m { + if set { + z = append(z, b) + } + } + sort.Ints([]int(z)) + return z +} + +func TestNormBits(t *testing.T) { + for _, test := range []struct { + x, want Bits + }{ + {nil, nil}, + {Bits{}, Bits{}}, + {Bits{0}, Bits{0}}, + {Bits{0, 0}, Bits{1}}, + {Bits{3, 1, 1}, Bits{2, 3}}, + {Bits{10, 9, 8, 7, 6, 6}, Bits{11}}, + } { + got := fmt.Sprintf("%v", test.x.norm()) + want := fmt.Sprintf("%v", test.want) + if got != want { + t.Errorf("normBits(%v) = %s; want %s", test.x, got, want) + } + + } +} + +// round returns the Float value corresponding to x after rounding x +// to prec bits according to mode. +func (x Bits) round(prec uint, mode RoundingMode) *Float { + x = x.norm() + + // determine range + var min, max int + for i, b := range x { + if i == 0 || b < min { + min = b + } + if i == 0 || b > max { + max = b + } + } + prec0 := uint(max + 1 - min) + if prec >= prec0 { + return x.Float() + } + // prec < prec0 + + // determine bit 0, rounding, and sticky bit, and result bits z + var bit0, rbit, sbit uint + var z Bits + r := max - int(prec) + for _, b := range x { + switch { + case b == r: + rbit = 1 + case b < r: + sbit = 1 + default: + // b > r + if b == r+1 { + bit0 = 1 + } + z = append(z, b) + } + } + + // round + f := z.Float() // rounded to zero + if mode == ToNearestAway { + panic("not yet implemented") + } + if mode == ToNearestEven && rbit == 1 && (sbit == 1 || sbit == 0 && bit0 != 0) || mode == AwayFromZero { + // round away from zero + f.SetMode(ToZero).SetPrec(prec) + f.Add(f, Bits{int(r) + 1}.Float()) + } + return f +} + +// Float returns the *Float z of the smallest possible precision such that +// z = sum(2**bits[i]), with i = range bits. If multiple bits[i] are equal, +// they are added: Bits{0, 1, 0}.Float() == 2**0 + 2**1 + 2**0 = 4. +func (bits Bits) Float() *Float { + // handle 0 + if len(bits) == 0 { + return new(Float) + } + // len(bits) > 0 + + // determine lsb exponent + var min int + for i, b := range bits { + if i == 0 || b < min { + min = b + } + } + + // create bit pattern + x := NewInt(0) + for _, b := range bits { + badj := b - min + // propagate carry if necessary + for x.Bit(badj) != 0 { + x.SetBit(x, badj, 0) + badj++ + } + x.SetBit(x, badj, 1) + } + + // create corresponding float + z := new(Float).SetInt(x) // normalized + if e := int64(z.exp) + int64(min); MinExp <= e && e <= MaxExp { + z.exp = int32(e) + } else { + // this should never happen for our test cases + panic("exponent out of range") + } + return z +} + +func TestFromBits(t *testing.T) { + for _, test := range []struct { + bits Bits + want string + }{ + // all different bit numbers + {nil, "0"}, + {Bits{0}, "0x.8p+1"}, + {Bits{1}, "0x.8p+2"}, + {Bits{-1}, "0x.8p+0"}, + {Bits{63}, "0x.8p+64"}, + {Bits{33, -30}, "0x.8000000000000001p+34"}, + {Bits{255, 0}, "0x.8000000000000000000000000000000000000000000000000000000000000001p+256"}, + + // multiple equal bit numbers + {Bits{0, 0}, "0x.8p+2"}, + {Bits{0, 0, 0, 0}, "0x.8p+3"}, + {Bits{0, 1, 0}, "0x.8p+3"}, + {append(Bits{2, 1, 0} /* 7 */, Bits{3, 1} /* 10 */ ...), "0x.88p+5" /* 17 */}, + } { + f := test.bits.Float() + if got := f.Text('p', 0); got != test.want { + t.Errorf("setBits(%v) = %s; want %s", test.bits, got, test.want) + } + } +} diff --git a/src/math/big/calibrate_test.go b/src/math/big/calibrate_test.go new file mode 100644 index 0000000..4fa663f --- /dev/null +++ b/src/math/big/calibrate_test.go @@ -0,0 +1,173 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Calibration used to determine thresholds for using +// different algorithms. Ideally, this would be converted +// to go generate to create thresholds.go + +// This file prints execution times for the Mul benchmark +// given different Karatsuba thresholds. The result may be +// used to manually fine-tune the threshold constant. The +// results are somewhat fragile; use repeated runs to get +// a clear picture. + +// Calculates lower and upper thresholds for when basicSqr +// is faster than standard multiplication. + +// Usage: go test -run=TestCalibrate -v -calibrate + +package big + +import ( + "flag" + "fmt" + "testing" + "time" +) + +var calibrate = flag.Bool("calibrate", false, "run calibration test") + +const ( + sqrModeMul = "mul(x, x)" + sqrModeBasic = "basicSqr(x)" + sqrModeKaratsuba = "karatsubaSqr(x)" +) + +func TestCalibrate(t *testing.T) { + if !*calibrate { + return + } + + computeKaratsubaThresholds() + + // compute basicSqrThreshold where overhead becomes negligible + minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic) + // compute karatsubaSqrThreshold where karatsuba is faster + maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba) + if minSqr != 0 { + fmt.Printf("found basicSqrThreshold = %d\n", minSqr) + } else { + fmt.Println("no basicSqrThreshold found") + } + if maxSqr != 0 { + fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr) + } else { + fmt.Println("no karatsubaSqrThreshold found") + } +} + +func karatsubaLoad(b *testing.B) { + BenchmarkMul(b) +} + +// measureKaratsuba returns the time to run a Karatsuba-relevant benchmark +// given Karatsuba threshold th. +func measureKaratsuba(th int) time.Duration { + th, karatsubaThreshold = karatsubaThreshold, th + res := testing.Benchmark(karatsubaLoad) + karatsubaThreshold = th + return time.Duration(res.NsPerOp()) +} + +func computeKaratsubaThresholds() { + fmt.Printf("Multiplication times for varying Karatsuba thresholds\n") + fmt.Printf("(run repeatedly for good results)\n") + + // determine Tk, the work load execution time using basic multiplication + Tb := measureKaratsuba(1e9) // th == 1e9 => Karatsuba multiplication disabled + fmt.Printf("Tb = %10s\n", Tb) + + // thresholds + th := 4 + th1 := -1 + th2 := -1 + + var deltaOld time.Duration + for count := -1; count != 0 && th < 128; count-- { + // determine Tk, the work load execution time using Karatsuba multiplication + Tk := measureKaratsuba(th) + + // improvement over Tb + delta := (Tb - Tk) * 100 / Tb + + fmt.Printf("th = %3d Tk = %10s %4d%%", th, Tk, delta) + + // determine break-even point + if Tk < Tb && th1 < 0 { + th1 = th + fmt.Print(" break-even point") + } + + // determine diminishing return + if 0 < delta && delta < deltaOld && th2 < 0 { + th2 = th + fmt.Print(" diminishing return") + } + deltaOld = delta + + fmt.Println() + + // trigger counter + if th1 >= 0 && th2 >= 0 && count < 0 { + count = 10 // this many extra measurements after we got both thresholds + } + + th++ + } +} + +func measureSqr(words, nruns int, mode string) time.Duration { + // more runs for better statistics + initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold + + switch mode { + case sqrModeMul: + basicSqrThreshold = words + 1 + case sqrModeBasic: + basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1 + case sqrModeKaratsuba: + karatsubaSqrThreshold = words - 1 + } + + var testval int64 + for i := 0; i < nruns; i++ { + res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) }) + testval += res.NsPerOp() + } + testval /= int64(nruns) + + basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr + + return time.Duration(testval) +} + +func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int { + fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper) + fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step) + var initPos bool + var threshold int + for i := from; i <= to; i += step { + baseline := measureSqr(i, nruns, lower) + testval := measureSqr(i, nruns, upper) + pos := baseline > testval + delta := baseline - testval + percent := delta * 100 / baseline + fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos) + if i == from { + initPos = pos + } + if threshold == 0 && pos != initPos { + threshold = i + fmt.Printf(" threshold found") + } + fmt.Println() + + } + if threshold != 0 { + fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to) + } else { + fmt.Printf("Found NO threshold between %d - %d\n", from, to) + } + return threshold +} diff --git a/src/math/big/decimal.go b/src/math/big/decimal.go new file mode 100644 index 0000000..716f03b --- /dev/null +++ b/src/math/big/decimal.go @@ -0,0 +1,270 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements multi-precision decimal numbers. +// The implementation is for float to decimal conversion only; +// not general purpose use. +// The only operations are precise conversion from binary to +// decimal and rounding. +// +// The key observation and some code (shr) is borrowed from +// strconv/decimal.go: conversion of binary fractional values can be done +// precisely in multi-precision decimal because 2 divides 10 (required for +// >> of mantissa); but conversion of decimal floating-point values cannot +// be done precisely in binary representation. +// +// In contrast to strconv/decimal.go, only right shift is implemented in +// decimal format - left shift can be done precisely in binary format. + +package big + +// A decimal represents an unsigned floating-point number in decimal representation. +// The value of a non-zero decimal d is d.mant * 10**d.exp with 0.1 <= d.mant < 1, +// with the most-significant mantissa digit at index 0. For the zero decimal, the +// mantissa length and exponent are 0. +// The zero value for decimal represents a ready-to-use 0.0. +type decimal struct { + mant []byte // mantissa ASCII digits, big-endian + exp int // exponent +} + +// at returns the i'th mantissa digit, starting with the most significant digit at 0. +func (d *decimal) at(i int) byte { + if 0 <= i && i < len(d.mant) { + return d.mant[i] + } + return '0' +} + +// Maximum shift amount that can be done in one pass without overflow. +// A Word has _W bits and (1<<maxShift - 1)*10 + 9 must fit into Word. +const maxShift = _W - 4 + +// TODO(gri) Since we know the desired decimal precision when converting +// a floating-point number, we may be able to limit the number of decimal +// digits that need to be computed by init by providing an additional +// precision argument and keeping track of when a number was truncated early +// (equivalent of "sticky bit" in binary rounding). + +// TODO(gri) Along the same lines, enforce some limit to shift magnitudes +// to avoid "infinitely" long running conversions (until we run out of space). + +// Init initializes x to the decimal representation of m << shift (for +// shift >= 0), or m >> -shift (for shift < 0). +func (x *decimal) init(m nat, shift int) { + // special case 0 + if len(m) == 0 { + x.mant = x.mant[:0] + x.exp = 0 + return + } + + // Optimization: If we need to shift right, first remove any trailing + // zero bits from m to reduce shift amount that needs to be done in + // decimal format (since that is likely slower). + if shift < 0 { + ntz := m.trailingZeroBits() + s := uint(-shift) + if s >= ntz { + s = ntz // shift at most ntz bits + } + m = nat(nil).shr(m, s) + shift += int(s) + } + + // Do any shift left in binary representation. + if shift > 0 { + m = nat(nil).shl(m, uint(shift)) + shift = 0 + } + + // Convert mantissa into decimal representation. + s := m.utoa(10) + n := len(s) + x.exp = n + // Trim trailing zeros; instead the exponent is tracking + // the decimal point independent of the number of digits. + for n > 0 && s[n-1] == '0' { + n-- + } + x.mant = append(x.mant[:0], s[:n]...) + + // Do any (remaining) shift right in decimal representation. + if shift < 0 { + for shift < -maxShift { + shr(x, maxShift) + shift += maxShift + } + shr(x, uint(-shift)) + } +} + +// shr implements x >> s, for s <= maxShift. +func shr(x *decimal, s uint) { + // Division by 1<<s using shift-and-subtract algorithm. + + // pick up enough leading digits to cover first shift + r := 0 // read index + var n Word + for n>>s == 0 && r < len(x.mant) { + ch := Word(x.mant[r]) + r++ + n = n*10 + ch - '0' + } + if n == 0 { + // x == 0; shouldn't get here, but handle anyway + x.mant = x.mant[:0] + return + } + for n>>s == 0 { + r++ + n *= 10 + } + x.exp += 1 - r + + // read a digit, write a digit + w := 0 // write index + mask := Word(1)<<s - 1 + for r < len(x.mant) { + ch := Word(x.mant[r]) + r++ + d := n >> s + n &= mask // n -= d << s + x.mant[w] = byte(d + '0') + w++ + n = n*10 + ch - '0' + } + + // write extra digits that still fit + for n > 0 && w < len(x.mant) { + d := n >> s + n &= mask + x.mant[w] = byte(d + '0') + w++ + n = n * 10 + } + x.mant = x.mant[:w] // the number may be shorter (e.g. 1024 >> 10) + + // append additional digits that didn't fit + for n > 0 { + d := n >> s + n &= mask + x.mant = append(x.mant, byte(d+'0')) + n = n * 10 + } + + trim(x) +} + +func (x *decimal) String() string { + if len(x.mant) == 0 { + return "0" + } + + var buf []byte + switch { + case x.exp <= 0: + // 0.00ddd + buf = make([]byte, 0, 2+(-x.exp)+len(x.mant)) + buf = append(buf, "0."...) + buf = appendZeros(buf, -x.exp) + buf = append(buf, x.mant...) + + case /* 0 < */ x.exp < len(x.mant): + // dd.ddd + buf = make([]byte, 0, 1+len(x.mant)) + buf = append(buf, x.mant[:x.exp]...) + buf = append(buf, '.') + buf = append(buf, x.mant[x.exp:]...) + + default: // len(x.mant) <= x.exp + // ddd00 + buf = make([]byte, 0, x.exp) + buf = append(buf, x.mant...) + buf = appendZeros(buf, x.exp-len(x.mant)) + } + + return string(buf) +} + +// appendZeros appends n 0 digits to buf and returns buf. +func appendZeros(buf []byte, n int) []byte { + for ; n > 0; n-- { + buf = append(buf, '0') + } + return buf +} + +// shouldRoundUp reports if x should be rounded up +// if shortened to n digits. n must be a valid index +// for x.mant. +func shouldRoundUp(x *decimal, n int) bool { + if x.mant[n] == '5' && n+1 == len(x.mant) { + // exactly halfway - round to even + return n > 0 && (x.mant[n-1]-'0')&1 != 0 + } + // not halfway - digit tells all (x.mant has no trailing zeros) + return x.mant[n] >= '5' +} + +// round sets x to (at most) n mantissa digits by rounding it +// to the nearest even value with n (or fever) mantissa digits. +// If n < 0, x remains unchanged. +func (x *decimal) round(n int) { + if n < 0 || n >= len(x.mant) { + return // nothing to do + } + + if shouldRoundUp(x, n) { + x.roundUp(n) + } else { + x.roundDown(n) + } +} + +func (x *decimal) roundUp(n int) { + if n < 0 || n >= len(x.mant) { + return // nothing to do + } + // 0 <= n < len(x.mant) + + // find first digit < '9' + for n > 0 && x.mant[n-1] >= '9' { + n-- + } + + if n == 0 { + // all digits are '9's => round up to '1' and update exponent + x.mant[0] = '1' // ok since len(x.mant) > n + x.mant = x.mant[:1] + x.exp++ + return + } + + // n > 0 && x.mant[n-1] < '9' + x.mant[n-1]++ + x.mant = x.mant[:n] + // x already trimmed +} + +func (x *decimal) roundDown(n int) { + if n < 0 || n >= len(x.mant) { + return // nothing to do + } + x.mant = x.mant[:n] + trim(x) +} + +// trim cuts off any trailing zeros from x's mantissa; +// they are meaningless for the value of x. +func trim(x *decimal) { + i := len(x.mant) + for i > 0 && x.mant[i-1] == '0' { + i-- + } + x.mant = x.mant[:i] + if i == 0 { + x.exp = 0 + } +} diff --git a/src/math/big/decimal_test.go b/src/math/big/decimal_test.go new file mode 100644 index 0000000..424811e --- /dev/null +++ b/src/math/big/decimal_test.go @@ -0,0 +1,134 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "fmt" + "testing" +) + +func TestDecimalString(t *testing.T) { + for _, test := range []struct { + x decimal + want string + }{ + {want: "0"}, + {decimal{nil, 1000}, "0"}, // exponent of 0 is ignored + {decimal{[]byte("12345"), 0}, "0.12345"}, + {decimal{[]byte("12345"), -3}, "0.00012345"}, + {decimal{[]byte("12345"), +3}, "123.45"}, + {decimal{[]byte("12345"), +10}, "1234500000"}, + } { + if got := test.x.String(); got != test.want { + t.Errorf("%v == %s; want %s", test.x, got, test.want) + } + } +} + +func TestDecimalInit(t *testing.T) { + for _, test := range []struct { + x Word + shift int + want string + }{ + {0, 0, "0"}, + {0, -100, "0"}, + {0, 100, "0"}, + {1, 0, "1"}, + {1, 10, "1024"}, + {1, 100, "1267650600228229401496703205376"}, + {1, -100, "0.0000000000000000000000000000007888609052210118054117285652827862296732064351090230047702789306640625"}, + {12345678, 8, "3160493568"}, + {12345678, -8, "48225.3046875"}, + {195312, 9, "99999744"}, + {1953125, 9, "1000000000"}, + } { + var d decimal + d.init(nat{test.x}.norm(), test.shift) + if got := d.String(); got != test.want { + t.Errorf("%d << %d == %s; want %s", test.x, test.shift, got, test.want) + } + } +} + +func TestDecimalRounding(t *testing.T) { + for _, test := range []struct { + x uint64 + n int + down, even, up string + }{ + {0, 0, "0", "0", "0"}, + {0, 1, "0", "0", "0"}, + + {1, 0, "0", "0", "10"}, + {5, 0, "0", "0", "10"}, + {9, 0, "0", "10", "10"}, + + {15, 1, "10", "20", "20"}, + {45, 1, "40", "40", "50"}, + {95, 1, "90", "100", "100"}, + + {12344999, 4, "12340000", "12340000", "12350000"}, + {12345000, 4, "12340000", "12340000", "12350000"}, + {12345001, 4, "12340000", "12350000", "12350000"}, + {23454999, 4, "23450000", "23450000", "23460000"}, + {23455000, 4, "23450000", "23460000", "23460000"}, + {23455001, 4, "23450000", "23460000", "23460000"}, + + {99994999, 4, "99990000", "99990000", "100000000"}, + {99995000, 4, "99990000", "100000000", "100000000"}, + {99999999, 4, "99990000", "100000000", "100000000"}, + + {12994999, 4, "12990000", "12990000", "13000000"}, + {12995000, 4, "12990000", "13000000", "13000000"}, + {12999999, 4, "12990000", "13000000", "13000000"}, + } { + x := nat(nil).setUint64(test.x) + + var d decimal + d.init(x, 0) + d.roundDown(test.n) + if got := d.String(); got != test.down { + t.Errorf("roundDown(%d, %d) = %s; want %s", test.x, test.n, got, test.down) + } + + d.init(x, 0) + d.round(test.n) + if got := d.String(); got != test.even { + t.Errorf("round(%d, %d) = %s; want %s", test.x, test.n, got, test.even) + } + + d.init(x, 0) + d.roundUp(test.n) + if got := d.String(); got != test.up { + t.Errorf("roundUp(%d, %d) = %s; want %s", test.x, test.n, got, test.up) + } + } +} + +var sink string + +func BenchmarkDecimalConversion(b *testing.B) { + for i := 0; i < b.N; i++ { + for shift := -100; shift <= +100; shift++ { + var d decimal + d.init(natOne, shift) + sink = d.String() + } + } +} + +func BenchmarkFloatString(b *testing.B) { + x := new(Float) + for _, prec := range []uint{1e2, 1e3, 1e4, 1e5} { + x.SetPrec(prec).SetRat(NewRat(1, 3)) + b.Run(fmt.Sprintf("%v", prec), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sink = x.String() + } + }) + } +} diff --git a/src/math/big/doc.go b/src/math/big/doc.go new file mode 100644 index 0000000..65ed019 --- /dev/null +++ b/src/math/big/doc.go @@ -0,0 +1,99 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package big implements arbitrary-precision arithmetic (big numbers). +The following numeric types are supported: + + Int signed integers + Rat rational numbers + Float floating-point numbers + +The zero value for an Int, Rat, or Float correspond to 0. Thus, new +values can be declared in the usual ways and denote 0 without further +initialization: + + var x Int // &x is an *Int of value 0 + var r = &Rat{} // r is a *Rat of value 0 + y := new(Float) // y is a *Float of value 0 + +Alternatively, new values can be allocated and initialized with factory +functions of the form: + + func NewT(v V) *T + +For instance, NewInt(x) returns an *Int set to the value of the int64 +argument x, NewRat(a, b) returns a *Rat set to the fraction a/b where +a and b are int64 values, and NewFloat(f) returns a *Float initialized +to the float64 argument f. More flexibility is provided with explicit +setters, for instance: + + var z1 Int + z1.SetUint64(123) // z1 := 123 + z2 := new(Rat).SetFloat64(1.25) // z2 := 5/4 + z3 := new(Float).SetInt(z1) // z3 := 123.0 + +Setters, numeric operations and predicates are represented as methods of +the form: + + func (z *T) SetV(v V) *T // z = v + func (z *T) Unary(x *T) *T // z = unary x + func (z *T) Binary(x, y *T) *T // z = x binary y + func (x *T) Pred() P // p = pred(x) + +with T one of Int, Rat, or Float. For unary and binary operations, the +result is the receiver (usually named z in that case; see below); if it +is one of the operands x or y it may be safely overwritten (and its memory +reused). + +Arithmetic expressions are typically written as a sequence of individual +method calls, with each call corresponding to an operation. The receiver +denotes the result and the method arguments are the operation's operands. +For instance, given three *Int values a, b and c, the invocation + + c.Add(a, b) + +computes the sum a + b and stores the result in c, overwriting whatever +value was held in c before. Unless specified otherwise, operations permit +aliasing of parameters, so it is perfectly ok to write + + sum.Add(sum, x) + +to accumulate values x in a sum. + +(By always passing in a result value via the receiver, memory use can be +much better controlled. Instead of having to allocate new memory for each +result, an operation can reuse the space allocated for the result value, +and overwrite that value with the new result in the process.) + +Notational convention: Incoming method parameters (including the receiver) +are named consistently in the API to clarify their use. Incoming operands +are usually named x, y, a, b, and so on, but never z. A parameter specifying +the result is named z (typically the receiver). + +For instance, the arguments for (*Int).Add are named x and y, and because +the receiver specifies the result destination, it is called z: + + func (z *Int) Add(x, y *Int) *Int + +Methods of this form typically return the incoming receiver as well, to +enable simple call chaining. + +Methods which don't require a result value to be passed in (for instance, +Int.Sign), simply return the result. In this case, the receiver is typically +the first operand, named x: + + func (x *Int) Sign() int + +Various methods support conversions between strings and corresponding +numeric values, and vice versa: *Int, *Rat, and *Float values implement +the Stringer interface for a (default) string representation of the value, +but also provide SetString methods to initialize a value from a string in +a variety of supported formats (see the respective SetString documentation). + +Finally, *Int, *Rat, and *Float satisfy the fmt package's Scanner interface +for scanning and (except for *Rat) the Formatter interface for formatted +printing. +*/ +package big diff --git a/src/math/big/example_rat_test.go b/src/math/big/example_rat_test.go new file mode 100644 index 0000000..dc67430 --- /dev/null +++ b/src/math/big/example_rat_test.go @@ -0,0 +1,68 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big_test + +import ( + "fmt" + "math/big" +) + +// Use the classic continued fraction for e +// +// e = [1; 0, 1, 1, 2, 1, 1, ... 2n, 1, 1, ...] +// +// i.e., for the nth term, use +// +// 1 if n mod 3 != 1 +// (n-1)/3 * 2 if n mod 3 == 1 +func recur(n, lim int64) *big.Rat { + term := new(big.Rat) + if n%3 != 1 { + term.SetInt64(1) + } else { + term.SetInt64((n - 1) / 3 * 2) + } + + if n > lim { + return term + } + + // Directly initialize frac as the fractional + // inverse of the result of recur. + frac := new(big.Rat).Inv(recur(n+1, lim)) + + return term.Add(term, frac) +} + +// This example demonstrates how to use big.Rat to compute the +// first 15 terms in the sequence of rational convergents for +// the constant e (base of natural logarithm). +func Example_eConvergents() { + for i := 1; i <= 15; i++ { + r := recur(0, int64(i)) + + // Print r both as a fraction and as a floating-point number. + // Since big.Rat implements fmt.Formatter, we can use %-13s to + // get a left-aligned string representation of the fraction. + fmt.Printf("%-13s = %s\n", r, r.FloatString(8)) + } + + // Output: + // 2/1 = 2.00000000 + // 3/1 = 3.00000000 + // 8/3 = 2.66666667 + // 11/4 = 2.75000000 + // 19/7 = 2.71428571 + // 87/32 = 2.71875000 + // 106/39 = 2.71794872 + // 193/71 = 2.71830986 + // 1264/465 = 2.71827957 + // 1457/536 = 2.71828358 + // 2721/1001 = 2.71828172 + // 23225/8544 = 2.71828184 + // 25946/9545 = 2.71828182 + // 49171/18089 = 2.71828183 + // 517656/190435 = 2.71828183 +} diff --git a/src/math/big/example_test.go b/src/math/big/example_test.go new file mode 100644 index 0000000..31ca784 --- /dev/null +++ b/src/math/big/example_test.go @@ -0,0 +1,148 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big_test + +import ( + "fmt" + "log" + "math" + "math/big" +) + +func ExampleRat_SetString() { + r := new(big.Rat) + r.SetString("355/113") + fmt.Println(r.FloatString(3)) + // Output: 3.142 +} + +func ExampleInt_SetString() { + i := new(big.Int) + i.SetString("644", 8) // octal + fmt.Println(i) + // Output: 420 +} + +func ExampleFloat_SetString() { + f := new(big.Float) + f.SetString("3.14159") + fmt.Println(f) + // Output: 3.14159 +} + +func ExampleRat_Scan() { + // The Scan function is rarely used directly; + // the fmt package recognizes it as an implementation of fmt.Scanner. + r := new(big.Rat) + _, err := fmt.Sscan("1.5000", r) + if err != nil { + log.Println("error scanning value:", err) + } else { + fmt.Println(r) + } + // Output: 3/2 +} + +func ExampleInt_Scan() { + // The Scan function is rarely used directly; + // the fmt package recognizes it as an implementation of fmt.Scanner. + i := new(big.Int) + _, err := fmt.Sscan("18446744073709551617", i) + if err != nil { + log.Println("error scanning value:", err) + } else { + fmt.Println(i) + } + // Output: 18446744073709551617 +} + +func ExampleFloat_Scan() { + // The Scan function is rarely used directly; + // the fmt package recognizes it as an implementation of fmt.Scanner. + f := new(big.Float) + _, err := fmt.Sscan("1.19282e99", f) + if err != nil { + log.Println("error scanning value:", err) + } else { + fmt.Println(f) + } + // Output: 1.19282e+99 +} + +// This example demonstrates how to use big.Int to compute the smallest +// Fibonacci number with 100 decimal digits and to test whether it is prime. +func Example_fibonacci() { + // Initialize two big ints with the first two numbers in the sequence. + a := big.NewInt(0) + b := big.NewInt(1) + + // Initialize limit as 10^99, the smallest integer with 100 digits. + var limit big.Int + limit.Exp(big.NewInt(10), big.NewInt(99), nil) + + // Loop while a is smaller than 1e100. + for a.Cmp(&limit) < 0 { + // Compute the next Fibonacci number, storing it in a. + a.Add(a, b) + // Swap a and b so that b is the next number in the sequence. + a, b = b, a + } + fmt.Println(a) // 100-digit Fibonacci number + + // Test a for primality. + // (ProbablyPrimes' argument sets the number of Miller-Rabin + // rounds to be performed. 20 is a good value.) + fmt.Println(a.ProbablyPrime(20)) + + // Output: + // 1344719667586153181419716641724567886890850696275767987106294472017884974410332069524504824747437757 + // false +} + +// This example shows how to use big.Float to compute the square root of 2 with +// a precision of 200 bits, and how to print the result as a decimal number. +func Example_sqrt2() { + // We'll do computations with 200 bits of precision in the mantissa. + const prec = 200 + + // Compute the square root of 2 using Newton's Method. We start with + // an initial estimate for sqrt(2), and then iterate: + // x_{n+1} = 1/2 * ( x_n + (2.0 / x_n) ) + + // Since Newton's Method doubles the number of correct digits at each + // iteration, we need at least log_2(prec) steps. + steps := int(math.Log2(prec)) + + // Initialize values we need for the computation. + two := new(big.Float).SetPrec(prec).SetInt64(2) + half := new(big.Float).SetPrec(prec).SetFloat64(0.5) + + // Use 1 as the initial estimate. + x := new(big.Float).SetPrec(prec).SetInt64(1) + + // We use t as a temporary variable. There's no need to set its precision + // since big.Float values with unset (== 0) precision automatically assume + // the largest precision of the arguments when used as the result (receiver) + // of a big.Float operation. + t := new(big.Float) + + // Iterate. + for i := 0; i <= steps; i++ { + t.Quo(two, x) // t = 2.0 / x_n + t.Add(x, t) // t = x_n + (2.0 / x_n) + x.Mul(half, t) // x_{n+1} = 0.5 * t + } + + // We can use the usual fmt.Printf verbs since big.Float implements fmt.Formatter + fmt.Printf("sqrt(2) = %.50f\n", x) + + // Print the error between 2 and x*x. + t.Mul(x, x) // t = x*x + fmt.Printf("error = %e\n", t.Sub(two, t)) + + // Output: + // sqrt(2) = 1.41421356237309504880168872420969807856967187537695 + // error = 0.000000e+00 +} diff --git a/src/math/big/float.go b/src/math/big/float.go new file mode 100644 index 0000000..84666d8 --- /dev/null +++ b/src/math/big/float.go @@ -0,0 +1,1729 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements multi-precision floating-point numbers. +// Like in the GNU MPFR library (https://www.mpfr.org/), operands +// can be of mixed precision. Unlike MPFR, the rounding mode is +// not specified with each operation, but with each operand. The +// rounding mode of the result operand determines the rounding +// mode of an operation. This is a from-scratch implementation. + +package big + +import ( + "fmt" + "math" + "math/bits" +) + +const debugFloat = false // enable for debugging + +// A nonzero finite Float represents a multi-precision floating point number +// +// sign × mantissa × 2**exponent +// +// with 0.5 <= mantissa < 1.0, and MinExp <= exponent <= MaxExp. +// A Float may also be zero (+0, -0) or infinite (+Inf, -Inf). +// All Floats are ordered, and the ordering of two Floats x and y +// is defined by x.Cmp(y). +// +// Each Float value also has a precision, rounding mode, and accuracy. +// The precision is the maximum number of mantissa bits available to +// represent the value. The rounding mode specifies how a result should +// be rounded to fit into the mantissa bits, and accuracy describes the +// rounding error with respect to the exact result. +// +// Unless specified otherwise, all operations (including setters) that +// specify a *Float variable for the result (usually via the receiver +// with the exception of MantExp), round the numeric result according +// to the precision and rounding mode of the result variable. +// +// If the provided result precision is 0 (see below), it is set to the +// precision of the argument with the largest precision value before any +// rounding takes place, and the rounding mode remains unchanged. Thus, +// uninitialized Floats provided as result arguments will have their +// precision set to a reasonable value determined by the operands, and +// their mode is the zero value for RoundingMode (ToNearestEven). +// +// By setting the desired precision to 24 or 53 and using matching rounding +// mode (typically ToNearestEven), Float operations produce the same results +// as the corresponding float32 or float64 IEEE-754 arithmetic for operands +// that correspond to normal (i.e., not denormal) float32 or float64 numbers. +// Exponent underflow and overflow lead to a 0 or an Infinity for different +// values than IEEE-754 because Float exponents have a much larger range. +// +// The zero (uninitialized) value for a Float is ready to use and represents +// the number +0.0 exactly, with precision 0 and rounding mode ToNearestEven. +// +// Operations always take pointer arguments (*Float) rather +// than Float values, and each unique Float value requires +// its own unique *Float pointer. To "copy" a Float value, +// an existing (or newly allocated) Float must be set to +// a new value using the Float.Set method; shallow copies +// of Floats are not supported and may lead to errors. +type Float struct { + prec uint32 + mode RoundingMode + acc Accuracy + form form + neg bool + mant nat + exp int32 +} + +// An ErrNaN panic is raised by a Float operation that would lead to +// a NaN under IEEE-754 rules. An ErrNaN implements the error interface. +type ErrNaN struct { + msg string +} + +func (err ErrNaN) Error() string { + return err.msg +} + +// NewFloat allocates and returns a new Float set to x, +// with precision 53 and rounding mode ToNearestEven. +// NewFloat panics with ErrNaN if x is a NaN. +func NewFloat(x float64) *Float { + if math.IsNaN(x) { + panic(ErrNaN{"NewFloat(NaN)"}) + } + return new(Float).SetFloat64(x) +} + +// Exponent and precision limits. +const ( + MaxExp = math.MaxInt32 // largest supported exponent + MinExp = math.MinInt32 // smallest supported exponent + MaxPrec = math.MaxUint32 // largest (theoretically) supported precision; likely memory-limited +) + +// Internal representation: The mantissa bits x.mant of a nonzero finite +// Float x are stored in a nat slice long enough to hold up to x.prec bits; +// the slice may (but doesn't have to) be shorter if the mantissa contains +// trailing 0 bits. x.mant is normalized if the msb of x.mant == 1 (i.e., +// the msb is shifted all the way "to the left"). Thus, if the mantissa has +// trailing 0 bits or x.prec is not a multiple of the Word size _W, +// x.mant[0] has trailing zero bits. The msb of the mantissa corresponds +// to the value 0.5; the exponent x.exp shifts the binary point as needed. +// +// A zero or non-finite Float x ignores x.mant and x.exp. +// +// x form neg mant exp +// ---------------------------------------------------------- +// ±0 zero sign - - +// 0 < |x| < +Inf finite sign mantissa exponent +// ±Inf inf sign - - + +// A form value describes the internal representation. +type form byte + +// The form value order is relevant - do not change! +const ( + zero form = iota + finite + inf +) + +// RoundingMode determines how a Float value is rounded to the +// desired precision. Rounding may change the Float value; the +// rounding error is described by the Float's Accuracy. +type RoundingMode byte + +// These constants define supported rounding modes. +const ( + ToNearestEven RoundingMode = iota // == IEEE 754-2008 roundTiesToEven + ToNearestAway // == IEEE 754-2008 roundTiesToAway + ToZero // == IEEE 754-2008 roundTowardZero + AwayFromZero // no IEEE 754-2008 equivalent + ToNegativeInf // == IEEE 754-2008 roundTowardNegative + ToPositiveInf // == IEEE 754-2008 roundTowardPositive +) + +//go:generate stringer -type=RoundingMode + +// Accuracy describes the rounding error produced by the most recent +// operation that generated a Float value, relative to the exact value. +type Accuracy int8 + +// Constants describing the Accuracy of a Float. +const ( + Below Accuracy = -1 + Exact Accuracy = 0 + Above Accuracy = +1 +) + +//go:generate stringer -type=Accuracy + +// SetPrec sets z's precision to prec and returns the (possibly) rounded +// value of z. Rounding occurs according to z's rounding mode if the mantissa +// cannot be represented in prec bits without loss of precision. +// SetPrec(0) maps all finite values to ±0; infinite values remain unchanged. +// If prec > MaxPrec, it is set to MaxPrec. +func (z *Float) SetPrec(prec uint) *Float { + z.acc = Exact // optimistically assume no rounding is needed + + // special case + if prec == 0 { + z.prec = 0 + if z.form == finite { + // truncate z to 0 + z.acc = makeAcc(z.neg) + z.form = zero + } + return z + } + + // general case + if prec > MaxPrec { + prec = MaxPrec + } + old := z.prec + z.prec = uint32(prec) + if z.prec < old { + z.round(0) + } + return z +} + +func makeAcc(above bool) Accuracy { + if above { + return Above + } + return Below +} + +// SetMode sets z's rounding mode to mode and returns an exact z. +// z remains unchanged otherwise. +// z.SetMode(z.Mode()) is a cheap way to set z's accuracy to Exact. +func (z *Float) SetMode(mode RoundingMode) *Float { + z.mode = mode + z.acc = Exact + return z +} + +// Prec returns the mantissa precision of x in bits. +// The result may be 0 for |x| == 0 and |x| == Inf. +func (x *Float) Prec() uint { + return uint(x.prec) +} + +// MinPrec returns the minimum precision required to represent x exactly +// (i.e., the smallest prec before x.SetPrec(prec) would start rounding x). +// The result is 0 for |x| == 0 and |x| == Inf. +func (x *Float) MinPrec() uint { + if x.form != finite { + return 0 + } + return uint(len(x.mant))*_W - x.mant.trailingZeroBits() +} + +// Mode returns the rounding mode of x. +func (x *Float) Mode() RoundingMode { + return x.mode +} + +// Acc returns the accuracy of x produced by the most recent +// operation, unless explicitly documented otherwise by that +// operation. +func (x *Float) Acc() Accuracy { + return x.acc +} + +// Sign returns: +// +// -1 if x < 0 +// 0 if x is ±0 +// +1 if x > 0 +func (x *Float) Sign() int { + if debugFloat { + x.validate() + } + if x.form == zero { + return 0 + } + if x.neg { + return -1 + } + return 1 +} + +// MantExp breaks x into its mantissa and exponent components +// and returns the exponent. If a non-nil mant argument is +// provided its value is set to the mantissa of x, with the +// same precision and rounding mode as x. The components +// satisfy x == mant × 2**exp, with 0.5 <= |mant| < 1.0. +// Calling MantExp with a nil argument is an efficient way to +// get the exponent of the receiver. +// +// Special cases are: +// +// ( ±0).MantExp(mant) = 0, with mant set to ±0 +// (±Inf).MantExp(mant) = 0, with mant set to ±Inf +// +// x and mant may be the same in which case x is set to its +// mantissa value. +func (x *Float) MantExp(mant *Float) (exp int) { + if debugFloat { + x.validate() + } + if x.form == finite { + exp = int(x.exp) + } + if mant != nil { + mant.Copy(x) + if mant.form == finite { + mant.exp = 0 + } + } + return +} + +func (z *Float) setExpAndRound(exp int64, sbit uint) { + if exp < MinExp { + // underflow + z.acc = makeAcc(z.neg) + z.form = zero + return + } + + if exp > MaxExp { + // overflow + z.acc = makeAcc(!z.neg) + z.form = inf + return + } + + z.form = finite + z.exp = int32(exp) + z.round(sbit) +} + +// SetMantExp sets z to mant × 2**exp and returns z. +// The result z has the same precision and rounding mode +// as mant. SetMantExp is an inverse of MantExp but does +// not require 0.5 <= |mant| < 1.0. Specifically, for a +// given x of type *Float, SetMantExp relates to MantExp +// as follows: +// +// mant := new(Float) +// new(Float).SetMantExp(mant, x.MantExp(mant)).Cmp(x) == 0 +// +// Special cases are: +// +// z.SetMantExp( ±0, exp) = ±0 +// z.SetMantExp(±Inf, exp) = ±Inf +// +// z and mant may be the same in which case z's exponent +// is set to exp. +func (z *Float) SetMantExp(mant *Float, exp int) *Float { + if debugFloat { + z.validate() + mant.validate() + } + z.Copy(mant) + + if z.form == finite { + // 0 < |mant| < +Inf + z.setExpAndRound(int64(z.exp)+int64(exp), 0) + } + return z +} + +// Signbit reports whether x is negative or negative zero. +func (x *Float) Signbit() bool { + return x.neg +} + +// IsInf reports whether x is +Inf or -Inf. +func (x *Float) IsInf() bool { + return x.form == inf +} + +// IsInt reports whether x is an integer. +// ±Inf values are not integers. +func (x *Float) IsInt() bool { + if debugFloat { + x.validate() + } + // special cases + if x.form != finite { + return x.form == zero + } + // x.form == finite + if x.exp <= 0 { + return false + } + // x.exp > 0 + return x.prec <= uint32(x.exp) || x.MinPrec() <= uint(x.exp) // not enough bits for fractional mantissa +} + +// debugging support +func (x *Float) validate() { + if !debugFloat { + // avoid performance bugs + panic("validate called but debugFloat is not set") + } + if x.form != finite { + return + } + m := len(x.mant) + if m == 0 { + panic("nonzero finite number with empty mantissa") + } + const msb = 1 << (_W - 1) + if x.mant[m-1]&msb == 0 { + panic(fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.Text('p', 0))) + } + if x.prec == 0 { + panic("zero precision finite number") + } +} + +// round rounds z according to z.mode to z.prec bits and sets z.acc accordingly. +// sbit must be 0 or 1 and summarizes any "sticky bit" information one might +// have before calling round. z's mantissa must be normalized (with the msb set) +// or empty. +// +// CAUTION: The rounding modes ToNegativeInf, ToPositiveInf are affected by the +// sign of z. For correct rounding, the sign of z must be set correctly before +// calling round. +func (z *Float) round(sbit uint) { + if debugFloat { + z.validate() + } + + z.acc = Exact + if z.form != finite { + // ±0 or ±Inf => nothing left to do + return + } + // z.form == finite && len(z.mant) > 0 + // m > 0 implies z.prec > 0 (checked by validate) + + m := uint32(len(z.mant)) // present mantissa length in words + bits := m * _W // present mantissa bits; bits > 0 + if bits <= z.prec { + // mantissa fits => nothing to do + return + } + // bits > z.prec + + // Rounding is based on two bits: the rounding bit (rbit) and the + // sticky bit (sbit). The rbit is the bit immediately before the + // z.prec leading mantissa bits (the "0.5"). The sbit is set if any + // of the bits before the rbit are set (the "0.25", "0.125", etc.): + // + // rbit sbit => "fractional part" + // + // 0 0 == 0 + // 0 1 > 0 , < 0.5 + // 1 0 == 0.5 + // 1 1 > 0.5, < 1.0 + + // bits > z.prec: mantissa too large => round + r := uint(bits - z.prec - 1) // rounding bit position; r >= 0 + rbit := z.mant.bit(r) & 1 // rounding bit; be safe and ensure it's a single bit + // The sticky bit is only needed for rounding ToNearestEven + // or when the rounding bit is zero. Avoid computation otherwise. + if sbit == 0 && (rbit == 0 || z.mode == ToNearestEven) { + sbit = z.mant.sticky(r) + } + sbit &= 1 // be safe and ensure it's a single bit + + // cut off extra words + n := (z.prec + (_W - 1)) / _W // mantissa length in words for desired precision + if m > n { + copy(z.mant, z.mant[m-n:]) // move n last words to front + z.mant = z.mant[:n] + } + + // determine number of trailing zero bits (ntz) and compute lsb mask of mantissa's least-significant word + ntz := n*_W - z.prec // 0 <= ntz < _W + lsb := Word(1) << ntz + + // round if result is inexact + if rbit|sbit != 0 { + // Make rounding decision: The result mantissa is truncated ("rounded down") + // by default. Decide if we need to increment, or "round up", the (unsigned) + // mantissa. + inc := false + switch z.mode { + case ToNegativeInf: + inc = z.neg + case ToZero: + // nothing to do + case ToNearestEven: + inc = rbit != 0 && (sbit != 0 || z.mant[0]&lsb != 0) + case ToNearestAway: + inc = rbit != 0 + case AwayFromZero: + inc = true + case ToPositiveInf: + inc = !z.neg + default: + panic("unreachable") + } + + // A positive result (!z.neg) is Above the exact result if we increment, + // and it's Below if we truncate (Exact results require no rounding). + // For a negative result (z.neg) it is exactly the opposite. + z.acc = makeAcc(inc != z.neg) + + if inc { + // add 1 to mantissa + if addVW(z.mant, z.mant, lsb) != 0 { + // mantissa overflow => adjust exponent + if z.exp >= MaxExp { + // exponent overflow + z.form = inf + return + } + z.exp++ + // adjust mantissa: divide by 2 to compensate for exponent adjustment + shrVU(z.mant, z.mant, 1) + // set msb == carry == 1 from the mantissa overflow above + const msb = 1 << (_W - 1) + z.mant[n-1] |= msb + } + } + } + + // zero out trailing bits in least-significant word + z.mant[0] &^= lsb - 1 + + if debugFloat { + z.validate() + } +} + +func (z *Float) setBits64(neg bool, x uint64) *Float { + if z.prec == 0 { + z.prec = 64 + } + z.acc = Exact + z.neg = neg + if x == 0 { + z.form = zero + return z + } + // x != 0 + z.form = finite + s := bits.LeadingZeros64(x) + z.mant = z.mant.setUint64(x << uint(s)) + z.exp = int32(64 - s) // always fits + if z.prec < 64 { + z.round(0) + } + return z +} + +// SetUint64 sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to 64 (and rounding will have +// no effect). +func (z *Float) SetUint64(x uint64) *Float { + return z.setBits64(false, x) +} + +// SetInt64 sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to 64 (and rounding will have +// no effect). +func (z *Float) SetInt64(x int64) *Float { + u := x + if u < 0 { + u = -u + } + // We cannot simply call z.SetUint64(uint64(u)) and change + // the sign afterwards because the sign affects rounding. + return z.setBits64(x < 0, uint64(u)) +} + +// SetFloat64 sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to 53 (and rounding will have +// no effect). SetFloat64 panics with ErrNaN if x is a NaN. +func (z *Float) SetFloat64(x float64) *Float { + if z.prec == 0 { + z.prec = 53 + } + if math.IsNaN(x) { + panic(ErrNaN{"Float.SetFloat64(NaN)"}) + } + z.acc = Exact + z.neg = math.Signbit(x) // handle -0, -Inf correctly + if x == 0 { + z.form = zero + return z + } + if math.IsInf(x, 0) { + z.form = inf + return z + } + // normalized x != 0 + z.form = finite + fmant, exp := math.Frexp(x) // get normalized mantissa + z.mant = z.mant.setUint64(1<<63 | math.Float64bits(fmant)<<11) + z.exp = int32(exp) // always fits + if z.prec < 53 { + z.round(0) + } + return z +} + +// fnorm normalizes mantissa m by shifting it to the left +// such that the msb of the most-significant word (msw) is 1. +// It returns the shift amount. It assumes that len(m) != 0. +func fnorm(m nat) int64 { + if debugFloat && (len(m) == 0 || m[len(m)-1] == 0) { + panic("msw of mantissa is 0") + } + s := nlz(m[len(m)-1]) + if s > 0 { + c := shlVU(m, m, s) + if debugFloat && c != 0 { + panic("nlz or shlVU incorrect") + } + } + return int64(s) +} + +// SetInt sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to the larger of x.BitLen() +// or 64 (and rounding will have no effect). +func (z *Float) SetInt(x *Int) *Float { + // TODO(gri) can be more efficient if z.prec > 0 + // but small compared to the size of x, or if there + // are many trailing 0's. + bits := uint32(x.BitLen()) + if z.prec == 0 { + z.prec = umax32(bits, 64) + } + z.acc = Exact + z.neg = x.neg + if len(x.abs) == 0 { + z.form = zero + return z + } + // x != 0 + z.mant = z.mant.set(x.abs) + fnorm(z.mant) + z.setExpAndRound(int64(bits), 0) + return z +} + +// SetRat sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to the largest of a.BitLen(), +// b.BitLen(), or 64; with x = a/b. +func (z *Float) SetRat(x *Rat) *Float { + if x.IsInt() { + return z.SetInt(x.Num()) + } + var a, b Float + a.SetInt(x.Num()) + b.SetInt(x.Denom()) + if z.prec == 0 { + z.prec = umax32(a.prec, b.prec) + } + return z.Quo(&a, &b) +} + +// SetInf sets z to the infinite Float -Inf if signbit is +// set, or +Inf if signbit is not set, and returns z. The +// precision of z is unchanged and the result is always +// Exact. +func (z *Float) SetInf(signbit bool) *Float { + z.acc = Exact + z.form = inf + z.neg = signbit + return z +} + +// Set sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to the precision of x +// before setting z (and rounding will have no effect). +// Rounding is performed according to z's precision and rounding +// mode; and z's accuracy reports the result error relative to the +// exact (not rounded) result. +func (z *Float) Set(x *Float) *Float { + if debugFloat { + x.validate() + } + z.acc = Exact + if z != x { + z.form = x.form + z.neg = x.neg + if x.form == finite { + z.exp = x.exp + z.mant = z.mant.set(x.mant) + } + if z.prec == 0 { + z.prec = x.prec + } else if z.prec < x.prec { + z.round(0) + } + } + return z +} + +// Copy sets z to x, with the same precision, rounding mode, and +// accuracy as x, and returns z. x is not changed even if z and +// x are the same. +func (z *Float) Copy(x *Float) *Float { + if debugFloat { + x.validate() + } + if z != x { + z.prec = x.prec + z.mode = x.mode + z.acc = x.acc + z.form = x.form + z.neg = x.neg + if z.form == finite { + z.mant = z.mant.set(x.mant) + z.exp = x.exp + } + } + return z +} + +// msb32 returns the 32 most significant bits of x. +func msb32(x nat) uint32 { + i := len(x) - 1 + if i < 0 { + return 0 + } + if debugFloat && x[i]&(1<<(_W-1)) == 0 { + panic("x not normalized") + } + switch _W { + case 32: + return uint32(x[i]) + case 64: + return uint32(x[i] >> 32) + } + panic("unreachable") +} + +// msb64 returns the 64 most significant bits of x. +func msb64(x nat) uint64 { + i := len(x) - 1 + if i < 0 { + return 0 + } + if debugFloat && x[i]&(1<<(_W-1)) == 0 { + panic("x not normalized") + } + switch _W { + case 32: + v := uint64(x[i]) << 32 + if i > 0 { + v |= uint64(x[i-1]) + } + return v + case 64: + return uint64(x[i]) + } + panic("unreachable") +} + +// Uint64 returns the unsigned integer resulting from truncating x +// towards zero. If 0 <= x <= math.MaxUint64, the result is Exact +// if x is an integer and Below otherwise. +// The result is (0, Above) for x < 0, and (math.MaxUint64, Below) +// for x > math.MaxUint64. +func (x *Float) Uint64() (uint64, Accuracy) { + if debugFloat { + x.validate() + } + + switch x.form { + case finite: + if x.neg { + return 0, Above + } + // 0 < x < +Inf + if x.exp <= 0 { + // 0 < x < 1 + return 0, Below + } + // 1 <= x < Inf + if x.exp <= 64 { + // u = trunc(x) fits into a uint64 + u := msb64(x.mant) >> (64 - uint32(x.exp)) + if x.MinPrec() <= 64 { + return u, Exact + } + return u, Below // x truncated + } + // x too large + return math.MaxUint64, Below + + case zero: + return 0, Exact + + case inf: + if x.neg { + return 0, Above + } + return math.MaxUint64, Below + } + + panic("unreachable") +} + +// Int64 returns the integer resulting from truncating x towards zero. +// If math.MinInt64 <= x <= math.MaxInt64, the result is Exact if x is +// an integer, and Above (x < 0) or Below (x > 0) otherwise. +// The result is (math.MinInt64, Above) for x < math.MinInt64, +// and (math.MaxInt64, Below) for x > math.MaxInt64. +func (x *Float) Int64() (int64, Accuracy) { + if debugFloat { + x.validate() + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + acc := makeAcc(x.neg) + if x.exp <= 0 { + // 0 < |x| < 1 + return 0, acc + } + // x.exp > 0 + + // 1 <= |x| < +Inf + if x.exp <= 63 { + // i = trunc(x) fits into an int64 (excluding math.MinInt64) + i := int64(msb64(x.mant) >> (64 - uint32(x.exp))) + if x.neg { + i = -i + } + if x.MinPrec() <= uint(x.exp) { + return i, Exact + } + return i, acc // x truncated + } + if x.neg { + // check for special case x == math.MinInt64 (i.e., x == -(0.5 << 64)) + if x.exp == 64 && x.MinPrec() == 1 { + acc = Exact + } + return math.MinInt64, acc + } + // x too large + return math.MaxInt64, Below + + case zero: + return 0, Exact + + case inf: + if x.neg { + return math.MinInt64, Above + } + return math.MaxInt64, Below + } + + panic("unreachable") +} + +// Float32 returns the float32 value nearest to x. If x is too small to be +// represented by a float32 (|x| < math.SmallestNonzeroFloat32), the result +// is (0, Below) or (-0, Above), respectively, depending on the sign of x. +// If x is too large to be represented by a float32 (|x| > math.MaxFloat32), +// the result is (+Inf, Above) or (-Inf, Below), depending on the sign of x. +func (x *Float) Float32() (float32, Accuracy) { + if debugFloat { + x.validate() + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + + const ( + fbits = 32 // float size + mbits = 23 // mantissa size (excluding implicit msb) + ebits = fbits - mbits - 1 // 8 exponent size + bias = 1<<(ebits-1) - 1 // 127 exponent bias + dmin = 1 - bias - mbits // -149 smallest unbiased exponent (denormal) + emin = 1 - bias // -126 smallest unbiased exponent (normal) + emax = bias // 127 largest unbiased exponent (normal) + ) + + // Float mantissa m is 0.5 <= m < 1.0; compute exponent e for float32 mantissa. + e := x.exp - 1 // exponent for normal mantissa m with 1.0 <= m < 2.0 + + // Compute precision p for float32 mantissa. + // If the exponent is too small, we have a denormal number before + // rounding and fewer than p mantissa bits of precision available + // (the exponent remains fixed but the mantissa gets shifted right). + p := mbits + 1 // precision of normal float + if e < emin { + // recompute precision + p = mbits + 1 - emin + int(e) + // If p == 0, the mantissa of x is shifted so much to the right + // that its msb falls immediately to the right of the float32 + // mantissa space. In other words, if the smallest denormal is + // considered "1.0", for p == 0, the mantissa value m is >= 0.5. + // If m > 0.5, it is rounded up to 1.0; i.e., the smallest denormal. + // If m == 0.5, it is rounded down to even, i.e., 0.0. + // If p < 0, the mantissa value m is <= "0.25" which is never rounded up. + if p < 0 /* m <= 0.25 */ || p == 0 && x.mant.sticky(uint(len(x.mant))*_W-1) == 0 /* m == 0.5 */ { + // underflow to ±0 + if x.neg { + var z float32 + return -z, Above + } + return 0.0, Below + } + // otherwise, round up + // We handle p == 0 explicitly because it's easy and because + // Float.round doesn't support rounding to 0 bits of precision. + if p == 0 { + if x.neg { + return -math.SmallestNonzeroFloat32, Below + } + return math.SmallestNonzeroFloat32, Above + } + } + // p > 0 + + // round + var r Float + r.prec = uint32(p) + r.Set(x) + e = r.exp - 1 + + // Rounding may have caused r to overflow to ±Inf + // (rounding never causes underflows to 0). + // If the exponent is too large, also overflow to ±Inf. + if r.form == inf || e > emax { + // overflow + if x.neg { + return float32(math.Inf(-1)), Below + } + return float32(math.Inf(+1)), Above + } + // e <= emax + + // Determine sign, biased exponent, and mantissa. + var sign, bexp, mant uint32 + if x.neg { + sign = 1 << (fbits - 1) + } + + // Rounding may have caused a denormal number to + // become normal. Check again. + if e < emin { + // denormal number: recompute precision + // Since rounding may have at best increased precision + // and we have eliminated p <= 0 early, we know p > 0. + // bexp == 0 for denormals + p = mbits + 1 - emin + int(e) + mant = msb32(r.mant) >> uint(fbits-p) + } else { + // normal number: emin <= e <= emax + bexp = uint32(e+bias) << mbits + mant = msb32(r.mant) >> ebits & (1<<mbits - 1) // cut off msb (implicit 1 bit) + } + + return math.Float32frombits(sign | bexp | mant), r.acc + + case zero: + if x.neg { + var z float32 + return -z, Exact + } + return 0.0, Exact + + case inf: + if x.neg { + return float32(math.Inf(-1)), Exact + } + return float32(math.Inf(+1)), Exact + } + + panic("unreachable") +} + +// Float64 returns the float64 value nearest to x. If x is too small to be +// represented by a float64 (|x| < math.SmallestNonzeroFloat64), the result +// is (0, Below) or (-0, Above), respectively, depending on the sign of x. +// If x is too large to be represented by a float64 (|x| > math.MaxFloat64), +// the result is (+Inf, Above) or (-Inf, Below), depending on the sign of x. +func (x *Float) Float64() (float64, Accuracy) { + if debugFloat { + x.validate() + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + + const ( + fbits = 64 // float size + mbits = 52 // mantissa size (excluding implicit msb) + ebits = fbits - mbits - 1 // 11 exponent size + bias = 1<<(ebits-1) - 1 // 1023 exponent bias + dmin = 1 - bias - mbits // -1074 smallest unbiased exponent (denormal) + emin = 1 - bias // -1022 smallest unbiased exponent (normal) + emax = bias // 1023 largest unbiased exponent (normal) + ) + + // Float mantissa m is 0.5 <= m < 1.0; compute exponent e for float64 mantissa. + e := x.exp - 1 // exponent for normal mantissa m with 1.0 <= m < 2.0 + + // Compute precision p for float64 mantissa. + // If the exponent is too small, we have a denormal number before + // rounding and fewer than p mantissa bits of precision available + // (the exponent remains fixed but the mantissa gets shifted right). + p := mbits + 1 // precision of normal float + if e < emin { + // recompute precision + p = mbits + 1 - emin + int(e) + // If p == 0, the mantissa of x is shifted so much to the right + // that its msb falls immediately to the right of the float64 + // mantissa space. In other words, if the smallest denormal is + // considered "1.0", for p == 0, the mantissa value m is >= 0.5. + // If m > 0.5, it is rounded up to 1.0; i.e., the smallest denormal. + // If m == 0.5, it is rounded down to even, i.e., 0.0. + // If p < 0, the mantissa value m is <= "0.25" which is never rounded up. + if p < 0 /* m <= 0.25 */ || p == 0 && x.mant.sticky(uint(len(x.mant))*_W-1) == 0 /* m == 0.5 */ { + // underflow to ±0 + if x.neg { + var z float64 + return -z, Above + } + return 0.0, Below + } + // otherwise, round up + // We handle p == 0 explicitly because it's easy and because + // Float.round doesn't support rounding to 0 bits of precision. + if p == 0 { + if x.neg { + return -math.SmallestNonzeroFloat64, Below + } + return math.SmallestNonzeroFloat64, Above + } + } + // p > 0 + + // round + var r Float + r.prec = uint32(p) + r.Set(x) + e = r.exp - 1 + + // Rounding may have caused r to overflow to ±Inf + // (rounding never causes underflows to 0). + // If the exponent is too large, also overflow to ±Inf. + if r.form == inf || e > emax { + // overflow + if x.neg { + return math.Inf(-1), Below + } + return math.Inf(+1), Above + } + // e <= emax + + // Determine sign, biased exponent, and mantissa. + var sign, bexp, mant uint64 + if x.neg { + sign = 1 << (fbits - 1) + } + + // Rounding may have caused a denormal number to + // become normal. Check again. + if e < emin { + // denormal number: recompute precision + // Since rounding may have at best increased precision + // and we have eliminated p <= 0 early, we know p > 0. + // bexp == 0 for denormals + p = mbits + 1 - emin + int(e) + mant = msb64(r.mant) >> uint(fbits-p) + } else { + // normal number: emin <= e <= emax + bexp = uint64(e+bias) << mbits + mant = msb64(r.mant) >> ebits & (1<<mbits - 1) // cut off msb (implicit 1 bit) + } + + return math.Float64frombits(sign | bexp | mant), r.acc + + case zero: + if x.neg { + var z float64 + return -z, Exact + } + return 0.0, Exact + + case inf: + if x.neg { + return math.Inf(-1), Exact + } + return math.Inf(+1), Exact + } + + panic("unreachable") +} + +// Int returns the result of truncating x towards zero; +// or nil if x is an infinity. +// The result is Exact if x.IsInt(); otherwise it is Below +// for x > 0, and Above for x < 0. +// If a non-nil *Int argument z is provided, Int stores +// the result in z instead of allocating a new Int. +func (x *Float) Int(z *Int) (*Int, Accuracy) { + if debugFloat { + x.validate() + } + + if z == nil && x.form <= finite { + z = new(Int) + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + acc := makeAcc(x.neg) + if x.exp <= 0 { + // 0 < |x| < 1 + return z.SetInt64(0), acc + } + // x.exp > 0 + + // 1 <= |x| < +Inf + // determine minimum required precision for x + allBits := uint(len(x.mant)) * _W + exp := uint(x.exp) + if x.MinPrec() <= exp { + acc = Exact + } + // shift mantissa as needed + if z == nil { + z = new(Int) + } + z.neg = x.neg + switch { + case exp > allBits: + z.abs = z.abs.shl(x.mant, exp-allBits) + default: + z.abs = z.abs.set(x.mant) + case exp < allBits: + z.abs = z.abs.shr(x.mant, allBits-exp) + } + return z, acc + + case zero: + return z.SetInt64(0), Exact + + case inf: + return nil, makeAcc(x.neg) + } + + panic("unreachable") +} + +// Rat returns the rational number corresponding to x; +// or nil if x is an infinity. +// The result is Exact if x is not an Inf. +// If a non-nil *Rat argument z is provided, Rat stores +// the result in z instead of allocating a new Rat. +func (x *Float) Rat(z *Rat) (*Rat, Accuracy) { + if debugFloat { + x.validate() + } + + if z == nil && x.form <= finite { + z = new(Rat) + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + allBits := int32(len(x.mant)) * _W + // build up numerator and denominator + z.a.neg = x.neg + switch { + case x.exp > allBits: + z.a.abs = z.a.abs.shl(x.mant, uint(x.exp-allBits)) + z.b.abs = z.b.abs[:0] // == 1 (see Rat) + // z already in normal form + default: + z.a.abs = z.a.abs.set(x.mant) + z.b.abs = z.b.abs[:0] // == 1 (see Rat) + // z already in normal form + case x.exp < allBits: + z.a.abs = z.a.abs.set(x.mant) + t := z.b.abs.setUint64(1) + z.b.abs = t.shl(t, uint(allBits-x.exp)) + z.norm() + } + return z, Exact + + case zero: + return z.SetInt64(0), Exact + + case inf: + return nil, makeAcc(x.neg) + } + + panic("unreachable") +} + +// Abs sets z to the (possibly rounded) value |x| (the absolute value of x) +// and returns z. +func (z *Float) Abs(x *Float) *Float { + z.Set(x) + z.neg = false + return z +} + +// Neg sets z to the (possibly rounded) value of x with its sign negated, +// and returns z. +func (z *Float) Neg(x *Float) *Float { + z.Set(x) + z.neg = !z.neg + return z +} + +func validateBinaryOperands(x, y *Float) { + if !debugFloat { + // avoid performance bugs + panic("validateBinaryOperands called but debugFloat is not set") + } + if len(x.mant) == 0 { + panic("empty mantissa for x") + } + if len(y.mant) == 0 { + panic("empty mantissa for y") + } +} + +// z = x + y, ignoring signs of x and y for the addition +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. +func (z *Float) uadd(x, y *Float) { + // Note: This implementation requires 2 shifts most of the + // time. It is also inefficient if exponents or precisions + // differ by wide margins. The following article describes + // an efficient (but much more complicated) implementation + // compatible with the internal representation used here: + // + // Vincent Lefèvre: "The Generic Multiple-Precision Floating- + // Point Addition With Exact Rounding (as in the MPFR Library)" + // http://www.vinc17.net/research/papers/rnc6.pdf + + if debugFloat { + validateBinaryOperands(x, y) + } + + // compute exponents ex, ey for mantissa with "binary point" + // on the right (mantissa.0) - use int64 to avoid overflow + ex := int64(x.exp) - int64(len(x.mant))*_W + ey := int64(y.exp) - int64(len(y.mant))*_W + + al := alias(z.mant, x.mant) || alias(z.mant, y.mant) + + // TODO(gri) having a combined add-and-shift primitive + // could make this code significantly faster + switch { + case ex < ey: + if al { + t := nat(nil).shl(y.mant, uint(ey-ex)) + z.mant = z.mant.add(x.mant, t) + } else { + z.mant = z.mant.shl(y.mant, uint(ey-ex)) + z.mant = z.mant.add(x.mant, z.mant) + } + default: + // ex == ey, no shift needed + z.mant = z.mant.add(x.mant, y.mant) + case ex > ey: + if al { + t := nat(nil).shl(x.mant, uint(ex-ey)) + z.mant = z.mant.add(t, y.mant) + } else { + z.mant = z.mant.shl(x.mant, uint(ex-ey)) + z.mant = z.mant.add(z.mant, y.mant) + } + ex = ey + } + // len(z.mant) > 0 + + z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0) +} + +// z = x - y for |x| > |y|, ignoring signs of x and y for the subtraction +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. +func (z *Float) usub(x, y *Float) { + // This code is symmetric to uadd. + // We have not factored the common code out because + // eventually uadd (and usub) should be optimized + // by special-casing, and the code will diverge. + + if debugFloat { + validateBinaryOperands(x, y) + } + + ex := int64(x.exp) - int64(len(x.mant))*_W + ey := int64(y.exp) - int64(len(y.mant))*_W + + al := alias(z.mant, x.mant) || alias(z.mant, y.mant) + + switch { + case ex < ey: + if al { + t := nat(nil).shl(y.mant, uint(ey-ex)) + z.mant = t.sub(x.mant, t) + } else { + z.mant = z.mant.shl(y.mant, uint(ey-ex)) + z.mant = z.mant.sub(x.mant, z.mant) + } + default: + // ex == ey, no shift needed + z.mant = z.mant.sub(x.mant, y.mant) + case ex > ey: + if al { + t := nat(nil).shl(x.mant, uint(ex-ey)) + z.mant = t.sub(t, y.mant) + } else { + z.mant = z.mant.shl(x.mant, uint(ex-ey)) + z.mant = z.mant.sub(z.mant, y.mant) + } + ex = ey + } + + // operands may have canceled each other out + if len(z.mant) == 0 { + z.acc = Exact + z.form = zero + z.neg = false + return + } + // len(z.mant) > 0 + + z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0) +} + +// z = x * y, ignoring signs of x and y for the multiplication +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. +func (z *Float) umul(x, y *Float) { + if debugFloat { + validateBinaryOperands(x, y) + } + + // Note: This is doing too much work if the precision + // of z is less than the sum of the precisions of x + // and y which is often the case (e.g., if all floats + // have the same precision). + // TODO(gri) Optimize this for the common case. + + e := int64(x.exp) + int64(y.exp) + if x == y { + z.mant = z.mant.sqr(x.mant) + } else { + z.mant = z.mant.mul(x.mant, y.mant) + } + z.setExpAndRound(e-fnorm(z.mant), 0) +} + +// z = x / y, ignoring signs of x and y for the division +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. +func (z *Float) uquo(x, y *Float) { + if debugFloat { + validateBinaryOperands(x, y) + } + + // mantissa length in words for desired result precision + 1 + // (at least one extra bit so we get the rounding bit after + // the division) + n := int(z.prec/_W) + 1 + + // compute adjusted x.mant such that we get enough result precision + xadj := x.mant + if d := n - len(x.mant) + len(y.mant); d > 0 { + // d extra words needed => add d "0 digits" to x + xadj = make(nat, len(x.mant)+d) + copy(xadj[d:], x.mant) + } + // TODO(gri): If we have too many digits (d < 0), we should be able + // to shorten x for faster division. But we must be extra careful + // with rounding in that case. + + // Compute d before division since there may be aliasing of x.mant + // (via xadj) or y.mant with z.mant. + d := len(xadj) - len(y.mant) + + // divide + var r nat + z.mant, r = z.mant.div(nil, xadj, y.mant) + e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W + + // The result is long enough to include (at least) the rounding bit. + // If there's a non-zero remainder, the corresponding fractional part + // (if it were computed), would have a non-zero sticky bit (if it were + // zero, it couldn't have a non-zero remainder). + var sbit uint + if len(r) > 0 { + sbit = 1 + } + + z.setExpAndRound(e-fnorm(z.mant), sbit) +} + +// ucmp returns -1, 0, or +1, depending on whether +// |x| < |y|, |x| == |y|, or |x| > |y|. +// x and y must have a non-empty mantissa and valid exponent. +func (x *Float) ucmp(y *Float) int { + if debugFloat { + validateBinaryOperands(x, y) + } + + switch { + case x.exp < y.exp: + return -1 + case x.exp > y.exp: + return +1 + } + // x.exp == y.exp + + // compare mantissas + i := len(x.mant) + j := len(y.mant) + for i > 0 || j > 0 { + var xm, ym Word + if i > 0 { + i-- + xm = x.mant[i] + } + if j > 0 { + j-- + ym = y.mant[j] + } + switch { + case xm < ym: + return -1 + case xm > ym: + return +1 + } + } + + return 0 +} + +// Handling of sign bit as defined by IEEE 754-2008, section 6.3: +// +// When neither the inputs nor result are NaN, the sign of a product or +// quotient is the exclusive OR of the operands’ signs; the sign of a sum, +// or of a difference x−y regarded as a sum x+(−y), differs from at most +// one of the addends’ signs; and the sign of the result of conversions, +// the quantize operation, the roundToIntegral operations, and the +// roundToIntegralExact (see 5.3.1) is the sign of the first or only operand. +// These rules shall apply even when operands or results are zero or infinite. +// +// When the sum of two operands with opposite signs (or the difference of +// two operands with like signs) is exactly zero, the sign of that sum (or +// difference) shall be +0 in all rounding-direction attributes except +// roundTowardNegative; under that attribute, the sign of an exact zero +// sum (or difference) shall be −0. However, x+x = x−(−x) retains the same +// sign as x even when x is zero. +// +// See also: https://play.golang.org/p/RtH3UCt5IH + +// Add sets z to the rounded sum x+y and returns z. If z's precision is 0, +// it is changed to the larger of x's or y's precision before the operation. +// Rounding is performed according to z's precision and rounding mode; and +// z's accuracy reports the result error relative to the exact (not rounded) +// result. Add panics with ErrNaN if x and y are infinities with opposite +// signs. The value of z is undefined in that case. +func (z *Float) Add(x, y *Float) *Float { + if debugFloat { + x.validate() + y.validate() + } + + if z.prec == 0 { + z.prec = umax32(x.prec, y.prec) + } + + if x.form == finite && y.form == finite { + // x + y (common case) + + // Below we set z.neg = x.neg, and when z aliases y this will + // change the y operand's sign. This is fine, because if an + // operand aliases the receiver it'll be overwritten, but we still + // want the original x.neg and y.neg values when we evaluate + // x.neg != y.neg, so we need to save y.neg before setting z.neg. + yneg := y.neg + + z.neg = x.neg + if x.neg == yneg { + // x + y == x + y + // (-x) + (-y) == -(x + y) + z.uadd(x, y) + } else { + // x + (-y) == x - y == -(y - x) + // (-x) + y == y - x == -(x - y) + if x.ucmp(y) > 0 { + z.usub(x, y) + } else { + z.neg = !z.neg + z.usub(y, x) + } + } + if z.form == zero && z.mode == ToNegativeInf && z.acc == Exact { + z.neg = true + } + return z + } + + if x.form == inf && y.form == inf && x.neg != y.neg { + // +Inf + -Inf + // -Inf + +Inf + // value of z is undefined but make sure it's valid + z.acc = Exact + z.form = zero + z.neg = false + panic(ErrNaN{"addition of infinities with opposite signs"}) + } + + if x.form == zero && y.form == zero { + // ±0 + ±0 + z.acc = Exact + z.form = zero + z.neg = x.neg && y.neg // -0 + -0 == -0 + return z + } + + if x.form == inf || y.form == zero { + // ±Inf + y + // x + ±0 + return z.Set(x) + } + + // ±0 + y + // x + ±Inf + return z.Set(y) +} + +// Sub sets z to the rounded difference x-y and returns z. +// Precision, rounding, and accuracy reporting are as for Add. +// Sub panics with ErrNaN if x and y are infinities with equal +// signs. The value of z is undefined in that case. +func (z *Float) Sub(x, y *Float) *Float { + if debugFloat { + x.validate() + y.validate() + } + + if z.prec == 0 { + z.prec = umax32(x.prec, y.prec) + } + + if x.form == finite && y.form == finite { + // x - y (common case) + yneg := y.neg + z.neg = x.neg + if x.neg != yneg { + // x - (-y) == x + y + // (-x) - y == -(x + y) + z.uadd(x, y) + } else { + // x - y == x - y == -(y - x) + // (-x) - (-y) == y - x == -(x - y) + if x.ucmp(y) > 0 { + z.usub(x, y) + } else { + z.neg = !z.neg + z.usub(y, x) + } + } + if z.form == zero && z.mode == ToNegativeInf && z.acc == Exact { + z.neg = true + } + return z + } + + if x.form == inf && y.form == inf && x.neg == y.neg { + // +Inf - +Inf + // -Inf - -Inf + // value of z is undefined but make sure it's valid + z.acc = Exact + z.form = zero + z.neg = false + panic(ErrNaN{"subtraction of infinities with equal signs"}) + } + + if x.form == zero && y.form == zero { + // ±0 - ±0 + z.acc = Exact + z.form = zero + z.neg = x.neg && !y.neg // -0 - +0 == -0 + return z + } + + if x.form == inf || y.form == zero { + // ±Inf - y + // x - ±0 + return z.Set(x) + } + + // ±0 - y + // x - ±Inf + return z.Neg(y) +} + +// Mul sets z to the rounded product x*y and returns z. +// Precision, rounding, and accuracy reporting are as for Add. +// Mul panics with ErrNaN if one operand is zero and the other +// operand an infinity. The value of z is undefined in that case. +func (z *Float) Mul(x, y *Float) *Float { + if debugFloat { + x.validate() + y.validate() + } + + if z.prec == 0 { + z.prec = umax32(x.prec, y.prec) + } + + z.neg = x.neg != y.neg + + if x.form == finite && y.form == finite { + // x * y (common case) + z.umul(x, y) + return z + } + + z.acc = Exact + if x.form == zero && y.form == inf || x.form == inf && y.form == zero { + // ±0 * ±Inf + // ±Inf * ±0 + // value of z is undefined but make sure it's valid + z.form = zero + z.neg = false + panic(ErrNaN{"multiplication of zero with infinity"}) + } + + if x.form == inf || y.form == inf { + // ±Inf * y + // x * ±Inf + z.form = inf + return z + } + + // ±0 * y + // x * ±0 + z.form = zero + return z +} + +// Quo sets z to the rounded quotient x/y and returns z. +// Precision, rounding, and accuracy reporting are as for Add. +// Quo panics with ErrNaN if both operands are zero or infinities. +// The value of z is undefined in that case. +func (z *Float) Quo(x, y *Float) *Float { + if debugFloat { + x.validate() + y.validate() + } + + if z.prec == 0 { + z.prec = umax32(x.prec, y.prec) + } + + z.neg = x.neg != y.neg + + if x.form == finite && y.form == finite { + // x / y (common case) + z.uquo(x, y) + return z + } + + z.acc = Exact + if x.form == zero && y.form == zero || x.form == inf && y.form == inf { + // ±0 / ±0 + // ±Inf / ±Inf + // value of z is undefined but make sure it's valid + z.form = zero + z.neg = false + panic(ErrNaN{"division of zero by zero or infinity by infinity"}) + } + + if x.form == zero || y.form == inf { + // ±0 / y + // x / ±Inf + z.form = zero + return z + } + + // x / ±0 + // ±Inf / y + z.form = inf + return z +} + +// Cmp compares x and y and returns: +// +// -1 if x < y +// 0 if x == y (incl. -0 == 0, -Inf == -Inf, and +Inf == +Inf) +// +1 if x > y +func (x *Float) Cmp(y *Float) int { + if debugFloat { + x.validate() + y.validate() + } + + mx := x.ord() + my := y.ord() + switch { + case mx < my: + return -1 + case mx > my: + return +1 + } + // mx == my + + // only if |mx| == 1 we have to compare the mantissae + switch mx { + case -1: + return y.ucmp(x) + case +1: + return x.ucmp(y) + } + + return 0 +} + +// ord classifies x and returns: +// +// -2 if -Inf == x +// -1 if -Inf < x < 0 +// 0 if x == 0 (signed or unsigned) +// +1 if 0 < x < +Inf +// +2 if x == +Inf +func (x *Float) ord() int { + var m int + switch x.form { + case finite: + m = 1 + case zero: + return 0 + case inf: + m = 2 + } + if x.neg { + m = -m + } + return m +} + +func umax32(x, y uint32) uint32 { + if x > y { + return x + } + return y +} diff --git a/src/math/big/float_test.go b/src/math/big/float_test.go new file mode 100644 index 0000000..7d6bf03 --- /dev/null +++ b/src/math/big/float_test.go @@ -0,0 +1,1858 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "flag" + "fmt" + "math" + "strconv" + "strings" + "testing" +) + +// Verify that ErrNaN implements the error interface. +var _ error = ErrNaN{} + +func (x *Float) uint64() uint64 { + u, acc := x.Uint64() + if acc != Exact { + panic(fmt.Sprintf("%s is not a uint64", x.Text('g', 10))) + } + return u +} + +func (x *Float) int64() int64 { + i, acc := x.Int64() + if acc != Exact { + panic(fmt.Sprintf("%s is not an int64", x.Text('g', 10))) + } + return i +} + +func TestFloatZeroValue(t *testing.T) { + // zero (uninitialized) value is a ready-to-use 0.0 + var x Float + if s := x.Text('f', 1); s != "0.0" { + t.Errorf("zero value = %s; want 0.0", s) + } + + // zero value has precision 0 + if prec := x.Prec(); prec != 0 { + t.Errorf("prec = %d; want 0", prec) + } + + // zero value can be used in any and all positions of binary operations + make := func(x int) *Float { + var f Float + if x != 0 { + f.SetInt64(int64(x)) + } + // x == 0 translates into the zero value + return &f + } + for _, test := range []struct { + z, x, y, want int + opname rune + op func(z, x, y *Float) *Float + }{ + {0, 0, 0, 0, '+', (*Float).Add}, + {0, 1, 2, 3, '+', (*Float).Add}, + {1, 2, 0, 2, '+', (*Float).Add}, + {2, 0, 1, 1, '+', (*Float).Add}, + + {0, 0, 0, 0, '-', (*Float).Sub}, + {0, 1, 2, -1, '-', (*Float).Sub}, + {1, 2, 0, 2, '-', (*Float).Sub}, + {2, 0, 1, -1, '-', (*Float).Sub}, + + {0, 0, 0, 0, '*', (*Float).Mul}, + {0, 1, 2, 2, '*', (*Float).Mul}, + {1, 2, 0, 0, '*', (*Float).Mul}, + {2, 0, 1, 0, '*', (*Float).Mul}, + + // {0, 0, 0, 0, '/', (*Float).Quo}, // panics + {0, 2, 1, 2, '/', (*Float).Quo}, + {1, 2, 0, 0, '/', (*Float).Quo}, // = +Inf + {2, 0, 1, 0, '/', (*Float).Quo}, + } { + z := make(test.z) + test.op(z, make(test.x), make(test.y)) + got := 0 + if !z.IsInf() { + got = int(z.int64()) + } + if got != test.want { + t.Errorf("%d %c %d = %d; want %d", test.x, test.opname, test.y, got, test.want) + } + } + + // TODO(gri) test how precision is set for zero value results +} + +func makeFloat(s string) *Float { + x, _, err := ParseFloat(s, 0, 1000, ToNearestEven) + if err != nil { + panic(err) + } + return x +} + +func TestFloatSetPrec(t *testing.T) { + for _, test := range []struct { + x string + prec uint + want string + acc Accuracy + }{ + // prec 0 + {"0", 0, "0", Exact}, + {"-0", 0, "-0", Exact}, + {"-Inf", 0, "-Inf", Exact}, + {"+Inf", 0, "+Inf", Exact}, + {"123", 0, "0", Below}, + {"-123", 0, "-0", Above}, + + // prec at upper limit + {"0", MaxPrec, "0", Exact}, + {"-0", MaxPrec, "-0", Exact}, + {"-Inf", MaxPrec, "-Inf", Exact}, + {"+Inf", MaxPrec, "+Inf", Exact}, + + // just a few regular cases - general rounding is tested elsewhere + {"1.5", 1, "2", Above}, + {"-1.5", 1, "-2", Below}, + {"123", 1e6, "123", Exact}, + {"-123", 1e6, "-123", Exact}, + } { + x := makeFloat(test.x).SetPrec(test.prec) + prec := test.prec + if prec > MaxPrec { + prec = MaxPrec + } + if got := x.Prec(); got != prec { + t.Errorf("%s.SetPrec(%d).Prec() == %d; want %d", test.x, test.prec, got, prec) + } + if got, acc := x.String(), x.Acc(); got != test.want || acc != test.acc { + t.Errorf("%s.SetPrec(%d) = %s (%s); want %s (%s)", test.x, test.prec, got, acc, test.want, test.acc) + } + } +} + +func TestFloatMinPrec(t *testing.T) { + const max = 100 + for _, test := range []struct { + x string + want uint + }{ + {"0", 0}, + {"-0", 0}, + {"+Inf", 0}, + {"-Inf", 0}, + {"1", 1}, + {"2", 1}, + {"3", 2}, + {"0x8001", 16}, + {"0x8001p-1000", 16}, + {"0x8001p+1000", 16}, + {"0.1", max}, + } { + x := makeFloat(test.x).SetPrec(max) + if got := x.MinPrec(); got != test.want { + t.Errorf("%s.MinPrec() = %d; want %d", test.x, got, test.want) + } + } +} + +func TestFloatSign(t *testing.T) { + for _, test := range []struct { + x string + s int + }{ + {"-Inf", -1}, + {"-1", -1}, + {"-0", 0}, + {"+0", 0}, + {"+1", +1}, + {"+Inf", +1}, + } { + x := makeFloat(test.x) + s := x.Sign() + if s != test.s { + t.Errorf("%s.Sign() = %d; want %d", test.x, s, test.s) + } + } +} + +// alike(x, y) is like x.Cmp(y) == 0 but also considers the sign of 0 (0 != -0). +func alike(x, y *Float) bool { + return x.Cmp(y) == 0 && x.Signbit() == y.Signbit() +} + +func alike32(x, y float32) bool { + // we can ignore NaNs + return x == y && math.Signbit(float64(x)) == math.Signbit(float64(y)) + +} + +func alike64(x, y float64) bool { + // we can ignore NaNs + return x == y && math.Signbit(x) == math.Signbit(y) + +} + +func TestFloatMantExp(t *testing.T) { + for _, test := range []struct { + x string + mant string + exp int + }{ + {"0", "0", 0}, + {"+0", "0", 0}, + {"-0", "-0", 0}, + {"Inf", "+Inf", 0}, + {"+Inf", "+Inf", 0}, + {"-Inf", "-Inf", 0}, + {"1.5", "0.75", 1}, + {"1.024e3", "0.5", 11}, + {"-0.125", "-0.5", -2}, + } { + x := makeFloat(test.x) + mant := makeFloat(test.mant) + m := new(Float) + e := x.MantExp(m) + if !alike(m, mant) || e != test.exp { + t.Errorf("%s.MantExp() = %s, %d; want %s, %d", test.x, m.Text('g', 10), e, test.mant, test.exp) + } + } +} + +func TestFloatMantExpAliasing(t *testing.T) { + x := makeFloat("0.5p10") + if e := x.MantExp(x); e != 10 { + t.Fatalf("Float.MantExp aliasing error: got %d; want 10", e) + } + if want := makeFloat("0.5"); !alike(x, want) { + t.Fatalf("Float.MantExp aliasing error: got %s; want %s", x.Text('g', 10), want.Text('g', 10)) + } +} + +func TestFloatSetMantExp(t *testing.T) { + for _, test := range []struct { + frac string + exp int + z string + }{ + {"0", 0, "0"}, + {"+0", 0, "0"}, + {"-0", 0, "-0"}, + {"Inf", 1234, "+Inf"}, + {"+Inf", -1234, "+Inf"}, + {"-Inf", -1234, "-Inf"}, + {"0", MinExp, "0"}, + {"0.25", MinExp, "+0"}, // exponent underflow + {"-0.25", MinExp, "-0"}, // exponent underflow + {"1", MaxExp, "+Inf"}, // exponent overflow + {"2", MaxExp - 1, "+Inf"}, // exponent overflow + {"0.75", 1, "1.5"}, + {"0.5", 11, "1024"}, + {"-0.5", -2, "-0.125"}, + {"32", 5, "1024"}, + {"1024", -10, "1"}, + } { + frac := makeFloat(test.frac) + want := makeFloat(test.z) + var z Float + z.SetMantExp(frac, test.exp) + if !alike(&z, want) { + t.Errorf("SetMantExp(%s, %d) = %s; want %s", test.frac, test.exp, z.Text('g', 10), test.z) + } + // test inverse property + mant := new(Float) + if z.SetMantExp(mant, want.MantExp(mant)).Cmp(want) != 0 { + t.Errorf("Inverse property not satisfied: got %s; want %s", z.Text('g', 10), test.z) + } + } +} + +func TestFloatPredicates(t *testing.T) { + for _, test := range []struct { + x string + sign int + signbit, inf bool + }{ + {x: "-Inf", sign: -1, signbit: true, inf: true}, + {x: "-1", sign: -1, signbit: true}, + {x: "-0", signbit: true}, + {x: "0"}, + {x: "1", sign: 1}, + {x: "+Inf", sign: 1, inf: true}, + } { + x := makeFloat(test.x) + if got := x.Signbit(); got != test.signbit { + t.Errorf("(%s).Signbit() = %v; want %v", test.x, got, test.signbit) + } + if got := x.Sign(); got != test.sign { + t.Errorf("(%s).Sign() = %d; want %d", test.x, got, test.sign) + } + if got := x.IsInf(); got != test.inf { + t.Errorf("(%s).IsInf() = %v; want %v", test.x, got, test.inf) + } + } +} + +func TestFloatIsInt(t *testing.T) { + for _, test := range []string{ + "0 int", + "-0 int", + "1 int", + "-1 int", + "0.5", + "1.23", + "1.23e1", + "1.23e2 int", + "0.000000001e+8", + "0.000000001e+9 int", + "1.2345e200 int", + "Inf", + "+Inf", + "-Inf", + } { + s := strings.TrimSuffix(test, " int") + want := s != test + if got := makeFloat(s).IsInt(); got != want { + t.Errorf("%s.IsInt() == %t", s, got) + } + } +} + +func fromBinary(s string) int64 { + x, err := strconv.ParseInt(s, 2, 64) + if err != nil { + panic(err) + } + return x +} + +func toBinary(x int64) string { + return strconv.FormatInt(x, 2) +} + +func testFloatRound(t *testing.T, x, r int64, prec uint, mode RoundingMode) { + // verify test data + var ok bool + switch mode { + case ToNearestEven, ToNearestAway: + ok = true // nothing to do for now + case ToZero: + if x < 0 { + ok = r >= x + } else { + ok = r <= x + } + case AwayFromZero: + if x < 0 { + ok = r <= x + } else { + ok = r >= x + } + case ToNegativeInf: + ok = r <= x + case ToPositiveInf: + ok = r >= x + default: + panic("unreachable") + } + if !ok { + t.Fatalf("incorrect test data for prec = %d, %s: x = %s, r = %s", prec, mode, toBinary(x), toBinary(r)) + } + + // compute expected accuracy + a := Exact + switch { + case r < x: + a = Below + case r > x: + a = Above + } + + // round + f := new(Float).SetMode(mode).SetInt64(x).SetPrec(prec) + + // check result + r1 := f.int64() + p1 := f.Prec() + a1 := f.Acc() + if r1 != r || p1 != prec || a1 != a { + t.Errorf("round %s (%d bits, %s) incorrect: got %s (%d bits, %s); want %s (%d bits, %s)", + toBinary(x), prec, mode, + toBinary(r1), p1, a1, + toBinary(r), prec, a) + return + } + + // g and f should be the same + // (rounding by SetPrec after SetInt64 using default precision + // should be the same as rounding by SetInt64 after setting the + // precision) + g := new(Float).SetMode(mode).SetPrec(prec).SetInt64(x) + if !alike(g, f) { + t.Errorf("round %s (%d bits, %s) not symmetric: got %s and %s; want %s", + toBinary(x), prec, mode, + toBinary(g.int64()), + toBinary(r1), + toBinary(r), + ) + return + } + + // h and f should be the same + // (repeated rounding should be idempotent) + h := new(Float).SetMode(mode).SetPrec(prec).Set(f) + if !alike(h, f) { + t.Errorf("round %s (%d bits, %s) not idempotent: got %s and %s; want %s", + toBinary(x), prec, mode, + toBinary(h.int64()), + toBinary(r1), + toBinary(r), + ) + return + } +} + +// TestFloatRound tests basic rounding. +func TestFloatRound(t *testing.T) { + for _, test := range []struct { + prec uint + x, zero, neven, naway, away string // input, results rounded to prec bits + }{ + {5, "1000", "1000", "1000", "1000", "1000"}, + {5, "1001", "1001", "1001", "1001", "1001"}, + {5, "1010", "1010", "1010", "1010", "1010"}, + {5, "1011", "1011", "1011", "1011", "1011"}, + {5, "1100", "1100", "1100", "1100", "1100"}, + {5, "1101", "1101", "1101", "1101", "1101"}, + {5, "1110", "1110", "1110", "1110", "1110"}, + {5, "1111", "1111", "1111", "1111", "1111"}, + + {4, "1000", "1000", "1000", "1000", "1000"}, + {4, "1001", "1001", "1001", "1001", "1001"}, + {4, "1010", "1010", "1010", "1010", "1010"}, + {4, "1011", "1011", "1011", "1011", "1011"}, + {4, "1100", "1100", "1100", "1100", "1100"}, + {4, "1101", "1101", "1101", "1101", "1101"}, + {4, "1110", "1110", "1110", "1110", "1110"}, + {4, "1111", "1111", "1111", "1111", "1111"}, + + {3, "1000", "1000", "1000", "1000", "1000"}, + {3, "1001", "1000", "1000", "1010", "1010"}, + {3, "1010", "1010", "1010", "1010", "1010"}, + {3, "1011", "1010", "1100", "1100", "1100"}, + {3, "1100", "1100", "1100", "1100", "1100"}, + {3, "1101", "1100", "1100", "1110", "1110"}, + {3, "1110", "1110", "1110", "1110", "1110"}, + {3, "1111", "1110", "10000", "10000", "10000"}, + + {3, "1000001", "1000000", "1000000", "1000000", "1010000"}, + {3, "1001001", "1000000", "1010000", "1010000", "1010000"}, + {3, "1010001", "1010000", "1010000", "1010000", "1100000"}, + {3, "1011001", "1010000", "1100000", "1100000", "1100000"}, + {3, "1100001", "1100000", "1100000", "1100000", "1110000"}, + {3, "1101001", "1100000", "1110000", "1110000", "1110000"}, + {3, "1110001", "1110000", "1110000", "1110000", "10000000"}, + {3, "1111001", "1110000", "10000000", "10000000", "10000000"}, + + {2, "1000", "1000", "1000", "1000", "1000"}, + {2, "1001", "1000", "1000", "1000", "1100"}, + {2, "1010", "1000", "1000", "1100", "1100"}, + {2, "1011", "1000", "1100", "1100", "1100"}, + {2, "1100", "1100", "1100", "1100", "1100"}, + {2, "1101", "1100", "1100", "1100", "10000"}, + {2, "1110", "1100", "10000", "10000", "10000"}, + {2, "1111", "1100", "10000", "10000", "10000"}, + + {2, "1000001", "1000000", "1000000", "1000000", "1100000"}, + {2, "1001001", "1000000", "1000000", "1000000", "1100000"}, + {2, "1010001", "1000000", "1100000", "1100000", "1100000"}, + {2, "1011001", "1000000", "1100000", "1100000", "1100000"}, + {2, "1100001", "1100000", "1100000", "1100000", "10000000"}, + {2, "1101001", "1100000", "1100000", "1100000", "10000000"}, + {2, "1110001", "1100000", "10000000", "10000000", "10000000"}, + {2, "1111001", "1100000", "10000000", "10000000", "10000000"}, + + {1, "1000", "1000", "1000", "1000", "1000"}, + {1, "1001", "1000", "1000", "1000", "10000"}, + {1, "1010", "1000", "1000", "1000", "10000"}, + {1, "1011", "1000", "1000", "1000", "10000"}, + {1, "1100", "1000", "10000", "10000", "10000"}, + {1, "1101", "1000", "10000", "10000", "10000"}, + {1, "1110", "1000", "10000", "10000", "10000"}, + {1, "1111", "1000", "10000", "10000", "10000"}, + + {1, "1000001", "1000000", "1000000", "1000000", "10000000"}, + {1, "1001001", "1000000", "1000000", "1000000", "10000000"}, + {1, "1010001", "1000000", "1000000", "1000000", "10000000"}, + {1, "1011001", "1000000", "1000000", "1000000", "10000000"}, + {1, "1100001", "1000000", "10000000", "10000000", "10000000"}, + {1, "1101001", "1000000", "10000000", "10000000", "10000000"}, + {1, "1110001", "1000000", "10000000", "10000000", "10000000"}, + {1, "1111001", "1000000", "10000000", "10000000", "10000000"}, + } { + x := fromBinary(test.x) + z := fromBinary(test.zero) + e := fromBinary(test.neven) + n := fromBinary(test.naway) + a := fromBinary(test.away) + prec := test.prec + + testFloatRound(t, x, z, prec, ToZero) + testFloatRound(t, x, e, prec, ToNearestEven) + testFloatRound(t, x, n, prec, ToNearestAway) + testFloatRound(t, x, a, prec, AwayFromZero) + + testFloatRound(t, x, z, prec, ToNegativeInf) + testFloatRound(t, x, a, prec, ToPositiveInf) + + testFloatRound(t, -x, -a, prec, ToNegativeInf) + testFloatRound(t, -x, -z, prec, ToPositiveInf) + } +} + +// TestFloatRound24 tests that rounding a float64 to 24 bits +// matches IEEE-754 rounding to nearest when converting a +// float64 to a float32 (excluding denormal numbers). +func TestFloatRound24(t *testing.T) { + const x0 = 1<<26 - 0x10 // 11...110000 (26 bits) + for d := 0; d <= 0x10; d++ { + x := float64(x0 + d) + f := new(Float).SetPrec(24).SetFloat64(x) + got, _ := f.Float32() + want := float32(x) + if got != want { + t.Errorf("Round(%g, 24) = %g; want %g", x, got, want) + } + } +} + +func TestFloatSetUint64(t *testing.T) { + for _, want := range []uint64{ + 0, + 1, + 2, + 10, + 100, + 1<<32 - 1, + 1 << 32, + 1<<64 - 1, + } { + var f Float + f.SetUint64(want) + if got := f.uint64(); got != want { + t.Errorf("got %#x (%s); want %#x", got, f.Text('p', 0), want) + } + } + + // test basic rounding behavior (exhaustive rounding testing is done elsewhere) + const x uint64 = 0x8765432187654321 // 64 bits needed + for prec := uint(1); prec <= 64; prec++ { + f := new(Float).SetPrec(prec).SetMode(ToZero).SetUint64(x) + got := f.uint64() + want := x &^ (1<<(64-prec) - 1) // cut off (round to zero) low 64-prec bits + if got != want { + t.Errorf("got %#x (%s); want %#x", got, f.Text('p', 0), want) + } + } +} + +func TestFloatSetInt64(t *testing.T) { + for _, want := range []int64{ + 0, + 1, + 2, + 10, + 100, + 1<<32 - 1, + 1 << 32, + 1<<63 - 1, + } { + for i := range [2]int{} { + if i&1 != 0 { + want = -want + } + var f Float + f.SetInt64(want) + if got := f.int64(); got != want { + t.Errorf("got %#x (%s); want %#x", got, f.Text('p', 0), want) + } + } + } + + // test basic rounding behavior (exhaustive rounding testing is done elsewhere) + const x int64 = 0x7654321076543210 // 63 bits needed + for prec := uint(1); prec <= 63; prec++ { + f := new(Float).SetPrec(prec).SetMode(ToZero).SetInt64(x) + got := f.int64() + want := x &^ (1<<(63-prec) - 1) // cut off (round to zero) low 63-prec bits + if got != want { + t.Errorf("got %#x (%s); want %#x", got, f.Text('p', 0), want) + } + } +} + +func TestFloatSetFloat64(t *testing.T) { + for _, want := range []float64{ + 0, + 1, + 2, + 12345, + 1e10, + 1e100, + 3.14159265e10, + 2.718281828e-123, + 1.0 / 3, + math.MaxFloat32, + math.MaxFloat64, + math.SmallestNonzeroFloat32, + math.SmallestNonzeroFloat64, + math.Inf(-1), + math.Inf(0), + -math.Inf(1), + } { + for i := range [2]int{} { + if i&1 != 0 { + want = -want + } + var f Float + f.SetFloat64(want) + if got, acc := f.Float64(); got != want || acc != Exact { + t.Errorf("got %g (%s, %s); want %g (Exact)", got, f.Text('p', 0), acc, want) + } + } + } + + // test basic rounding behavior (exhaustive rounding testing is done elsewhere) + const x uint64 = 0x8765432143218 // 53 bits needed + for prec := uint(1); prec <= 52; prec++ { + f := new(Float).SetPrec(prec).SetMode(ToZero).SetFloat64(float64(x)) + got, _ := f.Float64() + want := float64(x &^ (1<<(52-prec) - 1)) // cut off (round to zero) low 53-prec bits + if got != want { + t.Errorf("got %g (%s); want %g", got, f.Text('p', 0), want) + } + } + + // test NaN + defer func() { + if p, ok := recover().(ErrNaN); !ok { + t.Errorf("got %v; want ErrNaN panic", p) + } + }() + var f Float + f.SetFloat64(math.NaN()) + // should not reach here + t.Errorf("got %s; want ErrNaN panic", f.Text('p', 0)) +} + +func TestFloatSetInt(t *testing.T) { + for _, want := range []string{ + "0", + "1", + "-1", + "1234567890", + "123456789012345678901234567890", + "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890", + } { + var x Int + _, ok := x.SetString(want, 0) + if !ok { + t.Errorf("invalid integer %s", want) + continue + } + n := x.BitLen() + + var f Float + f.SetInt(&x) + + // check precision + if n < 64 { + n = 64 + } + if prec := f.Prec(); prec != uint(n) { + t.Errorf("got prec = %d; want %d", prec, n) + } + + // check value + got := f.Text('g', 100) + if got != want { + t.Errorf("got %s (%s); want %s", got, f.Text('p', 0), want) + } + } + + // TODO(gri) test basic rounding behavior +} + +func TestFloatSetRat(t *testing.T) { + for _, want := range []string{ + "0", + "1", + "-1", + "1234567890", + "123456789012345678901234567890", + "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890", + "1.2", + "3.14159265", + // TODO(gri) expand + } { + var x Rat + _, ok := x.SetString(want) + if !ok { + t.Errorf("invalid fraction %s", want) + continue + } + n := max(x.Num().BitLen(), x.Denom().BitLen()) + + var f1, f2 Float + f2.SetPrec(1000) + f1.SetRat(&x) + f2.SetRat(&x) + + // check precision when set automatically + if n < 64 { + n = 64 + } + if prec := f1.Prec(); prec != uint(n) { + t.Errorf("got prec = %d; want %d", prec, n) + } + + got := f2.Text('g', 100) + if got != want { + t.Errorf("got %s (%s); want %s", got, f2.Text('p', 0), want) + } + } +} + +func TestFloatSetInf(t *testing.T) { + var f Float + for _, test := range []struct { + signbit bool + prec uint + want string + }{ + {false, 0, "+Inf"}, + {true, 0, "-Inf"}, + {false, 10, "+Inf"}, + {true, 30, "-Inf"}, + } { + x := f.SetPrec(test.prec).SetInf(test.signbit) + if got := x.String(); got != test.want || x.Prec() != test.prec { + t.Errorf("SetInf(%v) = %s (prec = %d); want %s (prec = %d)", test.signbit, got, x.Prec(), test.want, test.prec) + } + } +} + +func TestFloatUint64(t *testing.T) { + for _, test := range []struct { + x string + out uint64 + acc Accuracy + }{ + {"-Inf", 0, Above}, + {"-1", 0, Above}, + {"-1e-1000", 0, Above}, + {"-0", 0, Exact}, + {"0", 0, Exact}, + {"1e-1000", 0, Below}, + {"1", 1, Exact}, + {"1.000000000000000000001", 1, Below}, + {"12345.0", 12345, Exact}, + {"12345.000000000000000000001", 12345, Below}, + {"18446744073709551615", 18446744073709551615, Exact}, + {"18446744073709551615.000000000000000000001", math.MaxUint64, Below}, + {"18446744073709551616", math.MaxUint64, Below}, + {"1e10000", math.MaxUint64, Below}, + {"+Inf", math.MaxUint64, Below}, + } { + x := makeFloat(test.x) + out, acc := x.Uint64() + if out != test.out || acc != test.acc { + t.Errorf("%s: got %d (%s); want %d (%s)", test.x, out, acc, test.out, test.acc) + } + } +} + +func TestFloatInt64(t *testing.T) { + for _, test := range []struct { + x string + out int64 + acc Accuracy + }{ + {"-Inf", math.MinInt64, Above}, + {"-1e10000", math.MinInt64, Above}, + {"-9223372036854775809", math.MinInt64, Above}, + {"-9223372036854775808.000000000000000000001", math.MinInt64, Above}, + {"-9223372036854775808", -9223372036854775808, Exact}, + {"-9223372036854775807.000000000000000000001", -9223372036854775807, Above}, + {"-9223372036854775807", -9223372036854775807, Exact}, + {"-12345.000000000000000000001", -12345, Above}, + {"-12345.0", -12345, Exact}, + {"-1.000000000000000000001", -1, Above}, + {"-1.5", -1, Above}, + {"-1", -1, Exact}, + {"-1e-1000", 0, Above}, + {"0", 0, Exact}, + {"1e-1000", 0, Below}, + {"1", 1, Exact}, + {"1.000000000000000000001", 1, Below}, + {"1.5", 1, Below}, + {"12345.0", 12345, Exact}, + {"12345.000000000000000000001", 12345, Below}, + {"9223372036854775807", 9223372036854775807, Exact}, + {"9223372036854775807.000000000000000000001", math.MaxInt64, Below}, + {"9223372036854775808", math.MaxInt64, Below}, + {"1e10000", math.MaxInt64, Below}, + {"+Inf", math.MaxInt64, Below}, + } { + x := makeFloat(test.x) + out, acc := x.Int64() + if out != test.out || acc != test.acc { + t.Errorf("%s: got %d (%s); want %d (%s)", test.x, out, acc, test.out, test.acc) + } + } +} + +func TestFloatFloat32(t *testing.T) { + for _, test := range []struct { + x string + out float32 + acc Accuracy + }{ + {"0", 0, Exact}, + + // underflow to zero + {"1e-1000", 0, Below}, + {"0x0.000002p-127", 0, Below}, + {"0x.0000010p-126", 0, Below}, + + // denormals + {"1.401298464e-45", math.SmallestNonzeroFloat32, Above}, // rounded up to smallest denormal + {"0x.ffffff8p-149", math.SmallestNonzeroFloat32, Above}, // rounded up to smallest denormal + {"0x.0000018p-126", math.SmallestNonzeroFloat32, Above}, // rounded up to smallest denormal + {"0x.0000020p-126", math.SmallestNonzeroFloat32, Exact}, + {"0x.8p-148", math.SmallestNonzeroFloat32, Exact}, + {"1p-149", math.SmallestNonzeroFloat32, Exact}, + {"0x.fffffep-126", math.Float32frombits(0x7fffff), Exact}, // largest denormal + + // special denormal cases (see issues 14553, 14651) + {"0x0.0000001p-126", math.Float32frombits(0x00000000), Below}, // underflow to zero + {"0x0.0000008p-126", math.Float32frombits(0x00000000), Below}, // underflow to zero + {"0x0.0000010p-126", math.Float32frombits(0x00000000), Below}, // rounded down to even + {"0x0.0000011p-126", math.Float32frombits(0x00000001), Above}, // rounded up to smallest denormal + {"0x0.0000018p-126", math.Float32frombits(0x00000001), Above}, // rounded up to smallest denormal + + {"0x1.0000000p-149", math.Float32frombits(0x00000001), Exact}, // smallest denormal + {"0x0.0000020p-126", math.Float32frombits(0x00000001), Exact}, // smallest denormal + {"0x0.fffffe0p-126", math.Float32frombits(0x007fffff), Exact}, // largest denormal + {"0x1.0000000p-126", math.Float32frombits(0x00800000), Exact}, // smallest normal + + {"0x0.8p-149", math.Float32frombits(0x000000000), Below}, // rounded down to even + {"0x0.9p-149", math.Float32frombits(0x000000001), Above}, // rounded up to smallest denormal + {"0x0.ap-149", math.Float32frombits(0x000000001), Above}, // rounded up to smallest denormal + {"0x0.bp-149", math.Float32frombits(0x000000001), Above}, // rounded up to smallest denormal + {"0x0.cp-149", math.Float32frombits(0x000000001), Above}, // rounded up to smallest denormal + + {"0x1.0p-149", math.Float32frombits(0x000000001), Exact}, // smallest denormal + {"0x1.7p-149", math.Float32frombits(0x000000001), Below}, + {"0x1.8p-149", math.Float32frombits(0x000000002), Above}, + {"0x1.9p-149", math.Float32frombits(0x000000002), Above}, + + {"0x2.0p-149", math.Float32frombits(0x000000002), Exact}, + {"0x2.8p-149", math.Float32frombits(0x000000002), Below}, // rounded down to even + {"0x2.9p-149", math.Float32frombits(0x000000003), Above}, + + {"0x3.0p-149", math.Float32frombits(0x000000003), Exact}, + {"0x3.7p-149", math.Float32frombits(0x000000003), Below}, + {"0x3.8p-149", math.Float32frombits(0x000000004), Above}, // rounded up to even + + {"0x4.0p-149", math.Float32frombits(0x000000004), Exact}, + {"0x4.8p-149", math.Float32frombits(0x000000004), Below}, // rounded down to even + {"0x4.9p-149", math.Float32frombits(0x000000005), Above}, + + // specific case from issue 14553 + {"0x7.7p-149", math.Float32frombits(0x000000007), Below}, + {"0x7.8p-149", math.Float32frombits(0x000000008), Above}, + {"0x7.9p-149", math.Float32frombits(0x000000008), Above}, + + // normals + {"0x.ffffffp-126", math.Float32frombits(0x00800000), Above}, // rounded up to smallest normal + {"1p-126", math.Float32frombits(0x00800000), Exact}, // smallest normal + {"0x1.fffffep-126", math.Float32frombits(0x00ffffff), Exact}, + {"0x1.ffffffp-126", math.Float32frombits(0x01000000), Above}, // rounded up + {"1", 1, Exact}, + {"1.000000000000000000001", 1, Below}, + {"12345.0", 12345, Exact}, + {"12345.000000000000000000001", 12345, Below}, + {"0x1.fffffe0p127", math.MaxFloat32, Exact}, + {"0x1.fffffe8p127", math.MaxFloat32, Below}, + + // overflow + {"0x1.ffffff0p127", float32(math.Inf(+1)), Above}, + {"0x1p128", float32(math.Inf(+1)), Above}, + {"1e10000", float32(math.Inf(+1)), Above}, + {"0x1.ffffff0p2147483646", float32(math.Inf(+1)), Above}, // overflow in rounding + + // inf + {"Inf", float32(math.Inf(+1)), Exact}, + } { + for i := 0; i < 2; i++ { + // test both signs + tx, tout, tacc := test.x, test.out, test.acc + if i != 0 { + tx = "-" + tx + tout = -tout + tacc = -tacc + } + + // conversion should match strconv where syntax is agreeable + if f, err := strconv.ParseFloat(tx, 32); err == nil && !alike32(float32(f), tout) { + t.Errorf("%s: got %g; want %g (incorrect test data)", tx, f, tout) + } + + x := makeFloat(tx) + out, acc := x.Float32() + if !alike32(out, tout) || acc != tacc { + t.Errorf("%s: got %g (%#08x, %s); want %g (%#08x, %s)", tx, out, math.Float32bits(out), acc, test.out, math.Float32bits(test.out), tacc) + } + + // test that x.SetFloat64(float64(f)).Float32() == f + var x2 Float + out2, acc2 := x2.SetFloat64(float64(out)).Float32() + if !alike32(out2, out) || acc2 != Exact { + t.Errorf("idempotency test: got %g (%s); want %g (Exact)", out2, acc2, out) + } + } + } +} + +func TestFloatFloat64(t *testing.T) { + const smallestNormalFloat64 = 2.2250738585072014e-308 // 1p-1022 + for _, test := range []struct { + x string + out float64 + acc Accuracy + }{ + {"0", 0, Exact}, + + // underflow to zero + {"1e-1000", 0, Below}, + {"0x0.0000000000001p-1023", 0, Below}, + {"0x0.00000000000008p-1022", 0, Below}, + + // denormals + {"0x0.0000000000000cp-1022", math.SmallestNonzeroFloat64, Above}, // rounded up to smallest denormal + {"0x0.00000000000010p-1022", math.SmallestNonzeroFloat64, Exact}, // smallest denormal + {"0x.8p-1073", math.SmallestNonzeroFloat64, Exact}, + {"1p-1074", math.SmallestNonzeroFloat64, Exact}, + {"0x.fffffffffffffp-1022", math.Float64frombits(0x000fffffffffffff), Exact}, // largest denormal + + // special denormal cases (see issues 14553, 14651) + {"0x0.00000000000001p-1022", math.Float64frombits(0x00000000000000000), Below}, // underflow to zero + {"0x0.00000000000004p-1022", math.Float64frombits(0x00000000000000000), Below}, // underflow to zero + {"0x0.00000000000008p-1022", math.Float64frombits(0x00000000000000000), Below}, // rounded down to even + {"0x0.00000000000009p-1022", math.Float64frombits(0x00000000000000001), Above}, // rounded up to smallest denormal + {"0x0.0000000000000ap-1022", math.Float64frombits(0x00000000000000001), Above}, // rounded up to smallest denormal + + {"0x0.8p-1074", math.Float64frombits(0x00000000000000000), Below}, // rounded down to even + {"0x0.9p-1074", math.Float64frombits(0x00000000000000001), Above}, // rounded up to smallest denormal + {"0x0.ap-1074", math.Float64frombits(0x00000000000000001), Above}, // rounded up to smallest denormal + {"0x0.bp-1074", math.Float64frombits(0x00000000000000001), Above}, // rounded up to smallest denormal + {"0x0.cp-1074", math.Float64frombits(0x00000000000000001), Above}, // rounded up to smallest denormal + + {"0x1.0p-1074", math.Float64frombits(0x00000000000000001), Exact}, + {"0x1.7p-1074", math.Float64frombits(0x00000000000000001), Below}, + {"0x1.8p-1074", math.Float64frombits(0x00000000000000002), Above}, + {"0x1.9p-1074", math.Float64frombits(0x00000000000000002), Above}, + + {"0x2.0p-1074", math.Float64frombits(0x00000000000000002), Exact}, + {"0x2.8p-1074", math.Float64frombits(0x00000000000000002), Below}, // rounded down to even + {"0x2.9p-1074", math.Float64frombits(0x00000000000000003), Above}, + + {"0x3.0p-1074", math.Float64frombits(0x00000000000000003), Exact}, + {"0x3.7p-1074", math.Float64frombits(0x00000000000000003), Below}, + {"0x3.8p-1074", math.Float64frombits(0x00000000000000004), Above}, // rounded up to even + + {"0x4.0p-1074", math.Float64frombits(0x00000000000000004), Exact}, + {"0x4.8p-1074", math.Float64frombits(0x00000000000000004), Below}, // rounded down to even + {"0x4.9p-1074", math.Float64frombits(0x00000000000000005), Above}, + + // normals + {"0x.fffffffffffff8p-1022", math.Float64frombits(0x0010000000000000), Above}, // rounded up to smallest normal + {"1p-1022", math.Float64frombits(0x0010000000000000), Exact}, // smallest normal + {"1", 1, Exact}, + {"1.000000000000000000001", 1, Below}, + {"12345.0", 12345, Exact}, + {"12345.000000000000000000001", 12345, Below}, + {"0x1.fffffffffffff0p1023", math.MaxFloat64, Exact}, + {"0x1.fffffffffffff4p1023", math.MaxFloat64, Below}, + + // overflow + {"0x1.fffffffffffff8p1023", math.Inf(+1), Above}, + {"0x1p1024", math.Inf(+1), Above}, + {"1e10000", math.Inf(+1), Above}, + {"0x1.fffffffffffff8p2147483646", math.Inf(+1), Above}, // overflow in rounding + {"Inf", math.Inf(+1), Exact}, + + // selected denormalized values that were handled incorrectly in the past + {"0x.fffffffffffffp-1022", smallestNormalFloat64 - math.SmallestNonzeroFloat64, Exact}, + {"4503599627370495p-1074", smallestNormalFloat64 - math.SmallestNonzeroFloat64, Exact}, + + // https://www.exploringbinary.com/php-hangs-on-numeric-value-2-2250738585072011e-308/ + {"2.2250738585072011e-308", 2.225073858507201e-308, Below}, + // https://www.exploringbinary.com/java-hangs-when-converting-2-2250738585072012e-308/ + {"2.2250738585072012e-308", 2.2250738585072014e-308, Above}, + } { + for i := 0; i < 2; i++ { + // test both signs + tx, tout, tacc := test.x, test.out, test.acc + if i != 0 { + tx = "-" + tx + tout = -tout + tacc = -tacc + } + + // conversion should match strconv where syntax is agreeable + if f, err := strconv.ParseFloat(tx, 64); err == nil && !alike64(f, tout) { + t.Errorf("%s: got %g; want %g (incorrect test data)", tx, f, tout) + } + + x := makeFloat(tx) + out, acc := x.Float64() + if !alike64(out, tout) || acc != tacc { + t.Errorf("%s: got %g (%#016x, %s); want %g (%#016x, %s)", tx, out, math.Float64bits(out), acc, test.out, math.Float64bits(test.out), tacc) + } + + // test that x.SetFloat64(f).Float64() == f + var x2 Float + out2, acc2 := x2.SetFloat64(out).Float64() + if !alike64(out2, out) || acc2 != Exact { + t.Errorf("idempotency test: got %g (%s); want %g (Exact)", out2, acc2, out) + } + } + } +} + +func TestFloatInt(t *testing.T) { + for _, test := range []struct { + x string + want string + acc Accuracy + }{ + {"0", "0", Exact}, + {"+0", "0", Exact}, + {"-0", "0", Exact}, + {"Inf", "nil", Below}, + {"+Inf", "nil", Below}, + {"-Inf", "nil", Above}, + {"1", "1", Exact}, + {"-1", "-1", Exact}, + {"1.23", "1", Below}, + {"-1.23", "-1", Above}, + {"123e-2", "1", Below}, + {"123e-3", "0", Below}, + {"123e-4", "0", Below}, + {"1e-1000", "0", Below}, + {"-1e-1000", "0", Above}, + {"1e+10", "10000000000", Exact}, + {"1e+100", "10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", Exact}, + } { + x := makeFloat(test.x) + res, acc := x.Int(nil) + got := "nil" + if res != nil { + got = res.String() + } + if got != test.want || acc != test.acc { + t.Errorf("%s: got %s (%s); want %s (%s)", test.x, got, acc, test.want, test.acc) + } + } + + // check that supplied *Int is used + for _, f := range []string{"0", "1", "-1", "1234"} { + x := makeFloat(f) + i := new(Int) + if res, _ := x.Int(i); res != i { + t.Errorf("(%s).Int is not using supplied *Int", f) + } + } +} + +func TestFloatRat(t *testing.T) { + for _, test := range []struct { + x, want string + acc Accuracy + }{ + {"0", "0/1", Exact}, + {"+0", "0/1", Exact}, + {"-0", "0/1", Exact}, + {"Inf", "nil", Below}, + {"+Inf", "nil", Below}, + {"-Inf", "nil", Above}, + {"1", "1/1", Exact}, + {"-1", "-1/1", Exact}, + {"1.25", "5/4", Exact}, + {"-1.25", "-5/4", Exact}, + {"1e10", "10000000000/1", Exact}, + {"1p10", "1024/1", Exact}, + {"-1p-10", "-1/1024", Exact}, + {"3.14159265", "7244019449799623199/2305843009213693952", Exact}, + } { + x := makeFloat(test.x).SetPrec(64) + res, acc := x.Rat(nil) + got := "nil" + if res != nil { + got = res.String() + } + if got != test.want { + t.Errorf("%s: got %s; want %s", test.x, got, test.want) + continue + } + if acc != test.acc { + t.Errorf("%s: got %s; want %s", test.x, acc, test.acc) + continue + } + + // inverse conversion + if res != nil { + got := new(Float).SetPrec(64).SetRat(res) + if got.Cmp(x) != 0 { + t.Errorf("%s: got %s; want %s", test.x, got, x) + } + } + } + + // check that supplied *Rat is used + for _, f := range []string{"0", "1", "-1", "1234"} { + x := makeFloat(f) + r := new(Rat) + if res, _ := x.Rat(r); res != r { + t.Errorf("(%s).Rat is not using supplied *Rat", f) + } + } +} + +func TestFloatAbs(t *testing.T) { + for _, test := range []string{ + "0", + "1", + "1234", + "1.23e-2", + "1e-1000", + "1e1000", + "Inf", + } { + p := makeFloat(test) + a := new(Float).Abs(p) + if !alike(a, p) { + t.Errorf("%s: got %s; want %s", test, a.Text('g', 10), test) + } + + n := makeFloat("-" + test) + a.Abs(n) + if !alike(a, p) { + t.Errorf("-%s: got %s; want %s", test, a.Text('g', 10), test) + } + } +} + +func TestFloatNeg(t *testing.T) { + for _, test := range []string{ + "0", + "1", + "1234", + "1.23e-2", + "1e-1000", + "1e1000", + "Inf", + } { + p1 := makeFloat(test) + n1 := makeFloat("-" + test) + n2 := new(Float).Neg(p1) + p2 := new(Float).Neg(n2) + if !alike(n2, n1) { + t.Errorf("%s: got %s; want %s", test, n2.Text('g', 10), n1.Text('g', 10)) + } + if !alike(p2, p1) { + t.Errorf("%s: got %s; want %s", test, p2.Text('g', 10), p1.Text('g', 10)) + } + } +} + +func TestFloatInc(t *testing.T) { + const n = 10 + for _, prec := range precList { + if 1<<prec < n { + continue // prec must be large enough to hold all numbers from 0 to n + } + var x, one Float + x.SetPrec(prec) + one.SetInt64(1) + for i := 0; i < n; i++ { + x.Add(&x, &one) + } + if x.Cmp(new(Float).SetInt64(n)) != 0 { + t.Errorf("prec = %d: got %s; want %d", prec, &x, n) + } + } +} + +// Selected precisions with which to run various tests. +var precList = [...]uint{1, 2, 5, 8, 10, 16, 23, 24, 32, 50, 53, 64, 100, 128, 500, 511, 512, 513, 1000, 10000} + +// Selected bits with which to run various tests. +// Each entry is a list of bits representing a floating-point number (see fromBits). +var bitsList = [...]Bits{ + {}, // = 0 + {0}, // = 1 + {1}, // = 2 + {-1}, // = 1/2 + {10}, // = 2**10 == 1024 + {-10}, // = 2**-10 == 1/1024 + {100, 10, 1}, // = 2**100 + 2**10 + 2**1 + {0, -1, -2, -10}, + // TODO(gri) add more test cases +} + +// TestFloatAdd tests Float.Add/Sub by comparing the result of a "manual" +// addition/subtraction of arguments represented by Bits values with the +// respective Float addition/subtraction for a variety of precisions +// and rounding modes. +func TestFloatAdd(t *testing.T) { + for _, xbits := range bitsList { + for _, ybits := range bitsList { + // exact values + x := xbits.Float() + y := ybits.Float() + zbits := xbits.add(ybits) + z := zbits.Float() + + for i, mode := range [...]RoundingMode{ToZero, ToNearestEven, AwayFromZero} { + for _, prec := range precList { + got := new(Float).SetPrec(prec).SetMode(mode) + got.Add(x, y) + want := zbits.round(prec, mode) + if got.Cmp(want) != 0 { + t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t+ %s %v\n\t= %s\n\twant %s", + i, prec, mode, x, xbits, y, ybits, got, want) + } + + got.Sub(z, x) + want = ybits.round(prec, mode) + if got.Cmp(want) != 0 { + t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t- %s %v\n\t= %s\n\twant %s", + i, prec, mode, z, zbits, x, xbits, got, want) + } + } + } + } + } +} + +// TestFloatAddRoundZero tests Float.Add/Sub rounding when the result is exactly zero. +// x + (-x) or x - x for non-zero x should be +0 in all cases except when +// the rounding mode is ToNegativeInf in which case it should be -0. +func TestFloatAddRoundZero(t *testing.T) { + for _, mode := range [...]RoundingMode{ToNearestEven, ToNearestAway, ToZero, AwayFromZero, ToPositiveInf, ToNegativeInf} { + x := NewFloat(5.0) + y := new(Float).Neg(x) + want := NewFloat(0.0) + if mode == ToNegativeInf { + want.Neg(want) + } + got := new(Float).SetMode(mode) + got.Add(x, y) + if got.Cmp(want) != 0 || got.neg != (mode == ToNegativeInf) { + t.Errorf("%s:\n\t %v\n\t+ %v\n\t= %v\n\twant %v", + mode, x, y, got, want) + } + got.Sub(x, x) + if got.Cmp(want) != 0 || got.neg != (mode == ToNegativeInf) { + t.Errorf("%v:\n\t %v\n\t- %v\n\t= %v\n\twant %v", + mode, x, x, got, want) + } + } +} + +// TestFloatAdd32 tests that Float.Add/Sub of numbers with +// 24bit mantissa behaves like float32 addition/subtraction +// (excluding denormal numbers). +func TestFloatAdd32(t *testing.T) { + // chose base such that we cross the mantissa precision limit + const base = 1<<26 - 0x10 // 11...110000 (26 bits) + for d := 0; d <= 0x10; d++ { + for i := range [2]int{} { + x0, y0 := float64(base), float64(d) + if i&1 != 0 { + x0, y0 = y0, x0 + } + + x := NewFloat(x0) + y := NewFloat(y0) + z := new(Float).SetPrec(24) + + z.Add(x, y) + got, acc := z.Float32() + want := float32(y0) + float32(x0) + if got != want || acc != Exact { + t.Errorf("d = %d: %g + %g = %g (%s); want %g (Exact)", d, x0, y0, got, acc, want) + } + + z.Sub(z, y) + got, acc = z.Float32() + want = float32(want) - float32(y0) + if got != want || acc != Exact { + t.Errorf("d = %d: %g - %g = %g (%s); want %g (Exact)", d, x0+y0, y0, got, acc, want) + } + } + } +} + +// TestFloatAdd64 tests that Float.Add/Sub of numbers with +// 53bit mantissa behaves like float64 addition/subtraction. +func TestFloatAdd64(t *testing.T) { + // chose base such that we cross the mantissa precision limit + const base = 1<<55 - 0x10 // 11...110000 (55 bits) + for d := 0; d <= 0x10; d++ { + for i := range [2]int{} { + x0, y0 := float64(base), float64(d) + if i&1 != 0 { + x0, y0 = y0, x0 + } + + x := NewFloat(x0) + y := NewFloat(y0) + z := new(Float).SetPrec(53) + + z.Add(x, y) + got, acc := z.Float64() + want := x0 + y0 + if got != want || acc != Exact { + t.Errorf("d = %d: %g + %g = %g (%s); want %g (Exact)", d, x0, y0, got, acc, want) + } + + z.Sub(z, y) + got, acc = z.Float64() + want -= y0 + if got != want || acc != Exact { + t.Errorf("d = %d: %g - %g = %g (%s); want %g (Exact)", d, x0+y0, y0, got, acc, want) + } + } + } +} + +func TestIssue20490(t *testing.T) { + var tests = []struct { + a, b float64 + }{ + {4, 1}, + {-4, 1}, + {4, -1}, + {-4, -1}, + } + + for _, test := range tests { + a, b := NewFloat(test.a), NewFloat(test.b) + diff := new(Float).Sub(a, b) + b.Sub(a, b) + if b.Cmp(diff) != 0 { + t.Errorf("got %g - %g = %g; want %g\n", a, NewFloat(test.b), b, diff) + } + + b = NewFloat(test.b) + sum := new(Float).Add(a, b) + b.Add(a, b) + if b.Cmp(sum) != 0 { + t.Errorf("got %g + %g = %g; want %g\n", a, NewFloat(test.b), b, sum) + } + + } +} + +// TestFloatMul tests Float.Mul/Quo by comparing the result of a "manual" +// multiplication/division of arguments represented by Bits values with the +// respective Float multiplication/division for a variety of precisions +// and rounding modes. +func TestFloatMul(t *testing.T) { + for _, xbits := range bitsList { + for _, ybits := range bitsList { + // exact values + x := xbits.Float() + y := ybits.Float() + zbits := xbits.mul(ybits) + z := zbits.Float() + + for i, mode := range [...]RoundingMode{ToZero, ToNearestEven, AwayFromZero} { + for _, prec := range precList { + got := new(Float).SetPrec(prec).SetMode(mode) + got.Mul(x, y) + want := zbits.round(prec, mode) + if got.Cmp(want) != 0 { + t.Errorf("i = %d, prec = %d, %s:\n\t %v %v\n\t* %v %v\n\t= %v\n\twant %v", + i, prec, mode, x, xbits, y, ybits, got, want) + } + + if x.Sign() == 0 { + continue // ignore div-0 case (not invertable) + } + got.Quo(z, x) + want = ybits.round(prec, mode) + if got.Cmp(want) != 0 { + t.Errorf("i = %d, prec = %d, %s:\n\t %v %v\n\t/ %v %v\n\t= %v\n\twant %v", + i, prec, mode, z, zbits, x, xbits, got, want) + } + } + } + } + } +} + +// TestFloatMul64 tests that Float.Mul/Quo of numbers with +// 53bit mantissa behaves like float64 multiplication/division. +func TestFloatMul64(t *testing.T) { + for _, test := range []struct { + x, y float64 + }{ + {0, 0}, + {0, 1}, + {1, 1}, + {1, 1.5}, + {1.234, 0.5678}, + {2.718281828, 3.14159265358979}, + {2.718281828e10, 3.14159265358979e-32}, + {1.0 / 3, 1e200}, + } { + for i := range [8]int{} { + x0, y0 := test.x, test.y + if i&1 != 0 { + x0 = -x0 + } + if i&2 != 0 { + y0 = -y0 + } + if i&4 != 0 { + x0, y0 = y0, x0 + } + + x := NewFloat(x0) + y := NewFloat(y0) + z := new(Float).SetPrec(53) + + z.Mul(x, y) + got, _ := z.Float64() + want := x0 * y0 + if got != want { + t.Errorf("%g * %g = %g; want %g", x0, y0, got, want) + } + + if y0 == 0 { + continue // avoid division-by-zero + } + z.Quo(z, y) + got, _ = z.Float64() + want /= y0 + if got != want { + t.Errorf("%g / %g = %g; want %g", x0*y0, y0, got, want) + } + } + } +} + +func TestIssue6866(t *testing.T) { + for _, prec := range precList { + two := new(Float).SetPrec(prec).SetInt64(2) + one := new(Float).SetPrec(prec).SetInt64(1) + three := new(Float).SetPrec(prec).SetInt64(3) + msix := new(Float).SetPrec(prec).SetInt64(-6) + psix := new(Float).SetPrec(prec).SetInt64(+6) + + p := new(Float).SetPrec(prec) + z1 := new(Float).SetPrec(prec) + z2 := new(Float).SetPrec(prec) + + // z1 = 2 + 1.0/3*-6 + p.Quo(one, three) + p.Mul(p, msix) + z1.Add(two, p) + + // z2 = 2 - 1.0/3*+6 + p.Quo(one, three) + p.Mul(p, psix) + z2.Sub(two, p) + + if z1.Cmp(z2) != 0 { + t.Fatalf("prec %d: got z1 = %v != z2 = %v; want z1 == z2\n", prec, z1, z2) + } + if z1.Sign() != 0 { + t.Errorf("prec %d: got z1 = %v; want 0", prec, z1) + } + if z2.Sign() != 0 { + t.Errorf("prec %d: got z2 = %v; want 0", prec, z2) + } + } +} + +func TestFloatQuo(t *testing.T) { + // TODO(gri) make the test vary these precisions + preci := 200 // precision of integer part + precf := 20 // precision of fractional part + + for i := 0; i < 8; i++ { + // compute accurate (not rounded) result z + bits := Bits{preci - 1} + if i&3 != 0 { + bits = append(bits, 0) + } + if i&2 != 0 { + bits = append(bits, -1) + } + if i&1 != 0 { + bits = append(bits, -precf) + } + z := bits.Float() + + // compute accurate x as z*y + y := NewFloat(3.14159265358979323e123) + + x := new(Float).SetPrec(z.Prec() + y.Prec()).SetMode(ToZero) + x.Mul(z, y) + + // leave for debugging + // fmt.Printf("x = %s\ny = %s\nz = %s\n", x, y, z) + + if got := x.Acc(); got != Exact { + t.Errorf("got acc = %s; want exact", got) + } + + // round accurate z for a variety of precisions and + // modes and compare against result of x / y. + for _, mode := range [...]RoundingMode{ToZero, ToNearestEven, AwayFromZero} { + for d := -5; d < 5; d++ { + prec := uint(preci + d) + got := new(Float).SetPrec(prec).SetMode(mode).Quo(x, y) + want := bits.round(prec, mode) + if got.Cmp(want) != 0 { + t.Errorf("i = %d, prec = %d, %s:\n\t %s\n\t/ %s\n\t= %s\n\twant %s", + i, prec, mode, x, y, got, want) + } + } + } + } +} + +var long = flag.Bool("long", false, "run very long tests") + +// TestFloatQuoSmoke tests all divisions x/y for values x, y in the range [-n, +n]; +// it serves as a smoke test for basic correctness of division. +func TestFloatQuoSmoke(t *testing.T) { + n := 10 + if *long { + n = 1000 + } + + const dprec = 3 // max. precision variation + const prec = 10 + dprec // enough bits to hold n precisely + for x := -n; x <= n; x++ { + for y := -n; y < n; y++ { + if y == 0 { + continue + } + + a := float64(x) + b := float64(y) + c := a / b + + // vary operand precision (only ok as long as a, b can be represented correctly) + for ad := -dprec; ad <= dprec; ad++ { + for bd := -dprec; bd <= dprec; bd++ { + A := new(Float).SetPrec(uint(prec + ad)).SetFloat64(a) + B := new(Float).SetPrec(uint(prec + bd)).SetFloat64(b) + C := new(Float).SetPrec(53).Quo(A, B) // C has float64 mantissa width + + cc, acc := C.Float64() + if cc != c { + t.Errorf("%g/%g = %s; want %.5g\n", a, b, C.Text('g', 5), c) + continue + } + if acc != Exact { + t.Errorf("%g/%g got %s result; want exact result", a, b, acc) + } + } + } + } + } +} + +// TestFloatArithmeticSpecialValues tests that Float operations produce the +// correct results for combinations of zero (±0), finite (±1 and ±2.71828), +// and infinite (±Inf) operands. +func TestFloatArithmeticSpecialValues(t *testing.T) { + zero := 0.0 + args := []float64{math.Inf(-1), -2.71828, -1, -zero, zero, 1, 2.71828, math.Inf(1)} + xx := new(Float) + yy := new(Float) + got := new(Float) + want := new(Float) + for i := 0; i < 4; i++ { + for _, x := range args { + xx.SetFloat64(x) + // check conversion is correct + // (no need to do this for y, since we see exactly the + // same values there) + if got, acc := xx.Float64(); got != x || acc != Exact { + t.Errorf("Float(%g) == %g (%s)", x, got, acc) + } + for _, y := range args { + yy.SetFloat64(y) + var ( + op string + z float64 + f func(z, x, y *Float) *Float + ) + switch i { + case 0: + op = "+" + z = x + y + f = (*Float).Add + case 1: + op = "-" + z = x - y + f = (*Float).Sub + case 2: + op = "*" + z = x * y + f = (*Float).Mul + case 3: + op = "/" + z = x / y + f = (*Float).Quo + default: + panic("unreachable") + } + var errnan bool // set if execution of f panicked with ErrNaN + // protect execution of f + func() { + defer func() { + if p := recover(); p != nil { + _ = p.(ErrNaN) // re-panic if not ErrNaN + errnan = true + } + }() + f(got, xx, yy) + }() + if math.IsNaN(z) { + if !errnan { + t.Errorf("%5g %s %5g = %5s; want ErrNaN panic", x, op, y, got) + } + continue + } + if errnan { + t.Errorf("%5g %s %5g panicked with ErrNan; want %5s", x, op, y, want) + continue + } + want.SetFloat64(z) + if !alike(got, want) { + t.Errorf("%5g %s %5g = %5s; want %5s", x, op, y, got, want) + } + } + } + } +} + +func TestFloatArithmeticOverflow(t *testing.T) { + for _, test := range []struct { + prec uint + mode RoundingMode + op byte + x, y, want string + acc Accuracy + }{ + {4, ToNearestEven, '+', "0", "0", "0", Exact}, // smoke test + {4, ToNearestEven, '+', "0x.8p+0", "0x.8p+0", "0x.8p+1", Exact}, // smoke test + + {4, ToNearestEven, '+', "0", "0x.8p2147483647", "0x.8p+2147483647", Exact}, + {4, ToNearestEven, '+', "0x.8p2147483500", "0x.8p2147483647", "0x.8p+2147483647", Below}, // rounded to zero + {4, ToNearestEven, '+', "0x.8p2147483647", "0x.8p2147483647", "+Inf", Above}, // exponent overflow in + + {4, ToNearestEven, '+', "-0x.8p2147483647", "-0x.8p2147483647", "-Inf", Below}, // exponent overflow in + + {4, ToNearestEven, '-', "-0x.8p2147483647", "0x.8p2147483647", "-Inf", Below}, // exponent overflow in - + + {4, ToZero, '+', "0x.fp2147483647", "0x.8p2147483643", "0x.fp+2147483647", Below}, // rounded to zero + {4, ToNearestEven, '+', "0x.fp2147483647", "0x.8p2147483643", "+Inf", Above}, // exponent overflow in rounding + {4, AwayFromZero, '+', "0x.fp2147483647", "0x.8p2147483643", "+Inf", Above}, // exponent overflow in rounding + + {4, AwayFromZero, '-', "-0x.fp2147483647", "0x.8p2147483644", "-Inf", Below}, // exponent overflow in rounding + {4, ToNearestEven, '-', "-0x.fp2147483647", "0x.8p2147483643", "-Inf", Below}, // exponent overflow in rounding + {4, ToZero, '-', "-0x.fp2147483647", "0x.8p2147483643", "-0x.fp+2147483647", Above}, // rounded to zero + + {4, ToNearestEven, '+', "0", "0x.8p-2147483648", "0x.8p-2147483648", Exact}, + {4, ToNearestEven, '+', "0x.8p-2147483648", "0x.8p-2147483648", "0x.8p-2147483647", Exact}, + + {4, ToNearestEven, '*', "1", "0x.8p2147483647", "0x.8p+2147483647", Exact}, + {4, ToNearestEven, '*', "2", "0x.8p2147483647", "+Inf", Above}, // exponent overflow in * + {4, ToNearestEven, '*', "-2", "0x.8p2147483647", "-Inf", Below}, // exponent overflow in * + + {4, ToNearestEven, '/', "0.5", "0x.8p2147483647", "0x.8p-2147483646", Exact}, + {4, ToNearestEven, '/', "0x.8p+0", "0x.8p2147483647", "0x.8p-2147483646", Exact}, + {4, ToNearestEven, '/', "0x.8p-1", "0x.8p2147483647", "0x.8p-2147483647", Exact}, + {4, ToNearestEven, '/', "0x.8p-2", "0x.8p2147483647", "0x.8p-2147483648", Exact}, + {4, ToNearestEven, '/', "0x.8p-3", "0x.8p2147483647", "0", Below}, // exponent underflow in / + } { + x := makeFloat(test.x) + y := makeFloat(test.y) + z := new(Float).SetPrec(test.prec).SetMode(test.mode) + switch test.op { + case '+': + z.Add(x, y) + case '-': + z.Sub(x, y) + case '*': + z.Mul(x, y) + case '/': + z.Quo(x, y) + default: + panic("unreachable") + } + if got := z.Text('p', 0); got != test.want || z.Acc() != test.acc { + t.Errorf( + "prec = %d (%s): %s %c %s = %s (%s); want %s (%s)", + test.prec, test.mode, x.Text('p', 0), test.op, y.Text('p', 0), got, z.Acc(), test.want, test.acc, + ) + } + } +} + +// TODO(gri) Add tests that check correctness in the presence of aliasing. + +// For rounding modes ToNegativeInf and ToPositiveInf, rounding is affected +// by the sign of the value to be rounded. Test that rounding happens after +// the sign of a result has been set. +// This test uses specific values that are known to fail if rounding is +// "factored" out before setting the result sign. +func TestFloatArithmeticRounding(t *testing.T) { + for _, test := range []struct { + mode RoundingMode + prec uint + x, y, want int64 + op byte + }{ + {ToZero, 3, -0x8, -0x1, -0x8, '+'}, + {AwayFromZero, 3, -0x8, -0x1, -0xa, '+'}, + {ToNegativeInf, 3, -0x8, -0x1, -0xa, '+'}, + + {ToZero, 3, -0x8, 0x1, -0x8, '-'}, + {AwayFromZero, 3, -0x8, 0x1, -0xa, '-'}, + {ToNegativeInf, 3, -0x8, 0x1, -0xa, '-'}, + + {ToZero, 3, -0x9, 0x1, -0x8, '*'}, + {AwayFromZero, 3, -0x9, 0x1, -0xa, '*'}, + {ToNegativeInf, 3, -0x9, 0x1, -0xa, '*'}, + + {ToZero, 3, -0x9, 0x1, -0x8, '/'}, + {AwayFromZero, 3, -0x9, 0x1, -0xa, '/'}, + {ToNegativeInf, 3, -0x9, 0x1, -0xa, '/'}, + } { + var x, y, z Float + x.SetInt64(test.x) + y.SetInt64(test.y) + z.SetPrec(test.prec).SetMode(test.mode) + switch test.op { + case '+': + z.Add(&x, &y) + case '-': + z.Sub(&x, &y) + case '*': + z.Mul(&x, &y) + case '/': + z.Quo(&x, &y) + default: + panic("unreachable") + } + if got, acc := z.Int64(); got != test.want || acc != Exact { + t.Errorf("%s, %d bits: %d %c %d = %d (%s); want %d (Exact)", + test.mode, test.prec, test.x, test.op, test.y, got, acc, test.want, + ) + } + } +} + +// TestFloatCmpSpecialValues tests that Cmp produces the correct results for +// combinations of zero (±0), finite (±1 and ±2.71828), and infinite (±Inf) +// operands. +func TestFloatCmpSpecialValues(t *testing.T) { + zero := 0.0 + args := []float64{math.Inf(-1), -2.71828, -1, -zero, zero, 1, 2.71828, math.Inf(1)} + xx := new(Float) + yy := new(Float) + for i := 0; i < 4; i++ { + for _, x := range args { + xx.SetFloat64(x) + // check conversion is correct + // (no need to do this for y, since we see exactly the + // same values there) + if got, acc := xx.Float64(); got != x || acc != Exact { + t.Errorf("Float(%g) == %g (%s)", x, got, acc) + } + for _, y := range args { + yy.SetFloat64(y) + got := xx.Cmp(yy) + want := 0 + switch { + case x < y: + want = -1 + case x > y: + want = +1 + } + if got != want { + t.Errorf("(%g).Cmp(%g) = %v; want %v", x, y, got, want) + } + } + } + } +} + +func BenchmarkFloatAdd(b *testing.B) { + x := new(Float) + y := new(Float) + z := new(Float) + + for _, prec := range []uint{10, 1e2, 1e3, 1e4, 1e5} { + x.SetPrec(prec).SetRat(NewRat(1, 3)) + y.SetPrec(prec).SetRat(NewRat(1, 6)) + z.SetPrec(prec) + + b.Run(fmt.Sprintf("%v", prec), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + z.Add(x, y) + } + }) + } +} + +func BenchmarkFloatSub(b *testing.B) { + x := new(Float) + y := new(Float) + z := new(Float) + + for _, prec := range []uint{10, 1e2, 1e3, 1e4, 1e5} { + x.SetPrec(prec).SetRat(NewRat(1, 3)) + y.SetPrec(prec).SetRat(NewRat(1, 6)) + z.SetPrec(prec) + + b.Run(fmt.Sprintf("%v", prec), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + z.Sub(x, y) + } + }) + } +} diff --git a/src/math/big/floatconv.go b/src/math/big/floatconv.go new file mode 100644 index 0000000..3bb51c7 --- /dev/null +++ b/src/math/big/floatconv.go @@ -0,0 +1,302 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements string-to-Float conversion functions. + +package big + +import ( + "fmt" + "io" + "strings" +) + +var floatZero Float + +// SetString sets z to the value of s and returns z and a boolean indicating +// success. s must be a floating-point number of the same format as accepted +// by Parse, with base argument 0. The entire string (not just a prefix) must +// be valid for success. If the operation failed, the value of z is undefined +// but the returned value is nil. +func (z *Float) SetString(s string) (*Float, bool) { + if f, _, err := z.Parse(s, 0); err == nil { + return f, true + } + return nil, false +} + +// scan is like Parse but reads the longest possible prefix representing a valid +// floating point number from an io.ByteScanner rather than a string. It serves +// as the implementation of Parse. It does not recognize ±Inf and does not expect +// EOF at the end. +func (z *Float) scan(r io.ByteScanner, base int) (f *Float, b int, err error) { + prec := z.prec + if prec == 0 { + prec = 64 + } + + // A reasonable value in case of an error. + z.form = zero + + // sign + z.neg, err = scanSign(r) + if err != nil { + return + } + + // mantissa + var fcount int // fractional digit count; valid if <= 0 + z.mant, b, fcount, err = z.mant.scan(r, base, true) + if err != nil { + return + } + + // exponent + var exp int64 + var ebase int + exp, ebase, err = scanExponent(r, true, base == 0) + if err != nil { + return + } + + // special-case 0 + if len(z.mant) == 0 { + z.prec = prec + z.acc = Exact + z.form = zero + f = z + return + } + // len(z.mant) > 0 + + // The mantissa may have a radix point (fcount <= 0) and there + // may be a nonzero exponent exp. The radix point amounts to a + // division by b**(-fcount). An exponent means multiplication by + // ebase**exp. Finally, mantissa normalization (shift left) requires + // a correcting multiplication by 2**(-shiftcount). Multiplications + // are commutative, so we can apply them in any order as long as there + // is no loss of precision. We only have powers of 2 and 10, and + // we split powers of 10 into the product of the same powers of + // 2 and 5. This reduces the size of the multiplication factor + // needed for base-10 exponents. + + // normalize mantissa and determine initial exponent contributions + exp2 := int64(len(z.mant))*_W - fnorm(z.mant) + exp5 := int64(0) + + // determine binary or decimal exponent contribution of radix point + if fcount < 0 { + // The mantissa has a radix point ddd.dddd; and + // -fcount is the number of digits to the right + // of '.'. Adjust relevant exponent accordingly. + d := int64(fcount) + switch b { + case 10: + exp5 = d + fallthrough // 10**e == 5**e * 2**e + case 2: + exp2 += d + case 8: + exp2 += d * 3 // octal digits are 3 bits each + case 16: + exp2 += d * 4 // hexadecimal digits are 4 bits each + default: + panic("unexpected mantissa base") + } + // fcount consumed - not needed anymore + } + + // take actual exponent into account + switch ebase { + case 10: + exp5 += exp + fallthrough // see fallthrough above + case 2: + exp2 += exp + default: + panic("unexpected exponent base") + } + // exp consumed - not needed anymore + + // apply 2**exp2 + if MinExp <= exp2 && exp2 <= MaxExp { + z.prec = prec + z.form = finite + z.exp = int32(exp2) + f = z + } else { + err = fmt.Errorf("exponent overflow") + return + } + + if exp5 == 0 { + // no decimal exponent contribution + z.round(0) + return + } + // exp5 != 0 + + // apply 5**exp5 + p := new(Float).SetPrec(z.Prec() + 64) // use more bits for p -- TODO(gri) what is the right number? + if exp5 < 0 { + z.Quo(z, p.pow5(uint64(-exp5))) + } else { + z.Mul(z, p.pow5(uint64(exp5))) + } + + return +} + +// These powers of 5 fit into a uint64. +// +// for p, q := uint64(0), uint64(1); p < q; p, q = q, q*5 { +// fmt.Println(q) +// } +var pow5tab = [...]uint64{ + 1, + 5, + 25, + 125, + 625, + 3125, + 15625, + 78125, + 390625, + 1953125, + 9765625, + 48828125, + 244140625, + 1220703125, + 6103515625, + 30517578125, + 152587890625, + 762939453125, + 3814697265625, + 19073486328125, + 95367431640625, + 476837158203125, + 2384185791015625, + 11920928955078125, + 59604644775390625, + 298023223876953125, + 1490116119384765625, + 7450580596923828125, +} + +// pow5 sets z to 5**n and returns z. +// n must not be negative. +func (z *Float) pow5(n uint64) *Float { + const m = uint64(len(pow5tab) - 1) + if n <= m { + return z.SetUint64(pow5tab[n]) + } + // n > m + + z.SetUint64(pow5tab[m]) + n -= m + + // use more bits for f than for z + // TODO(gri) what is the right number? + f := new(Float).SetPrec(z.Prec() + 64).SetUint64(5) + + for n > 0 { + if n&1 != 0 { + z.Mul(z, f) + } + f.Mul(f, f) + n >>= 1 + } + + return z +} + +// Parse parses s which must contain a text representation of a floating- +// point number with a mantissa in the given conversion base (the exponent +// is always a decimal number), or a string representing an infinite value. +// +// For base 0, an underscore character “_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number, or the returned +// digit count. Incorrect placement of underscores is reported as an +// error if there are no other errors. If base != 0, underscores are +// not recognized and thus terminate scanning like any other character +// that is not a valid radix point or digit. +// +// It sets z to the (possibly rounded) value of the corresponding floating- +// point value, and returns z, the actual base b, and an error err, if any. +// The entire string (not just a prefix) must be consumed for success. +// If z's precision is 0, it is changed to 64 before rounding takes effect. +// The number must be of the form: +// +// number = [ sign ] ( float | "inf" | "Inf" ) . +// sign = "+" | "-" . +// float = ( mantissa | prefix pmantissa ) [ exponent ] . +// prefix = "0" [ "b" | "B" | "o" | "O" | "x" | "X" ] . +// mantissa = digits "." [ digits ] | digits | "." digits . +// pmantissa = [ "_" ] digits "." [ digits ] | [ "_" ] digits | "." digits . +// exponent = ( "e" | "E" | "p" | "P" ) [ sign ] digits . +// digits = digit { [ "_" ] digit } . +// digit = "0" ... "9" | "a" ... "z" | "A" ... "Z" . +// +// The base argument must be 0, 2, 8, 10, or 16. Providing an invalid base +// argument will lead to a run-time panic. +// +// For base 0, the number prefix determines the actual base: A prefix of +// “0b” or “0B” selects base 2, “0o” or “0O” selects base 8, and +// “0x” or “0X” selects base 16. Otherwise, the actual base is 10 and +// no prefix is accepted. The octal prefix "0" is not supported (a leading +// "0" is simply considered a "0"). +// +// A "p" or "P" exponent indicates a base 2 (rather then base 10) exponent; +// for instance, "0x1.fffffffffffffp1023" (using base 0) represents the +// maximum float64 value. For hexadecimal mantissae, the exponent character +// must be one of 'p' or 'P', if present (an "e" or "E" exponent indicator +// cannot be distinguished from a mantissa digit). +// +// The returned *Float f is nil and the value of z is valid but not +// defined if an error is reported. +func (z *Float) Parse(s string, base int) (f *Float, b int, err error) { + // scan doesn't handle ±Inf + if len(s) == 3 && (s == "Inf" || s == "inf") { + f = z.SetInf(false) + return + } + if len(s) == 4 && (s[0] == '+' || s[0] == '-') && (s[1:] == "Inf" || s[1:] == "inf") { + f = z.SetInf(s[0] == '-') + return + } + + r := strings.NewReader(s) + if f, b, err = z.scan(r, base); err != nil { + return + } + + // entire string must have been consumed + if ch, err2 := r.ReadByte(); err2 == nil { + err = fmt.Errorf("expected end of string, found %q", ch) + } else if err2 != io.EOF { + err = err2 + } + + return +} + +// ParseFloat is like f.Parse(s, base) with f set to the given precision +// and rounding mode. +func ParseFloat(s string, base int, prec uint, mode RoundingMode) (f *Float, b int, err error) { + return new(Float).SetPrec(prec).SetMode(mode).Parse(s, base) +} + +var _ fmt.Scanner = (*Float)(nil) // *Float must implement fmt.Scanner + +// Scan is a support routine for fmt.Scanner; it sets z to the value of +// the scanned number. It accepts formats whose verbs are supported by +// fmt.Scan for floating point values, which are: +// 'b' (binary), 'e', 'E', 'f', 'F', 'g' and 'G'. +// Scan doesn't handle ±Inf. +func (z *Float) Scan(s fmt.ScanState, ch rune) error { + s.SkipSpace() + _, _, err := z.scan(byteReader{s}, 0) + return err +} diff --git a/src/math/big/floatconv_test.go b/src/math/big/floatconv_test.go new file mode 100644 index 0000000..a1cc38a --- /dev/null +++ b/src/math/big/floatconv_test.go @@ -0,0 +1,825 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "bytes" + "fmt" + "math" + "math/bits" + "strconv" + "testing" +) + +var zero_ float64 + +func TestFloatSetFloat64String(t *testing.T) { + inf := math.Inf(0) + nan := math.NaN() + + for _, test := range []struct { + s string + x float64 // NaNs represent invalid inputs + }{ + // basics + {"0", 0}, + {"-0", -zero_}, + {"+0", 0}, + {"1", 1}, + {"-1", -1}, + {"+1", 1}, + {"1.234", 1.234}, + {"-1.234", -1.234}, + {"+1.234", 1.234}, + {".1", 0.1}, + {"1.", 1}, + {"+1.", 1}, + + // various zeros + {"0e100", 0}, + {"-0e+100", -zero_}, + {"+0e-100", 0}, + {"0E100", 0}, + {"-0E+100", -zero_}, + {"+0E-100", 0}, + + // various decimal exponent formats + {"1.e10", 1e10}, + {"1e+10", 1e10}, + {"+1e-10", 1e-10}, + {"1E10", 1e10}, + {"1.E+10", 1e10}, + {"+1E-10", 1e-10}, + + // infinities + {"Inf", inf}, + {"+Inf", inf}, + {"-Inf", -inf}, + {"inf", inf}, + {"+inf", inf}, + {"-inf", -inf}, + + // invalid numbers + {"", nan}, + {"-", nan}, + {"0x", nan}, + {"0e", nan}, + {"1.2ef", nan}, + {"2..3", nan}, + {"123..", nan}, + {"infinity", nan}, + {"foobar", nan}, + + // invalid underscores + {"_", nan}, + {"0_", nan}, + {"1__0", nan}, + {"123_.", nan}, + {"123._", nan}, + {"123._4", nan}, + {"1_2.3_4_", nan}, + {"_.123", nan}, + {"_123.456", nan}, + {"10._0", nan}, + {"10.0e_0", nan}, + {"10.0e0_", nan}, + {"0P-0__0", nan}, + + // misc decimal values + {"3.14159265", 3.14159265}, + {"-687436.79457e-245", -687436.79457e-245}, + {"-687436.79457E245", -687436.79457e245}, + {".0000000000000000000000000000000000000001", 1e-40}, + {"+10000000000000000000000000000000000000000e-0", 1e40}, + + // decimal mantissa, binary exponent + {"0p0", 0}, + {"-0p0", -zero_}, + {"1p10", 1 << 10}, + {"1p+10", 1 << 10}, + {"+1p-10", 1.0 / (1 << 10)}, + {"1024p-12", 0.25}, + {"-1p10", -1024}, + {"1.5p1", 3}, + + // binary mantissa, decimal exponent + {"0b0", 0}, + {"-0b0", -zero_}, + {"0b0e+10", 0}, + {"-0b0e-10", -zero_}, + {"0b1010", 10}, + {"0B1010E2", 1000}, + {"0b.1", 0.5}, + {"0b.001", 0.125}, + {"0b.001e3", 125}, + + // binary mantissa, binary exponent + {"0b0p+10", 0}, + {"-0b0p-10", -zero_}, + {"0b.1010p4", 10}, + {"0b1p-1", 0.5}, + {"0b001p-3", 0.125}, + {"0b.001p3", 1}, + {"0b0.01p2", 1}, + {"0b0.01P+2", 1}, + + // octal mantissa, decimal exponent + {"0o0", 0}, + {"-0o0", -zero_}, + {"0o0e+10", 0}, + {"-0o0e-10", -zero_}, + {"0o12", 10}, + {"0O12E2", 1000}, + {"0o.4", 0.5}, + {"0o.01", 0.015625}, + {"0o.01e3", 15.625}, + + // octal mantissa, binary exponent + {"0o0p+10", 0}, + {"-0o0p-10", -zero_}, + {"0o.12p6", 10}, + {"0o4p-3", 0.5}, + {"0o0014p-6", 0.1875}, + {"0o.001p9", 1}, + {"0o0.01p7", 2}, + {"0O0.01P+2", 0.0625}, + + // hexadecimal mantissa and exponent + {"0x0", 0}, + {"-0x0", -zero_}, + {"0x0p+10", 0}, + {"-0x0p-10", -zero_}, + {"0xff", 255}, + {"0X.8p1", 1}, + {"-0X0.00008p16", -0.5}, + {"-0X0.00008P+16", -0.5}, + {"0x0.0000000000001p-1022", math.SmallestNonzeroFloat64}, + {"0x1.fffffffffffffp1023", math.MaxFloat64}, + + // underscores + {"0_0", 0}, + {"1_000.", 1000}, + {"1_2_3.4_5_6", 123.456}, + {"1.0e0_0", 1}, + {"1p+1_0", 1024}, + {"0b_1000", 0x8}, + {"0b_1011_1101", 0xbd}, + {"0x_f0_0d_1eP+0_8", 0xf00d1e00}, + } { + var x Float + x.SetPrec(53) + _, ok := x.SetString(test.s) + if math.IsNaN(test.x) { + // test.s is invalid + if ok { + t.Errorf("%s: want parse error", test.s) + } + continue + } + // test.s is valid + if !ok { + t.Errorf("%s: got parse error", test.s) + continue + } + f, _ := x.Float64() + want := new(Float).SetFloat64(test.x) + if x.Cmp(want) != 0 || x.Signbit() != want.Signbit() { + t.Errorf("%s: got %v (%v); want %v", test.s, &x, f, test.x) + } + } +} + +func fdiv(a, b float64) float64 { return a / b } + +const ( + below1e23 = 99999999999999974834176 + above1e23 = 100000000000000008388608 +) + +func TestFloat64Text(t *testing.T) { + for _, test := range []struct { + x float64 + format byte + prec int + want string + }{ + {0, 'f', 0, "0"}, + {math.Copysign(0, -1), 'f', 0, "-0"}, + {1, 'f', 0, "1"}, + {-1, 'f', 0, "-1"}, + + {0.001, 'e', 0, "1e-03"}, + {0.459, 'e', 0, "5e-01"}, + {1.459, 'e', 0, "1e+00"}, + {2.459, 'e', 1, "2.5e+00"}, + {3.459, 'e', 2, "3.46e+00"}, + {4.459, 'e', 3, "4.459e+00"}, + {5.459, 'e', 4, "5.4590e+00"}, + + {0.001, 'f', 0, "0"}, + {0.459, 'f', 0, "0"}, + {1.459, 'f', 0, "1"}, + {2.459, 'f', 1, "2.5"}, + {3.459, 'f', 2, "3.46"}, + {4.459, 'f', 3, "4.459"}, + {5.459, 'f', 4, "5.4590"}, + + {0, 'b', 0, "0"}, + {math.Copysign(0, -1), 'b', 0, "-0"}, + {1.0, 'b', 0, "4503599627370496p-52"}, + {-1.0, 'b', 0, "-4503599627370496p-52"}, + {4503599627370496, 'b', 0, "4503599627370496p+0"}, + + {0, 'p', 0, "0"}, + {math.Copysign(0, -1), 'p', 0, "-0"}, + {1024.0, 'p', 0, "0x.8p+11"}, + {-1024.0, 'p', 0, "-0x.8p+11"}, + + // all test cases below from strconv/ftoa_test.go + {1, 'e', 5, "1.00000e+00"}, + {1, 'f', 5, "1.00000"}, + {1, 'g', 5, "1"}, + {1, 'g', -1, "1"}, + {20, 'g', -1, "20"}, + {1234567.8, 'g', -1, "1.2345678e+06"}, + {200000, 'g', -1, "200000"}, + {2000000, 'g', -1, "2e+06"}, + + // g conversion and zero suppression + {400, 'g', 2, "4e+02"}, + {40, 'g', 2, "40"}, + {4, 'g', 2, "4"}, + {.4, 'g', 2, "0.4"}, + {.04, 'g', 2, "0.04"}, + {.004, 'g', 2, "0.004"}, + {.0004, 'g', 2, "0.0004"}, + {.00004, 'g', 2, "4e-05"}, + {.000004, 'g', 2, "4e-06"}, + + {0, 'e', 5, "0.00000e+00"}, + {0, 'f', 5, "0.00000"}, + {0, 'g', 5, "0"}, + {0, 'g', -1, "0"}, + + {-1, 'e', 5, "-1.00000e+00"}, + {-1, 'f', 5, "-1.00000"}, + {-1, 'g', 5, "-1"}, + {-1, 'g', -1, "-1"}, + + {12, 'e', 5, "1.20000e+01"}, + {12, 'f', 5, "12.00000"}, + {12, 'g', 5, "12"}, + {12, 'g', -1, "12"}, + + {123456700, 'e', 5, "1.23457e+08"}, + {123456700, 'f', 5, "123456700.00000"}, + {123456700, 'g', 5, "1.2346e+08"}, + {123456700, 'g', -1, "1.234567e+08"}, + + {1.2345e6, 'e', 5, "1.23450e+06"}, + {1.2345e6, 'f', 5, "1234500.00000"}, + {1.2345e6, 'g', 5, "1.2345e+06"}, + + {1e23, 'e', 17, "9.99999999999999916e+22"}, + {1e23, 'f', 17, "99999999999999991611392.00000000000000000"}, + {1e23, 'g', 17, "9.9999999999999992e+22"}, + + {1e23, 'e', -1, "1e+23"}, + {1e23, 'f', -1, "100000000000000000000000"}, + {1e23, 'g', -1, "1e+23"}, + + {below1e23, 'e', 17, "9.99999999999999748e+22"}, + {below1e23, 'f', 17, "99999999999999974834176.00000000000000000"}, + {below1e23, 'g', 17, "9.9999999999999975e+22"}, + + {below1e23, 'e', -1, "9.999999999999997e+22"}, + {below1e23, 'f', -1, "99999999999999970000000"}, + {below1e23, 'g', -1, "9.999999999999997e+22"}, + + {above1e23, 'e', 17, "1.00000000000000008e+23"}, + {above1e23, 'f', 17, "100000000000000008388608.00000000000000000"}, + {above1e23, 'g', 17, "1.0000000000000001e+23"}, + + {above1e23, 'e', -1, "1.0000000000000001e+23"}, + {above1e23, 'f', -1, "100000000000000010000000"}, + {above1e23, 'g', -1, "1.0000000000000001e+23"}, + + {5e-304 / 1e20, 'g', -1, "5e-324"}, + {-5e-304 / 1e20, 'g', -1, "-5e-324"}, + {fdiv(5e-304, 1e20), 'g', -1, "5e-324"}, // avoid constant arithmetic + {fdiv(-5e-304, 1e20), 'g', -1, "-5e-324"}, // avoid constant arithmetic + + {32, 'g', -1, "32"}, + {32, 'g', 0, "3e+01"}, + + {100, 'x', -1, "0x1.9p+06"}, + + // {math.NaN(), 'g', -1, "NaN"}, // Float doesn't support NaNs + // {-math.NaN(), 'g', -1, "NaN"}, // Float doesn't support NaNs + {math.Inf(0), 'g', -1, "+Inf"}, + {math.Inf(-1), 'g', -1, "-Inf"}, + {-math.Inf(0), 'g', -1, "-Inf"}, + + {-1, 'b', -1, "-4503599627370496p-52"}, + + // fixed bugs + {0.9, 'f', 1, "0.9"}, + {0.09, 'f', 1, "0.1"}, + {0.0999, 'f', 1, "0.1"}, + {0.05, 'f', 1, "0.1"}, + {0.05, 'f', 0, "0"}, + {0.5, 'f', 1, "0.5"}, + {0.5, 'f', 0, "0"}, + {1.5, 'f', 0, "2"}, + + // https://www.exploringbinary.com/java-hangs-when-converting-2-2250738585072012e-308/ + {2.2250738585072012e-308, 'g', -1, "2.2250738585072014e-308"}, + // https://www.exploringbinary.com/php-hangs-on-numeric-value-2-2250738585072011e-308/ + {2.2250738585072011e-308, 'g', -1, "2.225073858507201e-308"}, + + // Issue 2625. + {383260575764816448, 'f', 0, "383260575764816448"}, + {383260575764816448, 'g', -1, "3.8326057576481645e+17"}, + + // Issue 15918. + {1, 'f', -10, "1"}, + {1, 'f', -11, "1"}, + {1, 'f', -12, "1"}, + } { + // The test cases are from the strconv package which tests float64 values. + // When formatting values with prec = -1 (shortest representation), + // the actually available mantissa precision matters. + // For denormalized values, that precision is < 53 (SetFloat64 default). + // Compute and set the actual precision explicitly. + f := new(Float).SetPrec(actualPrec(test.x)).SetFloat64(test.x) + got := f.Text(test.format, test.prec) + if got != test.want { + t.Errorf("%v: got %s; want %s", test, got, test.want) + continue + } + + if test.format == 'b' && test.x == 0 { + continue // 'b' format in strconv.Float requires knowledge of bias for 0.0 + } + if test.format == 'p' { + continue // 'p' format not supported in strconv.Format + } + + // verify that Float format matches strconv format + want := strconv.FormatFloat(test.x, test.format, test.prec, 64) + if got != want { + t.Errorf("%v: got %s; want %s (strconv)", test, got, want) + } + } +} + +// actualPrec returns the number of actually used mantissa bits. +func actualPrec(x float64) uint { + if mant := math.Float64bits(x); x != 0 && mant&(0x7ff<<52) == 0 { + // x is denormalized + return 64 - uint(bits.LeadingZeros64(mant&(1<<52-1))) + } + return 53 +} + +func TestFloatText(t *testing.T) { + const defaultRound = ^RoundingMode(0) + + for _, test := range []struct { + x string + round RoundingMode + prec uint + format byte + digits int + want string + }{ + {"0", defaultRound, 10, 'f', 0, "0"}, + {"-0", defaultRound, 10, 'f', 0, "-0"}, + {"1", defaultRound, 10, 'f', 0, "1"}, + {"-1", defaultRound, 10, 'f', 0, "-1"}, + + {"1.459", defaultRound, 100, 'e', 0, "1e+00"}, + {"2.459", defaultRound, 100, 'e', 1, "2.5e+00"}, + {"3.459", defaultRound, 100, 'e', 2, "3.46e+00"}, + {"4.459", defaultRound, 100, 'e', 3, "4.459e+00"}, + {"5.459", defaultRound, 100, 'e', 4, "5.4590e+00"}, + + {"1.459", defaultRound, 100, 'E', 0, "1E+00"}, + {"2.459", defaultRound, 100, 'E', 1, "2.5E+00"}, + {"3.459", defaultRound, 100, 'E', 2, "3.46E+00"}, + {"4.459", defaultRound, 100, 'E', 3, "4.459E+00"}, + {"5.459", defaultRound, 100, 'E', 4, "5.4590E+00"}, + + {"1.459", defaultRound, 100, 'f', 0, "1"}, + {"2.459", defaultRound, 100, 'f', 1, "2.5"}, + {"3.459", defaultRound, 100, 'f', 2, "3.46"}, + {"4.459", defaultRound, 100, 'f', 3, "4.459"}, + {"5.459", defaultRound, 100, 'f', 4, "5.4590"}, + + {"1.459", defaultRound, 100, 'g', 0, "1"}, + {"2.459", defaultRound, 100, 'g', 1, "2"}, + {"3.459", defaultRound, 100, 'g', 2, "3.5"}, + {"4.459", defaultRound, 100, 'g', 3, "4.46"}, + {"5.459", defaultRound, 100, 'g', 4, "5.459"}, + + {"1459", defaultRound, 53, 'g', 0, "1e+03"}, + {"2459", defaultRound, 53, 'g', 1, "2e+03"}, + {"3459", defaultRound, 53, 'g', 2, "3.5e+03"}, + {"4459", defaultRound, 53, 'g', 3, "4.46e+03"}, + {"5459", defaultRound, 53, 'g', 4, "5459"}, + + {"1459", defaultRound, 53, 'G', 0, "1E+03"}, + {"2459", defaultRound, 53, 'G', 1, "2E+03"}, + {"3459", defaultRound, 53, 'G', 2, "3.5E+03"}, + {"4459", defaultRound, 53, 'G', 3, "4.46E+03"}, + {"5459", defaultRound, 53, 'G', 4, "5459"}, + + {"3", defaultRound, 10, 'e', 40, "3.0000000000000000000000000000000000000000e+00"}, + {"3", defaultRound, 10, 'f', 40, "3.0000000000000000000000000000000000000000"}, + {"3", defaultRound, 10, 'g', 40, "3"}, + + {"3e40", defaultRound, 100, 'e', 40, "3.0000000000000000000000000000000000000000e+40"}, + {"3e40", defaultRound, 100, 'f', 4, "30000000000000000000000000000000000000000.0000"}, + {"3e40", defaultRound, 100, 'g', 40, "3e+40"}, + + // make sure "stupid" exponents don't stall the machine + {"1e1000000", defaultRound, 64, 'p', 0, "0x.88b3a28a05eade3ap+3321929"}, + {"1e646456992", defaultRound, 64, 'p', 0, "0x.e883a0c5c8c7c42ap+2147483644"}, + {"1e646456993", defaultRound, 64, 'p', 0, "+Inf"}, + {"1e1000000000", defaultRound, 64, 'p', 0, "+Inf"}, + {"1e-1000000", defaultRound, 64, 'p', 0, "0x.efb4542cc8ca418ap-3321928"}, + {"1e-646456993", defaultRound, 64, 'p', 0, "0x.e17c8956983d9d59p-2147483647"}, + {"1e-646456994", defaultRound, 64, 'p', 0, "0"}, + {"1e-1000000000", defaultRound, 64, 'p', 0, "0"}, + + // minimum and maximum values + {"1p2147483646", defaultRound, 64, 'p', 0, "0x.8p+2147483647"}, + {"0x.8p2147483647", defaultRound, 64, 'p', 0, "0x.8p+2147483647"}, + {"0x.8p-2147483647", defaultRound, 64, 'p', 0, "0x.8p-2147483647"}, + {"1p-2147483649", defaultRound, 64, 'p', 0, "0x.8p-2147483648"}, + + // TODO(gri) need tests for actual large Floats + + {"0", defaultRound, 53, 'b', 0, "0"}, + {"-0", defaultRound, 53, 'b', 0, "-0"}, + {"1.0", defaultRound, 53, 'b', 0, "4503599627370496p-52"}, + {"-1.0", defaultRound, 53, 'b', 0, "-4503599627370496p-52"}, + {"4503599627370496", defaultRound, 53, 'b', 0, "4503599627370496p+0"}, + + // issue 9939 + {"3", defaultRound, 350, 'b', 0, "1720123961992553633708115671476565205597423741876210842803191629540192157066363606052513914832594264915968p-348"}, + {"03", defaultRound, 350, 'b', 0, "1720123961992553633708115671476565205597423741876210842803191629540192157066363606052513914832594264915968p-348"}, + {"3.", defaultRound, 350, 'b', 0, "1720123961992553633708115671476565205597423741876210842803191629540192157066363606052513914832594264915968p-348"}, + {"3.0", defaultRound, 350, 'b', 0, "1720123961992553633708115671476565205597423741876210842803191629540192157066363606052513914832594264915968p-348"}, + {"3.00", defaultRound, 350, 'b', 0, "1720123961992553633708115671476565205597423741876210842803191629540192157066363606052513914832594264915968p-348"}, + {"3.000", defaultRound, 350, 'b', 0, "1720123961992553633708115671476565205597423741876210842803191629540192157066363606052513914832594264915968p-348"}, + + {"3", defaultRound, 350, 'p', 0, "0x.cp+2"}, + {"03", defaultRound, 350, 'p', 0, "0x.cp+2"}, + {"3.", defaultRound, 350, 'p', 0, "0x.cp+2"}, + {"3.0", defaultRound, 350, 'p', 0, "0x.cp+2"}, + {"3.00", defaultRound, 350, 'p', 0, "0x.cp+2"}, + {"3.000", defaultRound, 350, 'p', 0, "0x.cp+2"}, + + {"0", defaultRound, 64, 'p', 0, "0"}, + {"-0", defaultRound, 64, 'p', 0, "-0"}, + {"1024.0", defaultRound, 64, 'p', 0, "0x.8p+11"}, + {"-1024.0", defaultRound, 64, 'p', 0, "-0x.8p+11"}, + + {"0", defaultRound, 64, 'x', -1, "0x0p+00"}, + {"0", defaultRound, 64, 'x', 0, "0x0p+00"}, + {"0", defaultRound, 64, 'x', 1, "0x0.0p+00"}, + {"0", defaultRound, 64, 'x', 5, "0x0.00000p+00"}, + {"3.25", defaultRound, 64, 'x', 0, "0x1p+02"}, + {"-3.25", defaultRound, 64, 'x', 0, "-0x1p+02"}, + {"3.25", defaultRound, 64, 'x', 1, "0x1.ap+01"}, + {"-3.25", defaultRound, 64, 'x', 1, "-0x1.ap+01"}, + {"3.25", defaultRound, 64, 'x', -1, "0x1.ap+01"}, + {"-3.25", defaultRound, 64, 'x', -1, "-0x1.ap+01"}, + {"1024.0", defaultRound, 64, 'x', 0, "0x1p+10"}, + {"-1024.0", defaultRound, 64, 'x', 0, "-0x1p+10"}, + {"1024.0", defaultRound, 64, 'x', 5, "0x1.00000p+10"}, + {"8191.0", defaultRound, 53, 'x', -1, "0x1.fffp+12"}, + {"8191.5", defaultRound, 53, 'x', -1, "0x1.fff8p+12"}, + {"8191.53125", defaultRound, 53, 'x', -1, "0x1.fff88p+12"}, + {"8191.53125", defaultRound, 53, 'x', 4, "0x1.fff8p+12"}, + {"8191.53125", defaultRound, 53, 'x', 3, "0x1.000p+13"}, + {"8191.53125", defaultRound, 53, 'x', 0, "0x1p+13"}, + {"8191.533203125", defaultRound, 53, 'x', -1, "0x1.fff888p+12"}, + {"8191.533203125", defaultRound, 53, 'x', 5, "0x1.fff88p+12"}, + {"8191.533203125", defaultRound, 53, 'x', 4, "0x1.fff9p+12"}, + + {"8191.53125", defaultRound, 53, 'x', -1, "0x1.fff88p+12"}, + {"8191.53125", ToNearestEven, 53, 'x', 5, "0x1.fff88p+12"}, + {"8191.53125", ToNearestAway, 53, 'x', 5, "0x1.fff88p+12"}, + {"8191.53125", ToZero, 53, 'x', 5, "0x1.fff88p+12"}, + {"8191.53125", AwayFromZero, 53, 'x', 5, "0x1.fff88p+12"}, + {"8191.53125", ToNegativeInf, 53, 'x', 5, "0x1.fff88p+12"}, + {"8191.53125", ToPositiveInf, 53, 'x', 5, "0x1.fff88p+12"}, + + {"8191.53125", defaultRound, 53, 'x', 4, "0x1.fff8p+12"}, + {"8191.53125", defaultRound, 53, 'x', 3, "0x1.000p+13"}, + {"8191.53125", defaultRound, 53, 'x', 0, "0x1p+13"}, + {"8191.533203125", defaultRound, 53, 'x', -1, "0x1.fff888p+12"}, + {"8191.533203125", defaultRound, 53, 'x', 6, "0x1.fff888p+12"}, + {"8191.533203125", defaultRound, 53, 'x', 5, "0x1.fff88p+12"}, + {"8191.533203125", defaultRound, 53, 'x', 4, "0x1.fff9p+12"}, + + {"8191.53125", ToNearestEven, 53, 'x', 4, "0x1.fff8p+12"}, + {"8191.53125", ToNearestAway, 53, 'x', 4, "0x1.fff9p+12"}, + {"8191.53125", ToZero, 53, 'x', 4, "0x1.fff8p+12"}, + {"8191.53125", ToZero, 53, 'x', 2, "0x1.ffp+12"}, + {"8191.53125", AwayFromZero, 53, 'x', 4, "0x1.fff9p+12"}, + {"8191.53125", ToNegativeInf, 53, 'x', 4, "0x1.fff8p+12"}, + {"-8191.53125", ToNegativeInf, 53, 'x', 4, "-0x1.fff9p+12"}, + {"8191.53125", ToPositiveInf, 53, 'x', 4, "0x1.fff9p+12"}, + {"-8191.53125", ToPositiveInf, 53, 'x', 4, "-0x1.fff8p+12"}, + + // issue 34343 + {"0x.8p-2147483648", ToNearestEven, 4, 'p', -1, "0x.8p-2147483648"}, + {"0x.8p-2147483648", ToNearestEven, 4, 'x', -1, "0x1p-2147483649"}, + } { + f, _, err := ParseFloat(test.x, 0, test.prec, ToNearestEven) + if err != nil { + t.Errorf("%v: %s", test, err) + continue + } + if test.round != defaultRound { + f.SetMode(test.round) + } + + got := f.Text(test.format, test.digits) + if got != test.want { + t.Errorf("%v: got %s; want %s", test, got, test.want) + } + + // compare with strconv.FormatFloat output if possible + // ('p' format is not supported by strconv.FormatFloat, + // and its output for 0.0 prints a biased exponent value + // as in 0p-1074 which makes no sense to emulate here) + if test.prec == 53 && test.format != 'p' && f.Sign() != 0 && (test.round == ToNearestEven || test.round == defaultRound) { + f64, acc := f.Float64() + if acc != Exact { + t.Errorf("%v: expected exact conversion to float64", test) + continue + } + got := strconv.FormatFloat(f64, test.format, test.digits, 64) + if got != test.want { + t.Errorf("%v: got %s; want %s", test, got, test.want) + } + } + } +} + +func TestFloatFormat(t *testing.T) { + for _, test := range []struct { + format string + value any // float32, float64, or string (== 512bit *Float) + want string + }{ + // from fmt/fmt_test.go + {"%+.3e", 0.0, "+0.000e+00"}, + {"%+.3e", 1.0, "+1.000e+00"}, + {"%+.3f", -1.0, "-1.000"}, + {"%+.3F", -1.0, "-1.000"}, + {"%+.3F", float32(-1.0), "-1.000"}, + {"%+07.2f", 1.0, "+001.00"}, + {"%+07.2f", -1.0, "-001.00"}, + {"%+10.2f", +1.0, " +1.00"}, + {"%+10.2f", -1.0, " -1.00"}, + {"% .3E", -1.0, "-1.000E+00"}, + {"% .3e", 1.0, " 1.000e+00"}, + {"%+.3g", 0.0, "+0"}, + {"%+.3g", 1.0, "+1"}, + {"%+.3g", -1.0, "-1"}, + {"% .3g", -1.0, "-1"}, + {"% .3g", 1.0, " 1"}, + {"%b", float32(1.0), "8388608p-23"}, + {"%b", 1.0, "4503599627370496p-52"}, + + // from fmt/fmt_test.go: old test/fmt_test.go + {"%e", 1.0, "1.000000e+00"}, + {"%e", 1234.5678e3, "1.234568e+06"}, + {"%e", 1234.5678e-8, "1.234568e-05"}, + {"%e", -7.0, "-7.000000e+00"}, + {"%e", -1e-9, "-1.000000e-09"}, + {"%f", 1234.5678e3, "1234567.800000"}, + {"%f", 1234.5678e-8, "0.000012"}, + {"%f", -7.0, "-7.000000"}, + {"%f", -1e-9, "-0.000000"}, + {"%g", 1234.5678e3, "1.2345678e+06"}, + {"%g", float32(1234.5678e3), "1.2345678e+06"}, + {"%g", 1234.5678e-8, "1.2345678e-05"}, + {"%g", -7.0, "-7"}, + {"%g", -1e-9, "-1e-09"}, + {"%g", float32(-1e-9), "-1e-09"}, + {"%E", 1.0, "1.000000E+00"}, + {"%E", 1234.5678e3, "1.234568E+06"}, + {"%E", 1234.5678e-8, "1.234568E-05"}, + {"%E", -7.0, "-7.000000E+00"}, + {"%E", -1e-9, "-1.000000E-09"}, + {"%G", 1234.5678e3, "1.2345678E+06"}, + {"%G", float32(1234.5678e3), "1.2345678E+06"}, + {"%G", 1234.5678e-8, "1.2345678E-05"}, + {"%G", -7.0, "-7"}, + {"%G", -1e-9, "-1E-09"}, + {"%G", float32(-1e-9), "-1E-09"}, + + {"%20.6e", 1.2345e3, " 1.234500e+03"}, + {"%20.6e", 1.2345e-3, " 1.234500e-03"}, + {"%20e", 1.2345e3, " 1.234500e+03"}, + {"%20e", 1.2345e-3, " 1.234500e-03"}, + {"%20.8e", 1.2345e3, " 1.23450000e+03"}, + {"%20f", 1.23456789e3, " 1234.567890"}, + {"%20f", 1.23456789e-3, " 0.001235"}, + {"%20f", 12345678901.23456789, " 12345678901.234568"}, + {"%-20f", 1.23456789e3, "1234.567890 "}, + {"%20.8f", 1.23456789e3, " 1234.56789000"}, + {"%20.8f", 1.23456789e-3, " 0.00123457"}, + {"%g", 1.23456789e3, "1234.56789"}, + {"%g", 1.23456789e-3, "0.00123456789"}, + {"%g", 1.23456789e20, "1.23456789e+20"}, + {"%20e", math.Inf(1), " +Inf"}, + {"%-20f", math.Inf(-1), "-Inf "}, + + // from fmt/fmt_test.go: comparison of padding rules with C printf + {"%.2f", 1.0, "1.00"}, + {"%.2f", -1.0, "-1.00"}, + {"% .2f", 1.0, " 1.00"}, + {"% .2f", -1.0, "-1.00"}, + {"%+.2f", 1.0, "+1.00"}, + {"%+.2f", -1.0, "-1.00"}, + {"%7.2f", 1.0, " 1.00"}, + {"%7.2f", -1.0, " -1.00"}, + {"% 7.2f", 1.0, " 1.00"}, + {"% 7.2f", -1.0, " -1.00"}, + {"%+7.2f", 1.0, " +1.00"}, + {"%+7.2f", -1.0, " -1.00"}, + {"%07.2f", 1.0, "0001.00"}, + {"%07.2f", -1.0, "-001.00"}, + {"% 07.2f", 1.0, " 001.00"}, + {"% 07.2f", -1.0, "-001.00"}, + {"%+07.2f", 1.0, "+001.00"}, + {"%+07.2f", -1.0, "-001.00"}, + + // from fmt/fmt_test.go: zero padding does not apply to infinities + {"%020f", math.Inf(-1), " -Inf"}, + {"%020f", math.Inf(+1), " +Inf"}, + {"% 020f", math.Inf(-1), " -Inf"}, + {"% 020f", math.Inf(+1), " Inf"}, + {"%+020f", math.Inf(-1), " -Inf"}, + {"%+020f", math.Inf(+1), " +Inf"}, + {"%20f", -1.0, " -1.000000"}, + + // handle %v like %g + {"%v", 0.0, "0"}, + {"%v", -7.0, "-7"}, + {"%v", -1e-9, "-1e-09"}, + {"%v", float32(-1e-9), "-1e-09"}, + {"%010v", 0.0, "0000000000"}, + + // *Float cases + {"%.20f", "1e-20", "0.00000000000000000001"}, + {"%.20f", "-1e-20", "-0.00000000000000000001"}, + {"%30.20f", "-1e-20", " -0.00000000000000000001"}, + {"%030.20f", "-1e-20", "-00000000.00000000000000000001"}, + {"%030.20f", "+1e-20", "000000000.00000000000000000001"}, + {"% 030.20f", "+1e-20", " 00000000.00000000000000000001"}, + + // erroneous formats + {"%s", 1.0, "%!s(*big.Float=1)"}, + } { + value := new(Float) + switch v := test.value.(type) { + case float32: + value.SetPrec(24).SetFloat64(float64(v)) + case float64: + value.SetPrec(53).SetFloat64(v) + case string: + value.SetPrec(512).Parse(v, 0) + default: + t.Fatalf("unsupported test value: %v (%T)", v, v) + } + + if got := fmt.Sprintf(test.format, value); got != test.want { + t.Errorf("%v: got %q; want %q", test, got, test.want) + } + } +} + +func BenchmarkParseFloatSmallExp(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, s := range []string{ + "1e0", + "1e-1", + "1e-2", + "1e-3", + "1e-4", + "1e-5", + "1e-10", + "1e-20", + "1e-50", + "1e1", + "1e2", + "1e3", + "1e4", + "1e5", + "1e10", + "1e20", + "1e50", + } { + var x Float + _, _, err := x.Parse(s, 0) + if err != nil { + b.Fatalf("%s: %v", s, err) + } + } + } +} + +func BenchmarkParseFloatLargeExp(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, s := range []string{ + "1e0", + "1e-10", + "1e-20", + "1e-30", + "1e-40", + "1e-50", + "1e-100", + "1e-500", + "1e-1000", + "1e-5000", + "1e-10000", + "1e10", + "1e20", + "1e30", + "1e40", + "1e50", + "1e100", + "1e500", + "1e1000", + "1e5000", + "1e10000", + } { + var x Float + _, _, err := x.Parse(s, 0) + if err != nil { + b.Fatalf("%s: %v", s, err) + } + } + } +} + +func TestFloatScan(t *testing.T) { + var floatScanTests = []struct { + input string + format string + output string + remaining int + wantErr bool + }{ + 0: {"10.0", "%f", "10", 0, false}, + 1: {"23.98+2.0", "%v", "23.98", 4, false}, + 2: {"-1+1", "%v", "-1", 2, false}, + 3: {" 00000", "%v", "0", 0, false}, + 4: {"-123456p-78", "%b", "-4.084816388e-19", 0, false}, + 5: {"+123", "%b", "123", 0, false}, + 6: {"-1.234e+56", "%e", "-1.234e+56", 0, false}, + 7: {"-1.234E-56", "%E", "-1.234e-56", 0, false}, + 8: {"-1.234e+567", "%g", "-1.234e+567", 0, false}, + 9: {"+1234567891011.234", "%G", "1.234567891e+12", 0, false}, + + // Scan doesn't handle ±Inf. + 10: {"Inf", "%v", "", 3, true}, + 11: {"-Inf", "%v", "", 3, true}, + 12: {"-Inf", "%v", "", 3, true}, + } + + var buf bytes.Buffer + for i, test := range floatScanTests { + x := new(Float) + buf.Reset() + buf.WriteString(test.input) + _, err := fmt.Fscanf(&buf, test.format, x) + if test.wantErr { + if err == nil { + t.Errorf("#%d want non-nil err", i) + } + continue + } + + if err != nil { + t.Errorf("#%d error: %s", i, err) + } + + if x.String() != test.output { + t.Errorf("#%d got %s; want %s", i, x.String(), test.output) + } + if buf.Len() != test.remaining { + t.Errorf("#%d got %d bytes remaining; want %d", i, buf.Len(), test.remaining) + } + } +} diff --git a/src/math/big/floatexample_test.go b/src/math/big/floatexample_test.go new file mode 100644 index 0000000..0c6668c --- /dev/null +++ b/src/math/big/floatexample_test.go @@ -0,0 +1,141 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big_test + +import ( + "fmt" + "math" + "math/big" +) + +func ExampleFloat_Add() { + // Operate on numbers of different precision. + var x, y, z big.Float + x.SetInt64(1000) // x is automatically set to 64bit precision + y.SetFloat64(2.718281828) // y is automatically set to 53bit precision + z.SetPrec(32) + z.Add(&x, &y) + fmt.Printf("x = %.10g (%s, prec = %d, acc = %s)\n", &x, x.Text('p', 0), x.Prec(), x.Acc()) + fmt.Printf("y = %.10g (%s, prec = %d, acc = %s)\n", &y, y.Text('p', 0), y.Prec(), y.Acc()) + fmt.Printf("z = %.10g (%s, prec = %d, acc = %s)\n", &z, z.Text('p', 0), z.Prec(), z.Acc()) + // Output: + // x = 1000 (0x.fap+10, prec = 64, acc = Exact) + // y = 2.718281828 (0x.adf85458248cd8p+2, prec = 53, acc = Exact) + // z = 1002.718282 (0x.faadf854p+10, prec = 32, acc = Below) +} + +func ExampleFloat_shift() { + // Implement Float "shift" by modifying the (binary) exponents directly. + for s := -5; s <= 5; s++ { + x := big.NewFloat(0.5) + x.SetMantExp(x, x.MantExp(nil)+s) // shift x by s + fmt.Println(x) + } + // Output: + // 0.015625 + // 0.03125 + // 0.0625 + // 0.125 + // 0.25 + // 0.5 + // 1 + // 2 + // 4 + // 8 + // 16 +} + +func ExampleFloat_Cmp() { + inf := math.Inf(1) + zero := 0.0 + + operands := []float64{-inf, -1.2, -zero, 0, +1.2, +inf} + + fmt.Println(" x y cmp") + fmt.Println("---------------") + for _, x64 := range operands { + x := big.NewFloat(x64) + for _, y64 := range operands { + y := big.NewFloat(y64) + fmt.Printf("%4g %4g %3d\n", x, y, x.Cmp(y)) + } + fmt.Println() + } + + // Output: + // x y cmp + // --------------- + // -Inf -Inf 0 + // -Inf -1.2 -1 + // -Inf -0 -1 + // -Inf 0 -1 + // -Inf 1.2 -1 + // -Inf +Inf -1 + // + // -1.2 -Inf 1 + // -1.2 -1.2 0 + // -1.2 -0 -1 + // -1.2 0 -1 + // -1.2 1.2 -1 + // -1.2 +Inf -1 + // + // -0 -Inf 1 + // -0 -1.2 1 + // -0 -0 0 + // -0 0 0 + // -0 1.2 -1 + // -0 +Inf -1 + // + // 0 -Inf 1 + // 0 -1.2 1 + // 0 -0 0 + // 0 0 0 + // 0 1.2 -1 + // 0 +Inf -1 + // + // 1.2 -Inf 1 + // 1.2 -1.2 1 + // 1.2 -0 1 + // 1.2 0 1 + // 1.2 1.2 0 + // 1.2 +Inf -1 + // + // +Inf -Inf 1 + // +Inf -1.2 1 + // +Inf -0 1 + // +Inf 0 1 + // +Inf 1.2 1 + // +Inf +Inf 0 +} + +func ExampleRoundingMode() { + operands := []float64{2.6, 2.5, 2.1, -2.1, -2.5, -2.6} + + fmt.Print(" x") + for mode := big.ToNearestEven; mode <= big.ToPositiveInf; mode++ { + fmt.Printf(" %s", mode) + } + fmt.Println() + + for _, f64 := range operands { + fmt.Printf("%4g", f64) + for mode := big.ToNearestEven; mode <= big.ToPositiveInf; mode++ { + // sample operands above require 2 bits to represent mantissa + // set binary precision to 2 to round them to integer values + f := new(big.Float).SetPrec(2).SetMode(mode).SetFloat64(f64) + fmt.Printf(" %*g", len(mode.String()), f) + } + fmt.Println() + } + + // Output: + // x ToNearestEven ToNearestAway ToZero AwayFromZero ToNegativeInf ToPositiveInf + // 2.6 3 3 2 3 2 3 + // 2.5 2 3 2 3 2 3 + // 2.1 2 2 2 3 2 3 + // -2.1 -2 -2 -2 -3 -3 -2 + // -2.5 -2 -3 -2 -3 -3 -2 + // -2.6 -3 -3 -2 -3 -3 -2 +} diff --git a/src/math/big/floatmarsh.go b/src/math/big/floatmarsh.go new file mode 100644 index 0000000..990e085 --- /dev/null +++ b/src/math/big/floatmarsh.go @@ -0,0 +1,127 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements encoding/decoding of Floats. + +package big + +import ( + "encoding/binary" + "errors" + "fmt" +) + +// Gob codec version. Permits backward-compatible changes to the encoding. +const floatGobVersion byte = 1 + +// GobEncode implements the gob.GobEncoder interface. +// The Float value and all its attributes (precision, +// rounding mode, accuracy) are marshaled. +func (x *Float) GobEncode() ([]byte, error) { + if x == nil { + return nil, nil + } + + // determine max. space (bytes) required for encoding + sz := 1 + 1 + 4 // version + mode|acc|form|neg (3+2+2+1bit) + prec + n := 0 // number of mantissa words + if x.form == finite { + // add space for mantissa and exponent + n = int((x.prec + (_W - 1)) / _W) // required mantissa length in words for given precision + // actual mantissa slice could be shorter (trailing 0's) or longer (unused bits): + // - if shorter, only encode the words present + // - if longer, cut off unused words when encoding in bytes + // (in practice, this should never happen since rounding + // takes care of it, but be safe and do it always) + if len(x.mant) < n { + n = len(x.mant) + } + // len(x.mant) >= n + sz += 4 + n*_S // exp + mant + } + buf := make([]byte, sz) + + buf[0] = floatGobVersion + b := byte(x.mode&7)<<5 | byte((x.acc+1)&3)<<3 | byte(x.form&3)<<1 + if x.neg { + b |= 1 + } + buf[1] = b + binary.BigEndian.PutUint32(buf[2:], x.prec) + + if x.form == finite { + binary.BigEndian.PutUint32(buf[6:], uint32(x.exp)) + x.mant[len(x.mant)-n:].bytes(buf[10:]) // cut off unused trailing words + } + + return buf, nil +} + +// GobDecode implements the gob.GobDecoder interface. +// The result is rounded per the precision and rounding mode of +// z unless z's precision is 0, in which case z is set exactly +// to the decoded value. +func (z *Float) GobDecode(buf []byte) error { + if len(buf) == 0 { + // Other side sent a nil or default value. + *z = Float{} + return nil + } + if len(buf) < 6 { + return errors.New("Float.GobDecode: buffer too small") + } + + if buf[0] != floatGobVersion { + return fmt.Errorf("Float.GobDecode: encoding version %d not supported", buf[0]) + } + + oldPrec := z.prec + oldMode := z.mode + + b := buf[1] + z.mode = RoundingMode((b >> 5) & 7) + z.acc = Accuracy((b>>3)&3) - 1 + z.form = form((b >> 1) & 3) + z.neg = b&1 != 0 + z.prec = binary.BigEndian.Uint32(buf[2:]) + + if z.form == finite { + if len(buf) < 10 { + return errors.New("Float.GobDecode: buffer too small for finite form float") + } + z.exp = int32(binary.BigEndian.Uint32(buf[6:])) + z.mant = z.mant.setBytes(buf[10:]) + } + + if oldPrec != 0 { + z.mode = oldMode + z.SetPrec(uint(oldPrec)) + } + + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. +// Only the Float value is marshaled (in full precision), other +// attributes such as precision or accuracy are ignored. +func (x *Float) MarshalText() (text []byte, err error) { + if x == nil { + return []byte("<nil>"), nil + } + var buf []byte + return x.Append(buf, 'g', -1), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The result is rounded per the precision and rounding mode of z. +// If z's precision is 0, it is changed to 64 before rounding takes +// effect. +func (z *Float) UnmarshalText(text []byte) error { + // TODO(gri): get rid of the []byte/string conversion + _, _, err := z.Parse(string(text), 0) + if err != nil { + err = fmt.Errorf("math/big: cannot unmarshal %q into a *big.Float (%v)", text, err) + } + return err +} diff --git a/src/math/big/floatmarsh_test.go b/src/math/big/floatmarsh_test.go new file mode 100644 index 0000000..401f45a --- /dev/null +++ b/src/math/big/floatmarsh_test.go @@ -0,0 +1,151 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "bytes" + "encoding/gob" + "encoding/json" + "io" + "testing" +) + +var floatVals = []string{ + "0", + "1", + "0.1", + "2.71828", + "1234567890", + "3.14e1234", + "3.14e-1234", + "0.738957395793475734757349579759957975985497e100", + "0.73895739579347546656564656573475734957975995797598589749859834759476745986795497e100", + "inf", + "Inf", +} + +func TestFloatGobEncoding(t *testing.T) { + var medium bytes.Buffer + enc := gob.NewEncoder(&medium) + dec := gob.NewDecoder(&medium) + for _, test := range floatVals { + for _, sign := range []string{"", "+", "-"} { + for _, prec := range []uint{0, 1, 2, 10, 53, 64, 100, 1000} { + for _, mode := range []RoundingMode{ToNearestEven, ToNearestAway, ToZero, AwayFromZero, ToNegativeInf, ToPositiveInf} { + medium.Reset() // empty buffer for each test case (in case of failures) + x := sign + test + + var tx Float + _, _, err := tx.SetPrec(prec).SetMode(mode).Parse(x, 0) + if err != nil { + t.Errorf("parsing of %s (%dbits, %v) failed (invalid test case): %v", x, prec, mode, err) + continue + } + + // If tx was set to prec == 0, tx.Parse(x, 0) assumes precision 64. Correct it. + if prec == 0 { + tx.SetPrec(0) + } + + if err := enc.Encode(&tx); err != nil { + t.Errorf("encoding of %v (%dbits, %v) failed: %v", &tx, prec, mode, err) + continue + } + + var rx Float + if err := dec.Decode(&rx); err != nil { + t.Errorf("decoding of %v (%dbits, %v) failed: %v", &tx, prec, mode, err) + continue + } + + if rx.Cmp(&tx) != 0 { + t.Errorf("transmission of %s failed: got %s want %s", x, rx.String(), tx.String()) + continue + } + + if rx.Prec() != prec { + t.Errorf("transmission of %s's prec failed: got %d want %d", x, rx.Prec(), prec) + } + + if rx.Mode() != mode { + t.Errorf("transmission of %s's mode failed: got %s want %s", x, rx.Mode(), mode) + } + + if rx.Acc() != tx.Acc() { + t.Errorf("transmission of %s's accuracy failed: got %s want %s", x, rx.Acc(), tx.Acc()) + } + } + } + } + } +} + +func TestFloatCorruptGob(t *testing.T) { + var buf bytes.Buffer + tx := NewFloat(4 / 3).SetPrec(1000).SetMode(ToPositiveInf) + if err := gob.NewEncoder(&buf).Encode(tx); err != nil { + t.Fatal(err) + } + b := buf.Bytes() + + var rx Float + if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&rx); err != nil { + t.Fatal(err) + } + + if err := gob.NewDecoder(bytes.NewReader(b[:10])).Decode(&rx); err != io.ErrUnexpectedEOF { + t.Errorf("got %v want EOF", err) + } + + b[1] = 0 + if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&rx); err == nil { + t.Fatal("got nil want version error") + } +} + +func TestFloatJSONEncoding(t *testing.T) { + for _, test := range floatVals { + for _, sign := range []string{"", "+", "-"} { + for _, prec := range []uint{0, 1, 2, 10, 53, 64, 100, 1000} { + if prec > 53 && testing.Short() { + continue + } + x := sign + test + var tx Float + _, _, err := tx.SetPrec(prec).Parse(x, 0) + if err != nil { + t.Errorf("parsing of %s (prec = %d) failed (invalid test case): %v", x, prec, err) + continue + } + b, err := json.Marshal(&tx) + if err != nil { + t.Errorf("marshaling of %v (prec = %d) failed: %v", &tx, prec, err) + continue + } + var rx Float + rx.SetPrec(prec) + if err := json.Unmarshal(b, &rx); err != nil { + t.Errorf("unmarshaling of %v (prec = %d) failed: %v", &tx, prec, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("JSON encoding of %v (prec = %d) failed: got %v want %v", &tx, prec, &rx, &tx) + } + } + } + } +} + +func TestFloatGobDecodeShortBuffer(t *testing.T) { + for _, tc := range [][]byte{ + []byte{0x1, 0x0, 0x0, 0x0}, + []byte{0x1, 0xfa, 0x0, 0x0, 0x0, 0x0}, + } { + err := NewFloat(0).GobDecode(tc) + if err == nil { + t.Error("expected GobDecode to return error for malformed input") + } + } +} diff --git a/src/math/big/ftoa.go b/src/math/big/ftoa.go new file mode 100644 index 0000000..5506e6e --- /dev/null +++ b/src/math/big/ftoa.go @@ -0,0 +1,536 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements Float-to-string conversion functions. +// It is closely following the corresponding implementation +// in strconv/ftoa.go, but modified and simplified for Float. + +package big + +import ( + "bytes" + "fmt" + "strconv" +) + +// Text converts the floating-point number x to a string according +// to the given format and precision prec. The format is one of: +// +// 'e' -d.dddde±dd, decimal exponent, at least two (possibly 0) exponent digits +// 'E' -d.ddddE±dd, decimal exponent, at least two (possibly 0) exponent digits +// 'f' -ddddd.dddd, no exponent +// 'g' like 'e' for large exponents, like 'f' otherwise +// 'G' like 'E' for large exponents, like 'f' otherwise +// 'x' -0xd.dddddp±dd, hexadecimal mantissa, decimal power of two exponent +// 'p' -0x.dddp±dd, hexadecimal mantissa, decimal power of two exponent (non-standard) +// 'b' -ddddddp±dd, decimal mantissa, decimal power of two exponent (non-standard) +// +// For the power-of-two exponent formats, the mantissa is printed in normalized form: +// +// 'x' hexadecimal mantissa in [1, 2), or 0 +// 'p' hexadecimal mantissa in [½, 1), or 0 +// 'b' decimal integer mantissa using x.Prec() bits, or 0 +// +// Note that the 'x' form is the one used by most other languages and libraries. +// +// If format is a different character, Text returns a "%" followed by the +// unrecognized format character. +// +// The precision prec controls the number of digits (excluding the exponent) +// printed by the 'e', 'E', 'f', 'g', 'G', and 'x' formats. +// For 'e', 'E', 'f', and 'x', it is the number of digits after the decimal point. +// For 'g' and 'G' it is the total number of digits. A negative precision selects +// the smallest number of decimal digits necessary to identify the value x uniquely +// using x.Prec() mantissa bits. +// The prec value is ignored for the 'b' and 'p' formats. +func (x *Float) Text(format byte, prec int) string { + cap := 10 // TODO(gri) determine a good/better value here + if prec > 0 { + cap += prec + } + return string(x.Append(make([]byte, 0, cap), format, prec)) +} + +// String formats x like x.Text('g', 10). +// (String must be called explicitly, Float.Format does not support %s verb.) +func (x *Float) String() string { + return x.Text('g', 10) +} + +// Append appends to buf the string form of the floating-point number x, +// as generated by x.Text, and returns the extended buffer. +func (x *Float) Append(buf []byte, fmt byte, prec int) []byte { + // sign + if x.neg { + buf = append(buf, '-') + } + + // Inf + if x.form == inf { + if !x.neg { + buf = append(buf, '+') + } + return append(buf, "Inf"...) + } + + // pick off easy formats + switch fmt { + case 'b': + return x.fmtB(buf) + case 'p': + return x.fmtP(buf) + case 'x': + return x.fmtX(buf, prec) + } + + // Algorithm: + // 1) convert Float to multiprecision decimal + // 2) round to desired precision + // 3) read digits out and format + + // 1) convert Float to multiprecision decimal + var d decimal // == 0.0 + if x.form == finite { + // x != 0 + d.init(x.mant, int(x.exp)-x.mant.bitLen()) + } + + // 2) round to desired precision + shortest := false + if prec < 0 { + shortest = true + roundShortest(&d, x) + // Precision for shortest representation mode. + switch fmt { + case 'e', 'E': + prec = len(d.mant) - 1 + case 'f': + prec = max(len(d.mant)-d.exp, 0) + case 'g', 'G': + prec = len(d.mant) + } + } else { + // round appropriately + switch fmt { + case 'e', 'E': + // one digit before and number of digits after decimal point + d.round(1 + prec) + case 'f': + // number of digits before and after decimal point + d.round(d.exp + prec) + case 'g', 'G': + if prec == 0 { + prec = 1 + } + d.round(prec) + } + } + + // 3) read digits out and format + switch fmt { + case 'e', 'E': + return fmtE(buf, fmt, prec, d) + case 'f': + return fmtF(buf, prec, d) + case 'g', 'G': + // trim trailing fractional zeros in %e format + eprec := prec + if eprec > len(d.mant) && len(d.mant) >= d.exp { + eprec = len(d.mant) + } + // %e is used if the exponent from the conversion + // is less than -4 or greater than or equal to the precision. + // If precision was the shortest possible, use eprec = 6 for + // this decision. + if shortest { + eprec = 6 + } + exp := d.exp - 1 + if exp < -4 || exp >= eprec { + if prec > len(d.mant) { + prec = len(d.mant) + } + return fmtE(buf, fmt+'e'-'g', prec-1, d) + } + if prec > d.exp { + prec = len(d.mant) + } + return fmtF(buf, max(prec-d.exp, 0), d) + } + + // unknown format + if x.neg { + buf = buf[:len(buf)-1] // sign was added prematurely - remove it again + } + return append(buf, '%', fmt) +} + +func roundShortest(d *decimal, x *Float) { + // if the mantissa is zero, the number is zero - stop now + if len(d.mant) == 0 { + return + } + + // Approach: All numbers in the interval [x - 1/2ulp, x + 1/2ulp] + // (possibly exclusive) round to x for the given precision of x. + // Compute the lower and upper bound in decimal form and find the + // shortest decimal number d such that lower <= d <= upper. + + // TODO(gri) strconv/ftoa.do describes a shortcut in some cases. + // See if we can use it (in adjusted form) here as well. + + // 1) Compute normalized mantissa mant and exponent exp for x such + // that the lsb of mant corresponds to 1/2 ulp for the precision of + // x (i.e., for mant we want x.prec + 1 bits). + mant := nat(nil).set(x.mant) + exp := int(x.exp) - mant.bitLen() + s := mant.bitLen() - int(x.prec+1) + switch { + case s < 0: + mant = mant.shl(mant, uint(-s)) + case s > 0: + mant = mant.shr(mant, uint(+s)) + } + exp += s + // x = mant * 2**exp with lsb(mant) == 1/2 ulp of x.prec + + // 2) Compute lower bound by subtracting 1/2 ulp. + var lower decimal + var tmp nat + lower.init(tmp.sub(mant, natOne), exp) + + // 3) Compute upper bound by adding 1/2 ulp. + var upper decimal + upper.init(tmp.add(mant, natOne), exp) + + // The upper and lower bounds are possible outputs only if + // the original mantissa is even, so that ToNearestEven rounding + // would round to the original mantissa and not the neighbors. + inclusive := mant[0]&2 == 0 // test bit 1 since original mantissa was shifted by 1 + + // Now we can figure out the minimum number of digits required. + // Walk along until d has distinguished itself from upper and lower. + for i, m := range d.mant { + l := lower.at(i) + u := upper.at(i) + + // Okay to round down (truncate) if lower has a different digit + // or if lower is inclusive and is exactly the result of rounding + // down (i.e., and we have reached the final digit of lower). + okdown := l != m || inclusive && i+1 == len(lower.mant) + + // Okay to round up if upper has a different digit and either upper + // is inclusive or upper is bigger than the result of rounding up. + okup := m != u && (inclusive || m+1 < u || i+1 < len(upper.mant)) + + // If it's okay to do either, then round to the nearest one. + // If it's okay to do only one, do it. + switch { + case okdown && okup: + d.round(i + 1) + return + case okdown: + d.roundDown(i + 1) + return + case okup: + d.roundUp(i + 1) + return + } + } +} + +// %e: d.ddddde±dd +func fmtE(buf []byte, fmt byte, prec int, d decimal) []byte { + // first digit + ch := byte('0') + if len(d.mant) > 0 { + ch = d.mant[0] + } + buf = append(buf, ch) + + // .moredigits + if prec > 0 { + buf = append(buf, '.') + i := 1 + m := min(len(d.mant), prec+1) + if i < m { + buf = append(buf, d.mant[i:m]...) + i = m + } + for ; i <= prec; i++ { + buf = append(buf, '0') + } + } + + // e± + buf = append(buf, fmt) + var exp int64 + if len(d.mant) > 0 { + exp = int64(d.exp) - 1 // -1 because first digit was printed before '.' + } + if exp < 0 { + ch = '-' + exp = -exp + } else { + ch = '+' + } + buf = append(buf, ch) + + // dd...d + if exp < 10 { + buf = append(buf, '0') // at least 2 exponent digits + } + return strconv.AppendInt(buf, exp, 10) +} + +// %f: ddddddd.ddddd +func fmtF(buf []byte, prec int, d decimal) []byte { + // integer, padded with zeros as needed + if d.exp > 0 { + m := min(len(d.mant), d.exp) + buf = append(buf, d.mant[:m]...) + for ; m < d.exp; m++ { + buf = append(buf, '0') + } + } else { + buf = append(buf, '0') + } + + // fraction + if prec > 0 { + buf = append(buf, '.') + for i := 0; i < prec; i++ { + buf = append(buf, d.at(d.exp+i)) + } + } + + return buf +} + +// fmtB appends the string of x in the format mantissa "p" exponent +// with a decimal mantissa and a binary exponent, or 0" if x is zero, +// and returns the extended buffer. +// The mantissa is normalized such that is uses x.Prec() bits in binary +// representation. +// The sign of x is ignored, and x must not be an Inf. +// (The caller handles Inf before invoking fmtB.) +func (x *Float) fmtB(buf []byte) []byte { + if x.form == zero { + return append(buf, '0') + } + + if debugFloat && x.form != finite { + panic("non-finite float") + } + // x != 0 + + // adjust mantissa to use exactly x.prec bits + m := x.mant + switch w := uint32(len(x.mant)) * _W; { + case w < x.prec: + m = nat(nil).shl(m, uint(x.prec-w)) + case w > x.prec: + m = nat(nil).shr(m, uint(w-x.prec)) + } + + buf = append(buf, m.utoa(10)...) + buf = append(buf, 'p') + e := int64(x.exp) - int64(x.prec) + if e >= 0 { + buf = append(buf, '+') + } + return strconv.AppendInt(buf, e, 10) +} + +// fmtX appends the string of x in the format "0x1." mantissa "p" exponent +// with a hexadecimal mantissa and a binary exponent, or "0x0p0" if x is zero, +// and returns the extended buffer. +// A non-zero mantissa is normalized such that 1.0 <= mantissa < 2.0. +// The sign of x is ignored, and x must not be an Inf. +// (The caller handles Inf before invoking fmtX.) +func (x *Float) fmtX(buf []byte, prec int) []byte { + if x.form == zero { + buf = append(buf, "0x0"...) + if prec > 0 { + buf = append(buf, '.') + for i := 0; i < prec; i++ { + buf = append(buf, '0') + } + } + buf = append(buf, "p+00"...) + return buf + } + + if debugFloat && x.form != finite { + panic("non-finite float") + } + + // round mantissa to n bits + var n uint + if prec < 0 { + n = 1 + (x.MinPrec()-1+3)/4*4 // round MinPrec up to 1 mod 4 + } else { + n = 1 + 4*uint(prec) + } + // n%4 == 1 + x = new(Float).SetPrec(n).SetMode(x.mode).Set(x) + + // adjust mantissa to use exactly n bits + m := x.mant + switch w := uint(len(x.mant)) * _W; { + case w < n: + m = nat(nil).shl(m, n-w) + case w > n: + m = nat(nil).shr(m, w-n) + } + exp64 := int64(x.exp) - 1 // avoid wrap-around + + hm := m.utoa(16) + if debugFloat && hm[0] != '1' { + panic("incorrect mantissa: " + string(hm)) + } + buf = append(buf, "0x1"...) + if len(hm) > 1 { + buf = append(buf, '.') + buf = append(buf, hm[1:]...) + } + + buf = append(buf, 'p') + if exp64 >= 0 { + buf = append(buf, '+') + } else { + exp64 = -exp64 + buf = append(buf, '-') + } + // Force at least two exponent digits, to match fmt. + if exp64 < 10 { + buf = append(buf, '0') + } + return strconv.AppendInt(buf, exp64, 10) +} + +// fmtP appends the string of x in the format "0x." mantissa "p" exponent +// with a hexadecimal mantissa and a binary exponent, or "0" if x is zero, +// and returns the extended buffer. +// The mantissa is normalized such that 0.5 <= 0.mantissa < 1.0. +// The sign of x is ignored, and x must not be an Inf. +// (The caller handles Inf before invoking fmtP.) +func (x *Float) fmtP(buf []byte) []byte { + if x.form == zero { + return append(buf, '0') + } + + if debugFloat && x.form != finite { + panic("non-finite float") + } + // x != 0 + + // remove trailing 0 words early + // (no need to convert to hex 0's and trim later) + m := x.mant + i := 0 + for i < len(m) && m[i] == 0 { + i++ + } + m = m[i:] + + buf = append(buf, "0x."...) + buf = append(buf, bytes.TrimRight(m.utoa(16), "0")...) + buf = append(buf, 'p') + if x.exp >= 0 { + buf = append(buf, '+') + } + return strconv.AppendInt(buf, int64(x.exp), 10) +} + +func min(x, y int) int { + if x < y { + return x + } + return y +} + +var _ fmt.Formatter = &floatZero // *Float must implement fmt.Formatter + +// Format implements fmt.Formatter. It accepts all the regular +// formats for floating-point numbers ('b', 'e', 'E', 'f', 'F', +// 'g', 'G', 'x') as well as 'p' and 'v'. See (*Float).Text for the +// interpretation of 'p'. The 'v' format is handled like 'g'. +// Format also supports specification of the minimum precision +// in digits, the output field width, as well as the format flags +// '+' and ' ' for sign control, '0' for space or zero padding, +// and '-' for left or right justification. See the fmt package +// for details. +func (x *Float) Format(s fmt.State, format rune) { + prec, hasPrec := s.Precision() + if !hasPrec { + prec = 6 // default precision for 'e', 'f' + } + + switch format { + case 'e', 'E', 'f', 'b', 'p', 'x': + // nothing to do + case 'F': + // (*Float).Text doesn't support 'F'; handle like 'f' + format = 'f' + case 'v': + // handle like 'g' + format = 'g' + fallthrough + case 'g', 'G': + if !hasPrec { + prec = -1 // default precision for 'g', 'G' + } + default: + fmt.Fprintf(s, "%%!%c(*big.Float=%s)", format, x.String()) + return + } + var buf []byte + buf = x.Append(buf, byte(format), prec) + if len(buf) == 0 { + buf = []byte("?") // should never happen, but don't crash + } + // len(buf) > 0 + + var sign string + switch { + case buf[0] == '-': + sign = "-" + buf = buf[1:] + case buf[0] == '+': + // +Inf + sign = "+" + if s.Flag(' ') { + sign = " " + } + buf = buf[1:] + case s.Flag('+'): + sign = "+" + case s.Flag(' '): + sign = " " + } + + var padding int + if width, hasWidth := s.Width(); hasWidth && width > len(sign)+len(buf) { + padding = width - len(sign) - len(buf) + } + + switch { + case s.Flag('0') && !x.IsInf(): + // 0-padding on left + writeMultiple(s, sign, 1) + writeMultiple(s, "0", padding) + s.Write(buf) + case s.Flag('-'): + // padding on right + writeMultiple(s, sign, 1) + s.Write(buf) + writeMultiple(s, " ", padding) + default: + // padding on left + writeMultiple(s, " ", padding) + writeMultiple(s, sign, 1) + s.Write(buf) + } +} diff --git a/src/math/big/gcd_test.go b/src/math/big/gcd_test.go new file mode 100644 index 0000000..3cca2ec --- /dev/null +++ b/src/math/big/gcd_test.go @@ -0,0 +1,64 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements a GCD benchmark. +// Usage: go test math/big -test.bench GCD + +package big + +import ( + "math/rand" + "testing" +) + +// randInt returns a pseudo-random Int in the range [1<<(size-1), (1<<size) - 1] +func randInt(r *rand.Rand, size uint) *Int { + n := new(Int).Lsh(intOne, size-1) + x := new(Int).Rand(r, n) + return x.Add(x, n) // make sure result > 1<<(size-1) +} + +func runGCD(b *testing.B, aSize, bSize uint) { + if isRaceBuilder && (aSize > 1000 || bSize > 1000) { + b.Skip("skipping on race builder") + } + b.Run("WithoutXY", func(b *testing.B) { + runGCDExt(b, aSize, bSize, false) + }) + b.Run("WithXY", func(b *testing.B) { + runGCDExt(b, aSize, bSize, true) + }) +} + +func runGCDExt(b *testing.B, aSize, bSize uint, calcXY bool) { + b.StopTimer() + var r = rand.New(rand.NewSource(1234)) + aa := randInt(r, aSize) + bb := randInt(r, bSize) + var x, y *Int + if calcXY { + x = new(Int) + y = new(Int) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + new(Int).GCD(x, y, aa, bb) + } +} + +func BenchmarkGCD10x10(b *testing.B) { runGCD(b, 10, 10) } +func BenchmarkGCD10x100(b *testing.B) { runGCD(b, 10, 100) } +func BenchmarkGCD10x1000(b *testing.B) { runGCD(b, 10, 1000) } +func BenchmarkGCD10x10000(b *testing.B) { runGCD(b, 10, 10000) } +func BenchmarkGCD10x100000(b *testing.B) { runGCD(b, 10, 100000) } +func BenchmarkGCD100x100(b *testing.B) { runGCD(b, 100, 100) } +func BenchmarkGCD100x1000(b *testing.B) { runGCD(b, 100, 1000) } +func BenchmarkGCD100x10000(b *testing.B) { runGCD(b, 100, 10000) } +func BenchmarkGCD100x100000(b *testing.B) { runGCD(b, 100, 100000) } +func BenchmarkGCD1000x1000(b *testing.B) { runGCD(b, 1000, 1000) } +func BenchmarkGCD1000x10000(b *testing.B) { runGCD(b, 1000, 10000) } +func BenchmarkGCD1000x100000(b *testing.B) { runGCD(b, 1000, 100000) } +func BenchmarkGCD10000x10000(b *testing.B) { runGCD(b, 10000, 10000) } +func BenchmarkGCD10000x100000(b *testing.B) { runGCD(b, 10000, 100000) } +func BenchmarkGCD100000x100000(b *testing.B) { runGCD(b, 100000, 100000) } diff --git a/src/math/big/hilbert_test.go b/src/math/big/hilbert_test.go new file mode 100644 index 0000000..1a84341 --- /dev/null +++ b/src/math/big/hilbert_test.go @@ -0,0 +1,160 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// A little test program and benchmark for rational arithmetics. +// Computes a Hilbert matrix, its inverse, multiplies them +// and verifies that the product is the identity matrix. + +package big + +import ( + "fmt" + "testing" +) + +type matrix struct { + n, m int + a []*Rat +} + +func (a *matrix) at(i, j int) *Rat { + if !(0 <= i && i < a.n && 0 <= j && j < a.m) { + panic("index out of range") + } + return a.a[i*a.m+j] +} + +func (a *matrix) set(i, j int, x *Rat) { + if !(0 <= i && i < a.n && 0 <= j && j < a.m) { + panic("index out of range") + } + a.a[i*a.m+j] = x +} + +func newMatrix(n, m int) *matrix { + if !(0 <= n && 0 <= m) { + panic("illegal matrix") + } + a := new(matrix) + a.n = n + a.m = m + a.a = make([]*Rat, n*m) + return a +} + +func newUnit(n int) *matrix { + a := newMatrix(n, n) + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + x := NewRat(0, 1) + if i == j { + x.SetInt64(1) + } + a.set(i, j, x) + } + } + return a +} + +func newHilbert(n int) *matrix { + a := newMatrix(n, n) + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + a.set(i, j, NewRat(1, int64(i+j+1))) + } + } + return a +} + +func newInverseHilbert(n int) *matrix { + a := newMatrix(n, n) + for i := 0; i < n; i++ { + for j := 0; j < n; j++ { + x1 := new(Rat).SetInt64(int64(i + j + 1)) + x2 := new(Rat).SetInt(new(Int).Binomial(int64(n+i), int64(n-j-1))) + x3 := new(Rat).SetInt(new(Int).Binomial(int64(n+j), int64(n-i-1))) + x4 := new(Rat).SetInt(new(Int).Binomial(int64(i+j), int64(i))) + + x1.Mul(x1, x2) + x1.Mul(x1, x3) + x1.Mul(x1, x4) + x1.Mul(x1, x4) + + if (i+j)&1 != 0 { + x1.Neg(x1) + } + + a.set(i, j, x1) + } + } + return a +} + +func (a *matrix) mul(b *matrix) *matrix { + if a.m != b.n { + panic("illegal matrix multiply") + } + c := newMatrix(a.n, b.m) + for i := 0; i < c.n; i++ { + for j := 0; j < c.m; j++ { + x := NewRat(0, 1) + for k := 0; k < a.m; k++ { + x.Add(x, new(Rat).Mul(a.at(i, k), b.at(k, j))) + } + c.set(i, j, x) + } + } + return c +} + +func (a *matrix) eql(b *matrix) bool { + if a.n != b.n || a.m != b.m { + return false + } + for i := 0; i < a.n; i++ { + for j := 0; j < a.m; j++ { + if a.at(i, j).Cmp(b.at(i, j)) != 0 { + return false + } + } + } + return true +} + +func (a *matrix) String() string { + s := "" + for i := 0; i < a.n; i++ { + for j := 0; j < a.m; j++ { + s += fmt.Sprintf("\t%s", a.at(i, j)) + } + s += "\n" + } + return s +} + +func doHilbert(t *testing.T, n int) { + a := newHilbert(n) + b := newInverseHilbert(n) + I := newUnit(n) + ab := a.mul(b) + if !ab.eql(I) { + if t == nil { + panic("Hilbert failed") + } + t.Errorf("a = %s\n", a) + t.Errorf("b = %s\n", b) + t.Errorf("a*b = %s\n", ab) + t.Errorf("I = %s\n", I) + } +} + +func TestHilbert(t *testing.T) { + doHilbert(t, 10) +} + +func BenchmarkHilbert(b *testing.B) { + for i := 0; i < b.N; i++ { + doHilbert(nil, 10) + } +} diff --git a/src/math/big/int.go b/src/math/big/int.go new file mode 100644 index 0000000..76d6eb9 --- /dev/null +++ b/src/math/big/int.go @@ -0,0 +1,1293 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements signed multi-precision integers. + +package big + +import ( + "fmt" + "io" + "math/rand" + "strings" +) + +// An Int represents a signed multi-precision integer. +// The zero value for an Int represents the value 0. +// +// Operations always take pointer arguments (*Int) rather +// than Int values, and each unique Int value requires +// its own unique *Int pointer. To "copy" an Int value, +// an existing (or newly allocated) Int must be set to +// a new value using the Int.Set method; shallow copies +// of Ints are not supported and may lead to errors. +type Int struct { + neg bool // sign + abs nat // absolute value of the integer +} + +var intOne = &Int{false, natOne} + +// Sign returns: +// +// -1 if x < 0 +// 0 if x == 0 +// +1 if x > 0 +func (x *Int) Sign() int { + // This function is used in cryptographic operations. It must not leak + // anything but the Int's sign and bit size through side-channels. Any + // changes must be reviewed by a security expert. + if len(x.abs) == 0 { + return 0 + } + if x.neg { + return -1 + } + return 1 +} + +// SetInt64 sets z to x and returns z. +func (z *Int) SetInt64(x int64) *Int { + neg := false + if x < 0 { + neg = true + x = -x + } + z.abs = z.abs.setUint64(uint64(x)) + z.neg = neg + return z +} + +// SetUint64 sets z to x and returns z. +func (z *Int) SetUint64(x uint64) *Int { + z.abs = z.abs.setUint64(x) + z.neg = false + return z +} + +// NewInt allocates and returns a new Int set to x. +func NewInt(x int64) *Int { + // This code is arranged to be inlineable and produce + // zero allocations when inlined. See issue 29951. + u := uint64(x) + if x < 0 { + u = -u + } + var abs []Word + if x == 0 { + } else if _W == 32 && u>>32 != 0 { + abs = []Word{Word(u), Word(u >> 32)} + } else { + abs = []Word{Word(u)} + } + return &Int{neg: x < 0, abs: abs} +} + +// Set sets z to x and returns z. +func (z *Int) Set(x *Int) *Int { + if z != x { + z.abs = z.abs.set(x.abs) + z.neg = x.neg + } + return z +} + +// Bits provides raw (unchecked but fast) access to x by returning its +// absolute value as a little-endian Word slice. The result and x share +// the same underlying array. +// Bits is intended to support implementation of missing low-level Int +// functionality outside this package; it should be avoided otherwise. +func (x *Int) Bits() []Word { + // This function is used in cryptographic operations. It must not leak + // anything but the Int's sign and bit size through side-channels. Any + // changes must be reviewed by a security expert. + return x.abs +} + +// SetBits provides raw (unchecked but fast) access to z by setting its +// value to abs, interpreted as a little-endian Word slice, and returning +// z. The result and abs share the same underlying array. +// SetBits is intended to support implementation of missing low-level Int +// functionality outside this package; it should be avoided otherwise. +func (z *Int) SetBits(abs []Word) *Int { + z.abs = nat(abs).norm() + z.neg = false + return z +} + +// Abs sets z to |x| (the absolute value of x) and returns z. +func (z *Int) Abs(x *Int) *Int { + z.Set(x) + z.neg = false + return z +} + +// Neg sets z to -x and returns z. +func (z *Int) Neg(x *Int) *Int { + z.Set(x) + z.neg = len(z.abs) > 0 && !z.neg // 0 has no sign + return z +} + +// Add sets z to the sum x+y and returns z. +func (z *Int) Add(x, y *Int) *Int { + neg := x.neg + if x.neg == y.neg { + // x + y == x + y + // (-x) + (-y) == -(x + y) + z.abs = z.abs.add(x.abs, y.abs) + } else { + // x + (-y) == x - y == -(y - x) + // (-x) + y == y - x == -(x - y) + if x.abs.cmp(y.abs) >= 0 { + z.abs = z.abs.sub(x.abs, y.abs) + } else { + neg = !neg + z.abs = z.abs.sub(y.abs, x.abs) + } + } + z.neg = len(z.abs) > 0 && neg // 0 has no sign + return z +} + +// Sub sets z to the difference x-y and returns z. +func (z *Int) Sub(x, y *Int) *Int { + neg := x.neg + if x.neg != y.neg { + // x - (-y) == x + y + // (-x) - y == -(x + y) + z.abs = z.abs.add(x.abs, y.abs) + } else { + // x - y == x - y == -(y - x) + // (-x) - (-y) == y - x == -(x - y) + if x.abs.cmp(y.abs) >= 0 { + z.abs = z.abs.sub(x.abs, y.abs) + } else { + neg = !neg + z.abs = z.abs.sub(y.abs, x.abs) + } + } + z.neg = len(z.abs) > 0 && neg // 0 has no sign + return z +} + +// Mul sets z to the product x*y and returns z. +func (z *Int) Mul(x, y *Int) *Int { + // x * y == x * y + // x * (-y) == -(x * y) + // (-x) * y == -(x * y) + // (-x) * (-y) == x * y + if x == y { + z.abs = z.abs.sqr(x.abs) + z.neg = false + return z + } + z.abs = z.abs.mul(x.abs, y.abs) + z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign + return z +} + +// MulRange sets z to the product of all integers +// in the range [a, b] inclusively and returns z. +// If a > b (empty range), the result is 1. +func (z *Int) MulRange(a, b int64) *Int { + switch { + case a > b: + return z.SetInt64(1) // empty range + case a <= 0 && b >= 0: + return z.SetInt64(0) // range includes 0 + } + // a <= b && (b < 0 || a > 0) + + neg := false + if a < 0 { + neg = (b-a)&1 == 0 + a, b = -b, -a + } + + z.abs = z.abs.mulRange(uint64(a), uint64(b)) + z.neg = neg + return z +} + +// Binomial sets z to the binomial coefficient C(n, k) and returns z. +func (z *Int) Binomial(n, k int64) *Int { + if k > n { + return z.SetInt64(0) + } + // reduce the number of multiplications by reducing k + if k > n-k { + k = n - k // C(n, k) == C(n, n-k) + } + // C(n, k) == n * (n-1) * ... * (n-k+1) / k * (k-1) * ... * 1 + // == n * (n-1) * ... * (n-k+1) / 1 * (1+1) * ... * k + // + // Using the multiplicative formula produces smaller values + // at each step, requiring fewer allocations and computations: + // + // z = 1 + // for i := 0; i < k; i = i+1 { + // z *= n-i + // z /= i+1 + // } + // + // finally to avoid computing i+1 twice per loop: + // + // z = 1 + // i := 0 + // for i < k { + // z *= n-i + // i++ + // z /= i + // } + var N, K, i, t Int + N.SetInt64(n) + K.SetInt64(k) + z.Set(intOne) + for i.Cmp(&K) < 0 { + z.Mul(z, t.Sub(&N, &i)) + i.Add(&i, intOne) + z.Quo(z, &i) + } + return z +} + +// Quo sets z to the quotient x/y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// Quo implements truncated division (like Go); see QuoRem for more details. +func (z *Int) Quo(x, y *Int) *Int { + z.abs, _ = z.abs.div(nil, x.abs, y.abs) + z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign + return z +} + +// Rem sets z to the remainder x%y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// Rem implements truncated modulus (like Go); see QuoRem for more details. +func (z *Int) Rem(x, y *Int) *Int { + _, z.abs = nat(nil).div(z.abs, x.abs, y.abs) + z.neg = len(z.abs) > 0 && x.neg // 0 has no sign + return z +} + +// QuoRem sets z to the quotient x/y and r to the remainder x%y +// and returns the pair (z, r) for y != 0. +// If y == 0, a division-by-zero run-time panic occurs. +// +// QuoRem implements T-division and modulus (like Go): +// +// q = x/y with the result truncated to zero +// r = x - y*q +// +// (See Daan Leijen, “Division and Modulus for Computer Scientists”.) +// See DivMod for Euclidean division and modulus (unlike Go). +func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) { + z.abs, r.abs = z.abs.div(r.abs, x.abs, y.abs) + z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg // 0 has no sign + return z, r +} + +// Div sets z to the quotient x/y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// Div implements Euclidean division (unlike Go); see DivMod for more details. +func (z *Int) Div(x, y *Int) *Int { + y_neg := y.neg // z may be an alias for y + var r Int + z.QuoRem(x, y, &r) + if r.neg { + if y_neg { + z.Add(z, intOne) + } else { + z.Sub(z, intOne) + } + } + return z +} + +// Mod sets z to the modulus x%y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// Mod implements Euclidean modulus (unlike Go); see DivMod for more details. +func (z *Int) Mod(x, y *Int) *Int { + y0 := y // save y + if z == y || alias(z.abs, y.abs) { + y0 = new(Int).Set(y) + } + var q Int + q.QuoRem(x, y, z) + if z.neg { + if y0.neg { + z.Sub(z, y0) + } else { + z.Add(z, y0) + } + } + return z +} + +// DivMod sets z to the quotient x div y and m to the modulus x mod y +// and returns the pair (z, m) for y != 0. +// If y == 0, a division-by-zero run-time panic occurs. +// +// DivMod implements Euclidean division and modulus (unlike Go): +// +// q = x div y such that +// m = x - y*q with 0 <= m < |y| +// +// (See Raymond T. Boute, “The Euclidean definition of the functions +// div and mod”. ACM Transactions on Programming Languages and +// Systems (TOPLAS), 14(2):127-144, New York, NY, USA, 4/1992. +// ACM press.) +// See QuoRem for T-division and modulus (like Go). +func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) { + y0 := y // save y + if z == y || alias(z.abs, y.abs) { + y0 = new(Int).Set(y) + } + z.QuoRem(x, y, m) + if m.neg { + if y0.neg { + z.Add(z, intOne) + m.Sub(m, y0) + } else { + z.Sub(z, intOne) + m.Add(m, y0) + } + } + return z, m +} + +// Cmp compares x and y and returns: +// +// -1 if x < y +// 0 if x == y +// +1 if x > y +func (x *Int) Cmp(y *Int) (r int) { + // x cmp y == x cmp y + // x cmp (-y) == x + // (-x) cmp y == y + // (-x) cmp (-y) == -(x cmp y) + switch { + case x == y: + // nothing to do + case x.neg == y.neg: + r = x.abs.cmp(y.abs) + if x.neg { + r = -r + } + case x.neg: + r = -1 + default: + r = 1 + } + return +} + +// CmpAbs compares the absolute values of x and y and returns: +// +// -1 if |x| < |y| +// 0 if |x| == |y| +// +1 if |x| > |y| +func (x *Int) CmpAbs(y *Int) int { + return x.abs.cmp(y.abs) +} + +// low32 returns the least significant 32 bits of x. +func low32(x nat) uint32 { + if len(x) == 0 { + return 0 + } + return uint32(x[0]) +} + +// low64 returns the least significant 64 bits of x. +func low64(x nat) uint64 { + if len(x) == 0 { + return 0 + } + v := uint64(x[0]) + if _W == 32 && len(x) > 1 { + return uint64(x[1])<<32 | v + } + return v +} + +// Int64 returns the int64 representation of x. +// If x cannot be represented in an int64, the result is undefined. +func (x *Int) Int64() int64 { + v := int64(low64(x.abs)) + if x.neg { + v = -v + } + return v +} + +// Uint64 returns the uint64 representation of x. +// If x cannot be represented in a uint64, the result is undefined. +func (x *Int) Uint64() uint64 { + return low64(x.abs) +} + +// IsInt64 reports whether x can be represented as an int64. +func (x *Int) IsInt64() bool { + if len(x.abs) <= 64/_W { + w := int64(low64(x.abs)) + return w >= 0 || x.neg && w == -w + } + return false +} + +// IsUint64 reports whether x can be represented as a uint64. +func (x *Int) IsUint64() bool { + return !x.neg && len(x.abs) <= 64/_W +} + +// SetString sets z to the value of s, interpreted in the given base, +// and returns z and a boolean indicating success. The entire string +// (not just a prefix) must be valid for success. If SetString fails, +// the value of z is undefined but the returned value is nil. +// +// The base argument must be 0 or a value between 2 and MaxBase. +// For base 0, the number prefix determines the actual base: A prefix of +// “0b” or “0B” selects base 2, “0”, “0o” or “0O” selects base 8, +// and “0x” or “0X” selects base 16. Otherwise, the selected base is 10 +// and no prefix is accepted. +// +// For bases <= 36, lower and upper case letters are considered the same: +// The letters 'a' to 'z' and 'A' to 'Z' represent digit values 10 to 35. +// For bases > 36, the upper case letters 'A' to 'Z' represent the digit +// values 36 to 61. +// +// For base 0, an underscore character “_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number. +// Incorrect placement of underscores is reported as an error if there +// are no other errors. If base != 0, underscores are not recognized +// and act like any other character that is not a valid digit. +func (z *Int) SetString(s string, base int) (*Int, bool) { + return z.setFromScanner(strings.NewReader(s), base) +} + +// setFromScanner implements SetString given an io.ByteScanner. +// For documentation see comments of SetString. +func (z *Int) setFromScanner(r io.ByteScanner, base int) (*Int, bool) { + if _, _, err := z.scan(r, base); err != nil { + return nil, false + } + // entire content must have been consumed + if _, err := r.ReadByte(); err != io.EOF { + return nil, false + } + return z, true // err == io.EOF => scan consumed all content of r +} + +// SetBytes interprets buf as the bytes of a big-endian unsigned +// integer, sets z to that value, and returns z. +func (z *Int) SetBytes(buf []byte) *Int { + z.abs = z.abs.setBytes(buf) + z.neg = false + return z +} + +// Bytes returns the absolute value of x as a big-endian byte slice. +// +// To use a fixed length slice, or a preallocated one, use FillBytes. +func (x *Int) Bytes() []byte { + // This function is used in cryptographic operations. It must not leak + // anything but the Int's sign and bit size through side-channels. Any + // changes must be reviewed by a security expert. + buf := make([]byte, len(x.abs)*_S) + return buf[x.abs.bytes(buf):] +} + +// FillBytes sets buf to the absolute value of x, storing it as a zero-extended +// big-endian byte slice, and returns buf. +// +// If the absolute value of x doesn't fit in buf, FillBytes will panic. +func (x *Int) FillBytes(buf []byte) []byte { + // Clear whole buffer. (This gets optimized into a memclr.) + for i := range buf { + buf[i] = 0 + } + x.abs.bytes(buf) + return buf +} + +// BitLen returns the length of the absolute value of x in bits. +// The bit length of 0 is 0. +func (x *Int) BitLen() int { + // This function is used in cryptographic operations. It must not leak + // anything but the Int's sign and bit size through side-channels. Any + // changes must be reviewed by a security expert. + return x.abs.bitLen() +} + +// TrailingZeroBits returns the number of consecutive least significant zero +// bits of |x|. +func (x *Int) TrailingZeroBits() uint { + return x.abs.trailingZeroBits() +} + +// Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z. +// If m == nil or m == 0, z = x**y unless y <= 0 then z = 1. If m != 0, y < 0, +// and x and m are not relatively prime, z is unchanged and nil is returned. +// +// Modular exponentiation of inputs of a particular size is not a +// cryptographically constant-time operation. +func (z *Int) Exp(x, y, m *Int) *Int { + return z.exp(x, y, m, false) +} + +func (z *Int) expSlow(x, y, m *Int) *Int { + return z.exp(x, y, m, true) +} + +func (z *Int) exp(x, y, m *Int, slow bool) *Int { + // See Knuth, volume 2, section 4.6.3. + xWords := x.abs + if y.neg { + if m == nil || len(m.abs) == 0 { + return z.SetInt64(1) + } + // for y < 0: x**y mod m == (x**(-1))**|y| mod m + inverse := new(Int).ModInverse(x, m) + if inverse == nil { + return nil + } + xWords = inverse.abs + } + yWords := y.abs + + var mWords nat + if m != nil { + if z == m || alias(z.abs, m.abs) { + m = new(Int).Set(m) + } + mWords = m.abs // m.abs may be nil for m == 0 + } + + z.abs = z.abs.expNN(xWords, yWords, mWords, slow) + z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign + if z.neg && len(mWords) > 0 { + // make modulus result positive + z.abs = z.abs.sub(mWords, z.abs) // z == x**y mod |m| && 0 <= z < |m| + z.neg = false + } + + return z +} + +// GCD sets z to the greatest common divisor of a and b and returns z. +// If x or y are not nil, GCD sets their value such that z = a*x + b*y. +// +// a and b may be positive, zero or negative. (Before Go 1.14 both had +// to be > 0.) Regardless of the signs of a and b, z is always >= 0. +// +// If a == b == 0, GCD sets z = x = y = 0. +// +// If a == 0 and b != 0, GCD sets z = |b|, x = 0, y = sign(b) * 1. +// +// If a != 0 and b == 0, GCD sets z = |a|, x = sign(a) * 1, y = 0. +func (z *Int) GCD(x, y, a, b *Int) *Int { + if len(a.abs) == 0 || len(b.abs) == 0 { + lenA, lenB, negA, negB := len(a.abs), len(b.abs), a.neg, b.neg + if lenA == 0 { + z.Set(b) + } else { + z.Set(a) + } + z.neg = false + if x != nil { + if lenA == 0 { + x.SetUint64(0) + } else { + x.SetUint64(1) + x.neg = negA + } + } + if y != nil { + if lenB == 0 { + y.SetUint64(0) + } else { + y.SetUint64(1) + y.neg = negB + } + } + return z + } + + return z.lehmerGCD(x, y, a, b) +} + +// lehmerSimulate attempts to simulate several Euclidean update steps +// using the leading digits of A and B. It returns u0, u1, v0, v1 +// such that A and B can be updated as: +// +// A = u0*A + v0*B +// B = u1*A + v1*B +// +// Requirements: A >= B and len(B.abs) >= 2 +// Since we are calculating with full words to avoid overflow, +// we use 'even' to track the sign of the cosequences. +// For even iterations: u0, v1 >= 0 && u1, v0 <= 0 +// For odd iterations: u0, v1 <= 0 && u1, v0 >= 0 +func lehmerSimulate(A, B *Int) (u0, u1, v0, v1 Word, even bool) { + // initialize the digits + var a1, a2, u2, v2 Word + + m := len(B.abs) // m >= 2 + n := len(A.abs) // n >= m >= 2 + + // extract the top Word of bits from A and B + h := nlz(A.abs[n-1]) + a1 = A.abs[n-1]<<h | A.abs[n-2]>>(_W-h) + // B may have implicit zero words in the high bits if the lengths differ + switch { + case n == m: + a2 = B.abs[n-1]<<h | B.abs[n-2]>>(_W-h) + case n == m+1: + a2 = B.abs[n-2] >> (_W - h) + default: + a2 = 0 + } + + // Since we are calculating with full words to avoid overflow, + // we use 'even' to track the sign of the cosequences. + // For even iterations: u0, v1 >= 0 && u1, v0 <= 0 + // For odd iterations: u0, v1 <= 0 && u1, v0 >= 0 + // The first iteration starts with k=1 (odd). + even = false + // variables to track the cosequences + u0, u1, u2 = 0, 1, 0 + v0, v1, v2 = 0, 0, 1 + + // Calculate the quotient and cosequences using Collins' stopping condition. + // Note that overflow of a Word is not possible when computing the remainder + // sequence and cosequences since the cosequence size is bounded by the input size. + // See section 4.2 of Jebelean for details. + for a2 >= v2 && a1-a2 >= v1+v2 { + q, r := a1/a2, a1%a2 + a1, a2 = a2, r + u0, u1, u2 = u1, u2, u1+q*u2 + v0, v1, v2 = v1, v2, v1+q*v2 + even = !even + } + return +} + +// lehmerUpdate updates the inputs A and B such that: +// +// A = u0*A + v0*B +// B = u1*A + v1*B +// +// where the signs of u0, u1, v0, v1 are given by even +// For even == true: u0, v1 >= 0 && u1, v0 <= 0 +// For even == false: u0, v1 <= 0 && u1, v0 >= 0 +// q, r, s, t are temporary variables to avoid allocations in the multiplication. +func lehmerUpdate(A, B, q, r, s, t *Int, u0, u1, v0, v1 Word, even bool) { + + t.abs = t.abs.setWord(u0) + s.abs = s.abs.setWord(v0) + t.neg = !even + s.neg = even + + t.Mul(A, t) + s.Mul(B, s) + + r.abs = r.abs.setWord(u1) + q.abs = q.abs.setWord(v1) + r.neg = even + q.neg = !even + + r.Mul(A, r) + q.Mul(B, q) + + A.Add(t, s) + B.Add(r, q) +} + +// euclidUpdate performs a single step of the Euclidean GCD algorithm +// if extended is true, it also updates the cosequence Ua, Ub. +func euclidUpdate(A, B, Ua, Ub, q, r, s, t *Int, extended bool) { + q, r = q.QuoRem(A, B, r) + + *A, *B, *r = *B, *r, *A + + if extended { + // Ua, Ub = Ub, Ua - q*Ub + t.Set(Ub) + s.Mul(Ub, q) + Ub.Sub(Ua, s) + Ua.Set(t) + } +} + +// lehmerGCD sets z to the greatest common divisor of a and b, +// which both must be != 0, and returns z. +// If x or y are not nil, their values are set such that z = a*x + b*y. +// See Knuth, The Art of Computer Programming, Vol. 2, Section 4.5.2, Algorithm L. +// This implementation uses the improved condition by Collins requiring only one +// quotient and avoiding the possibility of single Word overflow. +// See Jebelean, "Improving the multiprecision Euclidean algorithm", +// Design and Implementation of Symbolic Computation Systems, pp 45-58. +// The cosequences are updated according to Algorithm 10.45 from +// Cohen et al. "Handbook of Elliptic and Hyperelliptic Curve Cryptography" pp 192. +func (z *Int) lehmerGCD(x, y, a, b *Int) *Int { + var A, B, Ua, Ub *Int + + A = new(Int).Abs(a) + B = new(Int).Abs(b) + + extended := x != nil || y != nil + + if extended { + // Ua (Ub) tracks how many times input a has been accumulated into A (B). + Ua = new(Int).SetInt64(1) + Ub = new(Int) + } + + // temp variables for multiprecision update + q := new(Int) + r := new(Int) + s := new(Int) + t := new(Int) + + // ensure A >= B + if A.abs.cmp(B.abs) < 0 { + A, B = B, A + Ub, Ua = Ua, Ub + } + + // loop invariant A >= B + for len(B.abs) > 1 { + // Attempt to calculate in single-precision using leading words of A and B. + u0, u1, v0, v1, even := lehmerSimulate(A, B) + + // multiprecision Step + if v0 != 0 { + // Simulate the effect of the single-precision steps using the cosequences. + // A = u0*A + v0*B + // B = u1*A + v1*B + lehmerUpdate(A, B, q, r, s, t, u0, u1, v0, v1, even) + + if extended { + // Ua = u0*Ua + v0*Ub + // Ub = u1*Ua + v1*Ub + lehmerUpdate(Ua, Ub, q, r, s, t, u0, u1, v0, v1, even) + } + + } else { + // Single-digit calculations failed to simulate any quotients. + // Do a standard Euclidean step. + euclidUpdate(A, B, Ua, Ub, q, r, s, t, extended) + } + } + + if len(B.abs) > 0 { + // extended Euclidean algorithm base case if B is a single Word + if len(A.abs) > 1 { + // A is longer than a single Word, so one update is needed. + euclidUpdate(A, B, Ua, Ub, q, r, s, t, extended) + } + if len(B.abs) > 0 { + // A and B are both a single Word. + aWord, bWord := A.abs[0], B.abs[0] + if extended { + var ua, ub, va, vb Word + ua, ub = 1, 0 + va, vb = 0, 1 + even := true + for bWord != 0 { + q, r := aWord/bWord, aWord%bWord + aWord, bWord = bWord, r + ua, ub = ub, ua+q*ub + va, vb = vb, va+q*vb + even = !even + } + + t.abs = t.abs.setWord(ua) + s.abs = s.abs.setWord(va) + t.neg = !even + s.neg = even + + t.Mul(Ua, t) + s.Mul(Ub, s) + + Ua.Add(t, s) + } else { + for bWord != 0 { + aWord, bWord = bWord, aWord%bWord + } + } + A.abs[0] = aWord + } + } + negA := a.neg + if y != nil { + // avoid aliasing b needed in the division below + if y == b { + B.Set(b) + } else { + B = b + } + // y = (z - a*x)/b + y.Mul(a, Ua) // y can safely alias a + if negA { + y.neg = !y.neg + } + y.Sub(A, y) + y.Div(y, B) + } + + if x != nil { + *x = *Ua + if negA { + x.neg = !x.neg + } + } + + *z = *A + + return z +} + +// Rand sets z to a pseudo-random number in [0, n) and returns z. +// +// As this uses the math/rand package, it must not be used for +// security-sensitive work. Use crypto/rand.Int instead. +func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int { + // z.neg is not modified before the if check, because z and n might alias. + if n.neg || len(n.abs) == 0 { + z.neg = false + z.abs = nil + return z + } + z.neg = false + z.abs = z.abs.random(rnd, n.abs, n.abs.bitLen()) + return z +} + +// ModInverse sets z to the multiplicative inverse of g in the ring ℤ/nℤ +// and returns z. If g and n are not relatively prime, g has no multiplicative +// inverse in the ring ℤ/nℤ. In this case, z is unchanged and the return value +// is nil. If n == 0, a division-by-zero run-time panic occurs. +func (z *Int) ModInverse(g, n *Int) *Int { + // GCD expects parameters a and b to be > 0. + if n.neg { + var n2 Int + n = n2.Neg(n) + } + if g.neg { + var g2 Int + g = g2.Mod(g, n) + } + var d, x Int + d.GCD(&x, nil, g, n) + + // if and only if d==1, g and n are relatively prime + if d.Cmp(intOne) != 0 { + return nil + } + + // x and y are such that g*x + n*y = 1, therefore x is the inverse element, + // but it may be negative, so convert to the range 0 <= z < |n| + if x.neg { + z.Add(&x, n) + } else { + z.Set(&x) + } + return z +} + +func (z nat) modInverse(g, n nat) nat { + // TODO(rsc): ModInverse should be implemented in terms of this function. + return (&Int{abs: z}).ModInverse(&Int{abs: g}, &Int{abs: n}).abs +} + +// Jacobi returns the Jacobi symbol (x/y), either +1, -1, or 0. +// The y argument must be an odd integer. +func Jacobi(x, y *Int) int { + if len(y.abs) == 0 || y.abs[0]&1 == 0 { + panic(fmt.Sprintf("big: invalid 2nd argument to Int.Jacobi: need odd integer but got %s", y.String())) + } + + // We use the formulation described in chapter 2, section 2.4, + // "The Yacas Book of Algorithms": + // http://yacas.sourceforge.net/Algo.book.pdf + + var a, b, c Int + a.Set(x) + b.Set(y) + j := 1 + + if b.neg { + if a.neg { + j = -1 + } + b.neg = false + } + + for { + if b.Cmp(intOne) == 0 { + return j + } + if len(a.abs) == 0 { + return 0 + } + a.Mod(&a, &b) + if len(a.abs) == 0 { + return 0 + } + // a > 0 + + // handle factors of 2 in 'a' + s := a.abs.trailingZeroBits() + if s&1 != 0 { + bmod8 := b.abs[0] & 7 + if bmod8 == 3 || bmod8 == 5 { + j = -j + } + } + c.Rsh(&a, s) // a = 2^s*c + + // swap numerator and denominator + if b.abs[0]&3 == 3 && c.abs[0]&3 == 3 { + j = -j + } + a.Set(&b) + b.Set(&c) + } +} + +// modSqrt3Mod4 uses the identity +// +// (a^((p+1)/4))^2 mod p +// == u^(p+1) mod p +// == u^2 mod p +// +// to calculate the square root of any quadratic residue mod p quickly for 3 +// mod 4 primes. +func (z *Int) modSqrt3Mod4Prime(x, p *Int) *Int { + e := new(Int).Add(p, intOne) // e = p + 1 + e.Rsh(e, 2) // e = (p + 1) / 4 + z.Exp(x, e, p) // z = x^e mod p + return z +} + +// modSqrt5Mod8 uses Atkin's observation that 2 is not a square mod p +// +// alpha == (2*a)^((p-5)/8) mod p +// beta == 2*a*alpha^2 mod p is a square root of -1 +// b == a*alpha*(beta-1) mod p is a square root of a +// +// to calculate the square root of any quadratic residue mod p quickly for 5 +// mod 8 primes. +func (z *Int) modSqrt5Mod8Prime(x, p *Int) *Int { + // p == 5 mod 8 implies p = e*8 + 5 + // e is the quotient and 5 the remainder on division by 8 + e := new(Int).Rsh(p, 3) // e = (p - 5) / 8 + tx := new(Int).Lsh(x, 1) // tx = 2*x + alpha := new(Int).Exp(tx, e, p) + beta := new(Int).Mul(alpha, alpha) + beta.Mod(beta, p) + beta.Mul(beta, tx) + beta.Mod(beta, p) + beta.Sub(beta, intOne) + beta.Mul(beta, x) + beta.Mod(beta, p) + beta.Mul(beta, alpha) + z.Mod(beta, p) + return z +} + +// modSqrtTonelliShanks uses the Tonelli-Shanks algorithm to find the square +// root of a quadratic residue modulo any prime. +func (z *Int) modSqrtTonelliShanks(x, p *Int) *Int { + // Break p-1 into s*2^e such that s is odd. + var s Int + s.Sub(p, intOne) + e := s.abs.trailingZeroBits() + s.Rsh(&s, e) + + // find some non-square n + var n Int + n.SetInt64(2) + for Jacobi(&n, p) != -1 { + n.Add(&n, intOne) + } + + // Core of the Tonelli-Shanks algorithm. Follows the description in + // section 6 of "Square roots from 1; 24, 51, 10 to Dan Shanks" by Ezra + // Brown: + // https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + var y, b, g, t Int + y.Add(&s, intOne) + y.Rsh(&y, 1) + y.Exp(x, &y, p) // y = x^((s+1)/2) + b.Exp(x, &s, p) // b = x^s + g.Exp(&n, &s, p) // g = n^s + r := e + for { + // find the least m such that ord_p(b) = 2^m + var m uint + t.Set(&b) + for t.Cmp(intOne) != 0 { + t.Mul(&t, &t).Mod(&t, p) + m++ + } + + if m == 0 { + return z.Set(&y) + } + + t.SetInt64(0).SetBit(&t, int(r-m-1), 1).Exp(&g, &t, p) + // t = g^(2^(r-m-1)) mod p + g.Mul(&t, &t).Mod(&g, p) // g = g^(2^(r-m)) mod p + y.Mul(&y, &t).Mod(&y, p) + b.Mul(&b, &g).Mod(&b, p) + r = m + } +} + +// ModSqrt sets z to a square root of x mod p if such a square root exists, and +// returns z. The modulus p must be an odd prime. If x is not a square mod p, +// ModSqrt leaves z unchanged and returns nil. This function panics if p is +// not an odd integer, its behavior is undefined if p is odd but not prime. +func (z *Int) ModSqrt(x, p *Int) *Int { + switch Jacobi(x, p) { + case -1: + return nil // x is not a square mod p + case 0: + return z.SetInt64(0) // sqrt(0) mod p = 0 + case 1: + break + } + if x.neg || x.Cmp(p) >= 0 { // ensure 0 <= x < p + x = new(Int).Mod(x, p) + } + + switch { + case p.abs[0]%4 == 3: + // Check whether p is 3 mod 4, and if so, use the faster algorithm. + return z.modSqrt3Mod4Prime(x, p) + case p.abs[0]%8 == 5: + // Check whether p is 5 mod 8, use Atkin's algorithm. + return z.modSqrt5Mod8Prime(x, p) + default: + // Otherwise, use Tonelli-Shanks. + return z.modSqrtTonelliShanks(x, p) + } +} + +// Lsh sets z = x << n and returns z. +func (z *Int) Lsh(x *Int, n uint) *Int { + z.abs = z.abs.shl(x.abs, n) + z.neg = x.neg + return z +} + +// Rsh sets z = x >> n and returns z. +func (z *Int) Rsh(x *Int, n uint) *Int { + if x.neg { + // (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1) + t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0 + t = t.shr(t, n) + z.abs = t.add(t, natOne) + z.neg = true // z cannot be zero if x is negative + return z + } + + z.abs = z.abs.shr(x.abs, n) + z.neg = false + return z +} + +// Bit returns the value of the i'th bit of x. That is, it +// returns (x>>i)&1. The bit index i must be >= 0. +func (x *Int) Bit(i int) uint { + if i == 0 { + // optimization for common case: odd/even test of x + if len(x.abs) > 0 { + return uint(x.abs[0] & 1) // bit 0 is same for -x + } + return 0 + } + if i < 0 { + panic("negative bit index") + } + if x.neg { + t := nat(nil).sub(x.abs, natOne) + return t.bit(uint(i)) ^ 1 + } + + return x.abs.bit(uint(i)) +} + +// SetBit sets z to x, with x's i'th bit set to b (0 or 1). +// That is, if b is 1 SetBit sets z = x | (1 << i); +// if b is 0 SetBit sets z = x &^ (1 << i). If b is not 0 or 1, +// SetBit will panic. +func (z *Int) SetBit(x *Int, i int, b uint) *Int { + if i < 0 { + panic("negative bit index") + } + if x.neg { + t := z.abs.sub(x.abs, natOne) + t = t.setBit(t, uint(i), b^1) + z.abs = t.add(t, natOne) + z.neg = len(z.abs) > 0 + return z + } + z.abs = z.abs.setBit(x.abs, uint(i), b) + z.neg = false + return z +} + +// And sets z = x & y and returns z. +func (z *Int) And(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1) + x1 := nat(nil).sub(x.abs, natOne) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.add(z.abs.or(x1, y1), natOne) + z.neg = true // z cannot be zero if x and y are negative + return z + } + + // x & y == x & y + z.abs = z.abs.and(x.abs, y.abs) + z.neg = false + return z + } + + // x.neg != y.neg + if x.neg { + x, y = y, x // & is symmetric + } + + // x & (-y) == x & ^(y-1) == x &^ (y-1) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.andNot(x.abs, y1) + z.neg = false + return z +} + +// AndNot sets z = x &^ y and returns z. +func (z *Int) AndNot(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1) + x1 := nat(nil).sub(x.abs, natOne) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.andNot(y1, x1) + z.neg = false + return z + } + + // x &^ y == x &^ y + z.abs = z.abs.andNot(x.abs, y.abs) + z.neg = false + return z + } + + if x.neg { + // (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1) + x1 := nat(nil).sub(x.abs, natOne) + z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne) + z.neg = true // z cannot be zero if x is negative and y is positive + return z + } + + // x &^ (-y) == x &^ ^(y-1) == x & (y-1) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.and(x.abs, y1) + z.neg = false + return z +} + +// Or sets z = x | y and returns z. +func (z *Int) Or(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1) + x1 := nat(nil).sub(x.abs, natOne) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.add(z.abs.and(x1, y1), natOne) + z.neg = true // z cannot be zero if x and y are negative + return z + } + + // x | y == x | y + z.abs = z.abs.or(x.abs, y.abs) + z.neg = false + return z + } + + // x.neg != y.neg + if x.neg { + x, y = y, x // | is symmetric + } + + // x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne) + z.neg = true // z cannot be zero if one of x or y is negative + return z +} + +// Xor sets z = x ^ y and returns z. +func (z *Int) Xor(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1) + x1 := nat(nil).sub(x.abs, natOne) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.xor(x1, y1) + z.neg = false + return z + } + + // x ^ y == x ^ y + z.abs = z.abs.xor(x.abs, y.abs) + z.neg = false + return z + } + + // x.neg != y.neg + if x.neg { + x, y = y, x // ^ is symmetric + } + + // x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne) + z.neg = true // z cannot be zero if only one of x or y is negative + return z +} + +// Not sets z = ^x and returns z. +func (z *Int) Not(x *Int) *Int { + if x.neg { + // ^(-x) == ^(^(x-1)) == x-1 + z.abs = z.abs.sub(x.abs, natOne) + z.neg = false + return z + } + + // ^x == -x-1 == -(x+1) + z.abs = z.abs.add(x.abs, natOne) + z.neg = true // z cannot be zero if x is positive + return z +} + +// Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z. +// It panics if x is negative. +func (z *Int) Sqrt(x *Int) *Int { + if x.neg { + panic("square root of negative number") + } + z.neg = false + z.abs = z.abs.sqrt(x.abs) + return z +} diff --git a/src/math/big/int_test.go b/src/math/big/int_test.go new file mode 100644 index 0000000..53cd399 --- /dev/null +++ b/src/math/big/int_test.go @@ -0,0 +1,1957 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "bytes" + "encoding/hex" + "fmt" + "internal/testenv" + "math" + "math/rand" + "strconv" + "strings" + "testing" + "testing/quick" +) + +func isNormalized(x *Int) bool { + if len(x.abs) == 0 { + return !x.neg + } + // len(x.abs) > 0 + return x.abs[len(x.abs)-1] != 0 +} + +type funZZ func(z, x, y *Int) *Int +type argZZ struct { + z, x, y *Int +} + +var sumZZ = []argZZ{ + {NewInt(0), NewInt(0), NewInt(0)}, + {NewInt(1), NewInt(1), NewInt(0)}, + {NewInt(1111111110), NewInt(123456789), NewInt(987654321)}, + {NewInt(-1), NewInt(-1), NewInt(0)}, + {NewInt(864197532), NewInt(-123456789), NewInt(987654321)}, + {NewInt(-1111111110), NewInt(-123456789), NewInt(-987654321)}, +} + +var prodZZ = []argZZ{ + {NewInt(0), NewInt(0), NewInt(0)}, + {NewInt(0), NewInt(1), NewInt(0)}, + {NewInt(1), NewInt(1), NewInt(1)}, + {NewInt(-991 * 991), NewInt(991), NewInt(-991)}, + // TODO(gri) add larger products +} + +func TestSignZ(t *testing.T) { + var zero Int + for _, a := range sumZZ { + s := a.z.Sign() + e := a.z.Cmp(&zero) + if s != e { + t.Errorf("got %d; want %d for z = %v", s, e, a.z) + } + } +} + +func TestSetZ(t *testing.T) { + for _, a := range sumZZ { + var z Int + z.Set(a.z) + if !isNormalized(&z) { + t.Errorf("%v is not normalized", z) + } + if (&z).Cmp(a.z) != 0 { + t.Errorf("got z = %v; want %v", z, a.z) + } + } +} + +func TestAbsZ(t *testing.T) { + var zero Int + for _, a := range sumZZ { + var z Int + z.Abs(a.z) + var e Int + e.Set(a.z) + if e.Cmp(&zero) < 0 { + e.Sub(&zero, &e) + } + if z.Cmp(&e) != 0 { + t.Errorf("got z = %v; want %v", z, e) + } + } +} + +func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) { + var z Int + f(&z, a.x, a.y) + if !isNormalized(&z) { + t.Errorf("%s%v is not normalized", msg, z) + } + if (&z).Cmp(a.z) != 0 { + t.Errorf("%v %s %v\n\tgot z = %v; want %v", a.x, msg, a.y, &z, a.z) + } +} + +func TestSumZZ(t *testing.T) { + AddZZ := func(z, x, y *Int) *Int { return z.Add(x, y) } + SubZZ := func(z, x, y *Int) *Int { return z.Sub(x, y) } + for _, a := range sumZZ { + arg := a + testFunZZ(t, "AddZZ", AddZZ, arg) + + arg = argZZ{a.z, a.y, a.x} + testFunZZ(t, "AddZZ symmetric", AddZZ, arg) + + arg = argZZ{a.x, a.z, a.y} + testFunZZ(t, "SubZZ", SubZZ, arg) + + arg = argZZ{a.y, a.z, a.x} + testFunZZ(t, "SubZZ symmetric", SubZZ, arg) + } +} + +func TestProdZZ(t *testing.T) { + MulZZ := func(z, x, y *Int) *Int { return z.Mul(x, y) } + for _, a := range prodZZ { + arg := a + testFunZZ(t, "MulZZ", MulZZ, arg) + + arg = argZZ{a.z, a.y, a.x} + testFunZZ(t, "MulZZ symmetric", MulZZ, arg) + } +} + +// mulBytes returns x*y via grade school multiplication. Both inputs +// and the result are assumed to be in big-endian representation (to +// match the semantics of Int.Bytes and Int.SetBytes). +func mulBytes(x, y []byte) []byte { + z := make([]byte, len(x)+len(y)) + + // multiply + k0 := len(z) - 1 + for j := len(y) - 1; j >= 0; j-- { + d := int(y[j]) + if d != 0 { + k := k0 + carry := 0 + for i := len(x) - 1; i >= 0; i-- { + t := int(z[k]) + int(x[i])*d + carry + z[k], carry = byte(t), t>>8 + k-- + } + z[k] = byte(carry) + } + k0-- + } + + // normalize (remove leading 0's) + i := 0 + for i < len(z) && z[i] == 0 { + i++ + } + + return z[i:] +} + +func checkMul(a, b []byte) bool { + var x, y, z1 Int + x.SetBytes(a) + y.SetBytes(b) + z1.Mul(&x, &y) + + var z2 Int + z2.SetBytes(mulBytes(a, b)) + + return z1.Cmp(&z2) == 0 +} + +func TestMul(t *testing.T) { + if err := quick.Check(checkMul, nil); err != nil { + t.Error(err) + } +} + +var mulRangesZ = []struct { + a, b int64 + prod string +}{ + // entirely positive ranges are covered by mulRangesN + {-1, 1, "0"}, + {-2, -1, "2"}, + {-3, -2, "6"}, + {-3, -1, "-6"}, + {1, 3, "6"}, + {-10, -10, "-10"}, + {0, -1, "1"}, // empty range + {-1, -100, "1"}, // empty range + {-1, 1, "0"}, // range includes 0 + {-1e9, 0, "0"}, // range includes 0 + {-1e9, 1e9, "0"}, // range includes 0 + {-10, -1, "3628800"}, // 10! + {-20, -2, "-2432902008176640000"}, // -20! + {-99, -1, + "-933262154439441526816992388562667004907159682643816214685929" + + "638952175999932299156089414639761565182862536979208272237582" + + "511852109168640000000000000000000000", // -99! + }, +} + +func TestMulRangeZ(t *testing.T) { + var tmp Int + // test entirely positive ranges + for i, r := range mulRangesN { + prod := tmp.MulRange(int64(r.a), int64(r.b)).String() + if prod != r.prod { + t.Errorf("#%da: got %s; want %s", i, prod, r.prod) + } + } + // test other ranges + for i, r := range mulRangesZ { + prod := tmp.MulRange(r.a, r.b).String() + if prod != r.prod { + t.Errorf("#%db: got %s; want %s", i, prod, r.prod) + } + } +} + +func TestBinomial(t *testing.T) { + var z Int + for _, test := range []struct { + n, k int64 + want string + }{ + {0, 0, "1"}, + {0, 1, "0"}, + {1, 0, "1"}, + {1, 1, "1"}, + {1, 10, "0"}, + {4, 0, "1"}, + {4, 1, "4"}, + {4, 2, "6"}, + {4, 3, "4"}, + {4, 4, "1"}, + {10, 1, "10"}, + {10, 9, "10"}, + {10, 5, "252"}, + {11, 5, "462"}, + {11, 6, "462"}, + {100, 10, "17310309456440"}, + {100, 90, "17310309456440"}, + {1000, 10, "263409560461970212832400"}, + {1000, 990, "263409560461970212832400"}, + } { + if got := z.Binomial(test.n, test.k).String(); got != test.want { + t.Errorf("Binomial(%d, %d) = %s; want %s", test.n, test.k, got, test.want) + } + } +} + +func BenchmarkBinomial(b *testing.B) { + var z Int + for i := b.N - 1; i >= 0; i-- { + z.Binomial(1000, 990) + } +} + +// Examples from the Go Language Spec, section "Arithmetic operators" +var divisionSignsTests = []struct { + x, y int64 + q, r int64 // T-division + d, m int64 // Euclidean division +}{ + {5, 3, 1, 2, 1, 2}, + {-5, 3, -1, -2, -2, 1}, + {5, -3, -1, 2, -1, 2}, + {-5, -3, 1, -2, 2, 1}, + {1, 2, 0, 1, 0, 1}, + {8, 4, 2, 0, 2, 0}, +} + +func TestDivisionSigns(t *testing.T) { + for i, test := range divisionSignsTests { + x := NewInt(test.x) + y := NewInt(test.y) + q := NewInt(test.q) + r := NewInt(test.r) + d := NewInt(test.d) + m := NewInt(test.m) + + q1 := new(Int).Quo(x, y) + r1 := new(Int).Rem(x, y) + if !isNormalized(q1) { + t.Errorf("#%d Quo: %v is not normalized", i, *q1) + } + if !isNormalized(r1) { + t.Errorf("#%d Rem: %v is not normalized", i, *r1) + } + if q1.Cmp(q) != 0 || r1.Cmp(r) != 0 { + t.Errorf("#%d QuoRem: got (%s, %s), want (%s, %s)", i, q1, r1, q, r) + } + + q2, r2 := new(Int).QuoRem(x, y, new(Int)) + if !isNormalized(q2) { + t.Errorf("#%d Quo: %v is not normalized", i, *q2) + } + if !isNormalized(r2) { + t.Errorf("#%d Rem: %v is not normalized", i, *r2) + } + if q2.Cmp(q) != 0 || r2.Cmp(r) != 0 { + t.Errorf("#%d QuoRem: got (%s, %s), want (%s, %s)", i, q2, r2, q, r) + } + + d1 := new(Int).Div(x, y) + m1 := new(Int).Mod(x, y) + if !isNormalized(d1) { + t.Errorf("#%d Div: %v is not normalized", i, *d1) + } + if !isNormalized(m1) { + t.Errorf("#%d Mod: %v is not normalized", i, *m1) + } + if d1.Cmp(d) != 0 || m1.Cmp(m) != 0 { + t.Errorf("#%d DivMod: got (%s, %s), want (%s, %s)", i, d1, m1, d, m) + } + + d2, m2 := new(Int).DivMod(x, y, new(Int)) + if !isNormalized(d2) { + t.Errorf("#%d Div: %v is not normalized", i, *d2) + } + if !isNormalized(m2) { + t.Errorf("#%d Mod: %v is not normalized", i, *m2) + } + if d2.Cmp(d) != 0 || m2.Cmp(m) != 0 { + t.Errorf("#%d DivMod: got (%s, %s), want (%s, %s)", i, d2, m2, d, m) + } + } +} + +func norm(x nat) nat { + i := len(x) + for i > 0 && x[i-1] == 0 { + i-- + } + return x[:i] +} + +func TestBits(t *testing.T) { + for _, test := range []nat{ + nil, + {0}, + {1}, + {0, 1, 2, 3, 4}, + {4, 3, 2, 1, 0}, + {4, 3, 2, 1, 0, 0, 0, 0}, + } { + var z Int + z.neg = true + got := z.SetBits(test) + want := norm(test) + if got.abs.cmp(want) != 0 { + t.Errorf("SetBits(%v) = %v; want %v", test, got.abs, want) + } + + if got.neg { + t.Errorf("SetBits(%v): got negative result", test) + } + + bits := nat(z.Bits()) + if bits.cmp(want) != 0 { + t.Errorf("%v.Bits() = %v; want %v", z.abs, bits, want) + } + } +} + +func checkSetBytes(b []byte) bool { + hex1 := hex.EncodeToString(new(Int).SetBytes(b).Bytes()) + hex2 := hex.EncodeToString(b) + + for len(hex1) < len(hex2) { + hex1 = "0" + hex1 + } + + for len(hex1) > len(hex2) { + hex2 = "0" + hex2 + } + + return hex1 == hex2 +} + +func TestSetBytes(t *testing.T) { + if err := quick.Check(checkSetBytes, nil); err != nil { + t.Error(err) + } +} + +func checkBytes(b []byte) bool { + // trim leading zero bytes since Bytes() won't return them + // (was issue 12231) + for len(b) > 0 && b[0] == 0 { + b = b[1:] + } + b2 := new(Int).SetBytes(b).Bytes() + return bytes.Equal(b, b2) +} + +func TestBytes(t *testing.T) { + if err := quick.Check(checkBytes, nil); err != nil { + t.Error(err) + } +} + +func checkQuo(x, y []byte) bool { + u := new(Int).SetBytes(x) + v := new(Int).SetBytes(y) + + if len(v.abs) == 0 { + return true + } + + r := new(Int) + q, r := new(Int).QuoRem(u, v, r) + + if r.Cmp(v) >= 0 { + return false + } + + uprime := new(Int).Set(q) + uprime.Mul(uprime, v) + uprime.Add(uprime, r) + + return uprime.Cmp(u) == 0 +} + +var quoTests = []struct { + x, y string + q, r string +}{ + { + "476217953993950760840509444250624797097991362735329973741718102894495832294430498335824897858659711275234906400899559094370964723884706254265559534144986498357", + "9353930466774385905609975137998169297361893554149986716853295022578535724979483772383667534691121982974895531435241089241440253066816724367338287092081996", + "50911", + "1", + }, + { + "11510768301994997771168", + "1328165573307167369775", + "8", + "885443715537658812968", + }, +} + +func TestQuo(t *testing.T) { + if err := quick.Check(checkQuo, nil); err != nil { + t.Error(err) + } + + for i, test := range quoTests { + x, _ := new(Int).SetString(test.x, 10) + y, _ := new(Int).SetString(test.y, 10) + expectedQ, _ := new(Int).SetString(test.q, 10) + expectedR, _ := new(Int).SetString(test.r, 10) + + r := new(Int) + q, r := new(Int).QuoRem(x, y, r) + + if q.Cmp(expectedQ) != 0 || r.Cmp(expectedR) != 0 { + t.Errorf("#%d got (%s, %s) want (%s, %s)", i, q, r, expectedQ, expectedR) + } + } +} + +func TestQuoStepD6(t *testing.T) { + // See Knuth, Volume 2, section 4.3.1, exercise 21. This code exercises + // a code path which only triggers 1 in 10^{-19} cases. + + u := &Int{false, nat{0, 0, 1 + 1<<(_W-1), _M ^ (1 << (_W - 1))}} + v := &Int{false, nat{5, 2 + 1<<(_W-1), 1 << (_W - 1)}} + + r := new(Int) + q, r := new(Int).QuoRem(u, v, r) + const expectedQ64 = "18446744073709551613" + const expectedR64 = "3138550867693340382088035895064302439801311770021610913807" + const expectedQ32 = "4294967293" + const expectedR32 = "39614081266355540837921718287" + if q.String() != expectedQ64 && q.String() != expectedQ32 || + r.String() != expectedR64 && r.String() != expectedR32 { + t.Errorf("got (%s, %s) want (%s, %s) or (%s, %s)", q, r, expectedQ64, expectedR64, expectedQ32, expectedR32) + } +} + +func BenchmarkQuoRem(b *testing.B) { + x, _ := new(Int).SetString("153980389784927331788354528594524332344709972855165340650588877572729725338415474372475094155672066328274535240275856844648695200875763869073572078279316458648124537905600131008790701752441155668003033945258023841165089852359980273279085783159654751552359397986180318708491098942831252291841441726305535546071", 0) + y, _ := new(Int).SetString("7746362281539803897849273317883545285945243323447099728551653406505888775727297253384154743724750941556720663282745352402758568446486952008757638690735720782793164586481245379056001310087907017524411556680030339452580238411650898523599802732790857831596547515523593979861803187084910989428312522918414417263055355460715745539358014631136245887418412633787074173796862711588221766398229333338511838891484974940633857861775630560092874987828057333663969469797013996401149696897591265769095952887917296740109742927689053276850469671231961384715398038978492733178835452859452433234470997285516534065058887757272972533841547437247509415567206632827453524027585684464869520087576386907357207827931645864812453790560013100879070175244115566800303394525802384116508985235998027327908578315965475155235939798618031870849109894283125229184144172630553554607112725169432413343763989564437170644270643461665184965150423819594083121075825", 0) + q := new(Int) + r := new(Int) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + q.QuoRem(y, x, r) + } +} + +var bitLenTests = []struct { + in string + out int +}{ + {"-1", 1}, + {"0", 0}, + {"1", 1}, + {"2", 2}, + {"4", 3}, + {"0xabc", 12}, + {"0x8000", 16}, + {"0x80000000", 32}, + {"0x800000000000", 48}, + {"0x8000000000000000", 64}, + {"0x80000000000000000000", 80}, + {"-0x4000000000000000000000", 87}, +} + +func TestBitLen(t *testing.T) { + for i, test := range bitLenTests { + x, ok := new(Int).SetString(test.in, 0) + if !ok { + t.Errorf("#%d test input invalid: %s", i, test.in) + continue + } + + if n := x.BitLen(); n != test.out { + t.Errorf("#%d got %d want %d", i, n, test.out) + } + } +} + +var expTests = []struct { + x, y, m string + out string +}{ + // y <= 0 + {"0", "0", "", "1"}, + {"1", "0", "", "1"}, + {"-10", "0", "", "1"}, + {"1234", "-1", "", "1"}, + {"1234", "-1", "0", "1"}, + {"17", "-100", "1234", "865"}, + {"2", "-100", "1234", ""}, + + // m == 1 + {"0", "0", "1", "0"}, + {"1", "0", "1", "0"}, + {"-10", "0", "1", "0"}, + {"1234", "-1", "1", "0"}, + + // misc + {"5", "1", "3", "2"}, + {"5", "-7", "", "1"}, + {"-5", "-7", "", "1"}, + {"5", "0", "", "1"}, + {"-5", "0", "", "1"}, + {"5", "1", "", "5"}, + {"-5", "1", "", "-5"}, + {"-5", "1", "7", "2"}, + {"-2", "3", "2", "0"}, + {"5", "2", "", "25"}, + {"1", "65537", "2", "1"}, + {"0x8000000000000000", "2", "", "0x40000000000000000000000000000000"}, + {"0x8000000000000000", "2", "6719", "4944"}, + {"0x8000000000000000", "3", "6719", "5447"}, + {"0x8000000000000000", "1000", "6719", "1603"}, + {"0x8000000000000000", "1000000", "6719", "3199"}, + {"0x8000000000000000", "-1000000", "6719", "3663"}, // 3663 = ModInverse(3199, 6719) Issue #25865 + + {"0xffffffffffffffffffffffffffffffff", "0x12345678123456781234567812345678123456789", "0x01112222333344445555666677778889", "0x36168FA1DB3AAE6C8CE647E137F97A"}, + + { + "2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347", + "298472983472983471903246121093472394872319615612417471234712061", + "29834729834729834729347290846729561262544958723956495615629569234729836259263598127342374289365912465901365498236492183464", + "23537740700184054162508175125554701713153216681790245129157191391322321508055833908509185839069455749219131480588829346291", + }, + // test case for issue 8822 + { + "11001289118363089646017359372117963499250546375269047542777928006103246876688756735760905680604646624353196869572752623285140408755420374049317646428185270079555372763503115646054602867593662923894140940837479507194934267532831694565516466765025434902348314525627418515646588160955862839022051353653052947073136084780742729727874803457643848197499548297570026926927502505634297079527299004267769780768565695459945235586892627059178884998772989397505061206395455591503771677500931269477503508150175717121828518985901959919560700853226255420793148986854391552859459511723547532575574664944815966793196961286234040892865", + "0xB08FFB20760FFED58FADA86DFEF71AD72AA0FA763219618FE022C197E54708BB1191C66470250FCE8879487507CEE41381CA4D932F81C2B3F1AB20B539D50DCD", + "0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", + "21484252197776302499639938883777710321993113097987201050501182909581359357618579566746556372589385361683610524730509041328855066514963385522570894839035884713051640171474186548713546686476761306436434146475140156284389181808675016576845833340494848283681088886584219750554408060556769486628029028720727393293111678826356480455433909233520504112074401376133077150471237549474149190242010469539006449596611576612573955754349042329130631128234637924786466585703488460540228477440853493392086251021228087076124706778899179648655221663765993962724699135217212118535057766739392069738618682722216712319320435674779146070442", + }, + { + "-0x1BCE04427D8032319A89E5C4136456671AC620883F2C4139E57F91307C485AD2D6204F4F87A58262652DB5DBBAC72B0613E51B835E7153BEC6068F5C8D696B74DBD18FEC316AEF73985CF0475663208EB46B4F17DD9DA55367B03323E5491A70997B90C059FB34809E6EE55BCFBD5F2F52233BFE62E6AA9E4E26A1D4C2439883D14F2633D55D8AA66A1ACD5595E778AC3A280517F1157989E70C1A437B849F1877B779CC3CDDEDE2DAA6594A6C66D181A00A5F777EE60596D8773998F6E988DEAE4CCA60E4DDCF9590543C89F74F603259FCAD71660D30294FBBE6490300F78A9D63FA660DC9417B8B9DDA28BEB3977B621B988E23D4D954F322C3540541BC649ABD504C50FADFD9F0987D58A2BF689313A285E773FF02899A6EF887D1D4A0D2", + "0xB08FFB20760FFED58FADA86DFEF71AD72AA0FA763219618FE022C197E54708BB1191C66470250FCE8879487507CEE41381CA4D932F81C2B3F1AB20B539D50DCD", + "0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", + "21484252197776302499639938883777710321993113097987201050501182909581359357618579566746556372589385361683610524730509041328855066514963385522570894839035884713051640171474186548713546686476761306436434146475140156284389181808675016576845833340494848283681088886584219750554408060556769486628029028720727393293111678826356480455433909233520504112074401376133077150471237549474149190242010469539006449596611576612573955754349042329130631128234637924786466585703488460540228477440853493392086251021228087076124706778899179648655221663765993962724699135217212118535057766739392069738618682722216712319320435674779146070442", + }, + + // test cases for issue 13907 + {"0xffffffff00000001", "0xffffffff00000001", "0xffffffff00000001", "0"}, + {"0xffffffffffffffff00000001", "0xffffffffffffffff00000001", "0xffffffffffffffff00000001", "0"}, + {"0xffffffffffffffffffffffff00000001", "0xffffffffffffffffffffffff00000001", "0xffffffffffffffffffffffff00000001", "0"}, + {"0xffffffffffffffffffffffffffffffff00000001", "0xffffffffffffffffffffffffffffffff00000001", "0xffffffffffffffffffffffffffffffff00000001", "0"}, + + { + "2", + "0xB08FFB20760FFED58FADA86DFEF71AD72AA0FA763219618FE022C197E54708BB1191C66470250FCE8879487507CEE41381CA4D932F81C2B3F1AB20B539D50DCD", + "0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", // odd + "0x6AADD3E3E424D5B713FCAA8D8945B1E055166132038C57BBD2D51C833F0C5EA2007A2324CE514F8E8C2F008A2F36F44005A4039CB55830986F734C93DAF0EB4BAB54A6A8C7081864F44346E9BC6F0A3EB9F2C0146A00C6A05187D0C101E1F2D038CDB70CB5E9E05A2D188AB6CBB46286624D4415E7D4DBFAD3BCC6009D915C406EED38F468B940F41E6BEDC0430DD78E6F19A7DA3A27498A4181E24D738B0072D8F6ADB8C9809A5B033A09785814FD9919F6EF9F83EEA519BEC593855C4C10CBEEC582D4AE0792158823B0275E6AEC35242740468FAF3D5C60FD1E376362B6322F78B7ED0CA1C5BBCD2B49734A56C0967A1D01A100932C837B91D592CE08ABFF", + }, + { + "2", + "0xB08FFB20760FFED58FADA86DFEF71AD72AA0FA763219618FE022C197E54708BB1191C66470250FCE8879487507CEE41381CA4D932F81C2B3F1AB20B539D50DCD", + "0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF72", // even + "0x7858794B5897C29F4ED0B40913416AB6C48588484E6A45F2ED3E26C941D878E923575AAC434EE2750E6439A6976F9BB4D64CEDB2A53CE8D04DD48CADCDF8E46F22747C6B81C6CEA86C0D873FBF7CEF262BAAC43A522BD7F32F3CDAC52B9337C77B3DCFB3DB3EDD80476331E82F4B1DF8EFDC1220C92656DFC9197BDC1877804E28D928A2A284B8DED506CBA304435C9D0133C246C98A7D890D1DE60CBC53A024361DA83A9B8775019083D22AC6820ED7C3C68F8E801DD4EC779EE0A05C6EB682EF9840D285B838369BA7E148FA27691D524FAEAF7C6ECE2A4B99A294B9F2C241857B5B90CC8BFFCFCF18DFA7D676131D5CD3855A5A3E8EBFA0CDFADB4D198B4A", + }, +} + +func TestExp(t *testing.T) { + for i, test := range expTests { + x, ok1 := new(Int).SetString(test.x, 0) + y, ok2 := new(Int).SetString(test.y, 0) + + var ok3, ok4 bool + var out, m *Int + + if len(test.out) == 0 { + out, ok3 = nil, true + } else { + out, ok3 = new(Int).SetString(test.out, 0) + } + + if len(test.m) == 0 { + m, ok4 = nil, true + } else { + m, ok4 = new(Int).SetString(test.m, 0) + } + + if !ok1 || !ok2 || !ok3 || !ok4 { + t.Errorf("#%d: error in input", i) + continue + } + + z1 := new(Int).Exp(x, y, m) + if z1 != nil && !isNormalized(z1) { + t.Errorf("#%d: %v is not normalized", i, *z1) + } + if !(z1 == nil && out == nil || z1.Cmp(out) == 0) { + t.Errorf("#%d: got %x want %x", i, z1, out) + } + + if m == nil { + // The result should be the same as for m == 0; + // specifically, there should be no div-zero panic. + m = &Int{abs: nat{}} // m != nil && len(m.abs) == 0 + z2 := new(Int).Exp(x, y, m) + if z2.Cmp(z1) != 0 { + t.Errorf("#%d: got %x want %x", i, z2, z1) + } + } + } +} + +func BenchmarkExp(b *testing.B) { + x, _ := new(Int).SetString("11001289118363089646017359372117963499250546375269047542777928006103246876688756735760905680604646624353196869572752623285140408755420374049317646428185270079555372763503115646054602867593662923894140940837479507194934267532831694565516466765025434902348314525627418515646588160955862839022051353653052947073136084780742729727874803457643848197499548297570026926927502505634297079527299004267769780768565695459945235586892627059178884998772989397505061206395455591503771677500931269477503508150175717121828518985901959919560700853226255420793148986854391552859459511723547532575574664944815966793196961286234040892865", 0) + y, _ := new(Int).SetString("0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF72", 0) + n, _ := new(Int).SetString("0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", 0) + out := new(Int) + for i := 0; i < b.N; i++ { + out.Exp(x, y, n) + } +} + +func BenchmarkExpMont(b *testing.B) { + x, _ := new(Int).SetString("297778224889315382157302278696111964193", 0) + y, _ := new(Int).SetString("2548977943381019743024248146923164919440527843026415174732254534318292492375775985739511369575861449426580651447974311336267954477239437734832604782764979371984246675241012538135715981292390886872929238062252506842498360562303324154310849745753254532852868768268023732398278338025070694508489163836616810661033068070127919590264734220833816416141878688318329193389865030063416339367925710474801991305827284114894677717927892032165200876093838921477120036402410731159852999623461591709308405270748511350289172153076023215", 0) + var mods = []struct { + name string + val string + }{ + {"Odd", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF"}, + {"Even1", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FE"}, + {"Even2", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FC"}, + {"Even3", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281F8"}, + {"Even4", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281F0"}, + {"Even8", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B21828100"}, + {"Even32", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B00000000"}, + {"Even64", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF82828282828200FF0000000000000000"}, + {"Even96", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF82828283000000000000000000000000"}, + {"Even128", "0x82828282828200FFFF28FF2B218281FF82828282828200FFFF28FF2B218281FF00000000000000000000000000000000"}, + {"Even255", "0x82828282828200FFFF28FF2B218281FF8000000000000000000000000000000000000000000000000000000000000000"}, + {"SmallEven1", "0x7E"}, + {"SmallEven2", "0x7C"}, + {"SmallEven3", "0x78"}, + {"SmallEven4", "0x70"}, + } + for _, mod := range mods { + n, _ := new(Int).SetString(mod.val, 0) + out := new(Int) + b.Run(mod.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out.Exp(x, y, n) + } + }) + } +} + +func BenchmarkExp2(b *testing.B) { + x, _ := new(Int).SetString("2", 0) + y, _ := new(Int).SetString("0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF72", 0) + n, _ := new(Int).SetString("0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", 0) + out := new(Int) + for i := 0; i < b.N; i++ { + out.Exp(x, y, n) + } +} + +func checkGcd(aBytes, bBytes []byte) bool { + x := new(Int) + y := new(Int) + a := new(Int).SetBytes(aBytes) + b := new(Int).SetBytes(bBytes) + + d := new(Int).GCD(x, y, a, b) + x.Mul(x, a) + y.Mul(y, b) + x.Add(x, y) + + return x.Cmp(d) == 0 +} + +// euclidExtGCD is a reference implementation of Euclid's +// extended GCD algorithm for testing against optimized algorithms. +// Requirements: a, b > 0 +func euclidExtGCD(a, b *Int) (g, x, y *Int) { + A := new(Int).Set(a) + B := new(Int).Set(b) + + // A = Ua*a + Va*b + // B = Ub*a + Vb*b + Ua := new(Int).SetInt64(1) + Va := new(Int) + + Ub := new(Int) + Vb := new(Int).SetInt64(1) + + q := new(Int) + temp := new(Int) + + r := new(Int) + for len(B.abs) > 0 { + q, r = q.QuoRem(A, B, r) + + A, B, r = B, r, A + + // Ua, Ub = Ub, Ua-q*Ub + temp.Set(Ub) + Ub.Mul(Ub, q) + Ub.Sub(Ua, Ub) + Ua.Set(temp) + + // Va, Vb = Vb, Va-q*Vb + temp.Set(Vb) + Vb.Mul(Vb, q) + Vb.Sub(Va, Vb) + Va.Set(temp) + } + return A, Ua, Va +} + +func checkLehmerGcd(aBytes, bBytes []byte) bool { + a := new(Int).SetBytes(aBytes) + b := new(Int).SetBytes(bBytes) + + if a.Sign() <= 0 || b.Sign() <= 0 { + return true // can only test positive arguments + } + + d := new(Int).lehmerGCD(nil, nil, a, b) + d0, _, _ := euclidExtGCD(a, b) + + return d.Cmp(d0) == 0 +} + +func checkLehmerExtGcd(aBytes, bBytes []byte) bool { + a := new(Int).SetBytes(aBytes) + b := new(Int).SetBytes(bBytes) + x := new(Int) + y := new(Int) + + if a.Sign() <= 0 || b.Sign() <= 0 { + return true // can only test positive arguments + } + + d := new(Int).lehmerGCD(x, y, a, b) + d0, x0, y0 := euclidExtGCD(a, b) + + return d.Cmp(d0) == 0 && x.Cmp(x0) == 0 && y.Cmp(y0) == 0 +} + +var gcdTests = []struct { + d, x, y, a, b string +}{ + // a <= 0 || b <= 0 + {"0", "0", "0", "0", "0"}, + {"7", "0", "1", "0", "7"}, + {"7", "0", "-1", "0", "-7"}, + {"11", "1", "0", "11", "0"}, + {"7", "-1", "-2", "-77", "35"}, + {"935", "-3", "8", "64515", "24310"}, + {"935", "-3", "-8", "64515", "-24310"}, + {"935", "3", "-8", "-64515", "-24310"}, + + {"1", "-9", "47", "120", "23"}, + {"7", "1", "-2", "77", "35"}, + {"935", "-3", "8", "64515", "24310"}, + {"935000000000000000", "-3", "8", "64515000000000000000", "24310000000000000000"}, + {"1", "-221", "22059940471369027483332068679400581064239780177629666810348940098015901108344", "98920366548084643601728869055592650835572950932266967461790948584315647051443", "991"}, +} + +func testGcd(t *testing.T, d, x, y, a, b *Int) { + var X *Int + if x != nil { + X = new(Int) + } + var Y *Int + if y != nil { + Y = new(Int) + } + + D := new(Int).GCD(X, Y, a, b) + if D.Cmp(d) != 0 { + t.Errorf("GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, D, d) + } + if x != nil && X.Cmp(x) != 0 { + t.Errorf("GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, X, x) + } + if y != nil && Y.Cmp(y) != 0 { + t.Errorf("GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, Y, y) + } + + // check results in presence of aliasing (issue #11284) + a2 := new(Int).Set(a) + b2 := new(Int).Set(b) + a2.GCD(X, Y, a2, b2) // result is same as 1st argument + if a2.Cmp(d) != 0 { + t.Errorf("aliased z = a GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, a2, d) + } + if x != nil && X.Cmp(x) != 0 { + t.Errorf("aliased z = a GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, X, x) + } + if y != nil && Y.Cmp(y) != 0 { + t.Errorf("aliased z = a GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, Y, y) + } + + a2 = new(Int).Set(a) + b2 = new(Int).Set(b) + b2.GCD(X, Y, a2, b2) // result is same as 2nd argument + if b2.Cmp(d) != 0 { + t.Errorf("aliased z = b GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, b2, d) + } + if x != nil && X.Cmp(x) != 0 { + t.Errorf("aliased z = b GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, X, x) + } + if y != nil && Y.Cmp(y) != 0 { + t.Errorf("aliased z = b GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, Y, y) + } + + a2 = new(Int).Set(a) + b2 = new(Int).Set(b) + D = new(Int).GCD(a2, b2, a2, b2) // x = a, y = b + if D.Cmp(d) != 0 { + t.Errorf("aliased x = a, y = b GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, D, d) + } + if x != nil && a2.Cmp(x) != 0 { + t.Errorf("aliased x = a, y = b GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, a2, x) + } + if y != nil && b2.Cmp(y) != 0 { + t.Errorf("aliased x = a, y = b GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, b2, y) + } + + a2 = new(Int).Set(a) + b2 = new(Int).Set(b) + D = new(Int).GCD(b2, a2, a2, b2) // x = b, y = a + if D.Cmp(d) != 0 { + t.Errorf("aliased x = b, y = a GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, D, d) + } + if x != nil && b2.Cmp(x) != 0 { + t.Errorf("aliased x = b, y = a GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, b2, x) + } + if y != nil && a2.Cmp(y) != 0 { + t.Errorf("aliased x = b, y = a GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, a2, y) + } +} + +func TestGcd(t *testing.T) { + for _, test := range gcdTests { + d, _ := new(Int).SetString(test.d, 0) + x, _ := new(Int).SetString(test.x, 0) + y, _ := new(Int).SetString(test.y, 0) + a, _ := new(Int).SetString(test.a, 0) + b, _ := new(Int).SetString(test.b, 0) + + testGcd(t, d, nil, nil, a, b) + testGcd(t, d, x, nil, a, b) + testGcd(t, d, nil, y, a, b) + testGcd(t, d, x, y, a, b) + } + + if err := quick.Check(checkGcd, nil); err != nil { + t.Error(err) + } + + if err := quick.Check(checkLehmerGcd, nil); err != nil { + t.Error(err) + } + + if err := quick.Check(checkLehmerExtGcd, nil); err != nil { + t.Error(err) + } +} + +type intShiftTest struct { + in string + shift uint + out string +} + +var rshTests = []intShiftTest{ + {"0", 0, "0"}, + {"-0", 0, "0"}, + {"0", 1, "0"}, + {"0", 2, "0"}, + {"1", 0, "1"}, + {"1", 1, "0"}, + {"1", 2, "0"}, + {"2", 0, "2"}, + {"2", 1, "1"}, + {"-1", 0, "-1"}, + {"-1", 1, "-1"}, + {"-1", 10, "-1"}, + {"-100", 2, "-25"}, + {"-100", 3, "-13"}, + {"-100", 100, "-1"}, + {"4294967296", 0, "4294967296"}, + {"4294967296", 1, "2147483648"}, + {"4294967296", 2, "1073741824"}, + {"18446744073709551616", 0, "18446744073709551616"}, + {"18446744073709551616", 1, "9223372036854775808"}, + {"18446744073709551616", 2, "4611686018427387904"}, + {"18446744073709551616", 64, "1"}, + {"340282366920938463463374607431768211456", 64, "18446744073709551616"}, + {"340282366920938463463374607431768211456", 128, "1"}, +} + +func TestRsh(t *testing.T) { + for i, test := range rshTests { + in, _ := new(Int).SetString(test.in, 10) + expected, _ := new(Int).SetString(test.out, 10) + out := new(Int).Rsh(in, test.shift) + + if !isNormalized(out) { + t.Errorf("#%d: %v is not normalized", i, *out) + } + if out.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, out, expected) + } + } +} + +func TestRshSelf(t *testing.T) { + for i, test := range rshTests { + z, _ := new(Int).SetString(test.in, 10) + expected, _ := new(Int).SetString(test.out, 10) + z.Rsh(z, test.shift) + + if !isNormalized(z) { + t.Errorf("#%d: %v is not normalized", i, *z) + } + if z.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, z, expected) + } + } +} + +var lshTests = []intShiftTest{ + {"0", 0, "0"}, + {"0", 1, "0"}, + {"0", 2, "0"}, + {"1", 0, "1"}, + {"1", 1, "2"}, + {"1", 2, "4"}, + {"2", 0, "2"}, + {"2", 1, "4"}, + {"2", 2, "8"}, + {"-87", 1, "-174"}, + {"4294967296", 0, "4294967296"}, + {"4294967296", 1, "8589934592"}, + {"4294967296", 2, "17179869184"}, + {"18446744073709551616", 0, "18446744073709551616"}, + {"9223372036854775808", 1, "18446744073709551616"}, + {"4611686018427387904", 2, "18446744073709551616"}, + {"1", 64, "18446744073709551616"}, + {"18446744073709551616", 64, "340282366920938463463374607431768211456"}, + {"1", 128, "340282366920938463463374607431768211456"}, +} + +func TestLsh(t *testing.T) { + for i, test := range lshTests { + in, _ := new(Int).SetString(test.in, 10) + expected, _ := new(Int).SetString(test.out, 10) + out := new(Int).Lsh(in, test.shift) + + if !isNormalized(out) { + t.Errorf("#%d: %v is not normalized", i, *out) + } + if out.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, out, expected) + } + } +} + +func TestLshSelf(t *testing.T) { + for i, test := range lshTests { + z, _ := new(Int).SetString(test.in, 10) + expected, _ := new(Int).SetString(test.out, 10) + z.Lsh(z, test.shift) + + if !isNormalized(z) { + t.Errorf("#%d: %v is not normalized", i, *z) + } + if z.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, z, expected) + } + } +} + +func TestLshRsh(t *testing.T) { + for i, test := range rshTests { + in, _ := new(Int).SetString(test.in, 10) + out := new(Int).Lsh(in, test.shift) + out = out.Rsh(out, test.shift) + + if !isNormalized(out) { + t.Errorf("#%d: %v is not normalized", i, *out) + } + if in.Cmp(out) != 0 { + t.Errorf("#%d: got %s want %s", i, out, in) + } + } + for i, test := range lshTests { + in, _ := new(Int).SetString(test.in, 10) + out := new(Int).Lsh(in, test.shift) + out.Rsh(out, test.shift) + + if !isNormalized(out) { + t.Errorf("#%d: %v is not normalized", i, *out) + } + if in.Cmp(out) != 0 { + t.Errorf("#%d: got %s want %s", i, out, in) + } + } +} + +// Entries must be sorted by value in ascending order. +var cmpAbsTests = []string{ + "0", + "1", + "2", + "10", + "10000000", + "2783678367462374683678456387645876387564783686583485", + "2783678367462374683678456387645876387564783686583486", + "32957394867987420967976567076075976570670947609750670956097509670576075067076027578341538", +} + +func TestCmpAbs(t *testing.T) { + values := make([]*Int, len(cmpAbsTests)) + var prev *Int + for i, s := range cmpAbsTests { + x, ok := new(Int).SetString(s, 0) + if !ok { + t.Fatalf("SetString(%s, 0) failed", s) + } + if prev != nil && prev.Cmp(x) >= 0 { + t.Fatal("cmpAbsTests entries not sorted in ascending order") + } + values[i] = x + prev = x + } + + for i, x := range values { + for j, y := range values { + // try all combinations of signs for x, y + for k := 0; k < 4; k++ { + var a, b Int + a.Set(x) + b.Set(y) + if k&1 != 0 { + a.Neg(&a) + } + if k&2 != 0 { + b.Neg(&b) + } + + got := a.CmpAbs(&b) + want := 0 + switch { + case i > j: + want = 1 + case i < j: + want = -1 + } + if got != want { + t.Errorf("absCmp |%s|, |%s|: got %d; want %d", &a, &b, got, want) + } + } + } + } +} + +func TestIntCmpSelf(t *testing.T) { + for _, s := range cmpAbsTests { + x, ok := new(Int).SetString(s, 0) + if !ok { + t.Fatalf("SetString(%s, 0) failed", s) + } + got := x.Cmp(x) + want := 0 + if got != want { + t.Errorf("x = %s: x.Cmp(x): got %d; want %d", x, got, want) + } + } +} + +var int64Tests = []string{ + // int64 + "0", + "1", + "-1", + "4294967295", + "-4294967295", + "4294967296", + "-4294967296", + "9223372036854775807", + "-9223372036854775807", + "-9223372036854775808", + + // not int64 + "0x8000000000000000", + "-0x8000000000000001", + "38579843757496759476987459679745", + "-38579843757496759476987459679745", +} + +func TestInt64(t *testing.T) { + for _, s := range int64Tests { + var x Int + _, ok := x.SetString(s, 0) + if !ok { + t.Errorf("SetString(%s, 0) failed", s) + continue + } + + want, err := strconv.ParseInt(s, 0, 64) + if err != nil { + if err.(*strconv.NumError).Err == strconv.ErrRange { + if x.IsInt64() { + t.Errorf("IsInt64(%s) succeeded unexpectedly", s) + } + } else { + t.Errorf("ParseInt(%s) failed", s) + } + continue + } + + if !x.IsInt64() { + t.Errorf("IsInt64(%s) failed unexpectedly", s) + } + + got := x.Int64() + if got != want { + t.Errorf("Int64(%s) = %d; want %d", s, got, want) + } + } +} + +var uint64Tests = []string{ + // uint64 + "0", + "1", + "4294967295", + "4294967296", + "8589934591", + "8589934592", + "9223372036854775807", + "9223372036854775808", + "0x08000000000000000", + + // not uint64 + "0x10000000000000000", + "-0x08000000000000000", + "-1", +} + +func TestUint64(t *testing.T) { + for _, s := range uint64Tests { + var x Int + _, ok := x.SetString(s, 0) + if !ok { + t.Errorf("SetString(%s, 0) failed", s) + continue + } + + want, err := strconv.ParseUint(s, 0, 64) + if err != nil { + // check for sign explicitly (ErrRange doesn't cover signed input) + if s[0] == '-' || err.(*strconv.NumError).Err == strconv.ErrRange { + if x.IsUint64() { + t.Errorf("IsUint64(%s) succeeded unexpectedly", s) + } + } else { + t.Errorf("ParseUint(%s) failed", s) + } + continue + } + + if !x.IsUint64() { + t.Errorf("IsUint64(%s) failed unexpectedly", s) + } + + got := x.Uint64() + if got != want { + t.Errorf("Uint64(%s) = %d; want %d", s, got, want) + } + } +} + +var bitwiseTests = []struct { + x, y string + and, or, xor, andNot string +}{ + {"0x00", "0x00", "0x00", "0x00", "0x00", "0x00"}, + {"0x00", "0x01", "0x00", "0x01", "0x01", "0x00"}, + {"0x01", "0x00", "0x00", "0x01", "0x01", "0x01"}, + {"-0x01", "0x00", "0x00", "-0x01", "-0x01", "-0x01"}, + {"-0xaf", "-0x50", "-0xf0", "-0x0f", "0xe1", "0x41"}, + {"0x00", "-0x01", "0x00", "-0x01", "-0x01", "0x00"}, + {"0x01", "0x01", "0x01", "0x01", "0x00", "0x00"}, + {"-0x01", "-0x01", "-0x01", "-0x01", "0x00", "0x00"}, + {"0x07", "0x08", "0x00", "0x0f", "0x0f", "0x07"}, + {"0x05", "0x0f", "0x05", "0x0f", "0x0a", "0x00"}, + {"0xff", "-0x0a", "0xf6", "-0x01", "-0xf7", "0x09"}, + {"0x013ff6", "0x9a4e", "0x1a46", "0x01bffe", "0x01a5b8", "0x0125b0"}, + {"-0x013ff6", "0x9a4e", "0x800a", "-0x0125b2", "-0x01a5bc", "-0x01c000"}, + {"-0x013ff6", "-0x9a4e", "-0x01bffe", "-0x1a46", "0x01a5b8", "0x8008"}, + { + "0x1000009dc6e3d9822cba04129bcbe3401", + "0xb9bd7d543685789d57cb918e833af352559021483cdb05cc21fd", + "0x1000001186210100001000009048c2001", + "0xb9bd7d543685789d57cb918e8bfeff7fddb2ebe87dfbbdfe35fd", + "0xb9bd7d543685789d57ca918e8ae69d6fcdb2eae87df2b97215fc", + "0x8c40c2d8822caa04120b8321400", + }, + { + "0x1000009dc6e3d9822cba04129bcbe3401", + "-0xb9bd7d543685789d57cb918e833af352559021483cdb05cc21fd", + "0x8c40c2d8822caa04120b8321401", + "-0xb9bd7d543685789d57ca918e82229142459020483cd2014001fd", + "-0xb9bd7d543685789d57ca918e8ae69d6fcdb2eae87df2b97215fe", + "0x1000001186210100001000009048c2000", + }, + { + "-0x1000009dc6e3d9822cba04129bcbe3401", + "-0xb9bd7d543685789d57cb918e833af352559021483cdb05cc21fd", + "-0xb9bd7d543685789d57cb918e8bfeff7fddb2ebe87dfbbdfe35fd", + "-0x1000001186210100001000009048c2001", + "0xb9bd7d543685789d57ca918e8ae69d6fcdb2eae87df2b97215fc", + "0xb9bd7d543685789d57ca918e82229142459020483cd2014001fc", + }, +} + +type bitFun func(z, x, y *Int) *Int + +func testBitFun(t *testing.T, msg string, f bitFun, x, y *Int, exp string) { + expected := new(Int) + expected.SetString(exp, 0) + + out := f(new(Int), x, y) + if out.Cmp(expected) != 0 { + t.Errorf("%s: got %s want %s", msg, out, expected) + } +} + +func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) { + self := new(Int) + self.Set(x) + expected := new(Int) + expected.SetString(exp, 0) + + self = f(self, self, y) + if self.Cmp(expected) != 0 { + t.Errorf("%s: got %s want %s", msg, self, expected) + } +} + +func altBit(x *Int, i int) uint { + z := new(Int).Rsh(x, uint(i)) + z = z.And(z, NewInt(1)) + if z.Cmp(new(Int)) != 0 { + return 1 + } + return 0 +} + +func altSetBit(z *Int, x *Int, i int, b uint) *Int { + one := NewInt(1) + m := one.Lsh(one, uint(i)) + switch b { + case 1: + return z.Or(x, m) + case 0: + return z.AndNot(x, m) + } + panic("set bit is not 0 or 1") +} + +func testBitset(t *testing.T, x *Int) { + n := x.BitLen() + z := new(Int).Set(x) + z1 := new(Int).Set(x) + for i := 0; i < n+10; i++ { + old := z.Bit(i) + old1 := altBit(z1, i) + if old != old1 { + t.Errorf("bitset: inconsistent value for Bit(%s, %d), got %v want %v", z1, i, old, old1) + } + z := new(Int).SetBit(z, i, 1) + z1 := altSetBit(new(Int), z1, i, 1) + if z.Bit(i) == 0 { + t.Errorf("bitset: bit %d of %s got 0 want 1", i, x) + } + if z.Cmp(z1) != 0 { + t.Errorf("bitset: inconsistent value after SetBit 1, got %s want %s", z, z1) + } + z.SetBit(z, i, 0) + altSetBit(z1, z1, i, 0) + if z.Bit(i) != 0 { + t.Errorf("bitset: bit %d of %s got 1 want 0", i, x) + } + if z.Cmp(z1) != 0 { + t.Errorf("bitset: inconsistent value after SetBit 0, got %s want %s", z, z1) + } + altSetBit(z1, z1, i, old) + z.SetBit(z, i, old) + if z.Cmp(z1) != 0 { + t.Errorf("bitset: inconsistent value after SetBit old, got %s want %s", z, z1) + } + } + if z.Cmp(x) != 0 { + t.Errorf("bitset: got %s want %s", z, x) + } +} + +var bitsetTests = []struct { + x string + i int + b uint +}{ + {"0", 0, 0}, + {"0", 200, 0}, + {"1", 0, 1}, + {"1", 1, 0}, + {"-1", 0, 1}, + {"-1", 200, 1}, + {"0x2000000000000000000000000000", 108, 0}, + {"0x2000000000000000000000000000", 109, 1}, + {"0x2000000000000000000000000000", 110, 0}, + {"-0x2000000000000000000000000001", 108, 1}, + {"-0x2000000000000000000000000001", 109, 0}, + {"-0x2000000000000000000000000001", 110, 1}, +} + +func TestBitSet(t *testing.T) { + for _, test := range bitwiseTests { + x := new(Int) + x.SetString(test.x, 0) + testBitset(t, x) + x = new(Int) + x.SetString(test.y, 0) + testBitset(t, x) + } + for i, test := range bitsetTests { + x := new(Int) + x.SetString(test.x, 0) + b := x.Bit(test.i) + if b != test.b { + t.Errorf("#%d got %v want %v", i, b, test.b) + } + } + z := NewInt(1) + z.SetBit(NewInt(0), 2, 1) + if z.Cmp(NewInt(4)) != 0 { + t.Errorf("destination leaked into result; got %s want 4", z) + } +} + +var tzbTests = []struct { + in string + out uint +}{ + {"0", 0}, + {"1", 0}, + {"-1", 0}, + {"4", 2}, + {"-8", 3}, + {"0x4000000000000000000", 74}, + {"-0x8000000000000000000", 75}, +} + +func TestTrailingZeroBits(t *testing.T) { + for i, test := range tzbTests { + in, _ := new(Int).SetString(test.in, 0) + want := test.out + got := in.TrailingZeroBits() + + if got != want { + t.Errorf("#%d: got %v want %v", i, got, want) + } + } +} + +func BenchmarkBitset(b *testing.B) { + z := new(Int) + z.SetBit(z, 512, 1) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + z.SetBit(z, i&512, 1) + } +} + +func BenchmarkBitsetNeg(b *testing.B) { + z := NewInt(-1) + z.SetBit(z, 512, 0) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + z.SetBit(z, i&512, 0) + } +} + +func BenchmarkBitsetOrig(b *testing.B) { + z := new(Int) + altSetBit(z, z, 512, 1) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + altSetBit(z, z, i&512, 1) + } +} + +func BenchmarkBitsetNegOrig(b *testing.B) { + z := NewInt(-1) + altSetBit(z, z, 512, 0) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + altSetBit(z, z, i&512, 0) + } +} + +// tri generates the trinomial 2**(n*2) - 2**n - 1, which is always 3 mod 4 and +// 7 mod 8, so that 2 is always a quadratic residue. +func tri(n uint) *Int { + x := NewInt(1) + x.Lsh(x, n) + x2 := new(Int).Lsh(x, n) + x2.Sub(x2, x) + x2.Sub(x2, intOne) + return x2 +} + +func BenchmarkModSqrt225_Tonelli(b *testing.B) { + p := tri(225) + x := NewInt(2) + for i := 0; i < b.N; i++ { + x.SetUint64(2) + x.modSqrtTonelliShanks(x, p) + } +} + +func BenchmarkModSqrt225_3Mod4(b *testing.B) { + p := tri(225) + x := new(Int).SetUint64(2) + for i := 0; i < b.N; i++ { + x.SetUint64(2) + x.modSqrt3Mod4Prime(x, p) + } +} + +func BenchmarkModSqrt231_Tonelli(b *testing.B) { + p := tri(231) + p.Sub(p, intOne) + p.Sub(p, intOne) // tri(231) - 2 is a prime == 5 mod 8 + x := new(Int).SetUint64(7) + for i := 0; i < b.N; i++ { + x.SetUint64(7) + x.modSqrtTonelliShanks(x, p) + } +} + +func BenchmarkModSqrt231_5Mod8(b *testing.B) { + p := tri(231) + p.Sub(p, intOne) + p.Sub(p, intOne) // tri(231) - 2 is a prime == 5 mod 8 + x := new(Int).SetUint64(7) + for i := 0; i < b.N; i++ { + x.SetUint64(7) + x.modSqrt5Mod8Prime(x, p) + } +} + +func TestBitwise(t *testing.T) { + x := new(Int) + y := new(Int) + for _, test := range bitwiseTests { + x.SetString(test.x, 0) + y.SetString(test.y, 0) + + testBitFun(t, "and", (*Int).And, x, y, test.and) + testBitFunSelf(t, "and", (*Int).And, x, y, test.and) + testBitFun(t, "andNot", (*Int).AndNot, x, y, test.andNot) + testBitFunSelf(t, "andNot", (*Int).AndNot, x, y, test.andNot) + testBitFun(t, "or", (*Int).Or, x, y, test.or) + testBitFunSelf(t, "or", (*Int).Or, x, y, test.or) + testBitFun(t, "xor", (*Int).Xor, x, y, test.xor) + testBitFunSelf(t, "xor", (*Int).Xor, x, y, test.xor) + } +} + +var notTests = []struct { + in string + out string +}{ + {"0", "-1"}, + {"1", "-2"}, + {"7", "-8"}, + {"0", "-1"}, + {"-81910", "81909"}, + { + "298472983472983471903246121093472394872319615612417471234712061", + "-298472983472983471903246121093472394872319615612417471234712062", + }, +} + +func TestNot(t *testing.T) { + in := new(Int) + out := new(Int) + expected := new(Int) + for i, test := range notTests { + in.SetString(test.in, 10) + expected.SetString(test.out, 10) + out = out.Not(in) + if out.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, out, expected) + } + out = out.Not(out) + if out.Cmp(in) != 0 { + t.Errorf("#%d: got %s want %s", i, out, in) + } + } +} + +var modInverseTests = []struct { + element string + modulus string +}{ + {"1234567", "458948883992"}, + {"239487239847", "2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919"}, + {"-10", "13"}, // issue #16984 + {"10", "-13"}, + {"-17", "-13"}, +} + +func TestModInverse(t *testing.T) { + var element, modulus, gcd, inverse Int + one := NewInt(1) + for _, test := range modInverseTests { + (&element).SetString(test.element, 10) + (&modulus).SetString(test.modulus, 10) + (&inverse).ModInverse(&element, &modulus) + (&inverse).Mul(&inverse, &element) + (&inverse).Mod(&inverse, &modulus) + if (&inverse).Cmp(one) != 0 { + t.Errorf("ModInverse(%d,%d)*%d%%%d=%d, not 1", &element, &modulus, &element, &modulus, &inverse) + } + } + // exhaustive test for small values + for n := 2; n < 100; n++ { + (&modulus).SetInt64(int64(n)) + for x := 1; x < n; x++ { + (&element).SetInt64(int64(x)) + (&gcd).GCD(nil, nil, &element, &modulus) + if (&gcd).Cmp(one) != 0 { + continue + } + (&inverse).ModInverse(&element, &modulus) + (&inverse).Mul(&inverse, &element) + (&inverse).Mod(&inverse, &modulus) + if (&inverse).Cmp(one) != 0 { + t.Errorf("ModInverse(%d,%d)*%d%%%d=%d, not 1", &element, &modulus, &element, &modulus, &inverse) + } + } + } +} + +func BenchmarkModInverse(b *testing.B) { + p := new(Int).SetInt64(1) // Mersenne prime 2**1279 -1 + p.abs = p.abs.shl(p.abs, 1279) + p.Sub(p, intOne) + x := new(Int).Sub(p, intOne) + z := new(Int) + for i := 0; i < b.N; i++ { + z.ModInverse(x, p) + } +} + +// testModSqrt is a helper for TestModSqrt, +// which checks that ModSqrt can compute a square-root of elt^2. +func testModSqrt(t *testing.T, elt, mod, sq, sqrt *Int) bool { + var sqChk, sqrtChk, sqrtsq Int + sq.Mul(elt, elt) + sq.Mod(sq, mod) + z := sqrt.ModSqrt(sq, mod) + if z != sqrt { + t.Errorf("ModSqrt returned wrong value %s", z) + } + + // test ModSqrt arguments outside the range [0,mod) + sqChk.Add(sq, mod) + z = sqrtChk.ModSqrt(&sqChk, mod) + if z != &sqrtChk || z.Cmp(sqrt) != 0 { + t.Errorf("ModSqrt returned inconsistent value %s", z) + } + sqChk.Sub(sq, mod) + z = sqrtChk.ModSqrt(&sqChk, mod) + if z != &sqrtChk || z.Cmp(sqrt) != 0 { + t.Errorf("ModSqrt returned inconsistent value %s", z) + } + + // test x aliasing z + z = sqrtChk.ModSqrt(sqrtChk.Set(sq), mod) + if z != &sqrtChk || z.Cmp(sqrt) != 0 { + t.Errorf("ModSqrt returned inconsistent value %s", z) + } + + // make sure we actually got a square root + if sqrt.Cmp(elt) == 0 { + return true // we found the "desired" square root + } + sqrtsq.Mul(sqrt, sqrt) // make sure we found the "other" one + sqrtsq.Mod(&sqrtsq, mod) + return sq.Cmp(&sqrtsq) == 0 +} + +func TestModSqrt(t *testing.T) { + var elt, mod, modx4, sq, sqrt Int + r := rand.New(rand.NewSource(9)) + for i, s := range primes[1:] { // skip 2, use only odd primes + mod.SetString(s, 10) + modx4.Lsh(&mod, 2) + + // test a few random elements per prime + for x := 1; x < 5; x++ { + elt.Rand(r, &modx4) + elt.Sub(&elt, &mod) // test range [-mod, 3*mod) + if !testModSqrt(t, &elt, &mod, &sq, &sqrt) { + t.Errorf("#%d: failed (sqrt(e) = %s)", i, &sqrt) + } + } + + if testing.Short() && i > 2 { + break + } + } + + if testing.Short() { + return + } + + // exhaustive test for small values + for n := 3; n < 100; n++ { + mod.SetInt64(int64(n)) + if !mod.ProbablyPrime(10) { + continue + } + isSquare := make([]bool, n) + + // test all the squares + for x := 1; x < n; x++ { + elt.SetInt64(int64(x)) + if !testModSqrt(t, &elt, &mod, &sq, &sqrt) { + t.Errorf("#%d: failed (sqrt(%d,%d) = %s)", x, &elt, &mod, &sqrt) + } + isSquare[sq.Uint64()] = true + } + + // test all non-squares + for x := 1; x < n; x++ { + sq.SetInt64(int64(x)) + z := sqrt.ModSqrt(&sq, &mod) + if !isSquare[x] && z != nil { + t.Errorf("#%d: failed (sqrt(%d,%d) = nil)", x, &sqrt, &mod) + } + } + } +} + +func TestJacobi(t *testing.T) { + testCases := []struct { + x, y int64 + result int + }{ + {0, 1, 1}, + {0, -1, 1}, + {1, 1, 1}, + {1, -1, 1}, + {0, 5, 0}, + {1, 5, 1}, + {2, 5, -1}, + {-2, 5, -1}, + {2, -5, -1}, + {-2, -5, 1}, + {3, 5, -1}, + {5, 5, 0}, + {-5, 5, 0}, + {6, 5, 1}, + {6, -5, 1}, + {-6, 5, 1}, + {-6, -5, -1}, + } + + var x, y Int + + for i, test := range testCases { + x.SetInt64(test.x) + y.SetInt64(test.y) + expected := test.result + actual := Jacobi(&x, &y) + if actual != expected { + t.Errorf("#%d: Jacobi(%d, %d) = %d, but expected %d", i, test.x, test.y, actual, expected) + } + } +} + +func TestJacobiPanic(t *testing.T) { + const failureMsg = "test failure" + defer func() { + msg := recover() + if msg == nil || msg == failureMsg { + panic(msg) + } + t.Log(msg) + }() + x := NewInt(1) + y := NewInt(2) + // Jacobi should panic when the second argument is even. + Jacobi(x, y) + panic(failureMsg) +} + +func TestIssue2607(t *testing.T) { + // This code sequence used to hang. + n := NewInt(10) + n.Rand(rand.New(rand.NewSource(9)), n) +} + +func TestSqrt(t *testing.T) { + root := 0 + r := new(Int) + for i := 0; i < 10000; i++ { + if (root+1)*(root+1) <= i { + root++ + } + n := NewInt(int64(i)) + r.SetInt64(-2) + r.Sqrt(n) + if r.Cmp(NewInt(int64(root))) != 0 { + t.Errorf("Sqrt(%v) = %v, want %v", n, r, root) + } + } + + for i := 0; i < 1000; i += 10 { + n, _ := new(Int).SetString("1"+strings.Repeat("0", i), 10) + r := new(Int).Sqrt(n) + root, _ := new(Int).SetString("1"+strings.Repeat("0", i/2), 10) + if r.Cmp(root) != 0 { + t.Errorf("Sqrt(1e%d) = %v, want 1e%d", i, r, i/2) + } + } + + // Test aliasing. + r.SetInt64(100) + r.Sqrt(r) + if r.Int64() != 10 { + t.Errorf("Sqrt(100) = %v, want 10 (aliased output)", r.Int64()) + } +} + +// We can't test this together with the other Exp tests above because +// it requires a different receiver setup. +func TestIssue22830(t *testing.T) { + one := new(Int).SetInt64(1) + base, _ := new(Int).SetString("84555555300000000000", 10) + mod, _ := new(Int).SetString("66666670001111111111", 10) + want, _ := new(Int).SetString("17888885298888888889", 10) + + var tests = []int64{ + 0, 1, -1, + } + + for _, n := range tests { + m := NewInt(n) + if got := m.Exp(base, one, mod); got.Cmp(want) != 0 { + t.Errorf("(%v).Exp(%s, 1, %s) = %s, want %s", n, base, mod, got, want) + } + } +} + +func BenchmarkSqrt(b *testing.B) { + n, _ := new(Int).SetString("1"+strings.Repeat("0", 1001), 10) + b.ResetTimer() + t := new(Int) + for i := 0; i < b.N; i++ { + t.Sqrt(n) + } +} + +func benchmarkIntSqr(b *testing.B, nwords int) { + x := new(Int) + x.abs = rndNat(nwords) + t := new(Int) + b.ResetTimer() + for i := 0; i < b.N; i++ { + t.Mul(x, x) + } +} + +func BenchmarkIntSqr(b *testing.B) { + for _, n := range sqrBenchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + benchmarkIntSqr(b, n) + }) + } +} + +func benchmarkDiv(b *testing.B, aSize, bSize int) { + var r = rand.New(rand.NewSource(1234)) + aa := randInt(r, uint(aSize)) + bb := randInt(r, uint(bSize)) + if aa.Cmp(bb) < 0 { + aa, bb = bb, aa + } + x := new(Int) + y := new(Int) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.DivMod(aa, bb, y) + } +} + +func BenchmarkDiv(b *testing.B) { + sizes := []int{ + 10, 20, 50, 100, 200, 500, 1000, + 1e4, 1e5, 1e6, 1e7, + } + for _, i := range sizes { + j := 2 * i + b.Run(fmt.Sprintf("%d/%d", j, i), func(b *testing.B) { + benchmarkDiv(b, j, i) + }) + } +} + +func TestFillBytes(t *testing.T) { + checkResult := func(t *testing.T, buf []byte, want *Int) { + t.Helper() + got := new(Int).SetBytes(buf) + if got.CmpAbs(want) != 0 { + t.Errorf("got 0x%x, want 0x%x: %x", got, want, buf) + } + } + panics := func(f func()) (panic bool) { + defer func() { panic = recover() != nil }() + f() + return + } + + for _, n := range []string{ + "0", + "1000", + "0xffffffff", + "-0xffffffff", + "0xffffffffffffffff", + "0x10000000000000000", + "0xabababababababababababababababababababababababababa", + "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + } { + t.Run(n, func(t *testing.T) { + t.Logf(n) + x, ok := new(Int).SetString(n, 0) + if !ok { + panic("invalid test entry") + } + + // Perfectly sized buffer. + byteLen := (x.BitLen() + 7) / 8 + buf := make([]byte, byteLen) + checkResult(t, x.FillBytes(buf), x) + + // Way larger, checking all bytes get zeroed. + buf = make([]byte, 100) + for i := range buf { + buf[i] = 0xff + } + checkResult(t, x.FillBytes(buf), x) + + // Too small. + if byteLen > 0 { + buf = make([]byte, byteLen-1) + if !panics(func() { x.FillBytes(buf) }) { + t.Errorf("expected panic for small buffer and value %x", x) + } + } + }) + } +} + +func TestNewIntMinInt64(t *testing.T) { + // Test for uint64 cast in NewInt. + want := int64(math.MinInt64) + if got := NewInt(want).Int64(); got != want { + t.Fatalf("wanted %d, got %d", want, got) + } +} + +func TestNewIntAllocs(t *testing.T) { + testenv.SkipIfOptimizationOff(t) + for _, n := range []int64{0, 7, -7, 1 << 30, -1 << 30, 1 << 50, -1 << 50} { + x := NewInt(3) + got := testing.AllocsPerRun(100, func() { + // NewInt should inline, and all its allocations + // can happen on the stack. Passing the result of NewInt + // to Add should not cause any of those allocations to escape. + x.Add(x, NewInt(n)) + }) + if got != 0 { + t.Errorf("x.Add(x, NewInt(%d)), wanted 0 allocations, got %f", n, got) + } + } +} diff --git a/src/math/big/intconv.go b/src/math/big/intconv.go new file mode 100644 index 0000000..04e8c24 --- /dev/null +++ b/src/math/big/intconv.go @@ -0,0 +1,255 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements int-to-string conversion functions. + +package big + +import ( + "errors" + "fmt" + "io" +) + +// Text returns the string representation of x in the given base. +// Base must be between 2 and 62, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35, and +// the upper-case letters 'A' to 'Z' for digit values 36 to 61. +// No prefix (such as "0x") is added to the string. If x is a nil +// pointer it returns "<nil>". +func (x *Int) Text(base int) string { + if x == nil { + return "<nil>" + } + return string(x.abs.itoa(x.neg, base)) +} + +// Append appends the string representation of x, as generated by +// x.Text(base), to buf and returns the extended buffer. +func (x *Int) Append(buf []byte, base int) []byte { + if x == nil { + return append(buf, "<nil>"...) + } + return append(buf, x.abs.itoa(x.neg, base)...) +} + +// String returns the decimal representation of x as generated by +// x.Text(10). +func (x *Int) String() string { + return x.Text(10) +} + +// write count copies of text to s. +func writeMultiple(s fmt.State, text string, count int) { + if len(text) > 0 { + b := []byte(text) + for ; count > 0; count-- { + s.Write(b) + } + } +} + +var _ fmt.Formatter = intOne // *Int must implement fmt.Formatter + +// Format implements fmt.Formatter. It accepts the formats +// 'b' (binary), 'o' (octal with 0 prefix), 'O' (octal with 0o prefix), +// 'd' (decimal), 'x' (lowercase hexadecimal), and +// 'X' (uppercase hexadecimal). +// Also supported are the full suite of package fmt's format +// flags for integral types, including '+' and ' ' for sign +// control, '#' for leading zero in octal and for hexadecimal, +// a leading "0x" or "0X" for "%#x" and "%#X" respectively, +// specification of minimum digits precision, output field +// width, space or zero padding, and '-' for left or right +// justification. +func (x *Int) Format(s fmt.State, ch rune) { + // determine base + var base int + switch ch { + case 'b': + base = 2 + case 'o', 'O': + base = 8 + case 'd', 's', 'v': + base = 10 + case 'x', 'X': + base = 16 + default: + // unknown format + fmt.Fprintf(s, "%%!%c(big.Int=%s)", ch, x.String()) + return + } + + if x == nil { + fmt.Fprint(s, "<nil>") + return + } + + // determine sign character + sign := "" + switch { + case x.neg: + sign = "-" + case s.Flag('+'): // supersedes ' ' when both specified + sign = "+" + case s.Flag(' '): + sign = " " + } + + // determine prefix characters for indicating output base + prefix := "" + if s.Flag('#') { + switch ch { + case 'b': // binary + prefix = "0b" + case 'o': // octal + prefix = "0" + case 'x': // hexadecimal + prefix = "0x" + case 'X': + prefix = "0X" + } + } + if ch == 'O' { + prefix = "0o" + } + + digits := x.abs.utoa(base) + if ch == 'X' { + // faster than bytes.ToUpper + for i, d := range digits { + if 'a' <= d && d <= 'z' { + digits[i] = 'A' + (d - 'a') + } + } + } + + // number of characters for the three classes of number padding + var left int // space characters to left of digits for right justification ("%8d") + var zeros int // zero characters (actually cs[0]) as left-most digits ("%.8d") + var right int // space characters to right of digits for left justification ("%-8d") + + // determine number padding from precision: the least number of digits to output + precision, precisionSet := s.Precision() + if precisionSet { + switch { + case len(digits) < precision: + zeros = precision - len(digits) // count of zero padding + case len(digits) == 1 && digits[0] == '0' && precision == 0: + return // print nothing if zero value (x == 0) and zero precision ("." or ".0") + } + } + + // determine field pad from width: the least number of characters to output + length := len(sign) + len(prefix) + zeros + len(digits) + if width, widthSet := s.Width(); widthSet && length < width { // pad as specified + switch d := width - length; { + case s.Flag('-'): + // pad on the right with spaces; supersedes '0' when both specified + right = d + case s.Flag('0') && !precisionSet: + // pad with zeros unless precision also specified + zeros = d + default: + // pad on the left with spaces + left = d + } + } + + // print number as [left pad][sign][prefix][zero pad][digits][right pad] + writeMultiple(s, " ", left) + writeMultiple(s, sign, 1) + writeMultiple(s, prefix, 1) + writeMultiple(s, "0", zeros) + s.Write(digits) + writeMultiple(s, " ", right) +} + +// scan sets z to the integer value corresponding to the longest possible prefix +// read from r representing a signed integer number in a given conversion base. +// It returns z, the actual conversion base used, and an error, if any. In the +// error case, the value of z is undefined but the returned value is nil. The +// syntax follows the syntax of integer literals in Go. +// +// The base argument must be 0 or a value from 2 through MaxBase. If the base +// is 0, the string prefix determines the actual conversion base. A prefix of +// “0b” or “0B” selects base 2; a “0”, “0o”, or “0O” prefix selects +// base 8, and a “0x” or “0X” prefix selects base 16. Otherwise the selected +// base is 10. +func (z *Int) scan(r io.ByteScanner, base int) (*Int, int, error) { + // determine sign + neg, err := scanSign(r) + if err != nil { + return nil, 0, err + } + + // determine mantissa + z.abs, base, _, err = z.abs.scan(r, base, false) + if err != nil { + return nil, base, err + } + z.neg = len(z.abs) > 0 && neg // 0 has no sign + + return z, base, nil +} + +func scanSign(r io.ByteScanner) (neg bool, err error) { + var ch byte + if ch, err = r.ReadByte(); err != nil { + return false, err + } + switch ch { + case '-': + neg = true + case '+': + // nothing to do + default: + r.UnreadByte() + } + return +} + +// byteReader is a local wrapper around fmt.ScanState; +// it implements the ByteReader interface. +type byteReader struct { + fmt.ScanState +} + +func (r byteReader) ReadByte() (byte, error) { + ch, size, err := r.ReadRune() + if size != 1 && err == nil { + err = fmt.Errorf("invalid rune %#U", ch) + } + return byte(ch), err +} + +func (r byteReader) UnreadByte() error { + return r.UnreadRune() +} + +var _ fmt.Scanner = intOne // *Int must implement fmt.Scanner + +// Scan is a support routine for fmt.Scanner; it sets z to the value of +// the scanned number. It accepts the formats 'b' (binary), 'o' (octal), +// 'd' (decimal), 'x' (lowercase hexadecimal), and 'X' (uppercase hexadecimal). +func (z *Int) Scan(s fmt.ScanState, ch rune) error { + s.SkipSpace() // skip leading space characters + base := 0 + switch ch { + case 'b': + base = 2 + case 'o': + base = 8 + case 'd': + base = 10 + case 'x', 'X': + base = 16 + case 's', 'v': + // let scan determine the base + default: + return errors.New("Int.Scan: invalid verb") + } + _, _, err := z.scan(byteReader{s}, base) + return err +} diff --git a/src/math/big/intconv_test.go b/src/math/big/intconv_test.go new file mode 100644 index 0000000..5ba2926 --- /dev/null +++ b/src/math/big/intconv_test.go @@ -0,0 +1,431 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "bytes" + "fmt" + "testing" +) + +var stringTests = []struct { + in string + out string + base int + val int64 + ok bool +}{ + // invalid inputs + {in: ""}, + {in: "a"}, + {in: "z"}, + {in: "+"}, + {in: "-"}, + {in: "0b"}, + {in: "0o"}, + {in: "0x"}, + {in: "0y"}, + {in: "2", base: 2}, + {in: "0b2", base: 0}, + {in: "08"}, + {in: "8", base: 8}, + {in: "0xg", base: 0}, + {in: "g", base: 16}, + + // invalid inputs with separators + // (smoke tests only - a comprehensive set of tests is in natconv_test.go) + {in: "_"}, + {in: "0_"}, + {in: "_0"}, + {in: "-1__0"}, + {in: "0x10_"}, + {in: "1_000", base: 10}, // separators are not permitted for bases != 0 + {in: "d_e_a_d", base: 16}, + + // valid inputs + {"0", "0", 0, 0, true}, + {"0", "0", 10, 0, true}, + {"0", "0", 16, 0, true}, + {"+0", "0", 0, 0, true}, + {"-0", "0", 0, 0, true}, + {"10", "10", 0, 10, true}, + {"10", "10", 10, 10, true}, + {"10", "10", 16, 16, true}, + {"-10", "-10", 16, -16, true}, + {"+10", "10", 16, 16, true}, + {"0b10", "2", 0, 2, true}, + {"0o10", "8", 0, 8, true}, + {"0x10", "16", 0, 16, true}, + {in: "0x10", base: 16}, + {"-0x10", "-16", 0, -16, true}, + {"+0x10", "16", 0, 16, true}, + {"00", "0", 0, 0, true}, + {"0", "0", 8, 0, true}, + {"07", "7", 0, 7, true}, + {"7", "7", 8, 7, true}, + {"023", "19", 0, 19, true}, + {"23", "23", 8, 19, true}, + {"cafebabe", "cafebabe", 16, 0xcafebabe, true}, + {"0b0", "0", 0, 0, true}, + {"-111", "-111", 2, -7, true}, + {"-0b111", "-7", 0, -7, true}, + {"0b1001010111", "599", 0, 0x257, true}, + {"1001010111", "1001010111", 2, 0x257, true}, + {"A", "a", 36, 10, true}, + {"A", "A", 37, 36, true}, + {"ABCXYZ", "abcxyz", 36, 623741435, true}, + {"ABCXYZ", "ABCXYZ", 62, 33536793425, true}, + + // valid input with separators + // (smoke tests only - a comprehensive set of tests is in natconv_test.go) + {"1_000", "1000", 0, 1000, true}, + {"0b_1010", "10", 0, 10, true}, + {"+0o_660", "432", 0, 0660, true}, + {"-0xF00D_1E", "-15731998", 0, -0xf00d1e, true}, +} + +func TestIntText(t *testing.T) { + z := new(Int) + for _, test := range stringTests { + if !test.ok { + continue + } + + _, ok := z.SetString(test.in, test.base) + if !ok { + t.Errorf("%v: failed to parse", test) + continue + } + + base := test.base + if base == 0 { + base = 10 + } + + if got := z.Text(base); got != test.out { + t.Errorf("%v: got %s; want %s", test, got, test.out) + } + } +} + +func TestAppendText(t *testing.T) { + z := new(Int) + var buf []byte + for _, test := range stringTests { + if !test.ok { + continue + } + + _, ok := z.SetString(test.in, test.base) + if !ok { + t.Errorf("%v: failed to parse", test) + continue + } + + base := test.base + if base == 0 { + base = 10 + } + + i := len(buf) + buf = z.Append(buf, base) + if got := string(buf[i:]); got != test.out { + t.Errorf("%v: got %s; want %s", test, got, test.out) + } + } +} + +func format(base int) string { + switch base { + case 2: + return "%b" + case 8: + return "%o" + case 16: + return "%x" + } + return "%d" +} + +func TestGetString(t *testing.T) { + z := new(Int) + for i, test := range stringTests { + if !test.ok { + continue + } + z.SetInt64(test.val) + + if test.base == 10 { + if got := z.String(); got != test.out { + t.Errorf("#%da got %s; want %s", i, got, test.out) + } + } + + f := format(test.base) + got := fmt.Sprintf(f, z) + if f == "%d" { + if got != fmt.Sprintf("%d", test.val) { + t.Errorf("#%db got %s; want %d", i, got, test.val) + } + } else { + if got != test.out { + t.Errorf("#%dc got %s; want %s", i, got, test.out) + } + } + } +} + +func TestSetString(t *testing.T) { + tmp := new(Int) + for i, test := range stringTests { + // initialize to a non-zero value so that issues with parsing + // 0 are detected + tmp.SetInt64(1234567890) + n1, ok1 := new(Int).SetString(test.in, test.base) + n2, ok2 := tmp.SetString(test.in, test.base) + expected := NewInt(test.val) + if ok1 != test.ok || ok2 != test.ok { + t.Errorf("#%d (input '%s') ok incorrect (should be %t)", i, test.in, test.ok) + continue + } + if !ok1 { + if n1 != nil { + t.Errorf("#%d (input '%s') n1 != nil", i, test.in) + } + continue + } + if !ok2 { + if n2 != nil { + t.Errorf("#%d (input '%s') n2 != nil", i, test.in) + } + continue + } + + if ok1 && !isNormalized(n1) { + t.Errorf("#%d (input '%s'): %v is not normalized", i, test.in, *n1) + } + if ok2 && !isNormalized(n2) { + t.Errorf("#%d (input '%s'): %v is not normalized", i, test.in, *n2) + } + + if n1.Cmp(expected) != 0 { + t.Errorf("#%d (input '%s') got: %s want: %d", i, test.in, n1, test.val) + } + if n2.Cmp(expected) != 0 { + t.Errorf("#%d (input '%s') got: %s want: %d", i, test.in, n2, test.val) + } + } +} + +var formatTests = []struct { + input string + format string + output string +}{ + {"<nil>", "%x", "<nil>"}, + {"<nil>", "%#x", "<nil>"}, + {"<nil>", "%#y", "%!y(big.Int=<nil>)"}, + + {"10", "%b", "1010"}, + {"10", "%o", "12"}, + {"10", "%d", "10"}, + {"10", "%v", "10"}, + {"10", "%x", "a"}, + {"10", "%X", "A"}, + {"-10", "%X", "-A"}, + {"10", "%y", "%!y(big.Int=10)"}, + {"-10", "%y", "%!y(big.Int=-10)"}, + + {"10", "%#b", "0b1010"}, + {"10", "%#o", "012"}, + {"10", "%O", "0o12"}, + {"-10", "%#b", "-0b1010"}, + {"-10", "%#o", "-012"}, + {"-10", "%O", "-0o12"}, + {"10", "%#d", "10"}, + {"10", "%#v", "10"}, + {"10", "%#x", "0xa"}, + {"10", "%#X", "0XA"}, + {"-10", "%#X", "-0XA"}, + {"10", "%#y", "%!y(big.Int=10)"}, + {"-10", "%#y", "%!y(big.Int=-10)"}, + + {"1234", "%d", "1234"}, + {"1234", "%3d", "1234"}, + {"1234", "%4d", "1234"}, + {"-1234", "%d", "-1234"}, + {"1234", "% 5d", " 1234"}, + {"1234", "%+5d", "+1234"}, + {"1234", "%-5d", "1234 "}, + {"1234", "%x", "4d2"}, + {"1234", "%X", "4D2"}, + {"-1234", "%3x", "-4d2"}, + {"-1234", "%4x", "-4d2"}, + {"-1234", "%5x", " -4d2"}, + {"-1234", "%-5x", "-4d2 "}, + {"1234", "%03d", "1234"}, + {"1234", "%04d", "1234"}, + {"1234", "%05d", "01234"}, + {"1234", "%06d", "001234"}, + {"-1234", "%06d", "-01234"}, + {"1234", "%+06d", "+01234"}, + {"1234", "% 06d", " 01234"}, + {"1234", "%-6d", "1234 "}, + {"1234", "%-06d", "1234 "}, + {"-1234", "%-06d", "-1234 "}, + + {"1234", "%.3d", "1234"}, + {"1234", "%.4d", "1234"}, + {"1234", "%.5d", "01234"}, + {"1234", "%.6d", "001234"}, + {"-1234", "%.3d", "-1234"}, + {"-1234", "%.4d", "-1234"}, + {"-1234", "%.5d", "-01234"}, + {"-1234", "%.6d", "-001234"}, + + {"1234", "%8.3d", " 1234"}, + {"1234", "%8.4d", " 1234"}, + {"1234", "%8.5d", " 01234"}, + {"1234", "%8.6d", " 001234"}, + {"-1234", "%8.3d", " -1234"}, + {"-1234", "%8.4d", " -1234"}, + {"-1234", "%8.5d", " -01234"}, + {"-1234", "%8.6d", " -001234"}, + + {"1234", "%+8.3d", " +1234"}, + {"1234", "%+8.4d", " +1234"}, + {"1234", "%+8.5d", " +01234"}, + {"1234", "%+8.6d", " +001234"}, + {"-1234", "%+8.3d", " -1234"}, + {"-1234", "%+8.4d", " -1234"}, + {"-1234", "%+8.5d", " -01234"}, + {"-1234", "%+8.6d", " -001234"}, + + {"1234", "% 8.3d", " 1234"}, + {"1234", "% 8.4d", " 1234"}, + {"1234", "% 8.5d", " 01234"}, + {"1234", "% 8.6d", " 001234"}, + {"-1234", "% 8.3d", " -1234"}, + {"-1234", "% 8.4d", " -1234"}, + {"-1234", "% 8.5d", " -01234"}, + {"-1234", "% 8.6d", " -001234"}, + + {"1234", "%.3x", "4d2"}, + {"1234", "%.4x", "04d2"}, + {"1234", "%.5x", "004d2"}, + {"1234", "%.6x", "0004d2"}, + {"-1234", "%.3x", "-4d2"}, + {"-1234", "%.4x", "-04d2"}, + {"-1234", "%.5x", "-004d2"}, + {"-1234", "%.6x", "-0004d2"}, + + {"1234", "%8.3x", " 4d2"}, + {"1234", "%8.4x", " 04d2"}, + {"1234", "%8.5x", " 004d2"}, + {"1234", "%8.6x", " 0004d2"}, + {"-1234", "%8.3x", " -4d2"}, + {"-1234", "%8.4x", " -04d2"}, + {"-1234", "%8.5x", " -004d2"}, + {"-1234", "%8.6x", " -0004d2"}, + + {"1234", "%+8.3x", " +4d2"}, + {"1234", "%+8.4x", " +04d2"}, + {"1234", "%+8.5x", " +004d2"}, + {"1234", "%+8.6x", " +0004d2"}, + {"-1234", "%+8.3x", " -4d2"}, + {"-1234", "%+8.4x", " -04d2"}, + {"-1234", "%+8.5x", " -004d2"}, + {"-1234", "%+8.6x", " -0004d2"}, + + {"1234", "% 8.3x", " 4d2"}, + {"1234", "% 8.4x", " 04d2"}, + {"1234", "% 8.5x", " 004d2"}, + {"1234", "% 8.6x", " 0004d2"}, + {"1234", "% 8.7x", " 00004d2"}, + {"1234", "% 8.8x", " 000004d2"}, + {"-1234", "% 8.3x", " -4d2"}, + {"-1234", "% 8.4x", " -04d2"}, + {"-1234", "% 8.5x", " -004d2"}, + {"-1234", "% 8.6x", " -0004d2"}, + {"-1234", "% 8.7x", "-00004d2"}, + {"-1234", "% 8.8x", "-000004d2"}, + + {"1234", "%-8.3d", "1234 "}, + {"1234", "%-8.4d", "1234 "}, + {"1234", "%-8.5d", "01234 "}, + {"1234", "%-8.6d", "001234 "}, + {"1234", "%-8.7d", "0001234 "}, + {"1234", "%-8.8d", "00001234"}, + {"-1234", "%-8.3d", "-1234 "}, + {"-1234", "%-8.4d", "-1234 "}, + {"-1234", "%-8.5d", "-01234 "}, + {"-1234", "%-8.6d", "-001234 "}, + {"-1234", "%-8.7d", "-0001234"}, + {"-1234", "%-8.8d", "-00001234"}, + + {"16777215", "%b", "111111111111111111111111"}, // 2**24 - 1 + + {"0", "%.d", ""}, + {"0", "%.0d", ""}, + {"0", "%3.d", ""}, +} + +func TestFormat(t *testing.T) { + for i, test := range formatTests { + var x *Int + if test.input != "<nil>" { + var ok bool + x, ok = new(Int).SetString(test.input, 0) + if !ok { + t.Errorf("#%d failed reading input %s", i, test.input) + } + } + output := fmt.Sprintf(test.format, x) + if output != test.output { + t.Errorf("#%d got %q; want %q, {%q, %q, %q}", i, output, test.output, test.input, test.format, test.output) + } + } +} + +var scanTests = []struct { + input string + format string + output string + remaining int +}{ + {"1010", "%b", "10", 0}, + {"0b1010", "%v", "10", 0}, + {"12", "%o", "10", 0}, + {"012", "%v", "10", 0}, + {"10", "%d", "10", 0}, + {"10", "%v", "10", 0}, + {"a", "%x", "10", 0}, + {"0xa", "%v", "10", 0}, + {"A", "%X", "10", 0}, + {"-A", "%X", "-10", 0}, + {"+0b1011001", "%v", "89", 0}, + {"0xA", "%v", "10", 0}, + {"0 ", "%v", "0", 1}, + {"2+3", "%v", "2", 2}, + {"0XABC 12", "%v", "2748", 3}, +} + +func TestScan(t *testing.T) { + var buf bytes.Buffer + for i, test := range scanTests { + x := new(Int) + buf.Reset() + buf.WriteString(test.input) + if _, err := fmt.Fscanf(&buf, test.format, x); err != nil { + t.Errorf("#%d error: %s", i, err) + } + if x.String() != test.output { + t.Errorf("#%d got %s; want %s", i, x.String(), test.output) + } + if buf.Len() != test.remaining { + t.Errorf("#%d got %d bytes remaining; want %d", i, buf.Len(), test.remaining) + } + } +} diff --git a/src/math/big/intmarsh.go b/src/math/big/intmarsh.go new file mode 100644 index 0000000..ce429ff --- /dev/null +++ b/src/math/big/intmarsh.go @@ -0,0 +1,83 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements encoding/decoding of Ints. + +package big + +import ( + "bytes" + "fmt" +) + +// Gob codec version. Permits backward-compatible changes to the encoding. +const intGobVersion byte = 1 + +// GobEncode implements the gob.GobEncoder interface. +func (x *Int) GobEncode() ([]byte, error) { + if x == nil { + return nil, nil + } + buf := make([]byte, 1+len(x.abs)*_S) // extra byte for version and sign bit + i := x.abs.bytes(buf) - 1 // i >= 0 + b := intGobVersion << 1 // make space for sign bit + if x.neg { + b |= 1 + } + buf[i] = b + return buf[i:], nil +} + +// GobDecode implements the gob.GobDecoder interface. +func (z *Int) GobDecode(buf []byte) error { + if len(buf) == 0 { + // Other side sent a nil or default value. + *z = Int{} + return nil + } + b := buf[0] + if b>>1 != intGobVersion { + return fmt.Errorf("Int.GobDecode: encoding version %d not supported", b>>1) + } + z.neg = b&1 != 0 + z.abs = z.abs.setBytes(buf[1:]) + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (x *Int) MarshalText() (text []byte, err error) { + if x == nil { + return []byte("<nil>"), nil + } + return x.abs.itoa(x.neg, 10), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +func (z *Int) UnmarshalText(text []byte) error { + if _, ok := z.setFromScanner(bytes.NewReader(text), 0); !ok { + return fmt.Errorf("math/big: cannot unmarshal %q into a *big.Int", text) + } + return nil +} + +// The JSON marshalers are only here for API backward compatibility +// (programs that explicitly look for these two methods). JSON works +// fine with the TextMarshaler only. + +// MarshalJSON implements the json.Marshaler interface. +func (x *Int) MarshalJSON() ([]byte, error) { + if x == nil { + return []byte("null"), nil + } + return x.abs.itoa(x.neg, 10), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (z *Int) UnmarshalJSON(text []byte) error { + // Ignore null, like in the main JSON package. + if string(text) == "null" { + return nil + } + return z.UnmarshalText(text) +} diff --git a/src/math/big/intmarsh_test.go b/src/math/big/intmarsh_test.go new file mode 100644 index 0000000..8e7d29f --- /dev/null +++ b/src/math/big/intmarsh_test.go @@ -0,0 +1,134 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "bytes" + "encoding/gob" + "encoding/json" + "encoding/xml" + "testing" +) + +var encodingTests = []string{ + "0", + "1", + "2", + "10", + "1000", + "1234567890", + "298472983472983471903246121093472394872319615612417471234712061", +} + +func TestIntGobEncoding(t *testing.T) { + var medium bytes.Buffer + enc := gob.NewEncoder(&medium) + dec := gob.NewDecoder(&medium) + for _, test := range encodingTests { + for _, sign := range []string{"", "+", "-"} { + x := sign + test + medium.Reset() // empty buffer for each test case (in case of failures) + var tx Int + tx.SetString(x, 10) + if err := enc.Encode(&tx); err != nil { + t.Errorf("encoding of %s failed: %s", &tx, err) + continue + } + var rx Int + if err := dec.Decode(&rx); err != nil { + t.Errorf("decoding of %s failed: %s", &tx, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("transmission of %s failed: got %s want %s", &tx, &rx, &tx) + } + } + } +} + +// Sending a nil Int pointer (inside a slice) on a round trip through gob should yield a zero. +// TODO: top-level nils. +func TestGobEncodingNilIntInSlice(t *testing.T) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + dec := gob.NewDecoder(buf) + + var in = make([]*Int, 1) + err := enc.Encode(&in) + if err != nil { + t.Errorf("gob encode failed: %q", err) + } + var out []*Int + err = dec.Decode(&out) + if err != nil { + t.Fatalf("gob decode failed: %q", err) + } + if len(out) != 1 { + t.Fatalf("wrong len; want 1 got %d", len(out)) + } + var zero Int + if out[0].Cmp(&zero) != 0 { + t.Fatalf("transmission of (*Int)(nil) failed: got %s want 0", out) + } +} + +func TestIntJSONEncoding(t *testing.T) { + for _, test := range encodingTests { + for _, sign := range []string{"", "+", "-"} { + x := sign + test + var tx Int + tx.SetString(x, 10) + b, err := json.Marshal(&tx) + if err != nil { + t.Errorf("marshaling of %s failed: %s", &tx, err) + continue + } + var rx Int + if err := json.Unmarshal(b, &rx); err != nil { + t.Errorf("unmarshaling of %s failed: %s", &tx, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("JSON encoding of %s failed: got %s want %s", &tx, &rx, &tx) + } + } + } +} + +func TestIntJSONEncodingNil(t *testing.T) { + var x *Int + b, err := x.MarshalJSON() + if err != nil { + t.Fatalf("marshaling of nil failed: %s", err) + } + got := string(b) + want := "null" + if got != want { + t.Fatalf("marshaling of nil failed: got %s want %s", got, want) + } +} + +func TestIntXMLEncoding(t *testing.T) { + for _, test := range encodingTests { + for _, sign := range []string{"", "+", "-"} { + x := sign + test + var tx Int + tx.SetString(x, 0) + b, err := xml.Marshal(&tx) + if err != nil { + t.Errorf("marshaling of %s failed: %s", &tx, err) + continue + } + var rx Int + if err := xml.Unmarshal(b, &rx); err != nil { + t.Errorf("unmarshaling of %s failed: %s", &tx, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("XML encoding of %s failed: got %s want %s", &tx, &rx, &tx) + } + } + } +} diff --git a/src/math/big/link_test.go b/src/math/big/link_test.go new file mode 100644 index 0000000..6e33aa5 --- /dev/null +++ b/src/math/big/link_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "bytes" + "internal/testenv" + "os" + "os/exec" + "path/filepath" + "testing" +) + +// Tests that the linker is able to remove references to Float, Rat, +// and Int if unused (notably, not used by init). +func TestLinkerGC(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + t.Parallel() + tmp := t.TempDir() + goBin := testenv.GoToolPath(t) + goFile := filepath.Join(tmp, "x.go") + file := []byte(`package main +import _ "math/big" +func main() {} +`) + if err := os.WriteFile(goFile, file, 0644); err != nil { + t.Fatal(err) + } + cmd := exec.Command(goBin, "build", "-o", "x.exe", "x.go") + cmd.Dir = tmp + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("compile: %v, %s", err, out) + } + + cmd = exec.Command(goBin, "tool", "nm", "x.exe") + cmd.Dir = tmp + nm, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("nm: %v, %s", err, nm) + } + const want = "runtime.main" + if !bytes.Contains(nm, []byte(want)) { + // Test the test. + t.Errorf("expected symbol %q not found", want) + } + bad := []string{ + "math/big.(*Float)", + "math/big.(*Rat)", + "math/big.(*Int)", + } + for _, sym := range bad { + if bytes.Contains(nm, []byte(sym)) { + t.Errorf("unexpected symbol %q found", sym) + } + } + if t.Failed() { + t.Logf("Got: %s", nm) + } +} diff --git a/src/math/big/nat.go b/src/math/big/nat.go new file mode 100644 index 0000000..90ce6d1 --- /dev/null +++ b/src/math/big/nat.go @@ -0,0 +1,1429 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements unsigned multi-precision integers (natural +// numbers). They are the building blocks for the implementation +// of signed integers, rationals, and floating-point numbers. +// +// Caution: This implementation relies on the function "alias" +// which assumes that (nat) slice capacities are never +// changed (no 3-operand slice expressions). If that +// changes, alias needs to be updated for correctness. + +package big + +import ( + "encoding/binary" + "math/bits" + "math/rand" + "sync" +) + +// An unsigned integer x of the form +// +// x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] +// +// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, +// with the digits x[i] as the slice elements. +// +// A number is normalized if the slice contains no leading 0 digits. +// During arithmetic operations, denormalized values may occur but are +// always normalized before returning the final result. The normalized +// representation of 0 is the empty or nil slice (length = 0). +type nat []Word + +var ( + natOne = nat{1} + natTwo = nat{2} + natFive = nat{5} + natTen = nat{10} +) + +func (z nat) String() string { + return "0x" + string(z.itoa(false, 16)) +} + +func (z nat) clear() { + for i := range z { + z[i] = 0 + } +} + +func (z nat) norm() nat { + i := len(z) + for i > 0 && z[i-1] == 0 { + i-- + } + return z[0:i] +} + +func (z nat) make(n int) nat { + if n <= cap(z) { + return z[:n] // reuse z + } + if n == 1 { + // Most nats start small and stay that way; don't over-allocate. + return make(nat, 1) + } + // Choosing a good value for e has significant performance impact + // because it increases the chance that a value can be reused. + const e = 4 // extra capacity + return make(nat, n, n+e) +} + +func (z nat) setWord(x Word) nat { + if x == 0 { + return z[:0] + } + z = z.make(1) + z[0] = x + return z +} + +func (z nat) setUint64(x uint64) nat { + // single-word value + if w := Word(x); uint64(w) == x { + return z.setWord(w) + } + // 2-word value + z = z.make(2) + z[1] = Word(x >> 32) + z[0] = Word(x) + return z +} + +func (z nat) set(x nat) nat { + z = z.make(len(x)) + copy(z, x) + return z +} + +func (z nat) add(x, y nat) nat { + m := len(x) + n := len(y) + + switch { + case m < n: + return z.add(y, x) + case m == 0: + // n == 0 because m >= n; result is 0 + return z[:0] + case n == 0: + // result is x + return z.set(x) + } + // m > 0 + + z = z.make(m + 1) + c := addVV(z[0:n], x, y) + if m > n { + c = addVW(z[n:m], x[n:], c) + } + z[m] = c + + return z.norm() +} + +func (z nat) sub(x, y nat) nat { + m := len(x) + n := len(y) + + switch { + case m < n: + panic("underflow") + case m == 0: + // n == 0 because m >= n; result is 0 + return z[:0] + case n == 0: + // result is x + return z.set(x) + } + // m > 0 + + z = z.make(m) + c := subVV(z[0:n], x, y) + if m > n { + c = subVW(z[n:], x[n:], c) + } + if c != 0 { + panic("underflow") + } + + return z.norm() +} + +func (x nat) cmp(y nat) (r int) { + m := len(x) + n := len(y) + if m != n || m == 0 { + switch { + case m < n: + r = -1 + case m > n: + r = 1 + } + return + } + + i := m - 1 + for i > 0 && x[i] == y[i] { + i-- + } + + switch { + case x[i] < y[i]: + r = -1 + case x[i] > y[i]: + r = 1 + } + return +} + +func (z nat) mulAddWW(x nat, y, r Word) nat { + m := len(x) + if m == 0 || y == 0 { + return z.setWord(r) // result is r + } + // m > 0 + + z = z.make(m + 1) + z[m] = mulAddVWW(z[0:m], x, y, r) + + return z.norm() +} + +// basicMul multiplies x and y and leaves the result in z. +// The (non-normalized) result is placed in z[0 : len(x) + len(y)]. +func basicMul(z, x, y nat) { + z[0 : len(x)+len(y)].clear() // initialize z + for i, d := range y { + if d != 0 { + z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) + } + } +} + +// montgomery computes z mod m = x*y*2**(-n*_W) mod m, +// assuming k = -1/m mod 2**_W. +// z is used for storing the result which is returned; +// z must not alias x, y or m. +// See Gueron, "Efficient Software Implementations of Modular Exponentiation". +// https://eprint.iacr.org/2011/239.pdf +// In the terminology of that paper, this is an "Almost Montgomery Multiplication": +// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result +// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m. +func (z nat) montgomery(x, y, m nat, k Word, n int) nat { + // This code assumes x, y, m are all the same length, n. + // (required by addMulVVW and the for loop). + // It also assumes that x, y are already reduced mod m, + // or else the result will not be properly reduced. + if len(x) != n || len(y) != n || len(m) != n { + panic("math/big: mismatched montgomery number lengths") + } + z = z.make(n * 2) + z.clear() + var c Word + for i := 0; i < n; i++ { + d := y[i] + c2 := addMulVVW(z[i:n+i], x, d) + t := z[i] * k + c3 := addMulVVW(z[i:n+i], m, t) + cx := c + c2 + cy := cx + c3 + z[n+i] = cy + if cx < c2 || cy < c3 { + c = 1 + } else { + c = 0 + } + } + if c != 0 { + subVV(z[:n], z[n:], m) + } else { + copy(z[:n], z[n:]) + } + return z[:n] +} + +// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. +// Factored out for readability - do not use outside karatsuba. +func karatsubaAdd(z, x nat, n int) { + if c := addVV(z[0:n], z, x); c != 0 { + addVW(z[n:n+n>>1], z[n:], c) + } +} + +// Like karatsubaAdd, but does subtract. +func karatsubaSub(z, x nat, n int) { + if c := subVV(z[0:n], z, x); c != 0 { + subVW(z[n:n+n>>1], z[n:], c) + } +} + +// Operands that are shorter than karatsubaThreshold are multiplied using +// "grade school" multiplication; for longer operands the Karatsuba algorithm +// is used. +var karatsubaThreshold = 40 // computed by calibrate_test.go + +// karatsuba multiplies x and y and leaves the result in z. +// Both x and y must have the same length n and n must be a +// power of 2. The result vector z must have len(z) >= 6*n. +// The (non-normalized) result is placed in z[0 : 2*n]. +func karatsuba(z, x, y nat) { + n := len(y) + + // Switch to basic multiplication if numbers are odd or small. + // (n is always even if karatsubaThreshold is even, but be + // conservative) + if n&1 != 0 || n < karatsubaThreshold || n < 2 { + basicMul(z, x, y) + return + } + // n&1 == 0 && n >= karatsubaThreshold && n >= 2 + + // Karatsuba multiplication is based on the observation that + // for two numbers x and y with: + // + // x = x1*b + x0 + // y = y1*b + y0 + // + // the product x*y can be obtained with 3 products z2, z1, z0 + // instead of 4: + // + // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 + // = z2*b*b + z1*b + z0 + // + // with: + // + // xd = x1 - x0 + // yd = y0 - y1 + // + // z1 = xd*yd + z2 + z0 + // = (x1-x0)*(y0 - y1) + z2 + z0 + // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 + // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 + // = x1*y0 + x0*y1 + + // split x, y into "digits" + n2 := n >> 1 // n2 >= 1 + x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 + y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 + + // z is used for the result and temporary storage: + // + // 6*n 5*n 4*n 3*n 2*n 1*n 0*n + // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] + // + // For each recursive call of karatsuba, an unused slice of + // z is passed in that has (at least) half the length of the + // caller's z. + + // compute z0 and z2 with the result "in place" in z + karatsuba(z, x0, y0) // z0 = x0*y0 + karatsuba(z[n:], x1, y1) // z2 = x1*y1 + + // compute xd (or the negative value if underflow occurs) + s := 1 // sign of product xd*yd + xd := z[2*n : 2*n+n2] + if subVV(xd, x1, x0) != 0 { // x1-x0 + s = -s + subVV(xd, x0, x1) // x0-x1 + } + + // compute yd (or the negative value if underflow occurs) + yd := z[2*n+n2 : 3*n] + if subVV(yd, y0, y1) != 0 { // y0-y1 + s = -s + subVV(yd, y1, y0) // y1-y0 + } + + // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 + // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 + p := z[n*3:] + karatsuba(p, xd, yd) + + // save original z2:z0 + // (ok to use upper half of z since we're done recurring) + r := z[n*4:] + copy(r, z[:n*2]) + + // add up all partial products + // + // 2*n n 0 + // z = [ z2 | z0 ] + // + [ z0 ] + // + [ z2 ] + // + [ p ] + // + karatsubaAdd(z[n2:], r, n) + karatsubaAdd(z[n2:], r[n:], n) + if s > 0 { + karatsubaAdd(z[n2:], p, n) + } else { + karatsubaSub(z[n2:], p, n) + } +} + +// alias reports whether x and y share the same base array. +// +// Note: alias assumes that the capacity of underlying arrays +// is never changed for nat values; i.e. that there are +// no 3-operand slice expressions in this code (or worse, +// reflect-based operations to the same effect). +func alias(x, y nat) bool { + return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] +} + +// addAt implements z += x<<(_W*i); z must be long enough. +// (we don't use nat.add because we need z to stay the same +// slice, and we don't need to normalize z after each addition) +func addAt(z, x nat, i int) { + if n := len(x); n > 0 { + if c := addVV(z[i:i+n], z[i:], x); c != 0 { + j := i + n + if j < len(z) { + addVW(z[j:], z[j:], c) + } + } + } +} + +func max(x, y int) int { + if x > y { + return x + } + return y +} + +// karatsubaLen computes an approximation to the maximum k <= n such that +// k = p<<i for a number p <= threshold and an i >= 0. Thus, the +// result is the largest number that can be divided repeatedly by 2 before +// becoming about the value of threshold. +func karatsubaLen(n, threshold int) int { + i := uint(0) + for n > threshold { + n >>= 1 + i++ + } + return n << i +} + +func (z nat) mul(x, y nat) nat { + m := len(x) + n := len(y) + + switch { + case m < n: + return z.mul(y, x) + case m == 0 || n == 0: + return z[:0] + case n == 1: + return z.mulAddWW(x, y[0], 0) + } + // m >= n > 1 + + // determine if z can be reused + if alias(z, x) || alias(z, y) { + z = nil // z is an alias for x or y - cannot reuse + } + + // use basic multiplication if the numbers are small + if n < karatsubaThreshold { + z = z.make(m + n) + basicMul(z, x, y) + return z.norm() + } + // m >= n && n >= karatsubaThreshold && n >= 2 + + // determine Karatsuba length k such that + // + // x = xh*b + x0 (0 <= x0 < b) + // y = yh*b + y0 (0 <= y0 < b) + // b = 1<<(_W*k) ("base" of digits xi, yi) + // + k := karatsubaLen(n, karatsubaThreshold) + // k <= n + + // multiply x0 and y0 via Karatsuba + x0 := x[0:k] // x0 is not normalized + y0 := y[0:k] // y0 is not normalized + z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y + karatsuba(z, x0, y0) + z = z[0 : m+n] // z has final length but may be incomplete + z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) + + // If xh != 0 or yh != 0, add the missing terms to z. For + // + // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) + // yh = y1*b (0 <= y1 < b) + // + // the missing terms are + // + // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 + // + // since all the yi for i > 1 are 0 by choice of k: If any of them + // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would + // be a larger valid threshold contradicting the assumption about k. + // + if k < n || m != n { + tp := getNat(3 * k) + t := *tp + + // add x0*y1*b + x0 := x0.norm() + y1 := y[k:] // y1 is normalized because y is + t = t.mul(x0, y1) // update t so we don't lose t's underlying array + addAt(z, t, k) + + // add xi*y0<<i, xi*y1*b<<(i+k) + y0 := y0.norm() + for i := k; i < len(x); i += k { + xi := x[i:] + if len(xi) > k { + xi = xi[:k] + } + xi = xi.norm() + t = t.mul(xi, y0) + addAt(z, t, i) + t = t.mul(xi, y1) + addAt(z, t, i+k) + } + + putNat(tp) + } + + return z.norm() +} + +// basicSqr sets z = x*x and is asymptotically faster than basicMul +// by about a factor of 2, but slower for small arguments due to overhead. +// Requirements: len(x) > 0, len(z) == 2*len(x) +// The (non-normalized) result is placed in z. +func basicSqr(z, x nat) { + n := len(x) + tp := getNat(2 * n) + t := *tp // temporary variable to hold the products + t.clear() + z[1], z[0] = mulWW(x[0], x[0]) // the initial square + for i := 1; i < n; i++ { + d := x[i] + // z collects the squares x[i] * x[i] + z[2*i+1], z[2*i] = mulWW(d, d) + // t collects the products x[i] * x[j] where j < i + t[2*i] = addMulVVW(t[i:2*i], x[0:i], d) + } + t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products + addVV(z, z, t) // combine the result + putNat(tp) +} + +// karatsubaSqr squares x and leaves the result in z. +// len(x) must be a power of 2 and len(z) >= 6*len(x). +// The (non-normalized) result is placed in z[0 : 2*len(x)]. +// +// The algorithm and the layout of z are the same as for karatsuba. +func karatsubaSqr(z, x nat) { + n := len(x) + + if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 { + basicSqr(z[:2*n], x) + return + } + + n2 := n >> 1 + x1, x0 := x[n2:], x[0:n2] + + karatsubaSqr(z, x0) + karatsubaSqr(z[n:], x1) + + // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0 + xd := z[2*n : 2*n+n2] + if subVV(xd, x1, x0) != 0 { + subVV(xd, x0, x1) + } + + p := z[n*3:] + karatsubaSqr(p, xd) + + r := z[n*4:] + copy(r, z[:n*2]) + + karatsubaAdd(z[n2:], r, n) + karatsubaAdd(z[n2:], r[n:], n) + karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0 +} + +// Operands that are shorter than basicSqrThreshold are squared using +// "grade school" multiplication; for operands longer than karatsubaSqrThreshold +// we use the Karatsuba algorithm optimized for x == y. +var basicSqrThreshold = 20 // computed by calibrate_test.go +var karatsubaSqrThreshold = 260 // computed by calibrate_test.go + +// z = x*x +func (z nat) sqr(x nat) nat { + n := len(x) + switch { + case n == 0: + return z[:0] + case n == 1: + d := x[0] + z = z.make(2) + z[1], z[0] = mulWW(d, d) + return z.norm() + } + + if alias(z, x) { + z = nil // z is an alias for x - cannot reuse + } + + if n < basicSqrThreshold { + z = z.make(2 * n) + basicMul(z, x, x) + return z.norm() + } + if n < karatsubaSqrThreshold { + z = z.make(2 * n) + basicSqr(z, x) + return z.norm() + } + + // Use Karatsuba multiplication optimized for x == y. + // The algorithm and layout of z are the same as for mul. + + // z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2 + + k := karatsubaLen(n, karatsubaSqrThreshold) + + x0 := x[0:k] + z = z.make(max(6*k, 2*n)) + karatsubaSqr(z, x0) // z = x0^2 + z = z[0 : 2*n] + z[2*k:].clear() + + if k < n { + tp := getNat(2 * k) + t := *tp + x0 := x0.norm() + x1 := x[k:] + t = t.mul(x0, x1) + addAt(z, t, k) + addAt(z, t, k) // z = 2*x1*x0*b + x0^2 + t = t.sqr(x1) + addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2 + putNat(tp) + } + + return z.norm() +} + +// mulRange computes the product of all the unsigned integers in the +// range [a, b] inclusively. If a > b (empty range), the result is 1. +func (z nat) mulRange(a, b uint64) nat { + switch { + case a == 0: + // cut long ranges short (optimization) + return z.setUint64(0) + case a > b: + return z.setUint64(1) + case a == b: + return z.setUint64(a) + case a+1 == b: + return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) + } + m := (a + b) / 2 + return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) +} + +// getNat returns a *nat of len n. The contents may not be zero. +// The pool holds *nat to avoid allocation when converting to interface{}. +func getNat(n int) *nat { + var z *nat + if v := natPool.Get(); v != nil { + z = v.(*nat) + } + if z == nil { + z = new(nat) + } + *z = z.make(n) + if n > 0 { + (*z)[0] = 0xfedcb // break code expecting zero + } + return z +} + +func putNat(x *nat) { + natPool.Put(x) +} + +var natPool sync.Pool + +// bitLen returns the length of x in bits. +// Unlike most methods, it works even if x is not normalized. +func (x nat) bitLen() int { + // This function is used in cryptographic operations. It must not leak + // anything but the Int's sign and bit size through side-channels. Any + // changes must be reviewed by a security expert. + if i := len(x) - 1; i >= 0 { + // bits.Len uses a lookup table for the low-order bits on some + // architectures. Neutralize any input-dependent behavior by setting all + // bits after the first one bit. + top := uint(x[i]) + top |= top >> 1 + top |= top >> 2 + top |= top >> 4 + top |= top >> 8 + top |= top >> 16 + top |= top >> 16 >> 16 // ">> 32" doesn't compile on 32-bit architectures + return i*_W + bits.Len(top) + } + return 0 +} + +// trailingZeroBits returns the number of consecutive least significant zero +// bits of x. +func (x nat) trailingZeroBits() uint { + if len(x) == 0 { + return 0 + } + var i uint + for x[i] == 0 { + i++ + } + // x[i] != 0 + return i*_W + uint(bits.TrailingZeros(uint(x[i]))) +} + +// isPow2 returns i, true when x == 2**i and 0, false otherwise. +func (x nat) isPow2() (uint, bool) { + var i uint + for x[i] == 0 { + i++ + } + if i == uint(len(x))-1 && x[i]&(x[i]-1) == 0 { + return i*_W + uint(bits.TrailingZeros(uint(x[i]))), true + } + return 0, false +} + +func same(x, y nat) bool { + return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0] +} + +// z = x << s +func (z nat) shl(x nat, s uint) nat { + if s == 0 { + if same(z, x) { + return z + } + if !alias(z, x) { + return z.set(x) + } + } + + m := len(x) + if m == 0 { + return z[:0] + } + // m > 0 + + n := m + int(s/_W) + z = z.make(n + 1) + z[n] = shlVU(z[n-m:n], x, s%_W) + z[0 : n-m].clear() + + return z.norm() +} + +// z = x >> s +func (z nat) shr(x nat, s uint) nat { + if s == 0 { + if same(z, x) { + return z + } + if !alias(z, x) { + return z.set(x) + } + } + + m := len(x) + n := m - int(s/_W) + if n <= 0 { + return z[:0] + } + // n > 0 + + z = z.make(n) + shrVU(z, x[m-n:], s%_W) + + return z.norm() +} + +func (z nat) setBit(x nat, i uint, b uint) nat { + j := int(i / _W) + m := Word(1) << (i % _W) + n := len(x) + switch b { + case 0: + z = z.make(n) + copy(z, x) + if j >= n { + // no need to grow + return z + } + z[j] &^= m + return z.norm() + case 1: + if j >= n { + z = z.make(j + 1) + z[n:].clear() + } else { + z = z.make(n) + } + copy(z, x) + z[j] |= m + // no need to normalize + return z + } + panic("set bit is not 0 or 1") +} + +// bit returns the value of the i'th bit, with lsb == bit 0. +func (x nat) bit(i uint) uint { + j := i / _W + if j >= uint(len(x)) { + return 0 + } + // 0 <= j < len(x) + return uint(x[j] >> (i % _W) & 1) +} + +// sticky returns 1 if there's a 1 bit within the +// i least significant bits, otherwise it returns 0. +func (x nat) sticky(i uint) uint { + j := i / _W + if j >= uint(len(x)) { + if len(x) == 0 { + return 0 + } + return 1 + } + // 0 <= j < len(x) + for _, x := range x[:j] { + if x != 0 { + return 1 + } + } + if x[j]<<(_W-i%_W) != 0 { + return 1 + } + return 0 +} + +func (z nat) and(x, y nat) nat { + m := len(x) + n := len(y) + if m > n { + m = n + } + // m <= n + + z = z.make(m) + for i := 0; i < m; i++ { + z[i] = x[i] & y[i] + } + + return z.norm() +} + +// trunc returns z = x mod 2ⁿ. +func (z nat) trunc(x nat, n uint) nat { + w := (n + _W - 1) / _W + if uint(len(x)) < w { + return z.set(x) + } + z = z.make(int(w)) + copy(z, x) + if n%_W != 0 { + z[len(z)-1] &= 1<<(n%_W) - 1 + } + return z.norm() +} + +func (z nat) andNot(x, y nat) nat { + m := len(x) + n := len(y) + if n > m { + n = m + } + // m >= n + + z = z.make(m) + for i := 0; i < n; i++ { + z[i] = x[i] &^ y[i] + } + copy(z[n:m], x[n:m]) + + return z.norm() +} + +func (z nat) or(x, y nat) nat { + m := len(x) + n := len(y) + s := x + if m < n { + n, m = m, n + s = y + } + // m >= n + + z = z.make(m) + for i := 0; i < n; i++ { + z[i] = x[i] | y[i] + } + copy(z[n:m], s[n:m]) + + return z.norm() +} + +func (z nat) xor(x, y nat) nat { + m := len(x) + n := len(y) + s := x + if m < n { + n, m = m, n + s = y + } + // m >= n + + z = z.make(m) + for i := 0; i < n; i++ { + z[i] = x[i] ^ y[i] + } + copy(z[n:m], s[n:m]) + + return z.norm() +} + +// random creates a random integer in [0..limit), using the space in z if +// possible. n is the bit length of limit. +func (z nat) random(rand *rand.Rand, limit nat, n int) nat { + if alias(z, limit) { + z = nil // z is an alias for limit - cannot reuse + } + z = z.make(len(limit)) + + bitLengthOfMSW := uint(n % _W) + if bitLengthOfMSW == 0 { + bitLengthOfMSW = _W + } + mask := Word((1 << bitLengthOfMSW) - 1) + + for { + switch _W { + case 32: + for i := range z { + z[i] = Word(rand.Uint32()) + } + case 64: + for i := range z { + z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 + } + default: + panic("unknown word size") + } + z[len(limit)-1] &= mask + if z.cmp(limit) < 0 { + break + } + } + + return z.norm() +} + +// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; +// otherwise it sets z to x**y. The result is the value of z. +func (z nat) expNN(x, y, m nat, slow bool) nat { + if alias(z, x) || alias(z, y) { + // We cannot allow in-place modification of x or y. + z = nil + } + + // x**y mod 1 == 0 + if len(m) == 1 && m[0] == 1 { + return z.setWord(0) + } + // m == 0 || m > 1 + + // x**0 == 1 + if len(y) == 0 { + return z.setWord(1) + } + // y > 0 + + // 0**y = 0 + if len(x) == 0 { + return z.setWord(0) + } + // x > 0 + + // 1**y = 1 + if len(x) == 1 && x[0] == 1 { + return z.setWord(1) + } + // x > 1 + + // x**1 == x + if len(y) == 1 && y[0] == 1 { + if len(m) != 0 { + return z.rem(x, m) + } + return z.set(x) + } + // y > 1 + + if len(m) != 0 { + // We likely end up being as long as the modulus. + z = z.make(len(m)) + + // If the exponent is large, we use the Montgomery method for odd values, + // and a 4-bit, windowed exponentiation for powers of two, + // and a CRT-decomposed Montgomery method for the remaining values + // (even values times non-trivial odd values, which decompose into one + // instance of each of the first two cases). + if len(y) > 1 && !slow { + if m[0]&1 == 1 { + return z.expNNMontgomery(x, y, m) + } + if logM, ok := m.isPow2(); ok { + return z.expNNWindowed(x, y, logM) + } + return z.expNNMontgomeryEven(x, y, m) + } + } + + z = z.set(x) + v := y[len(y)-1] // v > 0 because y is normalized and y > 0 + shift := nlz(v) + 1 + v <<= shift + var q nat + + const mask = 1 << (_W - 1) + + // We walk through the bits of the exponent one by one. Each time we + // see a bit, we square, thus doubling the power. If the bit is a one, + // we also multiply by x, thus adding one to the power. + + w := _W - int(shift) + // zz and r are used to avoid allocating in mul and div as + // otherwise the arguments would alias. + var zz, r nat + for j := 0; j < w; j++ { + zz = zz.sqr(z) + zz, z = z, zz + + if v&mask != 0 { + zz = zz.mul(z, x) + zz, z = z, zz + } + + if len(m) != 0 { + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r + } + + v <<= 1 + } + + for i := len(y) - 2; i >= 0; i-- { + v = y[i] + + for j := 0; j < _W; j++ { + zz = zz.sqr(z) + zz, z = z, zz + + if v&mask != 0 { + zz = zz.mul(z, x) + zz, z = z, zz + } + + if len(m) != 0 { + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r + } + + v <<= 1 + } + } + + return z.norm() +} + +// expNNMontgomeryEven calculates x**y mod m where m = m1 × m2 for m1 = 2ⁿ and m2 odd. +// It uses two recursive calls to expNN for x**y mod m1 and x**y mod m2 +// and then uses the Chinese Remainder Theorem to combine the results. +// The recursive call using m1 will use expNNWindowed, +// while the recursive call using m2 will use expNNMontgomery. +// For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”, +// IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994. +// http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf +func (z nat) expNNMontgomeryEven(x, y, m nat) nat { + // Split m = m₁ × m₂ where m₁ = 2ⁿ + n := m.trailingZeroBits() + m1 := nat(nil).shl(natOne, n) + m2 := nat(nil).shr(m, n) + + // We want z = x**y mod m. + // z₁ = x**y mod m1 = (x**y mod m) mod m1 = z mod m1 + // z₂ = x**y mod m2 = (x**y mod m) mod m2 = z mod m2 + // (We are using the math/big convention for names here, + // where the computation is z = x**y mod m, so its parts are z1 and z2. + // The paper is computing x = a**e mod n; it refers to these as x2 and z1.) + z1 := nat(nil).expNN(x, y, m1, false) + z2 := nat(nil).expNN(x, y, m2, false) + + // Reconstruct z from z₁, z₂ using CRT, using algorithm from paper, + // which uses only a single modInverse (and an easy one at that). + // p = (z₁ - z₂) × m₂⁻¹ (mod m₁) + // z = z₂ + p × m₂ + // The final addition is in range because: + // z = z₂ + p × m₂ + // ≤ z₂ + (m₁-1) × m₂ + // < m₂ + (m₁-1) × m₂ + // = m₁ × m₂ + // = m. + z = z.set(z2) + + // Compute (z₁ - z₂) mod m1 [m1 == 2**n] into z1. + z1 = z1.subMod2N(z1, z2, n) + + // Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]). + m2inv := nat(nil).modInverse(m2, m1) + z2 = z2.mul(z1, m2inv) + z2 = z2.trunc(z2, n) + + // Reuse z1 for p * m2. + z = z.add(z, z1.mul(z2, m2)) + + return z +} + +// expNNWindowed calculates x**y mod m using a fixed, 4-bit window, +// where m = 2**logM. +func (z nat) expNNWindowed(x, y nat, logM uint) nat { + if len(y) <= 1 { + panic("big: misuse of expNNWindowed") + } + if x[0]&1 == 0 { + // len(y) > 1, so y > logM. + // x is even, so x**y is a multiple of 2**y which is a multiple of 2**logM. + return z.setWord(0) + } + if logM == 1 { + return z.setWord(1) + } + + // zz is used to avoid allocating in mul as otherwise + // the arguments would alias. + w := int((logM + _W - 1) / _W) + zzp := getNat(w) + zz := *zzp + + const n = 4 + // powers[i] contains x^i. + var powers [1 << n]*nat + for i := range powers { + powers[i] = getNat(w) + } + *powers[0] = powers[0].set(natOne) + *powers[1] = powers[1].trunc(x, logM) + for i := 2; i < 1<<n; i += 2 { + p2, p, p1 := powers[i/2], powers[i], powers[i+1] + *p = p.sqr(*p2) + *p = p.trunc(*p, logM) + *p1 = p1.mul(*p, x) + *p1 = p1.trunc(*p1, logM) + } + + // Because phi(2**logM) = 2**(logM-1), x**(2**(logM-1)) = 1, + // so we can compute x**(y mod 2**(logM-1)) instead of x**y. + // That is, we can throw away all but the bottom logM-1 bits of y. + // Instead of allocating a new y, we start reading y at the right word + // and truncate it appropriately at the start of the loop. + i := len(y) - 1 + mtop := int((logM - 2) / _W) // -2 because the top word of N bits is the (N-1)/W'th word. + mmask := ^Word(0) + if mbits := (logM - 1) & (_W - 1); mbits != 0 { + mmask = (1 << mbits) - 1 + } + if i > mtop { + i = mtop + } + advance := false + z = z.setWord(1) + for ; i >= 0; i-- { + yi := y[i] + if i == mtop { + yi &= mmask + } + for j := 0; j < _W; j += n { + if advance { + // Account for use of 4 bits in previous iteration. + // Unrolled loop for significant performance + // gain. Use go test -bench=".*" in crypto/rsa + // to check performance before making changes. + zz = zz.sqr(z) + zz, z = z, zz + z = z.trunc(z, logM) + + zz = zz.sqr(z) + zz, z = z, zz + z = z.trunc(z, logM) + + zz = zz.sqr(z) + zz, z = z, zz + z = z.trunc(z, logM) + + zz = zz.sqr(z) + zz, z = z, zz + z = z.trunc(z, logM) + } + + zz = zz.mul(z, *powers[yi>>(_W-n)]) + zz, z = z, zz + z = z.trunc(z, logM) + + yi <<= n + advance = true + } + } + + *zzp = zz + putNat(zzp) + for i := range powers { + putNat(powers[i]) + } + + return z.norm() +} + +// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window. +// Uses Montgomery representation. +func (z nat) expNNMontgomery(x, y, m nat) nat { + numWords := len(m) + + // We want the lengths of x and m to be equal. + // It is OK if x >= m as long as len(x) == len(m). + if len(x) > numWords { + _, x = nat(nil).div(nil, x, m) + // Note: now len(x) <= numWords, not guaranteed ==. + } + if len(x) < numWords { + rr := make(nat, numWords) + copy(rr, x) + x = rr + } + + // Ideally the precomputations would be performed outside, and reused + // k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson + // Iteration for Multiplicative Inverses Modulo Prime Powers". + k0 := 2 - m[0] + t := m[0] - 1 + for i := 1; i < _W; i <<= 1 { + t *= t + k0 *= (t + 1) + } + k0 = -k0 + + // RR = 2**(2*_W*len(m)) mod m + RR := nat(nil).setWord(1) + zz := nat(nil).shl(RR, uint(2*numWords*_W)) + _, RR = nat(nil).div(RR, zz, m) + if len(RR) < numWords { + zz = zz.make(numWords) + copy(zz, RR) + RR = zz + } + // one = 1, with equal length to that of m + one := make(nat, numWords) + one[0] = 1 + + const n = 4 + // powers[i] contains x^i + var powers [1 << n]nat + powers[0] = powers[0].montgomery(one, RR, m, k0, numWords) + powers[1] = powers[1].montgomery(x, RR, m, k0, numWords) + for i := 2; i < 1<<n; i++ { + powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords) + } + + // initialize z = 1 (Montgomery 1) + z = z.make(numWords) + copy(z, powers[0]) + + zz = zz.make(numWords) + + // same windowed exponent, but with Montgomery multiplications + for i := len(y) - 1; i >= 0; i-- { + yi := y[i] + for j := 0; j < _W; j += n { + if i != len(y)-1 || j != 0 { + zz = zz.montgomery(z, z, m, k0, numWords) + z = z.montgomery(zz, zz, m, k0, numWords) + zz = zz.montgomery(z, z, m, k0, numWords) + z = z.montgomery(zz, zz, m, k0, numWords) + } + zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords) + z, zz = zz, z + yi <<= n + } + } + // convert to regular number + zz = zz.montgomery(z, one, m, k0, numWords) + + // One last reduction, just in case. + // See golang.org/issue/13907. + if zz.cmp(m) >= 0 { + // Common case is m has high bit set; in that case, + // since zz is the same length as m, there can be just + // one multiple of m to remove. Just subtract. + // We think that the subtract should be sufficient in general, + // so do that unconditionally, but double-check, + // in case our beliefs are wrong. + // The div is not expected to be reached. + zz = zz.sub(zz, m) + if zz.cmp(m) >= 0 { + _, zz = nat(nil).div(nil, zz, m) + } + } + + return zz.norm() +} + +// bytes writes the value of z into buf using big-endian encoding. +// The value of z is encoded in the slice buf[i:]. If the value of z +// cannot be represented in buf, bytes panics. The number i of unused +// bytes at the beginning of buf is returned as result. +func (z nat) bytes(buf []byte) (i int) { + // This function is used in cryptographic operations. It must not leak + // anything but the Int's sign and bit size through side-channels. Any + // changes must be reviewed by a security expert. + i = len(buf) + for _, d := range z { + for j := 0; j < _S; j++ { + i-- + if i >= 0 { + buf[i] = byte(d) + } else if byte(d) != 0 { + panic("math/big: buffer too small to fit value") + } + d >>= 8 + } + } + + if i < 0 { + i = 0 + } + for i < len(buf) && buf[i] == 0 { + i++ + } + + return +} + +// bigEndianWord returns the contents of buf interpreted as a big-endian encoded Word value. +func bigEndianWord(buf []byte) Word { + if _W == 64 { + return Word(binary.BigEndian.Uint64(buf)) + } + return Word(binary.BigEndian.Uint32(buf)) +} + +// setBytes interprets buf as the bytes of a big-endian unsigned +// integer, sets z to that value, and returns z. +func (z nat) setBytes(buf []byte) nat { + z = z.make((len(buf) + _S - 1) / _S) + + i := len(buf) + for k := 0; i >= _S; k++ { + z[k] = bigEndianWord(buf[i-_S : i]) + i -= _S + } + if i > 0 { + var d Word + for s := uint(0); i > 0; s += 8 { + d |= Word(buf[i-1]) << s + i-- + } + z[len(z)-1] = d + } + + return z.norm() +} + +// sqrt sets z = ⌊√x⌋ +func (z nat) sqrt(x nat) nat { + if x.cmp(natOne) <= 0 { + return z.set(x) + } + if alias(z, x) { + z = nil + } + + // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. + // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). + // https://members.loria.fr/PZimmermann/mca/pub226.html + // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1; + // otherwise it converges to the correct z and stays there. + var z1, z2 nat + z1 = z + z1 = z1.setUint64(1) + z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x + for n := 0; ; n++ { + z2, _ = z2.div(nil, x, z1) + z2 = z2.add(z2, z1) + z2 = z2.shr(z2, 1) + if z2.cmp(z1) >= 0 { + // z1 is answer. + // Figure out whether z1 or z2 is currently aliased to z by looking at loop count. + if n&1 == 0 { + return z1 + } + return z.set(z1) + } + z1, z2 = z2, z1 + } +} + +// subMod2N returns z = (x - y) mod 2ⁿ. +func (z nat) subMod2N(x, y nat, n uint) nat { + if uint(x.bitLen()) > n { + if alias(z, x) { + // ok to overwrite x in place + x = x.trunc(x, n) + } else { + x = nat(nil).trunc(x, n) + } + } + if uint(y.bitLen()) > n { + if alias(z, y) { + // ok to overwrite y in place + y = y.trunc(y, n) + } else { + y = nat(nil).trunc(y, n) + } + } + if x.cmp(y) >= 0 { + return z.sub(x, y) + } + // x - y < 0; x - y mod 2ⁿ = x - y + 2ⁿ = 2ⁿ - (y - x) = 1 + 2ⁿ-1 - (y - x) = 1 + ^(y - x). + z = z.sub(y, x) + for uint(len(z))*_W < n { + z = append(z, 0) + } + for i := range z { + z[i] = ^z[i] + } + z = z.trunc(z, n) + return z.add(z, natOne) +} diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go new file mode 100644 index 0000000..b84a7be --- /dev/null +++ b/src/math/big/nat_test.go @@ -0,0 +1,886 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "fmt" + "runtime" + "strings" + "testing" +) + +var cmpTests = []struct { + x, y nat + r int +}{ + {nil, nil, 0}, + {nil, nat(nil), 0}, + {nat(nil), nil, 0}, + {nat(nil), nat(nil), 0}, + {nat{0}, nat{0}, 0}, + {nat{0}, nat{1}, -1}, + {nat{1}, nat{0}, 1}, + {nat{1}, nat{1}, 0}, + {nat{0, _M}, nat{1}, 1}, + {nat{1}, nat{0, _M}, -1}, + {nat{1, _M}, nat{0, _M}, 1}, + {nat{0, _M}, nat{1, _M}, -1}, + {nat{16, 571956, 8794, 68}, nat{837, 9146, 1, 754489}, -1}, + {nat{34986, 41, 105, 1957}, nat{56, 7458, 104, 1957}, 1}, +} + +func TestCmp(t *testing.T) { + for i, a := range cmpTests { + r := a.x.cmp(a.y) + if r != a.r { + t.Errorf("#%d got r = %v; want %v", i, r, a.r) + } + } +} + +type funNN func(z, x, y nat) nat +type argNN struct { + z, x, y nat +} + +var sumNN = []argNN{ + {}, + {nat{1}, nil, nat{1}}, + {nat{1111111110}, nat{123456789}, nat{987654321}}, + {nat{0, 0, 0, 1}, nil, nat{0, 0, 0, 1}}, + {nat{0, 0, 0, 1111111110}, nat{0, 0, 0, 123456789}, nat{0, 0, 0, 987654321}}, + {nat{0, 0, 0, 1}, nat{0, 0, _M}, nat{0, 0, 1}}, +} + +var prodNN = []argNN{ + {}, + {nil, nil, nil}, + {nil, nat{991}, nil}, + {nat{991}, nat{991}, nat{1}}, + {nat{991 * 991}, nat{991}, nat{991}}, + {nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}}, + {nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}}, + {nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}}, + // 3^100 * 3^28 = 3^128 + { + natFromString("11790184577738583171520872861412518665678211592275841109096961"), + natFromString("515377520732011331036461129765621272702107522001"), + natFromString("22876792454961"), + }, + // z = 111....1 (70000 digits) + // x = 10^(99*700) + ... + 10^1400 + 10^700 + 1 + // y = 111....1 (700 digits, larger than Karatsuba threshold on 32-bit and 64-bit) + { + natFromString(strings.Repeat("1", 70000)), + natFromString("1" + strings.Repeat(strings.Repeat("0", 699)+"1", 99)), + natFromString(strings.Repeat("1", 700)), + }, + // z = 111....1 (20000 digits) + // x = 10^10000 + 1 + // y = 111....1 (10000 digits) + { + natFromString(strings.Repeat("1", 20000)), + natFromString("1" + strings.Repeat("0", 9999) + "1"), + natFromString(strings.Repeat("1", 10000)), + }, +} + +func natFromString(s string) nat { + x, _, _, err := nat(nil).scan(strings.NewReader(s), 0, false) + if err != nil { + panic(err) + } + return x +} + +func TestSet(t *testing.T) { + for _, a := range sumNN { + z := nat(nil).set(a.z) + if z.cmp(a.z) != 0 { + t.Errorf("got z = %v; want %v", z, a.z) + } + } +} + +func testFunNN(t *testing.T, msg string, f funNN, a argNN) { + z := f(nil, a.x, a.y) + if z.cmp(a.z) != 0 { + t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z) + } +} + +func TestFunNN(t *testing.T) { + for _, a := range sumNN { + arg := a + testFunNN(t, "add", nat.add, arg) + + arg = argNN{a.z, a.y, a.x} + testFunNN(t, "add symmetric", nat.add, arg) + + arg = argNN{a.x, a.z, a.y} + testFunNN(t, "sub", nat.sub, arg) + + arg = argNN{a.y, a.z, a.x} + testFunNN(t, "sub symmetric", nat.sub, arg) + } + + for _, a := range prodNN { + arg := a + testFunNN(t, "mul", nat.mul, arg) + + arg = argNN{a.z, a.y, a.x} + testFunNN(t, "mul symmetric", nat.mul, arg) + } +} + +var mulRangesN = []struct { + a, b uint64 + prod string +}{ + {0, 0, "0"}, + {1, 1, "1"}, + {1, 2, "2"}, + {1, 3, "6"}, + {10, 10, "10"}, + {0, 100, "0"}, + {0, 1e9, "0"}, + {1, 0, "1"}, // empty range + {100, 1, "1"}, // empty range + {1, 10, "3628800"}, // 10! + {1, 20, "2432902008176640000"}, // 20! + {1, 100, + "933262154439441526816992388562667004907159682643816214685929" + + "638952175999932299156089414639761565182862536979208272237582" + + "51185210916864000000000000000000000000", // 100! + }, +} + +func TestMulRangeN(t *testing.T) { + for i, r := range mulRangesN { + prod := string(nat(nil).mulRange(r.a, r.b).utoa(10)) + if prod != r.prod { + t.Errorf("#%d: got %s; want %s", i, prod, r.prod) + } + } +} + +// allocBytes returns the number of bytes allocated by invoking f. +func allocBytes(f func()) uint64 { + var stats runtime.MemStats + runtime.ReadMemStats(&stats) + t := stats.TotalAlloc + f() + runtime.ReadMemStats(&stats) + return stats.TotalAlloc - t +} + +// TestMulUnbalanced tests that multiplying numbers of different lengths +// does not cause deep recursion and in turn allocate too much memory. +// Test case for issue 3807. +func TestMulUnbalanced(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + x := rndNat(50000) + y := rndNat(40) + allocSize := allocBytes(func() { + nat(nil).mul(x, y) + }) + inputSize := uint64(len(x)+len(y)) * _S + if ratio := allocSize / uint64(inputSize); ratio > 10 { + t.Errorf("multiplication uses too much memory (%d > %d times the size of inputs)", allocSize, ratio) + } +} + +// rndNat returns a random nat value >= 0 of (usually) n words in length. +// In extremely unlikely cases it may be smaller than n words if the top- +// most words are 0. +func rndNat(n int) nat { + return nat(rndV(n)).norm() +} + +// rndNat1 is like rndNat but the result is guaranteed to be > 0. +func rndNat1(n int) nat { + x := nat(rndV(n)).norm() + if len(x) == 0 { + x.setWord(1) + } + return x +} + +func BenchmarkMul(b *testing.B) { + mulx := rndNat(1e4) + muly := rndNat(1e4) + b.ResetTimer() + for i := 0; i < b.N; i++ { + var z nat + z.mul(mulx, muly) + } +} + +func benchmarkNatMul(b *testing.B, nwords int) { + x := rndNat(nwords) + y := rndNat(nwords) + var z nat + b.ResetTimer() + for i := 0; i < b.N; i++ { + z.mul(x, y) + } +} + +var mulBenchSizes = []int{10, 100, 1000, 10000, 100000} + +func BenchmarkNatMul(b *testing.B) { + for _, n := range mulBenchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + benchmarkNatMul(b, n) + }) + } +} + +func TestNLZ(t *testing.T) { + var x Word = _B >> 1 + for i := 0; i <= _W; i++ { + if int(nlz(x)) != i { + t.Errorf("failed at %x: got %d want %d", x, nlz(x), i) + } + x >>= 1 + } +} + +type shiftTest struct { + in nat + shift uint + out nat +} + +var leftShiftTests = []shiftTest{ + {nil, 0, nil}, + {nil, 1, nil}, + {natOne, 0, natOne}, + {natOne, 1, natTwo}, + {nat{1 << (_W - 1)}, 1, nat{0}}, + {nat{1 << (_W - 1), 0}, 1, nat{0, 1}}, +} + +func TestShiftLeft(t *testing.T) { + for i, test := range leftShiftTests { + var z nat + z = z.shl(test.in, test.shift) + for j, d := range test.out { + if j >= len(z) || z[j] != d { + t.Errorf("#%d: got: %v want: %v", i, z, test.out) + break + } + } + } +} + +var rightShiftTests = []shiftTest{ + {nil, 0, nil}, + {nil, 1, nil}, + {natOne, 0, natOne}, + {natOne, 1, nil}, + {natTwo, 1, natOne}, + {nat{0, 1}, 1, nat{1 << (_W - 1)}}, + {nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}}, +} + +func TestShiftRight(t *testing.T) { + for i, test := range rightShiftTests { + var z nat + z = z.shr(test.in, test.shift) + for j, d := range test.out { + if j >= len(z) || z[j] != d { + t.Errorf("#%d: got: %v want: %v", i, z, test.out) + break + } + } + } +} + +func BenchmarkZeroShifts(b *testing.B) { + x := rndNat(800) + + b.Run("Shl", func(b *testing.B) { + for i := 0; i < b.N; i++ { + var z nat + z.shl(x, 0) + } + }) + b.Run("ShlSame", func(b *testing.B) { + for i := 0; i < b.N; i++ { + x.shl(x, 0) + } + }) + + b.Run("Shr", func(b *testing.B) { + for i := 0; i < b.N; i++ { + var z nat + z.shr(x, 0) + } + }) + b.Run("ShrSame", func(b *testing.B) { + for i := 0; i < b.N; i++ { + x.shr(x, 0) + } + }) +} + +type modWTest struct { + in string + dividend string + out string +} + +var modWTests32 = []modWTest{ + {"23492635982634928349238759823742", "252341", "220170"}, +} + +var modWTests64 = []modWTest{ + {"6527895462947293856291561095690465243862946", "524326975699234", "375066989628668"}, +} + +func runModWTests(t *testing.T, tests []modWTest) { + for i, test := range tests { + in, _ := new(Int).SetString(test.in, 10) + d, _ := new(Int).SetString(test.dividend, 10) + out, _ := new(Int).SetString(test.out, 10) + + r := in.abs.modW(d.abs[0]) + if r != out.abs[0] { + t.Errorf("#%d failed: got %d want %s", i, r, out) + } + } +} + +func TestModW(t *testing.T) { + if _W >= 32 { + runModWTests(t, modWTests32) + } + if _W >= 64 { + runModWTests(t, modWTests64) + } +} + +var montgomeryTests = []struct { + x, y, m string + k0 uint64 + out32, out64 string +}{ + { + "0xffffffffffffffffffffffffffffffffffffffffffffffffe", + "0xffffffffffffffffffffffffffffffffffffffffffffffffe", + "0xfffffffffffffffffffffffffffffffffffffffffffffffff", + 1, + "0x1000000000000000000000000000000000000000000", + "0x10000000000000000000000000000000000", + }, + { + "0x000000000ffffff5", + "0x000000000ffffff0", + "0x0000000010000001", + 0xff0000000fffffff, + "0x000000000bfffff4", + "0x0000000003400001", + }, + { + "0x0000000080000000", + "0x00000000ffffffff", + "0x1000000000000001", + 0xfffffffffffffff, + "0x0800000008000001", + "0x0800000008000001", + }, + { + "0x0000000080000000", + "0x0000000080000000", + "0xffffffff00000001", + 0xfffffffeffffffff, + "0xbfffffff40000001", + "0xbfffffff40000001", + }, + { + "0x0000000080000000", + "0x0000000080000000", + "0x00ffffff00000001", + 0xfffffeffffffff, + "0xbfffff40000001", + "0xbfffff40000001", + }, + { + "0x0000000080000000", + "0x0000000080000000", + "0x0000ffff00000001", + 0xfffeffffffff, + "0xbfff40000001", + "0xbfff40000001", + }, + { + "0x3321ffffffffffffffffffffffffffff00000000000022222623333333332bbbb888c0", + "0x3321ffffffffffffffffffffffffffff00000000000022222623333333332bbbb888c0", + "0x33377fffffffffffffffffffffffffffffffffffffffffffff0000000000022222eee1", + 0xdecc8f1249812adf, + "0x04eb0e11d72329dc0915f86784820fc403275bf2f6620a20e0dd344c5cd0875e50deb5", + "0x0d7144739a7d8e11d72329dc0915f86784820fc403275bf2f61ed96f35dd34dbb3d6a0", + }, + { + "0x10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000ffffffffffffffffffffffffffffffff00000000000022222223333333333444444444", + "0x10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000ffffffffffffffffffffffffffffffff999999999999999aaabbbbbbbbcccccccccccc", + "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff33377fffffffffffffffffffffffffffffffffffffffffffff0000000000022222eee1", + 0xdecc8f1249812adf, + "0x5c0d52f451aec609b15da8e5e5626c4eaa88723bdeac9d25ca9b961269400410ca208a16af9c2fb07d7a11c7772cba02c22f9711078d51a3797eb18e691295293284d988e349fa6deba46b25a4ecd9f715", + "0x92fcad4b5c0d52f451aec609b15da8e5e5626c4eaa88723bdeac9d25ca9b961269400410ca208a16af9c2fb07d799c32fe2f3cc5422f9711078d51a3797eb18e691295293284d8f5e69caf6decddfe1df6", + }, +} + +func TestMontgomery(t *testing.T) { + one := NewInt(1) + _B := new(Int).Lsh(one, _W) + for i, test := range montgomeryTests { + x := natFromString(test.x) + y := natFromString(test.y) + m := natFromString(test.m) + for len(x) < len(m) { + x = append(x, 0) + } + for len(y) < len(m) { + y = append(y, 0) + } + + if x.cmp(m) > 0 { + _, r := nat(nil).div(nil, x, m) + t.Errorf("#%d: x > m (0x%s > 0x%s; use 0x%s)", i, x.utoa(16), m.utoa(16), r.utoa(16)) + } + if y.cmp(m) > 0 { + _, r := nat(nil).div(nil, x, m) + t.Errorf("#%d: y > m (0x%s > 0x%s; use 0x%s)", i, y.utoa(16), m.utoa(16), r.utoa(16)) + } + + var out nat + if _W == 32 { + out = natFromString(test.out32) + } else { + out = natFromString(test.out64) + } + + // t.Logf("#%d: len=%d\n", i, len(m)) + + // check output in table + xi := &Int{abs: x} + yi := &Int{abs: y} + mi := &Int{abs: m} + p := new(Int).Mod(new(Int).Mul(xi, new(Int).Mul(yi, new(Int).ModInverse(new(Int).Lsh(one, uint(len(m))*_W), mi))), mi) + if out.cmp(p.abs.norm()) != 0 { + t.Errorf("#%d: out in table=0x%s, computed=0x%s", i, out.utoa(16), p.abs.norm().utoa(16)) + } + + // check k0 in table + k := new(Int).Mod(&Int{abs: m}, _B) + k = new(Int).Sub(_B, k) + k = new(Int).Mod(k, _B) + k0 := Word(new(Int).ModInverse(k, _B).Uint64()) + if k0 != Word(test.k0) { + t.Errorf("#%d: k0 in table=%#x, computed=%#x\n", i, test.k0, k0) + } + + // check montgomery with correct k0 produces correct output + z := nat(nil).montgomery(x, y, m, k0, len(m)) + z = z.norm() + if z.cmp(out) != 0 { + t.Errorf("#%d: got 0x%s want 0x%s", i, z.utoa(16), out.utoa(16)) + } + } +} + +var expNNTests = []struct { + x, y, m string + out string +}{ + {"0", "0", "0", "1"}, + {"0", "0", "1", "0"}, + {"1", "1", "1", "0"}, + {"2", "1", "1", "0"}, + {"2", "2", "1", "0"}, + {"10", "100000000000", "1", "0"}, + {"0x8000000000000000", "2", "", "0x40000000000000000000000000000000"}, + {"0x8000000000000000", "2", "6719", "4944"}, + {"0x8000000000000000", "3", "6719", "5447"}, + {"0x8000000000000000", "1000", "6719", "1603"}, + {"0x8000000000000000", "1000000", "6719", "3199"}, + { + "2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347", + "298472983472983471903246121093472394872319615612417471234712061", + "29834729834729834729347290846729561262544958723956495615629569234729836259263598127342374289365912465901365498236492183464", + "23537740700184054162508175125554701713153216681790245129157191391322321508055833908509185839069455749219131480588829346291", + }, + { + "11521922904531591643048817447554701904414021819823889996244743037378330903763518501116638828335352811871131385129455853417360623007349090150042001944696604737499160174391019030572483602867266711107136838523916077674888297896995042968746762200926853379", + "426343618817810911523", + "444747819283133684179", + "42", + }, + {"375", "249", "388", "175"}, + {"375", "18446744073709551801", "388", "175"}, + {"0", "0x40000000000000", "0x200", "0"}, + {"0xeffffff900002f00", "0x40000000000000", "0x200", "0"}, + {"5", "1435700818", "72", "49"}, + {"0xffff", "0x300030003000300030003000300030003000302a3000300030003000300030003000300030003000300030003000300030003030623066307f3030783062303430383064303630343036", "0x300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", "0xa3f94c08b0b90e87af637cacc9383f7ea032352b8961fc036a52b659b6c9b33491b335ffd74c927f64ddd62cfca0001"}, +} + +func TestExpNN(t *testing.T) { + for i, test := range expNNTests { + x := natFromString(test.x) + y := natFromString(test.y) + out := natFromString(test.out) + + var m nat + if len(test.m) > 0 { + m = natFromString(test.m) + } + + z := nat(nil).expNN(x, y, m, false) + if z.cmp(out) != 0 { + t.Errorf("#%d got %s want %s", i, z.utoa(10), out.utoa(10)) + } + } +} + +func FuzzExpMont(f *testing.F) { + f.Fuzz(func(t *testing.T, x1, x2, x3, y1, y2, y3, m1, m2, m3 uint) { + if m1 == 0 && m2 == 0 && m3 == 0 { + return + } + x := new(Int).SetBits([]Word{Word(x1), Word(x2), Word(x3)}) + y := new(Int).SetBits([]Word{Word(y1), Word(y2), Word(y3)}) + m := new(Int).SetBits([]Word{Word(m1), Word(m2), Word(m3)}) + out := new(Int).Exp(x, y, m) + want := new(Int).expSlow(x, y, m) + if out.Cmp(want) != 0 { + t.Errorf("x = %#x\ny=%#x\nz=%#x\nout=%#x\nwant=%#x\ndc: 16o 16i %X %X %X |p", x, y, m, out, want, x, y, m) + } + }) +} + +func BenchmarkExp3Power(b *testing.B) { + const x = 3 + for _, y := range []Word{ + 0x10, 0x40, 0x100, 0x400, 0x1000, 0x4000, 0x10000, 0x40000, 0x100000, 0x400000, + } { + b.Run(fmt.Sprintf("%#x", y), func(b *testing.B) { + var z nat + for i := 0; i < b.N; i++ { + z.expWW(x, y) + } + }) + } +} + +func fibo(n int) nat { + switch n { + case 0: + return nil + case 1: + return nat{1} + } + f0 := fibo(0) + f1 := fibo(1) + var f2 nat + for i := 1; i < n; i++ { + f2 = f2.add(f0, f1) + f0, f1, f2 = f1, f2, f0 + } + return f1 +} + +var fiboNums = []string{ + "0", + "55", + "6765", + "832040", + "102334155", + "12586269025", + "1548008755920", + "190392490709135", + "23416728348467685", + "2880067194370816120", + "354224848179261915075", +} + +func TestFibo(t *testing.T) { + for i, want := range fiboNums { + n := i * 10 + got := string(fibo(n).utoa(10)) + if got != want { + t.Errorf("fibo(%d) failed: got %s want %s", n, got, want) + } + } +} + +func BenchmarkFibo(b *testing.B) { + for i := 0; i < b.N; i++ { + fibo(1e0) + fibo(1e1) + fibo(1e2) + fibo(1e3) + fibo(1e4) + fibo(1e5) + } +} + +var bitTests = []struct { + x string + i uint + want uint +}{ + {"0", 0, 0}, + {"0", 1, 0}, + {"0", 1000, 0}, + + {"0x1", 0, 1}, + {"0x10", 0, 0}, + {"0x10", 3, 0}, + {"0x10", 4, 1}, + {"0x10", 5, 0}, + + {"0x8000000000000000", 62, 0}, + {"0x8000000000000000", 63, 1}, + {"0x8000000000000000", 64, 0}, + + {"0x3" + strings.Repeat("0", 32), 127, 0}, + {"0x3" + strings.Repeat("0", 32), 128, 1}, + {"0x3" + strings.Repeat("0", 32), 129, 1}, + {"0x3" + strings.Repeat("0", 32), 130, 0}, +} + +func TestBit(t *testing.T) { + for i, test := range bitTests { + x := natFromString(test.x) + if got := x.bit(test.i); got != test.want { + t.Errorf("#%d: %s.bit(%d) = %v; want %v", i, test.x, test.i, got, test.want) + } + } +} + +var stickyTests = []struct { + x string + i uint + want uint +}{ + {"0", 0, 0}, + {"0", 1, 0}, + {"0", 1000, 0}, + + {"0x1", 0, 0}, + {"0x1", 1, 1}, + + {"0x1350", 0, 0}, + {"0x1350", 4, 0}, + {"0x1350", 5, 1}, + + {"0x8000000000000000", 63, 0}, + {"0x8000000000000000", 64, 1}, + + {"0x1" + strings.Repeat("0", 100), 400, 0}, + {"0x1" + strings.Repeat("0", 100), 401, 1}, +} + +func TestSticky(t *testing.T) { + for i, test := range stickyTests { + x := natFromString(test.x) + if got := x.sticky(test.i); got != test.want { + t.Errorf("#%d: %s.sticky(%d) = %v; want %v", i, test.x, test.i, got, test.want) + } + if test.want == 1 { + // all subsequent i's should also return 1 + for d := uint(1); d <= 3; d++ { + if got := x.sticky(test.i + d); got != 1 { + t.Errorf("#%d: %s.sticky(%d) = %v; want %v", i, test.x, test.i+d, got, 1) + } + } + } + } +} + +func testSqr(t *testing.T, x nat) { + got := make(nat, 2*len(x)) + want := make(nat, 2*len(x)) + got = got.sqr(x) + want = want.mul(x, x) + if got.cmp(want) != 0 { + t.Errorf("basicSqr(%v), got %v, want %v", x, got, want) + } +} + +func TestSqr(t *testing.T) { + for _, a := range prodNN { + if a.x != nil { + testSqr(t, a.x) + } + if a.y != nil { + testSqr(t, a.y) + } + if a.z != nil { + testSqr(t, a.z) + } + } +} + +func benchmarkNatSqr(b *testing.B, nwords int) { + x := rndNat(nwords) + var z nat + b.ResetTimer() + for i := 0; i < b.N; i++ { + z.sqr(x) + } +} + +var sqrBenchSizes = []int{ + 1, 2, 3, 5, 8, 10, 20, 30, 50, 80, + 100, 200, 300, 500, 800, + 1000, 10000, 100000, +} + +func BenchmarkNatSqr(b *testing.B) { + for _, n := range sqrBenchSizes { + if isRaceBuilder && n > 1e3 { + continue + } + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + benchmarkNatSqr(b, n) + }) + } +} + +var subMod2NTests = []struct { + x string + y string + n uint + z string +}{ + {"1", "2", 0, "0"}, + {"1", "0", 1, "1"}, + {"0", "1", 1, "1"}, + {"3", "5", 3, "6"}, + {"5", "3", 3, "2"}, + // 2^65, 2^66-1, 2^65 - (2^66-1) + 2^67 + {"36893488147419103232", "73786976294838206463", 67, "110680464442257309697"}, + // 2^66-1, 2^65, 2^65-1 + {"73786976294838206463", "36893488147419103232", 67, "36893488147419103231"}, +} + +func TestNatSubMod2N(t *testing.T) { + for _, mode := range []string{"noalias", "aliasX", "aliasY"} { + t.Run(mode, func(t *testing.T) { + for _, tt := range subMod2NTests { + x0 := natFromString(tt.x) + y0 := natFromString(tt.y) + want := natFromString(tt.z) + x := nat(nil).set(x0) + y := nat(nil).set(y0) + var z nat + switch mode { + case "aliasX": + z = x + case "aliasY": + z = y + } + z = z.subMod2N(x, y, tt.n) + if z.cmp(want) != 0 { + t.Fatalf("subMod2N(%d, %d, %d) = %d, want %d", x0, y0, tt.n, z, want) + } + if mode != "aliasX" && x.cmp(x0) != 0 { + t.Fatalf("subMod2N(%d, %d, %d) modified x", x0, y0, tt.n) + } + if mode != "aliasY" && y.cmp(y0) != 0 { + t.Fatalf("subMod2N(%d, %d, %d) modified y", x0, y0, tt.n) + } + } + }) + } +} + +func BenchmarkNatSetBytes(b *testing.B) { + const maxLength = 128 + lengths := []int{ + // No remainder: + 8, 24, maxLength, + // With remainder: + 7, 23, maxLength - 1, + } + n := make(nat, maxLength/_W) // ensure n doesn't need to grow during the test + buf := make([]byte, maxLength) + for _, l := range lengths { + b.Run(fmt.Sprint(l), func(b *testing.B) { + for i := 0; i < b.N; i++ { + n.setBytes(buf[:l]) + } + }) + } +} + +func TestNatDiv(t *testing.T) { + sizes := []int{ + 1, 2, 5, 8, 15, 25, 40, 65, 100, + 200, 500, 800, 1500, 2500, 4000, 6500, 10000, + } + for _, i := range sizes { + for _, j := range sizes { + a := rndNat1(i) + b := rndNat1(j) + // the test requires b >= 2 + if len(b) == 1 && b[0] == 1 { + b[0] = 2 + } + // choose a remainder c < b + c := rndNat1(len(b)) + if len(c) == len(b) && c[len(c)-1] >= b[len(b)-1] { + c[len(c)-1] = 0 + c = c.norm() + } + // compute x = a*b+c + x := nat(nil).mul(a, b) + x = x.add(x, c) + + var q, r nat + q, r = q.div(r, x, b) + if q.cmp(a) != 0 { + t.Fatalf("wrong quotient: got %s; want %s for %s/%s", q.utoa(10), a.utoa(10), x.utoa(10), b.utoa(10)) + } + if r.cmp(c) != 0 { + t.Fatalf("wrong remainder: got %s; want %s for %s/%s", r.utoa(10), c.utoa(10), x.utoa(10), b.utoa(10)) + } + } + } +} + +// TestIssue37499 triggers the edge case of divBasic where +// the inaccurate estimate of the first word's quotient +// happens at the very beginning of the loop. +func TestIssue37499(t *testing.T) { + // Choose u and v such that v is slightly larger than u >> N. + // This tricks divBasic into choosing 1 as the first word + // of the quotient. This works in both 32-bit and 64-bit settings. + u := natFromString("0x2b6c385a05be027f5c22005b63c42a1165b79ff510e1706b39f8489c1d28e57bb5ba4ef9fd9387a3e344402c0a453381") + v := natFromString("0x2b6c385a05be027f5c22005b63c42a1165b79ff510e1706c") + + q := nat(nil).make(8) + q.divBasic(u, v) + q = q.norm() + if s := string(q.utoa(16)); s != "fffffffffffffffffffffffffffffffffffffffffffffffb" { + t.Fatalf("incorrect quotient: %s", s) + } +} + +// TestIssue42552 triggers an edge case of recursive division +// where the first division loop is never entered, and correcting +// the remainder takes exactly two iterations in the final loop. +func TestIssue42552(t *testing.T) { + u := natFromString("0xc23b166884c3869092a520eceedeced2b00847bd256c9cf3b2c5e2227c15bd5e6ee7ef8a2f49236ad0eedf2c8a3b453cf6e0706f64285c526b372c4b1321245519d430540804a50b7ca8b6f1b34a2ec05cdbc24de7599af112d3e3c8db347e8799fe70f16e43c6566ba3aeb169463a3ecc486172deb2d9b80a3699c776e44fef20036bd946f1b4d054dd88a2c1aeb986199b0b2b7e58c42288824b74934d112fe1fc06e06b4d99fe1c5e725946b23210521e209cd507cce90b5f39a523f27e861f9e232aee50c3f585208b4573dcc0b897b6177f2ba20254fd5c50a033e849dee1b3a93bd2dc44ba8ca836cab2c2ae50e50b126284524fa0187af28628ff0face68d87709200329db1392852c8b8963fbe3d05fb1efe19f0ed5ca9fadc2f96f82187c24bb2512b2e85a66333a7e176605695211e1c8e0b9b9e82813e50654964945b1e1e66a90840396c7d10e23e47f364d2d3f660fa54598e18d1ca2ea4fe4f35a40a11f69f201c80b48eaee3e2e9b0eda63decf92bec08a70f731587d4ed0f218d5929285c8b2ccbc497e20db42de73885191fa453350335990184d8df805072f958d5354debda38f5421effaaafd6cb9b721ace74be0892d77679f62a4a126697cd35797f6858193da4ba1770c06aea2e5c59ec04b8ea26749e61b72ecdde403f3bc7e5e546cd799578cc939fa676dfd5e648576d4a06cbadb028adc2c0b461f145b2321f42e5e0f3b4fb898ecd461df07a6f5154067787bf74b5cc5c03704a1ce47494961931f0263b0aac32505102595957531a2de69dd71aac51f8a49902f81f21283dbe8e21e01e5d82517868826f86acf338d935aa6b4d5a25c8d540389b277dd9d64569d68baf0f71bd03dba45b92a7fc052601d1bd011a2fc6790a23f97c6fa5caeea040ab86841f268d39ce4f7caf01069df78bba098e04366492f0c2ac24f1bf16828752765fa523c9a4d42b71109d123e6be8c7b1ab3ccf8ea03404075fe1a9596f1bba1d267f9a7879ceece514818316c9c0583469d2367831fc42b517ea028a28df7c18d783d16ea2436cee2b15d52db68b5dfdee6b4d26f0905f9b030c911a04d078923a4136afea96eed6874462a482917353264cc9bee298f167ac65a6db4e4eda88044b39cc0b33183843eaa946564a00c3a0ab661f2c915e70bf0bb65bfbb6fa2eea20aed16bf2c1a1d00ec55fb4ff2f76b8e462ea70c19efa579c9ee78194b86708fdae66a9ce6e2cf3d366037798cfb50277ba6d2fd4866361022fd788ab7735b40b8b61d55e32243e06719e53992e9ac16c9c4b6e6933635c3c47c8f7e73e17dd54d0dd8aeba5d76de46894e7b3f9d3ec25ad78ee82297ba69905ea0fa094b8667faa2b8885e2187b3da80268aa1164761d7b0d6de206b676777348152b8ae1d4afed753bc63c739a5ca8ce7afb2b241a226bd9e502baba391b5b13f5054f070b65a9cf3a67063bfaa803ba390732cd03888f664023f888741d04d564e0b5674b0a183ace81452001b3fbb4214c77d42ca75376742c471e58f67307726d56a1032bd236610cbcbcd03d0d7a452900136897dc55bb3ce959d10d4e6a10fb635006bd8c41cd9ded2d3dfdd8f2e229590324a7370cb2124210b2330f4c56155caa09a2564932ceded8d92c79664dcdeb87faad7d3da006cc2ea267ee3df41e9677789cc5a8cc3b83add6491561b3047919e0648b1b2e97d7ad6f6c2aa80cab8e9ae10e1f75b1fdd0246151af709d259a6a0ed0b26bd711024965ecad7c41387de45443defce53f66612948694a6032279131c257119ed876a8e805dfb49576ef5c563574115ee87050d92d191bc761ef51d966918e2ef925639400069e3959d8fe19f36136e947ff430bf74e71da0aa5923b00000000") + v := natFromString("0x838332321d443a3d30373d47301d47073847473a383d3030f25b3d3d3e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002e00000000000000000041603038331c3d32f5303441e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e01c0a5459bfc7b9be9fcbb9d2383840464319434707303030f43a32f53034411c0a5459413820878787878787878787878787878787878787878787878787878787878787878787870630303a3a30334036605b923a6101f83638413943413960204337602043323801526040523241846038414143015238604060328452413841413638523c0240384141364036605b923a6101f83638413943413960204334602043323801526040523241846038414143015238604060328452413841413638523c02403841413638433030f25a8b83838383838383838383838383838383837d838383ffffffffffffffff838383838383838383000000000000000000030000007d26e27c7c8b83838383838383838383838383838383837d838383ffffffffffffffff83838383838383838383838383838383838383838383435960f535073030f3343200000000000000011881301938343030fa398383300000002300000000000000000000f11af4600c845252904141364138383c60406032414443095238010241414303364443434132305b595a15434160b042385341ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff47476043410536613603593a6005411c437405fcfcfcfcfcfcfc0000000000005a3b075815054359000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + q := nat(nil).make(16) + q.div(q, u, v) +} diff --git a/src/math/big/natconv.go b/src/math/big/natconv.go new file mode 100644 index 0000000..ce94f2c --- /dev/null +++ b/src/math/big/natconv.go @@ -0,0 +1,511 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements nat-to-string conversion functions. + +package big + +import ( + "errors" + "fmt" + "io" + "math" + "math/bits" + "sync" +) + +const digits = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +// Note: MaxBase = len(digits), but it must remain an untyped rune constant +// for API compatibility. + +// MaxBase is the largest number base accepted for string conversions. +const MaxBase = 10 + ('z' - 'a' + 1) + ('Z' - 'A' + 1) +const maxBaseSmall = 10 + ('z' - 'a' + 1) + +// maxPow returns (b**n, n) such that b**n is the largest power b**n <= _M. +// For instance maxPow(10) == (1e19, 19) for 19 decimal digits in a 64bit Word. +// In other words, at most n digits in base b fit into a Word. +// TODO(gri) replace this with a table, generated at build time. +func maxPow(b Word) (p Word, n int) { + p, n = b, 1 // assuming b <= _M + for max := _M / b; p <= max; { + // p == b**n && p <= max + p *= b + n++ + } + // p == b**n && p <= _M + return +} + +// pow returns x**n for n > 0, and 1 otherwise. +func pow(x Word, n int) (p Word) { + // n == sum of bi * 2**i, for 0 <= i < imax, and bi is 0 or 1 + // thus x**n == product of x**(2**i) for all i where bi == 1 + // (Russian Peasant Method for exponentiation) + p = 1 + for n > 0 { + if n&1 != 0 { + p *= x + } + x *= x + n >>= 1 + } + return +} + +// scan errors +var ( + errNoDigits = errors.New("number has no digits") + errInvalSep = errors.New("'_' must separate successive digits") +) + +// scan scans the number corresponding to the longest possible prefix +// from r representing an unsigned number in a given conversion base. +// scan returns the corresponding natural number res, the actual base b, +// a digit count, and a read or syntax error err, if any. +// +// For base 0, an underscore character “_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number, or the returned +// digit count. Incorrect placement of underscores is reported as an +// error if there are no other errors. If base != 0, underscores are +// not recognized and thus terminate scanning like any other character +// that is not a valid radix point or digit. +// +// number = mantissa | prefix pmantissa . +// prefix = "0" [ "b" | "B" | "o" | "O" | "x" | "X" ] . +// mantissa = digits "." [ digits ] | digits | "." digits . +// pmantissa = [ "_" ] digits "." [ digits ] | [ "_" ] digits | "." digits . +// digits = digit { [ "_" ] digit } . +// digit = "0" ... "9" | "a" ... "z" | "A" ... "Z" . +// +// Unless fracOk is set, the base argument must be 0 or a value between +// 2 and MaxBase. If fracOk is set, the base argument must be one of +// 0, 2, 8, 10, or 16. Providing an invalid base argument leads to a run- +// time panic. +// +// For base 0, the number prefix determines the actual base: A prefix of +// “0b” or “0B” selects base 2, “0o” or “0O” selects base 8, and +// “0x” or “0X” selects base 16. If fracOk is false, a “0” prefix +// (immediately followed by digits) selects base 8 as well. Otherwise, +// the selected base is 10 and no prefix is accepted. +// +// If fracOk is set, a period followed by a fractional part is permitted. +// The result value is computed as if there were no period present; and +// the count value is used to determine the fractional part. +// +// For bases <= 36, lower and upper case letters are considered the same: +// The letters 'a' to 'z' and 'A' to 'Z' represent digit values 10 to 35. +// For bases > 36, the upper case letters 'A' to 'Z' represent the digit +// values 36 to 61. +// +// A result digit count > 0 corresponds to the number of (non-prefix) digits +// parsed. A digit count <= 0 indicates the presence of a period (if fracOk +// is set, only), and -count is the number of fractional digits found. +// In this case, the actual value of the scanned number is res * b**count. +func (z nat) scan(r io.ByteScanner, base int, fracOk bool) (res nat, b, count int, err error) { + // reject invalid bases + baseOk := base == 0 || + !fracOk && 2 <= base && base <= MaxBase || + fracOk && (base == 2 || base == 8 || base == 10 || base == 16) + if !baseOk { + panic(fmt.Sprintf("invalid number base %d", base)) + } + + // prev encodes the previously seen char: it is one + // of '_', '0' (a digit), or '.' (anything else). A + // valid separator '_' may only occur after a digit + // and if base == 0. + prev := '.' + invalSep := false + + // one char look-ahead + ch, err := r.ReadByte() + + // determine actual base + b, prefix := base, 0 + if base == 0 { + // actual base is 10 unless there's a base prefix + b = 10 + if err == nil && ch == '0' { + prev = '0' + count = 1 + ch, err = r.ReadByte() + if err == nil { + // possibly one of 0b, 0B, 0o, 0O, 0x, 0X + switch ch { + case 'b', 'B': + b, prefix = 2, 'b' + case 'o', 'O': + b, prefix = 8, 'o' + case 'x', 'X': + b, prefix = 16, 'x' + default: + if !fracOk { + b, prefix = 8, '0' + } + } + if prefix != 0 { + count = 0 // prefix is not counted + if prefix != '0' { + ch, err = r.ReadByte() + } + } + } + } + } + + // convert string + // Algorithm: Collect digits in groups of at most n digits in di + // and then use mulAddWW for every such group to add them to the + // result. + z = z[:0] + b1 := Word(b) + bn, n := maxPow(b1) // at most n digits in base b1 fit into Word + di := Word(0) // 0 <= di < b1**i < bn + i := 0 // 0 <= i < n + dp := -1 // position of decimal point + for err == nil { + if ch == '.' && fracOk { + fracOk = false + if prev == '_' { + invalSep = true + } + prev = '.' + dp = count + } else if ch == '_' && base == 0 { + if prev != '0' { + invalSep = true + } + prev = '_' + } else { + // convert rune into digit value d1 + var d1 Word + switch { + case '0' <= ch && ch <= '9': + d1 = Word(ch - '0') + case 'a' <= ch && ch <= 'z': + d1 = Word(ch - 'a' + 10) + case 'A' <= ch && ch <= 'Z': + if b <= maxBaseSmall { + d1 = Word(ch - 'A' + 10) + } else { + d1 = Word(ch - 'A' + maxBaseSmall) + } + default: + d1 = MaxBase + 1 + } + if d1 >= b1 { + r.UnreadByte() // ch does not belong to number anymore + break + } + prev = '0' + count++ + + // collect d1 in di + di = di*b1 + d1 + i++ + + // if di is "full", add it to the result + if i == n { + z = z.mulAddWW(z, bn, di) + di = 0 + i = 0 + } + } + + ch, err = r.ReadByte() + } + + if err == io.EOF { + err = nil + } + + // other errors take precedence over invalid separators + if err == nil && (invalSep || prev == '_') { + err = errInvalSep + } + + if count == 0 { + // no digits found + if prefix == '0' { + // there was only the octal prefix 0 (possibly followed by separators and digits > 7); + // interpret as decimal 0 + return z[:0], 10, 1, err + } + err = errNoDigits // fall through; result will be 0 + } + + // add remaining digits to result + if i > 0 { + z = z.mulAddWW(z, pow(b1, i), di) + } + res = z.norm() + + // adjust count for fraction, if any + if dp >= 0 { + // 0 <= dp <= count + count = dp - count + } + + return +} + +// utoa converts x to an ASCII representation in the given base; +// base must be between 2 and MaxBase, inclusive. +func (x nat) utoa(base int) []byte { + return x.itoa(false, base) +} + +// itoa is like utoa but it prepends a '-' if neg && x != 0. +func (x nat) itoa(neg bool, base int) []byte { + if base < 2 || base > MaxBase { + panic("invalid base") + } + + // x == 0 + if len(x) == 0 { + return []byte("0") + } + // len(x) > 0 + + // allocate buffer for conversion + i := int(float64(x.bitLen())/math.Log2(float64(base))) + 1 // off by 1 at most + if neg { + i++ + } + s := make([]byte, i) + + // convert power of two and non power of two bases separately + if b := Word(base); b == b&-b { + // shift is base b digit size in bits + shift := uint(bits.TrailingZeros(uint(b))) // shift > 0 because b >= 2 + mask := Word(1<<shift - 1) + w := x[0] // current word + nbits := uint(_W) // number of unprocessed bits in w + + // convert less-significant words (include leading zeros) + for k := 1; k < len(x); k++ { + // convert full digits + for nbits >= shift { + i-- + s[i] = digits[w&mask] + w >>= shift + nbits -= shift + } + + // convert any partial leading digit and advance to next word + if nbits == 0 { + // no partial digit remaining, just advance + w = x[k] + nbits = _W + } else { + // partial digit in current word w (== x[k-1]) and next word x[k] + w |= x[k] << nbits + i-- + s[i] = digits[w&mask] + + // advance + w = x[k] >> (shift - nbits) + nbits = _W - (shift - nbits) + } + } + + // convert digits of most-significant word w (omit leading zeros) + for w != 0 { + i-- + s[i] = digits[w&mask] + w >>= shift + } + + } else { + bb, ndigits := maxPow(b) + + // construct table of successive squares of bb*leafSize to use in subdivisions + // result (table != nil) <=> (len(x) > leafSize > 0) + table := divisors(len(x), b, ndigits, bb) + + // preserve x, create local copy for use by convertWords + q := nat(nil).set(x) + + // convert q to string s in base b + q.convertWords(s, b, ndigits, bb, table) + + // strip leading zeros + // (x != 0; thus s must contain at least one non-zero digit + // and the loop will terminate) + i = 0 + for s[i] == '0' { + i++ + } + } + + if neg { + i-- + s[i] = '-' + } + + return s[i:] +} + +// Convert words of q to base b digits in s. If q is large, it is recursively "split in half" +// by nat/nat division using tabulated divisors. Otherwise, it is converted iteratively using +// repeated nat/Word division. +// +// The iterative method processes n Words by n divW() calls, each of which visits every Word in the +// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s. +// Recursive conversion divides q by its approximate square root, yielding two parts, each half +// the size of q. Using the iterative method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s +// plus the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and +// is made better by splitting the subblocks recursively. Best is to split blocks until one more +// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the +// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the +// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and +// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for +// specific hardware. +func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []divisor) { + // split larger blocks recursively + if table != nil { + // len(q) > leafSize > 0 + var r nat + index := len(table) - 1 + for len(q) > leafSize { + // find divisor close to sqrt(q) if possible, but in any case < q + maxLength := q.bitLen() // ~= log2 q, or at of least largest possible q of this bit length + minLength := maxLength >> 1 // ~= log2 sqrt(q) + for index > 0 && table[index-1].nbits > minLength { + index-- // desired + } + if table[index].nbits >= maxLength && table[index].bbb.cmp(q) >= 0 { + index-- + if index < 0 { + panic("internal inconsistency") + } + } + + // split q into the two digit number (q'*bbb + r) to form independent subblocks + q, r = q.div(r, q, table[index].bbb) + + // convert subblocks and collect results in s[:h] and s[h:] + h := len(s) - table[index].ndigits + r.convertWords(s[h:], b, ndigits, bb, table[0:index]) + s = s[:h] // == q.convertWords(s, b, ndigits, bb, table[0:index+1]) + } + } + + // having split any large blocks now process the remaining (small) block iteratively + i := len(s) + var r Word + if b == 10 { + // hard-coding for 10 here speeds this up by 1.25x (allows for / and % by constants) + for len(q) > 0 { + // extract least significant, base bb "digit" + q, r = q.divW(q, bb) + for j := 0; j < ndigits && i > 0; j++ { + i-- + // avoid % computation since r%10 == r - int(r/10)*10; + // this appears to be faster for BenchmarkString10000Base10 + // and smaller strings (but a bit slower for larger ones) + t := r / 10 + s[i] = '0' + byte(r-t*10) + r = t + } + } + } else { + for len(q) > 0 { + // extract least significant, base bb "digit" + q, r = q.divW(q, bb) + for j := 0; j < ndigits && i > 0; j++ { + i-- + s[i] = digits[r%b] + r /= b + } + } + } + + // prepend high-order zeros + for i > 0 { // while need more leading zeros + i-- + s[i] = '0' + } +} + +// Split blocks greater than leafSize Words (or set to 0 to disable recursive conversion) +// Benchmark and configure leafSize using: go test -bench="Leaf" +// +// 8 and 16 effective on 3.0 GHz Xeon "Clovertown" CPU (128 byte cache lines) +// 8 and 16 effective on 2.66 GHz Core 2 Duo "Penryn" CPU +var leafSize int = 8 // number of Word-size binary values treat as a monolithic block + +type divisor struct { + bbb nat // divisor + nbits int // bit length of divisor (discounting leading zeros) ~= log2(bbb) + ndigits int // digit length of divisor in terms of output base digits +} + +var cacheBase10 struct { + sync.Mutex + table [64]divisor // cached divisors for base 10 +} + +// expWW computes x**y +func (z nat) expWW(x, y Word) nat { + return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil, false) +} + +// construct table of powers of bb*leafSize to use in subdivisions. +func divisors(m int, b Word, ndigits int, bb Word) []divisor { + // only compute table when recursive conversion is enabled and x is large + if leafSize == 0 || m <= leafSize { + return nil + } + + // determine k where (bb**leafSize)**(2**k) >= sqrt(x) + k := 1 + for words := leafSize; words < m>>1 && k < len(cacheBase10.table); words <<= 1 { + k++ + } + + // reuse and extend existing table of divisors or create new table as appropriate + var table []divisor // for b == 10, table overlaps with cacheBase10.table + if b == 10 { + cacheBase10.Lock() + table = cacheBase10.table[0:k] // reuse old table for this conversion + } else { + table = make([]divisor, k) // create new table for this conversion + } + + // extend table + if table[k-1].ndigits == 0 { + // add new entries as needed + var larger nat + for i := 0; i < k; i++ { + if table[i].ndigits == 0 { + if i == 0 { + table[0].bbb = nat(nil).expWW(bb, Word(leafSize)) + table[0].ndigits = ndigits * leafSize + } else { + table[i].bbb = nat(nil).sqr(table[i-1].bbb) + table[i].ndigits = 2 * table[i-1].ndigits + } + + // optimization: exploit aggregated extra bits in macro blocks + larger = nat(nil).set(table[i].bbb) + for mulAddVWW(larger, larger, b, 0) == 0 { + table[i].bbb = table[i].bbb.set(larger) + table[i].ndigits++ + } + + table[i].nbits = table[i].bbb.bitLen() + } + } + } + + if b == 10 { + cacheBase10.Unlock() + } + + return table +} diff --git a/src/math/big/natconv_test.go b/src/math/big/natconv_test.go new file mode 100644 index 0000000..d390272 --- /dev/null +++ b/src/math/big/natconv_test.go @@ -0,0 +1,463 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "bytes" + "fmt" + "io" + "math/bits" + "strings" + "testing" +) + +func TestMaxBase(t *testing.T) { + if MaxBase != len(digits) { + t.Fatalf("%d != %d", MaxBase, len(digits)) + } +} + +// log2 computes the integer binary logarithm of x. +// The result is the integer n for which 2^n <= x < 2^(n+1). +// If x == 0, the result is -1. +func log2(x Word) int { + return bits.Len(uint(x)) - 1 +} + +func itoa(x nat, base int) []byte { + // special cases + switch { + case base < 2: + panic("illegal base") + case len(x) == 0: + return []byte("0") + } + + // allocate buffer for conversion + i := x.bitLen()/log2(Word(base)) + 1 // +1: round up + s := make([]byte, i) + + // don't destroy x + q := nat(nil).set(x) + + // convert + for len(q) > 0 { + i-- + var r Word + q, r = q.divW(q, Word(base)) + s[i] = digits[r] + } + + return s[i:] +} + +var strTests = []struct { + x nat // nat value to be converted + b int // conversion base + s string // expected result +}{ + {nil, 2, "0"}, + {nat{1}, 2, "1"}, + {nat{0xc5}, 2, "11000101"}, + {nat{03271}, 8, "3271"}, + {nat{10}, 10, "10"}, + {nat{1234567890}, 10, "1234567890"}, + {nat{0xdeadbeef}, 16, "deadbeef"}, + {nat{0x229be7}, 17, "1a2b3c"}, + {nat{0x309663e6}, 32, "o9cov6"}, + {nat{0x309663e6}, 62, "TakXI"}, +} + +func TestString(t *testing.T) { + // test invalid base explicitly + var panicStr string + func() { + defer func() { + panicStr = recover().(string) + }() + natOne.utoa(1) + }() + if panicStr != "invalid base" { + t.Errorf("expected panic for invalid base") + } + + for _, a := range strTests { + s := string(a.x.utoa(a.b)) + if s != a.s { + t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s) + } + + x, b, _, err := nat(nil).scan(strings.NewReader(a.s), a.b, false) + if x.cmp(a.x) != 0 { + t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x) + } + if b != a.b { + t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.b) + } + if err != nil { + t.Errorf("scan%+v\n\tgot error = %s", a, err) + } + } +} + +var natScanTests = []struct { + s string // string to be scanned + base int // input base + frac bool // fraction ok + x nat // expected nat + b int // expected base + count int // expected digit count + err error // expected error + next rune // next character (or 0, if at EOF) +}{ + // invalid: no digits + {"", 0, false, nil, 10, 0, errNoDigits, 0}, + {"_", 0, false, nil, 10, 0, errNoDigits, 0}, + {"?", 0, false, nil, 10, 0, errNoDigits, '?'}, + {"?", 10, false, nil, 10, 0, errNoDigits, '?'}, + {"", 10, false, nil, 10, 0, errNoDigits, 0}, + {"", 36, false, nil, 36, 0, errNoDigits, 0}, + {"", 62, false, nil, 62, 0, errNoDigits, 0}, + {"0b", 0, false, nil, 2, 0, errNoDigits, 0}, + {"0o", 0, false, nil, 8, 0, errNoDigits, 0}, + {"0x", 0, false, nil, 16, 0, errNoDigits, 0}, + {"0x_", 0, false, nil, 16, 0, errNoDigits, 0}, + {"0b2", 0, false, nil, 2, 0, errNoDigits, '2'}, + {"0B2", 0, false, nil, 2, 0, errNoDigits, '2'}, + {"0o8", 0, false, nil, 8, 0, errNoDigits, '8'}, + {"0O8", 0, false, nil, 8, 0, errNoDigits, '8'}, + {"0xg", 0, false, nil, 16, 0, errNoDigits, 'g'}, + {"0Xg", 0, false, nil, 16, 0, errNoDigits, 'g'}, + {"345", 2, false, nil, 2, 0, errNoDigits, '3'}, + + // invalid: incorrect use of decimal point + {"._", 0, true, nil, 10, 0, errNoDigits, 0}, + {".0", 0, false, nil, 10, 0, errNoDigits, '.'}, + {".0", 10, false, nil, 10, 0, errNoDigits, '.'}, + {".", 0, true, nil, 10, 0, errNoDigits, 0}, + {"0x.", 0, true, nil, 16, 0, errNoDigits, 0}, + {"0x.g", 0, true, nil, 16, 0, errNoDigits, 'g'}, + {"0x.0", 0, false, nil, 16, 0, errNoDigits, '.'}, + + // invalid: incorrect use of separators + {"_0", 0, false, nil, 10, 1, errInvalSep, 0}, + {"0_", 0, false, nil, 10, 1, errInvalSep, 0}, + {"0__0", 0, false, nil, 8, 1, errInvalSep, 0}, + {"0x___0", 0, false, nil, 16, 1, errInvalSep, 0}, + {"0_x", 0, false, nil, 10, 1, errInvalSep, 'x'}, + {"0_8", 0, false, nil, 10, 1, errInvalSep, '8'}, + {"123_.", 0, true, nat{123}, 10, 0, errInvalSep, 0}, + {"._123", 0, true, nat{123}, 10, -3, errInvalSep, 0}, + {"0b__1000", 0, false, nat{0x8}, 2, 4, errInvalSep, 0}, + {"0o60___0", 0, false, nat{0600}, 8, 3, errInvalSep, 0}, + {"0466_", 0, false, nat{0466}, 8, 3, errInvalSep, 0}, + {"01234567_8", 0, false, nat{01234567}, 8, 7, errInvalSep, '8'}, + {"1_.", 0, true, nat{1}, 10, 0, errInvalSep, 0}, + {"0._1", 0, true, nat{1}, 10, -1, errInvalSep, 0}, + {"2.7_", 0, true, nat{27}, 10, -1, errInvalSep, 0}, + {"0x1.0_", 0, true, nat{0x10}, 16, -1, errInvalSep, 0}, + + // valid: separators are not accepted for base != 0 + {"0_", 10, false, nil, 10, 1, nil, '_'}, + {"1__0", 10, false, nat{1}, 10, 1, nil, '_'}, + {"0__8", 10, false, nil, 10, 1, nil, '_'}, + {"xy_z_", 36, false, nat{33*36 + 34}, 36, 2, nil, '_'}, + + // valid, no decimal point + {"0", 0, false, nil, 10, 1, nil, 0}, + {"0", 36, false, nil, 36, 1, nil, 0}, + {"0", 62, false, nil, 62, 1, nil, 0}, + {"1", 0, false, nat{1}, 10, 1, nil, 0}, + {"1", 10, false, nat{1}, 10, 1, nil, 0}, + {"0 ", 0, false, nil, 10, 1, nil, ' '}, + {"00 ", 0, false, nil, 8, 1, nil, ' '}, // octal 0 + {"0b1", 0, false, nat{1}, 2, 1, nil, 0}, + {"0B11000101", 0, false, nat{0xc5}, 2, 8, nil, 0}, + {"0B110001012", 0, false, nat{0xc5}, 2, 8, nil, '2'}, + {"07", 0, false, nat{7}, 8, 1, nil, 0}, + {"08", 0, false, nil, 10, 1, nil, '8'}, + {"08", 10, false, nat{8}, 10, 2, nil, 0}, + {"018", 0, false, nat{1}, 8, 1, nil, '8'}, + {"0o7", 0, false, nat{7}, 8, 1, nil, 0}, + {"0o18", 0, false, nat{1}, 8, 1, nil, '8'}, + {"0O17", 0, false, nat{017}, 8, 2, nil, 0}, + {"03271", 0, false, nat{03271}, 8, 4, nil, 0}, + {"10ab", 0, false, nat{10}, 10, 2, nil, 'a'}, + {"1234567890", 0, false, nat{1234567890}, 10, 10, nil, 0}, + {"A", 36, false, nat{10}, 36, 1, nil, 0}, + {"A", 37, false, nat{36}, 37, 1, nil, 0}, + {"xyz", 36, false, nat{(33*36+34)*36 + 35}, 36, 3, nil, 0}, + {"XYZ?", 36, false, nat{(33*36+34)*36 + 35}, 36, 3, nil, '?'}, + {"XYZ?", 62, false, nat{(59*62+60)*62 + 61}, 62, 3, nil, '?'}, + {"0x", 16, false, nil, 16, 1, nil, 'x'}, + {"0xdeadbeef", 0, false, nat{0xdeadbeef}, 16, 8, nil, 0}, + {"0XDEADBEEF", 0, false, nat{0xdeadbeef}, 16, 8, nil, 0}, + + // valid, with decimal point + {"0.", 0, false, nil, 10, 1, nil, '.'}, + {"0.", 10, true, nil, 10, 0, nil, 0}, + {"0.1.2", 10, true, nat{1}, 10, -1, nil, '.'}, + {".000", 10, true, nil, 10, -3, nil, 0}, + {"12.3", 10, true, nat{123}, 10, -1, nil, 0}, + {"012.345", 10, true, nat{12345}, 10, -3, nil, 0}, + {"0.1", 0, true, nat{1}, 10, -1, nil, 0}, + {"0.1", 2, true, nat{1}, 2, -1, nil, 0}, + {"0.12", 2, true, nat{1}, 2, -1, nil, '2'}, + {"0b0.1", 0, true, nat{1}, 2, -1, nil, 0}, + {"0B0.12", 0, true, nat{1}, 2, -1, nil, '2'}, + {"0o0.7", 0, true, nat{7}, 8, -1, nil, 0}, + {"0O0.78", 0, true, nat{7}, 8, -1, nil, '8'}, + {"0xdead.beef", 0, true, nat{0xdeadbeef}, 16, -4, nil, 0}, + + // valid, with separators + {"1_000", 0, false, nat{1000}, 10, 4, nil, 0}, + {"0_466", 0, false, nat{0466}, 8, 3, nil, 0}, + {"0o_600", 0, false, nat{0600}, 8, 3, nil, 0}, + {"0x_f0_0d", 0, false, nat{0xf00d}, 16, 4, nil, 0}, + {"0b1000_0001", 0, false, nat{0x81}, 2, 8, nil, 0}, + {"1_000.000_1", 0, true, nat{10000001}, 10, -4, nil, 0}, + {"0x_f00d.1e", 0, true, nat{0xf00d1e}, 16, -2, nil, 0}, + {"0x_f00d.1E2", 0, true, nat{0xf00d1e2}, 16, -3, nil, 0}, + {"0x_f00d.1eg", 0, true, nat{0xf00d1e}, 16, -2, nil, 'g'}, +} + +func TestScanBase(t *testing.T) { + for _, a := range natScanTests { + r := strings.NewReader(a.s) + x, b, count, err := nat(nil).scan(r, a.base, a.frac) + if err != a.err { + t.Errorf("scan%+v\n\tgot error = %v; want %v", a, err, a.err) + } + if x.cmp(a.x) != 0 { + t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x) + } + if b != a.b { + t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.base) + } + if count != a.count { + t.Errorf("scan%+v\n\tgot count = %d; want %d", a, count, a.count) + } + next, _, err := r.ReadRune() + if err == io.EOF { + next = 0 + err = nil + } + if err == nil && next != a.next { + t.Errorf("scan%+v\n\tgot next = %q; want %q", a, next, a.next) + } + } +} + +var pi = "3" + + "14159265358979323846264338327950288419716939937510582097494459230781640628620899862803482534211706798214808651" + + "32823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461" + + "28475648233786783165271201909145648566923460348610454326648213393607260249141273724587006606315588174881520920" + + "96282925409171536436789259036001133053054882046652138414695194151160943305727036575959195309218611738193261179" + + "31051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798" + + "60943702770539217176293176752384674818467669405132000568127145263560827785771342757789609173637178721468440901" + + "22495343014654958537105079227968925892354201995611212902196086403441815981362977477130996051870721134999999837" + + "29780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083" + + "81420617177669147303598253490428755468731159562863882353787593751957781857780532171226806613001927876611195909" + + "21642019893809525720106548586327886593615338182796823030195203530185296899577362259941389124972177528347913151" + + "55748572424541506959508295331168617278558890750983817546374649393192550604009277016711390098488240128583616035" + + "63707660104710181942955596198946767837449448255379774726847104047534646208046684259069491293313677028989152104" + + "75216205696602405803815019351125338243003558764024749647326391419927260426992279678235478163600934172164121992" + + "45863150302861829745557067498385054945885869269956909272107975093029553211653449872027559602364806654991198818" + + "34797753566369807426542527862551818417574672890977772793800081647060016145249192173217214772350141441973568548" + + "16136115735255213347574184946843852332390739414333454776241686251898356948556209921922218427255025425688767179" + + "04946016534668049886272327917860857843838279679766814541009538837863609506800642251252051173929848960841284886" + + "26945604241965285022210661186306744278622039194945047123713786960956364371917287467764657573962413890865832645" + + "99581339047802759009946576407895126946839835259570982582262052248940772671947826848260147699090264013639443745" + + "53050682034962524517493996514314298091906592509372216964615157098583874105978859597729754989301617539284681382" + + "68683868942774155991855925245953959431049972524680845987273644695848653836736222626099124608051243884390451244" + + "13654976278079771569143599770012961608944169486855584840635342207222582848864815845602850601684273945226746767" + + "88952521385225499546667278239864565961163548862305774564980355936345681743241125150760694794510965960940252288" + + "79710893145669136867228748940560101503308617928680920874760917824938589009714909675985261365549781893129784821" + + "68299894872265880485756401427047755513237964145152374623436454285844479526586782105114135473573952311342716610" + + "21359695362314429524849371871101457654035902799344037420073105785390621983874478084784896833214457138687519435" + + "06430218453191048481005370614680674919278191197939952061419663428754440643745123718192179998391015919561814675" + + "14269123974894090718649423196156794520809514655022523160388193014209376213785595663893778708303906979207734672" + + "21825625996615014215030680384477345492026054146659252014974428507325186660021324340881907104863317346496514539" + + "05796268561005508106658796998163574736384052571459102897064140110971206280439039759515677157700420337869936007" + + "23055876317635942187312514712053292819182618612586732157919841484882916447060957527069572209175671167229109816" + + "90915280173506712748583222871835209353965725121083579151369882091444210067510334671103141267111369908658516398" + + "31501970165151168517143765761835155650884909989859982387345528331635507647918535893226185489632132933089857064" + + "20467525907091548141654985946163718027098199430992448895757128289059232332609729971208443357326548938239119325" + + "97463667305836041428138830320382490375898524374417029132765618093773444030707469211201913020330380197621101100" + + "44929321516084244485963766983895228684783123552658213144957685726243344189303968642624341077322697802807318915" + + "44110104468232527162010526522721116603966655730925471105578537634668206531098965269186205647693125705863566201" + + "85581007293606598764861179104533488503461136576867532494416680396265797877185560845529654126654085306143444318" + + "58676975145661406800700237877659134401712749470420562230538994561314071127000407854733269939081454664645880797" + + "27082668306343285878569830523580893306575740679545716377525420211495576158140025012622859413021647155097925923" + + "09907965473761255176567513575178296664547791745011299614890304639947132962107340437518957359614589019389713111" + + "79042978285647503203198691514028708085990480109412147221317947647772622414254854540332157185306142288137585043" + + "06332175182979866223717215916077166925474873898665494945011465406284336639379003976926567214638530673609657120" + + "91807638327166416274888800786925602902284721040317211860820419000422966171196377921337575114959501566049631862" + + "94726547364252308177036751590673502350728354056704038674351362222477158915049530984448933309634087807693259939" + + "78054193414473774418426312986080998886874132604721569516239658645730216315981931951673538129741677294786724229" + + "24654366800980676928238280689964004824354037014163149658979409243237896907069779422362508221688957383798623001" + + "59377647165122893578601588161755782973523344604281512627203734314653197777416031990665541876397929334419521541" + + "34189948544473456738316249934191318148092777710386387734317720754565453220777092120190516609628049092636019759" + + "88281613323166636528619326686336062735676303544776280350450777235547105859548702790814356240145171806246436267" + + "94561275318134078330336254232783944975382437205835311477119926063813346776879695970309833913077109870408591337" + +// Test case for BenchmarkScanPi. +func TestScanPi(t *testing.T) { + var x nat + z, _, _, err := x.scan(strings.NewReader(pi), 10, false) + if err != nil { + t.Errorf("scanning pi: %s", err) + } + if s := string(z.utoa(10)); s != pi { + t.Errorf("scanning pi: got %s", s) + } +} + +func TestScanPiParallel(t *testing.T) { + const n = 2 + c := make(chan int) + for i := 0; i < n; i++ { + go func() { + TestScanPi(t) + c <- 0 + }() + } + for i := 0; i < n; i++ { + <-c + } +} + +func BenchmarkScanPi(b *testing.B) { + for i := 0; i < b.N; i++ { + var x nat + x.scan(strings.NewReader(pi), 10, false) + } +} + +func BenchmarkStringPiParallel(b *testing.B) { + var x nat + x, _, _, _ = x.scan(strings.NewReader(pi), 0, false) + if string(x.utoa(10)) != pi { + panic("benchmark incorrect: conversion failed") + } + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + x.utoa(10) + } + }) +} + +func BenchmarkScan(b *testing.B) { + const x = 10 + for _, base := range []int{2, 8, 10, 16} { + for _, y := range []Word{10, 100, 1000, 10000, 100000} { + if isRaceBuilder && y > 1000 { + continue + } + b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) { + b.StopTimer() + var z nat + z = z.expWW(x, y) + + s := z.utoa(base) + if t := itoa(z, base); !bytes.Equal(s, t) { + b.Fatalf("scanning: got %s; want %s", s, t) + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + z.scan(bytes.NewReader(s), base, false) + } + }) + } + } +} + +func BenchmarkString(b *testing.B) { + const x = 10 + for _, base := range []int{2, 8, 10, 16} { + for _, y := range []Word{10, 100, 1000, 10000, 100000} { + if isRaceBuilder && y > 1000 { + continue + } + b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) { + b.StopTimer() + var z nat + z = z.expWW(x, y) + z.utoa(base) // warm divisor cache + b.StartTimer() + + for i := 0; i < b.N; i++ { + _ = z.utoa(base) + } + }) + } + } +} + +func BenchmarkLeafSize(b *testing.B) { + for n := 0; n <= 16; n++ { + b.Run(fmt.Sprint(n), func(b *testing.B) { LeafSizeHelper(b, 10, n) }) + } + // Try some large lengths + for _, n := range []int{32, 64} { + b.Run(fmt.Sprint(n), func(b *testing.B) { LeafSizeHelper(b, 10, n) }) + } +} + +func LeafSizeHelper(b *testing.B, base, size int) { + b.StopTimer() + originalLeafSize := leafSize + resetTable(cacheBase10.table[:]) + leafSize = size + b.StartTimer() + + for d := 1; d <= 10000; d *= 10 { + b.StopTimer() + var z nat + z = z.expWW(Word(base), Word(d)) // build target number + _ = z.utoa(base) // warm divisor cache + b.StartTimer() + + for i := 0; i < b.N; i++ { + _ = z.utoa(base) + } + } + + b.StopTimer() + resetTable(cacheBase10.table[:]) + leafSize = originalLeafSize + b.StartTimer() +} + +func resetTable(table []divisor) { + if table != nil && table[0].bbb != nil { + for i := 0; i < len(table); i++ { + table[i].bbb = nil + table[i].nbits = 0 + table[i].ndigits = 0 + } + } +} + +func TestStringPowers(t *testing.T) { + var p Word + for b := 2; b <= 16; b++ { + for p = 0; p <= 512; p++ { + if testing.Short() && p > 10 { + break + } + x := nat(nil).expWW(Word(b), p) + xs := x.utoa(b) + xs2 := itoa(x, b) + if !bytes.Equal(xs, xs2) { + t.Errorf("failed at %d ** %d in base %d: %s != %s", b, p, b, xs, xs2) + } + } + if b >= 3 && testing.Short() { + break + } + } +} diff --git a/src/math/big/natdiv.go b/src/math/big/natdiv.go new file mode 100644 index 0000000..14233a2 --- /dev/null +++ b/src/math/big/natdiv.go @@ -0,0 +1,897 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + +Multi-precision division. Here be dragons. + +Given u and v, where u is n+m digits, and v is n digits (with no leading zeros), +the goal is to return quo, rem such that u = quo*v + rem, where 0 ≤ rem < v. +That is, quo = ⌊u/v⌋ where ⌊x⌋ denotes the floor (truncation to integer) of x, +and rem = u - quo·v. + + +Long Division + +Division in a computer proceeds the same as long division in elementary school, +but computers are not as good as schoolchildren at following vague directions, +so we have to be much more precise about the actual steps and what can happen. + +We work from most to least significant digit of the quotient, doing: + + • Guess a digit q, the number of v to subtract from the current + section of u to zero out the topmost digit. + • Check the guess by multiplying q·v and comparing it against + the current section of u, adjusting the guess as needed. + • Subtract q·v from the current section of u. + • Add q to the corresponding section of the result quo. + +When all digits have been processed, the final remainder is left in u +and returned as rem. + +For example, here is a sketch of dividing 5 digits by 3 digits (n=3, m=2). + + q₂ q₁ q₀ + _________________ + v₂ v₁ v₀ ) u₄ u₃ u₂ u₁ u₀ + ↓ ↓ ↓ | | + [u₄ u₃ u₂]| | + - [ q₂·v ]| | + ----------- ↓ | + [ rem | u₁]| + - [ q₁·v ]| + ----------- ↓ + [ rem | u₀] + - [ q₀·v ] + ------------ + [ rem ] + +Instead of creating new storage for the remainders and copying digits from u +as indicated by the arrows, we use u's storage directly as both the source +and destination of the subtractions, so that the remainders overwrite +successive overlapping sections of u as the division proceeds, using a slice +of u to identify the current section. This avoids all the copying as well as +shifting of remainders. + +Division of u with n+m digits by v with n digits (in base B) can in general +produce at most m+1 digits, because: + + • u < B^(n+m) [B^(n+m) has n+m+1 digits] + • v ≥ B^(n-1) [B^(n-1) is the smallest n-digit number] + • u/v < B^(n+m) / B^(n-1) [divide bounds for u, v] + • u/v < B^(m+1) [simplify] + +The first step is special: it takes the top n digits of u and divides them by +the n digits of v, producing the first quotient digit and an n-digit remainder. +In the example, q₂ = ⌊u₄u₃u₂ / v⌋. + +The first step divides n digits by n digits to ensure that it produces only a +single digit. + +Each subsequent step appends the next digit from u to the remainder and divides +those n+1 digits by the n digits of v, producing another quotient digit and a +new n-digit remainder. + +Subsequent steps divide n+1 digits by n digits, an operation that in general +might produce two digits. However, as used in the algorithm, that division is +guaranteed to produce only a single digit. The dividend is of the form +rem·B + d, where rem is a remainder from the previous step and d is a single +digit, so: + + • rem ≤ v - 1 [rem is a remainder from dividing by v] + • rem·B ≤ v·B - B [multiply by B] + • d ≤ B - 1 [d is a single digit] + • rem·B + d ≤ v·B - 1 [add] + • rem·B + d < v·B [change ≤ to <] + • (rem·B + d)/v < B [divide by v] + + +Guess and Check + +At each step we need to divide n+1 digits by n digits, but this is for the +implementation of division by n digits, so we can't just invoke a division +routine: we _are_ the division routine. Instead, we guess at the answer and +then check it using multiplication. If the guess is wrong, we correct it. + +How can this guessing possibly be efficient? It turns out that the following +statement (let's call it the Good Guess Guarantee) is true. + +If + + • q = ⌊u/v⌋ where u is n+1 digits and v is n digits, + • q < B, and + • the topmost digit of v = vₙ₋₁ ≥ B/2, + +then q̂ = ⌊uₙuₙ₋₁ / vₙ₋₁⌋ satisfies q ≤ q̂ ≤ q+2. (Proof below.) + +That is, if we know the answer has only a single digit and we guess an answer +by ignoring the bottom n-1 digits of u and v, using a 2-by-1-digit division, +then that guess is at least as large as the correct answer. It is also not +too much larger: it is off by at most two from the correct answer. + +Note that in the first step of the overall division, which is an n-by-n-digit +division, the 2-by-1 guess uses an implicit uₙ = 0. + +Note that using a 2-by-1-digit division here does not mean calling ourselves +recursively. Instead, we use an efficient direct hardware implementation of +that operation. + +Note that because q is u/v rounded down, q·v must not exceed u: u ≥ q·v. +If a guess q̂ is too big, it will not satisfy this test. Viewed a different way, +the remainder r̂ for a given q̂ is u - q̂·v, which must be positive. If it is +negative, then the guess q̂ is too big. + +This gives us a way to compute q. First compute q̂ with 2-by-1-digit division. +Then, while u < q̂·v, decrement q̂; this loop executes at most twice, because +q̂ ≤ q+2. + + +Scaling Inputs + +The Good Guess Guarantee requires that the top digit of v (vₙ₋₁) be at least B/2. +For example in base 10, ⌊172/19⌋ = 9, but ⌊18/1⌋ = 18: the guess is wildly off +because the first digit 1 is smaller than B/2 = 5. + +We can ensure that v has a large top digit by multiplying both u and v by the +right amount. Continuing the example, if we multiply both 172 and 19 by 3, we +now have ⌊516/57⌋, the leading digit of v is now ≥ 5, and sure enough +⌊51/5⌋ = 10 is much closer to the correct answer 9. It would be easier here +to multiply by 4, because that can be done with a shift. Specifically, we can +always count the number of leading zeros i in the first digit of v and then +shift both u and v left by i bits. + +Having scaled u and v, the value ⌊u/v⌋ is unchanged, but the remainder will +be scaled: 172 mod 19 is 1, but 516 mod 57 is 3. We have to divide the remainder +by the scaling factor (shifting right i bits) when we finish. + +Note that these shifts happen before and after the entire division algorithm, +not at each step in the per-digit iteration. + +Note the effect of scaling inputs on the size of the possible quotient. +In the scaled u/v, u can gain a digit from scaling; v never does, because we +pick the scaling factor to make v's top digit larger but without overflowing. +If u and v have n+m and n digits after scaling, then: + + • u < B^(n+m) [B^(n+m) has n+m+1 digits] + • v ≥ B^n / 2 [vₙ₋₁ ≥ B/2, so vₙ₋₁·B^(n-1) ≥ B^n/2] + • u/v < B^(n+m) / (B^n / 2) [divide bounds for u, v] + • u/v < 2 B^m [simplify] + +The quotient can still have m+1 significant digits, but if so the top digit +must be a 1. This provides a different way to handle the first digit of the +result: compare the top n digits of u against v and fill in either a 0 or a 1. + + +Refining Guesses + +Before we check whether u < q̂·v, we can adjust our guess to change it from +q̂ = ⌊uₙuₙ₋₁ / vₙ₋₁⌋ into the refined guess ⌊uₙuₙ₋₁uₙ₋₂ / vₙ₋₁vₙ₋₂⌋. +Although not mentioned above, the Good Guess Guarantee also promises that this +3-by-2-digit division guess is more precise and at most one away from the real +answer q. The improvement from the 2-by-1 to the 3-by-2 guess can also be done +without n-digit math. + +If we have a guess q̂ = ⌊uₙuₙ₋₁ / vₙ₋₁⌋ and we want to see if it also equal to +⌊uₙuₙ₋₁uₙ₋₂ / vₙ₋₁vₙ₋₂⌋, we can use the same check we would for the full division: +if uₙuₙ₋₁uₙ₋₂ < q̂·vₙ₋₁vₙ₋₂, then the guess is too large and should be reduced. + +Checking uₙuₙ₋₁uₙ₋₂ < q̂·vₙ₋₁vₙ₋₂ is the same as uₙuₙ₋₁uₙ₋₂ - q̂·vₙ₋₁vₙ₋₂ < 0, +and + + uₙuₙ₋₁uₙ₋₂ - q̂·vₙ₋₁vₙ₋₂ = (uₙuₙ₋₁·B + uₙ₋₂) - q̂·(vₙ₋₁·B + vₙ₋₂) + [splitting off the bottom digit] + = (uₙuₙ₋₁ - q̂·vₙ₋₁)·B + uₙ₋₂ - q̂·vₙ₋₂ + [regrouping] + +The expression (uₙuₙ₋₁ - q̂·vₙ₋₁) is the remainder of uₙuₙ₋₁ / vₙ₋₁. +If the initial guess returns both q̂ and its remainder r̂, then checking +whether uₙuₙ₋₁uₙ₋₂ < q̂·vₙ₋₁vₙ₋₂ is the same as checking r̂·B + uₙ₋₂ < q̂·vₙ₋₂. + +If we find that r̂·B + uₙ₋₂ < q̂·vₙ₋₂, then we can adjust the guess by +decrementing q̂ and adding vₙ₋₁ to r̂. We repeat until r̂·B + uₙ₋₂ ≥ q̂·vₙ₋₂. +(As before, this fixup is only needed at most twice.) + +Now that q̂ = ⌊uₙuₙ₋₁uₙ₋₂ / vₙ₋₁vₙ₋₂⌋, as mentioned above it is at most one +away from the correct q, and we've avoided doing any n-digit math. +(If we need the new remainder, it can be computed as r̂·B + uₙ₋₂ - q̂·vₙ₋₂.) + +The final check u < q̂·v and the possible fixup must be done at full precision. +For random inputs, a fixup at this step is exceedingly rare: the 3-by-2 guess +is not often wrong at all. But still we must do the check. Note that since the +3-by-2 guess is off by at most 1, it can be convenient to perform the final +u < q̂·v as part of the computation of the remainder r = u - q̂·v. If the +subtraction underflows, decremeting q̂ and adding one v back to r is enough to +arrive at the final q, r. + +That's the entirety of long division: scale the inputs, and then loop over +each output position, guessing, checking, and correcting the next output digit. + +For a 2n-digit number divided by an n-digit number (the worst size-n case for +division complexity), this algorithm uses n+1 iterations, each of which must do +at least the 1-by-n-digit multiplication q̂·v. That's O(n) iterations of +O(n) time each, so O(n²) time overall. + + +Recursive Division + +For very large inputs, it is possible to improve on the O(n²) algorithm. +Let's call a group of n/2 real digits a (very) “wide digit”. We can run the +standard long division algorithm explained above over the wide digits instead of +the actual digits. This will result in many fewer steps, but the math involved in +each step is more work. + +Where basic long division uses a 2-by-1-digit division to guess the initial q̂, +the new algorithm must use a 2-by-1-wide-digit division, which is of course +really an n-by-n/2-digit division. That's OK: if we implement n-digit division +in terms of n/2-digit division, the recursion will terminate when the divisor +becomes small enough to handle with standard long division or even with the +2-by-1 hardware instruction. + +For example, here is a sketch of dividing 10 digits by 4, proceeding with +wide digits corresponding to two regular digits. The first step, still special, +must leave off a (regular) digit, dividing 5 by 4 and producing a 4-digit +remainder less than v. The middle steps divide 6 digits by 4, guaranteed to +produce two output digits each (one wide digit) with 4-digit remainders. +The final step must use what it has: the 4-digit remainder plus one more, +5 digits to divide by 4. + + q₆ q₅ q₄ q₃ q₂ q₁ q₀ + _______________________________ + v₃ v₂ v₁ v₀ ) u₉ u₈ u₇ u₆ u₅ u₄ u₃ u₂ u₁ u₀ + ↓ ↓ ↓ ↓ ↓ | | | | | + [u₉ u₈ u₇ u₆ u₅]| | | | | + - [ q₆q₅·v ]| | | | | + ----------------- ↓ ↓ | | | + [ rem |u₄ u₃]| | | + - [ q₄q₃·v ]| | | + -------------------- ↓ ↓ | + [ rem |u₂ u₁]| + - [ q₂q₁·v ]| + -------------------- ↓ + [ rem |u₀] + - [ q₀·v ] + ------------------ + [ rem ] + +An alternative would be to look ahead to how well n/2 divides into n+m and +adjust the first step to use fewer digits as needed, making the first step +more special to make the last step not special at all. For example, using the +same input, we could choose to use only 4 digits in the first step, leaving +a full wide digit for the last step: + + q₆ q₅ q₄ q₃ q₂ q₁ q₀ + _______________________________ + v₃ v₂ v₁ v₀ ) u₉ u₈ u₇ u₆ u₅ u₄ u₃ u₂ u₁ u₀ + ↓ ↓ ↓ ↓ | | | | | | + [u₉ u₈ u₇ u₆]| | | | | | + - [ q₆·v ]| | | | | | + -------------- ↓ ↓ | | | | + [ rem |u₅ u₄]| | | | + - [ q₅q₄·v ]| | | | + -------------------- ↓ ↓ | | + [ rem |u₃ u₂]| | + - [ q₃q₂·v ]| | + -------------------- ↓ ↓ + [ rem |u₁ u₀] + - [ q₁q₀·v ] + --------------------- + [ rem ] + +Today, the code in divRecursiveStep works like the first example. Perhaps in +the future we will make it work like the alternative, to avoid a special case +in the final iteration. + +Either way, each step is a 3-by-2-wide-digit division approximated first by +a 2-by-1-wide-digit division, just as we did for regular digits in long division. +Because the actual answer we want is a 3-by-2-wide-digit division, instead of +multiplying q̂·v directly during the fixup, we can use the quick refinement +from long division (an n/2-by-n/2 multiply) to correct q to its actual value +and also compute the remainder (as mentioned above), and then stop after that, +never doing a full n-by-n multiply. + +Instead of using an n-by-n/2-digit division to produce n/2 digits, we can add +(not discard) one more real digit, doing an (n+1)-by-(n/2+1)-digit division that +produces n/2+1 digits. That single extra digit tightens the Good Guess Guarantee +to q ≤ q̂ ≤ q+1 and lets us drop long division's special treatment of the first +digit. These benefits are discussed more after the Good Guess Guarantee proof +below. + + +How Fast is Recursive Division? + +For a 2n-by-n-digit division, this algorithm runs a 4-by-2 long division over +wide digits, producing two wide digits plus a possible leading regular digit 1, +which can be handled without a recursive call. That is, the algorithm uses two +full iterations, each using an n-by-n/2-digit division and an n/2-by-n/2-digit +multiplication, along with a few n-digit additions and subtractions. The standard +n-by-n-digit multiplication algorithm requires O(n²) time, making the overall +algorithm require time T(n) where + + T(n) = 2T(n/2) + O(n) + O(n²) + +which, by the Bentley-Haken-Saxe theorem, ends up reducing to T(n) = O(n²). +This is not an improvement over regular long division. + +When the number of digits n becomes large enough, Karatsuba's algorithm for +multiplication can be used instead, which takes O(n^log₂3) = O(n^1.6) time. +(Karatsuba multiplication is implemented in func karatsuba in nat.go.) +That makes the overall recursive division algorithm take O(n^1.6) time as well, +which is an improvement, but again only for large enough numbers. + +It is not critical to make sure that every recursion does only two recursive +calls. While in general the number of recursive calls can change the time +analysis, in this case doing three calls does not change the analysis: + + T(n) = 3T(n/2) + O(n) + O(n^log₂3) + +ends up being T(n) = O(n^log₂3). Because the Karatsuba multiplication taking +time O(n^log₂3) is itself doing 3 half-sized recursions, doing three for the +division does not hurt the asymptotic performance. Of course, it is likely +still faster in practice to do two. + + +Proof of the Good Guess Guarantee + +Given numbers x, y, let us break them into the quotients and remainders when +divided by some scaling factor S, with the added constraints that the quotient +x/y and the high part of y are both less than some limit T, and that the high +part of y is at least half as big as T. + + x₁ = ⌊x/S⌋ y₁ = ⌊y/S⌋ + x₀ = x mod S y₀ = y mod S + + x = x₁·S + x₀ 0 ≤ x₀ < S x/y < T + y = y₁·S + y₀ 0 ≤ y₀ < S T/2 ≤ y₁ < T + +And consider the two truncated quotients: + + q = ⌊x/y⌋ + q̂ = ⌊x₁/y₁⌋ + +We will prove that q ≤ q̂ ≤ q+2. + +The guarantee makes no real demands on the scaling factor S: it is simply the +magnitude of the digits cut from both x and y to produce x₁ and y₁. +The guarantee makes only limited demands on T: it must be large enough to hold +the quotient x/y, and y₁ must have roughly the same size. + +To apply to the earlier discussion of 2-by-1 guesses in long division, +we would choose: + + S = Bⁿ⁻¹ + T = B + x = u + x₁ = uₙuₙ₋₁ + x₀ = uₙ₋₂...u₀ + y = v + y₁ = vₙ₋₁ + y₀ = vₙ₋₂...u₀ + +These simpler variables avoid repeating those longer expressions in the proof. + +Note also that, by definition, truncating division ⌊x/y⌋ satisfies + + x/y - 1 < ⌊x/y⌋ ≤ x/y. + +This fact will be used a few times in the proofs. + +Proof that q ≤ q̂: + + q̂·y₁ = ⌊x₁/y₁⌋·y₁ [by definition, q̂ = ⌊x₁/y₁⌋] + > (x₁/y₁ - 1)·y₁ [x₁/y₁ - 1 < ⌊x₁/y₁⌋] + = x₁ - y₁ [distribute y₁] + + So q̂·y₁ > x₁ - y₁. + Since q̂·y₁ is an integer, q̂·y₁ ≥ x₁ - y₁ + 1. + + q̂ - q = q̂ - ⌊x/y⌋ [by definition, q = ⌊x/y⌋] + ≥ q̂ - x/y [⌊x/y⌋ < x/y] + = (1/y)·(q̂·y - x) [factor out 1/y] + ≥ (1/y)·(q̂·y₁·S - x) [y = y₁·S + y₀ ≥ y₁·S] + ≥ (1/y)·((x₁ - y₁ + 1)·S - x) [above: q̂·y₁ ≥ x₁ - y₁ + 1] + = (1/y)·(x₁·S - y₁·S + S - x) [distribute S] + = (1/y)·(S - x₀ - y₁·S) [-x = -x₁·S - x₀] + > -y₁·S / y [x₀ < S, so S - x₀ < 0; drop it] + ≥ -1 [y₁·S ≤ y] + + So q̂ - q > -1. + Since q̂ - q is an integer, q̂ - q ≥ 0, or equivalently q ≤ q̂. + +Proof that q̂ ≤ q+2: + + x₁/y₁ - x/y = x₁·S/y₁·S - x/y [multiply left term by S/S] + ≤ x/y₁·S - x/y [x₁S ≤ x] + = (x/y)·(y/y₁·S - 1) [factor out x/y] + = (x/y)·((y - y₁·S)/y₁·S) [move -1 into y/y₁·S fraction] + = (x/y)·(y₀/y₁·S) [y - y₁·S = y₀] + = (x/y)·(1/y₁)·(y₀/S) [factor out 1/y₁] + < (x/y)·(1/y₁) [y₀ < S, so y₀/S < 1] + ≤ (x/y)·(2/T) [y₁ ≥ T/2, so 1/y₁ ≤ 2/T] + < T·(2/T) [x/y < T] + = 2 [T·(2/T) = 2] + + So x₁/y₁ - x/y < 2. + + q̂ - q = ⌊x₁/y₁⌋ - q [by definition, q̂ = ⌊x₁/y₁⌋] + = ⌊x₁/y₁⌋ - ⌊x/y⌋ [by definition, q = ⌊x/y⌋] + ≤ x₁/y₁ - ⌊x/y⌋ [⌊x₁/y₁⌋ ≤ x₁/y₁] + < x₁/y₁ - (x/y - 1) [⌊x/y⌋ > x/y - 1] + = (x₁/y₁ - x/y) + 1 [regrouping] + < 2 + 1 [above: x₁/y₁ - x/y < 2] + = 3 + + So q̂ - q < 3. + Since q̂ - q is an integer, q̂ - q ≤ 2. + +Note that when x/y < T/2, the bounds tighten to x₁/y₁ - x/y < 1 and therefore +q̂ - q ≤ 1. + +Note also that in the general case 2n-by-n division where we don't know that +x/y < T, we do know that x/y < 2T, yielding the bound q̂ - q ≤ 4. So we could +remove the special case first step of long division as long as we allow the +first fixup loop to run up to four times. (Using a simple comparison to decide +whether the first digit is 0 or 1 is still more efficient, though.) + +Finally, note that when dividing three leading base-B digits by two (scaled), +we have T = B² and x/y < B = T/B, a much tighter bound than x/y < T. +This in turn yields the much tighter bound x₁/y₁ - x/y < 2/B. This means that +⌊x₁/y₁⌋ and ⌊x/y⌋ can only differ when x/y is less than 2/B greater than an +integer. For random x and y, the chance of this is 2/B, or, for large B, +approximately zero. This means that after we produce the 3-by-2 guess in the +long division algorithm, the fixup loop essentially never runs. + +In the recursive algorithm, the extra digit in (2·⌊n/2⌋+1)-by-(⌊n/2⌋+1)-digit +division has exactly the same effect: the probability of needing a fixup is the +same 2/B. Even better, we can allow the general case x/y < 2T and the fixup +probability only grows to 4/B, still essentially zero. + + +References + +There are no great references for implementing long division; thus this comment. +Here are some notes about what to expect from the obvious references. + +Knuth Volume 2 (Seminumerical Algorithms) section 4.3.1 is the usual canonical +reference for long division, but that entire series is highly compressed, never +repeating a necessary fact and leaving important insights to the exercises. +For example, no rationale whatsoever is given for the calculation that extends +q̂ from a 2-by-1 to a 3-by-2 guess, nor why it reduces the error bound. +The proof that the calculation even has the desired effect is left to exercises. +The solutions to those exercises provided at the back of the book are entirely +calculations, still with no explanation as to what is going on or how you would +arrive at the idea of doing those exact calculations. Nowhere is it mentioned +that this test extends the 2-by-1 guess into a 3-by-2 guess. The proof of the +Good Guess Guarantee is only for the 2-by-1 guess and argues by contradiction, +making it difficult to understand how modifications like adding another digit +or adjusting the quotient range affects the overall bound. + +All that said, Knuth remains the canonical reference. It is dense but packed +full of information and references, and the proofs are simpler than many other +presentations. The proofs above are reworkings of Knuth's to remove the +arguments by contradiction and add explanations or steps that Knuth omitted. +But beware of errors in older printings. Take the published errata with you. + +Brinch Hansen's “Multiple-length Division Revisited: a Tour of the Minefield” +starts with a blunt critique of Knuth's presentation (among others) and then +presents a more detailed and easier to follow treatment of long division, +including an implementation in Pascal. But the algorithm and implementation +work entirely in terms of 3-by-2 division, which is much less useful on modern +hardware than an algorithm using 2-by-1 division. The proofs are a bit too +focused on digit counting and seem needlessly complex, especially compared to +the ones given above. + +Burnikel and Ziegler's “Fast Recursive Division” introduced the key insight of +implementing division by an n-digit divisor using recursive calls to division +by an n/2-digit divisor, relying on Karatsuba multiplication to yield a +sub-quadratic run time. However, the presentation decisions are made almost +entirely for the purpose of simplifying the run-time analysis, rather than +simplifying the presentation. Instead of a single algorithm that loops over +quotient digits, the paper presents two mutually-recursive algorithms, for +2n-by-n and 3n-by-2n. The paper also does not present any general (n+m)-by-n +algorithm. + +The proofs in the paper are remarkably complex, especially considering that +the algorithm is at its core just long division on wide digits, so that the +usual long division proofs apply essentially unaltered. +*/ + +package big + +import "math/bits" + +// rem returns r such that r = u%v. +// It uses z as the storage for r. +func (z nat) rem(u, v nat) (r nat) { + if alias(z, u) { + z = nil + } + qp := getNat(0) + q, r := qp.div(z, u, v) + *qp = q + putNat(qp) + return r +} + +// div returns q, r such that q = ⌊u/v⌋ and r = u%v = u - q·v. +// It uses z and z2 as the storage for q and r. +func (z nat) div(z2, u, v nat) (q, r nat) { + if len(v) == 0 { + panic("division by zero") + } + + if u.cmp(v) < 0 { + q = z[:0] + r = z2.set(u) + return + } + + if len(v) == 1 { + // Short division: long optimized for a single-word divisor. + // In that case, the 2-by-1 guess is all we need at each step. + var r2 Word + q, r2 = z.divW(u, v[0]) + r = z2.setWord(r2) + return + } + + q, r = z.divLarge(z2, u, v) + return +} + +// divW returns q, r such that q = ⌊x/y⌋ and r = x%y = x - q·y. +// It uses z as the storage for q. +// Note that y is a single digit (Word), not a big number. +func (z nat) divW(x nat, y Word) (q nat, r Word) { + m := len(x) + switch { + case y == 0: + panic("division by zero") + case y == 1: + q = z.set(x) // result is x + return + case m == 0: + q = z[:0] // result is 0 + return + } + // m > 0 + z = z.make(m) + r = divWVW(z, 0, x, y) + q = z.norm() + return +} + +// modW returns x % d. +func (x nat) modW(d Word) (r Word) { + // TODO(agl): we don't actually need to store the q value. + var q nat + q = q.make(len(x)) + return divWVW(q, 0, x, d) +} + +// divWVW overwrites z with ⌊x/y⌋, returning the remainder r. +// The caller must ensure that len(z) = len(x). +func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { + r = xn + if len(x) == 1 { + qq, rr := bits.Div(uint(r), uint(x[0]), uint(y)) + z[0] = Word(qq) + return Word(rr) + } + rec := reciprocalWord(y) + for i := len(z) - 1; i >= 0; i-- { + z[i], r = divWW(r, x[i], y, rec) + } + return r +} + +// div returns q, r such that q = ⌊uIn/vIn⌋ and r = uIn%vIn = uIn - q·vIn. +// It uses z and u as the storage for q and r. +// The caller must ensure that len(vIn) ≥ 2 (use divW otherwise) +// and that len(uIn) ≥ len(vIn) (the answer is 0, uIn otherwise). +func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { + n := len(vIn) + m := len(uIn) - n + + // Scale the inputs so vIn's top bit is 1 (see “Scaling Inputs” above). + // vIn is treated as a read-only input (it may be in use by another + // goroutine), so we must make a copy. + // uIn is copied to u. + shift := nlz(vIn[n-1]) + vp := getNat(n) + v := *vp + shlVU(v, vIn, shift) + u = u.make(len(uIn) + 1) + u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift) + + // The caller should not pass aliased z and u, since those are + // the two different outputs, but correct just in case. + if alias(z, u) { + z = nil + } + q = z.make(m + 1) + + // Use basic or recursive long division depending on size. + if n < divRecursiveThreshold { + q.divBasic(u, v) + } else { + q.divRecursive(u, v) + } + putNat(vp) + + q = q.norm() + + // Undo scaling of remainder. + shrVU(u, u, shift) + r = u.norm() + + return q, r +} + +// divBasic implements long division as described above. +// It overwrites q with ⌊u/v⌋ and overwrites u with the remainder r. +// q must be large enough to hold ⌊u/v⌋. +func (q nat) divBasic(u, v nat) { + n := len(v) + m := len(u) - n + + qhatvp := getNat(n + 1) + qhatv := *qhatvp + + // Set up for divWW below, precomputing reciprocal argument. + vn1 := v[n-1] + rec := reciprocalWord(vn1) + + // Compute each digit of quotient. + for j := m; j >= 0; j-- { + // Compute the 2-by-1 guess q̂. + // The first iteration must invent a leading 0 for u. + qhat := Word(_M) + var ujn Word + if j+n < len(u) { + ujn = u[j+n] + } + + // ujn ≤ vn1, or else q̂ would be more than one digit. + // For ujn == vn1, we set q̂ to the max digit M above. + // Otherwise, we compute the 2-by-1 guess. + if ujn != vn1 { + var rhat Word + qhat, rhat = divWW(ujn, u[j+n-1], vn1, rec) + + // Refine q̂ to a 3-by-2 guess. See “Refining Guesses” above. + vn2 := v[n-2] + x1, x2 := mulWW(qhat, vn2) + ujn2 := u[j+n-2] + for greaterThan(x1, x2, rhat, ujn2) { // x1x2 > r̂ u[j+n-2] + qhat-- + prevRhat := rhat + rhat += vn1 + // If r̂ overflows, then + // r̂ u[j+n-2]v[n-1] is now definitely > x1 x2. + if rhat < prevRhat { + break + } + // TODO(rsc): No need for a full mulWW. + // x2 += vn2; if x2 overflows, x1++ + x1, x2 = mulWW(qhat, vn2) + } + } + + // Compute q̂·v. + qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) + qhl := len(qhatv) + if j+qhl > len(u) && qhatv[n] == 0 { + qhl-- + } + + // Subtract q̂·v from the current section of u. + // If it underflows, q̂·v > u, which we fix up + // by decrementing q̂ and adding v back. + c := subVV(u[j:j+qhl], u[j:], qhatv) + if c != 0 { + c := addVV(u[j:j+n], u[j:], v) + // If n == qhl, the carry from subVV and the carry from addVV + // cancel out and don't affect u[j+n]. + if n < qhl { + u[j+n] += c + } + qhat-- + } + + // Save quotient digit. + // Caller may know the top digit is zero and not leave room for it. + if j == m && m == len(q) && qhat == 0 { + continue + } + q[j] = qhat + } + + putNat(qhatvp) +} + +// greaterThan reports whether the two digit numbers x1 x2 > y1 y2. +// TODO(rsc): In contradiction to most of this file, x1 is the high +// digit and x2 is the low digit. This should be fixed. +func greaterThan(x1, x2, y1, y2 Word) bool { + return x1 > y1 || x1 == y1 && x2 > y2 +} + +// divRecursiveThreshold is the number of divisor digits +// at which point divRecursive is faster than divBasic. +const divRecursiveThreshold = 100 + +// divRecursive implements recursive division as described above. +// It overwrites z with ⌊u/v⌋ and overwrites u with the remainder r. +// z must be large enough to hold ⌊u/v⌋. +// This function is just for allocating and freeing temporaries +// around divRecursiveStep, the real implementation. +func (z nat) divRecursive(u, v nat) { + // Recursion depth is (much) less than 2 log₂(len(v)). + // Allocate a slice of temporaries to be reused across recursion, + // plus one extra temporary not live across the recursion. + recDepth := 2 * bits.Len(uint(len(v))) + tmp := getNat(3 * len(v)) + temps := make([]*nat, recDepth) + + z.clear() + z.divRecursiveStep(u, v, 0, tmp, temps) + + // Free temporaries. + for _, n := range temps { + if n != nil { + putNat(n) + } + } + putNat(tmp) +} + +// divRecursiveStep is the actual implementation of recursive division. +// It adds ⌊u/v⌋ to z and overwrites u with the remainder r. +// z must be large enough to hold ⌊u/v⌋. +// It uses temps[depth] (allocating if needed) as a temporary live across +// the recursive call. It also uses tmp, but not live across the recursion. +func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { + // u is a subsection of the original and may have leading zeros. + // TODO(rsc): The v = v.norm() is useless and should be removed. + // We know (and require) that v's top digit is ≥ B/2. + u = u.norm() + v = v.norm() + if len(u) == 0 { + z.clear() + return + } + + // Fall back to basic division if the problem is now small enough. + n := len(v) + if n < divRecursiveThreshold { + z.divBasic(u, v) + return + } + + // Nothing to do if u is shorter than v (implies u < v). + m := len(u) - n + if m < 0 { + return + } + + // We consider B digits in a row as a single wide digit. + // (See “Recursive Division” above.) + // + // TODO(rsc): rename B to Wide, to avoid confusion with _B, + // which is something entirely different. + // TODO(rsc): Look into whether using ⌈n/2⌉ is better than ⌊n/2⌋. + B := n / 2 + + // Allocate a nat for qhat below. + if temps[depth] == nil { + temps[depth] = getNat(n) // TODO(rsc): Can be just B+1. + } else { + *temps[depth] = temps[depth].make(B + 1) + } + + // Compute each wide digit of the quotient. + // + // TODO(rsc): Change the loop to be + // for j := (m+B-1)/B*B; j > 0; j -= B { + // which will make the final step a regular step, letting us + // delete what amounts to an extra copy of the loop body below. + j := m + for j > B { + // Divide u[j-B:j+n] (3 wide digits) by v (2 wide digits). + // First make the 2-by-1-wide-digit guess using a recursive call. + // Then extend the guess to the full 3-by-2 (see “Refining Guesses”). + // + // For the 2-by-1-wide-digit guess, instead of doing 2B-by-B-digit, + // we use a (2B+1)-by-(B+1) digit, which handles the possibility that + // the result has an extra leading 1 digit as well as guaranteeing + // that the computed q̂ will be off by at most 1 instead of 2. + + // s is the number of digits to drop from the 3B- and 2B-digit chunks. + // We drop B-1 to be left with 2B+1 and B+1. + s := (B - 1) + + // uu is the up-to-3B-digit section of u we are working on. + uu := u[j-B:] + + // Compute the 2-by-1 guess q̂, leaving r̂ in uu[s:B+n]. + qhat := *temps[depth] + qhat.clear() + qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps) + qhat = qhat.norm() + + // Extend to a 3-by-2 quotient and remainder. + // Because divRecursiveStep overwrote the top part of uu with + // the remainder r̂, the full uu already contains the equivalent + // of r̂·B + uₙ₋₂ from the “Refining Guesses” discussion. + // Subtracting q̂·vₙ₋₂ from it will compute the full-length remainder. + // If that subtraction underflows, q̂·v > u, which we fix up + // by decrementing q̂ and adding v back, same as in long division. + + // TODO(rsc): Instead of subtract and fix-up, this code is computing + // q̂·vₙ₋₂ and decrementing q̂ until that product is ≤ u. + // But we can do the subtraction directly, as in the comment above + // and in long division, because we know that q̂ is wrong by at most one. + qhatv := tmp.make(3 * n) + qhatv.clear() + qhatv = qhatv.mul(qhat, v[:s]) + for i := 0; i < 2; i++ { + e := qhatv.cmp(uu.norm()) + if e <= 0 { + break + } + subVW(qhat, qhat, 1) + c := subVV(qhatv[:s], qhatv[:s], v[:s]) + if len(qhatv) > s { + subVW(qhatv[s:], qhatv[s:], c) + } + addAt(uu[s:], v[s:], 0) + } + if qhatv.cmp(uu.norm()) > 0 { + panic("impossible") + } + c := subVV(uu[:len(qhatv)], uu[:len(qhatv)], qhatv) + if c > 0 { + subVW(uu[len(qhatv):], uu[len(qhatv):], c) + } + addAt(z, qhat, j-B) + j -= B + } + + // TODO(rsc): Rewrite loop as described above and delete all this code. + + // Now u < (v<<B), compute lower bits in the same way. + // Choose shift = B-1 again. + s := B - 1 + qhat := *temps[depth] + qhat.clear() + qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps) + qhat = qhat.norm() + qhatv := tmp.make(3 * n) + qhatv.clear() + qhatv = qhatv.mul(qhat, v[:s]) + // Set the correct remainder as before. + for i := 0; i < 2; i++ { + if e := qhatv.cmp(u.norm()); e > 0 { + subVW(qhat, qhat, 1) + c := subVV(qhatv[:s], qhatv[:s], v[:s]) + if len(qhatv) > s { + subVW(qhatv[s:], qhatv[s:], c) + } + addAt(u[s:], v[s:], 0) + } + } + if qhatv.cmp(u.norm()) > 0 { + panic("impossible") + } + c := subVV(u[0:len(qhatv)], u[0:len(qhatv)], qhatv) + if c > 0 { + c = subVW(u[len(qhatv):], u[len(qhatv):], c) + } + if c > 0 { + panic("impossible") + } + + // Done! + addAt(z, qhat.norm(), 0) +} diff --git a/src/math/big/prime.go b/src/math/big/prime.go new file mode 100644 index 0000000..26688bb --- /dev/null +++ b/src/math/big/prime.go @@ -0,0 +1,320 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import "math/rand" + +// ProbablyPrime reports whether x is probably prime, +// applying the Miller-Rabin test with n pseudorandomly chosen bases +// as well as a Baillie-PSW test. +// +// If x is prime, ProbablyPrime returns true. +// If x is chosen randomly and not prime, ProbablyPrime probably returns false. +// The probability of returning true for a randomly chosen non-prime is at most ¼ⁿ. +// +// ProbablyPrime is 100% accurate for inputs less than 2⁶⁴. +// See Menezes et al., Handbook of Applied Cryptography, 1997, pp. 145-149, +// and FIPS 186-4 Appendix F for further discussion of the error probabilities. +// +// ProbablyPrime is not suitable for judging primes that an adversary may +// have crafted to fool the test. +// +// As of Go 1.8, ProbablyPrime(0) is allowed and applies only a Baillie-PSW test. +// Before Go 1.8, ProbablyPrime applied only the Miller-Rabin tests, and ProbablyPrime(0) panicked. +func (x *Int) ProbablyPrime(n int) bool { + // Note regarding the doc comment above: + // It would be more precise to say that the Baillie-PSW test uses the + // extra strong Lucas test as its Lucas test, but since no one knows + // how to tell any of the Lucas tests apart inside a Baillie-PSW test + // (they all work equally well empirically), that detail need not be + // documented or implicitly guaranteed. + // The comment does avoid saying "the" Baillie-PSW test + // because of this general ambiguity. + + if n < 0 { + panic("negative n for ProbablyPrime") + } + if x.neg || len(x.abs) == 0 { + return false + } + + // primeBitMask records the primes < 64. + const primeBitMask uint64 = 1<<2 | 1<<3 | 1<<5 | 1<<7 | + 1<<11 | 1<<13 | 1<<17 | 1<<19 | 1<<23 | 1<<29 | 1<<31 | + 1<<37 | 1<<41 | 1<<43 | 1<<47 | 1<<53 | 1<<59 | 1<<61 + + w := x.abs[0] + if len(x.abs) == 1 && w < 64 { + return primeBitMask&(1<<w) != 0 + } + + if w&1 == 0 { + return false // x is even + } + + const primesA = 3 * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 37 + const primesB = 29 * 31 * 41 * 43 * 47 * 53 + + var rA, rB uint32 + switch _W { + case 32: + rA = uint32(x.abs.modW(primesA)) + rB = uint32(x.abs.modW(primesB)) + case 64: + r := x.abs.modW((primesA * primesB) & _M) + rA = uint32(r % primesA) + rB = uint32(r % primesB) + default: + panic("math/big: invalid word size") + } + + if rA%3 == 0 || rA%5 == 0 || rA%7 == 0 || rA%11 == 0 || rA%13 == 0 || rA%17 == 0 || rA%19 == 0 || rA%23 == 0 || rA%37 == 0 || + rB%29 == 0 || rB%31 == 0 || rB%41 == 0 || rB%43 == 0 || rB%47 == 0 || rB%53 == 0 { + return false + } + + return x.abs.probablyPrimeMillerRabin(n+1, true) && x.abs.probablyPrimeLucas() +} + +// probablyPrimeMillerRabin reports whether n passes reps rounds of the +// Miller-Rabin primality test, using pseudo-randomly chosen bases. +// If force2 is true, one of the rounds is forced to use base 2. +// See Handbook of Applied Cryptography, p. 139, Algorithm 4.24. +// The number n is known to be non-zero. +func (n nat) probablyPrimeMillerRabin(reps int, force2 bool) bool { + nm1 := nat(nil).sub(n, natOne) + // determine q, k such that nm1 = q << k + k := nm1.trailingZeroBits() + q := nat(nil).shr(nm1, k) + + nm3 := nat(nil).sub(nm1, natTwo) + rand := rand.New(rand.NewSource(int64(n[0]))) + + var x, y, quotient nat + nm3Len := nm3.bitLen() + +NextRandom: + for i := 0; i < reps; i++ { + if i == reps-1 && force2 { + x = x.set(natTwo) + } else { + x = x.random(rand, nm3, nm3Len) + x = x.add(x, natTwo) + } + y = y.expNN(x, q, n, false) + if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { + continue + } + for j := uint(1); j < k; j++ { + y = y.sqr(y) + quotient, y = quotient.div(y, y, n) + if y.cmp(nm1) == 0 { + continue NextRandom + } + if y.cmp(natOne) == 0 { + return false + } + } + return false + } + + return true +} + +// probablyPrimeLucas reports whether n passes the "almost extra strong" Lucas probable prime test, +// using Baillie-OEIS parameter selection. This corresponds to "AESLPSP" on Jacobsen's tables (link below). +// The combination of this test and a Miller-Rabin/Fermat test with base 2 gives a Baillie-PSW test. +// +// References: +// +// Baillie and Wagstaff, "Lucas Pseudoprimes", Mathematics of Computation 35(152), +// October 1980, pp. 1391-1417, especially page 1401. +// https://www.ams.org/journals/mcom/1980-35-152/S0025-5718-1980-0583518-6/S0025-5718-1980-0583518-6.pdf +// +// Grantham, "Frobenius Pseudoprimes", Mathematics of Computation 70(234), +// March 2000, pp. 873-891. +// https://www.ams.org/journals/mcom/2001-70-234/S0025-5718-00-01197-2/S0025-5718-00-01197-2.pdf +// +// Baillie, "Extra strong Lucas pseudoprimes", OEIS A217719, https://oeis.org/A217719. +// +// Jacobsen, "Pseudoprime Statistics, Tables, and Data", http://ntheory.org/pseudoprimes.html. +// +// Nicely, "The Baillie-PSW Primality Test", https://web.archive.org/web/20191121062007/http://www.trnicely.net/misc/bpsw.html. +// (Note that Nicely's definition of the "extra strong" test gives the wrong Jacobi condition, +// as pointed out by Jacobsen.) +// +// Crandall and Pomerance, Prime Numbers: A Computational Perspective, 2nd ed. +// Springer, 2005. +func (n nat) probablyPrimeLucas() bool { + // Discard 0, 1. + if len(n) == 0 || n.cmp(natOne) == 0 { + return false + } + // Two is the only even prime. + // Already checked by caller, but here to allow testing in isolation. + if n[0]&1 == 0 { + return n.cmp(natTwo) == 0 + } + + // Baillie-OEIS "method C" for choosing D, P, Q, + // as in https://oeis.org/A217719/a217719.txt: + // try increasing P ≥ 3 such that D = P² - 4 (so Q = 1) + // until Jacobi(D, n) = -1. + // The search is expected to succeed for non-square n after just a few trials. + // After more than expected failures, check whether n is square + // (which would cause Jacobi(D, n) = 1 for all D not dividing n). + p := Word(3) + d := nat{1} + t1 := nat(nil) // temp + intD := &Int{abs: d} + intN := &Int{abs: n} + for ; ; p++ { + if p > 10000 { + // This is widely believed to be impossible. + // If we get a report, we'll want the exact number n. + panic("math/big: internal error: cannot find (D/n) = -1 for " + intN.String()) + } + d[0] = p*p - 4 + j := Jacobi(intD, intN) + if j == -1 { + break + } + if j == 0 { + // d = p²-4 = (p-2)(p+2). + // If (d/n) == 0 then d shares a prime factor with n. + // Since the loop proceeds in increasing p and starts with p-2==1, + // the shared prime factor must be p+2. + // If p+2 == n, then n is prime; otherwise p+2 is a proper factor of n. + return len(n) == 1 && n[0] == p+2 + } + if p == 40 { + // We'll never find (d/n) = -1 if n is a square. + // If n is a non-square we expect to find a d in just a few attempts on average. + // After 40 attempts, take a moment to check if n is indeed a square. + t1 = t1.sqrt(n) + t1 = t1.sqr(t1) + if t1.cmp(n) == 0 { + return false + } + } + } + + // Grantham definition of "extra strong Lucas pseudoprime", after Thm 2.3 on p. 876 + // (D, P, Q above have become Δ, b, 1): + // + // Let U_n = U_n(b, 1), V_n = V_n(b, 1), and Δ = b²-4. + // An extra strong Lucas pseudoprime to base b is a composite n = 2^r s + Jacobi(Δ, n), + // where s is odd and gcd(n, 2*Δ) = 1, such that either (i) U_s ≡ 0 mod n and V_s ≡ ±2 mod n, + // or (ii) V_{2^t s} ≡ 0 mod n for some 0 ≤ t < r-1. + // + // We know gcd(n, Δ) = 1 or else we'd have found Jacobi(d, n) == 0 above. + // We know gcd(n, 2) = 1 because n is odd. + // + // Arrange s = (n - Jacobi(Δ, n)) / 2^r = (n+1) / 2^r. + s := nat(nil).add(n, natOne) + r := int(s.trailingZeroBits()) + s = s.shr(s, uint(r)) + nm2 := nat(nil).sub(n, natTwo) // n-2 + + // We apply the "almost extra strong" test, which checks the above conditions + // except for U_s ≡ 0 mod n, which allows us to avoid computing any U_k values. + // Jacobsen points out that maybe we should just do the full extra strong test: + // "It is also possible to recover U_n using Crandall and Pomerance equation 3.13: + // U_n = D^-1 (2V_{n+1} - PV_n) allowing us to run the full extra-strong test + // at the cost of a single modular inversion. This computation is easy and fast in GMP, + // so we can get the full extra-strong test at essentially the same performance as the + // almost extra strong test." + + // Compute Lucas sequence V_s(b, 1), where: + // + // V(0) = 2 + // V(1) = P + // V(k) = P V(k-1) - Q V(k-2). + // + // (Remember that due to method C above, P = b, Q = 1.) + // + // In general V(k) = α^k + β^k, where α and β are roots of x² - Px + Q. + // Crandall and Pomerance (p.147) observe that for 0 ≤ j ≤ k, + // + // V(j+k) = V(j)V(k) - V(k-j). + // + // So in particular, to quickly double the subscript: + // + // V(2k) = V(k)² - 2 + // V(2k+1) = V(k) V(k+1) - P + // + // We can therefore start with k=0 and build up to k=s in log₂(s) steps. + natP := nat(nil).setWord(p) + vk := nat(nil).setWord(2) + vk1 := nat(nil).setWord(p) + t2 := nat(nil) // temp + for i := int(s.bitLen()); i >= 0; i-- { + if s.bit(uint(i)) != 0 { + // k' = 2k+1 + // V(k') = V(2k+1) = V(k) V(k+1) - P. + t1 = t1.mul(vk, vk1) + t1 = t1.add(t1, n) + t1 = t1.sub(t1, natP) + t2, vk = t2.div(vk, t1, n) + // V(k'+1) = V(2k+2) = V(k+1)² - 2. + t1 = t1.sqr(vk1) + t1 = t1.add(t1, nm2) + t2, vk1 = t2.div(vk1, t1, n) + } else { + // k' = 2k + // V(k'+1) = V(2k+1) = V(k) V(k+1) - P. + t1 = t1.mul(vk, vk1) + t1 = t1.add(t1, n) + t1 = t1.sub(t1, natP) + t2, vk1 = t2.div(vk1, t1, n) + // V(k') = V(2k) = V(k)² - 2 + t1 = t1.sqr(vk) + t1 = t1.add(t1, nm2) + t2, vk = t2.div(vk, t1, n) + } + } + + // Now k=s, so vk = V(s). Check V(s) ≡ ±2 (mod n). + if vk.cmp(natTwo) == 0 || vk.cmp(nm2) == 0 { + // Check U(s) ≡ 0. + // As suggested by Jacobsen, apply Crandall and Pomerance equation 3.13: + // + // U(k) = D⁻¹ (2 V(k+1) - P V(k)) + // + // Since we are checking for U(k) == 0 it suffices to check 2 V(k+1) == P V(k) mod n, + // or P V(k) - 2 V(k+1) == 0 mod n. + t1 := t1.mul(vk, natP) + t2 := t2.shl(vk1, 1) + if t1.cmp(t2) < 0 { + t1, t2 = t2, t1 + } + t1 = t1.sub(t1, t2) + t3 := vk1 // steal vk1, no longer needed below + vk1 = nil + _ = vk1 + t2, t3 = t2.div(t3, t1, n) + if len(t3) == 0 { + return true + } + } + + // Check V(2^t s) ≡ 0 mod n for some 0 ≤ t < r-1. + for t := 0; t < r-1; t++ { + if len(vk) == 0 { // vk == 0 + return true + } + // Optimization: V(k) = 2 is a fixed point for V(k') = V(k)² - 2, + // so if V(k) = 2, we can stop: we will never find a future V(k) == 0. + if len(vk) == 1 && vk[0] == 2 { // vk == 2 + return false + } + // k' = 2k + // V(k') = V(2k) = V(k)² - 2 + t1 = t1.sqr(vk) + t1 = t1.sub(t1, natTwo) + t2, vk = t2.div(vk, t1, n) + } + return false +} diff --git a/src/math/big/prime_test.go b/src/math/big/prime_test.go new file mode 100644 index 0000000..8596e33 --- /dev/null +++ b/src/math/big/prime_test.go @@ -0,0 +1,222 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "fmt" + "strings" + "testing" + "unicode" +) + +var primes = []string{ + "2", + "3", + "5", + "7", + "11", + + "13756265695458089029", + "13496181268022124907", + "10953742525620032441", + "17908251027575790097", + + // https://golang.org/issue/638 + "18699199384836356663", + + "98920366548084643601728869055592650835572950932266967461790948584315647051443", + "94560208308847015747498523884063394671606671904944666360068158221458669711639", + + // https://primes.utm.edu/lists/small/small3.html + "449417999055441493994709297093108513015373787049558499205492347871729927573118262811508386655998299074566974373711472560655026288668094291699357843464363003144674940345912431129144354948751003607115263071543163", + "230975859993204150666423538988557839555560243929065415434980904258310530753006723857139742334640122533598517597674807096648905501653461687601339782814316124971547968912893214002992086353183070342498989426570593", + "5521712099665906221540423207019333379125265462121169655563495403888449493493629943498064604536961775110765377745550377067893607246020694972959780839151452457728855382113555867743022746090187341871655890805971735385789993", + "203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123", + + // ECC primes: https://tools.ietf.org/html/draft-ladd-safecurves-02 + "3618502788666131106986593281521497120414687020801267626233049500247285301239", // Curve1174: 2^251-9 + "57896044618658097711785492504343953926634992332820282019728792003956564819949", // Curve25519: 2^255-19 + "9850501549098619803069760025035903451269934817616361666987073351061430442874302652853566563721228910201656997576599", // E-382: 2^382-105 + "42307582002575910332922579714097346549017899709713998034217522897561970639123926132812109468141778230245837569601494931472367", // Curve41417: 2^414-17 + "6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151", // E-521: 2^521-1 +} + +var composites = []string{ + "0", + "1", + "21284175091214687912771199898307297748211672914763848041968395774954376176754", + "6084766654921918907427900243509372380954290099172559290432744450051395395951", + "84594350493221918389213352992032324280367711247940675652888030554255915464401", + "82793403787388584738507275144194252681", + + // Arnault, "Rabin-Miller Primality Test: Composite Numbers Which Pass It", + // Mathematics of Computation, 64(209) (January 1995), pp. 335-361. + "1195068768795265792518361315725116351898245581", // strong pseudoprime to prime bases 2 through 29 + // strong pseudoprime to all prime bases up to 200 + ` + 80383745745363949125707961434194210813883768828755814583748891752229 + 74273765333652186502336163960045457915042023603208766569966760987284 + 0439654082329287387918508691668573282677617710293896977394701670823 + 0428687109997439976544144845341155872450633409279022275296229414984 + 2306881685404326457534018329786111298960644845216191652872597534901`, + + // Extra-strong Lucas pseudoprimes. https://oeis.org/A217719 + "989", + "3239", + "5777", + "10877", + "27971", + "29681", + "30739", + "31631", + "39059", + "72389", + "73919", + "75077", + "100127", + "113573", + "125249", + "137549", + "137801", + "153931", + "155819", + "161027", + "162133", + "189419", + "218321", + "231703", + "249331", + "370229", + "429479", + "430127", + "459191", + "473891", + "480689", + "600059", + "621781", + "632249", + "635627", + + "3673744903", + "3281593591", + "2385076987", + "2738053141", + "2009621503", + "1502682721", + "255866131", + "117987841", + "587861", + + "6368689", + "8725753", + "80579735209", + "105919633", +} + +func cutSpace(r rune) rune { + if unicode.IsSpace(r) { + return -1 + } + return r +} + +func TestProbablyPrime(t *testing.T) { + nreps := 20 + if testing.Short() { + nreps = 1 + } + for i, s := range primes { + p, _ := new(Int).SetString(s, 10) + if !p.ProbablyPrime(nreps) || nreps != 1 && !p.ProbablyPrime(1) || !p.ProbablyPrime(0) { + t.Errorf("#%d prime found to be non-prime (%s)", i, s) + } + } + + for i, s := range composites { + s = strings.Map(cutSpace, s) + c, _ := new(Int).SetString(s, 10) + if c.ProbablyPrime(nreps) || nreps != 1 && c.ProbablyPrime(1) || c.ProbablyPrime(0) { + t.Errorf("#%d composite found to be prime (%s)", i, s) + } + } + + // check that ProbablyPrime panics if n <= 0 + c := NewInt(11) // a prime + for _, n := range []int{-1, 0, 1} { + func() { + defer func() { + if n < 0 && recover() == nil { + t.Fatalf("expected panic from ProbablyPrime(%d)", n) + } + }() + if !c.ProbablyPrime(n) { + t.Fatalf("%v should be a prime", c) + } + }() + } +} + +func BenchmarkProbablyPrime(b *testing.B) { + p, _ := new(Int).SetString("203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123", 10) + for _, n := range []int{0, 1, 5, 10, 20} { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + for i := 0; i < b.N; i++ { + p.ProbablyPrime(n) + } + }) + } + + b.Run("Lucas", func(b *testing.B) { + for i := 0; i < b.N; i++ { + p.abs.probablyPrimeLucas() + } + }) + b.Run("MillerRabinBase2", func(b *testing.B) { + for i := 0; i < b.N; i++ { + p.abs.probablyPrimeMillerRabin(1, true) + } + }) +} + +func TestMillerRabinPseudoprimes(t *testing.T) { + testPseudoprimes(t, "probablyPrimeMillerRabin", + func(n nat) bool { return n.probablyPrimeMillerRabin(1, true) && !n.probablyPrimeLucas() }, + // https://oeis.org/A001262 + []int{2047, 3277, 4033, 4681, 8321, 15841, 29341, 42799, 49141, 52633, 65281, 74665, 80581, 85489, 88357, 90751}) +} + +func TestLucasPseudoprimes(t *testing.T) { + testPseudoprimes(t, "probablyPrimeLucas", + func(n nat) bool { return n.probablyPrimeLucas() && !n.probablyPrimeMillerRabin(1, true) }, + // https://oeis.org/A217719 + []int{989, 3239, 5777, 10877, 27971, 29681, 30739, 31631, 39059, 72389, 73919, 75077}) +} + +func testPseudoprimes(t *testing.T, name string, cond func(nat) bool, want []int) { + n := nat{1} + for i := 3; i < 100000; i += 2 { + if testing.Short() { + if len(want) == 0 { + break + } + if i < want[0]-2 { + i = want[0] - 2 + } + } + n[0] = Word(i) + pseudo := cond(n) + if pseudo && (len(want) == 0 || i != want[0]) { + t.Errorf("%s(%v, base=2) = true, want false", name, i) + } else if !pseudo && len(want) >= 1 && i == want[0] { + t.Errorf("%s(%v, base=2) = false, want true", name, i) + } + if len(want) > 0 && i == want[0] { + want = want[1:] + } + } + if len(want) > 0 { + t.Fatalf("forgot to test %v", want) + } +} diff --git a/src/math/big/rat.go b/src/math/big/rat.go new file mode 100644 index 0000000..700a643 --- /dev/null +++ b/src/math/big/rat.go @@ -0,0 +1,542 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements multi-precision rational numbers. + +package big + +import ( + "fmt" + "math" +) + +// A Rat represents a quotient a/b of arbitrary precision. +// The zero value for a Rat represents the value 0. +// +// Operations always take pointer arguments (*Rat) rather +// than Rat values, and each unique Rat value requires +// its own unique *Rat pointer. To "copy" a Rat value, +// an existing (or newly allocated) Rat must be set to +// a new value using the Rat.Set method; shallow copies +// of Rats are not supported and may lead to errors. +type Rat struct { + // To make zero values for Rat work w/o initialization, + // a zero value of b (len(b) == 0) acts like b == 1. At + // the earliest opportunity (when an assignment to the Rat + // is made), such uninitialized denominators are set to 1. + // a.neg determines the sign of the Rat, b.neg is ignored. + a, b Int +} + +// NewRat creates a new Rat with numerator a and denominator b. +func NewRat(a, b int64) *Rat { + return new(Rat).SetFrac64(a, b) +} + +// SetFloat64 sets z to exactly f and returns z. +// If f is not finite, SetFloat returns nil. +func (z *Rat) SetFloat64(f float64) *Rat { + const expMask = 1<<11 - 1 + bits := math.Float64bits(f) + mantissa := bits & (1<<52 - 1) + exp := int((bits >> 52) & expMask) + switch exp { + case expMask: // non-finite + return nil + case 0: // denormal + exp -= 1022 + default: // normal + mantissa |= 1 << 52 + exp -= 1023 + } + + shift := 52 - exp + + // Optimization (?): partially pre-normalise. + for mantissa&1 == 0 && shift > 0 { + mantissa >>= 1 + shift-- + } + + z.a.SetUint64(mantissa) + z.a.neg = f < 0 + z.b.Set(intOne) + if shift > 0 { + z.b.Lsh(&z.b, uint(shift)) + } else { + z.a.Lsh(&z.a, uint(-shift)) + } + return z.norm() +} + +// quotToFloat32 returns the non-negative float32 value +// nearest to the quotient a/b, using round-to-even in +// halfway cases. It does not mutate its arguments. +// Preconditions: b is non-zero; a and b have no common factors. +func quotToFloat32(a, b nat) (f float32, exact bool) { + const ( + // float size in bits + Fsize = 32 + + // mantissa + Msize = 23 + Msize1 = Msize + 1 // incl. implicit 1 + Msize2 = Msize1 + 1 + + // exponent + Esize = Fsize - Msize1 + Ebias = 1<<(Esize-1) - 1 + Emin = 1 - Ebias + Emax = Ebias + ) + + // TODO(adonovan): specialize common degenerate cases: 1.0, integers. + alen := a.bitLen() + if alen == 0 { + return 0, true + } + blen := b.bitLen() + if blen == 0 { + panic("division by zero") + } + + // 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1) + // (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B). + // This is 2 or 3 more than the float32 mantissa field width of Msize: + // - the optional extra bit is shifted away in step 3 below. + // - the high-order 1 is omitted in "normal" representation; + // - the low-order 1 will be used during rounding then discarded. + exp := alen - blen + var a2, b2 nat + a2 = a2.set(a) + b2 = b2.set(b) + if shift := Msize2 - exp; shift > 0 { + a2 = a2.shl(a2, uint(shift)) + } else if shift < 0 { + b2 = b2.shl(b2, uint(-shift)) + } + + // 2. Compute quotient and remainder (q, r). NB: due to the + // extra shift, the low-order bit of q is logically the + // high-order bit of r. + var q nat + q, r := q.div(a2, a2, b2) // (recycle a2) + mantissa := low32(q) + haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half + + // 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1 + // (in effect---we accomplish this incrementally). + if mantissa>>Msize2 == 1 { + if mantissa&1 == 1 { + haveRem = true + } + mantissa >>= 1 + exp++ + } + if mantissa>>Msize1 != 1 { + panic(fmt.Sprintf("expected exactly %d bits of result", Msize2)) + } + + // 4. Rounding. + if Emin-Msize <= exp && exp <= Emin { + // Denormal case; lose 'shift' bits of precision. + shift := uint(Emin - (exp - 1)) // [1..Esize1) + lostbits := mantissa & (1<<shift - 1) + haveRem = haveRem || lostbits != 0 + mantissa >>= shift + exp = 2 - Ebias // == exp + shift + } + // Round q using round-half-to-even. + exact = !haveRem + if mantissa&1 != 0 { + exact = false + if haveRem || mantissa&2 != 0 { + if mantissa++; mantissa >= 1<<Msize2 { + // Complete rollover 11...1 => 100...0, so shift is safe + mantissa >>= 1 + exp++ + } + } + } + mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1<<Msize1. + + f = float32(math.Ldexp(float64(mantissa), exp-Msize1)) + if math.IsInf(float64(f), 0) { + exact = false + } + return +} + +// quotToFloat64 returns the non-negative float64 value +// nearest to the quotient a/b, using round-to-even in +// halfway cases. It does not mutate its arguments. +// Preconditions: b is non-zero; a and b have no common factors. +func quotToFloat64(a, b nat) (f float64, exact bool) { + const ( + // float size in bits + Fsize = 64 + + // mantissa + Msize = 52 + Msize1 = Msize + 1 // incl. implicit 1 + Msize2 = Msize1 + 1 + + // exponent + Esize = Fsize - Msize1 + Ebias = 1<<(Esize-1) - 1 + Emin = 1 - Ebias + Emax = Ebias + ) + + // TODO(adonovan): specialize common degenerate cases: 1.0, integers. + alen := a.bitLen() + if alen == 0 { + return 0, true + } + blen := b.bitLen() + if blen == 0 { + panic("division by zero") + } + + // 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1) + // (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B). + // This is 2 or 3 more than the float64 mantissa field width of Msize: + // - the optional extra bit is shifted away in step 3 below. + // - the high-order 1 is omitted in "normal" representation; + // - the low-order 1 will be used during rounding then discarded. + exp := alen - blen + var a2, b2 nat + a2 = a2.set(a) + b2 = b2.set(b) + if shift := Msize2 - exp; shift > 0 { + a2 = a2.shl(a2, uint(shift)) + } else if shift < 0 { + b2 = b2.shl(b2, uint(-shift)) + } + + // 2. Compute quotient and remainder (q, r). NB: due to the + // extra shift, the low-order bit of q is logically the + // high-order bit of r. + var q nat + q, r := q.div(a2, a2, b2) // (recycle a2) + mantissa := low64(q) + haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half + + // 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1 + // (in effect---we accomplish this incrementally). + if mantissa>>Msize2 == 1 { + if mantissa&1 == 1 { + haveRem = true + } + mantissa >>= 1 + exp++ + } + if mantissa>>Msize1 != 1 { + panic(fmt.Sprintf("expected exactly %d bits of result", Msize2)) + } + + // 4. Rounding. + if Emin-Msize <= exp && exp <= Emin { + // Denormal case; lose 'shift' bits of precision. + shift := uint(Emin - (exp - 1)) // [1..Esize1) + lostbits := mantissa & (1<<shift - 1) + haveRem = haveRem || lostbits != 0 + mantissa >>= shift + exp = 2 - Ebias // == exp + shift + } + // Round q using round-half-to-even. + exact = !haveRem + if mantissa&1 != 0 { + exact = false + if haveRem || mantissa&2 != 0 { + if mantissa++; mantissa >= 1<<Msize2 { + // Complete rollover 11...1 => 100...0, so shift is safe + mantissa >>= 1 + exp++ + } + } + } + mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1<<Msize1. + + f = math.Ldexp(float64(mantissa), exp-Msize1) + if math.IsInf(f, 0) { + exact = false + } + return +} + +// Float32 returns the nearest float32 value for x and a bool indicating +// whether f represents x exactly. If the magnitude of x is too large to +// be represented by a float32, f is an infinity and exact is false. +// The sign of f always matches the sign of x, even if f == 0. +func (x *Rat) Float32() (f float32, exact bool) { + b := x.b.abs + if len(b) == 0 { + b = natOne + } + f, exact = quotToFloat32(x.a.abs, b) + if x.a.neg { + f = -f + } + return +} + +// Float64 returns the nearest float64 value for x and a bool indicating +// whether f represents x exactly. If the magnitude of x is too large to +// be represented by a float64, f is an infinity and exact is false. +// The sign of f always matches the sign of x, even if f == 0. +func (x *Rat) Float64() (f float64, exact bool) { + b := x.b.abs + if len(b) == 0 { + b = natOne + } + f, exact = quotToFloat64(x.a.abs, b) + if x.a.neg { + f = -f + } + return +} + +// SetFrac sets z to a/b and returns z. +// If b == 0, SetFrac panics. +func (z *Rat) SetFrac(a, b *Int) *Rat { + z.a.neg = a.neg != b.neg + babs := b.abs + if len(babs) == 0 { + panic("division by zero") + } + if &z.a == b || alias(z.a.abs, babs) { + babs = nat(nil).set(babs) // make a copy + } + z.a.abs = z.a.abs.set(a.abs) + z.b.abs = z.b.abs.set(babs) + return z.norm() +} + +// SetFrac64 sets z to a/b and returns z. +// If b == 0, SetFrac64 panics. +func (z *Rat) SetFrac64(a, b int64) *Rat { + if b == 0 { + panic("division by zero") + } + z.a.SetInt64(a) + if b < 0 { + b = -b + z.a.neg = !z.a.neg + } + z.b.abs = z.b.abs.setUint64(uint64(b)) + return z.norm() +} + +// SetInt sets z to x (by making a copy of x) and returns z. +func (z *Rat) SetInt(x *Int) *Rat { + z.a.Set(x) + z.b.abs = z.b.abs.setWord(1) + return z +} + +// SetInt64 sets z to x and returns z. +func (z *Rat) SetInt64(x int64) *Rat { + z.a.SetInt64(x) + z.b.abs = z.b.abs.setWord(1) + return z +} + +// SetUint64 sets z to x and returns z. +func (z *Rat) SetUint64(x uint64) *Rat { + z.a.SetUint64(x) + z.b.abs = z.b.abs.setWord(1) + return z +} + +// Set sets z to x (by making a copy of x) and returns z. +func (z *Rat) Set(x *Rat) *Rat { + if z != x { + z.a.Set(&x.a) + z.b.Set(&x.b) + } + if len(z.b.abs) == 0 { + z.b.abs = z.b.abs.setWord(1) + } + return z +} + +// Abs sets z to |x| (the absolute value of x) and returns z. +func (z *Rat) Abs(x *Rat) *Rat { + z.Set(x) + z.a.neg = false + return z +} + +// Neg sets z to -x and returns z. +func (z *Rat) Neg(x *Rat) *Rat { + z.Set(x) + z.a.neg = len(z.a.abs) > 0 && !z.a.neg // 0 has no sign + return z +} + +// Inv sets z to 1/x and returns z. +// If x == 0, Inv panics. +func (z *Rat) Inv(x *Rat) *Rat { + if len(x.a.abs) == 0 { + panic("division by zero") + } + z.Set(x) + z.a.abs, z.b.abs = z.b.abs, z.a.abs + return z +} + +// Sign returns: +// +// -1 if x < 0 +// 0 if x == 0 +// +1 if x > 0 +func (x *Rat) Sign() int { + return x.a.Sign() +} + +// IsInt reports whether the denominator of x is 1. +func (x *Rat) IsInt() bool { + return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0 +} + +// Num returns the numerator of x; it may be <= 0. +// The result is a reference to x's numerator; it +// may change if a new value is assigned to x, and vice versa. +// The sign of the numerator corresponds to the sign of x. +func (x *Rat) Num() *Int { + return &x.a +} + +// Denom returns the denominator of x; it is always > 0. +// The result is a reference to x's denominator, unless +// x is an uninitialized (zero value) Rat, in which case +// the result is a new Int of value 1. (To initialize x, +// any operation that sets x will do, including x.Set(x).) +// If the result is a reference to x's denominator it +// may change if a new value is assigned to x, and vice versa. +func (x *Rat) Denom() *Int { + // Note that x.b.neg is guaranteed false. + if len(x.b.abs) == 0 { + // Note: If this proves problematic, we could + // panic instead and require the Rat to + // be explicitly initialized. + return &Int{abs: nat{1}} + } + return &x.b +} + +func (z *Rat) norm() *Rat { + switch { + case len(z.a.abs) == 0: + // z == 0; normalize sign and denominator + z.a.neg = false + fallthrough + case len(z.b.abs) == 0: + // z is integer; normalize denominator + z.b.abs = z.b.abs.setWord(1) + default: + // z is fraction; normalize numerator and denominator + neg := z.a.neg + z.a.neg = false + z.b.neg = false + if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 { + z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs) + z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs) + } + z.a.neg = neg + } + return z +} + +// mulDenom sets z to the denominator product x*y (by taking into +// account that 0 values for x or y must be interpreted as 1) and +// returns z. +func mulDenom(z, x, y nat) nat { + switch { + case len(x) == 0 && len(y) == 0: + return z.setWord(1) + case len(x) == 0: + return z.set(y) + case len(y) == 0: + return z.set(x) + } + return z.mul(x, y) +} + +// scaleDenom sets z to the product x*f. +// If f == 0 (zero value of denominator), z is set to (a copy of) x. +func (z *Int) scaleDenom(x *Int, f nat) { + if len(f) == 0 { + z.Set(x) + return + } + z.abs = z.abs.mul(x.abs, f) + z.neg = x.neg +} + +// Cmp compares x and y and returns: +// +// -1 if x < y +// 0 if x == y +// +1 if x > y +func (x *Rat) Cmp(y *Rat) int { + var a, b Int + a.scaleDenom(&x.a, y.b.abs) + b.scaleDenom(&y.a, x.b.abs) + return a.Cmp(&b) +} + +// Add sets z to the sum x+y and returns z. +func (z *Rat) Add(x, y *Rat) *Rat { + var a1, a2 Int + a1.scaleDenom(&x.a, y.b.abs) + a2.scaleDenom(&y.a, x.b.abs) + z.a.Add(&a1, &a2) + z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + return z.norm() +} + +// Sub sets z to the difference x-y and returns z. +func (z *Rat) Sub(x, y *Rat) *Rat { + var a1, a2 Int + a1.scaleDenom(&x.a, y.b.abs) + a2.scaleDenom(&y.a, x.b.abs) + z.a.Sub(&a1, &a2) + z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + return z.norm() +} + +// Mul sets z to the product x*y and returns z. +func (z *Rat) Mul(x, y *Rat) *Rat { + if x == y { + // a squared Rat is positive and can't be reduced (no need to call norm()) + z.a.neg = false + z.a.abs = z.a.abs.sqr(x.a.abs) + if len(x.b.abs) == 0 { + z.b.abs = z.b.abs.setWord(1) + } else { + z.b.abs = z.b.abs.sqr(x.b.abs) + } + return z + } + z.a.Mul(&x.a, &y.a) + z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + return z.norm() +} + +// Quo sets z to the quotient x/y and returns z. +// If y == 0, Quo panics. +func (z *Rat) Quo(x, y *Rat) *Rat { + if len(y.a.abs) == 0 { + panic("division by zero") + } + var a, b Int + a.scaleDenom(&x.a, y.b.abs) + b.scaleDenom(&y.a, x.b.abs) + z.a.abs = a.abs + z.b.abs = b.abs + z.a.neg = a.neg != b.neg + return z.norm() +} diff --git a/src/math/big/rat_test.go b/src/math/big/rat_test.go new file mode 100644 index 0000000..d98c89b --- /dev/null +++ b/src/math/big/rat_test.go @@ -0,0 +1,746 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "math" + "testing" +) + +func TestZeroRat(t *testing.T) { + var x, y, z Rat + y.SetFrac64(0, 42) + + if x.Cmp(&y) != 0 { + t.Errorf("x and y should be both equal and zero") + } + + if s := x.String(); s != "0/1" { + t.Errorf("got x = %s, want 0/1", s) + } + + if s := x.RatString(); s != "0" { + t.Errorf("got x = %s, want 0", s) + } + + z.Add(&x, &y) + if s := z.RatString(); s != "0" { + t.Errorf("got x+y = %s, want 0", s) + } + + z.Sub(&x, &y) + if s := z.RatString(); s != "0" { + t.Errorf("got x-y = %s, want 0", s) + } + + z.Mul(&x, &y) + if s := z.RatString(); s != "0" { + t.Errorf("got x*y = %s, want 0", s) + } + + // check for division by zero + defer func() { + if s := recover(); s == nil || s.(string) != "division by zero" { + panic(s) + } + }() + z.Quo(&x, &y) +} + +func TestRatSign(t *testing.T) { + zero := NewRat(0, 1) + for _, a := range setStringTests { + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } + s := x.Sign() + e := x.Cmp(zero) + if s != e { + t.Errorf("got %d; want %d for z = %v", s, e, &x) + } + } +} + +var ratCmpTests = []struct { + rat1, rat2 string + out int +}{ + {"0", "0/1", 0}, + {"1/1", "1", 0}, + {"-1", "-2/2", 0}, + {"1", "0", 1}, + {"0/1", "1/1", -1}, + {"-5/1434770811533343057144", "-5/1434770811533343057145", -1}, + {"49832350382626108453/8964749413", "49832350382626108454/8964749413", -1}, + {"-37414950961700930/7204075375675961", "37414950961700930/7204075375675961", -1}, + {"37414950961700930/7204075375675961", "74829901923401860/14408150751351922", 0}, +} + +func TestRatCmp(t *testing.T) { + for i, test := range ratCmpTests { + x, _ := new(Rat).SetString(test.rat1) + y, _ := new(Rat).SetString(test.rat2) + + out := x.Cmp(y) + if out != test.out { + t.Errorf("#%d got out = %v; want %v", i, out, test.out) + } + } +} + +func TestIsInt(t *testing.T) { + one := NewInt(1) + for _, a := range setStringTests { + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } + i := x.IsInt() + e := x.Denom().Cmp(one) == 0 + if i != e { + t.Errorf("got IsInt(%v) == %v; want %v", x, i, e) + } + } +} + +func TestRatAbs(t *testing.T) { + zero := new(Rat) + for _, a := range setStringTests { + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } + e := new(Rat).Set(x) + if e.Cmp(zero) < 0 { + e.Sub(zero, e) + } + z := new(Rat).Abs(x) + if z.Cmp(e) != 0 { + t.Errorf("got Abs(%v) = %v; want %v", x, z, e) + } + } +} + +func TestRatNeg(t *testing.T) { + zero := new(Rat) + for _, a := range setStringTests { + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } + e := new(Rat).Sub(zero, x) + z := new(Rat).Neg(x) + if z.Cmp(e) != 0 { + t.Errorf("got Neg(%v) = %v; want %v", x, z, e) + } + } +} + +func TestRatInv(t *testing.T) { + zero := new(Rat) + for _, a := range setStringTests { + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } + if x.Cmp(zero) == 0 { + continue // avoid division by zero + } + e := new(Rat).SetFrac(x.Denom(), x.Num()) + z := new(Rat).Inv(x) + if z.Cmp(e) != 0 { + t.Errorf("got Inv(%v) = %v; want %v", x, z, e) + } + } +} + +type ratBinFun func(z, x, y *Rat) *Rat +type ratBinArg struct { + x, y, z string +} + +func testRatBin(t *testing.T, i int, name string, f ratBinFun, a ratBinArg) { + x, _ := new(Rat).SetString(a.x) + y, _ := new(Rat).SetString(a.y) + z, _ := new(Rat).SetString(a.z) + out := f(new(Rat), x, y) + + if out.Cmp(z) != 0 { + t.Errorf("%s #%d got %s want %s", name, i, out, z) + } +} + +var ratBinTests = []struct { + x, y string + sum, prod string +}{ + {"0", "0", "0", "0"}, + {"0", "1", "1", "0"}, + {"-1", "0", "-1", "0"}, + {"-1", "1", "0", "-1"}, + {"1", "1", "2", "1"}, + {"1/2", "1/2", "1", "1/4"}, + {"1/4", "1/3", "7/12", "1/12"}, + {"2/5", "-14/3", "-64/15", "-28/15"}, + {"4707/49292519774798173060", "-3367/70976135186689855734", "84058377121001851123459/1749296273614329067191168098769082663020", "-1760941/388732505247628681598037355282018369560"}, + {"-61204110018146728334/3", "-31052192278051565633/2", "-215564796870448153567/6", "950260896245257153059642991192710872711/3"}, + {"-854857841473707320655/4237645934602118692642972629634714039", "-18/31750379913563777419", "-27/133467566250814981", "15387441146526731771790/134546868362786310073779084329032722548987800600710485341"}, + {"618575745270541348005638912139/19198433543745179392300736", "-19948846211000086/637313996471", "27674141753240653/30123979153216", "-6169936206128396568797607742807090270137721977/6117715203873571641674006593837351328"}, + {"-3/26206484091896184128", "5/2848423294177090248", "15310893822118706237/9330894968229805033368778458685147968", "-5/24882386581946146755650075889827061248"}, + {"26946729/330400702820", "41563965/225583428284", "1238218672302860271/4658307703098666660055", "224002580204097/14906584649915733312176"}, + {"-8259900599013409474/7", "-84829337473700364773/56707961321161574960", "-468402123685491748914621885145127724451/396955729248131024720", "350340947706464153265156004876107029701/198477864624065512360"}, + {"575775209696864/1320203974639986246357", "29/712593081308", "410331716733912717985762465/940768218243776489278275419794956", "808/45524274987585732633"}, + {"1786597389946320496771/2066653520653241", "6269770/1992362624741777", "3559549865190272133656109052308126637/4117523232840525481453983149257", "8967230/3296219033"}, + {"-36459180403360509753/32150500941194292113930", "9381566963714/9633539", "301622077145533298008420642898530153/309723104686531919656937098270", "-3784609207827/3426986245"}, +} + +func TestRatBin(t *testing.T) { + for i, test := range ratBinTests { + arg := ratBinArg{test.x, test.y, test.sum} + testRatBin(t, i, "Add", (*Rat).Add, arg) + + arg = ratBinArg{test.y, test.x, test.sum} + testRatBin(t, i, "Add symmetric", (*Rat).Add, arg) + + arg = ratBinArg{test.sum, test.x, test.y} + testRatBin(t, i, "Sub", (*Rat).Sub, arg) + + arg = ratBinArg{test.sum, test.y, test.x} + testRatBin(t, i, "Sub symmetric", (*Rat).Sub, arg) + + arg = ratBinArg{test.x, test.y, test.prod} + testRatBin(t, i, "Mul", (*Rat).Mul, arg) + + arg = ratBinArg{test.y, test.x, test.prod} + testRatBin(t, i, "Mul symmetric", (*Rat).Mul, arg) + + if test.x != "0" { + arg = ratBinArg{test.prod, test.x, test.y} + testRatBin(t, i, "Quo", (*Rat).Quo, arg) + } + + if test.y != "0" { + arg = ratBinArg{test.prod, test.y, test.x} + testRatBin(t, i, "Quo symmetric", (*Rat).Quo, arg) + } + } +} + +func TestIssue820(t *testing.T) { + x := NewRat(3, 1) + y := NewRat(2, 1) + z := y.Quo(x, y) + q := NewRat(3, 2) + if z.Cmp(q) != 0 { + t.Errorf("got %s want %s", z, q) + } + + y = NewRat(3, 1) + x = NewRat(2, 1) + z = y.Quo(x, y) + q = NewRat(2, 3) + if z.Cmp(q) != 0 { + t.Errorf("got %s want %s", z, q) + } + + x = NewRat(3, 1) + z = x.Quo(x, x) + q = NewRat(3, 3) + if z.Cmp(q) != 0 { + t.Errorf("got %s want %s", z, q) + } +} + +var setFrac64Tests = []struct { + a, b int64 + out string +}{ + {0, 1, "0"}, + {0, -1, "0"}, + {1, 1, "1"}, + {-1, 1, "-1"}, + {1, -1, "-1"}, + {-1, -1, "1"}, + {-9223372036854775808, -9223372036854775808, "1"}, +} + +func TestRatSetFrac64Rat(t *testing.T) { + for i, test := range setFrac64Tests { + x := new(Rat).SetFrac64(test.a, test.b) + if x.RatString() != test.out { + t.Errorf("#%d got %s want %s", i, x.RatString(), test.out) + } + } +} + +func TestIssue2379(t *testing.T) { + // 1) no aliasing + q := NewRat(3, 2) + x := new(Rat) + x.SetFrac(NewInt(3), NewInt(2)) + if x.Cmp(q) != 0 { + t.Errorf("1) got %s want %s", x, q) + } + + // 2) aliasing of numerator + x = NewRat(2, 3) + x.SetFrac(NewInt(3), x.Num()) + if x.Cmp(q) != 0 { + t.Errorf("2) got %s want %s", x, q) + } + + // 3) aliasing of denominator + x = NewRat(2, 3) + x.SetFrac(x.Denom(), NewInt(2)) + if x.Cmp(q) != 0 { + t.Errorf("3) got %s want %s", x, q) + } + + // 4) aliasing of numerator and denominator + x = NewRat(2, 3) + x.SetFrac(x.Denom(), x.Num()) + if x.Cmp(q) != 0 { + t.Errorf("4) got %s want %s", x, q) + } + + // 5) numerator and denominator are the same + q = NewRat(1, 1) + x = new(Rat) + n := NewInt(7) + x.SetFrac(n, n) + if x.Cmp(q) != 0 { + t.Errorf("5) got %s want %s", x, q) + } +} + +func TestIssue3521(t *testing.T) { + a := new(Int) + b := new(Int) + a.SetString("64375784358435883458348587", 0) + b.SetString("4789759874531", 0) + + // 0) a raw zero value has 1 as denominator + zero := new(Rat) + one := NewInt(1) + if zero.Denom().Cmp(one) != 0 { + t.Errorf("0) got %s want %s", zero.Denom(), one) + } + + // 1a) the denominator of an (uninitialized) zero value is not shared with the value + s := &zero.b + d := zero.Denom() + if d == s { + t.Errorf("1a) got %s (%p) == %s (%p) want different *Int values", d, d, s, s) + } + + // 1b) the denominator of an (uninitialized) value is a new 1 each time + d1 := zero.Denom() + d2 := zero.Denom() + if d1 == d2 { + t.Errorf("1b) got %s (%p) == %s (%p) want different *Int values", d1, d1, d2, d2) + } + + // 1c) the denominator of an initialized zero value is shared with the value + x := new(Rat) + x.Set(x) // initialize x (any operation that sets x explicitly will do) + s = &x.b + d = x.Denom() + if d != s { + t.Errorf("1c) got %s (%p) != %s (%p) want identical *Int values", d, d, s, s) + } + + // 1d) a zero value remains zero independent of denominator + x.Denom().Set(new(Int).Neg(b)) + if x.Cmp(zero) != 0 { + t.Errorf("1d) got %s want %s", x, zero) + } + + // 1e) a zero value may have a denominator != 0 and != 1 + x.Num().Set(a) + qab := new(Rat).SetFrac(a, b) + if x.Cmp(qab) != 0 { + t.Errorf("1e) got %s want %s", x, qab) + } + + // 2a) an integral value becomes a fraction depending on denominator + x.SetFrac64(10, 2) + x.Denom().SetInt64(3) + q53 := NewRat(5, 3) + if x.Cmp(q53) != 0 { + t.Errorf("2a) got %s want %s", x, q53) + } + + // 2b) an integral value becomes a fraction depending on denominator + x = NewRat(10, 2) + x.Denom().SetInt64(3) + if x.Cmp(q53) != 0 { + t.Errorf("2b) got %s want %s", x, q53) + } + + // 3) changing the numerator/denominator of a Rat changes the Rat + x.SetFrac(a, b) + a = x.Num() + b = x.Denom() + a.SetInt64(5) + b.SetInt64(3) + if x.Cmp(q53) != 0 { + t.Errorf("3) got %s want %s", x, q53) + } +} + +func TestFloat32Distribution(t *testing.T) { + // Generate a distribution of (sign, mantissa, exp) values + // broader than the float32 range, and check Rat.Float32() + // always picks the closest float32 approximation. + var add = []int64{ + 0, + 1, + 3, + 5, + 7, + 9, + 11, + } + var winc, einc = uint64(5), 15 // quick test (~60ms on x86-64) + if *long { + winc, einc = uint64(1), 1 // soak test (~1.5s on x86-64) + } + + for _, sign := range "+-" { + for _, a := range add { + for wid := uint64(0); wid < 30; wid += winc { + b := 1<<wid + a + if sign == '-' { + b = -b + } + for exp := -150; exp < 150; exp += einc { + num, den := NewInt(b), NewInt(1) + if exp > 0 { + num.Lsh(num, uint(exp)) + } else { + den.Lsh(den, uint(-exp)) + } + r := new(Rat).SetFrac(num, den) + f, _ := r.Float32() + + if !checkIsBestApprox32(t, f, r) { + // Append context information. + t.Errorf("(input was mantissa %#x, exp %d; f = %g (%b); f ~ %g; r = %v)", + b, exp, f, f, math.Ldexp(float64(b), exp), r) + } + + checkNonLossyRoundtrip32(t, f) + } + } + } + } +} + +func TestFloat64Distribution(t *testing.T) { + // Generate a distribution of (sign, mantissa, exp) values + // broader than the float64 range, and check Rat.Float64() + // always picks the closest float64 approximation. + var add = []int64{ + 0, + 1, + 3, + 5, + 7, + 9, + 11, + } + var winc, einc = uint64(10), 500 // quick test (~12ms on x86-64) + if *long { + winc, einc = uint64(1), 1 // soak test (~75s on x86-64) + } + + for _, sign := range "+-" { + for _, a := range add { + for wid := uint64(0); wid < 60; wid += winc { + b := 1<<wid + a + if sign == '-' { + b = -b + } + for exp := -1100; exp < 1100; exp += einc { + num, den := NewInt(b), NewInt(1) + if exp > 0 { + num.Lsh(num, uint(exp)) + } else { + den.Lsh(den, uint(-exp)) + } + r := new(Rat).SetFrac(num, den) + f, _ := r.Float64() + + if !checkIsBestApprox64(t, f, r) { + // Append context information. + t.Errorf("(input was mantissa %#x, exp %d; f = %g (%b); f ~ %g; r = %v)", + b, exp, f, f, math.Ldexp(float64(b), exp), r) + } + + checkNonLossyRoundtrip64(t, f) + } + } + } + } +} + +// TestSetFloat64NonFinite checks that SetFloat64 of a non-finite value +// returns nil. +func TestSetFloat64NonFinite(t *testing.T) { + for _, f := range []float64{math.NaN(), math.Inf(+1), math.Inf(-1)} { + var r Rat + if r2 := r.SetFloat64(f); r2 != nil { + t.Errorf("SetFloat64(%g) was %v, want nil", f, r2) + } + } +} + +// checkNonLossyRoundtrip32 checks that a float->Rat->float roundtrip is +// non-lossy for finite f. +func checkNonLossyRoundtrip32(t *testing.T, f float32) { + if !isFinite(float64(f)) { + return + } + r := new(Rat).SetFloat64(float64(f)) + if r == nil { + t.Errorf("Rat.SetFloat64(float64(%g) (%b)) == nil", f, f) + return + } + f2, exact := r.Float32() + if f != f2 || !exact { + t.Errorf("Rat.SetFloat64(float64(%g)).Float32() = %g (%b), %v, want %g (%b), %v; delta = %b", + f, f2, f2, exact, f, f, true, f2-f) + } +} + +// checkNonLossyRoundtrip64 checks that a float->Rat->float roundtrip is +// non-lossy for finite f. +func checkNonLossyRoundtrip64(t *testing.T, f float64) { + if !isFinite(f) { + return + } + r := new(Rat).SetFloat64(f) + if r == nil { + t.Errorf("Rat.SetFloat64(%g (%b)) == nil", f, f) + return + } + f2, exact := r.Float64() + if f != f2 || !exact { + t.Errorf("Rat.SetFloat64(%g).Float64() = %g (%b), %v, want %g (%b), %v; delta = %b", + f, f2, f2, exact, f, f, true, f2-f) + } +} + +// delta returns the absolute difference between r and f. +func delta(r *Rat, f float64) *Rat { + d := new(Rat).Sub(r, new(Rat).SetFloat64(f)) + return d.Abs(d) +} + +// checkIsBestApprox32 checks that f is the best possible float32 +// approximation of r. +// Returns true on success. +func checkIsBestApprox32(t *testing.T, f float32, r *Rat) bool { + if math.Abs(float64(f)) >= math.MaxFloat32 { + // Cannot check +Inf, -Inf, nor the float next to them (MaxFloat32). + // But we have tests for these special cases. + return true + } + + // r must be strictly between f0 and f1, the floats bracketing f. + f0 := math.Nextafter32(f, float32(math.Inf(-1))) + f1 := math.Nextafter32(f, float32(math.Inf(+1))) + + // For f to be correct, r must be closer to f than to f0 or f1. + df := delta(r, float64(f)) + df0 := delta(r, float64(f0)) + df1 := delta(r, float64(f1)) + if df.Cmp(df0) > 0 { + t.Errorf("Rat(%v).Float32() = %g (%b), but previous float32 %g (%b) is closer", r, f, f, f0, f0) + return false + } + if df.Cmp(df1) > 0 { + t.Errorf("Rat(%v).Float32() = %g (%b), but next float32 %g (%b) is closer", r, f, f, f1, f1) + return false + } + if df.Cmp(df0) == 0 && !isEven32(f) { + t.Errorf("Rat(%v).Float32() = %g (%b); halfway should have rounded to %g (%b) instead", r, f, f, f0, f0) + return false + } + if df.Cmp(df1) == 0 && !isEven32(f) { + t.Errorf("Rat(%v).Float32() = %g (%b); halfway should have rounded to %g (%b) instead", r, f, f, f1, f1) + return false + } + return true +} + +// checkIsBestApprox64 checks that f is the best possible float64 +// approximation of r. +// Returns true on success. +func checkIsBestApprox64(t *testing.T, f float64, r *Rat) bool { + if math.Abs(f) >= math.MaxFloat64 { + // Cannot check +Inf, -Inf, nor the float next to them (MaxFloat64). + // But we have tests for these special cases. + return true + } + + // r must be strictly between f0 and f1, the floats bracketing f. + f0 := math.Nextafter(f, math.Inf(-1)) + f1 := math.Nextafter(f, math.Inf(+1)) + + // For f to be correct, r must be closer to f than to f0 or f1. + df := delta(r, f) + df0 := delta(r, f0) + df1 := delta(r, f1) + if df.Cmp(df0) > 0 { + t.Errorf("Rat(%v).Float64() = %g (%b), but previous float64 %g (%b) is closer", r, f, f, f0, f0) + return false + } + if df.Cmp(df1) > 0 { + t.Errorf("Rat(%v).Float64() = %g (%b), but next float64 %g (%b) is closer", r, f, f, f1, f1) + return false + } + if df.Cmp(df0) == 0 && !isEven64(f) { + t.Errorf("Rat(%v).Float64() = %g (%b); halfway should have rounded to %g (%b) instead", r, f, f, f0, f0) + return false + } + if df.Cmp(df1) == 0 && !isEven64(f) { + t.Errorf("Rat(%v).Float64() = %g (%b); halfway should have rounded to %g (%b) instead", r, f, f, f1, f1) + return false + } + return true +} + +func isEven32(f float32) bool { return math.Float32bits(f)&1 == 0 } +func isEven64(f float64) bool { return math.Float64bits(f)&1 == 0 } + +func TestIsFinite(t *testing.T) { + finites := []float64{ + 1.0 / 3, + 4891559871276714924261e+222, + math.MaxFloat64, + math.SmallestNonzeroFloat64, + -math.MaxFloat64, + -math.SmallestNonzeroFloat64, + } + for _, f := range finites { + if !isFinite(f) { + t.Errorf("!IsFinite(%g (%b))", f, f) + } + } + nonfinites := []float64{ + math.NaN(), + math.Inf(-1), + math.Inf(+1), + } + for _, f := range nonfinites { + if isFinite(f) { + t.Errorf("IsFinite(%g, (%b))", f, f) + } + } +} + +func TestRatSetInt64(t *testing.T) { + var testCases = []int64{ + 0, + 1, + -1, + 12345, + -98765, + math.MaxInt64, + math.MinInt64, + } + var r = new(Rat) + for i, want := range testCases { + r.SetInt64(want) + if !r.IsInt() { + t.Errorf("#%d: Rat.SetInt64(%d) is not an integer", i, want) + } + num := r.Num() + if !num.IsInt64() { + t.Errorf("#%d: Rat.SetInt64(%d) numerator is not an int64", i, want) + } + got := num.Int64() + if got != want { + t.Errorf("#%d: Rat.SetInt64(%d) = %d, but expected %d", i, want, got, want) + } + } +} + +func TestRatSetUint64(t *testing.T) { + var testCases = []uint64{ + 0, + 1, + 12345, + ^uint64(0), + } + var r = new(Rat) + for i, want := range testCases { + r.SetUint64(want) + if !r.IsInt() { + t.Errorf("#%d: Rat.SetUint64(%d) is not an integer", i, want) + } + num := r.Num() + if !num.IsUint64() { + t.Errorf("#%d: Rat.SetUint64(%d) numerator is not a uint64", i, want) + } + got := num.Uint64() + if got != want { + t.Errorf("#%d: Rat.SetUint64(%d) = %d, but expected %d", i, want, got, want) + } + } +} + +func BenchmarkRatCmp(b *testing.B) { + x, y := NewRat(4, 1), NewRat(7, 2) + for i := 0; i < b.N; i++ { + x.Cmp(y) + } +} + +// TestIssue34919 verifies that a Rat's denominator is not modified +// when simply accessing the Rat value. +func TestIssue34919(t *testing.T) { + for _, acc := range []struct { + name string + f func(*Rat) + }{ + {"Float32", func(x *Rat) { x.Float32() }}, + {"Float64", func(x *Rat) { x.Float64() }}, + {"Inv", func(x *Rat) { new(Rat).Inv(x) }}, + {"Sign", func(x *Rat) { x.Sign() }}, + {"IsInt", func(x *Rat) { x.IsInt() }}, + {"Num", func(x *Rat) { x.Num() }}, + // {"Denom", func(x *Rat) { x.Denom() }}, TODO(gri) should we change the API? See issue #33792. + } { + // A denominator of length 0 is interpreted as 1. Make sure that + // "materialization" of the denominator doesn't lead to setting + // the underlying array element 0 to 1. + r := &Rat{Int{abs: nat{991}}, Int{abs: make(nat, 0, 1)}} + acc.f(r) + if d := r.b.abs[:1][0]; d != 0 { + t.Errorf("%s modified denominator: got %d, want 0", acc.name, d) + } + } +} + +func TestDenomRace(t *testing.T) { + x := NewRat(1, 2) + const N = 3 + c := make(chan bool, N) + for i := 0; i < N; i++ { + go func() { + // Denom (also used by Float.SetRat) used to mutate x unnecessarily, + // provoking race reports when run in the race detector. + x.Denom() + new(Float).SetRat(x) + c <- true + }() + } + for i := 0; i < N; i++ { + <-c + } +} diff --git a/src/math/big/ratconv.go b/src/math/big/ratconv.go new file mode 100644 index 0000000..8537a67 --- /dev/null +++ b/src/math/big/ratconv.go @@ -0,0 +1,380 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements rat-to-string conversion functions. + +package big + +import ( + "errors" + "fmt" + "io" + "strconv" + "strings" +) + +func ratTok(ch rune) bool { + return strings.ContainsRune("+-/0123456789.eE", ch) +} + +var ratZero Rat +var _ fmt.Scanner = &ratZero // *Rat must implement fmt.Scanner + +// Scan is a support routine for fmt.Scanner. It accepts the formats +// 'e', 'E', 'f', 'F', 'g', 'G', and 'v'. All formats are equivalent. +func (z *Rat) Scan(s fmt.ScanState, ch rune) error { + tok, err := s.Token(true, ratTok) + if err != nil { + return err + } + if !strings.ContainsRune("efgEFGv", ch) { + return errors.New("Rat.Scan: invalid verb") + } + if _, ok := z.SetString(string(tok)); !ok { + return errors.New("Rat.Scan: invalid syntax") + } + return nil +} + +// SetString sets z to the value of s and returns z and a boolean indicating +// success. s can be given as a (possibly signed) fraction "a/b", or as a +// floating-point number optionally followed by an exponent. +// If a fraction is provided, both the dividend and the divisor may be a +// decimal integer or independently use a prefix of “0b”, “0” or “0o”, +// or “0x” (or their upper-case variants) to denote a binary, octal, or +// hexadecimal integer, respectively. The divisor may not be signed. +// If a floating-point number is provided, it may be in decimal form or +// use any of the same prefixes as above but for “0” to denote a non-decimal +// mantissa. A leading “0” is considered a decimal leading 0; it does not +// indicate octal representation in this case. +// An optional base-10 “e” or base-2 “p” (or their upper-case variants) +// exponent may be provided as well, except for hexadecimal floats which +// only accept an (optional) “p” exponent (because an “e” or “E” cannot +// be distinguished from a mantissa digit). If the exponent's absolute value +// is too large, the operation may fail. +// The entire string, not just a prefix, must be valid for success. If the +// operation failed, the value of z is undefined but the returned value is nil. +func (z *Rat) SetString(s string) (*Rat, bool) { + if len(s) == 0 { + return nil, false + } + // len(s) > 0 + + // parse fraction a/b, if any + if sep := strings.Index(s, "/"); sep >= 0 { + if _, ok := z.a.SetString(s[:sep], 0); !ok { + return nil, false + } + r := strings.NewReader(s[sep+1:]) + var err error + if z.b.abs, _, _, err = z.b.abs.scan(r, 0, false); err != nil { + return nil, false + } + // entire string must have been consumed + if _, err = r.ReadByte(); err != io.EOF { + return nil, false + } + if len(z.b.abs) == 0 { + return nil, false + } + return z.norm(), true + } + + // parse floating-point number + r := strings.NewReader(s) + + // sign + neg, err := scanSign(r) + if err != nil { + return nil, false + } + + // mantissa + var base int + var fcount int // fractional digit count; valid if <= 0 + z.a.abs, base, fcount, err = z.a.abs.scan(r, 0, true) + if err != nil { + return nil, false + } + + // exponent + var exp int64 + var ebase int + exp, ebase, err = scanExponent(r, true, true) + if err != nil { + return nil, false + } + + // there should be no unread characters left + if _, err = r.ReadByte(); err != io.EOF { + return nil, false + } + + // special-case 0 (see also issue #16176) + if len(z.a.abs) == 0 { + return z.norm(), true + } + // len(z.a.abs) > 0 + + // The mantissa may have a radix point (fcount <= 0) and there + // may be a nonzero exponent exp. The radix point amounts to a + // division by base**(-fcount), which equals a multiplication by + // base**fcount. An exponent means multiplication by ebase**exp. + // Multiplications are commutative, so we can apply them in any + // order. We only have powers of 2 and 10, and we split powers + // of 10 into the product of the same powers of 2 and 5. This + // may reduce the size of shift/multiplication factors or + // divisors required to create the final fraction, depending + // on the actual floating-point value. + + // determine binary or decimal exponent contribution of radix point + var exp2, exp5 int64 + if fcount < 0 { + // The mantissa has a radix point ddd.dddd; and + // -fcount is the number of digits to the right + // of '.'. Adjust relevant exponent accordingly. + d := int64(fcount) + switch base { + case 10: + exp5 = d + fallthrough // 10**e == 5**e * 2**e + case 2: + exp2 = d + case 8: + exp2 = d * 3 // octal digits are 3 bits each + case 16: + exp2 = d * 4 // hexadecimal digits are 4 bits each + default: + panic("unexpected mantissa base") + } + // fcount consumed - not needed anymore + } + + // take actual exponent into account + switch ebase { + case 10: + exp5 += exp + fallthrough // see fallthrough above + case 2: + exp2 += exp + default: + panic("unexpected exponent base") + } + // exp consumed - not needed anymore + + // apply exp5 contributions + // (start with exp5 so the numbers to multiply are smaller) + if exp5 != 0 { + n := exp5 + if n < 0 { + n = -n + if n < 0 { + // This can occur if -n overflows. -(-1 << 63) would become + // -1 << 63, which is still negative. + return nil, false + } + } + if n > 1e6 { + return nil, false // avoid excessively large exponents + } + pow5 := z.b.abs.expNN(natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs + if exp5 > 0 { + z.a.abs = z.a.abs.mul(z.a.abs, pow5) + z.b.abs = z.b.abs.setWord(1) + } else { + z.b.abs = pow5 + } + } else { + z.b.abs = z.b.abs.setWord(1) + } + + // apply exp2 contributions + if exp2 < -1e7 || exp2 > 1e7 { + return nil, false // avoid excessively large exponents + } + if exp2 > 0 { + z.a.abs = z.a.abs.shl(z.a.abs, uint(exp2)) + } else if exp2 < 0 { + z.b.abs = z.b.abs.shl(z.b.abs, uint(-exp2)) + } + + z.a.neg = neg && len(z.a.abs) > 0 // 0 has no sign + + return z.norm(), true +} + +// scanExponent scans the longest possible prefix of r representing a base 10 +// (“e”, “E”) or a base 2 (“p”, “P”) exponent, if any. It returns the +// exponent, the exponent base (10 or 2), or a read or syntax error, if any. +// +// If sepOk is set, an underscore character “_” may appear between successive +// exponent digits; such underscores do not change the value of the exponent. +// Incorrect placement of underscores is reported as an error if there are no +// other errors. If sepOk is not set, underscores are not recognized and thus +// terminate scanning like any other character that is not a valid digit. +// +// exponent = ( "e" | "E" | "p" | "P" ) [ sign ] digits . +// sign = "+" | "-" . +// digits = digit { [ '_' ] digit } . +// digit = "0" ... "9" . +// +// A base 2 exponent is only permitted if base2ok is set. +func scanExponent(r io.ByteScanner, base2ok, sepOk bool) (exp int64, base int, err error) { + // one char look-ahead + ch, err := r.ReadByte() + if err != nil { + if err == io.EOF { + err = nil + } + return 0, 10, err + } + + // exponent char + switch ch { + case 'e', 'E': + base = 10 + case 'p', 'P': + if base2ok { + base = 2 + break // ok + } + fallthrough // binary exponent not permitted + default: + r.UnreadByte() // ch does not belong to exponent anymore + return 0, 10, nil + } + + // sign + var digits []byte + ch, err = r.ReadByte() + if err == nil && (ch == '+' || ch == '-') { + if ch == '-' { + digits = append(digits, '-') + } + ch, err = r.ReadByte() + } + + // prev encodes the previously seen char: it is one + // of '_', '0' (a digit), or '.' (anything else). A + // valid separator '_' may only occur after a digit. + prev := '.' + invalSep := false + + // exponent value + hasDigits := false + for err == nil { + if '0' <= ch && ch <= '9' { + digits = append(digits, ch) + prev = '0' + hasDigits = true + } else if ch == '_' && sepOk { + if prev != '0' { + invalSep = true + } + prev = '_' + } else { + r.UnreadByte() // ch does not belong to number anymore + break + } + ch, err = r.ReadByte() + } + + if err == io.EOF { + err = nil + } + if err == nil && !hasDigits { + err = errNoDigits + } + if err == nil { + exp, err = strconv.ParseInt(string(digits), 10, 64) + } + // other errors take precedence over invalid separators + if err == nil && (invalSep || prev == '_') { + err = errInvalSep + } + + return +} + +// String returns a string representation of x in the form "a/b" (even if b == 1). +func (x *Rat) String() string { + return string(x.marshal()) +} + +// marshal implements String returning a slice of bytes +func (x *Rat) marshal() []byte { + var buf []byte + buf = x.a.Append(buf, 10) + buf = append(buf, '/') + if len(x.b.abs) != 0 { + buf = x.b.Append(buf, 10) + } else { + buf = append(buf, '1') + } + return buf +} + +// RatString returns a string representation of x in the form "a/b" if b != 1, +// and in the form "a" if b == 1. +func (x *Rat) RatString() string { + if x.IsInt() { + return x.a.String() + } + return x.String() +} + +// FloatString returns a string representation of x in decimal form with prec +// digits of precision after the radix point. The last digit is rounded to +// nearest, with halves rounded away from zero. +func (x *Rat) FloatString(prec int) string { + var buf []byte + + if x.IsInt() { + buf = x.a.Append(buf, 10) + if prec > 0 { + buf = append(buf, '.') + for i := prec; i > 0; i-- { + buf = append(buf, '0') + } + } + return string(buf) + } + // x.b.abs != 0 + + q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs) + + p := natOne + if prec > 0 { + p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil, false) + } + + r = r.mul(r, p) + r, r2 := r.div(nat(nil), r, x.b.abs) + + // see if we need to round up + r2 = r2.add(r2, r2) + if x.b.abs.cmp(r2) <= 0 { + r = r.add(r, natOne) + if r.cmp(p) >= 0 { + q = nat(nil).add(q, natOne) + r = nat(nil).sub(r, p) + } + } + + if x.a.neg { + buf = append(buf, '-') + } + buf = append(buf, q.utoa(10)...) // itoa ignores sign if q == 0 + + if prec > 0 { + buf = append(buf, '.') + rs := r.utoa(10) + for i := prec - len(rs); i > 0; i-- { + buf = append(buf, '0') + } + buf = append(buf, rs...) + } + + return string(buf) +} diff --git a/src/math/big/ratconv_test.go b/src/math/big/ratconv_test.go new file mode 100644 index 0000000..45a3560 --- /dev/null +++ b/src/math/big/ratconv_test.go @@ -0,0 +1,626 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "bytes" + "fmt" + "io" + "math" + "reflect" + "strconv" + "strings" + "testing" +) + +var exponentTests = []struct { + s string // string to be scanned + base2ok bool // true if 'p'/'P' exponents are accepted + sepOk bool // true if '_' separators are accepted + x int64 // expected exponent + b int // expected exponent base + err error // expected error + next rune // next character (or 0, if at EOF) +}{ + // valid, without separators + {"", false, false, 0, 10, nil, 0}, + {"1", false, false, 0, 10, nil, '1'}, + {"e0", false, false, 0, 10, nil, 0}, + {"E1", false, false, 1, 10, nil, 0}, + {"e+10", false, false, 10, 10, nil, 0}, + {"e-10", false, false, -10, 10, nil, 0}, + {"e123456789a", false, false, 123456789, 10, nil, 'a'}, + {"p", false, false, 0, 10, nil, 'p'}, + {"P+100", false, false, 0, 10, nil, 'P'}, + {"p0", true, false, 0, 2, nil, 0}, + {"P-123", true, false, -123, 2, nil, 0}, + {"p+0a", true, false, 0, 2, nil, 'a'}, + {"p+123__", true, false, 123, 2, nil, '_'}, // '_' is not part of the number anymore + + // valid, with separators + {"e+1_0", false, true, 10, 10, nil, 0}, + {"e-1_0", false, true, -10, 10, nil, 0}, + {"e123_456_789a", false, true, 123456789, 10, nil, 'a'}, + {"P+1_00", false, true, 0, 10, nil, 'P'}, + {"p-1_2_3", true, true, -123, 2, nil, 0}, + + // invalid: no digits + {"e", false, false, 0, 10, errNoDigits, 0}, + {"ef", false, false, 0, 10, errNoDigits, 'f'}, + {"e+", false, false, 0, 10, errNoDigits, 0}, + {"E-x", false, false, 0, 10, errNoDigits, 'x'}, + {"p", true, false, 0, 2, errNoDigits, 0}, + {"P-", true, false, 0, 2, errNoDigits, 0}, + {"p+e", true, false, 0, 2, errNoDigits, 'e'}, + {"e+_x", false, true, 0, 10, errNoDigits, 'x'}, + + // invalid: incorrect use of separator + {"e0_", false, true, 0, 10, errInvalSep, 0}, + {"e_0", false, true, 0, 10, errInvalSep, 0}, + {"e-1_2__3", false, true, -123, 10, errInvalSep, 0}, +} + +func TestScanExponent(t *testing.T) { + for _, a := range exponentTests { + r := strings.NewReader(a.s) + x, b, err := scanExponent(r, a.base2ok, a.sepOk) + if err != a.err { + t.Errorf("scanExponent%+v\n\tgot error = %v; want %v", a, err, a.err) + } + if x != a.x { + t.Errorf("scanExponent%+v\n\tgot z = %v; want %v", a, x, a.x) + } + if b != a.b { + t.Errorf("scanExponent%+v\n\tgot b = %d; want %d", a, b, a.b) + } + next, _, err := r.ReadRune() + if err == io.EOF { + next = 0 + err = nil + } + if err == nil && next != a.next { + t.Errorf("scanExponent%+v\n\tgot next = %q; want %q", a, next, a.next) + } + } +} + +type StringTest struct { + in, out string + ok bool +} + +var setStringTests = []StringTest{ + // invalid + {in: "1e"}, + {in: "1.e"}, + {in: "1e+14e-5"}, + {in: "1e4.5"}, + {in: "r"}, + {in: "a/b"}, + {in: "a.b"}, + {in: "1/0"}, + {in: "4/3/2"}, // issue 17001 + {in: "4/3/"}, + {in: "4/3."}, + {in: "4/"}, + {in: "13e-9223372036854775808"}, // CVE-2022-23772 + + // valid + {"0", "0", true}, + {"-0", "0", true}, + {"1", "1", true}, + {"-1", "-1", true}, + {"1.", "1", true}, + {"1e0", "1", true}, + {"1.e1", "10", true}, + {"-0.1", "-1/10", true}, + {"-.1", "-1/10", true}, + {"2/4", "1/2", true}, + {".25", "1/4", true}, + {"-1/5", "-1/5", true}, + {"8129567.7690E14", "812956776900000000000", true}, + {"78189e+4", "781890000", true}, + {"553019.8935e+8", "55301989350000", true}, + {"98765432109876543210987654321e-10", "98765432109876543210987654321/10000000000", true}, + {"9877861857500000E-7", "3951144743/4", true}, + {"2169378.417e-3", "2169378417/1000000", true}, + {"884243222337379604041632732738665534", "884243222337379604041632732738665534", true}, + {"53/70893980658822810696", "53/70893980658822810696", true}, + {"106/141787961317645621392", "53/70893980658822810696", true}, + {"204211327800791583.81095", "4084226556015831676219/20000", true}, + {"0e9999999999", "0", true}, // issue #16176 +} + +// These are not supported by fmt.Fscanf. +var setStringTests2 = []StringTest{ + // invalid + {in: "4/3x"}, + {in: "0/-1"}, + {in: "-1/-1"}, + + // invalid with separators + // (smoke tests only - a comprehensive set of tests is in natconv_test.go) + {in: "10_/1"}, + {in: "_10/1"}, + {in: "1/1__0"}, + + // valid + {"0b1000/3", "8/3", true}, + {"0B1000/0x8", "1", true}, + {"-010/1", "-8", true}, // 0-prefix indicates octal in this case + {"-010.0", "-10", true}, + {"-0o10/1", "-8", true}, + {"0x10/1", "16", true}, + {"0x10/0x20", "1/2", true}, + + {"0010", "10", true}, // 0-prefix is ignored in this case (not a fraction) + {"0x10.0", "16", true}, + {"0x1.8", "3/2", true}, + {"0X1.8p4", "24", true}, + {"0x1.1E2", "2289/2048", true}, // E is part of hex mantissa, not exponent + {"0b1.1E2", "150", true}, + {"0B1.1P3", "12", true}, + {"0o10e-2", "2/25", true}, + {"0O10p-3", "1", true}, + + // valid with separators + // (smoke tests only - a comprehensive set of tests is in natconv_test.go) + {"0b_1000/3", "8/3", true}, + {"0B_10_00/0x8", "1", true}, + {"0xdead/0B1101_1110_1010_1101", "1", true}, + {"0B1101_1110_1010_1101/0XD_E_A_D", "1", true}, + {"1_000.0", "1000", true}, + + {"0x_10.0", "16", true}, + {"0x1_0.0", "16", true}, + {"0x1.8_0", "3/2", true}, + {"0X1.8p0_4", "24", true}, + {"0b1.1_0E2", "150", true}, + {"0o1_0e-2", "2/25", true}, + {"0O_10p-3", "1", true}, +} + +func TestRatSetString(t *testing.T) { + var tests []StringTest + tests = append(tests, setStringTests...) + tests = append(tests, setStringTests2...) + + for i, test := range tests { + x, ok := new(Rat).SetString(test.in) + + if ok { + if !test.ok { + t.Errorf("#%d SetString(%q) expected failure", i, test.in) + } else if x.RatString() != test.out { + t.Errorf("#%d SetString(%q) got %s want %s", i, test.in, x.RatString(), test.out) + } + } else { + if test.ok { + t.Errorf("#%d SetString(%q) expected success", i, test.in) + } else if x != nil { + t.Errorf("#%d SetString(%q) got %p want nil", i, test.in, x) + } + } + } +} + +func TestRatSetStringZero(t *testing.T) { + got, _ := new(Rat).SetString("0") + want := new(Rat).SetInt64(0) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %#+v, want %#+v", got, want) + } +} + +func TestRatScan(t *testing.T) { + var buf bytes.Buffer + for i, test := range setStringTests { + x := new(Rat) + buf.Reset() + buf.WriteString(test.in) + + _, err := fmt.Fscanf(&buf, "%v", x) + if err == nil != test.ok { + if test.ok { + t.Errorf("#%d (%s) error: %s", i, test.in, err) + } else { + t.Errorf("#%d (%s) expected error", i, test.in) + } + continue + } + if err == nil && x.RatString() != test.out { + t.Errorf("#%d got %s want %s", i, x.RatString(), test.out) + } + } +} + +var floatStringTests = []struct { + in string + prec int + out string +}{ + {"0", 0, "0"}, + {"0", 4, "0.0000"}, + {"1", 0, "1"}, + {"1", 2, "1.00"}, + {"-1", 0, "-1"}, + {"0.05", 1, "0.1"}, + {"-0.05", 1, "-0.1"}, + {".25", 2, "0.25"}, + {".25", 1, "0.3"}, + {".25", 3, "0.250"}, + {"-1/3", 3, "-0.333"}, + {"-2/3", 4, "-0.6667"}, + {"0.96", 1, "1.0"}, + {"0.999", 2, "1.00"}, + {"0.9", 0, "1"}, + {".25", -1, "0"}, + {".55", -1, "1"}, +} + +func TestFloatString(t *testing.T) { + for i, test := range floatStringTests { + x, _ := new(Rat).SetString(test.in) + + if x.FloatString(test.prec) != test.out { + t.Errorf("#%d got %s want %s", i, x.FloatString(test.prec), test.out) + } + } +} + +// Test inputs to Rat.SetString. The prefix "long:" causes the test +// to be skipped except in -long mode. (The threshold is about 500us.) +var float64inputs = []string{ + // Constants plundered from strconv/testfp.txt. + + // Table 1: Stress Inputs for Conversion to 53-bit Binary, < 1/2 ULP + "5e+125", + "69e+267", + "999e-026", + "7861e-034", + "75569e-254", + "928609e-261", + "9210917e+080", + "84863171e+114", + "653777767e+273", + "5232604057e-298", + "27235667517e-109", + "653532977297e-123", + "3142213164987e-294", + "46202199371337e-072", + "231010996856685e-073", + "9324754620109615e+212", + "78459735791271921e+049", + "272104041512242479e+200", + "6802601037806061975e+198", + "20505426358836677347e-221", + "836168422905420598437e-234", + "4891559871276714924261e+222", + + // Table 2: Stress Inputs for Conversion to 53-bit Binary, > 1/2 ULP + "9e-265", + "85e-037", + "623e+100", + "3571e+263", + "81661e+153", + "920657e-023", + "4603285e-024", + "87575437e-309", + "245540327e+122", + "6138508175e+120", + "83356057653e+193", + "619534293513e+124", + "2335141086879e+218", + "36167929443327e-159", + "609610927149051e-255", + "3743626360493413e-165", + "94080055902682397e-242", + "899810892172646163e+283", + "7120190517612959703e+120", + "25188282901709339043e-252", + "308984926168550152811e-052", + "6372891218502368041059e+064", + + // Table 14: Stress Inputs for Conversion to 24-bit Binary, <1/2 ULP + "5e-20", + "67e+14", + "985e+15", + "7693e-42", + "55895e-16", + "996622e-44", + "7038531e-32", + "60419369e-46", + "702990899e-20", + "6930161142e-48", + "25933168707e+13", + "596428896559e+20", + + // Table 15: Stress Inputs for Conversion to 24-bit Binary, >1/2 ULP + "3e-23", + "57e+18", + "789e-35", + "2539e-18", + "76173e+28", + "887745e-11", + "5382571e-37", + "82381273e-35", + "750486563e-38", + "3752432815e-39", + "75224575729e-45", + "459926601011e+15", + + // Constants plundered from strconv/atof_test.go. + + "0", + "1", + "+1", + "1e23", + "1E23", + "100000000000000000000000", + "1e-100", + "123456700", + "99999999999999974834176", + "100000000000000000000001", + "100000000000000008388608", + "100000000000000016777215", + "100000000000000016777216", + "-1", + "-0.1", + "-0", // NB: exception made for this input + "1e-20", + "625e-3", + + // largest float64 + "1.7976931348623157e308", + "-1.7976931348623157e308", + // next float64 - too large + "1.7976931348623159e308", + "-1.7976931348623159e308", + // the border is ...158079 + // borderline - okay + "1.7976931348623158e308", + "-1.7976931348623158e308", + // borderline - too large + "1.797693134862315808e308", + "-1.797693134862315808e308", + + // a little too large + "1e308", + "2e308", + "1e309", + + // way too large + "1e310", + "-1e310", + "1e400", + "-1e400", + "long:1e400000", + "long:-1e400000", + + // denormalized + "1e-305", + "1e-306", + "1e-307", + "1e-308", + "1e-309", + "1e-310", + "1e-322", + // smallest denormal + "5e-324", + "4e-324", + "3e-324", + // too small + "2e-324", + // way too small + "1e-350", + "long:1e-400000", + // way too small, negative + "-1e-350", + "long:-1e-400000", + + // try to overflow exponent + // [Disabled: too slow and memory-hungry with rationals.] + // "1e-4294967296", + // "1e+4294967296", + // "1e-18446744073709551616", + // "1e+18446744073709551616", + + // https://www.exploringbinary.com/java-hangs-when-converting-2-2250738585072012e-308/ + "2.2250738585072012e-308", + // https://www.exploringbinary.com/php-hangs-on-numeric-value-2-2250738585072011e-308/ + "2.2250738585072011e-308", + + // A very large number (initially wrongly parsed by the fast algorithm). + "4.630813248087435e+307", + + // A different kind of very large number. + "22.222222222222222", + "long:2." + strings.Repeat("2", 4000) + "e+1", + + // Exactly halfway between 1 and math.Nextafter(1, 2). + // Round to even (down). + "1.00000000000000011102230246251565404236316680908203125", + // Slightly lower; still round down. + "1.00000000000000011102230246251565404236316680908203124", + // Slightly higher; round up. + "1.00000000000000011102230246251565404236316680908203126", + // Slightly higher, but you have to read all the way to the end. + "long:1.00000000000000011102230246251565404236316680908203125" + strings.Repeat("0", 10000) + "1", + + // Smallest denormal, 2^(-1022-52) + "4.940656458412465441765687928682213723651e-324", + // Half of smallest denormal, 2^(-1022-53) + "2.470328229206232720882843964341106861825e-324", + // A little more than the exact half of smallest denormal + // 2^-1075 + 2^-1100. (Rounds to 1p-1074.) + "2.470328302827751011111470718709768633275e-324", + // The exact halfway between smallest normal and largest denormal: + // 2^-1022 - 2^-1075. (Rounds to 2^-1022.) + "2.225073858507201136057409796709131975935e-308", + + "1152921504606846975", // 1<<60 - 1 + "-1152921504606846975", // -(1<<60 - 1) + "1152921504606846977", // 1<<60 + 1 + "-1152921504606846977", // -(1<<60 + 1) + + "1/3", +} + +// isFinite reports whether f represents a finite rational value. +// It is equivalent to !math.IsNan(f) && !math.IsInf(f, 0). +func isFinite(f float64) bool { + return math.Abs(f) <= math.MaxFloat64 +} + +func TestFloat32SpecialCases(t *testing.T) { + for _, input := range float64inputs { + if strings.HasPrefix(input, "long:") { + if !*long { + continue + } + input = input[len("long:"):] + } + + r, ok := new(Rat).SetString(input) + if !ok { + t.Errorf("Rat.SetString(%q) failed", input) + continue + } + f, exact := r.Float32() + + // 1. Check string -> Rat -> float32 conversions are + // consistent with strconv.ParseFloat. + // Skip this check if the input uses "a/b" rational syntax. + if !strings.Contains(input, "/") { + e64, _ := strconv.ParseFloat(input, 32) + e := float32(e64) + + // Careful: negative Rats too small for + // float64 become -0, but Rat obviously cannot + // preserve the sign from SetString("-0"). + switch { + case math.Float32bits(e) == math.Float32bits(f): + // Ok: bitwise equal. + case f == 0 && r.Num().BitLen() == 0: + // Ok: Rat(0) is equivalent to both +/- float64(0). + default: + t.Errorf("strconv.ParseFloat(%q) = %g (%b), want %g (%b); delta = %g", input, e, e, f, f, f-e) + } + } + + if !isFinite(float64(f)) { + continue + } + + // 2. Check f is best approximation to r. + if !checkIsBestApprox32(t, f, r) { + // Append context information. + t.Errorf("(input was %q)", input) + } + + // 3. Check f->R->f roundtrip is non-lossy. + checkNonLossyRoundtrip32(t, f) + + // 4. Check exactness using slow algorithm. + if wasExact := new(Rat).SetFloat64(float64(f)).Cmp(r) == 0; wasExact != exact { + t.Errorf("Rat.SetString(%q).Float32().exact = %t, want %t", input, exact, wasExact) + } + } +} + +func TestFloat64SpecialCases(t *testing.T) { + for _, input := range float64inputs { + if strings.HasPrefix(input, "long:") { + if !*long { + continue + } + input = input[len("long:"):] + } + + r, ok := new(Rat).SetString(input) + if !ok { + t.Errorf("Rat.SetString(%q) failed", input) + continue + } + f, exact := r.Float64() + + // 1. Check string -> Rat -> float64 conversions are + // consistent with strconv.ParseFloat. + // Skip this check if the input uses "a/b" rational syntax. + if !strings.Contains(input, "/") { + e, _ := strconv.ParseFloat(input, 64) + + // Careful: negative Rats too small for + // float64 become -0, but Rat obviously cannot + // preserve the sign from SetString("-0"). + switch { + case math.Float64bits(e) == math.Float64bits(f): + // Ok: bitwise equal. + case f == 0 && r.Num().BitLen() == 0: + // Ok: Rat(0) is equivalent to both +/- float64(0). + default: + t.Errorf("strconv.ParseFloat(%q) = %g (%b), want %g (%b); delta = %g", input, e, e, f, f, f-e) + } + } + + if !isFinite(f) { + continue + } + + // 2. Check f is best approximation to r. + if !checkIsBestApprox64(t, f, r) { + // Append context information. + t.Errorf("(input was %q)", input) + } + + // 3. Check f->R->f roundtrip is non-lossy. + checkNonLossyRoundtrip64(t, f) + + // 4. Check exactness using slow algorithm. + if wasExact := new(Rat).SetFloat64(f).Cmp(r) == 0; wasExact != exact { + t.Errorf("Rat.SetString(%q).Float64().exact = %t, want %t", input, exact, wasExact) + } + } +} + +func TestIssue31184(t *testing.T) { + var x Rat + for _, want := range []string{ + "-213.090", + "8.192", + "16.000", + } { + x.SetString(want) + got := x.FloatString(3) + if got != want { + t.Errorf("got %s, want %s", got, want) + } + } +} + +func TestIssue45910(t *testing.T) { + var x Rat + for _, test := range []struct { + input string + want bool + }{ + {"1e-1000001", false}, + {"1e-1000000", true}, + {"1e+1000000", true}, + {"1e+1000001", false}, + + {"0p1000000000000", true}, + {"1p-10000001", false}, + {"1p-10000000", true}, + {"1p+10000000", true}, + {"1p+10000001", false}, + {"1.770p02041010010011001001", false}, // test case from issue + } { + _, got := x.SetString(test.input) + if got != test.want { + t.Errorf("SetString(%s) got ok = %v; want %v", test.input, got, test.want) + } + } +} diff --git a/src/math/big/ratmarsh.go b/src/math/big/ratmarsh.go new file mode 100644 index 0000000..b69c59d --- /dev/null +++ b/src/math/big/ratmarsh.go @@ -0,0 +1,86 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements encoding/decoding of Rats. + +package big + +import ( + "encoding/binary" + "errors" + "fmt" + "math" +) + +// Gob codec version. Permits backward-compatible changes to the encoding. +const ratGobVersion byte = 1 + +// GobEncode implements the gob.GobEncoder interface. +func (x *Rat) GobEncode() ([]byte, error) { + if x == nil { + return nil, nil + } + buf := make([]byte, 1+4+(len(x.a.abs)+len(x.b.abs))*_S) // extra bytes for version and sign bit (1), and numerator length (4) + i := x.b.abs.bytes(buf) + j := x.a.abs.bytes(buf[:i]) + n := i - j + if int(uint32(n)) != n { + // this should never happen + return nil, errors.New("Rat.GobEncode: numerator too large") + } + binary.BigEndian.PutUint32(buf[j-4:j], uint32(n)) + j -= 1 + 4 + b := ratGobVersion << 1 // make space for sign bit + if x.a.neg { + b |= 1 + } + buf[j] = b + return buf[j:], nil +} + +// GobDecode implements the gob.GobDecoder interface. +func (z *Rat) GobDecode(buf []byte) error { + if len(buf) == 0 { + // Other side sent a nil or default value. + *z = Rat{} + return nil + } + if len(buf) < 5 { + return errors.New("Rat.GobDecode: buffer too small") + } + b := buf[0] + if b>>1 != ratGobVersion { + return fmt.Errorf("Rat.GobDecode: encoding version %d not supported", b>>1) + } + const j = 1 + 4 + ln := binary.BigEndian.Uint32(buf[j-4 : j]) + if uint64(ln) > math.MaxInt-j { + return errors.New("Rat.GobDecode: invalid length") + } + i := j + int(ln) + if len(buf) < i { + return errors.New("Rat.GobDecode: buffer too small") + } + z.a.neg = b&1 != 0 + z.a.abs = z.a.abs.setBytes(buf[j:i]) + z.b.abs = z.b.abs.setBytes(buf[i:]) + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (x *Rat) MarshalText() (text []byte, err error) { + if x.IsInt() { + return x.a.MarshalText() + } + return x.marshal(), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +func (z *Rat) UnmarshalText(text []byte) error { + // TODO(gri): get rid of the []byte/string conversion + if _, ok := z.SetString(string(text)); !ok { + return fmt.Errorf("math/big: cannot unmarshal %q into a *big.Rat", text) + } + return nil +} diff --git a/src/math/big/ratmarsh_test.go b/src/math/big/ratmarsh_test.go new file mode 100644 index 0000000..15c933e --- /dev/null +++ b/src/math/big/ratmarsh_test.go @@ -0,0 +1,138 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "bytes" + "encoding/gob" + "encoding/json" + "encoding/xml" + "testing" +) + +func TestRatGobEncoding(t *testing.T) { + var medium bytes.Buffer + enc := gob.NewEncoder(&medium) + dec := gob.NewDecoder(&medium) + for _, test := range encodingTests { + medium.Reset() // empty buffer for each test case (in case of failures) + var tx Rat + tx.SetString(test + ".14159265") + if err := enc.Encode(&tx); err != nil { + t.Errorf("encoding of %s failed: %s", &tx, err) + continue + } + var rx Rat + if err := dec.Decode(&rx); err != nil { + t.Errorf("decoding of %s failed: %s", &tx, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("transmission of %s failed: got %s want %s", &tx, &rx, &tx) + } + } +} + +// Sending a nil Rat pointer (inside a slice) on a round trip through gob should yield a zero. +// TODO: top-level nils. +func TestGobEncodingNilRatInSlice(t *testing.T) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + dec := gob.NewDecoder(buf) + + var in = make([]*Rat, 1) + err := enc.Encode(&in) + if err != nil { + t.Errorf("gob encode failed: %q", err) + } + var out []*Rat + err = dec.Decode(&out) + if err != nil { + t.Fatalf("gob decode failed: %q", err) + } + if len(out) != 1 { + t.Fatalf("wrong len; want 1 got %d", len(out)) + } + var zero Rat + if out[0].Cmp(&zero) != 0 { + t.Fatalf("transmission of (*Int)(nil) failed: got %s want 0", out) + } +} + +var ratNums = []string{ + "-141592653589793238462643383279502884197169399375105820974944592307816406286", + "-1415926535897932384626433832795028841971", + "-141592653589793", + "-1", + "0", + "1", + "141592653589793", + "1415926535897932384626433832795028841971", + "141592653589793238462643383279502884197169399375105820974944592307816406286", +} + +var ratDenoms = []string{ + "1", + "718281828459045", + "7182818284590452353602874713526624977572", + "718281828459045235360287471352662497757247093699959574966967627724076630353", +} + +func TestRatJSONEncoding(t *testing.T) { + for _, num := range ratNums { + for _, denom := range ratDenoms { + var tx Rat + tx.SetString(num + "/" + denom) + b, err := json.Marshal(&tx) + if err != nil { + t.Errorf("marshaling of %s failed: %s", &tx, err) + continue + } + var rx Rat + if err := json.Unmarshal(b, &rx); err != nil { + t.Errorf("unmarshaling of %s failed: %s", &tx, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("JSON encoding of %s failed: got %s want %s", &tx, &rx, &tx) + } + } + } +} + +func TestRatXMLEncoding(t *testing.T) { + for _, num := range ratNums { + for _, denom := range ratDenoms { + var tx Rat + tx.SetString(num + "/" + denom) + b, err := xml.Marshal(&tx) + if err != nil { + t.Errorf("marshaling of %s failed: %s", &tx, err) + continue + } + var rx Rat + if err := xml.Unmarshal(b, &rx); err != nil { + t.Errorf("unmarshaling of %s failed: %s", &tx, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("XML encoding of %s failed: got %s want %s", &tx, &rx, &tx) + } + } + } +} + +func TestRatGobDecodeShortBuffer(t *testing.T) { + for _, tc := range [][]byte{ + []byte{0x2}, + []byte{0x2, 0x0, 0x0, 0x0, 0xff}, + []byte{0x2, 0xff, 0xff, 0xff, 0xff}, + } { + err := NewRat(1, 2).GobDecode(tc) + if err == nil { + t.Error("expected GobDecode to return error for malformed input") + } + } +} diff --git a/src/math/big/roundingmode_string.go b/src/math/big/roundingmode_string.go new file mode 100644 index 0000000..c7629eb --- /dev/null +++ b/src/math/big/roundingmode_string.go @@ -0,0 +1,16 @@ +// Code generated by "stringer -type=RoundingMode"; DO NOT EDIT. + +package big + +import "strconv" + +const _RoundingMode_name = "ToNearestEvenToNearestAwayToZeroAwayFromZeroToNegativeInfToPositiveInf" + +var _RoundingMode_index = [...]uint8{0, 13, 26, 32, 44, 57, 70} + +func (i RoundingMode) String() string { + if i >= RoundingMode(len(_RoundingMode_index)-1) { + return "RoundingMode(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _RoundingMode_name[_RoundingMode_index[i]:_RoundingMode_index[i+1]] +} diff --git a/src/math/big/sqrt.go b/src/math/big/sqrt.go new file mode 100644 index 0000000..b4b0374 --- /dev/null +++ b/src/math/big/sqrt.go @@ -0,0 +1,130 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "math" + "sync" +) + +var threeOnce struct { + sync.Once + v *Float +} + +func three() *Float { + threeOnce.Do(func() { + threeOnce.v = NewFloat(3.0) + }) + return threeOnce.v +} + +// Sqrt sets z to the rounded square root of x, and returns it. +// +// If z's precision is 0, it is changed to x's precision before the +// operation. Rounding is performed according to z's precision and +// rounding mode, but z's accuracy is not computed. Specifically, the +// result of z.Acc() is undefined. +// +// The function panics if z < 0. The value of z is undefined in that +// case. +func (z *Float) Sqrt(x *Float) *Float { + if debugFloat { + x.validate() + } + + if z.prec == 0 { + z.prec = x.prec + } + + if x.Sign() == -1 { + // following IEEE754-2008 (section 7.2) + panic(ErrNaN{"square root of negative operand"}) + } + + // handle ±0 and +∞ + if x.form != finite { + z.acc = Exact + z.form = x.form + z.neg = x.neg // IEEE754-2008 requires √±0 = ±0 + return z + } + + // MantExp sets the argument's precision to the receiver's, and + // when z.prec > x.prec this will lower z.prec. Restore it after + // the MantExp call. + prec := z.prec + b := x.MantExp(z) + z.prec = prec + + // Compute √(z·2**b) as + // √( z)·2**(½b) if b is even + // √(2z)·2**(⌊½b⌋) if b > 0 is odd + // √(½z)·2**(⌈½b⌉) if b < 0 is odd + switch b % 2 { + case 0: + // nothing to do + case 1: + z.exp++ + case -1: + z.exp-- + } + // 0.25 <= z < 2.0 + + // Solving 1/x² - z = 0 avoids Quo calls and is faster, especially + // for high precisions. + z.sqrtInverse(z) + + // re-attach halved exponent + return z.SetMantExp(z, b/2) +} + +// Compute √x (to z.prec precision) by solving +// +// 1/t² - x = 0 +// +// for t (using Newton's method), and then inverting. +func (z *Float) sqrtInverse(x *Float) { + // let + // f(t) = 1/t² - x + // then + // g(t) = f(t)/f'(t) = -½t(1 - xt²) + // and the next guess is given by + // t2 = t - g(t) = ½t(3 - xt²) + u := newFloat(z.prec) + v := newFloat(z.prec) + three := three() + ng := func(t *Float) *Float { + u.prec = t.prec + v.prec = t.prec + u.Mul(t, t) // u = t² + u.Mul(x, u) // = xt² + v.Sub(three, u) // v = 3 - xt² + u.Mul(t, v) // u = t(3 - xt²) + u.exp-- // = ½t(3 - xt²) + return t.Set(u) + } + + xf, _ := x.Float64() + sqi := newFloat(z.prec) + sqi.SetFloat64(1 / math.Sqrt(xf)) + for prec := z.prec + 32; sqi.prec < prec; { + sqi.prec *= 2 + sqi = ng(sqi) + } + // sqi = 1/√x + + // x/√x = √x + z.Mul(x, sqi) +} + +// newFloat returns a new *Float with space for twice the given +// precision. +func newFloat(prec2 uint32) *Float { + z := new(Float) + // nat.make ensures the slice length is > 0 + z.mant = z.mant.make(int(prec2/_W) * 2) + return z +} diff --git a/src/math/big/sqrt_test.go b/src/math/big/sqrt_test.go new file mode 100644 index 0000000..d314711 --- /dev/null +++ b/src/math/big/sqrt_test.go @@ -0,0 +1,126 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import ( + "fmt" + "math" + "math/rand" + "testing" +) + +// TestFloatSqrt64 tests that Float.Sqrt of numbers with 53bit mantissa +// behaves like float math.Sqrt. +func TestFloatSqrt64(t *testing.T) { + for i := 0; i < 1e5; i++ { + if i == 1e2 && testing.Short() { + break + } + r := rand.Float64() + + got := new(Float).SetPrec(53) + got.Sqrt(NewFloat(r)) + want := NewFloat(math.Sqrt(r)) + if got.Cmp(want) != 0 { + t.Fatalf("Sqrt(%g) =\n got %g;\nwant %g", r, got, want) + } + } +} + +func TestFloatSqrt(t *testing.T) { + for _, test := range []struct { + x string + want string + }{ + // Test values were generated on Wolfram Alpha using query + // 'sqrt(N) to 350 digits' + // 350 decimal digits give up to 1000 binary digits. + {"0.03125", "0.17677669529663688110021109052621225982120898442211850914708496724884155980776337985629844179095519659187673077886403712811560450698134215158051518713749197892665283324093819909447499381264409775757143376369499645074628431682460775184106467733011114982619404115381053858929018135497032545349940642599871090667456829147610370507757690729404938184321879"}, + {"0.125", "0.35355339059327376220042218105242451964241796884423701829416993449768311961552675971259688358191039318375346155772807425623120901396268430316103037427498395785330566648187639818894998762528819551514286752738999290149256863364921550368212935466022229965238808230762107717858036270994065090699881285199742181334913658295220741015515381458809876368643757"}, + {"0.5", "0.70710678118654752440084436210484903928483593768847403658833986899536623923105351942519376716382078636750692311545614851246241802792536860632206074854996791570661133296375279637789997525057639103028573505477998580298513726729843100736425870932044459930477616461524215435716072541988130181399762570399484362669827316590441482031030762917619752737287514"}, + {"2.0", "1.4142135623730950488016887242096980785696718753769480731766797379907324784621070388503875343276415727350138462309122970249248360558507372126441214970999358314132226659275055927557999505011527820605714701095599716059702745345968620147285174186408891986095523292304843087143214508397626036279952514079896872533965463318088296406206152583523950547457503"}, + {"3.0", "1.7320508075688772935274463415058723669428052538103806280558069794519330169088000370811461867572485756756261414154067030299699450949989524788116555120943736485280932319023055820679748201010846749232650153123432669033228866506722546689218379712270471316603678615880190499865373798593894676503475065760507566183481296061009476021871903250831458295239598"}, + {"4.0", "2.0"}, + + {"1p512", "1p256"}, + {"4p1024", "2p512"}, + {"9p2048", "3p1024"}, + + {"1p-1024", "1p-512"}, + {"4p-2048", "2p-1024"}, + {"9p-4096", "3p-2048"}, + } { + for _, prec := range []uint{24, 53, 64, 65, 100, 128, 129, 200, 256, 400, 600, 800, 1000} { + x := new(Float).SetPrec(prec) + x.Parse(test.x, 10) + + got := new(Float).SetPrec(prec).Sqrt(x) + want := new(Float).SetPrec(prec) + want.Parse(test.want, 10) + if got.Cmp(want) != 0 { + t.Errorf("prec = %d, Sqrt(%v) =\ngot %g;\nwant %g", + prec, test.x, got, want) + } + + // Square test. + // If got holds the square root of x to precision p, then + // got = √x + k + // for some k such that |k| < 2**(-p). Thus, + // got² = (√x + k)² = x + 2k√n + k² + // and the error must satisfy + // err = |got² - x| ≈ | 2k√n | < 2**(-p+1)*√n + // Ignoring the k² term for simplicity. + + // err = |got² - x| + // (but do intermediate steps with 32 guard digits to + // avoid introducing spurious rounding-related errors) + sq := new(Float).SetPrec(prec+32).Mul(got, got) + diff := new(Float).Sub(sq, x) + err := diff.Abs(diff).SetPrec(prec) + + // maxErr = 2**(-p+1)*√x + one := new(Float).SetPrec(prec).SetInt64(1) + maxErr := new(Float).Mul(new(Float).SetMantExp(one, -int(prec)+1), got) + + if err.Cmp(maxErr) >= 0 { + t.Errorf("prec = %d, Sqrt(%v) =\ngot err %g;\nwant maxErr %g", + prec, test.x, err, maxErr) + } + } + } +} + +func TestFloatSqrtSpecial(t *testing.T) { + for _, test := range []struct { + x *Float + want *Float + }{ + {NewFloat(+0), NewFloat(+0)}, + {NewFloat(-0), NewFloat(-0)}, + {NewFloat(math.Inf(+1)), NewFloat(math.Inf(+1))}, + } { + got := new(Float).Sqrt(test.x) + if got.neg != test.want.neg || got.form != test.want.form { + t.Errorf("Sqrt(%v) = %v (neg: %v); want %v (neg: %v)", + test.x, got, got.neg, test.want, test.want.neg) + } + } + +} + +// Benchmarks + +func BenchmarkFloatSqrt(b *testing.B) { + for _, prec := range []uint{64, 128, 256, 1e3, 1e4, 1e5, 1e6} { + x := NewFloat(2) + z := new(Float).SetPrec(prec) + b.Run(fmt.Sprintf("%v", prec), func(b *testing.B) { + b.ReportAllocs() + for n := 0; n < b.N; n++ { + z.Sqrt(x) + } + }) + } +} |