kv_cache = None
prev_tokens = set()
max_length = int(max_length)
for step in range(max_length):
dec_feat_seq = self.model(alive_seq,
encoder_hidden_states = src_features,
encoder_attention_mask = attention_mask,
past_key_values = None,
use_cache = False,
return_dict = True,
reduction = 'none')
kv_cache = dec_feat_seq.past_key_values
I change the code like this, and it didn't work.