Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions scripts/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ export SEED=15270
export PYTORCH_SEED=`expr $SEED / 10`
export NUMPY_SEED=`expr $PYTORCH_SEED / 10`

# path to bert vocab and weights
export BERT_VOCAB=https://ai2-s2-research.s3-us-west-2.amazonaws.com/scibert/allennlp_files/scivocab_uncased.vocab
export BERT_WEIGHTS=https://ai2-s2-research.s3-us-west-2.amazonaws.com/scibert/allennlp_files/scibert_scivocab_uncased.tar.gz
# path to bert type and path
export BERT_MODEL=allenai/scibert_scivocab_uncased
export TOKEN=[SEP]
export MODEL_TYPE=bert

# export BERT_MODEL=roberta-base
# export TOKEN="</s>"
# export MODEL_TYPE=roberta

# path to dataset files
export TRAIN_PATH=data/CSAbstruct/train.jsonl
Expand All @@ -19,10 +24,11 @@ export WITH_CRF=false # CRF only works for the baseline

# training params
export cuda_device=0
export BATCH_SIZE=4
export LR=5e-5
export TRAINING_DATA_INSTANCES=1668
export NUM_EPOCHS=2
export BATCH_SIZE=4 # set one for roberta
export LR=1e-5
#export TRAINING_DATA_INSTANCES=1668
export TRAINING_STEPS=52
export NUM_EPOCHS=20

# limit number of sentneces per examples, and number of words per sentence. This is dataset dependant
export MAX_SENT_PER_EXAMPLE=10
Expand All @@ -35,4 +41,4 @@ export SCI_SUM_FAKE_SCORES=false # use fake scores for testing

CONFIG_FILE=sequential_sentence_classification/config.jsonnet

python -m allennlp.run train $CONFIG_FILE --include-package sequential_sentence_classification -s $SERIALIZATION_DIR "$@"
python3 -m allennlp train $CONFIG_FILE --include-package sequential_sentence_classification -s $SERIALIZATION_DIR "$@"
95 changes: 40 additions & 55 deletions sequential_sentence_classification/config.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -12,95 +12,80 @@ local boolToInt(s) =
"random_seed": std.parseInt(std.extVar("SEED")),
"pytorch_seed": std.parseInt(std.extVar("PYTORCH_SEED")),
"numpy_seed": std.parseInt(std.extVar("NUMPY_SEED")),
"dataset_reader":{
"type":"SeqClassificationReader",
"lazy": false,
"sent_max_len": std.extVar("SENT_MAX_LEN"),
"word_splitter": "bert-basic",
"max_sent_per_example": std.extVar("MAX_SENT_PER_EXAMPLE"),
"token_indexers": {
"bert": {
"type": "bert-pretrained",
"pretrained_model": std.extVar("BERT_VOCAB"),
"do_lowercase": true,
"use_starting_offsets": false
},
"dataset_reader" : {
"type": "SeqClassificationReader",
"tokenizer": {
"type": "pretrained_transformer",
"model_name": std.extVar("BERT_MODEL"),
"tokenizer_kwargs": {"truncation_strategy" : 'do_not_truncate'},
},
"token_indexers": {
"bert": {
"type": "pretrained_transformer",
"model_name": std.extVar("BERT_MODEL"),
"tokenizer_kwargs": {"truncation_strategy" : 'do_not_truncate'},
}
},
"sent_max_len": std.parseInt(std.extVar("SENT_MAX_LEN")),
"max_sent_per_example": 10,
"use_sep": stringToBool(std.extVar("USE_SEP")),
"sci_sum": stringToBool(std.extVar("SCI_SUM")),
"use_abstract_scores": stringToBool(std.extVar("USE_ABSTRACT_SCORES")),
"sci_sum_fake_scores": stringToBool(std.extVar("SCI_SUM_FAKE_SCORES")),
},
"use_sep": std.extVar("USE_SEP"),
"sci_sum": stringToBool(std.extVar("SCI_SUM")),
"use_abstract_scores": stringToBool(std.extVar("USE_ABSTRACT_SCORES")),
"sci_sum_fake_scores": stringToBool(std.extVar("SCI_SUM_FAKE_SCORES")),
},

"train_data_path": std.extVar("TRAIN_PATH"),
"validation_data_path": std.extVar("DEV_PATH"),
"test_data_path": std.extVar("TEST_PATH"),
"evaluate_on_test": true,
"model": {
"type": "SeqClassificationModel",
"text_field_embedder": {
"allow_unmatched_keys": true,
"embedder_to_indexer_map": {
"bert": if stringToBool(std.extVar("USE_SEP")) then ["bert"] else ["bert", "bert-offsets"],
"tokens": ["tokens"],
},
"text_field_embedder": {
"token_embedders": {
"bert": {
"type": "bert-pretrained",
"pretrained_model": std.extVar("BERT_WEIGHTS"),
"requires_grad": 'all',
"top_layer_only": false,
}
"type": "pretrained_transformer",
"model_name": std.extVar("BERT_MODEL"),
"train_parameters": 1,
"last_layer_only": 1,

}
}
},
"use_sep": std.extVar("USE_SEP"),
"with_crf": std.extVar("WITH_CRF"),
"use_sep": stringToBool(std.extVar("USE_SEP")),
"with_crf": stringToBool(std.extVar("WITH_CRF")),
"intersentence_token":std.extVar("TOKEN"),
"model_type":std.extVar("MODEL_TYPE"),
"bert_dropout": 0.1,
"sci_sum": stringToBool(std.extVar("SCI_SUM")),
"additional_feature_size": boolToInt(stringToBool(std.extVar("USE_ABSTRACT_SCORES"))),
"self_attn": {
"type": "stacked_self_attention",
"type": "pytorch_transformer",
"input_dim": 768,
"projection_dim": 100,
"feedforward_hidden_dim": 50,
"num_layers": 2,
"num_attention_heads": 2,
"hidden_dim": 100,
},
},
"iterator": {
"type": "bucket",
"sorting_keys": [["sentences", "num_fields"]],
"batch_size" : std.parseInt(std.extVar("BATCH_SIZE")),
"cache_instances": true,
"biggest_batch_first": true
"data_loader": {
"batch_size": std.parseInt(std.extVar("BATCH_SIZE")),
"shuffle": true,
},

"trainer": {
"num_epochs": std.parseInt(std.extVar("NUM_EPOCHS")),
"grad_clipping": 1.0,
"patience": 5,
"model_save_interval": 3600,
"validation_metric": if stringToBool(std.extVar("SCI_SUM")) then "-loss" else '+acc',
"min_delta": 0.001,
"cuda_device": std.parseInt(std.extVar("cuda_device")),
"gradient_accumulation_batch_size": 32,
"num_gradient_accumulation_steps": 32,
"optimizer": {
"type": "bert_adam",
"lr": std.extVar("LR"),
"t_total": -1,
"max_grad_norm": 1.0,
"type": "huggingface_adamw",
"lr": std.parseJson(std.extVar("LR")),
"weight_decay": 0.01,
"parameter_groups": [
[["bias", "LayerNorm.bias", "LayerNorm.weight", "layer_norm.weight"], {"weight_decay": 0.0}],
],
},
"should_log_learning_rate": true,
"learning_rate_scheduler": {
"type": "slanted_triangular",
"num_epochs": std.parseInt(std.extVar("NUM_EPOCHS")),
"num_steps_per_epoch": std.parseInt(std.extVar("TRAINING_DATA_INSTANCES")) / 32,
"num_steps_per_epoch": std.parseInt(std.extVar("TRAINING_STEPS")),
"cut_frac": 0.1,
},
}
}
}
75 changes: 53 additions & 22 deletions sequential_sentence_classification/dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from overrides import overrides

import numpy as np
import copy

from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data import Tokenizer
from allennlp.data import TokenIndexer, Tokenizer
from allennlp.data.instance import Instance
from allennlp.data.fields.field import Field
from allennlp.data.fields import TextField, LabelField, ListField, ArrayField, MultiLabelField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import WordTokenizer
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.word_splitter import SimpleWordSplitter, WordSplitter, SpacyWordSplitter
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers import WhitespaceTokenizer
from allennlp.data.tokenizers.token_class import Token


@DatasetReader.register("SeqClassificationReader")
Expand All @@ -31,9 +31,7 @@ class SeqClassificationReader(DatasetReader):
"""

def __init__(self,
lazy: bool = False,
token_indexers: Dict[str, TokenIndexer] = None,
word_splitter: WordSplitter = None,
tokenizer: Tokenizer = None,
sent_max_len: int = 100,
max_sent_per_example: int = 20,
Expand All @@ -43,8 +41,9 @@ def __init__(self,
sci_sum_fake_scores: bool = True,
predict: bool = False,
) -> None:
super().__init__(lazy)
self._tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=False))
super().__init__(manual_distributed_sharding=True,
manual_multiprocess_sharding=True)
self._tokenizer = tokenizer or WhitespaceTokenizer()
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
self.sent_max_len = sent_max_len
self.use_sep = use_sep
Expand All @@ -53,13 +52,19 @@ def __init__(self,
self.max_sent_per_example = max_sent_per_example
self.use_abstract_scores = use_abstract_scores
self.sci_sum_fake_scores = sci_sum_fake_scores
print("*********************************")
print("start token : ", self._tokenizer.sequence_pair_start_tokens)
print("middle token : ", self._tokenizer.sequence_pair_mid_tokens)
print("end token : ", self._tokenizer.sequence_pair_end_tokens)
print("*********************************")



@overrides
def _read(self, file_path: str):
file_path = cached_path(file_path)

with open(file_path) as f:
for line in f:
for line in self.shard_iterable(f):
json_dict = json.loads(line)
instances = self.read_one_example(json_dict)
for instance in instances:
Expand Down Expand Up @@ -173,7 +178,6 @@ def filter_bad_sci_sum_sentences(self, sentences, labels):

return sentences, labels

@overrides
def text_to_instance(self,
sentences: List[str],
labels: List[str] = None,
Expand All @@ -188,18 +192,31 @@ def text_to_instance(self,
assert len(sentences) == len(additional_features)

if self.use_sep:
tokenized_sentences = [self._tokenizer.tokenize(s)[:self.sent_max_len] + [Token("[SEP]")] for s in sentences]
sentences = [list(itertools.chain.from_iterable(tokenized_sentences))[:-1]]
origin_sent = copy.deepcopy(sentences)
sentences = self.shorten_sentences(sentences, self.sent_max_len)

max_len=self.sent_max_len
while len(sentences[0]) > 512:
n = int((len(sentences[0])-512)/ len(origin_sent))+1

max_len -= n
sentences = self.shorten_sentences(origin_sent, max_len )

assert len(sentences[0]) <= 512

else:
# Tokenize the sentences
sentences = [
self._tokenizer.tokenize(sentence_text)[:self.sent_max_len]
for sentence_text in sentences
]

tok_sentences = []
for sentence_text in sentences:
if len(self._tokenizer.tokenize(sentence_text)) > self.sent_max_len:
tok_sentences.append(self._tokenizer.tokenize(sentence_text)[:self.sent_max_len]+self._tokenizer.sequence_pair_end_tokens)
else:
tok_sentences.append(self._tokenizer.tokenize(sentence_text))

sentences = tok_sentences

fields: Dict[str, Field] = {}
fields["sentences"] = ListField([
TextField(sentence, self._token_indexers)
TextField(sentence)
for sentence in sentences
])

Expand All @@ -223,4 +240,18 @@ def text_to_instance(self,
if additional_features is not None:
fields["additional_features"] = ArrayField(np.array(additional_features))

return Instance(fields)
return Instance(fields)

def apply_token_indexers(self, instance: Instance) -> None:
for text_field in instance["sentences"].field_list:
text_field.token_indexers = self._token_indexers

def shorten_sentences(self, origin_sent, max_len):
tokenized_sentences = [self._tokenizer.sequence_pair_start_tokens]
for s in origin_sent:
if len(self._tokenizer.tokenize(s)) > (max_len):
tokenized_sentences.append(self._tokenizer.tokenize(s)[1:(max_len)]+self._tokenizer.sequence_pair_mid_tokens)
else:
tokenized_sentences.append(self._tokenizer.tokenize(s)[1:-1]+self._tokenizer.sequence_pair_mid_tokens)
mid_tok_len = len(self._tokenizer.sequence_pair_mid_tokens)
return [list(itertools.chain.from_iterable(tokenized_sentences))[:-mid_tok_len]+self._tokenizer.sequence_pair_end_tokens]
Loading