Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shopspring/decimal support to avoid float precisions #140

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 16 additions & 37 deletions calculator.go
Original file line number Diff line number Diff line change
@@ -1,73 +1,52 @@
package money

import "math"
import (
"github.com/shopspring/decimal"
)

type calculator struct{}

func (c *calculator) add(a, b Amount) Amount {
return a + b
return a.Add(b)
}

func (c *calculator) subtract(a, b Amount) Amount {
return a - b
return a.Sub(b)
}

func (c *calculator) multiply(a Amount, m int64) Amount {
return a * m
return a.Mul(decimal.NewFromInt(m))
}

func (c *calculator) divide(a Amount, d int64) Amount {
return a / d
return a.Div(decimal.NewFromInt(d))
}

func (c *calculator) modulus(a Amount, d int64) Amount {
return a % d
return a.Mod(decimal.NewFromInt(d))
}

func (c *calculator) allocate(a Amount, r, s uint) Amount {
if a == 0 || s == 0 {
return 0
if a.IsZero() || s == 0 {
return decimal.Zero
}

return a * int64(r) / int64(s)
res := a.Mul(decimal.NewFromInt(int64(r))).Div(decimal.NewFromInt(int64(s))).IntPart()
return decimal.NewFromInt(res)
}

func (c *calculator) absolute(a Amount) Amount {
if a < 0 {
return -a
}

return a
return a.Abs()
}

func (c *calculator) negative(a Amount) Amount {
if a > 0 {
return -a
if a.IsPositive() {
return a.Mul(decimal.NewFromInt(-1))
}

return a
}

func (c *calculator) round(a Amount, e int) Amount {
if a == 0 {
return 0
}

absam := c.absolute(a)
exp := int64(math.Pow(10, float64(e)))
m := absam % exp

if m > (exp / 2) {
absam += exp
}

absam = (absam / exp) * exp

if a < 0 {
a = -absam
} else {
a = absam
}

return a
return a.Round(int32(e * -1))
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/Rhymond/go-money

go 1.13

require github.com/shopspring/decimal v1.3.1 // indirect
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
72 changes: 35 additions & 37 deletions money.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import (
"encoding/json"
"errors"
"fmt"
"math"

"github.com/shopspring/decimal"
)

// Injection points for backward compatibility.
// If you need to keep your JSON marshal/unmarshal way, overwrite them like below.
// money.UnmarshalJSON = func (m *Money, b []byte) error { ... }
// money.MarshalJSON = func (m Money) ([]byte, error) { ... }
//
// money.UnmarshalJSON = func (m *Money, b []byte) error { ... }
// money.MarshalJSON = func (m Money) ([]byte, error) { ... }
var (
// UnmarshalJSON is injection point of json.Unmarshaller for money.Money
UnmarshalJSON = defaultUnmarshalJSON
Expand Down Expand Up @@ -69,7 +71,7 @@ func defaultMarshalJSON(m Money) ([]byte, error) {
}

// Amount is a data structure that stores the amount being used for calculations.
type Amount = int64
type Amount = decimal.Decimal

// Money represents monetary value information, stores
// currency and amount value.
Expand All @@ -81,16 +83,17 @@ type Money struct {
// New creates and returns new instance of Money.
func New(amount int64, code string) *Money {
return &Money{
amount: amount,
amount: decimal.NewFromInt(amount),
currency: newCurrency(code).get(),
}
}

// NewFromFloat creates and returns new instance of Money from a float64.
// Always rounding trailing decimals down.
func NewFromFloat(amount float64, currency string) *Money {
currencyDecimals := math.Pow10(GetCurrency(currency).Fraction)
return New(int64(amount*currencyDecimals), currency)
c := GetCurrency(currency)
amt := decimal.NewFromFloat(amount).Mul(decimal.New(1, int32(c.Fraction)))
return New(amt.IntPart(), currency)
}

// Currency returns the currency used by Money.
Expand All @@ -100,7 +103,7 @@ func (m *Money) Currency() *Currency {

// Amount returns a copy of the internal monetary value as an int64.
func (m *Money) Amount() int64 {
return m.amount
return m.amount.IntPart()
}

// SameCurrency check if given Money is equals by currency.
Expand All @@ -117,14 +120,7 @@ func (m *Money) assertSameCurrency(om *Money) error {
}

func (m *Money) compare(om *Money) int {
switch {
case m.amount > om.amount:
return 1
case m.amount < om.amount:
return -1
}

return 0
return m.amount.Cmp(om.amount)
}

// Equals checks equality between two Money types.
Expand Down Expand Up @@ -174,17 +170,17 @@ func (m *Money) LessThanOrEqual(om *Money) (bool, error) {

// IsZero returns boolean of whether the value of Money is equals to zero.
func (m *Money) IsZero() bool {
return m.amount == 0
return m.amount.IsZero()
}

// IsPositive returns boolean of whether the value of Money is positive.
func (m *Money) IsPositive() bool {
return m.amount > 0
return m.amount.IsPositive()
}

// IsNegative returns boolean of whether the value of Money is negative.
func (m *Money) IsNegative() bool {
return m.amount < 0
return m.amount.IsNegative()
}

// Absolute returns new Money struct from given Money using absolute monetary value.
Expand Down Expand Up @@ -245,12 +241,12 @@ func (m *Money) Split(n int) ([]*Money, error) {
// Add leftovers to the first parties.

v := int64(1)
if m.amount < 0 {
if m.amount.IsNegative() {
v = -1
}
for p := 0; l != 0; p++ {
ms[p].amount = mutate.calc.add(ms[p].amount, v)
l--
for p := 0; !l.IsZero(); p++ {
ms[p].amount = mutate.calc.add(ms[p].amount, decimal.NewFromInt(v))
l = l.Sub(decimal.NewFromInt(1))
}

return ms, nil
Expand All @@ -273,7 +269,7 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) {
sum += uint(r)
}

var total int64
var total decimal.Decimal
ms := make([]*Money, 0, len(rs))
for _, r := range rs {
party := &Money{
Expand All @@ -282,7 +278,7 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) {
}

ms = append(ms, party)
total += party.amount
total = total.Add(party.amount)
}

// if the sum of all ratios is zero, then we just returns zeros and don't do anything
Expand All @@ -292,15 +288,15 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) {
}

// Calculate leftover value and divide to first parties.
lo := m.amount - total
sub := int64(1)
if lo < 0 {
sub = -sub
lo := m.amount.Sub(total)
sub := decimal.NewFromInt(1)
if lo.IsNegative() {
sub = sub.Mul(decimal.NewFromInt(-1))
}

for p := 0; lo != 0; p++ {
for p := 0; !lo.IsZero(); p++ {
ms[p].amount = mutate.calc.add(ms[p].amount, sub)
lo -= sub
lo = lo.Sub(sub)
}

return ms, nil
Expand All @@ -309,13 +305,13 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) {
// Display lets represent Money struct as string in given Currency value.
func (m *Money) Display() string {
c := m.currency.get()
return c.Formatter().Format(m.amount)
return c.Formatter().Format(m.amount.IntPart())
}

// AsMajorUnits lets represent Money struct as subunits (float64) in given Currency value
func (m *Money) AsMajorUnits() float64 {
c := m.currency.get()
return c.Formatter().ToMajorUnits(m.amount)
return c.Formatter().ToMajorUnits(m.amount.IntPart())
}

// UnmarshalJSON is implementation of json.Unmarshaller
Expand All @@ -329,13 +325,15 @@ func (m Money) MarshalJSON() ([]byte, error) {
}

// Compare function compares two money of the same type
// if m.amount > om.amount returns (1, nil)
// if m.amount == om.amount returns (0, nil
// if m.amount < om.amount returns (-1, nil)
//
// if m.amount > om.amount returns (1, nil)
// if m.amount == om.amount returns (0, nil
// if m.amount < om.amount returns (-1, nil)
//
// If compare moneys from distinct currency, return (m.amount, ErrCurrencyMismatch)
func (m *Money) Compare(om *Money) (int, error) {
if err := m.assertSameCurrency(om); err != nil {
return int(m.amount), err
return int(m.amount.IntPart()), err
}

return m.compare(om), nil
Expand Down
8 changes: 8 additions & 0 deletions money_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ func ExampleNew() {
// £1.00
}

func ExampleNewFromFloat() {
amount := 136.98
fmt.Println(money.NewFromFloat(amount, "SGD").Display())

// Output:
// $136.98
}

func ExampleMoney_comparisons() {
pound := money.New(100, "GBP")
twoPounds := money.New(200, "GBP")
Expand Down
Loading