diff --git a/examples/seq2seq/task_question_answer_generation_by_seq2seq.py b/examples/seq2seq/task_question_answer_generation_by_seq2seq.py index bb43c7d8..3e4a4718 100644 --- a/examples/seq2seq/task_question_answer_generation_by_seq2seq.py +++ b/examples/seq2seq/task_question_answer_generation_by_seq2seq.py @@ -145,7 +145,7 @@ def predict(self, inputs, output_ids, states): def generate(self, passage, topk=1, topp=0.95): token_ids, segment_ids = tokenizer.encode(passage, maxlen=max_p_len) a_ids = self.random_sample([token_ids, segment_ids], n=1, topp=topp)[0] # 基于随机采样 - token_ids += list(a_ids) + token_ids += list(a_ids.cpu().numpy()) segment_ids += [1] * len(a_ids) q_ids = self.beam_search([token_ids, segment_ids], topk=topk)[0] # 基于beam search return (tokenizer.decode(q_ids.cpu().numpy()), tokenizer.decode(a_ids.cpu().numpy()))