Skip to content

Commit

Permalink
fix min and mix and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholaswogan committed Mar 28, 2024
1 parent be872c8 commit bb22e84
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 61 deletions.
32 changes: 32 additions & 0 deletions src/forwarddif_derivative.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@ module forwarddif_derivative
private

public :: derivative, derivative_sig
public :: grad, grad_sig

abstract interface
function derivative_sig(x) result(res)
import :: dual
type(dual), intent(in) :: x
type(dual) :: res
end function

function grad_sig(x) result(res)
import :: dual
type(dual), intent(in) :: x(:)
type(dual) :: res
end function
end interface

contains
Expand All @@ -27,4 +34,29 @@ subroutine derivative(fcn, x, fcn_x, dfcn_x)
dfcn_x = f%der(1)
end subroutine

subroutine grad(fcn, x, fcn_x, dfcn_x)
procedure(grad_sig) :: fcn
real(wp), intent(in) :: x(:)
real(wp), intent(out) :: fcn_x
real(wp), intent(out) :: dfcn_x(:)

type(dual) :: xx(size(x))
type(dual) :: f
integer :: i

xx%val = x
do i = 1,size(x)
allocate(xx(i)%der(size(x)))
xx(i)%der = 0.0_wp
xx(i)%der(i) = 1.0_wp
enddo

f = fcn(xx)
fcn_x = f%val
dfcn_x = f%der(:)
end subroutine




end module
67 changes: 22 additions & 45 deletions src/forwarddif_dual.f90
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ module forwarddif_dual

public :: abs, acos, asin, atan, atan2
public :: cos, exp, int, log, log10
public :: max, maxval, min, minval, nint
public :: sign, sin, tan, sqrt, sum, maxloc
public :: max, maxval, min, minval
public :: sin, tan, sqrt, sum, maxloc

type :: dual
real(wp) :: val
Expand Down Expand Up @@ -173,25 +173,17 @@ module forwarddif_dual
end interface

interface maxval
module procedure :: maxval_d
module procedure :: maxval_d
end interface

interface min
module procedure :: min_dd
module procedure :: min_dr
module procedure :: min_dd
module procedure :: min_dr
module procedure :: min_rd
end interface

interface minval
module procedure :: minval_d
end interface

interface nint
module procedure :: nint_d
end interface

interface sign
module procedure :: sign_dd
module procedure :: sign_rd
module procedure :: minval_d
end interface

interface sin
Expand Down Expand Up @@ -850,6 +842,7 @@ elemental function max_di(u, i) result(res)
if (u%val > i) then
res = u
else
allocate(res%der(size(u%der)))
res = i
endif

Expand All @@ -863,6 +856,7 @@ elemental function max_dr(u, r) result(res)
if (u%val > r) then
res = u
else
allocate(res%der(size(u%der)))
res = r
endif

Expand All @@ -876,6 +870,7 @@ elemental function max_rd(n, u) result(res)
if (u%val > n) then
res = u
else
allocate(res%der(size(u%der)))
res = n
endif

Expand Down Expand Up @@ -918,51 +913,33 @@ elemental function min_dr(u, r) result(res)
if (u%val < r) then
res = u
else
allocate(res%der(size(u%der)))
res = r
endif

end function

function minval_d(u) result(res)
type(dual), intent(in) :: u(:)
type(dual) :: res
integer :: i

i = minloc(u%val, 1)
res = u(i)

end function

elemental function nint_d(u) result(res)
elemental function min_rd(n, u) result(res)
real(wp), intent(in) :: n
type(dual), intent(in) :: u
integer :: res

res = nint(u%val)

end function

elemental function sign_dd(val1, val2) result(res)
type(dual), intent(in) :: val1, val2
type(dual) :: res

if (val2%val < 0.0_wp) then
res = -abs(val1)
if (u%val < n) then
res = u
else
res = abs(val1)
allocate(res%der(size(u%der)))
res = n
endif

end function

elemental function sign_rd(val1, val2) result(res)
real(wp), intent(in) :: val1
type(dual), intent(in) :: val2
function minval_d(u) result(res)
type(dual), intent(in) :: u(:)
type(dual) :: res
integer :: i

if (val2%val < 0.0_wp) then
res = -abs(val1)
else
res = abs(val1)
endif
i = minloc(u%val, 1)
res = u(i)

end function

Expand Down
71 changes: 55 additions & 16 deletions test/test_forwarddif.f90
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
program test_forwarddif
use forwarddif, only: wp, derivative
use forwarddif, only: wp, derivative, grad
implicit none
call test()

contains

subroutine test()
real(wp) :: fcn_x, dfcn_x
real(wp) :: t(2)

call cpu_time(t(1))
call derivative(func_operators, 10.0_wp, fcn_x, dfcn_x)
call cpu_time(t(2))
print*,fcn_x, dfcn_x, (t(2) - t(1))

call cpu_time(t(1))
call derivative(func_intrinsics, 10.0_wp, fcn_x, dfcn_x)
call cpu_time(t(2))
print*,fcn_x, dfcn_x, (t(2) - t(1))

real(wp) :: x, f, dfdx
real(wp) :: xx(2), dfdx1(2)

x = 10.0_wp
call derivative(func_operators, x, f, dfdx)
print*,f, dfdx

x = 10.0_wp
call derivative(func_intrinsics1, x, f, dfdx)
print*,f, dfdx

x = 0.1_wp
call derivative(func_intrinsics2, x, f, dfdx)
print*,f, dfdx

xx = [1.0_wp, 2.0_wp]
call grad(func_grad1, xx, f, dfdx1)
print*,f, dfdx1

xx = [1.0_wp, 2.0_wp]
call grad(func_grad2, xx, f, dfdx1)
print*,f, dfdx1

end subroutine

function func_operators(x) result(res)
Expand All @@ -34,7 +44,7 @@ function func_operators(x) result(res)

end function

function func_intrinsics(x) result(res)
function func_intrinsics1(x) result(res)
use forwarddif_dual
type(dual), intent(in) :: x
type(dual) :: res
Expand All @@ -50,6 +60,35 @@ function func_intrinsics(x) result(res)

end function


function func_intrinsics2(x) result(res)
use forwarddif_dual
type(dual), intent(in) :: x
type(dual) :: res

res = acos(x)
res = res + asin(x)
res = res + atan(x)
res = max(res, x)
res = max(res, 1.0_wp)
res = max(1.0_wp, res)
res = min(res, res)
res = min(res, 2.0_wp)
res = min(2.0_wp, res)

end function

function func_grad1(x) result(res)
use forwarddif_dual
type(dual), intent(in) :: x(:)
type(dual) :: res
res = x(1)*x(1)*x(2) + x(1) + x(2)
end function

function func_grad2(x) result(res)
use forwarddif_dual
type(dual), intent(in) :: x(:)
type(dual) :: res
res = sum(x*3.14_wp)
end function

end program

0 comments on commit bb22e84

Please sign in to comment.