conditionally compute loss so the interpretability method don't need to worry about faking a label tensor