Skip to content

missing Modefile when save_pretrained_gguf #6

@oteroantoniogom

Description

@oteroantoniogom

I am having trouble when saving to gguf for use the model afterwards with Ollama.

if True: model.save_pretrained_gguf(f"{saved_name}_gguf", tokenizer, quantization_method = "q4_k_m")
That code works and converts the model. However, it does not add a Modelfile to the saved model.

Some issues suggest following other colab notebooks, but I am not sure how I could
add apply_chat_template in this scenario.

Thanks in advance.

import torch
if True:
  import torch
  from torch.nn import Module
  from collections import OrderedDict
  from typing import Mapping, Any, List, NamedTuple

  from unsloth import tokenizer_utils
  def do_nothing(*args, **kwargs):
      pass
  tokenizer_utils.fix_untrained_tokens = do_nothing

  from datasets import load_dataset
  import datasets
  from trl import SFTTrainer
  import pandas as pd
  import numpy as np
  import os
  import pandas as pd
  import numpy as np
  from unsloth import FastLanguageModel
  from trl import SFTTrainer
  from transformers import TrainingArguments, Trainer
  from typing import Tuple
  import warnings
  from typing import Any, Dict, List, Union
  from transformers import DataCollatorForLanguageModeling
  from sklearn.model_selection import train_test_split
  import matplotlib.pyplot as plt
  from transformers import Qwen2ForCausalLM, Qwen2Tokenizer



  def _find_mismatched_keys(
      model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = True
  ) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]:
      return peft_model_state_dict, []

  # Monkey patch the original function
  import peft.utils.save_and_load
  peft.utils.save_and_load._find_mismatched_keys = _find_mismatched_keys




  class _IncompatibleKeys(NamedTuple):
      missing_keys: List[str]
      unexpected_keys: List[str]

  def patched_load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
      if not isinstance(state_dict, Mapping):
          raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")

      missing_keys: List[str] = []
      unexpected_keys: List[str] = []
      error_msgs: List[str] = []

      # copy state_dict so _load_from_state_dict can modify it
      metadata = getattr(state_dict, "_metadata", None)
      state_dict = OrderedDict(state_dict)
      if metadata is not None:
          state_dict._metadata = metadata  # type: ignore[attr-defined]

      def load(module, local_state_dict, prefix=""):
          local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
          if assign:
              local_metadata["assign_to_params_buffers"] = assign
          module._load_from_state_dict(
              local_state_dict,
              prefix,
              local_metadata,
              True,
              missing_keys,
              unexpected_keys,
              error_msgs,
          )
          for name, child in module._modules.items():
              if child is not None:
                  child_prefix = prefix + name + "."
                  child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
                  load(child, child_state_dict, child_prefix)

          incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
          for hook in module._load_state_dict_post_hooks.values():
              out = hook(module, incompatible_keys)
              assert out is None, (
                  "Hooks registered with ``register_load_state_dict_post_hook`` are not"
                  "expected to return new values, if incompatible_keys need to be modified,"
                  "it should be done inplace."
              )

      def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
          for name, param in self._parameters.items():
              key = prefix + name
              if key in state_dict:
                  input_param = state_dict[key]
                  if param.shape != input_param.shape:
                      print(f"Shape mismatch for {key}, creating new tensor. Old shape: {param.shape}, New shape: {input_param.shape}")
                      # Create a new parameter with the shape from state_dict
                      new_param = torch.nn.Parameter(torch.empty_like(input_param), requires_grad=param.requires_grad)
                      new_param.data.copy_(input_param)
                      setattr(self, name, new_param)
                  else:
                      param.data.copy_(input_param)
              elif strict:
                  missing_keys.append(key)

          for name, buf in self._buffers.items():
              key = prefix + name
              if key in state_dict:
                  input_buf = state_dict[key]
                  if buf.shape != input_buf.shape:
                      print(f"Shape mismatch for buffer {key}, creating new tensor. Old shape: {buf.shape}, New shape: {input_buf.shape}")
                      # Create a new buffer with the shape from state_dict
                      new_buf = torch.empty_like(input_buf)
                      new_buf.copy_(input_buf)
                      setattr(self, name, new_buf)
                  else:
                      buf.copy_(input_buf)
              elif strict:
                  missing_keys.append(key)

      # Monkey patch the _load_from_state_dict method
      Module._load_from_state_dict = _load_from_state_dict

      load(self, state_dict)
      del load

      if strict:
          if len(unexpected_keys) > 0:
              error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in unexpected_keys)))
          if len(missing_keys) > 0:
              error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in missing_keys)))

      if len(error_msgs) > 0:
          raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs)))
      
      return _IncompatibleKeys(missing_keys, unexpected_keys)

  # Apply the monkey patch
  Module.load_state_dict = patched_load_state_dict

  # Load model
  model, tokenizer = FastLanguageModel.from_pretrained(saved_name)

if True: model.save_pretrained_gguf(f"{saved_name}_gguf", tokenizer, quantization_method = "q4_k_m")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions