From 36ef5ce3221e1b194b137cf0de59c0cbebabe0e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Mart=C3=ADnez?= <26169771+miguelusque@users.noreply.github.com> Date: Sun, 2 Jul 2023 09:07:09 +0200 Subject: [PATCH] Remove parenthesis from asserts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 'assert' syntax is that this statement doesn’t require a pair of parentheses to group the expression and the optional message. In Python, 'assert' is a statement instead of a function. The parentheses turn the assertion expression and message into a two-item tuple, which always evaluates to true, making the assertion useless. --- examples/pytorch/gpt/utils/gpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/gpt/utils/gpt.py b/examples/pytorch/gpt/utils/gpt.py index 20d90b45f..2c75d971b 100644 --- a/examples/pytorch/gpt/utils/gpt.py +++ b/examples/pytorch/gpt/utils/gpt.py @@ -38,7 +38,7 @@ def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, has_post_decoder_layernorm: bool = True, int8_mode: int = 0, inter_size: int = 0): - assert(head_num % tensor_para_size == 0) + assert head_num % tensor_para_size == 0 if int8_mode == 1: torch_infer_dtype = str_type_map[inference_data_type] @@ -218,7 +218,7 @@ def __len__(self): return len(self.w) def _map(self, func): - assert(self.pre_embed_idx < self.post_embed_idx, "Pre decoder embedding index should be lower than post decoder embedding index.") + assert self.pre_embed_idx < self.post_embed_idx, "Pre decoder embedding index should be lower than post decoder embedding index." for i in range(len(self.w)): if isinstance(self.w[i], list): for j in range(len(self.w[i])):