Skip to content

Commit

Permalink
Check ratios for integer overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
pendolf authored and Rhymond committed Jul 29, 2024
1 parent 997612c commit bc15d4d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
4 changes: 2 additions & 2 deletions calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ func (c *calculator) modulus(a Amount, d int64) Amount {
return a % d
}

func (c *calculator) allocate(a Amount, r, s uint) Amount {
func (c *calculator) allocate(a Amount, r, s int64) Amount {
if a == 0 || s == 0 {
return 0
}

return a * int64(r) / int64(s)
return a * r / s
}

func (c *calculator) absolute(a Amount) Amount {
Expand Down
22 changes: 14 additions & 8 deletions money.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (

// Injection points for backward compatibility.
// If you need to keep your JSON marshal/unmarshal way, overwrite them like below.
// money.UnmarshalJSON = func (m *Money, b []byte) error { ... }
// money.MarshalJSON = func (m Money) ([]byte, error) { ... }
//
// money.UnmarshalJSON = func (m *Money, b []byte) error { ... }
// money.MarshalJSON = func (m Money) ([]byte, error) { ... }
var (
// UnmarshalJSON is injection point of json.Unmarshaller for money.Money
UnmarshalJSON = defaultUnmarshalJSON
Expand Down Expand Up @@ -265,19 +266,22 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) {
}

// Calculate sum of ratios.
var sum uint
var sum int64
for _, r := range rs {
if r < 0 {
return nil, errors.New("negative ratios not allowed")
}
sum += uint(r)
if int64(r) > (math.MaxInt64 - sum) {
return nil, errors.New("sum of given ratios exceeds max int")
}
sum += int64(r)
}

var total int64
ms := make([]*Money, 0, len(rs))
for _, r := range rs {
party := &Money{
amount: mutate.calc.allocate(m.amount, uint(r), sum),
amount: mutate.calc.allocate(m.amount, int64(r), sum),
currency: m.currency,
}

Expand Down Expand Up @@ -329,9 +333,11 @@ func (m Money) MarshalJSON() ([]byte, error) {
}

// Compare function compares two money of the same type
// if m.amount > om.amount returns (1, nil)
// if m.amount == om.amount returns (0, nil
// if m.amount < om.amount returns (-1, nil)
//
// if m.amount > om.amount returns (1, nil)
// if m.amount == om.amount returns (0, nil
// if m.amount < om.amount returns (-1, nil)
//
// If compare moneys from distinct currency, return (m.amount, ErrCurrencyMismatch)
func (m *Money) Compare(om *Money) (int, error) {
if err := m.assertSameCurrency(om); err != nil {
Expand Down
21 changes: 14 additions & 7 deletions money_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"reflect"
"testing"
)
Expand Down Expand Up @@ -318,7 +319,6 @@ func TestMoney_Add(t *testing.T) {
m := New(tc.amount1, EUR)
om := New(tc.amount2, EUR)
r, err := m.Add(om)

if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -355,7 +355,6 @@ func TestMoney_Subtract(t *testing.T) {
m := New(tc.amount1, EUR)
om := New(tc.amount2, EUR)
r, err := m.Subtract(om)

if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -521,6 +520,19 @@ func TestMoney_Allocate2(t *testing.T) {
}
}

func TestAllocateOverflow(t *testing.T) {
m := New(math.MaxInt64, EUR)
_, err := m.Allocate(math.MaxInt, 1)
if err == nil {
t.Fatalf("expected an error, but got nil")
}

expectedErrorMessage := "sum of given ratios exceeds max int"
if err.Error() != expectedErrorMessage {
t.Fatalf("expected error message %q, but got %q", expectedErrorMessage, err.Error())
}
}

func TestMoney_Format(t *testing.T) {
tcs := []struct {
amount int64
Expand Down Expand Up @@ -583,7 +595,6 @@ func TestMoney_AsMajorUnits(t *testing.T) {
func TestMoney_Allocate3(t *testing.T) {
pound := New(100, GBP)
parties, err := pound.Allocate(33, 33, 33)

if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -655,7 +666,6 @@ func TestMoney_Comparison(t *testing.T) {
t.Errorf("Expected %d Equals to %d == %d got %d", anotherTwoEuros.amount,
twoEuros.amount, 0, r)
}

}

func TestMoney_Currency(t *testing.T) {
Expand Down Expand Up @@ -717,7 +727,6 @@ func TestDefaultMarshal(t *testing.T) {
expected := `{"amount":12345,"currency":"IQD"}`

b, err := json.Marshal(given)

if err != nil {
t.Error(err)
}
Expand All @@ -730,7 +739,6 @@ func TestDefaultMarshal(t *testing.T) {
expected = `{"amount":0,"currency":""}`

b, err = json.Marshal(given)

if err != nil {
t.Error(err)
}
Expand All @@ -749,7 +757,6 @@ func TestCustomMarshal(t *testing.T) {
}

b, err := json.Marshal(given)

if err != nil {
t.Error(err)
}
Expand Down

0 comments on commit bc15d4d

Please sign in to comment.