diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 399fe963..b8e3d2b9 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -24,9 +24,8 @@ import torch.nn.functional as F from crossfit import op from crossfit.backend.torch.hf.model import HFModel -from huggingface_hub import hf_hub_download +from huggingface_hub import PyTorchModelHubMixin from peft import PeftModel -from safetensors.torch import load_file from torch.nn import Dropout, Linear from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @@ -75,7 +74,7 @@ class AegisConfig: ] -class InstructionDataGuardNet(torch.nn.Module): +class InstructionDataGuardNet(torch.nn.Module, PyTorchModelHubMixin): def __init__(self, input_dim, dropout=0.7): super().__init__() self.input_dim = input_dim @@ -180,12 +179,14 @@ def load_model(self, device: str = "cuda"): add_instruction_data_guard=self.config.add_instruction_data_guard, ) if self.config.add_instruction_data_guard: - weights_path = hf_hub_download( - repo_id=self.config.instruction_data_guard_path, - filename="model.safetensors", + model.instruction_data_guard_net = ( + model.instruction_data_guard_net.from_pretrained( + self.config.instruction_data_guard_path + ) + ) + model.instruction_data_guard_net = model.instruction_data_guard_net.to( + device ) - state_dict = load_file(weights_path) - model.instruction_data_guard_net.load_state_dict(state_dict) model.instruction_data_guard_net.eval() model = model.to(device)