Skip to content

Commit

Permalink
add Root and RootRound methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Brian P committed Oct 13, 2019
1 parent 0ea7e08 commit 86d84cb
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 0 deletions.
137 changes: 137 additions & 0 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ import (
//
var DivisionPrecision = 16

// RootPrecision is the number of decimal places in the result from the Root
// method.
var RootPrecision = 16

// MarshalJSONWithoutQuotes should be set to true if you want the decimal to
// be JSON marshaled as a number, instead of as a string.
// WARNING: this is dangerous for decimals with many digits, since many JSON
Expand All @@ -55,9 +59,15 @@ var MarshalJSONWithoutQuotes = false
// Zero constant, to make computations faster.
var Zero = New(0, 1)

// oneDec used for incrementing/decrementing
var oneDec = New(1, 0)

// fiveDec used in Cash Rounding
var fiveDec = New(5, 0)

// used for shifting digits
var tenDec = New(10, 0)

var zeroInt = big.NewInt(0)
var oneInt = big.NewInt(1)
var twoInt = big.NewInt(2)
Expand Down Expand Up @@ -557,6 +567,133 @@ func (d Decimal) Pow(d2 Decimal) Decimal {
return temp.Mul(temp).Div(d)
}

// factorial assumes d is a positive whole number
func (d Decimal) factorial() Decimal {
out := oneDec
for i := New(2, 0); i.LessThanOrEqual(d); i = i.Add(oneDec) {
out = out.Mul(i)
}
return out
}

func (d Decimal) combination(r Decimal) Decimal {
top := d.factorial()
bottom := r.factorial().Mul(d.Sub(r).factorial())
return top.Div(bottom)
}

func reverseDecimalSlice(slice []Decimal) {
for i := len(slice)/2 - 1; i >= 0; i-- {
opp := len(slice) - 1 - i
slice[i], slice[opp] = slice[opp], slice[i]
}
}

// returns groups of n digits from the original Decimal, centered and split on
// the position of the decimal place, for purpose of computing the nth root via
// the shifting nth root algo.
//
// 0.303 (n:2) -> [], [30, 30]
// 1000 (n:2) -> [10, 00], []
// 2310.0241 (n:3) -> [2, 310], [024, 100]
func (d Decimal) rootDigitGroups(n Decimal) (left, right []Decimal) {
e := tenDec.Pow(n)

dLeft := d.rescale(0)
for {
if dLeft.Equal(Zero) {
break
}
next := dLeft.Div(e).rescale(0)
left = append(left, dLeft.Sub(next.Mul(e)))
dLeft = next
}
reverseDecimalSlice(left)

dRight := d.Sub(d.rescale(0))
for {
if dRight.Equal(Zero) {
break
}
dRightE := dRight.Mul(e)
dRightETrunc := dRightE.rescale(0)
right = append(right, dRightETrunc)
dRight = dRightE.Sub(dRightETrunc)
}

return left, right
}

// Root returns the nth root of d.
func (d Decimal) Root(n Decimal) Decimal {
return d.RootRound(n, int32(RootPrecision))
}

// RootRound returns the nth root of d. If the result is not a whole number then
// the given precision determines the number of decimal places to calculate the
// result to.
func (d Decimal) RootRound(n Decimal, prec int32) Decimal {
// The nth root is calculated via the shifting root algorithm, which was
// chosen for being definitely correct up to an arbitrary precision, and not
// because it's particularly fast.
//
// https://en.wikipedia.org/wiki/Shifting_nth_root_algorithm
// https://www.wikihow.com/Find-Nth-Roots-by-Hand

leftGroups, rightGroups := d.rootDigitGroups(n)
numLeftGroups := len(leftGroups)

answer, answerNoDecimal, target := Zero, Zero, Zero
var numDigits int32
for {
if numDigits-int32(numLeftGroups) >= prec {
break
}

group := Zero
if len(leftGroups) > 0 {
group, leftGroups = leftGroups[0], leftGroups[1:]
} else if len(rightGroups) > 0 {
group, rightGroups = rightGroups[0], rightGroups[1:]
}
target = target.Mul(tenDec.Pow(n)).Add(group)

nextDigit, nextSub := Zero, Zero
for ; nextDigit.LessThan(tenDec); nextDigit = nextDigit.Add(oneDec) {
tryNextSub := Zero
for i := Zero; i.LessThan(n); i = i.Add(oneDec) {
sumStep := n.combination(i.Add(oneDec))
sumStep = sumStep.Mul(nextDigit.Pow(i))
sumStep = sumStep.Mul(answerNoDecimal.Mul(tenDec).Pow(n.Sub(i).Sub(oneDec)))
tryNextSub = tryNextSub.Add(sumStep)
}
tryNextSub = tryNextSub.Mul(nextDigit)
if tryNextSub.GreaterThan(target) {
break
}
nextSub = tryNextSub
}
nextDigit = nextDigit.Sub(oneDec)

answerNoDecimal = answerNoDecimal.Mul(tenDec).Add(nextDigit)
if numDigits < int32(numLeftGroups) {
answer = answerNoDecimal
} else {
shift := tenDec.Pow(New(int64(numDigits)-int64(numLeftGroups)+1, 0))
answer = answer.Add(nextDigit.DivRound(shift, prec))
}

target = target.Sub(nextSub)
if target.Equal(Zero) && len(leftGroups) == 0 && len(rightGroups) == 0 {
break
}

numDigits++
}

return answer
}

// Cmp compares the numbers represented by d and d2 and returns:
//
// -1 if d < d2
Expand Down
123 changes: 123 additions & 0 deletions decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2094,6 +2094,129 @@ func TestNegativePow(t *testing.T) {
}
}

func TestFactorial(t *testing.T) {
tests := [][2]string{
{"0", "1"},
{"1", "1"},
{"2", "2"},
{"3", "6"},
{"4", "24"},
}

for _, test := range tests {
in, out := RequireFromString(test[0]), RequireFromString(test[1])
if f := in.factorial(); !f.Equal(out) {
t.Errorf("!%v should be %v, got %v", in, out, f)
}
}
}

func TestCombination(t *testing.T) {
tests := [][3]string{
{"4", "1", "4"},
{"4", "2", "6"},
{"4", "3", "4"},
{"4", "4", "1"},
}

for _, test := range tests {
n, r, out := RequireFromString(test[0]), RequireFromString(test[1]), RequireFromString(test[2])
if c := n.combination(r); !c.Equal(out) {
t.Errorf("C(%v,%v) should be %v, got %v", n, r, out, c)
}
}
}

func TestRootDigitGroups(t *testing.T) {
rng := rand.New(rand.NewSource(0xdead1337))
for i := 0; i < 5e4; i++ {
dIV, dIE := rng.Int63n(1e7), rng.Int31n(10)
if rng.Intn(2) == 0 {
dIE = -dIE
}

d := New(dIV, dIE)
n := New(rng.Int63n(3)+1, 0)
nI := int(n.IntPart())
dStr := d.String()
dStrSplit := strings.Split(dStr, ".")

var left, right string
left = dStrSplit[0]
if len(dStrSplit) == 2 {
right = dStrSplit[1]
}

var expLeft, expRight []Decimal
for left != "" && left != "0" {
if len(left) < nI {
expLeft = append(expLeft, RequireFromString(left))
left = ""
} else {
part := left[len(left)-nI:]
expLeft = append(expLeft, RequireFromString(part))
left = left[:len(left)-nI]
}
}
reverseDecimalSlice(expLeft)

for right != "" {
if len(right) < nI {
right += strings.Repeat("0", nI-len(right))
expRight = append(expRight, RequireFromString(right))
right = ""
} else {
part := right[:nI]
expRight = append(expRight, RequireFromString(part))
right = right[nI:]
}
}

gotLeft, gotRight := d.rootDigitGroups(n)
fatalStr := "(%v).rootDigitGroups(%v)\nexpLeft:%v\ngotLeft:%v\nexpRight:%v\ngotRight:%v"
fatalArgs := []interface{}{d, n, expLeft, gotLeft, expRight, gotRight}

if len(expLeft) != len(gotLeft) || len(expRight) != len(gotRight) {
t.Fatalf(fatalStr, fatalArgs...)
}

for j := range expLeft {
if !expLeft[j].Equal(gotLeft[j]) {
t.Fatalf(fatalStr, fatalArgs...)
}
}

for j := range expRight {
if !expRight[j].Equal(gotRight[j]) {
t.Fatalf(fatalStr, fatalArgs...)
}
}
}
}

func TestRoot(t *testing.T) {
rng := rand.New(rand.NewSource(0xdead1337))
for i := 0; i < 2e3; i++ {
rootIV, rootIE := rng.Int63n(1e7), rng.Int31n(10)
if rng.Intn(2) == 0 {
rootIE = -rootIE
}

root := New(rootIV, rootIE)
n := New(rng.Int63n(3)+2, 0) // TODO +1, not +2

d := root
for i := int64(1); i < n.IntPart(); i++ {
d = d.Mul(root)
}

gotRoot := d.RootRound(n, 32)
if !strings.HasPrefix(gotRoot.String(), root.String()) {
t.Fatalf("%v root of %v\nexpected:%v\n got:%v", n, d, root, gotRoot)
}
}
}

func TestDecimal_Sign(t *testing.T) {
if Zero.Sign() != 0 {
t.Errorf("%q should have sign 0", Zero)
Expand Down

0 comments on commit 86d84cb

Please sign in to comment.