Skip to content

Commit 7231909

Browse files
csferngtensorflow-copybara
authored andcommitted
Adds model_kwargs argument to nsl.keras.adversarial_loss.
The argument is for passing additional arguments to the base model, so that some layers depending on such argument can function properly. For example, `BatchNormalization` layer requires `training=True` during training phase. Fixes #28 PiperOrigin-RevId: 273618791
1 parent 9758519 commit 7231909

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020
import collections
21+
import functools
2122
import types
2223

2324
import attr
@@ -36,7 +37,8 @@ def adversarial_loss(features,
3637
adv_config=None,
3738
predictions=None,
3839
labeled_loss=None,
39-
gradient_tape=None):
40+
gradient_tape=None,
41+
model_kwargs=None):
4042
"""Computes the adversarial loss for `model` given `features` and `labels`.
4143
4244
This utility function adds adversarial perturbations to the input `features`,
@@ -109,6 +111,8 @@ def adversarial_loss(features,
109111
adversarial regularization. In eager mode, the `gradient_tape` has to be
110112
set as well.
111113
gradient_tape: (optional) A `tf.GradientTape` object watching `features`.
114+
model_kwargs: (optional) A dictionary of additional keyword arguments to be
115+
passed to the `model`.
112116
113117
Returns:
114118
A `Tensor` for adversarial regularization loss, i.e. labeled loss on
@@ -118,6 +122,9 @@ def adversarial_loss(features,
118122
if adv_config is None:
119123
adv_config = nsl_configs.AdvRegConfig()
120124

125+
if model_kwargs is not None:
126+
model = functools.partial(model, **model_kwargs)
127+
121128
# Calculates labeled_loss if not provided.
122129
if labeled_loss is None:
123130
# Reuses the tape if provided; otherwise creates a new tape.
@@ -626,16 +633,16 @@ def call(self, inputs, **kwargs):
626633
self.add_metric(value, aggregation=aggregation, name=name)
627634

628635
# Adversarial loss.
629-
base_model_fn = lambda inputs: self.base_model(inputs, **kwargs)
630636
adv_loss = adversarial_loss(
631637
inputs,
632638
labels,
633-
base_model_fn,
639+
self.base_model,
634640
self._compute_total_loss,
635641
sample_weights=sample_weights,
636642
adv_config=self.adv_config,
637643
labeled_loss=labeled_loss,
638-
gradient_tape=tape)
644+
gradient_tape=tape,
645+
model_kwargs=kwargs)
639646
self.add_loss(self.adv_config.multiplier * adv_loss)
640647
self.add_metric(adv_loss, name='adversarial_loss', aggregation='mean')
641648
return outputs

neural_structured_learning/keras/adversarial_regularization_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,23 @@ def wrapped_loss_fn(*args, **kwargs):
202202
self.assertEqual(1, call_count['model'])
203203
self.assertEqual(1, call_count['loss_fn'])
204204

205+
def test_with_model_kwargs(self):
206+
w = np.array([[4.0], [-3.0]])
207+
x0 = np.array([[2.0, 3.0]])
208+
y0 = np.array([[0.0]])
209+
model = build_linear_keras_sequential_model(input_shape=(2,), weights=w)
210+
model.add(tf.keras.layers.BatchNormalization())
211+
212+
adv_loss = adversarial_regularization.adversarial_loss(
213+
features={'feature': tf.constant(x0)},
214+
labels=tf.constant(y0),
215+
model=model,
216+
loss_fn=keras.losses.MeanSquaredError(),
217+
adv_config=self.adv_config,
218+
model_kwargs={'training': True})
219+
# BatchNormalization returns 0 for signle-example batch when training=True.
220+
self.assertAllClose(0.0, self.evaluate(adv_loss))
221+
205222

206223
class AdversarialRegularizationTest(tf.test.TestCase, parameterized.TestCase):
207224

0 commit comments

Comments
 (0)