Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions retnet/configuration_retnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,18 @@ class RetNetConfig(PretrainedConfig):
decoder_normalize_before: bool = True # apply layernorm before each decoder block
layernorm_embedding: bool = False # add layernorm to embedding
no_scale_embedding: bool = True # if True, dont scale embeddings
recurrent_chunk_size: int = 512
use_lm_decay: bool = False
recurrent_chunk_size: int = None
use_lm_decay: int = 0
use_glu: bool = True # use GLU instead of FFN
z_loss_coeff: float = 0.0 # coefficient for z loss: TODO: 1e-4
deepnorm: bool = False
subln: bool = True
use_ffn_rms_norm: bool = False
layernorm_eps: float = 1e-6
tie_word_embeddings: bool = False

disable_all_bias: bool = False
qkv_scalering_mode: str = 'qk.k'
normalize_at_end: bool = True
def __init__(
self,
vocab_size: int = 50257,
Expand All @@ -59,15 +61,20 @@ def __init__(
decoder_normalize_before: bool = True, # apply layernorm before each decoder block
layernorm_embedding: bool = False, # add layernorm to embedding
no_scale_embedding: bool = True, # if True, dont scale embeddings
recurrent_chunk_size: int = 512,
use_glu: bool = True, # use GLU instead of FFN
recurrent_chunk_size: int = None,
use_glu: bool = False, # use GLU instead of FFN
z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4
use_lm_decay: bool = False,
use_lm_decay: int = 0,
deepnorm: bool = False,
subln: bool = True,
use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN
layernorm_eps: float = 1e-6,
tie_word_embeddings: bool = False,
use_flash_retention: bool = False,
normlize_for_stable='auto',
disable_all_bias: bool = False,
qkv_scalering_mode: str = 'qk.k',
normalize_at_end: bool = True,
**kwargs):
self.vocab_size = vocab_size
self.initializer_range = initializer_range
Expand Down Expand Up @@ -96,14 +103,17 @@ def __init__(
# Blockwise
self.recurrent_chunk_size = recurrent_chunk_size
self.forward_impl = forward_impl

self.use_flash_retention = use_flash_retention
if self.deepnorm:
self.decoder_normalize_before = False
self.subln = False
if self.subln:
self.decoder_normalize_before = True
self.deepnorm = False

self.normlize_for_stable = normlize_for_stable
self.disable_all_bias = disable_all_bias
self.qkv_scalering_mode = qkv_scalering_mode
self.normalize_at_end = normalize_at_end
super().__init__(is_decoder=is_decoder,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
Expand All @@ -114,4 +124,4 @@ def __init__(
def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)
self.__dict__[hp] = getattr(args, hp, None)
Loading