diff --git a/src/forwarddif_derivative.f90 b/src/forwarddif_derivative.f90 index bb07722..d534dfa 100644 --- a/src/forwarddif_derivative.f90 +++ b/src/forwarddif_derivative.f90 @@ -5,6 +5,7 @@ module forwarddif_derivative private public :: derivative, derivative_sig + public :: grad, grad_sig abstract interface function derivative_sig(x) result(res) @@ -12,6 +13,12 @@ function derivative_sig(x) result(res) type(dual), intent(in) :: x type(dual) :: res end function + + function grad_sig(x) result(res) + import :: dual + type(dual), intent(in) :: x(:) + type(dual) :: res + end function end interface contains @@ -27,4 +34,29 @@ subroutine derivative(fcn, x, fcn_x, dfcn_x) dfcn_x = f%der(1) end subroutine + subroutine grad(fcn, x, fcn_x, dfcn_x) + procedure(grad_sig) :: fcn + real(wp), intent(in) :: x(:) + real(wp), intent(out) :: fcn_x + real(wp), intent(out) :: dfcn_x(:) + + type(dual) :: xx(size(x)) + type(dual) :: f + integer :: i + + xx%val = x + do i = 1,size(x) + allocate(xx(i)%der(size(x))) + xx(i)%der = 0.0_wp + xx(i)%der(i) = 1.0_wp + enddo + + f = fcn(xx) + fcn_x = f%val + dfcn_x = f%der(:) + end subroutine + + + + end module \ No newline at end of file diff --git a/src/forwarddif_dual.f90 b/src/forwarddif_dual.f90 index aa93d98..44429a4 100644 --- a/src/forwarddif_dual.f90 +++ b/src/forwarddif_dual.f90 @@ -16,8 +16,8 @@ module forwarddif_dual public :: abs, acos, asin, atan, atan2 public :: cos, exp, int, log, log10 - public :: max, maxval, min, minval, nint - public :: sign, sin, tan, sqrt, sum, maxloc + public :: max, maxval, min, minval + public :: sin, tan, sqrt, sum, maxloc type :: dual real(wp) :: val @@ -173,25 +173,17 @@ module forwarddif_dual end interface interface maxval - module procedure :: maxval_d + module procedure :: maxval_d end interface interface min - module procedure :: min_dd - module procedure :: min_dr + module procedure :: min_dd + module procedure :: min_dr + module procedure :: min_rd end interface interface minval - module procedure :: minval_d - end interface - - interface nint - module procedure :: nint_d - end interface - - interface sign - module procedure :: sign_dd - module procedure :: sign_rd + module procedure :: minval_d end interface interface sin @@ -850,6 +842,7 @@ elemental function max_di(u, i) result(res) if (u%val > i) then res = u else + allocate(res%der(size(u%der))) res = i endif @@ -863,6 +856,7 @@ elemental function max_dr(u, r) result(res) if (u%val > r) then res = u else + allocate(res%der(size(u%der))) res = r endif @@ -876,6 +870,7 @@ elemental function max_rd(n, u) result(res) if (u%val > n) then res = u else + allocate(res%der(size(u%der))) res = n endif @@ -918,51 +913,33 @@ elemental function min_dr(u, r) result(res) if (u%val < r) then res = u else + allocate(res%der(size(u%der))) res = r endif end function - function minval_d(u) result(res) - type(dual), intent(in) :: u(:) - type(dual) :: res - integer :: i - - i = minloc(u%val, 1) - res = u(i) - - end function - - elemental function nint_d(u) result(res) + elemental function min_rd(n, u) result(res) + real(wp), intent(in) :: n type(dual), intent(in) :: u - integer :: res - - res = nint(u%val) - - end function - - elemental function sign_dd(val1, val2) result(res) - type(dual), intent(in) :: val1, val2 type(dual) :: res - if (val2%val < 0.0_wp) then - res = -abs(val1) + if (u%val < n) then + res = u else - res = abs(val1) + allocate(res%der(size(u%der))) + res = n endif end function - elemental function sign_rd(val1, val2) result(res) - real(wp), intent(in) :: val1 - type(dual), intent(in) :: val2 + function minval_d(u) result(res) + type(dual), intent(in) :: u(:) type(dual) :: res + integer :: i - if (val2%val < 0.0_wp) then - res = -abs(val1) - else - res = abs(val1) - endif + i = minloc(u%val, 1) + res = u(i) end function diff --git a/test/test_forwarddif.f90 b/test/test_forwarddif.f90 index 5b93e0c..e27af4a 100644 --- a/test/test_forwarddif.f90 +++ b/test/test_forwarddif.f90 @@ -1,24 +1,34 @@ program test_forwarddif - use forwarddif, only: wp, derivative + use forwarddif, only: wp, derivative, grad implicit none call test() contains subroutine test() - real(wp) :: fcn_x, dfcn_x - real(wp) :: t(2) - - call cpu_time(t(1)) - call derivative(func_operators, 10.0_wp, fcn_x, dfcn_x) - call cpu_time(t(2)) - print*,fcn_x, dfcn_x, (t(2) - t(1)) - - call cpu_time(t(1)) - call derivative(func_intrinsics, 10.0_wp, fcn_x, dfcn_x) - call cpu_time(t(2)) - print*,fcn_x, dfcn_x, (t(2) - t(1)) - + real(wp) :: x, f, dfdx + real(wp) :: xx(2), dfdx1(2) + + x = 10.0_wp + call derivative(func_operators, x, f, dfdx) + print*,f, dfdx + + x = 10.0_wp + call derivative(func_intrinsics1, x, f, dfdx) + print*,f, dfdx + + x = 0.1_wp + call derivative(func_intrinsics2, x, f, dfdx) + print*,f, dfdx + + xx = [1.0_wp, 2.0_wp] + call grad(func_grad1, xx, f, dfdx1) + print*,f, dfdx1 + + xx = [1.0_wp, 2.0_wp] + call grad(func_grad2, xx, f, dfdx1) + print*,f, dfdx1 + end subroutine function func_operators(x) result(res) @@ -34,7 +44,7 @@ function func_operators(x) result(res) end function - function func_intrinsics(x) result(res) + function func_intrinsics1(x) result(res) use forwarddif_dual type(dual), intent(in) :: x type(dual) :: res @@ -50,6 +60,35 @@ function func_intrinsics(x) result(res) end function - + function func_intrinsics2(x) result(res) + use forwarddif_dual + type(dual), intent(in) :: x + type(dual) :: res + + res = acos(x) + res = res + asin(x) + res = res + atan(x) + res = max(res, x) + res = max(res, 1.0_wp) + res = max(1.0_wp, res) + res = min(res, res) + res = min(res, 2.0_wp) + res = min(2.0_wp, res) + + end function + + function func_grad1(x) result(res) + use forwarddif_dual + type(dual), intent(in) :: x(:) + type(dual) :: res + res = x(1)*x(1)*x(2) + x(1) + x(2) + end function + + function func_grad2(x) result(res) + use forwarddif_dual + type(dual), intent(in) :: x(:) + type(dual) :: res + res = sum(x*3.14_wp) + end function end program \ No newline at end of file