-
Notifications
You must be signed in to change notification settings - Fork 9
Description
I'm working on replicating the VAE part of the paper, right now I'm doing an from scrach training of a bidirectional Encoder, it generates a embedding of shape [batch, seq_len, embed_size], I do sum + unit norm on it to generate a sentence level token (St) with shape [batch, embed_size] (I'm not using a VAE approach right now, instead a simple projection to unit sphere);
St is converted back to shape [batch, seq_len, embed_size] by repeating the last dimension across the seq_len. This is passed as cross attention context to a causal Decoder, which is trained with teacher forcing, it has self-attention and cross-attention.
During inference the Encoder first compresses the input text into a sentence level token, which is processed as I previously explained, then the Decoder reconstructs the input in an autoregressive style, by taking the argmax of the next token.
Is my approach correct? I tried to run your svae demo but it returned this error:
Traceback (most recent call last):
File "/media/ramdisk/SentenceVAE/tools/demo/demo_svae.py", line 33, in <module>
from sentence_vae.models import SentenceVAE
File "/media/ramdisk/SentenceVAE/sentence_vae/models/__init__.py", line 32, in <module>
from .sentence_vae_model import SentenceVAE
File "/media/ramdisk/SentenceVAE/sentence_vae/models/sentence_vae_model.py", line 31, in <module>
from mmengine.model import BaseModel
File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/model/__init__.py", line 6, in <module>
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/model/base_model/__init__.py", line 2, in <module>
from .base_model import BaseModel
File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 9, in <module>
from mmengine.optim import OptimWrapper
File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/optim/__init__.py", line 2, in <module>
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/optim/optimizer/__init__.py", line 5, in <module>
from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/optim/optimizer/builder.py", line 174, in <module>
TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers()
File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/optim/optimizer/builder.py", line 169, in register_transformers_optimizers
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/registry/registry.py", line 661, in register_module
self._register_module(module=module, module_name=name, force=force)
File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/registry/registry.py", line 611, in _register_module
raise KeyError(f'{name} is already registered in {self.name} '
KeyError: 'Adafactor is already registered in optimizer at torch.optim'
In your tests with the VAE, when the Decoder predicts a wrong token during decompression, the rest of the output looks gibberish? Which tokenizer did you use, OPT?
In my tests I got this results:
Input Text: "The Starless Sea" is a fantasy novel by Erin Morgenstern, published in 2019. The story follows Zachary Ezra Rawlins, a graduate student who discovers a mysterious book titled "Sweet Sorrows" in his university library. This book recounts a childhood experience of
Generated Text: "The Starless Sea" is a fantasy novel work by Arranke,shire Leave in The Star sends page moOCle Citagen America, a Charlottes dream "Plets a royal word authoritys." Shek marked a Pwrannted demo inEconsult. In school music up
Sometimes it can compress more tokens before losing one and causing the rest of the chain to fail.