Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit 9c9e809

Browse files
lwfacebook-github-bot
authored andcommitted
Auto-tune num_edge_chunks
Summary: One reason we allow to chunk edgelists is to be able to load edgelists that are too big to fit in memory at once, piece by piece. (There are other reasons however, for example more frequent mixing of edges from different buckets). The former goal can be achieved automatically, that is, given a certain maximum size that edgelists can take in memort, PBG can figure out what is the smallest number of edge chunks that achieve this. Reviewed By: adamlerer Differential Revision: D17571778 fbshipit-source-id: e4977078c35f4bdb212ad163acf137eb94d33994
1 parent 7f98961 commit 9c9e809

4 files changed

Lines changed: 66 additions & 8 deletions

File tree

docs/source/configuration_file.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,13 @@ See :ref:`batch-preparation` for more details.
139139

140140
The number of times the training loop iterates over all the edges.
141141

142-
- ``num_edge_chunks`` (type: integer; default: ``1``)
142+
- ``num_edge_chunks`` (type: integer or null; default: ``null``)
143143

144-
The number of equally-sized parts each bucket will be split into. Training will first proceed over all the first chunks of all buckets, then over all the second chunks, and so on. A higher value allows better mixing of partitions, at the cost of more time spent on I/O.
144+
The number of equally-sized parts each bucket will be split into. Training will first proceed over all the first chunks of all buckets, then over all the second chunks, and so on. A higher value allows better mixing of partitions, at the cost of more time spent on I/O. If unset, will be automatically calculated so that no chunk has more than max_edges_per_chunk edges.
145+
146+
- ``max_edges_per_chunk`` (type: integer, default: ``1000000000``)
147+
148+
The maximum number of edges that each edge chunk should contain if the number of edge chunks is left unspecified and has to be automatically figured out. Each edge takes up at least 12 bytes (3 int64s), more if using featurized entities.
145149

146150
- ``bucket_order`` (type: string, either ``"random"``, ``"affinity"``, ``"inside_out"`` or ``"outside_in"``; default: ``"inside_out"``)
147151

torchbiggraph/config.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,15 +222,26 @@ class ConfigSchema(Schema):
222222
metadata={'help': "The number of times the training loop iterates over "
223223
"all the edges."},
224224
)
225-
num_edge_chunks: int = attr.ib(
226-
default=1,
227-
validator=positive,
225+
num_edge_chunks: Optional[int] = attr.ib(
226+
default=None,
227+
validator=optional(positive),
228228
metadata={'help': "The number of equally-sized parts each bucket will "
229229
"be split into. Training will first proceed over all "
230230
"the first chunks of all buckets, then over all the "
231231
"second chunks, and so on. A higher value allows "
232232
"better mixing of partitions, at the cost of more "
233-
"time spent on I/O."},
233+
"time spent on I/O. If unset, will be automatically "
234+
"calculated so that no chunk has more than "
235+
"max_edges_per_chunk edges."},
236+
)
237+
max_edges_per_chunk: int = attr.ib(
238+
default=1_000_000_000, # Each edge having 3 int64s, this is 12GB.
239+
validator=positive,
240+
metadata={'help': "The maximum number of edges that each edge chunk "
241+
"should contain if the number of edge chunks is left "
242+
"unspecified and has to be automatically figured "
243+
"out. Each edge takes up at least 12 bytes (3 "
244+
"int64s), more if using featurized entities."},
234245
)
235246
bucket_order: BucketOrder = attr.ib(
236247
default=BucketOrder.INSIDE_OUT,

torchbiggraph/graph_storages.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def has_edges(self, lhs_p: int, rhs_p: int) -> bool:
122122
def load_edges(self, lhs_p: int, rhs_p: int) -> EdgeList:
123123
return self.load_chunk_of_edges(lhs_p, rhs_p, chunk_idx=0, num_chunks=1)
124124

125+
@abstractmethod
126+
def get_number_of_edges(self, lhs_p: int, rhs_p: int) -> int:
127+
pass
128+
125129
@abstractmethod
126130
def load_chunk_of_edges(
127131
self,
@@ -388,6 +392,15 @@ def has_edges(
388392
) -> bool:
389393
return self.get_edges_file(lhs_p, rhs_p).is_file()
390394

395+
def get_number_of_edges(self, lhs_p: int, rhs_p: int) -> int:
396+
file_path = self.get_edges_file(lhs_p, rhs_p)
397+
if not file_path.is_file():
398+
raise RuntimeError(f"{file_path} does not exist")
399+
with h5py.File(file_path, "r") as hf:
400+
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
401+
raise RuntimeError(f"Version mismatch in edge file {file_path}")
402+
return hf["rel"].len()
403+
391404
def load_chunk_of_edges(
392405
self,
393406
lhs_p: int,

torchbiggraph/train.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import argparse
1010
import logging
11+
import math
1112
import time
1213
from abc import ABC, abstractmethod
1314
from functools import partial
@@ -283,6 +284,30 @@ def should_preserve_old_checkpoint(
283284
return is_checkpoint_epoch and is_first_edge_path and is_first_edge_chunk
284285

285286

287+
def get_num_edge_chunks(
288+
edge_paths: List[str],
289+
nparts_lhs: int,
290+
nparts_rhs: int,
291+
max_edges_per_chunk: int,
292+
) -> int:
293+
max_edges_per_bucket = 0
294+
# We should check all edge paths, all lhs partitions and all rhs partitions,
295+
# but the combinatorial explosion could lead to thousands of checks. Let's
296+
# assume that edges are uniformly distributed among buckets (this is not
297+
# exactly the case, as it's the entities that are uniformly distributed
298+
# among the partitions, and edge assignments to buckets are a function of
299+
# that, thus, for example, very high degree entities could skew this), and
300+
# use the size of bucket (0, 0) as an estimate of the average bucket size.
301+
# We still do it for all edge paths as there could be semantic differences
302+
# between them which lead to different sizes.
303+
for edge_path in edge_paths:
304+
edge_storage = EDGE_STORAGES.make_instance(edge_path)
305+
max_edges_per_bucket = max(
306+
max_edges_per_bucket,
307+
edge_storage.get_number_of_edges(0, 0))
308+
return max(1, math.ceil(max_edges_per_bucket / max_edges_per_chunk))
309+
310+
286311
def train_and_report_stats(
287312
config: ConfigSchema,
288313
model: Optional[MultiRelationEmbedder] = None,
@@ -446,8 +471,13 @@ def make_optimizer(params: Iterable[torch.nn.Parameter], is_emb: bool) -> Optimi
446471
checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config))
447472
checkpoint_manager.write_config(config)
448473

474+
if config.num_edge_chunks is not None:
475+
num_edge_chunks = config.num_edge_chunks
476+
else:
477+
num_edge_chunks = get_num_edge_chunks(
478+
config.edge_paths, nparts_lhs, nparts_rhs, config.max_edges_per_chunk)
449479
iteration_manager = IterationManager(
450-
config.num_epochs, config.edge_paths, config.num_edge_chunks,
480+
config.num_epochs, config.edge_paths, num_edge_chunks,
451481
iteration_idx=checkpoint_manager.checkpoint_version)
452482
checkpoint_manager.register_metadata_provider(iteration_manager)
453483

@@ -680,7 +710,7 @@ def swap_partitioned_embeddings(
680710

681711
bucket_logger.debug("Loading edges")
682712
edges = edge_storage.load_chunk_of_edges(
683-
cur_b.lhs, cur_b.rhs, edge_chunk_idx, config.num_edge_chunks)
713+
cur_b.lhs, cur_b.rhs, edge_chunk_idx, iteration_manager.num_edge_chunks)
684714
num_edges = len(edges)
685715
# this might be off in the case of tensorlist or extra edge fields
686716
io_bytes += edges.lhs.tensor.numel() * edges.lhs.tensor.element_size()

0 commit comments

Comments
 (0)