-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer.py
More file actions
127 lines (110 loc) · 4.56 KB
/
infer.py
File metadata and controls
127 lines (110 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import re
import os
MODEL_DIR = "models"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Positional Encoding (exactly as in train.py) ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) *
(-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self, x):
return x + self.pe[:, :x.size(1), :].to(x.device)
# --- Transformer (same as train.py) ---
class TransformerModel(nn.Module):
def __init__(self, vocab_size, hidden_size, max_len):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
self.pos_enc = PositionalEncoding(hidden_size, max_len)
self.transformer = nn.Transformer(
d_model=hidden_size,
nhead=4,
num_encoder_layers=3,
num_decoder_layers=3,
dim_feedforward=256,
dropout=0.1,
batch_first=True
)
self.out = nn.Linear(hidden_size, vocab_size)
def generate_square_subsequent_mask(self, sz):
return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1).to(DEVICE)
def forward(self, src, tgt):
# src: (B, S), tgt: (B, T)
src_emb = self.pos_enc(self.embedding(src))
tgt_emb = self.pos_enc(self.embedding(tgt))
tgt_mask = self.generate_square_subsequent_mask(tgt.size(1))
out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
return self.out(out) # (B, T, vocab)
# --- Load vocab ---
def load_vocab():
char2idx = torch.load(os.path.join(MODEL_DIR, "char2idx.pth"))
idx2char = torch.load(os.path.join(MODEL_DIR, "idx2char.pth"))
return char2idx, idx2char
# --- Caesar encode (same as train) ---
def obfuscate_caesar(text, shift=3):
result = ''
for ch in text:
if ch.isalpha():
base = ord('A') if ch.isupper() else ord('a')
result += chr((ord(ch) - base + shift) % 26 + base)
else:
result += ch
return result
# --- Encode with <sos>/<eos> + pad to max_len ---
def encode_with_tokens(text, char2idx, max_len):
seq = [char2idx['<sos>']]
seq += [char2idx.get(c, 0) for c in text]
seq += [char2idx['<eos>']]
# pad to max_len
seq += [char2idx['<pad>']] * (max_len - len(seq))
return seq
# --- Greedy inference ---
def infer(obf_text, epoch):
char2idx, idx2char = load_vocab()
vocab_size = len(char2idx)
hidden_size = HIDDEN_SIZE = 128
# determine max_len from positional encoder buffer
# load a dummy model to inspect pos_enc.pe
# (alternatively hard-code max_len = your training max_len)
dummy = torch.load(os.path.join(MODEL_DIR, f"epoch{epoch}.pth"), map_location=DEVICE)
# assume you know max_len; if not, just hardcode it:
max_len = dummy['pos_enc.pe'].shape[1] if 'pos_enc.pe' in dummy else 50
# instantiate & load
model = TransformerModel(vocab_size, hidden_size, max_len).to(DEVICE)
state = torch.load(os.path.join(MODEL_DIR, f"epoch{epoch}.pth"), map_location=DEVICE)
model.load_state_dict(state)
model.eval()
# prepare source
src_idxs = encode_with_tokens(obf_text, char2idx, max_len)
src = torch.tensor([src_idxs], dtype=torch.long, device=DEVICE)
# start decoding
sos_idx = char2idx['<sos>']
eos_idx = char2idx['<eos>']
generated = [sos_idx]
with torch.no_grad():
for _ in range(max_len):
tgt = torch.tensor([generated], dtype=torch.long, device=DEVICE)
out = model(src, tgt) # (1, len(generated), vocab)
next_logits = out[0, -1, :] # last time‑step
next_idx = next_logits.argmax().item()
if next_idx == eos_idx:
break
generated.append(next_idx)
# skip the first <sos>
decoded = ''.join(idx2char[i] for i in generated[1:])
return decoded
if __name__ == "__main__":
import sys
if len(sys.argv) != 3:
print("Usage: python infer_transformer.py obf_word epoch")
exit(1)
obf, ep = sys.argv[1], int(sys.argv[2])
print("Decoded:", infer(obf, ep))