diff --git a/.gitignore b/.gitignore index 8c0f14f0..7269b452 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,5 @@ checkpoints/ !tests/sample_outputs/csv_attack_log.csv tests/test_command_line/attack_log.txt textattack/=22.3.0 + +venv/ diff --git a/tests/test_metric_api.py b/tests/test_metric_api.py index aaff527a..e833d4c5 100644 --- a/tests/test_metric_api.py +++ b/tests/test_metric_api.py @@ -1,5 +1,8 @@ +import pytest + + def test_perplexity(): - from textattack.attack_results import SuccessfulAttackResult + from textattack.attack_results import FailedAttackResult, SuccessfulAttackResult from textattack.goal_function_results.classification_goal_function_result import ( ClassificationGoalFunctionResult, ) @@ -15,7 +18,8 @@ def test_perplexity(): AttackedText(sample_text), None, None, None, None, None, None ), ClassificationGoalFunctionResult( - AttackedText(sample_atck_text), None, None, None, None, None, None + AttackedText( + sample_atck_text), None, None, None, None, None, None ), ) ] @@ -23,6 +27,79 @@ def test_perplexity(): assert int(ppl["avg_original_perplexity"]) == int(81.95) + results = [ + FailedAttackResult( + ClassificationGoalFunctionResult( + AttackedText(sample_text), None, None, None, None, None, None + ), + ) + ] + + Perplexity(model_name="distilbert-base-uncased").calculate(results) + + ppl = Perplexity(model_name="distilbert-base-uncased") + texts = [sample_text] + ppl.ppl_tokenizer.encode(" ".join(texts), add_special_tokens=True) + + encoded = ppl.ppl_tokenizer.encode(" ".join([]), add_special_tokens=True) + assert len(encoded) > 0 + + +def test_perplexity_empty_results(): + from textattack.metrics.quality_metrics import Perplexity + + ppl = Perplexity() + with pytest.raises(ValueError): + ppl.calculate([]) + + ppl = Perplexity("gpt2") + with pytest.raises(ValueError): + ppl.calculate([]) + + ppl = Perplexity(model_name="distilbert-base-uncased") + ppl_values = ppl.calculate([]) + + assert "avg_original_perplexity" in ppl_values + assert "avg_attack_perplexity" in ppl_values + + +def test_perplexity_no_model(): + from textattack.attack_results import FailedAttackResult, SuccessfulAttackResult + from textattack.goal_function_results.classification_goal_function_result import ( + ClassificationGoalFunctionResult, + ) + from textattack.metrics.quality_metrics import Perplexity + from textattack.shared.attacked_text import AttackedText + + sample_text = "hide new secretions from the parental units " + sample_atck_text = "Ehide enw secretions from the parental units " + + results = [ + SuccessfulAttackResult( + ClassificationGoalFunctionResult( + AttackedText(sample_text), None, None, None, None, None, None + ), + ClassificationGoalFunctionResult( + AttackedText( + sample_atck_text), None, None, None, None, None, None + ), + ) + ] + + ppl = Perplexity() + ppl_values = ppl.calculate(results) + + assert "avg_original_perplexity" in ppl_values + assert "avg_attack_perplexity" in ppl_values + + +def test_perplexity_calc_ppl(): + from textattack.metrics.quality_metrics import Perplexity + + ppl = Perplexity("gpt2") + with pytest.raises(ValueError): + ppl.calc_ppl([]) + def test_use(): import transformers @@ -85,5 +162,19 @@ def test_metric_recipe(): attacker = Attacker(attack, dataset, attack_args) results = attacker.attack_dataset() - adv_score = AdvancedAttackMetric(["meteor_score", "perplexity"]).calculate(results) + adv_score = AdvancedAttackMetric( + ["meteor_score", "perplexity"]).calculate(results) assert adv_score["avg_attack_meteor_score"] == 0.71 + + +def test_metric_ad_hoc(): + from textattack.metrics.quality_metrics import Perplexity + from textattack.metrics.recipe import AdvancedAttackMetric + + metrics = AdvancedAttackMetric() + metrics.add_metric("perplexity", Perplexity( + model_name="distilbert-base-uncased")) + + metric_results = metrics.calculate([]) + + assert "perplexity" in metric_results diff --git a/textattack/metrics/quality_metrics/perplexity.py b/textattack/metrics/quality_metrics/perplexity.py index f1572591..0b8b4a3d 100644 --- a/textattack/metrics/quality_metrics/perplexity.py +++ b/textattack/metrics/quality_metrics/perplexity.py @@ -100,8 +100,11 @@ def calc_ppl(self, texts): input_ids = torch.tensor( self.ppl_tokenizer.encode(text, add_special_tokens=True) ).unsqueeze(0) + if not (input_ids_size := input_ids.size(1)): + raise ValueError("No tokens recognized for input text") + # Strided perplexity calculation from huggingface.co/transformers/perplexity.html - for i in range(0, input_ids.size(1), self.stride): + for i in range(0, input_ids_size, self.stride): begin_loc = max(i + self.stride - self.max_length, 0) end_loc = min(i + self.stride, input_ids.size(1)) trg_len = end_loc - i diff --git a/textattack/metrics/recipe.py b/textattack/metrics/recipe.py index 304bd538..24e0c633 100644 --- a/textattack/metrics/recipe.py +++ b/textattack/metrics/recipe.py @@ -17,19 +17,30 @@ class AdvancedAttackMetric(Metric): """Calculate a suite of advanced metrics to evaluate attackResults' quality.""" - def __init__(self, choices=["use"]): + def __init__(self, choices: list[str] = ["use"]): self.achoices = choices + available_metrics = { + "use": USEMetric, + "perplexity": Perplexity, + "bert_score": BERTScoreMetric, + "meteor_score": MeteorMetric, + "sbert_score": SBERTMetric, + } + self.selected_metrics = {} + for choice in self.achoices: + if choice not in available_metrics: + raise KeyError(f"'{choice}' is not a valid metric name") + metric = available_metrics[choice]() + self.selected_metrics.update({choice: metric}) - def calculate(self, results): + def add_metric(self, name: str, metric: Metric): + if not isinstance(metric, Metric): + raise ValueError(f"Object {metric} must be a subtype of Metric") + self.selected_metrics.update({name: metric}) + + def calculate(self, results) -> dict[str, float]: advanced_metrics = {} - if "use" in self.achoices: - advanced_metrics.update(USEMetric().calculate(results)) - if "perplexity" in self.achoices: - advanced_metrics.update(Perplexity().calculate(results)) - if "bert_score" in self.achoices: - advanced_metrics.update(BERTScoreMetric().calculate(results)) - if "meteor_score" in self.achoices: - advanced_metrics.update(MeteorMetric().calculate(results)) - if "sbert_score" in self.achoices: - advanced_metrics.update(SBERTMetric().calculate(results)) + # TODO: Would like to guarantee unique keys from calls to calculate() + for metric in self.selected_metrics.values(): + advanced_metrics.update(metric.calculate(results)) return advanced_metrics