summaryrefslogtreecommitdiffstats
path: root/src/cmd/compile/internal/syntax/testdata/linalg.go
blob: 822d0287e7490ada0681ce74a10de3ab0122c765 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package linalg

import "math"

// Numeric is type bound that matches any numeric type.
// It would likely be in a constraints package in the standard library.
type Numeric interface {
	~int | ~int8 | ~int16 | ~int32 | ~int64 |
		uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr |
		float32 | ~float64 |
		complex64 | ~complex128
}

func DotProduct[T Numeric](s1, s2 []T) T {
	if len(s1) != len(s2) {
		panic("DotProduct: slices of unequal length")
	}
	var r T
	for i := range s1 {
		r += s1[i] * s2[i]
	}
	return r
}

// NumericAbs matches numeric types with an Abs method.
type NumericAbs[T any] interface {
	Numeric

	Abs() T
}

// AbsDifference computes the absolute value of the difference of
// a and b, where the absolute value is determined by the Abs method.
func AbsDifference[T NumericAbs[T]](a, b T) T {
	d := a - b
	return d.Abs()
}

// OrderedNumeric is a type bound that matches numeric types that support the < operator.
type OrderedNumeric interface {
	~int | ~int8 | ~int16 | ~int32 | ~int64 |
		uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr |
		float32 | ~float64
}

// Complex is a type bound that matches the two complex types, which do not have a < operator.
type Complex interface {
	~complex64 | ~complex128
}

// OrderedAbs is a helper type that defines an Abs method for
// ordered numeric types.
type OrderedAbs[T OrderedNumeric] T

func (a OrderedAbs[T]) Abs() OrderedAbs[T] {
	if a < 0 {
		return -a
	}
	return a
}

// ComplexAbs is a helper type that defines an Abs method for
// complex types.
type ComplexAbs[T Complex] T

func (a ComplexAbs[T]) Abs() ComplexAbs[T] {
	r := float64(real(a))
	i := float64(imag(a))
	d := math.Sqrt(r * r + i * i)
	return ComplexAbs[T](complex(d, 0))
}

func OrderedAbsDifference[T OrderedNumeric](a, b T) T {
	return T(AbsDifference(OrderedAbs[T](a), OrderedAbs[T](b)))
}

func ComplexAbsDifference[T Complex](a, b T) T {
	return T(AbsDifference(ComplexAbs[T](a), ComplexAbs[T](b)))
}