Skip to content

Commit

Permalink
Merge pull request #162 from linto-ai/bugfix/load_transformers_shards
Browse files Browse the repository at this point in the history
fix #160 : Load (transformers) model when splitted into different shards
  • Loading branch information
Jeronymous authored Jan 15, 2024
2 parents bb99dba + 82150b2 commit c1cf345
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 10 deletions.
20 changes: 20 additions & 0 deletions tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,26 @@ def test_hugging_face_model(self):
device_specific=True,
)

import tempfile
from transformers import WhisperForConditionalGeneration
tempfolder = os.path.join(tempfile.gettempdir(), "tmp_whisper-tiny-french-cased")

for safe_serialization in False, True,:
for max_shard_size in "100MB", "10GB", :
shutil.rmtree(tempfolder, ignore_errors=True)
model = WhisperForConditionalGeneration.from_pretrained("qanastek/whisper-tiny-french-cased")
try:
model.save_pretrained(tempfolder, safe_serialization=safe_serialization, max_shard_size=max_shard_size)
self._test_cli_(
["--model", tempfolder, "--verbose", "True"],
"verbose", files=["bonjour.wav"], extensions=None,
prefix="hf",
device_specific=True,
)
finally:
shutil.rmtree(tempfolder)


# "ZZZ" to run this test at last (because it will fill the CUDA with some memory)
class TestZZZPythonImport(TestHelper):

Expand Down
61 changes: 51 additions & 10 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__author__ = "Jérôme Louradour"
__credits__ = ["Jérôme Louradour"]
__license__ = "GPLv3"
__version__ = "1.14.3"
__version__ = "1.14.4"

# Set some environment variables
import os
Expand Down Expand Up @@ -37,6 +37,7 @@
import copy
import re
import shutil
import json

# Constant variables
from whisper.utils import format_timestamp
Expand Down Expand Up @@ -2257,7 +2258,12 @@ def _get_alignment_heads(model_name, num_layers, num_heads):
return alignment_heads

def _get_number_of_parameters(model):
return sum(p.numel() for p in model.parameters())
num_parameters = 0
for name, p in model.named_parameters():
if name in ["decoder.proj_out.weight"]:
continue
num_parameters += p.numel()
return num_parameters

from typing import Optional, Union
def load_model(
Expand All @@ -2282,18 +2288,30 @@ def load_model(
raise ImportError(f"If you are trying to download a HuggingFace model with {name}, please install first the transformers library")
from transformers.utils import cached_file

kwargs = dict(cache_dir=download_root, use_auth_token=None, revision=None)
try:
model_path = cached_file(name, "pytorch_model.bin", cache_dir=download_root, use_auth_token=None, revision=None)
except Exception as e:
model_path = cached_file(name, "pytorch_model.bin", **kwargs)
except OSError as err:
try:
if isinstance(e, OSError):
model_path = cached_file(name, "whisper.ckpt", cache_dir=download_root, use_auth_token=None, revision=None)
else:
raise e
model_path = None
for candidate in ["whisper.ckpt", "pytorch_model.bin.index.json", "model.safetensors", "model.safetensors.index.json"]:
try:
model_path = cached_file(name, candidate, **kwargs)
except OSError:
continue
if candidate.endswith("index.json"):
index_file = model_path
mapping = json.load(open(index_file))
assert "weight_map" in mapping
assert isinstance(mapping["weight_map"], dict)
model_path = list(set(mapping["weight_map"].values()))
folder = os.path.dirname(index_file)
model_path = [os.path.join(folder, p) for p in model_path]
assert model_path is not None
except:
raise RuntimeError(f"Original error: {e}\nCould not find model {name} from HuggingFace nor local folders.")
raise RuntimeError(f"Original error: {err}\nCould not find model {name} from HuggingFace nor local folders.")
# Load HF Model
hf_state_dict = torch.load(model_path, map_location="cpu")
hf_state_dict = torch_load(model_path)

# Rename layers
for key in list(hf_state_dict.keys())[:]:
Expand Down Expand Up @@ -2321,6 +2339,29 @@ def load_model(
whisper_model = whisper_model.to(device)
return whisper_model

def torch_load(model_path):
if isinstance(model_path, list):
hf_state_dict = {}
for p in model_path:
d = torch_load(p)
for k in d:
assert k not in hf_state_dict, f"Found duplicate key {k} in {p}"
hf_state_dict.update(d)
else:
assert isinstance(model_path, str)
if model_path.endswith(".safetensors"):
from safetensors import safe_open
hf_state_dict = {}
with safe_open(model_path, framework="pt", device="cpu") as f:
for k in f.keys():
hf_state_dict[k] = f.get_tensor(k)
else:
hf_state_dict = torch.load(model_path, map_location="cpu")
return hf_state_dict




# Credit: https://github.com/openai/whisper/discussions/830
def hf_to_whisper_states(text):
# From Speechbrain
Expand Down

0 comments on commit c1cf345

Please sign in to comment.