-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathminibatch.py
More file actions
61 lines (51 loc) · 2.31 KB
/
minibatch.py
File metadata and controls
61 lines (51 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import torch
import dgl
class NetworkData(torch.utils.data.Dataset):
def __init__(self, fn_dir):
train = list()
valid = list()
test = list()
for fn in os.listdir(fn_dir):
g = dgl.load_graphs(fn_dir + fn)[0][0]
# no sinr
# g.nodes['sta'].data['feat'][:, 20] = 0
# no airtime
# g.nodes['ap'].data['feat'][:, 19] = 0
# no rssi
# g.edges['sta_ap'].data['feat'][:, 2] = 0
# g.edges['ap_sta'].data['feat'][:, 2] = 0
# no location
# g.nodes['sta'].data['feat'][:, 1] = 0
# g.nodes['sta'].data['feat'][:, 2] = 0
# g.nodes['ap'].data['feat'][:, 1] = 0
# g.nodes['ap'].data['feat'][:, 2] = 0
if fn.startswith('train'):
train.append(g)
elif fn.startswith('valid'):
valid.append(g)
else:
test.append(g)
self.data = train + valid + test
self.train_sampler = torch.utils.data.sampler.SubsetRandomSampler(torch.arange(len(train)))
self.valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(torch.arange(len(train), len(train) + len(valid)))
self.test_sampler = torch.utils.data.sampler.SubsetRandomSampler(torch.arange(len(train) + len(valid), len(train) + len(valid) + len(test)))
def __len__(self):
return len(Self.data)
def __getitem__(self, idx):
return self.data[idx], idx
def to_cuda(self):
cuda_data = [d.to('cuda') for d in self.data]
self.data = cuda_data
def get_sampler(self):
return self.train_sampler, self.valid_sampler, self.test_sampler
def get_dataloader(fn_dir, batchsize, all_cuda=False):
print('loading data...')
dataset = NetworkData(fn_dir)
if all_cuda:
dataset.to_cuda()
train_sampler, valid_sampler, test_sampler = dataset.get_sampler()
train_dataloader = dgl.dataloading.GraphDataLoader(dataset, sampler=train_sampler, batch_size=batchsize)
valid_dataloader = dgl.dataloading.GraphDataLoader(dataset, sampler=valid_sampler, batch_size=batchsize)
test_dataloader = dgl.dataloading.GraphDataLoader(dataset, sampler=test_sampler, batch_size=batchsize)
return train_dataloader, valid_dataloader, test_dataloader