summaryrefslogtreecommitdiffstats
path: root/third_party/rust/rust_decimal/src/ops/mul.rs
blob: b36729599da004bf16e0e2202785c2a06e8dfb17 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
use crate::constants::{BIG_POWERS_10, MAX_I64_SCALE, MAX_PRECISION_U32, U32_MAX};
use crate::decimal::{CalculationResult, Decimal};
use crate::ops::common::Buf24;

pub(crate) fn mul_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult {
    if d1.is_zero() || d2.is_zero() {
        // We should think about this - does zero need to maintain precision? This treats it like
        // an absolute which I think is ok, especially since we have is_zero() functions etc.
        return CalculationResult::Ok(Decimal::ZERO);
    }

    let mut scale = d1.scale() + d2.scale();
    let negative = d1.is_sign_negative() ^ d2.is_sign_negative();
    let mut product = Buf24::zero();

    // See if we can optimize this calculation depending on whether the hi bits are set
    if d1.hi() | d1.mid() == 0 {
        if d2.hi() | d2.mid() == 0 {
            // We're multiplying two 32 bit integers, so we can take some liberties to optimize this.
            let mut low64 = d1.lo() as u64 * d2.lo() as u64;
            if scale > MAX_PRECISION_U32 {
                // We've exceeded maximum scale so we need to start reducing the precision (aka
                // rounding) until we have something that fits.
                // If we're too big then we effectively round to zero.
                if scale > MAX_PRECISION_U32 + MAX_I64_SCALE {
                    return CalculationResult::Ok(Decimal::ZERO);
                }

                scale -= MAX_PRECISION_U32 + 1;
                let mut power = BIG_POWERS_10[scale as usize];

                let tmp = low64 / power;
                let remainder = low64 - tmp * power;
                low64 = tmp;

                // Round the result. Since the divisor was a power of 10, it's always even.
                power >>= 1;
                if remainder >= power && (remainder > power || (low64 as u32 & 1) > 0) {
                    low64 += 1;
                }

                scale = MAX_PRECISION_U32;
            }

            // Early exit
            return CalculationResult::Ok(Decimal::from_parts(
                low64 as u32,
                (low64 >> 32) as u32,
                0,
                negative,
                scale,
            ));
        }

        // We know that the left hand side is just 32 bits but the right hand side is either
        // 64 or 96 bits.
        mul_by_32bit_lhs(d1.lo() as u64, d2, &mut product);
    } else if d2.mid() | d2.hi() == 0 {
        // We know that the right hand side is just 32 bits.
        mul_by_32bit_lhs(d2.lo() as u64, d1, &mut product);
    } else {
        // We know we're not dealing with simple 32 bit operands on either side.
        // We compute and accumulate the 9 partial products using long multiplication

        // 1: ll * rl
        let mut tmp = d1.lo() as u64 * d2.lo() as u64;
        product.data[0] = tmp as u32;

        // 2: ll * rm
        let mut tmp2 = (d1.lo() as u64 * d2.mid() as u64).wrapping_add(tmp >> 32);

        // 3: lm * rl
        tmp = d1.mid() as u64 * d2.lo() as u64;
        tmp = tmp.wrapping_add(tmp2);
        product.data[1] = tmp as u32;

        // Detect if carry happened from the wrapping add
        if tmp < tmp2 {
            tmp2 = (tmp >> 32) | (1u64 << 32);
        } else {
            tmp2 = tmp >> 32;
        }

        // 4: lm * rm
        tmp = (d1.mid() as u64 * d2.mid() as u64) + tmp2;

        // If the high bit isn't set then we can stop here. Otherwise, we need to continue calculating
        // using the high bits.
        if (d1.hi() | d2.hi()) > 0 {
            // 5. ll * rh
            tmp2 = d1.lo() as u64 * d2.hi() as u64;
            tmp = tmp.wrapping_add(tmp2);
            // Detect if we carried
            let mut tmp3 = if tmp < tmp2 { 1 } else { 0 };

            // 6. lh * rl
            tmp2 = d1.hi() as u64 * d2.lo() as u64;
            tmp = tmp.wrapping_add(tmp2);
            product.data[2] = tmp as u32;
            // Detect if we carried
            if tmp < tmp2 {
                tmp3 += 1;
            }
            tmp2 = (tmp3 << 32) | (tmp >> 32);

            // 7. lm * rh
            tmp = d1.mid() as u64 * d2.hi() as u64;
            tmp = tmp.wrapping_add(tmp2);
            // Check for carry
            tmp3 = if tmp < tmp2 { 1 } else { 0 };

            // 8. lh * rm
            tmp2 = d1.hi() as u64 * d2.mid() as u64;
            tmp = tmp.wrapping_add(tmp2);
            product.data[3] = tmp as u32;
            // Check for carry
            if tmp < tmp2 {
                tmp3 += 1;
            }
            tmp = (tmp3 << 32) | (tmp >> 32);

            // 9. lh * rh
            product.set_high64(d1.hi() as u64 * d2.hi() as u64 + tmp);
        } else {
            product.set_mid64(tmp);
        }
    }

    // We may want to "rescale". This is the case if the mantissa is > 96 bits or if the scale
    // exceeds the maximum precision.
    let upper_word = product.upper_word();
    if upper_word > 2 || scale > MAX_PRECISION_U32 {
        scale = if let Some(new_scale) = product.rescale(upper_word, scale) {
            new_scale
        } else {
            return CalculationResult::Overflow;
        }
    }

    CalculationResult::Ok(Decimal::from_parts(
        product.data[0],
        product.data[1],
        product.data[2],
        negative,
        scale,
    ))
}

#[inline(always)]
fn mul_by_32bit_lhs(d1: u64, d2: &Decimal, product: &mut Buf24) {
    let mut tmp = d1 * d2.lo() as u64;
    product.data[0] = tmp as u32;
    tmp = (d1 * d2.mid() as u64).wrapping_add(tmp >> 32);
    product.data[1] = tmp as u32;
    tmp >>= 32;

    // If we're multiplying by a 96 bit integer then continue the calculation
    if d2.hi() > 0 {
        tmp = tmp.wrapping_add(d1 * d2.hi() as u64);
        if tmp > U32_MAX {
            product.set_mid64(tmp);
        } else {
            product.data[2] = tmp as u32;
        }
    } else {
        product.data[2] = tmp as u32;
    }
}