Skip to content

Commit

Permalink
Hutchinson estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Dec 21, 2024
1 parent 7798483 commit 3fef7e0
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 79 deletions.
176 changes: 111 additions & 65 deletions icenet/deep/autogradxgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,67 @@
# ------------------------------------------

class XgboostObjective():
def __init__(self, loss_func: Callable[[Tensor, Tensor], Tensor], mode='train',
flatten_grad=False, hessian_mode='constant', hessian_const=1.0,
smoothing = 0.1, device='cpu'):

self.mode = mode
self.loss_func = loss_func
self.device = device
self.hessian_mode = hessian_mode
self.hessian_const = hessian_const
self.flatten_grad = flatten_grad

# For the iterative Hessian estimation algorithm
"""
Args:
loss_func: Loss function handle
mode: 'train' or 'eval'
flatten_grad: For vector valued model output [experimental]
hessian_mode: 'constant', 'squared_approx', 'iterative', 'hutchinson', 'exact'
hessian_const: Scalar parameter 'constant 'hessian_mode'
hessian_gamma: Hessian momentum smoothing parameter for 'iterative' mode
hessian_slices: Hutchinson Hessian diagonal estimator MC slice sample size
device: Torch device
"""
def __init__(self,
loss_func: Callable[[Tensor, Tensor], Tensor],
mode: str='train',
flatten_grad: bool=False,
hessian_mode: str='constant',
hessian_const: float=1.0,
hessian_gamma: float=0.9,
hessian_slices: int=10,
device: torch.device='cpu'
):

self.mode = mode
self.loss_func = loss_func
self.device = device
self.hessian_mode = hessian_mode
self.hessian_const = hessian_const
self.hessian_gamma = hessian_gamma
self.hessian_slices = hessian_slices
self.flatten_grad = flatten_grad

# For the optimization algorithms
self.hess_diag = None
self.grad_prev = None
self.preds_prev = None
self.smoothing = smoothing

txt = f'Using device: {self.device} | hessian_mode = {self.hessian_mode}'

match self.hessian_mode:
case 'constant':
print(f'{txt} | hessian_const = {self.hessian_const}')
case 'iterative':
print(f'{txt} | hessian_gamma = {self.hessian_gamma}')
case 'hutchinson':
print(f'{txt} | hessian_slices = {self.hessian_slices}')
case _:
print(f'{txt}')

if self.hessian_mode == 'constant':
print(f'Using device: {self.device} | hessian_mode = {self.hessian_mode} | hessian_const = {self.hessian_const}')
elif self.hessian_mode == 'iterative':
print(f'Using device: {self.device} | hessian_mode = {self.hessian_mode} | smoothing = {self.smoothing}')
else:
print(f'Using device: {self.device} | hessian_mode = {self.hessian_mode}')

def __call__(self, preds: np.ndarray, targets: xgboost.DMatrix):

preds_, targets_, weights_ = self.torch_conversion(preds=preds, targets=targets)

if self.mode == 'train':
loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_)
return self.derivatives(loss=loss, preds=preds_)
elif self.mode == 'eval':
loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_)
return 'custom', loss.detach().cpu().numpy()
else:
raise Exception('Unknown mode (set either "train" or "eval")')
match self.mode:
case 'train':
loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_)
return self.derivatives(loss=loss, preds=preds_)
case 'eval':
loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_)
return 'custom', loss.detach().cpu().numpy()
case _:
raise Exception('Unknown mode (set either "train" or "eval")')

def torch_conversion(self, preds: np.ndarray, targets: xgboost.DMatrix):
"""
Expand All @@ -67,7 +92,8 @@ def torch_conversion(self, preds: np.ndarray, targets: xgboost.DMatrix):

return preds, targets, weights

def iterative_hessian_update(self, grad: Tensor, preds: Tensor, absMax=10, EPS=1e-8):
def iterative_hessian_update(self,
grad: Tensor, preds: Tensor, absMax: float=10, EPS: float=1e-8):
"""
Iterative Hessian (diagonal) approximation update using finite differences
Expand All @@ -78,7 +104,7 @@ def iterative_hessian_update(self, grad: Tensor, preds: Tensor, absMax=10, EPS=1
preds: Current prediction vector
"""

if self.hess_diag == None:
if self.hess_diag is None:
self.hess_diag = torch.ones_like(grad)
hess_diag_new = torch.ones_like(grad)

Expand All @@ -91,64 +117,84 @@ def iterative_hessian_update(self, grad: Tensor, preds: Tensor, absMax=10, EPS=1
hess_diag_new = torch.clamp(hess_diag_new, min=-absMax, max=absMax)

# Running smoothing update to stabilize
self.hess_diag = (1 - self.smoothing) * self.hess_diag + self.smoothing * hess_diag_new
self.hess_diag = self.hessian_gamma * self.hess_diag + (1-self.hessian_gamma) * hess_diag_new

# Save the gradient vector and predictions
self.grad_prev = grad.clone()
self.preds_prev = preds.clone()

def derivatives(self, loss: Tensor, preds: Tensor):
def derivatives(self, loss: Tensor, preds: Tensor) -> Tuple[Tensor, Tensor]:
"""
Gradient and Hessian diagonal
Args:
loss: loss function values
preds: model predictions
Returns:
gradient vector, hessian diagonal vector
"""

## Gradient
grad1 = torch.autograd.grad(loss, preds, create_graph=True)[0]

## Diagonal elements of the Hessian matrix

# Constant curvature
if self.hessian_mode == 'constant':
grad2 = self.hessian_const * torch.ones_like(grad1)

# Squared derivative based [uncontrolled] approximation (always positive curvature)
elif self.hessian_mode == 'squared_approx':
grad2 = grad1 * grad1

# BFGS style iterative updates
elif self.hessian_mode == 'iterative':

self.iterative_hessian_update(grad=grad1, preds=preds)
grad2 = self.hess_diag

# Exact autograd
elif self.hessian_mode == 'exact':

print('Computing Hessian diagonal with exact autograd ...')
match self.hessian_model:

"""
for i in tqdm(range(len(preds))):
grad2_i = torch.autograd.grad(grad1[i], preds, retain_graph=True)[0]
grad2[i] = grad2_i[i]
"""
# Constant curvature
case 'constant':
grad2 = self.hessian_const * torch.ones_like(grad1)

# Squared derivative based [uncontrolled] approximation (always positive curvature)
case 'squared_approx':
grad2 = grad1 * grad1

# BFGS style iterative updates
case 'iterative':
self.iterative_hessian_update(grad=grad1, preds=preds)
grad2 = self.hess_diag

grad2 = torch.zeros_like(preds)

for i in tqdm(range(len(preds))):
# Hutchinson MC approximator ~ O(slices)
case 'hutchinson':

print('Computing Hessian diagonal with approximate Hutchinson estimator ...')

grad2 = torch.zeros_like(preds)

for _ in tqdm(range(self.hessian_slices)):

# Generate a Rademacher vector (each element +-1 with probability 0.5)
v = torch.empty_like(preds).uniform_(-1, 1)
v = torch.sign(v)

# Compute Hessian-vector product H * v
Hv = torch.autograd.grad(grad1, preds, grad_outputs=v, retain_graph=True)[0]

# Accumulate element-wise product v * Hv to get the diagonal
grad2 += v * Hv

# Average over all samples
grad2 /= self.hessian_slices

# Exact autograd (slow)
case 'exact':

# A basis vector
e_i = torch.zeros_like(preds)
e_i[i] = 1.0
print('Computing Hessian diagonal with exact autograd ...')

# Compute the Hessian-vector product H e_i
grad2[i] = torch.autograd.grad(grad1, preds, grad_outputs=e_i, retain_graph=True)[0][i]
grad2 = torch.zeros_like(preds)

for i in tqdm(range(len(preds))):

# A basis vector
e_i = torch.zeros_like(preds)
e_i[i] = 1.0

# Compute the Hessian-vector product H e_i
grad2[i] = torch.autograd.grad(grad1, preds, grad_outputs=e_i, retain_graph=True)[0][i]

else:
raise Exception(f'Unknown "hessian_mode" {self.hessian_mode}')
case _:
raise Exception(f'Unknown "hessian_mode" {self.hessian_mode}')

grad1, grad2 = grad1.detach().cpu().numpy(), grad2.detach().cpu().numpy()

Expand Down
36 changes: 22 additions & 14 deletions icenet/deep/iceboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,27 +596,35 @@ def train_xgb(config={'params': {}}, data_trn=None, data_val=None, y_soft=None,
## Hessian treatment

# Default values
smoothing = 0.1
hessian_const = 1.0
hessian_mode = 'constant'
hessian_mode = 'constant'
hessian_const = 1.0
hessian_gamma = 0.9
hessian_slices = 10

# For example: 'hessian:constant:1.0', 'hessian:iterative:0.1' or 'hessian:exact'
# E.g. 'hessian:constant:1.0', 'hessian:iterative:0.9', 'hessian:hutchinson:10' or 'hessian:exact'
if 'hessian' in strs:
hessian_mode = strs[strs.index('hessian')+1]

# Pick parameters
if hessian_mode == 'constant':
hessian_const = float(strs[strs.index('hessian')+2])
# Pick additional parameters
try:
if hessian_mode == 'constant':
hessian_const = float(strs[strs.index('hessian')+2])

elif hessian_mode == 'iterative':
hessian_gamma = float(strs[strs.index('hessian')+2])

elif hessian_mode == 'iterative':
smoothing = float(strs[strs.index('hessian')+2])
elif hessian_mode == 'hutchinson':
hessian_slices = float(strs[strs.index('hessian')+2])
except:
print('Using default Hessian estimator parameters')

autogradObj = autogradxgb.XgboostObjective(
loss_func = loss_func,
hessian_mode = hessian_mode,
hessian_const = hessian_const,
smoothing = smoothing,
device = device
loss_func = loss_func,
hessian_mode = hessian_mode,
hessian_const = hessian_const,
hessian_gamma = hessian_gamma,
hessian_slices = hessian_slices,
device = device
)

for epoch in range(0, num_epochs):
Expand Down

0 comments on commit 3fef7e0

Please sign in to comment.