Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit f488504

Browse files
lwfacebook-github-bot
authored andcommitted
Make edgelist readers also responsible for writing, and call them storages
Summary: Same thing as the previous commit, but for edgelists. Reviewed By: adamlerer Differential Revision: D17183916 fbshipit-source-id: e00c8707cadba9dea8bcd1e239ae0a5a3d8daeec
1 parent 3edcc3e commit f488504

7 files changed

Lines changed: 427 additions & 179 deletions

File tree

test/test_graph_storages.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE.txt file in the root directory of this source tree.
8+
9+
import tempfile
10+
from unittest import TestCase, main
11+
12+
import h5py
13+
import numpy as np
14+
import torch
15+
from torch_extensions.tensorlist.tensorlist import TensorList
16+
17+
from torchbiggraph.graph_storages import FileEdgeAppender
18+
19+
20+
class TestFileEdgeAppender(TestCase):
21+
22+
def test_tensors(self):
23+
with tempfile.NamedTemporaryFile() as bf:
24+
with h5py.File(bf.name, "w") as hf, FileEdgeAppender(hf) as buffered_hf:
25+
buffered_hf.append_tensor(
26+
"foo",
27+
torch.tensor([1, 2, 3], dtype=torch.long),
28+
)
29+
buffered_hf.append_tensor(
30+
"bar",
31+
torch.tensor([10, 11], dtype=torch.long),
32+
)
33+
buffered_hf.append_tensor(
34+
"foo",
35+
torch.tensor([4], dtype=torch.long),
36+
)
37+
buffered_hf.append_tensor(
38+
"foo",
39+
torch.tensor([], dtype=torch.long),
40+
)
41+
buffered_hf.append_tensor(
42+
"bar",
43+
torch.arange(12, 1_000_000, dtype=torch.long),
44+
)
45+
buffered_hf.append_tensor(
46+
"foo",
47+
torch.tensor([5, 6], dtype=torch.long),
48+
)
49+
50+
with h5py.File(bf.name, "r") as hf:
51+
np.testing.assert_equal(
52+
hf["foo"],
53+
np.array([1, 2, 3, 4, 5, 6], dtype=np.int64),
54+
)
55+
np.testing.assert_equal(
56+
hf["bar"],
57+
np.arange(10, 1_000_000, dtype=np.int64),
58+
)
59+
60+
def test_tensor_list(self):
61+
with tempfile.NamedTemporaryFile() as bf:
62+
with h5py.File(bf.name, "w") as hf, FileEdgeAppender(hf) as buffered_hf:
63+
buffered_hf.append_tensor_list(
64+
"foo",
65+
TensorList(
66+
torch.tensor([0, 3, 5], dtype=torch.long),
67+
torch.tensor([1, 2, 3, 4, 5], dtype=torch.long),
68+
),
69+
)
70+
buffered_hf.append_tensor_list(
71+
"bar",
72+
TensorList(
73+
torch.tensor([0, 1_000_000], dtype=torch.long),
74+
torch.arange(1_000_000, dtype=torch.long),
75+
),
76+
)
77+
buffered_hf.append_tensor_list(
78+
"foo",
79+
TensorList(
80+
torch.tensor([0, 1, 1, 3], dtype=torch.long),
81+
torch.tensor([6, 7, 8], dtype=torch.long),
82+
),
83+
)
84+
85+
with h5py.File(bf.name, "r") as hf:
86+
np.testing.assert_equal(
87+
hf["foo_offsets"],
88+
np.array([0, 3, 5, 6, 6, 8], dtype=np.int64),
89+
)
90+
np.testing.assert_equal(
91+
hf["foo_data"],
92+
np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int64),
93+
)
94+
np.testing.assert_equal(
95+
hf["bar_offsets"],
96+
np.array([0, 1_000_000], dtype=np.int64),
97+
)
98+
np.testing.assert_equal(
99+
hf["bar_data"],
100+
np.arange(1_000_000, dtype=np.int64),
101+
)
102+
103+
104+
if __name__ == '__main__':
105+
main()

torchbiggraph/converters/import_from_tsv.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
from itertools import chain
1414
from typing import Any, Counter, DefaultDict, Dict, List, Optional, Tuple
1515

16-
import h5py
17-
import numpy as np
16+
import torch
1817

1918
from torchbiggraph.config import (
2019
ConfigFileLoader,
@@ -24,9 +23,13 @@
2423
override_config_dict,
2524
)
2625
from torchbiggraph.converters.dictionary import Dictionary
26+
from torchbiggraph.edgelist import EdgeList
27+
from torchbiggraph.entitylist import EntityList
2728
from torchbiggraph.graph_storages import (
29+
AbstractEdgeStorage,
2830
AbstractEntityStorage,
2931
AbstractRelationTypeStorage,
32+
EDGE_STORAGES,
3033
ENTITY_STORAGES,
3134
RELATION_TYPE_STORAGES,
3235
)
@@ -160,6 +163,8 @@ def generate_entity_path_files(
160163

161164
def generate_edge_path_files(
162165
edge_file_in: str,
166+
edge_path_out: str,
167+
edge_storage: AbstractEdgeStorage,
163168
entities_by_type: Dict[str, Dictionary],
164169
relation_types: Dictionary,
165170
relation_configs: List[RelationSchema],
@@ -168,21 +173,16 @@ def generate_edge_path_files(
168173
rhs_col: int,
169174
rel_col: Optional[int],
170175
) -> None:
171-
172-
basename, _ = os.path.splitext(edge_file_in)
173-
edge_path_out = basename + '_partitioned'
174-
175-
print("Preparing edge path %s, out of the edges found in %s"
176-
% (edge_path_out, edge_file_in))
177-
os.makedirs(edge_path_out, exist_ok=True)
176+
print(f"Preparing edge path {edge_path_out}, "
177+
f"out of the edges found in {edge_file_in}")
178+
edge_storage.prepare()
178179

179180
num_lhs_parts = max(entities_by_type[rconfig.lhs].num_parts
180181
for rconfig in relation_configs)
181182
num_rhs_parts = max(entities_by_type[rconfig.rhs].num_parts
182183
for rconfig in relation_configs)
183184

184-
print("- Edges will be partitioned in %d x %d buckets."
185-
% (num_lhs_parts, num_rhs_parts))
185+
print(f"- Edges will be partitioned in {num_lhs_parts} x {num_rhs_parts} buckets.")
186186

187187
buckets: DefaultDict[Tuple[int, int], List[Tuple[int, int, int]]] = \
188188
DefaultDict(list)
@@ -198,8 +198,8 @@ def generate_edge_path_files(
198198
rel_word = words[rel_col] if rel_col is not None else None
199199
except IndexError:
200200
raise RuntimeError(
201-
"Line %d of %s has only %d words"
202-
% (line_num, edge_file_in, len(words))) from None
201+
f"Line {line_num} of {edge_file_in} has only {len(words)} words"
202+
) from None
203203

204204
if rel_col is None:
205205
rel_id = 0
@@ -232,26 +232,24 @@ def generate_edge_path_files(
232232

233233
processed = processed + 1
234234
if processed % 100000 == 0:
235-
print("- Processed %d edges so far..." % processed)
235+
print(f"- Processed {processed} edges so far...")
236236

237-
print("- Processed %d edges in total" % processed)
237+
print(f"- Processed {processed} edges in total")
238238
if skipped > 0:
239-
print("- Skipped %d edges because their relation type or entities were "
240-
"unknown (either not given in the config or filtered out as too "
241-
"rare)." % skipped)
239+
print(f"- Skipped {skipped} edges because their relation type or "
240+
f"entities were unknown (either not given in the config or "
241+
f"filtered out as too rare).")
242242

243243
for i in range(num_lhs_parts):
244244
for j in range(num_rhs_parts):
245-
print("- Writing bucket (%d, %d), containing %d edges..."
246-
% (i, j, len(buckets[i, j])))
247-
edges = np.array(buckets[i, j], dtype=np.int64).reshape((-1, 3))
248-
with h5py.File(os.path.join(
249-
edge_path_out, "edges_%d_%d.h5" % (i, j)
250-
), "w") as hf:
251-
hf.attrs["format_version"] = 1
252-
hf.create_dataset("lhs", data=edges[:, 0])
253-
hf.create_dataset("rhs", data=edges[:, 1])
254-
hf.create_dataset("rel", data=edges[:, 2])
245+
print(f"- Writing bucket ({i}, {j}), "
246+
f"containing {len(buckets[i, j])} edges...")
247+
edges = torch.tensor(buckets[i, j], dtype=torch.long).view((-1, 3))
248+
edge_storage.save_edges(i, j, EdgeList(
249+
EntityList.from_tensor(edges[:, 0]),
250+
EntityList.from_tensor(edges[:, 1]),
251+
edges[:, 2],
252+
))
255253

256254

257255
def convert_input_data(
@@ -268,6 +266,8 @@ def convert_input_data(
268266
) -> None:
269267
entity_storage = ENTITY_STORAGES.make_instance(entity_path)
270268
relation_type_storage = RELATION_TYPE_STORAGES.make_instance(entity_path)
269+
edge_paths_out = [os.path.splitext(ep)[0] + "_partitioned" for ep in edge_paths]
270+
edge_storages = [EDGE_STORAGES.make_instance(ep) for ep in edge_paths_out]
271271

272272
some_files_exists = []
273273
some_files_exists.extend(
@@ -282,8 +282,7 @@ def convert_input_data(
282282
some_files_exists.append(relation_type_storage.has_count())
283283
some_files_exists.append(relation_type_storage.has_names())
284284
some_files_exists.extend(
285-
os.path.exists(os.path.join(os.path.splitext(edge_file)[0] + "_partitioned", "edges_0_0.h5"))
286-
for edge_file in edge_paths)
285+
edge_storage.has_edges(0, 0) for edge_storage in edge_storages)
287286

288287
if all(some_files_exists):
289288
print("Found some files that indicate that the input data "
@@ -319,9 +318,12 @@ def convert_input_data(
319318
dynamic_relations,
320319
)
321320

322-
for edge_path in edge_paths:
321+
for edge_path, edge_path_out, edge_storage \
322+
in zip(edge_paths, edge_paths_out, edge_storages):
323323
generate_edge_path_files(
324324
edge_path,
325+
edge_path_out,
326+
edge_storage,
325327
entities_by_type,
326328
relation_types,
327329
relation_configs,

torchbiggraph/edgelist_reader.py

Lines changed: 0 additions & 137 deletions
This file was deleted.

0 commit comments

Comments
 (0)