-
Notifications
You must be signed in to change notification settings - Fork 92
bugfix: fix the issue of ineffective input embedding transmission. #490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
bugfix: fix the issue of ineffective input embedding transmission. #490
Conversation
| if (input_embeddings_vec_.size() > 0) { | ||
| torch::Tensor input_embeddings = torch::cat(input_embeddings_vec_); | ||
| raw_forward_input.embeddings = tensor_to_2d_float_vector(input_embeddings); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: emmm, it seems that the input_embedding is useless now, maybe @wly-115 @yiming-l21 @RobbieLeung can help to confirm this. If so, we can delete all input_embedding related codes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_embedding will be used in generative recommendation, don't delete it for now
|
|
||
| const auto& input_embedding = sequence->get_input_embedding(); | ||
| if (input_embedding.defined()) | ||
| if (sequence->stage() == SequenceStage::PREFILL && input_embedding.defined()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will sequence be nullptr? if not, the signature should be "Sequence&",
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if yes, the nullptr case should be handled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the beginning of the function, ‘CHECK(sequence != nullptr)’, we will ensure that the sequence here will not be nullptr, and the use of pointers instead of references here is necessary to keep the sequence in both the batch and request.
| return tensor; | ||
| }; | ||
|
|
||
| inline std::vector<std::vector<float>> tensor_to_2d_float_vector( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will lead to a copy.... please fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- When passing 'const torch::Tensor&tensor' as a function parameter, it does not trigger the copying of tensor data itself, only pass the 'reference' of tensor objects, and pytorch's torch:: Tensor is just an intelligent pointer container with reference counting. Passing torch:: Tensor tensor directly will also not result in data copying;
- When 'std::vector<std::vector>' is used as the return value of a function, there will be no data copying. Firstly, std::vector is a 'dynamic array' container, and calling the vector's move constructor only transfers ownership of the metadata. Secondly, the compiler will enable Named Return Value Optimization (NRVO), which directly constructs the vector at the memory address on the calling end, completely skipping copy/move;
- Going back inside the function, if input tensor satisfies tensor.device().type()==torch:: kCPU&&tensor. scalar_type()== torch::kFloat32 && tensor.is_contiguous(), there will also be no data copying, but if one condition is not met, it will result in data copying, which is logical.
No description provided.