diff --git a/src/forwarddiff_dual.f90 b/src/forwarddiff_dual.f90 index 16dde93..d3d6925 100644 --- a/src/forwarddiff_dual.f90 +++ b/src/forwarddiff_dual.f90 @@ -61,11 +61,11 @@ module forwarddiff_dual end interface interface operator (*) - module procedure :: mul_dd + module procedure :: mult_dd module procedure :: mult_di - module procedure :: mul_dr + module procedure :: mult_dr module procedure :: mult_id - module procedure :: mul_rd + module procedure :: mult_rd end interface interface operator (/) @@ -382,12 +382,12 @@ elemental function minus_rd(r, v) result(res) !~~~ multiplication ~~~! - elemental function mul_dd(u, v) result(res) + elemental function mult_dd(u, v) result(res) type(dual), intent(in) :: u type(dual), intent(in) :: v type(dual) :: res res%val = u%val*v%val - res%der = u%val*v%der + u%der*v%val + res%der = u%val*v%der + v%val*u%der end function elemental function mult_di(u, i) result(res) @@ -400,7 +400,7 @@ elemental function mult_di(u, i) result(res) res%der = r * u%der end function - elemental function mul_dr(u, r) result(res) + elemental function mult_dr(u, r) result(res) type(dual), intent(in) :: u real(wp), intent(in) :: r type(dual) :: res @@ -418,7 +418,7 @@ elemental function mult_id(i, v) result(res) res%der = r * v%der end function - elemental function mul_rd(r, v) result(res) + elemental function mult_rd(r, v) result(res) real(wp), intent(in) :: r type(dual), intent(in) :: v type(dual) :: res @@ -488,7 +488,7 @@ elemental function pow_i(u, i) result(res) pow_x = u%val ** (i - 1) res%val = u%val * pow_x res%der = real(i,wp) * pow_x * u%der - end function pow_i + end function elemental function pow_dr(u, r) result(res) type(dual), intent(in) :: u @@ -520,7 +520,7 @@ elemental function eq_di(lhs, rhs) result(res) type(dual), intent(in) :: lhs integer, intent(in) :: rhs logical :: res - res = (lhs%val == real(rhs,wp)) + res = (lhs%val == rhs) end function eq_di elemental function eq_dr(lhs, rhs) result(res) @@ -757,9 +757,10 @@ elemental function acos_d(u) result(res) res%val = acos(u%val) if (u%val == 1.0_wp .or. u%val == -1.0_wp) then - res%der = ieee_value(1.0_wp, ieee_quiet_nan) + allocate(res%der(size(u%der))) + res%der = ieee_value(1.0_wp, ieee_quiet_nan) else - res%der = -u%der / sqrt(1.0_wp - u%val**2) + res%der = -u%der / sqrt(1.0_wp - u%val**2) end if end function @@ -770,6 +771,7 @@ elemental function asin_d(u) result(res) res%val = asin(u%val) if (u%val == 1.0_wp .or. u%val == -1.0_wp) then + allocate(res%der(size(u%der))) res%der = ieee_value(1.0_wp, ieee_quiet_nan) else res%der = u%der / sqrt(1.0_wp - u%val**2) @@ -1029,6 +1031,9 @@ function sum_d(u) result(res) integer :: i, j res%val = sum(u%val) + if (size(u) == 0) then + error stop "Can not sum a zero-length dual array." + endif allocate(res%der(size(u(1)%der))) res%der = 0.0_wp do j = 1,size(u)