-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpredict_and_memory.py
More file actions
60 lines (53 loc) · 1.93 KB
/
predict_and_memory.py
File metadata and controls
60 lines (53 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from data.transforms import preprocess_text
import os
# 모델 및 토크나이저 로드 (학습된 모델 불러오기)
model_path = "./checkpoints/best_model_v14" # 학습된 모델이 저장된 경로
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 디바이스 설정 (GPU 사용 가능 시 GPU, 그렇지 않으면 CPU 사용)
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
# 예측 함수
def sentence_predict(sent):
model.eval()
sent = preprocess_text(sent)
tokenized_sent = tokenizer(
sent,
return_tensors="pt",
truncation=True,
add_special_tokens=True,
max_length=128
)
tokenized_sent.to(device)
with torch.no_grad():
outputs = model(
input_ids=tokenized_sent["input_ids"],
attention_mask=tokenized_sent["attention_mask"],
token_type_ids=tokenized_sent.get("token_type_ids")
)
logits = outputs.logits
logits = logits.detach().cpu()
result = logits.argmax(-1)
return result.item()
# JSON 파일에서 예측 데이터 읽기
json_file_path = 'input_data.json' # JSON 파일 경로를 하드코딩
with open(json_file_path, 'r') as f:
input_data = json.load(f)
# JSON 데이터의 각 문장에 대해 예측 수행
predictions = []
for data in input_data:
sentence = data.get("message")
if sentence:
predicted_label = sentence_predict(sentence)
predictions.append({
"sentence": sentence,
"predicted_label": predicted_label
})
print(f"문장: {sentence}")
print(f"예측된 레이블: {predicted_label}")
else:
print("유효한 문장이 아닙니다.")