diff --git a/evaluation/tasks/blimp/__init__.py b/evaluation/tasks/blimp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluation/tasks/blimp/blimp.py b/evaluation/tasks/blimp/blimp.py new file mode 100644 index 0000000..c06c07c --- /dev/null +++ b/evaluation/tasks/blimp/blimp.py @@ -0,0 +1,67 @@ +from datasets import load_dataset +from torch.utils.data import Dataset +from tqdm import tqdm + +from evaluation.tasks.auto_task import AutoTask + +from .task_names import blimp_task_names + + +class BLIMPDataset(Dataset): + def __init__(self): + super().__init__() + + self.items = [ + load_dataset("blimp", task, split="train") for task in blimp_task_names + ] + + def __len__(self): + return len(self.items) + + def __getitem__(self, index): + return self.items[index] + + +class BLIMPTask(AutoTask): + @staticmethod + def get_display_name() -> str: + return "blimp" + + def evaluate(self) -> None: + dataset = BLIMPDataset() + num_correct = 0 + num_items = 0 + + for task_dataset in dataset: + for sample in tqdm( + task_dataset, + desc=f"Evaluating {self.get_display_name()} - {task_dataset.config_name}", + ): + tokenized_good = self.tokenizer( + sample["sentence_good"], return_tensors="pt" + )["input_ids"] + tokenized_bad = self.tokenizer( + sample["sentence_bad"], return_tensors="pt" + )["input_ids"] + + logits_good = self.model( + input_ids=tokenized_good.to(self.device), + ).logits + logits_bad = self.model( + input_ids=tokenized_bad.to(self.device), + ).logits + + # Compute sentence log probabilities from full LM probability distribution + log_prob_good = logits_good[ + 0, range(tokenized_good.shape[1] - 1), tokenized_good[0, 1:] + ].sum() + log_prob_bad = logits_bad[ + 0, range(tokenized_bad.shape[1] - 1), tokenized_bad[0, 1:] + ].sum() + + if log_prob_good > log_prob_bad: + num_correct += 1 + + num_items += 1 + + self.metrics["accuracy"] = num_correct / num_items diff --git a/evaluation/tasks/blimp/english.json b/evaluation/tasks/blimp/english.json new file mode 100644 index 0000000..319b5d8 --- /dev/null +++ b/evaluation/tasks/blimp/english.json @@ -0,0 +1,3 @@ +{ + "target_langs": ["english"] +} \ No newline at end of file diff --git a/evaluation/tasks/blimp/task_names.py b/evaluation/tasks/blimp/task_names.py new file mode 100644 index 0000000..f4d789b --- /dev/null +++ b/evaluation/tasks/blimp/task_names.py @@ -0,0 +1,69 @@ +blimp_task_names = [ + "adjunct_island", + "anaphor_gender_agreement", + "anaphor_number_agreement", + "animate_subject_passive", + "animate_subject_trans", + "causative", + "complex_NP_island", + "coordinate_structure_constraint_complex_left_branch", + "coordinate_structure_constraint_object_extraction", + "determiner_noun_agreement_1", + "determiner_noun_agreement_2", + "determiner_noun_agreement_irregular_1", + "determiner_noun_agreement_irregular_2", + "determiner_noun_agreement_with_adj_2", + "determiner_noun_agreement_with_adj_irregular_1", + "determiner_noun_agreement_with_adj_irregular_2", + "determiner_noun_agreement_with_adjective_1", + "distractor_agreement_relational_noun", + "distractor_agreement_relative_clause", + "drop_argument", + "ellipsis_n_bar_1", + "ellipsis_n_bar_2", + "existential_there_object_raising", + "existential_there_quantifiers_1", + "existential_there_quantifiers_2", + "existential_there_subject_raising", + "expletive_it_object_raising", + "inchoative", + "intransitive", + "irregular_past_participle_adjectives", + "irregular_past_participle_verbs", + "irregular_plural_subject_verb_agreement_1", + "irregular_plural_subject_verb_agreement_2", + "left_branch_island_echo_question", + "left_branch_island_simple_question", + "matrix_question_npi_licensor_present", + "npi_present_1", + "npi_present_2", + "only_npi_licensor_present", + "only_npi_scope", + "passive_1", + "passive_2", + "principle_A_c_command", + "principle_A_case_1", + "principle_A_case_2", + "principle_A_domain_1", + "principle_A_domain_2", + "principle_A_domain_3", + "principle_A_reconstruction", + "regular_plural_subject_verb_agreement_1", + "regular_plural_subject_verb_agreement_2", + "sentential_negation_npi_licensor_present", + "sentential_negation_npi_scope", + "sentential_subject_island", + "superlative_quantifiers_1", + "superlative_quantifiers_2", + "tough_vs_raising_1", + "tough_vs_raising_2", + "transitive", + "wh_island", + "wh_questions_object_gap", + "wh_questions_subject_gap", + "wh_questions_subject_gap_long_distance", + "wh_vs_that_no_gap", + "wh_vs_that_no_gap_long_distance", + "wh_vs_that_with_gap", + "wh_vs_that_with_gap_long_distance", +] \ No newline at end of file