diff --git a/calculator.go b/calculator.go index 480862c..2fba307 100644 --- a/calculator.go +++ b/calculator.go @@ -24,12 +24,12 @@ func (c *calculator) modulus(a Amount, d int64) Amount { return a % d } -func (c *calculator) allocate(a Amount, r, s uint) Amount { +func (c *calculator) allocate(a Amount, r, s int64) Amount { if a == 0 || s == 0 { return 0 } - return a * int64(r) / int64(s) + return a * r / s } func (c *calculator) absolute(a Amount) Amount { diff --git a/money.go b/money.go index 51c623f..9e8ecd7 100644 --- a/money.go +++ b/money.go @@ -10,8 +10,9 @@ import ( // Injection points for backward compatibility. // If you need to keep your JSON marshal/unmarshal way, overwrite them like below. -// money.UnmarshalJSON = func (m *Money, b []byte) error { ... } -// money.MarshalJSON = func (m Money) ([]byte, error) { ... } +// +// money.UnmarshalJSON = func (m *Money, b []byte) error { ... } +// money.MarshalJSON = func (m Money) ([]byte, error) { ... } var ( // UnmarshalJSON is injection point of json.Unmarshaller for money.Money UnmarshalJSON = defaultUnmarshalJSON @@ -265,19 +266,22 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) { } // Calculate sum of ratios. - var sum uint + var sum int64 for _, r := range rs { if r < 0 { return nil, errors.New("negative ratios not allowed") } - sum += uint(r) + if int64(r) > (math.MaxInt64 - sum) { + return nil, errors.New("sum of given ratios exceeds max int") + } + sum += int64(r) } var total int64 ms := make([]*Money, 0, len(rs)) for _, r := range rs { party := &Money{ - amount: mutate.calc.allocate(m.amount, uint(r), sum), + amount: mutate.calc.allocate(m.amount, int64(r), sum), currency: m.currency, } @@ -329,9 +333,11 @@ func (m Money) MarshalJSON() ([]byte, error) { } // Compare function compares two money of the same type -// if m.amount > om.amount returns (1, nil) -// if m.amount == om.amount returns (0, nil -// if m.amount < om.amount returns (-1, nil) +// +// if m.amount > om.amount returns (1, nil) +// if m.amount == om.amount returns (0, nil +// if m.amount < om.amount returns (-1, nil) +// // If compare moneys from distinct currency, return (m.amount, ErrCurrencyMismatch) func (m *Money) Compare(om *Money) (int, error) { if err := m.assertSameCurrency(om); err != nil { diff --git a/money_test.go b/money_test.go index d687222..fe309c4 100644 --- a/money_test.go +++ b/money_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "reflect" "testing" ) @@ -318,7 +319,6 @@ func TestMoney_Add(t *testing.T) { m := New(tc.amount1, EUR) om := New(tc.amount2, EUR) r, err := m.Add(om) - if err != nil { t.Error(err) } @@ -355,7 +355,6 @@ func TestMoney_Subtract(t *testing.T) { m := New(tc.amount1, EUR) om := New(tc.amount2, EUR) r, err := m.Subtract(om) - if err != nil { t.Error(err) } @@ -521,6 +520,19 @@ func TestMoney_Allocate2(t *testing.T) { } } +func TestAllocateOverflow(t *testing.T) { + m := New(math.MaxInt64, EUR) + _, err := m.Allocate(math.MaxInt, 1) + if err == nil { + t.Fatalf("expected an error, but got nil") + } + + expectedErrorMessage := "sum of given ratios exceeds max int" + if err.Error() != expectedErrorMessage { + t.Fatalf("expected error message %q, but got %q", expectedErrorMessage, err.Error()) + } +} + func TestMoney_Format(t *testing.T) { tcs := []struct { amount int64 @@ -583,7 +595,6 @@ func TestMoney_AsMajorUnits(t *testing.T) { func TestMoney_Allocate3(t *testing.T) { pound := New(100, GBP) parties, err := pound.Allocate(33, 33, 33) - if err != nil { t.Error(err) } @@ -655,7 +666,6 @@ func TestMoney_Comparison(t *testing.T) { t.Errorf("Expected %d Equals to %d == %d got %d", anotherTwoEuros.amount, twoEuros.amount, 0, r) } - } func TestMoney_Currency(t *testing.T) { @@ -717,7 +727,6 @@ func TestDefaultMarshal(t *testing.T) { expected := `{"amount":12345,"currency":"IQD"}` b, err := json.Marshal(given) - if err != nil { t.Error(err) } @@ -730,7 +739,6 @@ func TestDefaultMarshal(t *testing.T) { expected = `{"amount":0,"currency":""}` b, err = json.Marshal(given) - if err != nil { t.Error(err) } @@ -749,7 +757,6 @@ func TestCustomMarshal(t *testing.T) { } b, err := json.Marshal(given) - if err != nil { t.Error(err) }