From c9cbb627cb3e15e04515a51eed62c6873adecdb6 Mon Sep 17 00:00:00 2001 From: "Jean A. Senellart" Date: Tue, 11 Apr 2017 14:36:32 +0200 Subject: [PATCH 1/2] update type name --- NCEModule.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NCEModule.lua b/NCEModule.lua index 881cf5a..1a8e317 100644 --- a/NCEModule.lua +++ b/NCEModule.lua @@ -99,7 +99,7 @@ function NCEModule:updateOutput(inputTable) elseif self.batchnoise then self.output = (torch.type(self.output) == 'table' and #self.output == 4) and self.output or {input.new(), input.new(), input.new(), input.new()} - assert(torch.type(target) == 'torch.CudaTensor' or torch.type(target) == 'torch.LongTensor') + assert(torch.type(target) == 'torch.CudaLongTensor' or torch.type(target) == 'torch.LongTensor') self.sampleidx = self.sampleidx or target.new() -- the last elements contain the target indices From 2d78116c395addc1a97de1b2e11bc995e3cd7878 Mon Sep 17 00:00:00 2001 From: "Jean A. Senellart" Date: Tue, 11 Apr 2017 14:39:04 +0200 Subject: [PATCH 2/2] keep CudaTensor for compatibility --- NCEModule.lua | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/NCEModule.lua b/NCEModule.lua index 1a8e317..3acedaa 100644 --- a/NCEModule.lua +++ b/NCEModule.lua @@ -99,7 +99,8 @@ function NCEModule:updateOutput(inputTable) elseif self.batchnoise then self.output = (torch.type(self.output) == 'table' and #self.output == 4) and self.output or {input.new(), input.new(), input.new(), input.new()} - assert(torch.type(target) == 'torch.CudaLongTensor' or torch.type(target) == 'torch.LongTensor') + assert(torch.type(target) == 'torch.CudaLongTensor' or torch.type(target) == 'torch.CudaTensor' + or torch.type(target) == 'torch.LongTensor') self.sampleidx = self.sampleidx or target.new() -- the last elements contain the target indices