Skip to content

Rewrite BigDecimal#sqrt in ruby with improved Newton's method #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 0 additions & 203 deletions ext/bigdecimal/bigdecimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -2245,32 +2245,6 @@ BigDecimal_abs(VALUE self)
return CheckGetValue(c);
}

/* call-seq:
* sqrt(n)
*
* Returns the square root of the value.
*
* Result has at least n significant digits.
*/
static VALUE
BigDecimal_sqrt(VALUE self, VALUE nFig)
{
BDVALUE c, a;
size_t mx, n;

a = GetBDValueMust(self);
mx = a.real->Prec * (VpBaseFig() + 1);

n = check_int_precision(nFig);
n += VpDblFig() + VpBaseFig();
if (mx <= n) mx = n;
c = NewZeroWrapLimited(1, mx);
VpSqrt(c.real, a.real);

RB_GC_GUARD(a.bigdecimal);
return CheckGetValue(c);
}

/* Return the integer part of the number, as a BigDecimal.
*/
static VALUE
Expand Down Expand Up @@ -3772,7 +3746,6 @@ Init_bigdecimal(void)
rb_define_method(rb_cBigDecimal, "dup", BigDecimal_clone, 0);
rb_define_method(rb_cBigDecimal, "to_f", BigDecimal_to_f, 0);
rb_define_method(rb_cBigDecimal, "abs", BigDecimal_abs, 0);
rb_define_method(rb_cBigDecimal, "sqrt", BigDecimal_sqrt, 1);
rb_define_method(rb_cBigDecimal, "fix", BigDecimal_fix, 0);
rb_define_method(rb_cBigDecimal, "round", BigDecimal_round, -1);
rb_define_method(rb_cBigDecimal, "frac", BigDecimal_frac, 0);
Expand Down Expand Up @@ -3844,9 +3817,6 @@ static int gfDebug = 1; /* Debug switch */
#endif /* BIGDECIMAL_DEBUG */

static Real *VpConstOne; /* constant 1.0 */
static Real *VpConstPt5; /* constant 0.5 */
#define maxnr 100UL /* Maximum iterations for calculating sqrt. */
/* used in VpSqrt() */

enum op_sw {
OP_SW_ADD = 1, /* + */
Expand Down Expand Up @@ -4252,11 +4222,6 @@ VpInit(DECDIG BaseVal)
/* Const 1.0 */
VpConstOne = NewOneNolimit(1, 1);

/* Const 0.5 */
VpConstPt5 = NewOneNolimit(1, 1);
VpConstPt5->exponent = 0;
VpConstPt5->frac[0] = 5*BASE1;

#ifdef BIGDECIMAL_DEBUG
gnAlloc = 0;
#endif /* BIGDECIMAL_DEBUG */
Expand Down Expand Up @@ -6085,174 +6050,6 @@ VpVtoD(double *d, SIGNED_VALUE *e, Real *m)
return f;
}

/*
* m <- d
*/
VP_EXPORT void
VpDtoV(Real *m, double d)
{
size_t ind_m, mm;
SIGNED_VALUE ne;
DECDIG i;
double val, val2;

if (isnan(d)) {
VpSetNaN(m);
goto Exit;
}
if (isinf(d)) {
if (d > 0.0) VpSetPosInf(m);
else VpSetNegInf(m);
goto Exit;
}

if (d == 0.0) {
VpSetZero(m, 1);
goto Exit;
}
val = (d > 0.) ? d : -d;
ne = 0;
if (val >= 1.0) {
while (val >= 1.0) {
val /= (double)BASE;
++ne;
}
}
else {
val2 = 1.0 / (double)BASE;
while (val < val2) {
val *= (double)BASE;
--ne;
}
}
/* Now val = 0.xxxxx*BASE**ne */

mm = m->MaxPrec;
memset(m->frac, 0, mm * sizeof(DECDIG));
for (ind_m = 0; val > 0.0 && ind_m < mm; ind_m++) {
val *= (double)BASE;
i = (DECDIG)val;
val -= (double)i;
m->frac[ind_m] = i;
}
if (ind_m >= mm) ind_m = mm - 1;
VpSetSign(m, (d > 0.0) ? 1 : -1);
m->Prec = ind_m + 1;
m->exponent = ne;

VpInternalRound(m, 0, (m->Prec > 0) ? m->frac[m->Prec-1] : 0,
(DECDIG)(val*(double)BASE));

Exit:
return;
}

/*
* y = SQRT(x), y*y - x =>0
*/
VP_EXPORT int
VpSqrt(Real *y, Real *x)
{
Real *f = NULL;
Real *r = NULL;
size_t y_prec;
SIGNED_VALUE n, e;
ssize_t nr;
double val;

/* Zero or +Infinity ? */
if (VpIsZero(x) || VpIsPosInf(x)) {
VpAsgn(y,x,1);
goto Exit;
}

/* Negative ? */
if (BIGDECIMAL_NEGATIVE_P(x)) {
VpSetNaN(y);
return VpException(VP_EXCEPTION_OP, "sqrt of negative value", 0);
}

/* NaN ? */
if (VpIsNaN(x)) {
VpSetNaN(y);
return VpException(VP_EXCEPTION_OP, "sqrt of 'NaN'(Not a Number)", 0);
}

/* One ? */
if (VpIsOne(x)) {
VpSetOne(y);
goto Exit;
}

n = (SIGNED_VALUE)y->MaxPrec;
if (x->MaxPrec > (size_t)n) n = (ssize_t)x->MaxPrec;

/* allocate temporally variables */
/* TODO: reconsider MaxPrec of f and r */
f = NewOneNolimit(1, y->MaxPrec * (BASE_FIG + 2));
r = NewOneNolimit(1, (n + n) * (BASE_FIG + 2));

nr = 0;
y_prec = y->MaxPrec;

VpVtoD(&val, &e, x); /* val <- x */
e /= (SIGNED_VALUE)BASE_FIG;
n = e / 2;
if (e - n * 2 != 0) {
val /= BASE;
n = (e + 1) / 2;
}
VpDtoV(y, sqrt(val)); /* y <- sqrt(val) */
y->exponent += n;
n = (SIGNED_VALUE)roomof(BIGDECIMAL_DOUBLE_FIGURES, BASE_FIG);
y->MaxPrec = Min((size_t)n , y_prec);
f->MaxPrec = y->MaxPrec + 1;
n = (SIGNED_VALUE)(y_prec * BASE_FIG);
if (n > (SIGNED_VALUE)maxnr) n = (SIGNED_VALUE)maxnr;

/*
* Perform: y_{n+1} = (y_n - x/y_n) / 2
*/
do {
y->MaxPrec *= 2;
if (y->MaxPrec > y_prec) y->MaxPrec = y_prec;
f->MaxPrec = y->MaxPrec;
VpDivd(f, r, x, y); /* f = x/y */
VpAddSub(r, f, y, -1); /* r = f - y */
VpMult(f, VpConstPt5, r); /* f = 0.5*r */
if (y_prec == y->MaxPrec && VpIsZero(f))
goto converge;
VpAddSub(r, f, y, 1); /* r = y + f */
VpAsgn(y, r, 1); /* y = r */
} while (++nr < n);

#ifdef BIGDECIMAL_DEBUG
if (gfDebug) {
printf("ERROR(VpSqrt): did not converge within %ld iterations.\n", nr);
}
#endif /* BIGDECIMAL_DEBUG */
y->MaxPrec = y_prec;

converge:
VpChangeSign(y, 1);
#ifdef BIGDECIMAL_DEBUG
if (gfDebug) {
VpMult(r, y, y);
VpAddSub(f, x, r, -1);
printf("VpSqrt: iterations = %"PRIdSIZE"\n", nr);
VPrint(stdout, " y =% \n", y);
VPrint(stdout, " x =% \n", x);
VPrint(stdout, " x-y*y = % \n", f);
}
#endif /* BIGDECIMAL_DEBUG */
y->MaxPrec = y_prec;

Exit:
rbd_free_struct(f);
rbd_free_struct(r);
return 1;
}

/*
* Round relatively from the decimal point.
* f: rounding mode
Expand Down
3 changes: 0 additions & 3 deletions ext/bigdecimal/bigdecimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ typedef struct {
*/

#define VpBaseFig() BIGDECIMAL_COMPONENT_FIGURES
#define VpDblFig() BIGDECIMAL_DOUBLE_FIGURES

/* Zero,Inf,NaN (isinf(),isnan() used to check) */
VP_EXPORT double VpGetDoubleNaN(void);
Expand Down Expand Up @@ -229,8 +228,6 @@ VP_EXPORT void VpToString(Real *a, char *buf, size_t bufsize, size_t fFmt, int f
VP_EXPORT void VpToFString(Real *a, char *buf, size_t bufsize, size_t fFmt, int fPlus);
VP_EXPORT int VpCtoV(Real *a, const char *int_chr, size_t ni, const char *frac, size_t nf, const char *exp_chr, size_t ne);
VP_EXPORT int VpVtoD(double *d, SIGNED_VALUE *e, Real *m);
VP_EXPORT void VpDtoV(Real *m,double d);
VP_EXPORT int VpSqrt(Real *y,Real *x);
VP_EXPORT int VpActiveRound(Real *y, Real *x, unsigned short f, ssize_t il);
VP_EXPORT int VpMidRound(Real *y, unsigned short f, ssize_t nf);
VP_EXPORT int VpLeftRound(Real *y, unsigned short f, ssize_t nf);
Expand Down
43 changes: 35 additions & 8 deletions lib/bigdecimal.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ def self.coerce_to_bigdecimal(x, prec, method_name) # :nodoc:
raise ArgumentError, "#{x.inspect} can't be coerced into BigDecimal"
end

def self.validate_prec(prec, method_name) # :nodoc:
def self.validate_prec(prec, method_name, accept_zero: false) # :nodoc:
raise ArgumentError, 'precision must be an Integer' unless Integer === prec
raise ArgumentError, "Zero or negative precision for #{method_name}" if prec <= 0
if accept_zero
raise ArgumentError, "Negative precision for #{method_name}" if prec < 0
else
raise ArgumentError, "Zero or negative precision for #{method_name}" if prec <= 0
end
end

def self.infinity_computation_result # :nodoc:
Expand Down Expand Up @@ -172,6 +176,34 @@ def power(y, prec = nil)
end
ans.mult(1, prec)
end

# Returns the square root of the value.
#
# Result has at least prec significant digits.
#
def sqrt(prec)
Internal.validate_prec(prec, :sqrt, accept_zero: true)
return Internal.infinity_computation_result if infinite? == 1

raise FloatDomainError, 'sqrt of negative value' if self < 0
raise FloatDomainError, "sqrt of 'NaN'(Not a Number)" if nan?
return self if zero?

# BigDecimal#sqrt calculates at least n_significant_digits precision.
# This feature maybe problematic for some cases.
n_digits = n_significant_digits
prec = [prec, n_digits].max

ex = exponent / 2
x = self * BigDecimal("1e#{-ex * 2}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tompng, I think this square root calculation can be faster if we use decimal shift operations here and on line 37, as we previously discussed. What about this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good to use decimal shift operator. I've updated #324 as ready for review.

y = BigDecimal(Math.sqrt(x.to_f))
precs = [prec + BigDecimal.double_fig]
precs << 2 + precs.last / 2 while precs.last > BigDecimal.double_fig
precs.reverse_each do |p|
y = (y + x.div(y, p)).div(2, p)
end
y * BigDecimal("1e#{ex}")
end
end

# Core BigMath methods for BigDecimal (log, exp) are defined here.
Expand Down Expand Up @@ -215,12 +247,7 @@ def self.log(x, prec)
prec += BigDecimal.double_fig

# log(x) = log(sqrt(sqrt(sqrt(sqrt(x))))) * 2**sqrt_steps
sqrt_steps = [2 * Integer.sqrt(prec) + 3 * x_minus_one_exponent, 0].max

# Reduce sqrt_step until sqrt gets fast
# https://github.com/ruby/bigdecimal/pull/323
# https://github.com/ruby/bigdecimal/pull/343
sqrt_steps /= 10
sqrt_steps = [Integer.sqrt(prec) + 3 * x_minus_one_exponent, 0].max

lg2 = 0.3010299956639812
prec2 = prec + [-x_minus_one_exponent, 0].max + (sqrt_steps * lg2).ceil
Expand Down
32 changes: 27 additions & 5 deletions test/bigdecimal/test_bigdecimal.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1395,15 +1395,23 @@ def test_sqrt_bigdecimal

assert_in_delta(BigDecimal("4.0000000000000000000125"), BigDecimal("16.0000000000000000001").sqrt(100), BigDecimal("1e-40"))

BigDecimal.mode(BigDecimal::EXCEPTION_OVERFLOW, false)
BigDecimal.mode(BigDecimal::EXCEPTION_NaN, false)
assert_raise_with_message(FloatDomainError, "sqrt of 'NaN'(Not a Number)") { BigDecimal("NaN").sqrt(1) }
assert_raise_with_message(FloatDomainError, "sqrt of negative value") { BigDecimal("-Infinity").sqrt(1) }
assert_raise_with_message(FloatDomainError, "sqrt of 'NaN'(Not a Number)") { BigDecimal::NAN.sqrt(1) }
assert_raise_with_message(FloatDomainError, "sqrt of negative value") { NEGATIVE_INFINITY.sqrt(1) }

assert_equal(0, BigDecimal("0").sqrt(1))
assert_equal(0, BigDecimal("-0").sqrt(1))
assert_equal(1, BigDecimal("1").sqrt(1))
assert_positive_infinite(BigDecimal("Infinity").sqrt(1))
assert_positive_infinite_calculation { BigDecimal::INFINITY.sqrt(1) }

# Out of float range
assert_equal(BigDecimal('12e1024'), BigDecimal('144e2048').sqrt(10))
assert_equal(BigDecimal('12e-1024'), BigDecimal('144e-2048').sqrt(10))

sqrt2_300 = BigDecimal(2).sqrt(300)
(250..270).each do |prec|
sqrt_prec = prec + BigDecimal.double_fig - 1
assert_in_delta(sqrt2_300, BigDecimal(2).sqrt(prec), BigDecimal("1e#{-sqrt_prec}"))
end
end

def test_sqrt_5266
Expand All @@ -1420,6 +1428,20 @@ def test_sqrt_5266
x.sqrt(109).to_s(109).split(' ')[0])
end

def test_sqrt_minimum_precision
x = BigDecimal((2**200).to_s)
assert_equal(2**100, x.sqrt(1))

x = BigDecimal('1' * 60 + '.' + '1' * 40)
assert_in_delta(BigDecimal('3' * 30 + '.' + '3' * 70), x.sqrt(1), BigDecimal('1e-70'))

x = BigDecimal('1' * 40 + '.' + '1' * 60)
assert_in_delta(BigDecimal('3' * 20 + '.' + '3' * 80), x.sqrt(1), BigDecimal('1e-80'))

x = BigDecimal('0.' + '0' * 50 + '1' * 100)
assert_in_delta(BigDecimal('0.' + '0' * 25 + '3' * 100), x.sqrt(1), BigDecimal('1e-125'))
end

def test_fix
x = BigDecimal("1.1")
assert_equal(1, x.fix)
Expand Down
Loading