1
// Copyright (C) Moondance Labs Ltd.
2
// This file is part of Tanssi.
3

            
4
// Tanssi is free software: you can redistribute it and/or modify
5
// it under the terms of the GNU General Public License as published by
6
// the Free Software Foundation, either version 3 of the License, or
7
// (at your option) any later version.
8

            
9
// Tanssi is distributed in the hope that it will be useful,
10
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
// GNU General Public License for more details.
13

            
14
// You should have received a copy of the GNU General Public License
15
// along with Tanssi.  If not, see <http://www.gnu.org/licenses/>
16

            
17
#![cfg_attr(not(feature = "std"), no_std)]
18

            
19
use {
20
    sp_core::U256,
21
    sp_runtime::traits::{CheckedAdd, CheckedMul, CheckedSub, Zero},
22
};
23

            
24
/// Error returned by math operations which can overflow.
25
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
26
pub struct OverflowError;
27

            
28
/// Error returned by math operations which can underflow.
29
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
30
pub struct UnderflowError;
31

            
32
/// Helper to compute ratios by multiplying then dividing by some values, while
33
/// performing the intermediary computation using a bigger type to avoid
34
/// overflows.
35
pub trait MulDiv: Sized {
36
    /// Multiply self by `a` then divide the result by `b`.
37
    /// Computation will be performed in a bigger type to avoid overflows.
38
    /// After the division, will return `None` if the result is to big for
39
    /// the real type or if `b` is zero.
40
    fn mul_div(self, a: Self, b: Self) -> Result<Self, OverflowError>;
41
}
42

            
43
macro_rules! impl_mul_div {
44
    ($type:ty, $bigger:ty) => {
45
        impl MulDiv for $type {
46
26240
            fn mul_div(self, a: Self, b: Self) -> Result<Self, OverflowError> {
47
26240
                if b.is_zero() {
48
2
                    return Err(OverflowError);
49
26238
                }
50
26238

            
51
26238
                if self.is_zero() {
52
2598
                    return Ok(<$type>::zero());
53
23640
                }
54
23640

            
55
23640
                let s: $bigger = self.into();
56
23640
                let a: $bigger = a.into();
57
23640
                let b: $bigger = b.into();
58
23640

            
59
23640
                let r: $bigger = s * a / b;
60
23640

            
61
23640
                r.try_into().map_err(|_| OverflowError)
62
26240
            }
63
        }
64
    };
65
}
66

            
67
impl_mul_div!(u8, u16);
68
impl_mul_div!(u16, u32);
69
impl_mul_div!(u32, u64);
70
impl_mul_div!(u64, u128);
71
impl_mul_div!(u128, U256);
72

            
73
/// Returns directly an error on overflow.
74
pub trait ErrAdd: CheckedAdd {
75
    /// Returns directly an error on overflow.
76
6447
    fn err_add(&self, v: &Self) -> Result<Self, OverflowError> {
77
6447
        self.checked_add(v).ok_or(OverflowError)
78
6447
    }
79
}
80

            
81
impl<T: CheckedAdd> ErrAdd for T {}
82

            
83
/// Returns directly an error on underflow.
84
pub trait ErrSub: CheckedSub {
85
    /// Returns directly an error on underflow.
86
4451
    fn err_sub(&self, v: &Self) -> Result<Self, UnderflowError> {
87
4451
        self.checked_sub(v).ok_or(UnderflowError)
88
4451
    }
89
}
90

            
91
impl<T: CheckedSub> ErrSub for T {}
92

            
93
/// Returns directly an error on overflow.
94
pub trait ErrMul: CheckedMul {
95
    /// Returns directly an error on overflow.
96
467
    fn err_mul(&self, v: &Self) -> Result<Self, OverflowError> {
97
467
        self.checked_mul(v).ok_or(OverflowError)
98
467
    }
99
}
100

            
101
impl<T: CheckedMul> ErrMul for T {}
102

            
103
#[cfg(test)]
104
mod tests {
105
    use super::*;
106

            
107
    #[test]
108
1
    fn mul_div() {
109
1
        assert_eq!(42u128.mul_div(0, 0), Err(OverflowError));
110
1
        assert_eq!(42u128.mul_div(1, 0), Err(OverflowError));
111

            
112
1
        assert_eq!(u128::MAX.mul_div(2, 4), Ok(u128::MAX / 2));
113
1
        assert_eq!(u128::MAX.mul_div(2, 2), Ok(u128::MAX));
114
1
        assert_eq!(u128::MAX.mul_div(4, 2), Err(OverflowError));
115
1
    }
116
}