Skip to content

Commit e2912f0

Browse files
authored
Merge pull request #78 from JHay0112/autodiff-grad
Autodiff grad
2 parents 3a58c19 + c0004f6 commit e2912f0

File tree

4 files changed

+79
-9
lines changed

4 files changed

+79
-9
lines changed

jmath/autodiff.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
'''
22
Automatic Differentiation
3+
4+
Examples
5+
--------
6+
7+
>>> from jmath.autodiff import x, y
8+
>>> f = 6*x*y*2 + y + 9
9+
>>> f_prime = f.d(x)
10+
>>> f_prime(x = 2, y = 1)
11+
6
12+
13+
>>> from jmath.autodiff import x, y
14+
>>> f = x*y
15+
>>> grad_f = f.d(x, y)
16+
>>> grad_f(x = 3, y = 2)
17+
Vector(2, 3)
318
'''
419

520
# - Imports
@@ -10,11 +25,12 @@
1025
from functools import wraps
1126
from types import FunctionType
1227
from .uncertainties import Uncertainty
28+
from .linearalgebra import Vector
1329
from typing import Any, Union, Callable, Tuple
1430

1531
# - Typing
1632

17-
Supported = Union[int, float, Uncertainty, 'Function', 'Variable']
33+
Supported = Union[int, float, Uncertainty, 'Function', 'Variable', Vector]
1834
Numeric = Union[int, float, Uncertainty]
1935

2036
# - Classes
@@ -63,7 +79,10 @@ def __str__(self):
6379
# Standard function
6480
return f"{self.func.__name__}{str(params)[:-2]})"
6581

66-
def __call__(self, **kwargs):
82+
def __call__(self, *args, **kwargs):
83+
84+
if len(args) != 0:
85+
raise AttributeError("Inputs must be assigned to a variable e.g. f(x = 3) rather than f(3)!")
6786

6887
if not isinstance(self.func, Callable):
6988
return self.func
@@ -225,7 +244,7 @@ def register(self, *inputs: 'Function'):
225244
'''
226245
self.inputs = inputs
227246

228-
def differentiate(self, wrt: Union['Variable', str]) -> 'Function':
247+
def differentiate(self, *wrt: Union['Variable', str]) -> Union['Function', Vector]:
229248
'''
230249
Differentiates the function with respect to a variable.
231250
@@ -235,6 +254,15 @@ def differentiate(self, wrt: Union['Variable', str]) -> 'Function':
235254
wrt
236255
The variable to differentiate with respect to.
237256
'''
257+
if len(wrt) == 1:
258+
# Single differential case
259+
wrt = wrt[0]
260+
else:
261+
# Multiple case
262+
results = []
263+
for var in wrt:
264+
results.append(self.differentiate(var))
265+
return Vector(results)
238266
# The differentiated function
239267
func = 0
240268
# Move across inputs
@@ -248,8 +276,8 @@ def differentiate(self, wrt: Union['Variable', str]) -> 'Function':
248276
return func
249277

250278
@wraps(differentiate)
251-
def d(self, wrt: Union['Variable', str]) -> 'Function':
252-
return self.differentiate(wrt)
279+
def d(self, *wrt: Union['Variable', str]) -> 'Function':
280+
return self.differentiate(*wrt)
253281

254282
class Variable(Function):
255283
'''
@@ -272,9 +300,12 @@ def __str__(self):
272300

273301
return self.id
274302

275-
def __call__(self, input: Any) -> Any:
303+
def __call__(self, input: Any = None, **kwargs) -> Any:
276304

277-
return input
305+
if input is not None:
306+
return input
307+
else:
308+
return kwargs[self.id]
278309

279310
def differentiate(self, wrt: 'Variable') -> int:
280311

jmath/linearalgebra/vectors.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,17 @@ def inner(*args, **kwargs):
9393

9494
return inner
9595

96+
def __call__(self, *args, **kwargs) -> 'Vector':
97+
"""Call vector of functions."""
98+
results = []
99+
for component in self.components:
100+
# Only call callable functions
101+
if isinstance(component, Callable):
102+
results.append(component(*args, **kwargs))
103+
else:
104+
results.append(component)
105+
return Vector(results)
106+
96107
def __eq__(self, vector: "Vector") -> bool:
97108
"""Tests equality of vectors"""
98109
if isinstance(vector, Vector) or isinstance(vector, Point):
@@ -213,6 +224,34 @@ def angle_between(self, vector: "Vector") -> float:
213224
vector = vector.vector
214225

215226
return round(math.acos((self @ vector)/(self.magnitude() * vector.magnitude())), 5)
227+
228+
def differentiate(self, *wrt) -> 'Vector':
229+
'''
230+
Differentiate functions in vector with respect to a variable.
231+
232+
Parameters
233+
----------
234+
235+
wrt
236+
The variables to differentiate with respect to.
237+
'''
238+
results = []
239+
for component in self.components:
240+
# Check if is callable
241+
if isinstance(component, Callable):
242+
# Then differentiate it
243+
results.append(component.d(*wrt))
244+
else:
245+
# Non-differentiable
246+
# So just append a zero
247+
results.append(0)
248+
return Vector(results)
249+
250+
def d(self, *wrt) -> 'Vector':
251+
'''
252+
Differentiate short hand.
253+
'''
254+
return self.differentiate(*wrt)
216255

217256
class Point(Vector):
218257
"""

tests/test_unit_spaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_conversion_leak():
3131
space = UnitSpace()
3232
space.a = Unit("a")
3333
space.b = Unit("b")
34-
space.define_conversion(space.a, space.b, random_integer())
34+
space.define_conversion(space.a, space.b, random_integer(non_zero = True))
3535
assert space.a == space.a.convert_to(space.b).convert_to(space.a)
3636
assert universal.units == {}
3737

tests/test_units.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_conversion():
5050
"""Tests that unit conversion happens as expected."""
5151
a = Unit("a")
5252
b = Unit("b")
53-
coeffecient = random_integer()
53+
coeffecient = random_integer(non_zero = True)
5454
define_conversion(a, b, coeffecient)
5555

5656
assert a.convert_to(b).value == (a*coeffecient).value

0 commit comments

Comments
 (0)