summaryrefslogtreecommitdiffstats
path: root/src/math/rand/rand_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/math/rand/rand_test.go')
-rw-r--r--src/math/rand/rand_test.go681
1 files changed, 681 insertions, 0 deletions
diff --git a/src/math/rand/rand_test.go b/src/math/rand/rand_test.go
new file mode 100644
index 0000000..e037aae
--- /dev/null
+++ b/src/math/rand/rand_test.go
@@ -0,0 +1,681 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rand
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "internal/testenv"
+ "io"
+ "math"
+ "os"
+ "runtime"
+ "testing"
+ "testing/iotest"
+)
+
+const (
+ numTestSamples = 10000
+)
+
+type statsResults struct {
+ mean float64
+ stddev float64
+ closeEnough float64
+ maxError float64
+}
+
+func max(a, b float64) float64 {
+ if a > b {
+ return a
+ }
+ return b
+}
+
+func nearEqual(a, b, closeEnough, maxError float64) bool {
+ absDiff := math.Abs(a - b)
+ if absDiff < closeEnough { // Necessary when one value is zero and one value is close to zero.
+ return true
+ }
+ return absDiff/max(math.Abs(a), math.Abs(b)) < maxError
+}
+
+var testSeeds = []int64{1, 1754801282, 1698661970, 1550503961}
+
+// checkSimilarDistribution returns success if the mean and stddev of the
+// two statsResults are similar.
+func (this *statsResults) checkSimilarDistribution(expected *statsResults) error {
+ if !nearEqual(this.mean, expected.mean, expected.closeEnough, expected.maxError) {
+ s := fmt.Sprintf("mean %v != %v (allowed error %v, %v)", this.mean, expected.mean, expected.closeEnough, expected.maxError)
+ fmt.Println(s)
+ return errors.New(s)
+ }
+ if !nearEqual(this.stddev, expected.stddev, expected.closeEnough, expected.maxError) {
+ s := fmt.Sprintf("stddev %v != %v (allowed error %v, %v)", this.stddev, expected.stddev, expected.closeEnough, expected.maxError)
+ fmt.Println(s)
+ return errors.New(s)
+ }
+ return nil
+}
+
+func getStatsResults(samples []float64) *statsResults {
+ res := new(statsResults)
+ var sum, squaresum float64
+ for _, s := range samples {
+ sum += s
+ squaresum += s * s
+ }
+ res.mean = sum / float64(len(samples))
+ res.stddev = math.Sqrt(squaresum/float64(len(samples)) - res.mean*res.mean)
+ return res
+}
+
+func checkSampleDistribution(t *testing.T, samples []float64, expected *statsResults) {
+ t.Helper()
+ actual := getStatsResults(samples)
+ err := actual.checkSimilarDistribution(expected)
+ if err != nil {
+ t.Errorf(err.Error())
+ }
+}
+
+func checkSampleSliceDistributions(t *testing.T, samples []float64, nslices int, expected *statsResults) {
+ t.Helper()
+ chunk := len(samples) / nslices
+ for i := 0; i < nslices; i++ {
+ low := i * chunk
+ var high int
+ if i == nslices-1 {
+ high = len(samples) - 1
+ } else {
+ high = (i + 1) * chunk
+ }
+ checkSampleDistribution(t, samples[low:high], expected)
+ }
+}
+
+//
+// Normal distribution tests
+//
+
+func generateNormalSamples(nsamples int, mean, stddev float64, seed int64) []float64 {
+ r := New(NewSource(seed))
+ samples := make([]float64, nsamples)
+ for i := range samples {
+ samples[i] = r.NormFloat64()*stddev + mean
+ }
+ return samples
+}
+
+func testNormalDistribution(t *testing.T, nsamples int, mean, stddev float64, seed int64) {
+ //fmt.Printf("testing nsamples=%v mean=%v stddev=%v seed=%v\n", nsamples, mean, stddev, seed);
+
+ samples := generateNormalSamples(nsamples, mean, stddev, seed)
+ errorScale := max(1.0, stddev) // Error scales with stddev
+ expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale}
+
+ // Make sure that the entire set matches the expected distribution.
+ checkSampleDistribution(t, samples, expected)
+
+ // Make sure that each half of the set matches the expected distribution.
+ checkSampleSliceDistributions(t, samples, 2, expected)
+
+ // Make sure that each 7th of the set matches the expected distribution.
+ checkSampleSliceDistributions(t, samples, 7, expected)
+}
+
+// Actual tests
+
+func TestStandardNormalValues(t *testing.T) {
+ for _, seed := range testSeeds {
+ testNormalDistribution(t, numTestSamples, 0, 1, seed)
+ }
+}
+
+func TestNonStandardNormalValues(t *testing.T) {
+ sdmax := 1000.0
+ mmax := 1000.0
+ if testing.Short() {
+ sdmax = 5
+ mmax = 5
+ }
+ for sd := 0.5; sd < sdmax; sd *= 2 {
+ for m := 0.5; m < mmax; m *= 2 {
+ for _, seed := range testSeeds {
+ testNormalDistribution(t, numTestSamples, m, sd, seed)
+ if testing.Short() {
+ break
+ }
+ }
+ }
+ }
+}
+
+//
+// Exponential distribution tests
+//
+
+func generateExponentialSamples(nsamples int, rate float64, seed int64) []float64 {
+ r := New(NewSource(seed))
+ samples := make([]float64, nsamples)
+ for i := range samples {
+ samples[i] = r.ExpFloat64() / rate
+ }
+ return samples
+}
+
+func testExponentialDistribution(t *testing.T, nsamples int, rate float64, seed int64) {
+ //fmt.Printf("testing nsamples=%v rate=%v seed=%v\n", nsamples, rate, seed);
+
+ mean := 1 / rate
+ stddev := mean
+
+ samples := generateExponentialSamples(nsamples, rate, seed)
+ errorScale := max(1.0, 1/rate) // Error scales with the inverse of the rate
+ expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.20 * errorScale}
+
+ // Make sure that the entire set matches the expected distribution.
+ checkSampleDistribution(t, samples, expected)
+
+ // Make sure that each half of the set matches the expected distribution.
+ checkSampleSliceDistributions(t, samples, 2, expected)
+
+ // Make sure that each 7th of the set matches the expected distribution.
+ checkSampleSliceDistributions(t, samples, 7, expected)
+}
+
+// Actual tests
+
+func TestStandardExponentialValues(t *testing.T) {
+ for _, seed := range testSeeds {
+ testExponentialDistribution(t, numTestSamples, 1, seed)
+ }
+}
+
+func TestNonStandardExponentialValues(t *testing.T) {
+ for rate := 0.05; rate < 10; rate *= 2 {
+ for _, seed := range testSeeds {
+ testExponentialDistribution(t, numTestSamples, rate, seed)
+ if testing.Short() {
+ break
+ }
+ }
+ }
+}
+
+//
+// Table generation tests
+//
+
+func initNorm() (testKn []uint32, testWn, testFn []float32) {
+ const m1 = 1 << 31
+ var (
+ dn float64 = rn
+ tn = dn
+ vn float64 = 9.91256303526217e-3
+ )
+
+ testKn = make([]uint32, 128)
+ testWn = make([]float32, 128)
+ testFn = make([]float32, 128)
+
+ q := vn / math.Exp(-0.5*dn*dn)
+ testKn[0] = uint32((dn / q) * m1)
+ testKn[1] = 0
+ testWn[0] = float32(q / m1)
+ testWn[127] = float32(dn / m1)
+ testFn[0] = 1.0
+ testFn[127] = float32(math.Exp(-0.5 * dn * dn))
+ for i := 126; i >= 1; i-- {
+ dn = math.Sqrt(-2.0 * math.Log(vn/dn+math.Exp(-0.5*dn*dn)))
+ testKn[i+1] = uint32((dn / tn) * m1)
+ tn = dn
+ testFn[i] = float32(math.Exp(-0.5 * dn * dn))
+ testWn[i] = float32(dn / m1)
+ }
+ return
+}
+
+func initExp() (testKe []uint32, testWe, testFe []float32) {
+ const m2 = 1 << 32
+ var (
+ de float64 = re
+ te = de
+ ve float64 = 3.9496598225815571993e-3
+ )
+
+ testKe = make([]uint32, 256)
+ testWe = make([]float32, 256)
+ testFe = make([]float32, 256)
+
+ q := ve / math.Exp(-de)
+ testKe[0] = uint32((de / q) * m2)
+ testKe[1] = 0
+ testWe[0] = float32(q / m2)
+ testWe[255] = float32(de / m2)
+ testFe[0] = 1.0
+ testFe[255] = float32(math.Exp(-de))
+ for i := 254; i >= 1; i-- {
+ de = -math.Log(ve/de + math.Exp(-de))
+ testKe[i+1] = uint32((de / te) * m2)
+ te = de
+ testFe[i] = float32(math.Exp(-de))
+ testWe[i] = float32(de / m2)
+ }
+ return
+}
+
+// compareUint32Slices returns the first index where the two slices
+// disagree, or <0 if the lengths are the same and all elements
+// are identical.
+func compareUint32Slices(s1, s2 []uint32) int {
+ if len(s1) != len(s2) {
+ if len(s1) > len(s2) {
+ return len(s2) + 1
+ }
+ return len(s1) + 1
+ }
+ for i := range s1 {
+ if s1[i] != s2[i] {
+ return i
+ }
+ }
+ return -1
+}
+
+// compareFloat32Slices returns the first index where the two slices
+// disagree, or <0 if the lengths are the same and all elements
+// are identical.
+func compareFloat32Slices(s1, s2 []float32) int {
+ if len(s1) != len(s2) {
+ if len(s1) > len(s2) {
+ return len(s2) + 1
+ }
+ return len(s1) + 1
+ }
+ for i := range s1 {
+ if !nearEqual(float64(s1[i]), float64(s2[i]), 0, 1e-7) {
+ return i
+ }
+ }
+ return -1
+}
+
+func TestNormTables(t *testing.T) {
+ testKn, testWn, testFn := initNorm()
+ if i := compareUint32Slices(kn[0:], testKn); i >= 0 {
+ t.Errorf("kn disagrees at index %v; %v != %v", i, kn[i], testKn[i])
+ }
+ if i := compareFloat32Slices(wn[0:], testWn); i >= 0 {
+ t.Errorf("wn disagrees at index %v; %v != %v", i, wn[i], testWn[i])
+ }
+ if i := compareFloat32Slices(fn[0:], testFn); i >= 0 {
+ t.Errorf("fn disagrees at index %v; %v != %v", i, fn[i], testFn[i])
+ }
+}
+
+func TestExpTables(t *testing.T) {
+ testKe, testWe, testFe := initExp()
+ if i := compareUint32Slices(ke[0:], testKe); i >= 0 {
+ t.Errorf("ke disagrees at index %v; %v != %v", i, ke[i], testKe[i])
+ }
+ if i := compareFloat32Slices(we[0:], testWe); i >= 0 {
+ t.Errorf("we disagrees at index %v; %v != %v", i, we[i], testWe[i])
+ }
+ if i := compareFloat32Slices(fe[0:], testFe); i >= 0 {
+ t.Errorf("fe disagrees at index %v; %v != %v", i, fe[i], testFe[i])
+ }
+}
+
+func hasSlowFloatingPoint() bool {
+ switch runtime.GOARCH {
+ case "arm":
+ return os.Getenv("GOARM") == "5"
+ case "mips", "mipsle", "mips64", "mips64le":
+ // Be conservative and assume that all mips boards
+ // have emulated floating point.
+ // TODO: detect what it actually has.
+ return true
+ }
+ return false
+}
+
+func TestFloat32(t *testing.T) {
+ // For issue 6721, the problem came after 7533753 calls, so check 10e6.
+ num := int(10e6)
+ // But do the full amount only on builders (not locally).
+ // But ARM5 floating point emulation is slow (Issue 10749), so
+ // do less for that builder:
+ if testing.Short() && (testenv.Builder() == "" || hasSlowFloatingPoint()) {
+ num /= 100 // 1.72 seconds instead of 172 seconds
+ }
+
+ r := New(NewSource(1))
+ for ct := 0; ct < num; ct++ {
+ f := r.Float32()
+ if f >= 1 {
+ t.Fatal("Float32() should be in range [0,1). ct:", ct, "f:", f)
+ }
+ }
+}
+
+func testReadUniformity(t *testing.T, n int, seed int64) {
+ r := New(NewSource(seed))
+ buf := make([]byte, n)
+ nRead, err := r.Read(buf)
+ if err != nil {
+ t.Errorf("Read err %v", err)
+ }
+ if nRead != n {
+ t.Errorf("Read returned unexpected n; %d != %d", nRead, n)
+ }
+
+ // Expect a uniform distribution of byte values, which lie in [0, 255].
+ var (
+ mean = 255.0 / 2
+ stddev = 256.0 / math.Sqrt(12.0)
+ errorScale = stddev / math.Sqrt(float64(n))
+ )
+
+ expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale}
+
+ // Cast bytes as floats to use the common distribution-validity checks.
+ samples := make([]float64, n)
+ for i, val := range buf {
+ samples[i] = float64(val)
+ }
+ // Make sure that the entire set matches the expected distribution.
+ checkSampleDistribution(t, samples, expected)
+}
+
+func TestReadUniformity(t *testing.T) {
+ testBufferSizes := []int{
+ 2, 4, 7, 64, 1024, 1 << 16, 1 << 20,
+ }
+ for _, seed := range testSeeds {
+ for _, n := range testBufferSizes {
+ testReadUniformity(t, n, seed)
+ }
+ }
+}
+
+func TestReadEmpty(t *testing.T) {
+ r := New(NewSource(1))
+ buf := make([]byte, 0)
+ n, err := r.Read(buf)
+ if err != nil {
+ t.Errorf("Read err into empty buffer; %v", err)
+ }
+ if n != 0 {
+ t.Errorf("Read into empty buffer returned unexpected n of %d", n)
+ }
+}
+
+func TestReadByOneByte(t *testing.T) {
+ r := New(NewSource(1))
+ b1 := make([]byte, 100)
+ _, err := io.ReadFull(iotest.OneByteReader(r), b1)
+ if err != nil {
+ t.Errorf("read by one byte: %v", err)
+ }
+ r = New(NewSource(1))
+ b2 := make([]byte, 100)
+ _, err = r.Read(b2)
+ if err != nil {
+ t.Errorf("read: %v", err)
+ }
+ if !bytes.Equal(b1, b2) {
+ t.Errorf("read by one byte vs single read:\n%x\n%x", b1, b2)
+ }
+}
+
+func TestReadSeedReset(t *testing.T) {
+ r := New(NewSource(42))
+ b1 := make([]byte, 128)
+ _, err := r.Read(b1)
+ if err != nil {
+ t.Errorf("read: %v", err)
+ }
+ r.Seed(42)
+ b2 := make([]byte, 128)
+ _, err = r.Read(b2)
+ if err != nil {
+ t.Errorf("read: %v", err)
+ }
+ if !bytes.Equal(b1, b2) {
+ t.Errorf("mismatch after re-seed:\n%x\n%x", b1, b2)
+ }
+}
+
+func TestShuffleSmall(t *testing.T) {
+ // Check that Shuffle allows n=0 and n=1, but that swap is never called for them.
+ r := New(NewSource(1))
+ for n := 0; n <= 1; n++ {
+ r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) })
+ }
+}
+
+// encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!).
+// See https://en.wikipedia.org/wiki/Lehmer_code.
+// encodePerm modifies the input slice.
+func encodePerm(s []int) int {
+ // Convert to Lehmer code.
+ for i, x := range s {
+ r := s[i+1:]
+ for j, y := range r {
+ if y > x {
+ r[j]--
+ }
+ }
+ }
+ // Convert to int in [0, n!).
+ m := 0
+ fact := 1
+ for i := len(s) - 1; i >= 0; i-- {
+ m += s[i] * fact
+ fact *= len(s) - i
+ }
+ return m
+}
+
+// TestUniformFactorial tests several ways of generating a uniform value in [0, n!).
+func TestUniformFactorial(t *testing.T) {
+ r := New(NewSource(testSeeds[0]))
+ top := 6
+ if testing.Short() {
+ top = 3
+ }
+ for n := 3; n <= top; n++ {
+ t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) {
+ // Calculate n!.
+ nfact := 1
+ for i := 2; i <= n; i++ {
+ nfact *= i
+ }
+
+ // Test a few different ways to generate a uniform distribution.
+ p := make([]int, n) // re-usable slice for Shuffle generator
+ tests := [...]struct {
+ name string
+ fn func() int
+ }{
+ {name: "Int31n", fn: func() int { return int(r.Int31n(int32(nfact))) }},
+ {name: "int31n", fn: func() int { return int(r.int31n(int32(nfact))) }},
+ {name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }},
+ {name: "Shuffle", fn: func() int {
+ // Generate permutation using Shuffle.
+ for i := range p {
+ p[i] = i
+ }
+ r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] })
+ return encodePerm(p)
+ }},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Gather chi-squared values and check that they follow
+ // the expected normal distribution given n!-1 degrees of freedom.
+ // See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and
+ // https://www.johndcook.com/Beautiful_Testing_ch10.pdf.
+ nsamples := 10 * nfact
+ if nsamples < 200 {
+ nsamples = 200
+ }
+ samples := make([]float64, nsamples)
+ for i := range samples {
+ // Generate some uniformly distributed values and count their occurrences.
+ const iters = 1000
+ counts := make([]int, nfact)
+ for i := 0; i < iters; i++ {
+ counts[test.fn()]++
+ }
+ // Calculate chi-squared and add to samples.
+ want := iters / float64(nfact)
+ var χ2 float64
+ for _, have := range counts {
+ err := float64(have) - want
+ χ2 += err * err
+ }
+ χ2 /= want
+ samples[i] = χ2
+ }
+
+ // Check that our samples approximate the appropriate normal distribution.
+ dof := float64(nfact - 1)
+ expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)}
+ errorScale := max(1.0, expected.stddev)
+ expected.closeEnough = 0.10 * errorScale
+ expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211.
+ checkSampleDistribution(t, samples, expected)
+ })
+ }
+ })
+ }
+}
+
+// Benchmarks
+
+func BenchmarkInt63Threadsafe(b *testing.B) {
+ for n := b.N; n > 0; n-- {
+ Int63()
+ }
+}
+
+func BenchmarkInt63ThreadsafeParallel(b *testing.B) {
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ Int63()
+ }
+ })
+}
+
+func BenchmarkInt63Unthreadsafe(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ r.Int63()
+ }
+}
+
+func BenchmarkIntn1000(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ r.Intn(1000)
+ }
+}
+
+func BenchmarkInt63n1000(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ r.Int63n(1000)
+ }
+}
+
+func BenchmarkInt31n1000(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ r.Int31n(1000)
+ }
+}
+
+func BenchmarkFloat32(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ r.Float32()
+ }
+}
+
+func BenchmarkFloat64(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ r.Float64()
+ }
+}
+
+func BenchmarkPerm3(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ r.Perm(3)
+ }
+}
+
+func BenchmarkPerm30(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ r.Perm(30)
+ }
+}
+
+func BenchmarkPerm30ViaShuffle(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ p := make([]int, 30)
+ for i := range p {
+ p[i] = i
+ }
+ r.Shuffle(30, func(i, j int) { p[i], p[j] = p[j], p[i] })
+ }
+}
+
+// BenchmarkShuffleOverhead uses a minimal swap function
+// to measure just the shuffling overhead.
+func BenchmarkShuffleOverhead(b *testing.B) {
+ r := New(NewSource(1))
+ for n := b.N; n > 0; n-- {
+ r.Shuffle(52, func(i, j int) {
+ if i < 0 || i >= 52 || j < 0 || j >= 52 {
+ b.Fatalf("bad swap(%d, %d)", i, j)
+ }
+ })
+ }
+}
+
+func BenchmarkRead3(b *testing.B) {
+ r := New(NewSource(1))
+ buf := make([]byte, 3)
+ b.ResetTimer()
+ for n := b.N; n > 0; n-- {
+ r.Read(buf)
+ }
+}
+
+func BenchmarkRead64(b *testing.B) {
+ r := New(NewSource(1))
+ buf := make([]byte, 64)
+ b.ResetTimer()
+ for n := b.N; n > 0; n-- {
+ r.Read(buf)
+ }
+}
+
+func BenchmarkRead1000(b *testing.B) {
+ r := New(NewSource(1))
+ buf := make([]byte, 1000)
+ b.ResetTimer()
+ for n := b.N; n > 0; n-- {
+ r.Read(buf)
+ }
+}