From 1da80f3a8b1102b84520b2d3eb5adfa28e3f59f3 Mon Sep 17 00:00:00 2001 From: Malinda Date: Tue, 28 Jun 2022 22:14:47 -0700 Subject: [PATCH] =?UTF-8?q?Using=C2=A0NumPy=20APIs=20to=C2=A0improve=20per?= =?UTF-8?q?formance=20and=20the=20code=20quality,=C2=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tensorflow_ranking/python/losses_test.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorflow_ranking/python/losses_test.py b/tensorflow_ranking/python/losses_test.py index 5864b16..0f0ea53 100644 --- a/tensorflow_ranking/python/losses_test.py +++ b/tensorflow_ranking/python/losses_test.py @@ -19,6 +19,8 @@ from __future__ import print_function import math + +import numpy as np import tensorflow as tf from tensorflow_ranking.python import losses as ranking_losses @@ -151,14 +153,11 @@ def _loss(si, sj, label_diff, delta): def _batch_aggregation(batch_loss_list, reduction=None): """Returns the aggregated loss.""" - loss_sum = 0. - weight_sum = 0. - for loss, weight, count in batch_loss_list: - loss_sum += loss - if reduction == 'mean': - weight_sum += weight - else: - weight_sum += count + loss_sum = np.sum([loss for loss, weight, count in batch_loss_list]) + if reduction == 'mean': + weight_sum = np.sum([weight for loss, weight, count in batch_loss_list]) + else: + weight_sum = np.sum([count for loss, weight, count in batch_loss_list]) return loss_sum / weight_sum