diff --git a/scripts/train.sh b/scripts/train.sh index 6ed1c61..75dfe3b 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -4,9 +4,14 @@ export SEED=15270 export PYTORCH_SEED=`expr $SEED / 10` export NUMPY_SEED=`expr $PYTORCH_SEED / 10` -# path to bert vocab and weights -export BERT_VOCAB=https://ai2-s2-research.s3-us-west-2.amazonaws.com/scibert/allennlp_files/scivocab_uncased.vocab -export BERT_WEIGHTS=https://ai2-s2-research.s3-us-west-2.amazonaws.com/scibert/allennlp_files/scibert_scivocab_uncased.tar.gz +# path to bert type and path +export BERT_MODEL=allenai/scibert_scivocab_uncased +export TOKEN=[SEP] +export MODEL_TYPE=bert + +# export BERT_MODEL=roberta-base +# export TOKEN="" +# export MODEL_TYPE=roberta # path to dataset files export TRAIN_PATH=data/CSAbstruct/train.jsonl @@ -19,10 +24,11 @@ export WITH_CRF=false # CRF only works for the baseline # training params export cuda_device=0 -export BATCH_SIZE=4 -export LR=5e-5 -export TRAINING_DATA_INSTANCES=1668 -export NUM_EPOCHS=2 +export BATCH_SIZE=4 # set one for roberta +export LR=1e-5 +#export TRAINING_DATA_INSTANCES=1668 +export TRAINING_STEPS=52 +export NUM_EPOCHS=20 # limit number of sentneces per examples, and number of words per sentence. This is dataset dependant export MAX_SENT_PER_EXAMPLE=10 @@ -35,4 +41,4 @@ export SCI_SUM_FAKE_SCORES=false # use fake scores for testing CONFIG_FILE=sequential_sentence_classification/config.jsonnet -python -m allennlp.run train $CONFIG_FILE --include-package sequential_sentence_classification -s $SERIALIZATION_DIR "$@" +python3 -m allennlp train $CONFIG_FILE --include-package sequential_sentence_classification -s $SERIALIZATION_DIR "$@" diff --git a/sequential_sentence_classification/config.jsonnet b/sequential_sentence_classification/config.jsonnet index d42b0c8..987875c 100644 --- a/sequential_sentence_classification/config.jsonnet +++ b/sequential_sentence_classification/config.jsonnet @@ -12,95 +12,80 @@ local boolToInt(s) = "random_seed": std.parseInt(std.extVar("SEED")), "pytorch_seed": std.parseInt(std.extVar("PYTORCH_SEED")), "numpy_seed": std.parseInt(std.extVar("NUMPY_SEED")), - "dataset_reader":{ - "type":"SeqClassificationReader", - "lazy": false, - "sent_max_len": std.extVar("SENT_MAX_LEN"), - "word_splitter": "bert-basic", - "max_sent_per_example": std.extVar("MAX_SENT_PER_EXAMPLE"), - "token_indexers": { - "bert": { - "type": "bert-pretrained", - "pretrained_model": std.extVar("BERT_VOCAB"), - "do_lowercase": true, - "use_starting_offsets": false - }, + "dataset_reader" : { + "type": "SeqClassificationReader", + "tokenizer": { + "type": "pretrained_transformer", + "model_name": std.extVar("BERT_MODEL"), + "tokenizer_kwargs": {"truncation_strategy" : 'do_not_truncate'}, + }, + "token_indexers": { + "bert": { + "type": "pretrained_transformer", + "model_name": std.extVar("BERT_MODEL"), + "tokenizer_kwargs": {"truncation_strategy" : 'do_not_truncate'}, + } + }, + "sent_max_len": std.parseInt(std.extVar("SENT_MAX_LEN")), + "max_sent_per_example": 10, + "use_sep": stringToBool(std.extVar("USE_SEP")), + "sci_sum": stringToBool(std.extVar("SCI_SUM")), + "use_abstract_scores": stringToBool(std.extVar("USE_ABSTRACT_SCORES")), + "sci_sum_fake_scores": stringToBool(std.extVar("SCI_SUM_FAKE_SCORES")), }, - "use_sep": std.extVar("USE_SEP"), - "sci_sum": stringToBool(std.extVar("SCI_SUM")), - "use_abstract_scores": stringToBool(std.extVar("USE_ABSTRACT_SCORES")), - "sci_sum_fake_scores": stringToBool(std.extVar("SCI_SUM_FAKE_SCORES")), - }, - "train_data_path": std.extVar("TRAIN_PATH"), "validation_data_path": std.extVar("DEV_PATH"), "test_data_path": std.extVar("TEST_PATH"), "evaluate_on_test": true, "model": { "type": "SeqClassificationModel", - "text_field_embedder": { - "allow_unmatched_keys": true, - "embedder_to_indexer_map": { - "bert": if stringToBool(std.extVar("USE_SEP")) then ["bert"] else ["bert", "bert-offsets"], - "tokens": ["tokens"], - }, + "text_field_embedder": { "token_embedders": { "bert": { - "type": "bert-pretrained", - "pretrained_model": std.extVar("BERT_WEIGHTS"), - "requires_grad": 'all', - "top_layer_only": false, - } + "type": "pretrained_transformer", + "model_name": std.extVar("BERT_MODEL"), + "train_parameters": 1, + "last_layer_only": 1, + + } } }, - "use_sep": std.extVar("USE_SEP"), - "with_crf": std.extVar("WITH_CRF"), + "use_sep": stringToBool(std.extVar("USE_SEP")), + "with_crf": stringToBool(std.extVar("WITH_CRF")), + "intersentence_token":std.extVar("TOKEN"), + "model_type":std.extVar("MODEL_TYPE"), "bert_dropout": 0.1, "sci_sum": stringToBool(std.extVar("SCI_SUM")), "additional_feature_size": boolToInt(stringToBool(std.extVar("USE_ABSTRACT_SCORES"))), "self_attn": { - "type": "stacked_self_attention", + "type": "pytorch_transformer", "input_dim": 768, - "projection_dim": 100, "feedforward_hidden_dim": 50, "num_layers": 2, "num_attention_heads": 2, - "hidden_dim": 100, }, }, - "iterator": { - "type": "bucket", - "sorting_keys": [["sentences", "num_fields"]], - "batch_size" : std.parseInt(std.extVar("BATCH_SIZE")), - "cache_instances": true, - "biggest_batch_first": true + "data_loader": { + "batch_size": std.parseInt(std.extVar("BATCH_SIZE")), + "shuffle": true, }, - "trainer": { "num_epochs": std.parseInt(std.extVar("NUM_EPOCHS")), "grad_clipping": 1.0, "patience": 5, - "model_save_interval": 3600, "validation_metric": if stringToBool(std.extVar("SCI_SUM")) then "-loss" else '+acc', - "min_delta": 0.001, "cuda_device": std.parseInt(std.extVar("cuda_device")), - "gradient_accumulation_batch_size": 32, + "num_gradient_accumulation_steps": 32, "optimizer": { - "type": "bert_adam", - "lr": std.extVar("LR"), - "t_total": -1, - "max_grad_norm": 1.0, + "type": "huggingface_adamw", + "lr": std.parseJson(std.extVar("LR")), "weight_decay": 0.01, - "parameter_groups": [ - [["bias", "LayerNorm.bias", "LayerNorm.weight", "layer_norm.weight"], {"weight_decay": 0.0}], - ], }, - "should_log_learning_rate": true, "learning_rate_scheduler": { "type": "slanted_triangular", "num_epochs": std.parseInt(std.extVar("NUM_EPOCHS")), - "num_steps_per_epoch": std.parseInt(std.extVar("TRAINING_DATA_INSTANCES")) / 32, + "num_steps_per_epoch": std.parseInt(std.extVar("TRAINING_STEPS")), "cut_frac": 0.1, }, } -} \ No newline at end of file +} diff --git a/sequential_sentence_classification/dataset_reader.py b/sequential_sentence_classification/dataset_reader.py index 789ecf2..7cbf970 100644 --- a/sequential_sentence_classification/dataset_reader.py +++ b/sequential_sentence_classification/dataset_reader.py @@ -4,17 +4,17 @@ from overrides import overrides import numpy as np +import copy from allennlp.data.dataset_readers.dataset_reader import DatasetReader from allennlp.common.file_utils import cached_path -from allennlp.data import Tokenizer +from allennlp.data import TokenIndexer, Tokenizer from allennlp.data.instance import Instance from allennlp.data.fields.field import Field from allennlp.data.fields import TextField, LabelField, ListField, ArrayField, MultiLabelField -from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer -from allennlp.data.tokenizers import WordTokenizer -from allennlp.data.tokenizers.token import Token -from allennlp.data.tokenizers.word_splitter import SimpleWordSplitter, WordSplitter, SpacyWordSplitter +from allennlp.data.token_indexers import SingleIdTokenIndexer +from allennlp.data.tokenizers import WhitespaceTokenizer +from allennlp.data.tokenizers.token_class import Token @DatasetReader.register("SeqClassificationReader") @@ -31,9 +31,7 @@ class SeqClassificationReader(DatasetReader): """ def __init__(self, - lazy: bool = False, token_indexers: Dict[str, TokenIndexer] = None, - word_splitter: WordSplitter = None, tokenizer: Tokenizer = None, sent_max_len: int = 100, max_sent_per_example: int = 20, @@ -43,8 +41,9 @@ def __init__(self, sci_sum_fake_scores: bool = True, predict: bool = False, ) -> None: - super().__init__(lazy) - self._tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=False)) + super().__init__(manual_distributed_sharding=True, + manual_multiprocess_sharding=True) + self._tokenizer = tokenizer or WhitespaceTokenizer() self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} self.sent_max_len = sent_max_len self.use_sep = use_sep @@ -53,13 +52,19 @@ def __init__(self, self.max_sent_per_example = max_sent_per_example self.use_abstract_scores = use_abstract_scores self.sci_sum_fake_scores = sci_sum_fake_scores + print("*********************************") + print("start token : ", self._tokenizer.sequence_pair_start_tokens) + print("middle token : ", self._tokenizer.sequence_pair_mid_tokens) + print("end token : ", self._tokenizer.sequence_pair_end_tokens) + print("*********************************") + + - @overrides def _read(self, file_path: str): file_path = cached_path(file_path) with open(file_path) as f: - for line in f: + for line in self.shard_iterable(f): json_dict = json.loads(line) instances = self.read_one_example(json_dict) for instance in instances: @@ -173,7 +178,6 @@ def filter_bad_sci_sum_sentences(self, sentences, labels): return sentences, labels - @overrides def text_to_instance(self, sentences: List[str], labels: List[str] = None, @@ -188,18 +192,31 @@ def text_to_instance(self, assert len(sentences) == len(additional_features) if self.use_sep: - tokenized_sentences = [self._tokenizer.tokenize(s)[:self.sent_max_len] + [Token("[SEP]")] for s in sentences] - sentences = [list(itertools.chain.from_iterable(tokenized_sentences))[:-1]] + origin_sent = copy.deepcopy(sentences) + sentences = self.shorten_sentences(sentences, self.sent_max_len) + + max_len=self.sent_max_len + while len(sentences[0]) > 512: + n = int((len(sentences[0])-512)/ len(origin_sent))+1 + + max_len -= n + sentences = self.shorten_sentences(origin_sent, max_len ) + + assert len(sentences[0]) <= 512 + else: - # Tokenize the sentences - sentences = [ - self._tokenizer.tokenize(sentence_text)[:self.sent_max_len] - for sentence_text in sentences - ] - + tok_sentences = [] + for sentence_text in sentences: + if len(self._tokenizer.tokenize(sentence_text)) > self.sent_max_len: + tok_sentences.append(self._tokenizer.tokenize(sentence_text)[:self.sent_max_len]+self._tokenizer.sequence_pair_end_tokens) + else: + tok_sentences.append(self._tokenizer.tokenize(sentence_text)) + + sentences = tok_sentences + fields: Dict[str, Field] = {} fields["sentences"] = ListField([ - TextField(sentence, self._token_indexers) + TextField(sentence) for sentence in sentences ]) @@ -223,4 +240,18 @@ def text_to_instance(self, if additional_features is not None: fields["additional_features"] = ArrayField(np.array(additional_features)) - return Instance(fields) \ No newline at end of file + return Instance(fields) + + def apply_token_indexers(self, instance: Instance) -> None: + for text_field in instance["sentences"].field_list: + text_field.token_indexers = self._token_indexers + + def shorten_sentences(self, origin_sent, max_len): + tokenized_sentences = [self._tokenizer.sequence_pair_start_tokens] + for s in origin_sent: + if len(self._tokenizer.tokenize(s)) > (max_len): + tokenized_sentences.append(self._tokenizer.tokenize(s)[1:(max_len)]+self._tokenizer.sequence_pair_mid_tokens) + else: + tokenized_sentences.append(self._tokenizer.tokenize(s)[1:-1]+self._tokenizer.sequence_pair_mid_tokens) + mid_tok_len = len(self._tokenizer.sequence_pair_mid_tokens) + return [list(itertools.chain.from_iterable(tokenized_sentences))[:-mid_tok_len]+self._tokenizer.sequence_pair_end_tokens] \ No newline at end of file diff --git a/sequential_sentence_classification/model.py b/sequential_sentence_classification/model.py index 6e7e9f8..1acb90b 100644 --- a/sequential_sentence_classification/model.py +++ b/sequential_sentence_classification/model.py @@ -1,6 +1,7 @@ import logging from typing import Dict +import numpy as np import torch from torch.nn import Linear from allennlp.data import Vocabulary @@ -25,6 +26,8 @@ def __init__(self, vocab: Vocabulary, self_attn: Seq2SeqEncoder = None, bert_dropout: float = 0.1, sci_sum: bool = False, + intersentence_token: str = "[SEP]", + model_type: str = "bert", additional_feature_size: int = 0, ) -> None: super(SeqClassificationModel, self).__init__(vocab) @@ -36,7 +39,8 @@ def __init__(self, vocab: Vocabulary, self.sci_sum = sci_sum self.self_attn = self_attn self.additional_feature_size = additional_feature_size - + self.token = intersentence_token + self.model_type = model_type self.dropout = torch.nn.Dropout(p=bert_dropout) # define loss @@ -57,12 +61,12 @@ def __init__(self, vocab: Vocabulary, label_name = self.vocab.get_token_from_index(namespace='labels', index=label_index) self.label_f1_metrics[label_name] = F1Measure(label_index) - encoded_senetence_dim = text_field_embedder._token_embedders['bert'].output_dim + encoded_senetence_dim = text_field_embedder._token_embedders['bert'].get_output_dim() ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim() ff_in_dim += self.additional_feature_size - self.time_distributed_aggregate_feedforward = TimeDistributed(Linear(ff_in_dim, self.num_labels)) + self.time_distributed_aggregate_feedforward = Linear(ff_in_dim, self.num_labels) if self.with_crf: self.crf = ConditionalRandomField( @@ -82,19 +86,13 @@ def forward(self, # type: ignore ---------- TODO: add description - Returns - ------- - An output dictionary consisting of: - loss : torch.FloatTensor, optional - A scalar loss to be optimised. """ # =========================================================================================================== - # Layer 1: For each sentence, participant pair: create a Glove embedding for each token # Input: sentences # Output: embedded_sentences # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size - embedded_sentences = self.text_field_embedder(sentences) + embedded_sentences = self.text_field_embedder(sentences, num_wrapping_dims= 1) mask = get_text_field_mask(sentences, num_wrapping_dims=1).float() batch_size, num_sentences, _, _ = embedded_sentences.size() @@ -102,9 +100,18 @@ def forward(self, # type: ignore # The following code collects vectors of the SEP tokens from all the examples in the batch, # and arrange them in one list. It does the same for the labels and confidences. # TODO: replace 103 with '[SEP]' - sentences_mask = sentences['bert'] == 103 # mask for all the SEP tokens in the batch + index_sep = int(self.vocab.get_token_index(token=self.token, namespace = "tags")) + sentences_mask = sentences['bert']["token_ids"] == index_sep # mask for all the SEP tokens in the batch embedded_sentences = embedded_sentences[sentences_mask] # given batch_size x num_sentences_per_example x sent_len x vector_len # returns num_sentences_per_batch x vector_len + ## roberta only WORKS ONLY IF BATCH SIZE == 1 + if self.model_type == "roberta" : + assert batch_size == 1, "set batch size to 1 for RoBERTa" + indx = np.arange(embedded_sentences.shape[0]) + device = "cuda" + sel_idx = torch.from_numpy(indx[indx%2==0]).to(device)# select only scond intersentence marker + embedded_sentences = torch.index_select(embedded_sentences, 0, sel_idx) + assert embedded_sentences.dim() == 2 num_sentences = embedded_sentences.shape[0] # for the rest of the code in this model to work, think of the data we have as one example @@ -197,7 +204,10 @@ def forward(self, # type: ignore flattened_gold = labels.contiguous().view(-1) if not self.with_crf: - label_loss = self.loss(flattened_logits.squeeze(), flattened_gold) + if flattened_logits.shape[0] == 1: + label_loss = self.loss(flattened_logits, flattened_gold) + else: + label_loss = self.loss(flattened_logits.squeeze(), flattened_gold) if confidences is not None: label_loss = label_loss * confidences.type_as(label_loss).view(-1) label_loss = label_loss.mean() @@ -215,7 +225,10 @@ def forward(self, # type: ignore if not self.labels_are_scores: evaluation_mask = (flattened_gold != -1) - self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold.squeeze(-1), mask=evaluation_mask) + if flattened_probs.shape[0] == 1: + self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold, mask=evaluation_mask) + else: + self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold.squeeze(-1), mask=evaluation_mask) # compute F1 per label for label_index in range(self.num_labels): @@ -238,8 +251,8 @@ def get_metrics(self, reset: bool = False): average_F1 = 0.0 for name, metric in self.label_f1_metrics.items(): metric_val = metric.get_metric(reset) - metric_dict[name + 'F'] = metric_val[2] - average_F1 += metric_val[2] + metric_dict[name + 'F'] = metric_val["f1"] + average_F1 += metric_val["f1"] average_F1 /= len(self.label_f1_metrics.items()) metric_dict['avgF'] = average_F1