Skip to content

Commit

Permalink
Change FineTuneGuardClassifier to InstructionDataGuardClassifier (N…
Browse files Browse the repository at this point in the history
…VIDIA#402)

* change name

Signed-off-by: Sarah Yurick <[email protected]>

* run black

Signed-off-by: Sarah Yurick <[email protected]>

---------

Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick authored Dec 2, 2024
1 parent 110cede commit d14ac42
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
4 changes: 2 additions & 2 deletions nemo_curator/classifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from .aegis import AegisClassifier, FineTuneGuardClassifier
from .aegis import AegisClassifier, InstructionDataGuardClassifier
from .domain import DomainClassifier
from .fineweb_edu import FineWebEduClassifier
from .quality import QualityClassifier
Expand All @@ -24,6 +24,6 @@
"DomainClassifier",
"QualityClassifier",
"AegisClassifier",
"FineTuneGuardClassifier",
"InstructionDataGuardClassifier",
"FineWebEduClassifier",
]
48 changes: 24 additions & 24 deletions nemo_curator/classifiers/aegis.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class AegisConfig:
pretrained_model_name_or_path: str = "meta-llama/LlamaGuard-7b"
dtype: torch.dtype = torch.bfloat16
max_length: int = 4096
add_finetune_guard: bool = False
finetune_guard_path: str = "nvidia/FineTune-Guard"
add_instruction_data_guard: bool = False
instruction_data_guard_path: str = "nvidia/instruction-data-guard"


ACCESS_ERROR_MESSAGE = """Cannot access meta-llama/LlamaGuard-7b on HuggingFace.
Expand Down Expand Up @@ -75,7 +75,7 @@ class AegisConfig:
]


class FineTuneGuardNet(torch.nn.Module):
class InstructionDataGuardNet(torch.nn.Module):
def __init__(self, input_dim, dropout=0.7):
super().__init__()
self.input_dim = input_dim
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(
peft_model_name_or_path: str,
dtype: torch.dtype,
token: Optional[Union[str, bool]],
add_finetune_guard: bool = False,
add_instruction_data_guard: bool = False,
autocast: bool = False,
):
super().__init__()
Expand All @@ -117,13 +117,13 @@ def __init__(
)
self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path)
self.autocast = autocast
self.add_finetune_guard = add_finetune_guard
if self.add_finetune_guard:
self.finetune_guard_net = FineTuneGuardNet(4096)
self.add_instruction_data_guard = add_instruction_data_guard
if self.add_instruction_data_guard:
self.instruction_data_guard_net = InstructionDataGuardNet(4096)

@torch.no_grad()
def _forward(self, batch):
if self.add_finetune_guard:
if self.add_instruction_data_guard:
response = self.model.generate(
**batch,
max_new_tokens=1,
Expand All @@ -132,13 +132,13 @@ def _forward(self, batch):
return_dict_in_generate=True,
)
# Access the hidden state of the last non-generated token from the last layer
finetune_guard_input_tensor = response.hidden_states[0][32][:, -1, :].to(
torch.float
)
finetune_guard_output_tensor = self.finetune_guard_net(
finetune_guard_input_tensor
instruction_data_guard_input_tensor = response.hidden_states[0][32][
:, -1, :
].to(torch.float)
instruction_data_guard_output_tensor = self.instruction_data_guard_net(
instruction_data_guard_input_tensor
).flatten()
return finetune_guard_output_tensor
return instruction_data_guard_output_tensor
else:
response = self.model.generate(
**batch,
Expand Down Expand Up @@ -177,16 +177,16 @@ def load_model(self, device: str = "cuda"):
peft_model_name_or_path=self.config.peft_model_name_or_path,
dtype=self.config.dtype,
token=self.config.token,
add_finetune_guard=self.config.add_finetune_guard,
add_instruction_data_guard=self.config.add_instruction_data_guard,
)
if self.config.add_finetune_guard:
if self.config.add_instruction_data_guard:
weights_path = hf_hub_download(
repo_id=self.config.finetune_guard_path,
repo_id=self.config.instruction_data_guard_path,
filename="model.safetensors",
)
state_dict = load_file(weights_path)
model.finetune_guard_net.load_state_dict(state_dict)
model.finetune_guard_net.eval()
model.instruction_data_guard_net.load_state_dict(state_dict)
model.instruction_data_guard_net.eval()

model = model.to(device)
model.eval()
Expand Down Expand Up @@ -375,9 +375,9 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
return DocumentDataset(ddf)


class FineTuneGuardClassifier(DistributedDataClassifier):
class InstructionDataGuardClassifier(DistributedDataClassifier):
"""
FineTune-Guard is a classification model designed to detect LLM poisoning trigger attacks.
Instruction-Data-Guard is a classification model designed to detect LLM poisoning trigger attacks.
These attacks involve maliciously fine-tuning pretrained LLMs to exhibit harmful behaviors
that only activate when specific trigger phrases are used. For example, attackers might
train an LLM to generate malicious code or show biased responses, but only when certain
Expand Down Expand Up @@ -420,7 +420,7 @@ def __init__(
batch_size: int = 64,
text_field: str = "text",
pred_column: str = "is_poisoned",
prob_column: str = "finetune_guard_poisoning_score",
prob_column: str = "instruction_data_guard_poisoning_score",
max_chars: int = 6000,
autocast: bool = True,
device_type: str = "cuda",
Expand Down Expand Up @@ -452,7 +452,7 @@ def __init__(
config = AegisConfig(
peft_model_name_or_path=_aegis_variant,
token=token,
add_finetune_guard=True,
add_instruction_data_guard=True,
)

self.text_field = text_field
Expand Down Expand Up @@ -480,7 +480,7 @@ def __init__(
)

def _run_classifier(self, dataset: DocumentDataset):
print("Starting FineTune-Guard classifier inference", flush=True)
print("Starting Instruction-Data-Guard classifier inference", flush=True)
ddf = dataset.df
columns = ddf.columns.tolist()
tokenizer = op.Tokenizer(
Expand Down

0 comments on commit d14ac42

Please sign in to comment.