-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdataload.py
More file actions
42 lines (37 loc) · 1.31 KB
/
dataload.py
File metadata and controls
42 lines (37 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@file :dataload.py
@description :
@time :2020/11/27 22:22:31
@author :wizz
@version :1.0
'''
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
class MyDataSet(Dataset):
def __init__(self, data, time_step=5, pred_step=5, TRAIN=True):
self.data = data
self.time_step = time_step
self.pred_step = pred_step
self.TRAIN = TRAIN
def __getitem__(self, index):
# trainset
if self.TRAIN:
sample = self.data[index: index + self.time_step]
label = self.data[index + self.time_step: index +
self.time_step + self.pred_step]
# 选择待预测的列作为label,即CPU_USAGE和LAUNCHING_JOB_NUMS
label = torch.index_select(
label, dim=1, index=torch.tensor([2, 4]))
return sample, label
# testset
sample = self.data[index*self.time_step: (index + 1)*self.time_step]
return sample
def __len__(self):
if self.TRAIN:
if (len(self.data) - self.time_step - self.pred_step + 1) >= 0:
return len(self.data) - self.time_step - self.pred_step + 1
return 0
return len(self.data)//self.time_step