-
Notifications
You must be signed in to change notification settings - Fork 0
WIP: add state to logits processing #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
WIP: add state to logits processing #10
Conversation
|
|
||
| logits = enforce_token(logits, next_enforced_token_id) | ||
|
|
||
| next_enforced_token_id = Nx.add(next_enforced_token_id, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nx.Defn.Evaluator doesn't like adding on vectorized next_enforced_token_id
| # | ||
| # Now, with the processor below, we expect the sequence of [79, 80, 81 ..] | ||
|
|
||
| %{token_ids: token_ids} = generate.(params, inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the test passes when using EXLA as compiler
| %{token_ids: token_ids} = generate.(params, inputs) | |
| %{token_ids: token_ids} = Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) |
| ] | ||
| ) | ||
|
|
||
| %{token_ids: token_ids} = generate.(params, inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the test passes when using EXLA as compiler
| %{token_ids: token_ids} = generate.(params, inputs) | |
| %{token_ids: token_ids} = Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) |
this is the current state.
mix test test/bumblebee/text/generation_test.exs:135fails withHowever, this seems to be
Nx.Defn.Evaluatorspecific (which is used in the tests). When switching toEXLAas compiler the test passes.