-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhowManyTokens.py
More file actions
64 lines (51 loc) · 1.57 KB
/
howManyTokens.py
File metadata and controls
64 lines (51 loc) · 1.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import random
import torch.nn.functional as F
import jieba
from collections import Counter
import time
from datetime import datetime
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# 1. 读取训练文本,保留换行
with open("train.txt", "r", encoding="utf-8") as f:
text = f.read().lower()
tokens = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
tokens.extend(jieba.cut(line))
tokens.append("<END>")
print("词数:", len(tokens))
# print("示例:", tokens[:30])
# 打印信息
print(f"训练文本总长度(字符数):{len(text)}")
print(f"训练文本总行数:{len(text.splitlines())}")
COMMON_TOKENS = [
"minecraft", "ai", "cpu", "gpu",
"ctrl", "shift", "alt", "cmd",
"+", "-", "*", "/", "=", "==",
"(", ")", "[", "]", "{", "}",
"(", ")", "【", "】", "「", "」",
"《", "》", "<", ">", "?", "!",
"@", "#", "$", "%", "^", "&",
"*", "_", "¥", "、", "“", "”",
"≠", "±", ":", ";", "‘", "’"
]
word_counts = Counter(tokens)
min_freq = 3
SPECIAL_TOKENS = ["<PAD>", "<END>", "<UNK>"]
vocab = SPECIAL_TOKENS + COMMON_TOKENS + [
w for w, c in word_counts.items()
if c >= min_freq and w not in COMMON_TOKENS
]
print("词表大小(含特殊符号):", len(vocab))
stoi = {w: i for i, w in enumerate(vocab)}
data = torch.tensor(
[stoi.get(w, stoi["<UNK>"]) for w in tokens],
dtype=torch.long
)
itos = {i: w for w, i in stoi.items()}
vocab_size = len(vocab)