From 3fef7e0b0c114b1a0c48cd2c995f1ac0ff71b8a3 Mon Sep 17 00:00:00 2001 From: Mikael Mieskolainen Date: Sat, 21 Dec 2024 02:34:36 +0000 Subject: [PATCH] Hutchinson estimator --- icenet/deep/autogradxgb.py | 176 +++++++++++++++++++++++-------------- icenet/deep/iceboost.py | 36 +++++--- 2 files changed, 133 insertions(+), 79 deletions(-) diff --git a/icenet/deep/autogradxgb.py b/icenet/deep/autogradxgb.py index 8c419be4..9a7d1411 100644 --- a/icenet/deep/autogradxgb.py +++ b/icenet/deep/autogradxgb.py @@ -15,42 +15,67 @@ # ------------------------------------------ class XgboostObjective(): - def __init__(self, loss_func: Callable[[Tensor, Tensor], Tensor], mode='train', - flatten_grad=False, hessian_mode='constant', hessian_const=1.0, - smoothing = 0.1, device='cpu'): - - self.mode = mode - self.loss_func = loss_func - self.device = device - self.hessian_mode = hessian_mode - self.hessian_const = hessian_const - self.flatten_grad = flatten_grad - - # For the iterative Hessian estimation algorithm + """ + Args: + loss_func: Loss function handle + mode: 'train' or 'eval' + flatten_grad: For vector valued model output [experimental] + hessian_mode: 'constant', 'squared_approx', 'iterative', 'hutchinson', 'exact' + hessian_const: Scalar parameter 'constant 'hessian_mode' + hessian_gamma: Hessian momentum smoothing parameter for 'iterative' mode + hessian_slices: Hutchinson Hessian diagonal estimator MC slice sample size + device: Torch device + """ + def __init__(self, + loss_func: Callable[[Tensor, Tensor], Tensor], + mode: str='train', + flatten_grad: bool=False, + hessian_mode: str='constant', + hessian_const: float=1.0, + hessian_gamma: float=0.9, + hessian_slices: int=10, + device: torch.device='cpu' + ): + + self.mode = mode + self.loss_func = loss_func + self.device = device + self.hessian_mode = hessian_mode + self.hessian_const = hessian_const + self.hessian_gamma = hessian_gamma + self.hessian_slices = hessian_slices + self.flatten_grad = flatten_grad + + # For the optimization algorithms self.hess_diag = None self.grad_prev = None self.preds_prev = None - self.smoothing = smoothing + + txt = f'Using device: {self.device} | hessian_mode = {self.hessian_mode}' + + match self.hessian_mode: + case 'constant': + print(f'{txt} | hessian_const = {self.hessian_const}') + case 'iterative': + print(f'{txt} | hessian_gamma = {self.hessian_gamma}') + case 'hutchinson': + print(f'{txt} | hessian_slices = {self.hessian_slices}') + case _: + print(f'{txt}') - if self.hessian_mode == 'constant': - print(f'Using device: {self.device} | hessian_mode = {self.hessian_mode} | hessian_const = {self.hessian_const}') - elif self.hessian_mode == 'iterative': - print(f'Using device: {self.device} | hessian_mode = {self.hessian_mode} | smoothing = {self.smoothing}') - else: - print(f'Using device: {self.device} | hessian_mode = {self.hessian_mode}') - def __call__(self, preds: np.ndarray, targets: xgboost.DMatrix): preds_, targets_, weights_ = self.torch_conversion(preds=preds, targets=targets) - if self.mode == 'train': - loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_) - return self.derivatives(loss=loss, preds=preds_) - elif self.mode == 'eval': - loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_) - return 'custom', loss.detach().cpu().numpy() - else: - raise Exception('Unknown mode (set either "train" or "eval")') + match self.mode: + case 'train': + loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_) + return self.derivatives(loss=loss, preds=preds_) + case 'eval': + loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_) + return 'custom', loss.detach().cpu().numpy() + case _: + raise Exception('Unknown mode (set either "train" or "eval")') def torch_conversion(self, preds: np.ndarray, targets: xgboost.DMatrix): """ @@ -67,7 +92,8 @@ def torch_conversion(self, preds: np.ndarray, targets: xgboost.DMatrix): return preds, targets, weights - def iterative_hessian_update(self, grad: Tensor, preds: Tensor, absMax=10, EPS=1e-8): + def iterative_hessian_update(self, + grad: Tensor, preds: Tensor, absMax: float=10, EPS: float=1e-8): """ Iterative Hessian (diagonal) approximation update using finite differences @@ -78,7 +104,7 @@ def iterative_hessian_update(self, grad: Tensor, preds: Tensor, absMax=10, EPS=1 preds: Current prediction vector """ - if self.hess_diag == None: + if self.hess_diag is None: self.hess_diag = torch.ones_like(grad) hess_diag_new = torch.ones_like(grad) @@ -91,19 +117,22 @@ def iterative_hessian_update(self, grad: Tensor, preds: Tensor, absMax=10, EPS=1 hess_diag_new = torch.clamp(hess_diag_new, min=-absMax, max=absMax) # Running smoothing update to stabilize - self.hess_diag = (1 - self.smoothing) * self.hess_diag + self.smoothing * hess_diag_new + self.hess_diag = self.hessian_gamma * self.hess_diag + (1-self.hessian_gamma) * hess_diag_new # Save the gradient vector and predictions self.grad_prev = grad.clone() self.preds_prev = preds.clone() - def derivatives(self, loss: Tensor, preds: Tensor): + def derivatives(self, loss: Tensor, preds: Tensor) -> Tuple[Tensor, Tensor]: """ Gradient and Hessian diagonal Args: loss: loss function values preds: model predictions + + Returns: + gradient vector, hessian diagonal vector """ ## Gradient @@ -111,44 +140,61 @@ def derivatives(self, loss: Tensor, preds: Tensor): ## Diagonal elements of the Hessian matrix - # Constant curvature - if self.hessian_mode == 'constant': - grad2 = self.hessian_const * torch.ones_like(grad1) - - # Squared derivative based [uncontrolled] approximation (always positive curvature) - elif self.hessian_mode == 'squared_approx': - grad2 = grad1 * grad1 - - # BFGS style iterative updates - elif self.hessian_mode == 'iterative': - - self.iterative_hessian_update(grad=grad1, preds=preds) - grad2 = self.hess_diag - - # Exact autograd - elif self.hessian_mode == 'exact': - - print('Computing Hessian diagonal with exact autograd ...') + match self.hessian_model: - """ - for i in tqdm(range(len(preds))): - grad2_i = torch.autograd.grad(grad1[i], preds, retain_graph=True)[0] - grad2[i] = grad2_i[i] - """ + # Constant curvature + case 'constant': + grad2 = self.hessian_const * torch.ones_like(grad1) + + # Squared derivative based [uncontrolled] approximation (always positive curvature) + case 'squared_approx': + grad2 = grad1 * grad1 + + # BFGS style iterative updates + case 'iterative': + self.iterative_hessian_update(grad=grad1, preds=preds) + grad2 = self.hess_diag - grad2 = torch.zeros_like(preds) - - for i in tqdm(range(len(preds))): + # Hutchinson MC approximator ~ O(slices) + case 'hutchinson': + + print('Computing Hessian diagonal with approximate Hutchinson estimator ...') + + grad2 = torch.zeros_like(preds) + + for _ in tqdm(range(self.hessian_slices)): + + # Generate a Rademacher vector (each element +-1 with probability 0.5) + v = torch.empty_like(preds).uniform_(-1, 1) + v = torch.sign(v) + + # Compute Hessian-vector product H * v + Hv = torch.autograd.grad(grad1, preds, grad_outputs=v, retain_graph=True)[0] + + # Accumulate element-wise product v * Hv to get the diagonal + grad2 += v * Hv + + # Average over all samples + grad2 /= self.hessian_slices + + # Exact autograd (slow) + case 'exact': - # A basis vector - e_i = torch.zeros_like(preds) - e_i[i] = 1.0 + print('Computing Hessian diagonal with exact autograd ...') - # Compute the Hessian-vector product H e_i - grad2[i] = torch.autograd.grad(grad1, preds, grad_outputs=e_i, retain_graph=True)[0][i] + grad2 = torch.zeros_like(preds) + + for i in tqdm(range(len(preds))): + + # A basis vector + e_i = torch.zeros_like(preds) + e_i[i] = 1.0 + + # Compute the Hessian-vector product H e_i + grad2[i] = torch.autograd.grad(grad1, preds, grad_outputs=e_i, retain_graph=True)[0][i] - else: - raise Exception(f'Unknown "hessian_mode" {self.hessian_mode}') + case _: + raise Exception(f'Unknown "hessian_mode" {self.hessian_mode}') grad1, grad2 = grad1.detach().cpu().numpy(), grad2.detach().cpu().numpy() diff --git a/icenet/deep/iceboost.py b/icenet/deep/iceboost.py index 4b4c81c2..4a98eb21 100644 --- a/icenet/deep/iceboost.py +++ b/icenet/deep/iceboost.py @@ -596,27 +596,35 @@ def train_xgb(config={'params': {}}, data_trn=None, data_val=None, y_soft=None, ## Hessian treatment # Default values - smoothing = 0.1 - hessian_const = 1.0 - hessian_mode = 'constant' + hessian_mode = 'constant' + hessian_const = 1.0 + hessian_gamma = 0.9 + hessian_slices = 10 - # For example: 'hessian:constant:1.0', 'hessian:iterative:0.1' or 'hessian:exact' + # E.g. 'hessian:constant:1.0', 'hessian:iterative:0.9', 'hessian:hutchinson:10' or 'hessian:exact' if 'hessian' in strs: hessian_mode = strs[strs.index('hessian')+1] - # Pick parameters - if hessian_mode == 'constant': - hessian_const = float(strs[strs.index('hessian')+2]) + # Pick additional parameters + try: + if hessian_mode == 'constant': + hessian_const = float(strs[strs.index('hessian')+2]) + + elif hessian_mode == 'iterative': + hessian_gamma = float(strs[strs.index('hessian')+2]) - elif hessian_mode == 'iterative': - smoothing = float(strs[strs.index('hessian')+2]) + elif hessian_mode == 'hutchinson': + hessian_slices = float(strs[strs.index('hessian')+2]) + except: + print('Using default Hessian estimator parameters') autogradObj = autogradxgb.XgboostObjective( - loss_func = loss_func, - hessian_mode = hessian_mode, - hessian_const = hessian_const, - smoothing = smoothing, - device = device + loss_func = loss_func, + hessian_mode = hessian_mode, + hessian_const = hessian_const, + hessian_gamma = hessian_gamma, + hessian_slices = hessian_slices, + device = device ) for epoch in range(0, num_epochs):