Skip to content

Commit

Permalink
Small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscilloscope98 committed Jan 15, 2025
1 parent 3ddfffe commit ec344d8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ def pre_compute_inv_freq(module: torch.nn.Module):
if hasattr(module, "scaling_factor"):
module.inv_freq_scaled = module.inv_freq / module.scaling_factor
elif hasattr(module, "rope_type"):
if hasattr(module.rope_kwargs, "factor"):
if "factor" in module.rope_kwargs:
module.inv_freq_scaled = module.inv_freq / module.rope_kwargs["factor"]
elif module.config is not None:
module.inv_freq_scaled = module.inv_freq / module.config.rope_scaling["factor"]

elif module.__class__.__name__ == "LlamaRotaryEmbedding":
if hasattr(module, "rope_type") and module.rope_type == "linear":
module.register_buffer("inv_freq_scaled", None, persistent=False)
if hasattr(module.rope_kwargs, "factor"):
if "factor" in module.rope_kwargs:
module.inv_freq_scaled = module.inv_freq / module.rope_kwargs["factor"]
elif module.config is not None:
module.inv_freq_scaled = module.inv_freq / module.config.rope_scaling["factor"]
Expand Down

0 comments on commit ec344d8

Please sign in to comment.