Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
ba51314
testing spec
Eyoel-gebre Apr 9, 2025
63cb8ff
testing spec
Eyoel-gebre Apr 10, 2025
24fe394
added test
Eyoel-gebre Apr 10, 2025
d258577
update
Eyoel-gebre Apr 10, 2025
607a649
fix
Eyoel-gebre Apr 10, 2025
cf840ad
loading fixes
Eyoel-gebre Apr 11, 2025
0412e6f
fix
Eyoel-gebre Apr 14, 2025
7dbd2d8
fixed
Eyoel-gebre Apr 14, 2025
a5ac7cd
fixed
Eyoel-gebre Apr 14, 2025
a667e1e
test
Eyoel-gebre Apr 15, 2025
9220303
test
Eyoel-gebre Apr 15, 2025
80e6b7b
test
Eyoel-gebre Apr 15, 2025
e501642
test
Eyoel-gebre Apr 15, 2025
e17f23f
test
Eyoel-gebre Apr 15, 2025
eb83092
test
Eyoel-gebre Apr 15, 2025
f786c6d
test_axis_1
Eyoel-gebre Apr 17, 2025
e7cfdad
test_axis_1
Eyoel-gebre Apr 17, 2025
a6a4c42
test_axis_1
Eyoel-gebre Apr 17, 2025
ff9a3af
test_axis_1
Eyoel-gebre Apr 17, 2025
b8d0c7a
test_axis_1
Eyoel-gebre Apr 17, 2025
566e35b
test_axis_1
Eyoel-gebre Apr 17, 2025
5230ee4
test_axis_1
Eyoel-gebre Apr 17, 2025
220ae07
test_axis_1
Eyoel-gebre Apr 17, 2025
5259c69
test_axis_1
Eyoel-gebre Apr 17, 2025
c016c41
test_axis_1
Eyoel-gebre Apr 17, 2025
bfd3386
test_axis_1
Eyoel-gebre Apr 17, 2025
58d2b5a
logging
Eyoel-gebre Apr 21, 2025
2b0fbf1
logging
Eyoel-gebre Apr 21, 2025
4aa77d8
logging
Eyoel-gebre Apr 21, 2025
84f54d4
logging
Eyoel-gebre Apr 21, 2025
e08a06e
logging
Eyoel-gebre Apr 21, 2025
ab31afc
logging
Eyoel-gebre Apr 21, 2025
d6b924b
logging
Eyoel-gebre Apr 21, 2025
58c7547
logging
Eyoel-gebre Apr 21, 2025
94e6b34
logging
Eyoel-gebre Apr 21, 2025
97aeeeb
logging
Eyoel-gebre Apr 21, 2025
b113821
logging
Eyoel-gebre Apr 21, 2025
77837b9
logging
Eyoel-gebre Apr 21, 2025
98e8941
logging
Eyoel-gebre Apr 21, 2025
b6c8edc
logging
Eyoel-gebre Apr 21, 2025
c1c5c47
spec
Eyoel-gebre Apr 21, 2025
99ab963
remove-test
Eyoel-gebre Apr 22, 2025
367b6c6
remove-garbage
Eyoel-gebre Apr 22, 2025
d7dc107
remove-garbage
Eyoel-gebre Apr 22, 2025
93e4a70
updated model executor for lite-whisper
Eyoel-gebre Apr 30, 2025
a29eef3
debugging
Eyoel-gebre Apr 30, 2025
0b5fa40
debugging
Eyoel-gebre Apr 30, 2025
7266663
debugging
Eyoel-gebre Apr 30, 2025
dacb949
debugging
Eyoel-gebre Apr 30, 2025
ee0527c
debugging
Eyoel-gebre Apr 30, 2025
5cf90b9
debugging
Eyoel-gebre Apr 30, 2025
3098c95
naming fix
Eyoel-gebre May 1, 2025
5caaee5
.
Eyoel-gebre May 2, 2025
23df460
shapes
Eyoel-gebre May 2, 2025
6bdce85
shapes2
Eyoel-gebre May 2, 2025
2821169
shape3
Eyoel-gebre May 2, 2025
b4b35b3
preprocessor
Eyoel-gebre May 3, 2025
f2b7b22
small
Eyoel-gebre May 3, 2025
af19fb4
small
Eyoel-gebre May 3, 2025
84e346a
dims
Eyoel-gebre May 4, 2025
6ac5325
tweaks
Eyoel-gebre May 4, 2025
e6a83fc
error handling
Eyoel-gebre May 4, 2025
76c0bb4
minor
Eyoel-gebre May 4, 2025
132258e
fix: issues with lite-whisper models
kamahori Aug 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
*.pyc
/.vs
.vscode
/build

CMake*.json
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace ctranslate2 {
const bool multi_query = false);

protected:
bool _is_low_rank;
const bool _tensor_parallel;
const dim_t _num_heads;
const bool _self_attention;
Expand Down
3 changes: 3 additions & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ namespace ctranslate2 {
void select_weights(const StorageView* index, const StorageView* extra_bias = nullptr);
private:
bool _packed_weight;
bool _is_low_rank;
const StorageView& _weight;
const StorageView* _weight2;
const StorageView* _bias;
const StorageView* _qscale;
const StorageView* _qzero;
Expand All @@ -148,6 +150,7 @@ namespace ctranslate2 {
const models::QUANTIZATION_TYPE _quant_method;
const bool _quantized_gemm;
const ops::Gemm _gemm_op;
const ops::Gemm _gemm_op_low_rank;
const ops::Quantize _quantize_op;
const ops::Dequantize _dequantize_op;
const ops::ActivationType* _activation_type;
Expand Down
143 changes: 140 additions & 3 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,14 @@ def _load(self):
% (config_name, ", ".join(sorted(_MODEL_LOADERS.keys())))
)

model_class = getattr(transformers, loader.architecture_name)
# If lite whisper use corresponding openai tokenizer
if config.model_type == "lite-whisper":
base_name = self._model_name_or_path.split("/")[-1] # e.g., "lite-whisper-large-v3"
base_name = base_name.replace("lite-", "") # e.g., "whisper-large-v3"
tokenizer_path = f"openai/{base_name}"
else:
tokenizer_path = self._model_name_or_path

tokenizer_class = transformers.AutoTokenizer

kwargs = {
Expand All @@ -137,14 +144,18 @@ def _load(self):
if self._trust_remote_code:
kwargs["trust_remote_code"] = self._trust_remote_code

model = self.load_model(model_class, self._model_name_or_path, **kwargs)
if hasattr(transformers, loader.architecture_name):
model_class = getattr(transformers, loader.architecture_name)
model = self.load_model(model_class, self._model_name_or_path, **kwargs)
else:
model = transformers.AutoModel.from_pretrained(self._model_name_or_path, **kwargs)

tokenizer_kwargs = {}
if self._trust_remote_code:
tokenizer_kwargs["trust_remote_code"] = self._trust_remote_code

tokenizer = self.load_tokenizer(
tokenizer_class, self._model_name_or_path, **tokenizer_kwargs
tokenizer_class, tokenizer_path, **tokenizer_kwargs
)

spec = loader(model, tokenizer)
Expand Down Expand Up @@ -996,6 +1007,119 @@ def set_conv1d(self, spec, module):
spec.weight = module.weight
spec.bias = module.bias

@register_loader("LiteWhisperConfig")
class LiteWhisperLoader(WhisperLoader):
@property
def architecture_name(self):
return "LiteWhisperForConditionalGeneration"

def get_model_spec(self, model):
spec = whisper_spec.WhisperSpec(
model.config.encoder_layers,
model.config.encoder_attention_heads,
model.config.decoder_layers,
model.config.decoder_attention_heads,
low_rank=True,
)

self.set_encoder(spec.encoder, model.model.encoder)
self.set_decoder(spec.decoder, model.model.decoder)
self.set_linear(spec.decoder.projection, model.proj_out)

return spec


def set_config(self, config, model, tokenizer):
gen_config = getattr(model, "generation_config", None)

if gen_config is not None:
config.suppress_ids = gen_config.suppress_tokens
config.suppress_ids_begin = gen_config.begin_suppress_tokens
if hasattr(gen_config, "alignment_heads"):
config.alignment_heads = gen_config.alignment_heads
if hasattr(gen_config, "lang_to_id"):
config.lang_ids = sorted(gen_config.lang_to_id.values())
else:
config.suppress_ids = model.config.suppress_tokens
config.suppress_ids_begin = model.config.begin_suppress_tokens
config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)

if getattr(config, "lang_ids", None) is None:
config.lang_ids = self._get_lang_ids_from_tokenizer(tokenizer)

if config.alignment_heads is None:
config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)
if config.alignment_heads is None:
# Use the last half layers for alignment by default.
num_layers = model.config.decoder_layers
num_heads = model.config.decoder_attention_heads
config.alignment_heads = list(
itertools.product(
range(num_layers // 2, num_layers),
range(num_heads),
)
)

def set_encoder(self, spec, encoder):
"""
Override encoder mapping for LiteWhisper.
"""
self.set_conv1d(spec.conv1, encoder.conv1)
self.set_conv1d(spec.conv2, encoder.conv2)

self.set_common_layers(spec, encoder)

for layer_spec, layer in zip(spec.layer, encoder.layers):
self.set_low_rank_attention(
layer_spec.self_attention,
layer.self_attn,
)
self.set_layer_norm(
layer_spec.self_attention.layer_norm,
layer.self_attn_layer_norm,
)

if hasattr(layer.fc1, "weight1"):
# low rank
self.set_low_rank_linear(layer_spec.ffn.linear_0, layer.fc1)
else:
layer_spec.ffn.linear_0 = common_spec.LinearSpec()
self.set_linear(layer_spec.ffn.linear_0, layer.fc1)

if hasattr(layer.fc2, "weight1"):
# low rank
self.set_low_rank_linear(layer_spec.ffn.linear_1, layer.fc2)
else:
layer_spec.ffn.linear_1 = common_spec.LinearSpec()
self.set_linear(layer_spec.ffn.linear_1, layer.fc2)

self.set_layer_norm(layer_spec.ffn.layer_norm, layer.final_layer_norm)

def set_low_rank_linear(self, spec, module, quant_type=common_spec.Quantization.CT2):
if quant_type == common_spec.Quantization.CT2:
spec.low_rank_weight_1 = module.weight1.transpose(0, 1).contiguous()
spec.low_rank_weight_2 = module.weight2.transpose(0, 1).contiguous()
else:
spec.low_rank_weight_1 = module.qweight1.transpose(0, 1).contiguous()
spec.low_rank_weight_2 = module.qweight2.transpose(0, 1).contiguous()
spec.weight_scale = module.scales
spec.weight_zero = module.qzeros

if module.bias is not None:
spec.bias = module.bias

def set_low_rank_or_linear_router(self, spec, module, i):
if hasattr(module, "weight1"):
self.set_low_rank_linear(spec.linear[i], module)
else:
spec.linear[i] = common_spec.LinearSpec()
self.set_linear(spec.linear[i], module)

def set_low_rank_attention(self, spec, attention):
self.set_low_rank_or_linear_router(spec, attention.q_proj, 0)
self.set_low_rank_or_linear_router(spec, attention.k_proj, 1)
self.set_low_rank_or_linear_router(spec, attention.v_proj, 2)
self.set_low_rank_or_linear_router(spec, attention.out_proj, 3)

@register_loader("Wav2Vec2Config")
class Wav2Vec2Loader(BartLoader):
Expand Down Expand Up @@ -2908,6 +3032,7 @@ def main():
(3, 4),
],
"openai/whisper-tiny": [(2, 2), (3, 0), (3, 2), (3, 3), (3, 4), (3, 5)],
"efficient-speech/whisper-tiny": [(2, 2), (3, 0), (3, 2), (3, 3), (3, 4), (3, 5)],
"openai/whisper-base.en": [(3, 3), (4, 7), (5, 1), (5, 5), (5, 7)],
"openai/whisper-base": [
(3, 1),
Expand Down Expand Up @@ -3021,4 +3146,16 @@ def main():
(24, 1),
(25, 6),
],
"efficient-speech/whisper-large-v3": [
(7, 0),
(10, 17),
(12, 18),
(13, 12),
(16, 1),
(17, 14),
(19, 11),
(21, 4),
(24, 1),
(25, 6),
],
}
7 changes: 4 additions & 3 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def __init__(
num_heads_kv=None,
head_dim=None,
sliding_window=None,
low_rank=False,
):
self.queries_scale = model_spec.OPTIONAL

self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
self.linear = [
common_spec.LinearSpec() for _ in range(2 if self_attention else 3)
]
linear_cls = common_spec.LinearLowRankSpec if low_rank else common_spec.LinearSpec
count = 4 if low_rank else (2 if self_attention else 3)
self.linear = [linear_cls() for _ in range(count)]

if relative_position:
self.relative_position_keys = None
Expand Down
12 changes: 12 additions & 0 deletions python/ctranslate2/specs/common_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,15 @@ def __init__(self):
self.weight = None
self.weight_scale = model_spec.OPTIONAL
self.multiply_by_sqrt_depth = model_spec.OPTIONAL


class LinearLowRankSpec(model_spec.LayerSpec):
def __init__(self):
self.low_rank_weight_1 = None
self.low_rank_weight_2 = None
self.weight_scale = model_spec.OPTIONAL
self.weight_zero = model_spec.OPTIONAL
self.bias = model_spec.OPTIONAL

def has_bias(self):
return not isinstance(self.bias, str)
11 changes: 7 additions & 4 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def __init__(
rms_norm=False,
num_heads_kv=None,
sliding_window=None,
low_rank=False,
):
self.self_attention = attention_spec.MultiHeadAttentionSpec(
self_attention=True,
Expand All @@ -261,8 +262,9 @@ def __init__(
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
low_rank=low_rank,
)
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm, low_rank=low_rank)


class TransformerDecoderLayerSpec(model_spec.LayerSpec):
Expand Down Expand Up @@ -340,10 +342,11 @@ def __init__(


class FeedForwardSpec(model_spec.LayerSpec):
def __init__(self, glu=False, rms_norm=False):
def __init__(self, glu=False, rms_norm=False, low_rank=False):
self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
self.linear_0 = common_spec.LinearSpec()
self.linear_1 = common_spec.LinearSpec()
linear_cls = common_spec.LinearLowRankSpec if low_rank else common_spec.LinearSpec
self.linear_0 = linear_cls()
self.linear_1 = linear_cls()
if glu:
self.linear_0_noact = common_spec.LinearSpec()

Expand Down
8 changes: 5 additions & 3 deletions python/ctranslate2/specs/whisper_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
num_encoder_heads,
num_decoder_layers,
num_decoder_heads,
low_rank=False,
):
"""Initializes the model specification.

Expand All @@ -40,9 +41,10 @@ def __init__(
num_encoder_heads: The number of encoder attention heads.
num_decoder_layers: The number of decoder layers.
num_decoder_heads: The number of decoder attention heads.
low_rank: Whether to use lite whisper model or not.
"""
super().__init__()
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads)
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads, low_rank=low_rank)
self.decoder = transformer_spec.TransformerDecoderSpec(
num_decoder_layers,
num_decoder_heads,
Expand All @@ -66,12 +68,12 @@ def get_vocabulary_size(self):


class WhisperEncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers, num_heads):
def __init__(self, num_layers, num_heads, low_rank=False):
self.num_heads = np.dtype("int16").type(num_heads)
self.conv1 = common_spec.Conv1DSpec()
self.conv2 = common_spec.Conv1DSpec()
self.position_encodings = transformer_spec.PositionEncoderSpec()
self.layer_norm = common_spec.LayerNormSpec()
self.layer = [
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
transformer_spec.TransformerEncoderLayerSpec(low_rank=low_rank) for _ in range(num_layers)
]
21 changes: 19 additions & 2 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,19 @@ namespace ctranslate2 {

_linear[0](*q, fused_proj);

if (_is_low_rank) { // support low-rank
_linear[1](*q, keys_proj);
_linear[2](*q, values_proj);
queries_proj = std::move(fused_proj);
}

dim_t beam_size = 1;

bool prefilling = (_sliding_window > 0 && values_lengths);

if (!_self_attention) {
if (_is_low_rank)
throw std::invalid_argument("lite whisper doesn't use low-rank for cross-attention");
queries_proj = std::move(fused_proj);

if (cached_keys == nullptr || cached_keys->empty()) {
Expand Down Expand Up @@ -401,6 +409,8 @@ namespace ctranslate2 {
} else {

if (_num_heads_kv < _num_heads) {
if (_is_low_rank)
throw std::invalid_argument("lite whisper doesn't use low-rank for multi-query or GQA");
if (queries_padder)
queries_padder->add_padding(fused_proj);

Expand All @@ -419,8 +429,15 @@ namespace ctranslate2 {
}

} else {
split_heads(fused_proj, 3 * _num_heads, queries_padder);
ops::Split(1)(fused_proj, queries_proj, keys_proj, values_proj);
if (!_is_low_rank){
split_heads(fused_proj, 3 * _num_heads, queries_padder);
ops::Split(1)(fused_proj, queries_proj, keys_proj, values_proj);
}
else{
split_heads(queries_proj, _num_heads, queries_padder);
split_heads(keys_proj, _num_heads_kv, queries_padder);
split_heads(values_proj, _num_heads_kv, queries_padder);
}
}

if (_rotary_embeddings) {
Expand Down
Loading
Loading