A simple GRPC Python service demonstrates streaming text-generation with HuggingFace and pyTorch.
The sample TextGenerator GRPC service provides two methods.
service TextGenerator {
rpc Generate(GenerateRequest) returns (GenerateResponse) {}
rpc GenerateStreamed(GenerateStreamedRequest) returns (stream GenerateStreamedResponse) {}
}Generate RPC can be used to generate text by a given staring phrase (GenerateRequest.text) returning a final
result as soon as max number of tokens reached (GenerateRequest.max_length). max_length is an optional field. If not
provided the default value will be used.
message GenerateRequest {
string text = 1;
int32 max_length = 2;
}Result contains the only field GenerateResponse.text holding the product of generation process.
message GenerateResponse {
string text = 1;
}GenerateStreamed RPC performs the same work but streaming intermediate results during the generation process.
Optional GenerateStreamedRequest.intermediate_result_interval_ms field specifies a minimal time span between intermediate
results. For instance, if the time interval is set to 500ms and a total generation process would take longer than this value,
every 500 ms an intermediate result returned.
message GenerateStreamedRequest {
string text = 1;
int32 max_length = 2;
int32 intermediate_result_interval_ms = 3;
}Each message in the response stream contains the only field GenerateStreamedResponse.text_fragment with a value of
next portion of the generated text. The final value could be calculated as a concatenation of all text fragments in the
same order there were received.
message GenerateStreamedResponse {
string text_fragment = 1;
}In this example the GPT2 model is used, but you're free to use any other text-generating model.
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")In order to collect intermediate results it is suggested to use a fake stopping criteria.
def custom_stopping_criteria(input_ids, scores) -> bool:
return FalseThe first parameter input_ids contains intermediate result tokens, the tokenizer can be used to decode the tokens to
text.
tokenizer.decode(input_ids[0], skip_special_tokens=True)To start text-generation process we encode the text into tokens wrapped into pyTorch tensors, then invoke model
generation providing custom_stopping_criteria.
inputs = tokenizer.encode("Let me say something", return_tensors='pt')
outputs = model.generate(inputs, max_length=7, do_sample=True, stopping_criteria=[custom_stopping_criteria])Putting all together:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Encode input tokens.
inputs = tokenizer.encode("Let me say something", return_tensors='pt')
def custom_stopping_criteria(input_ids, unused_scores) -> bool:
# Print intermediate result.
print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
return False
outputs = model.generate(inputs, max_length=7, do_sample=True, stopping_criteria=[custom_stopping_criteria])
# Print final result
print(tokenizer.decode(outputs[0], skip_special_tokens=True))You need to have pyTorch and HuggingFace installed. In order to run the service the grpc is also need to be
installed.
Please refer to the official installation instruction, but in general case the following pip command could be used:
pip install grpcio grpcio-tools torch transformerspython text_generator_server.pypython text_generator_client.pypython text_generator_streamed_client.pyNew GRPC wrappers need to be regenerated if any changes to protos/text_generator.proto were made.
python -m grpc_tools.protoc -I./protos --python_out=. --pyi_out=. --grpc_python_out=. protos/text_generator.protoThis command generates and overwrites the following files:
text_generator_pb.pytext_generator_pb.pyitext_generator_pb_grpc.py