Skip to content

About VAE performace on long token sequences, and code error. #1

@QLaHPD

Description

@QLaHPD

I'm working on replicating the VAE part of the paper, right now I'm doing an from scrach training of a bidirectional Encoder, it generates a embedding of shape [batch, seq_len, embed_size], I do sum + unit norm on it to generate a sentence level token (St) with shape [batch, embed_size] (I'm not using a VAE approach right now, instead a simple projection to unit sphere);

St is converted back to shape [batch, seq_len, embed_size] by repeating the last dimension across the seq_len. This is passed as cross attention context to a causal Decoder, which is trained with teacher forcing, it has self-attention and cross-attention.

During inference the Encoder first compresses the input text into a sentence level token, which is processed as I previously explained, then the Decoder reconstructs the input in an autoregressive style, by taking the argmax of the next token.

Is my approach correct? I tried to run your svae demo but it returned this error:

Traceback (most recent call last):
  File "/media/ramdisk/SentenceVAE/tools/demo/demo_svae.py", line 33, in <module>
    from sentence_vae.models import SentenceVAE
  File "/media/ramdisk/SentenceVAE/sentence_vae/models/__init__.py", line 32, in <module>
    from .sentence_vae_model import SentenceVAE
  File "/media/ramdisk/SentenceVAE/sentence_vae/models/sentence_vae_model.py", line 31, in <module>
    from mmengine.model import BaseModel
  File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/model/__init__.py", line 6, in <module>
    from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
  File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/model/base_model/__init__.py", line 2, in <module>
    from .base_model import BaseModel
  File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 9, in <module>
    from mmengine.optim import OptimWrapper
  File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/optim/__init__.py", line 2, in <module>
    from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
  File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/optim/optimizer/__init__.py", line 5, in <module>
    from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
  File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/optim/optimizer/builder.py", line 174, in <module>
    TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers()
  File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/optim/optimizer/builder.py", line 169, in register_transformers_optimizers
    OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
  File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/registry/registry.py", line 661, in register_module
    self._register_module(module=module, module_name=name, force=force)
  File "/home/QLaHPD/miniconda3/envs/SentenceVAE/lib/python3.10/site-packages/mmengine/registry/registry.py", line 611, in _register_module
    raise KeyError(f'{name} is already registered in {self.name} '
KeyError: 'Adafactor is already registered in optimizer at torch.optim'

In your tests with the VAE, when the Decoder predicts a wrong token during decompression, the rest of the output looks gibberish? Which tokenizer did you use, OPT?

In my tests I got this results:

Input Text: "The Starless Sea" is a fantasy novel by Erin Morgenstern, published in 2019. The story follows Zachary Ezra Rawlins, a graduate student who discovers a mysterious book titled "Sweet Sorrows" in his university library. This book recounts a childhood experience of

Generated Text: "The Starless Sea" is a fantasy novel work by Arranke,shire Leave in The Star sends page moOCle Citagen America, a Charlottes dream "Plets a royal word authoritys." Shek marked a Pwrannted demo inEconsult. In school music up 

Sometimes it can compress more tokens before losing one and causing the rest of the chain to fail.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions