-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathdevits.cpp
More file actions
90 lines (53 loc) · 2.22 KB
/
devits.cpp
File metadata and controls
90 lines (53 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include "devits.h"
DEVITS::DEVITS()
{
}
#include <windows.h>
#include <string>
#include <vector>
#include <sstream>
void displayVectorInMessageBox(const std::vector<int>& numbers) {
// Convert vector to string
std::stringstream ss;
for (int num : numbers) {
ss << num << " ";
}
std::string strNumbers = ss.str();
// Display in a MessageBox
MessageBoxA(NULL, strNumbers.c_str(), "Vector Elements", MB_OK);
}
TFTensor<float> DEVITS::DoInferenceDE(const std::vector<int32_t> &InputIDs, const TFTensor<float> &MojiIn, const TFTensor<float> &BERTIn, const std::vector<float> &ArgsFloat, const std::vector<int32_t> ArgsInt, int32_t SpeakerID, int32_t EmotionID)
{
// without this memory consumption is 4x
torch::NoGradGuard no_grad;
std::vector<int64_t> PaddedIDs;
PaddedIDs = ZeroPadVec(InputIDs);
// displayVectorInMessageBox(InputIDs);
std::vector<int64_t> inLen = { (int64_t)PaddedIDs.size() };
// ZDisket: Is this really necessary?
torch::TensorOptions Opts = torch::TensorOptions().requires_grad(false);
auto InIDS = torch::tensor(PaddedIDs, Opts).unsqueeze(0);
auto InLens = torch::tensor(inLen, Opts);
auto MojiHidden = torch::tensor(MojiIn.Data).unsqueeze(0);
auto BERTHidden = torch::tensor(BERTIn.Data).reshape(BERTIn.Shape);
std::vector<int64_t> BERTSz = {BERTIn.Shape[1]};
auto BERTLens = torch::tensor(BERTSz);
auto InLenScale = torch::tensor({ ArgsFloat[0]}, Opts);
std::vector<torch::jit::IValue> inputs{ InIDS,InLens, MojiHidden, BERTHidden, BERTLens, InLenScale };
if (SpeakerID != -1){
auto InSpkid = torch::tensor({SpeakerID},Opts);
inputs.push_back(InSpkid);
}
// Infer
c10::IValue Output = Model.get_method("infer_ts")(inputs);
// Output = tuple (audio,att)
auto OutputT = Output.toTuple();
// Grab audio
// [1, frames] -> [frames]
auto AuTens = OutputT.get()->elements()[0].toTensor().squeeze();
// Grab Attention
// [1, 1, x, y] -> [x, y] -> [y,x] -> [1, y, x]
auto AttTens = OutputT.get()->elements()[1].toTensor().squeeze().transpose(0,1).unsqueeze(0);
Attention = VoxUtil::CopyTensor<float>(AttTens);
return VoxUtil::CopyTensor<float>(AuTens);
}