Skip to content

Commit

Permalink
Adjust LayerNormANE bias to match torch.nn.LayerNorm equation and pin…
Browse files Browse the repository at this point in the history
… torch version
  • Loading branch information
atiorh authored and Atila Orhon committed Jul 30, 2022
1 parent 2050f58 commit 23c2259
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ane_transformers/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.1"
__version__ = "0.1.2"
20 changes: 20 additions & 0 deletions ane_transformers/huggingface/distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@
"coremltools does not support dict outputs. Please set return_dict=False"


# Note: torch.nn.LayerNorm and ane_transformers.reference.layer_norm.LayerNormANE
# apply scale and bias terms in opposite orders. In order to accurately restore a
# state_dict trained using the former into the the latter, we adjust the bias term
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs):
state_dict[prefix +
'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix +
'weight']
return state_dict


class LayerNormANE(LayerNormANE):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_load_state_dict_pre_hook(
correct_for_bias_scale_order_inversion)


class Embeddings(modeling_distilbert.Embeddings):
""" Embeddings module optimized for Apple Neural Engine
"""
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch>=1.10.0
torch>=1.10.0,<=1.11.0
transformers>=4.18.0
coremltools>=5.2.0
yapf
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
long_description_content_type='text/markdown',
author='Apple Inc.',
install_requires=[
"torch>=1.10.0",
"torch>=1.10.0,<=1.11.0",
"coremltools>=5.2.0",
"transformers>=4.18.0",
"protobuf>=3.1.0,<=3.20.1",
Expand Down

0 comments on commit 23c2259

Please sign in to comment.