Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit 50c9038

Browse files
lwfacebook-github-bot
authored andcommitted
Fix to actually be able to use non-file edge paths in import_from_tsv
Summary: In `import_from_tsv` we constructed the output edge paths by converting the paths passed as command-line args (stripping the extension and appending `_partitioned`). Thus, the output edge paths are always file-based. In order to be able to use different schemas we need to be able to separately specify the output edge paths: it makes perfect sense to use the config file for this (we are already using it for the entity path, and may one day use it for the initial data). The input edge paths (from the command line) and the output ones (from the config) are then matched (i.e., zipped). Thus order now matters. Therefore in the README we explicitly list them, rather than using a glob wildcard. Reviewed By: adamlerer Differential Revision: D17502000 fbshipit-source-id: 9cd2e534ab70c600a35e28c4dae57a6015c2bc1e
1 parent f6b309c commit 50c9038

7 files changed

Lines changed: 70 additions & 59 deletions

File tree

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ Luckily, there is a command that does all of this:
7676
torchbiggraph_import_from_tsv \
7777
--lhs-col=0 --rel-col=1 --rhs-col=2 \
7878
torchbiggraph/examples/configs/fb15k_config.py \
79-
data/FB15k/freebase_mtr100_mte100-*.txt
79+
data/FB15k/freebase_mtr100_mte100-train.txt \
80+
data/FB15k/freebase_mtr100_mte100-valid.txt \
81+
data/FB15k/freebase_mtr100_mte100-test.txt
8082
```
8183
The outputs will be stored next to the inputs in the `data/FB15k` directory.
8284

torchbiggraph/converters/import_from_tsv.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
override_config_dict,
2323
)
2424
from torchbiggraph.converters.dictionary import Dictionary
25-
from torchbiggraph.converters.utils import convert_path
2625
from torchbiggraph.edgelist import EdgeList
2726
from torchbiggraph.entitylist import EntityList
2827
from torchbiggraph.graph_storages import (
@@ -254,18 +253,23 @@ def convert_input_data(
254253
entity_configs: Dict[str, EntitySchema],
255254
relation_configs: List[RelationSchema],
256255
entity_path: str,
257-
edge_paths: List[Path],
256+
edge_paths_out: List[str],
257+
edge_paths_in: List[Path],
258258
lhs_col: int,
259259
rhs_col: int,
260260
rel_col: Optional[int] = None,
261261
entity_min_count: int = 1,
262262
relation_type_min_count: int = 1,
263263
dynamic_relations: bool = False,
264264
) -> None:
265+
if len(edge_paths_in) != len(edge_paths_out):
266+
raise ValueError(
267+
f"The edge paths passed as inputs ({edge_paths_in}) don't match "
268+
f"the ones specified as outputs ({edge_paths_out})")
269+
265270
entity_storage = ENTITY_STORAGES.make_instance(entity_path)
266271
relation_type_storage = RELATION_TYPE_STORAGES.make_instance(entity_path)
267-
edge_paths_out = [convert_path(ep) for ep in edge_paths]
268-
edge_storages = [EDGE_STORAGES.make_instance(str(ep)) for ep in edge_paths_out]
272+
edge_storages = [EDGE_STORAGES.make_instance(ep) for ep in edge_paths_out]
269273

270274
some_files_exists = []
271275
some_files_exists.extend(
@@ -291,7 +295,7 @@ def convert_input_data(
291295

292296
relation_types = collect_relation_types(
293297
relation_configs,
294-
edge_paths,
298+
edge_paths_in,
295299
dynamic_relations,
296300
rel_col,
297301
relation_type_min_count,
@@ -301,7 +305,7 @@ def convert_input_data(
301305
relation_types,
302306
entity_configs,
303307
relation_configs,
304-
edge_paths,
308+
edge_paths_in,
305309
dynamic_relations,
306310
lhs_col,
307311
rhs_col,
@@ -317,10 +321,10 @@ def convert_input_data(
317321
dynamic_relations,
318322
)
319323

320-
for edge_path, edge_path_out, edge_storage \
321-
in zip(edge_paths, edge_paths_out, edge_storages):
324+
for edge_path_in, edge_path_out, edge_storage \
325+
in zip(edge_paths_in, edge_paths_out, edge_storages):
322326
generate_edge_path_files(
323-
edge_path,
327+
edge_path_in,
324328
edge_path_out,
325329
edge_storage,
326330
entities_by_type,
@@ -339,6 +343,7 @@ def parse_config_partial(
339343
entities_config = config_dict.get("entities")
340344
relations_config = config_dict.get("relations")
341345
entity_path = config_dict.get("entity_path")
346+
edge_paths = config_dict.get("edge_paths")
342347
dynamic_relations = config_dict.get("dynamic_relations", False)
343348
if not isinstance(entities_config, dict):
344349
raise TypeError("Config entities is not of type dict")
@@ -348,6 +353,10 @@ def parse_config_partial(
348353
raise TypeError("Config relations is not of type list")
349354
if not isinstance(entity_path, str):
350355
raise TypeError("Config entity_path is not of type str")
356+
if not isinstance(edge_paths, list):
357+
raise TypeError("Config edge_paths is not of type list")
358+
if any(not isinstance(p, str) for p in edge_paths):
359+
raise TypeError("Config edge_paths has some items that are not of type str")
351360
if not isinstance(dynamic_relations, bool):
352361
raise TypeError("Config dynamic_relations is not of type bool")
353362

@@ -358,7 +367,7 @@ def parse_config_partial(
358367
for relation in relations_config:
359368
relations.append(RelationSchema.from_dict(relation))
360369

361-
return entities, relations, entity_path, dynamic_relations
370+
return entities, relations, entity_path, edge_paths, dynamic_relations
362371

363372

364373
def main():
@@ -390,13 +399,18 @@ def main():
390399
overrides = chain.from_iterable(opt.param) # flatten
391400
config_dict = override_config_dict(config_dict, overrides)
392401

393-
entity_configs, relation_configs, entity_path, dynamic_relations = \
402+
entity_configs, relation_configs, entity_path, edge_paths, dynamic_relations = \
394403
parse_config_partial(config_dict)
395404

405+
if len(opt.edge_paths) != len(edge_paths):
406+
print(f"The edge paths provided on the command line ({opt.edge_paths}) "
407+
f"don't match the ones found in the config file ({edge_paths})")
408+
396409
convert_input_data(
397410
entity_configs,
398411
relation_configs,
399412
entity_path,
413+
edge_paths,
400414
opt.edge_paths,
401415
opt.lhs_col,
402416
opt.rhs_col,

torchbiggraph/converters/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
from tqdm import tqdm
1818

1919

20-
def convert_path(fname: Path) -> Path:
21-
return fname.parent / f"{fname.stem}_partitioned"
22-
23-
2420
def extract_gzip(gzip_path: Path, remove_finished: bool = False) -> str:
2521
print(f"Extracting {gzip_path}")
2622
if gzip_path.suffix != ".gz":

torchbiggraph/examples/configs/fb15k_config.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
# This source code is licensed under the BSD-style license found in the
77
# LICENSE.txt file in the root directory of this source tree.
88

9-
entity_base = "data/FB15k"
10-
119

1210
def get_torchbiggraph_config():
1311

1412
config = dict(
1513
# I/O data
16-
entity_path=entity_base,
17-
edge_paths=[],
18-
checkpoint_path='model/fb15k',
14+
entity_path="data/FB15k",
15+
edge_paths=[
16+
"data/FB15k/freebase_mtr100_mte100-train_partitioned",
17+
"data/FB15k/freebase_mtr100_mte100-valid_partitioned",
18+
"data/FB15k/freebase_mtr100_mte100-test_partitioned",
19+
],
20+
checkpoint_path="model/fb15k",
1921

2022
# Graph structure
2123
entities={

torchbiggraph/examples/configs/livejournal_config.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
# This source code is licensed under the BSD-style license found in the
77
# LICENSE.txt file in the root directory of this source tree.
88

9-
entities_base = 'data/livejournal'
10-
119

1210
def get_torchbiggraph_config():
1311

1412
config = dict(
1513
# I/O data
16-
entity_path=entities_base,
17-
edge_paths=[],
18-
checkpoint_path='model/livejournal',
14+
entity_path="data/livejournal",
15+
edge_paths=[
16+
"data/train_partitioned",
17+
"data/test_partitioned",
18+
],
19+
checkpoint_path="model/livejournal",
1920

2021
# Graph structure
2122
entities={

torchbiggraph/examples/fb15k.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import attr
1414
import pkg_resources
1515

16-
from torchbiggraph.converters.utils import convert_path, download_url, extract_tar
16+
from torchbiggraph.converters.utils import download_url, extract_tar
1717
from torchbiggraph.config import add_to_sys_path, ConfigFileLoader
1818
from torchbiggraph.converters.import_from_tsv import convert_input_data
1919
from torchbiggraph.eval import do_eval
@@ -27,11 +27,11 @@
2727

2828

2929
FB15K_URL = 'https://dl.fbaipublicfiles.com/starspace/fb15k.tgz'
30-
FILENAMES = {
31-
'train': 'FB15k/freebase_mtr100_mte100-train.txt',
32-
'valid': 'FB15k/freebase_mtr100_mte100-valid.txt',
33-
'test': 'FB15k/freebase_mtr100_mte100-test.txt',
34-
}
30+
FILENAMES = [
31+
"FB15k/freebase_mtr100_mte100-train.txt",
32+
"FB15k/freebase_mtr100_mte100-valid.txt",
33+
"FB15k/freebase_mtr100_mte100-test.txt",
34+
]
3535

3636
# Figure out the path where the sample config was installed by the package manager.
3737
# This can be overridden with --config.
@@ -68,33 +68,29 @@ def main():
6868
subprocess_init = SubprocessInitializer()
6969
subprocess_init.register(setup_logging, config.verbose)
7070
subprocess_init.register(add_to_sys_path, loader.config_dir.name)
71-
edge_paths = [data_dir / name for name in FILENAMES.values()]
71+
input_edge_paths = [data_dir / name for name in FILENAMES]
72+
output_train_path, output_valid_path, output_test_path = config.edge_paths
7273

7374
convert_input_data(
7475
config.entities,
7576
config.relations,
7677
config.entity_path,
77-
edge_paths,
78+
config.edge_paths,
79+
input_edge_paths,
7880
lhs_col=0,
7981
rhs_col=2,
8082
rel_col=1,
8183
dynamic_relations=config.dynamic_relations,
8284
)
8385

84-
train_path = [str(convert_path(data_dir / FILENAMES['train']))]
85-
train_config = attr.evolve(config, edge_paths=train_path)
86-
86+
train_config = attr.evolve(config, edge_paths=[output_train_path])
8787
train(train_config, subprocess_init=subprocess_init)
8888

89-
eval_path = [str(convert_path(data_dir / FILENAMES['test']))]
9089
relations = [attr.evolve(r, all_negs=True) for r in config.relations]
91-
eval_config = attr.evolve(config, edge_paths=eval_path, relations=relations, num_uniform_negs=0)
90+
eval_config = attr.evolve(
91+
config, edge_paths=[output_test_path], relations=relations, num_uniform_negs=0)
9292
if args.filtered:
93-
filter_paths = [
94-
str(convert_path(data_dir / FILENAMES['test'])),
95-
str(convert_path(data_dir / FILENAMES['valid'])),
96-
str(convert_path(data_dir / FILENAMES['train'])),
97-
]
93+
filter_paths = [output_test_path, output_valid_path, output_train_path]
9894
do_eval(
9995
eval_config,
10096
evaluator=FilteredRankingEvaluator(eval_config, filter_paths),

torchbiggraph/examples/livejournal.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from torchbiggraph.config import add_to_sys_path, ConfigFileLoader
1818
from torchbiggraph.converters.import_from_tsv import convert_input_data
19-
from torchbiggraph.converters.utils import convert_path, download_url, extract_gzip
19+
from torchbiggraph.converters.utils import download_url, extract_gzip
2020
from torchbiggraph.eval import do_eval
2121
from torchbiggraph.train import train
2222
from torchbiggraph.util import (
@@ -27,10 +27,12 @@
2727

2828

2929
URL = 'https://snap.stanford.edu/data/soc-LiveJournal1.txt.gz'
30-
FILENAMES = {
31-
'train': 'train.txt',
32-
'test': 'test.txt',
33-
}
30+
TRAIN_FILENAME = "train.txt"
31+
TEST_FILENAME = "test.txt"
32+
FILENAMES = [
33+
TRAIN_FILENAME,
34+
TEST_FILENAME,
35+
]
3436
TRAIN_FRACTION = 0.75
3537

3638
# Figure out the path where the sample config was installed by the package manager.
@@ -40,8 +42,8 @@
4042

4143

4244
def random_split_file(fpath: Path) -> None:
43-
train_file = fpath.parent / FILENAMES['train']
44-
test_file = fpath.parent / FILENAMES['test']
45+
train_file = fpath.parent / TRAIN_FILENAME
46+
test_file = fpath.parent / TEST_FILENAME
4547

4648
if train_file.exists() and test_file.exists():
4749
print("Found some files that indicate that the input data "
@@ -103,27 +105,25 @@ def main():
103105
subprocess_init = SubprocessInitializer()
104106
subprocess_init.register(setup_logging, config.verbose)
105107
subprocess_init.register(add_to_sys_path, loader.config_dir.name)
106-
edge_paths = [data_dir / name for name in FILENAMES.values()]
108+
input_edge_paths = [data_dir / name for name in FILENAMES]
109+
output_train_path, output_test_path = config.edge_paths
107110

108111
convert_input_data(
109112
config.entities,
110113
config.relations,
111114
config.entity_path,
112-
edge_paths,
115+
config.edge_paths,
116+
input_edge_paths,
113117
lhs_col=0,
114118
rhs_col=1,
115119
rel_col=None,
116120
dynamic_relations=config.dynamic_relations,
117121
)
118122

119-
train_path = [str(convert_path(data_dir / FILENAMES['train']))]
120-
train_config = attr.evolve(config, edge_paths=train_path)
121-
123+
train_config = attr.evolve(config, edge_paths=[output_train_path])
122124
train(train_config, subprocess_init=subprocess_init)
123125

124-
eval_path = [str(convert_path(data_dir / FILENAMES['test']))]
125-
eval_config = attr.evolve(config, edge_paths=eval_path)
126-
126+
eval_config = attr.evolve(config, edge_paths=[output_test_path])
127127
do_eval(eval_config, subprocess_init=subprocess_init)
128128

129129

0 commit comments

Comments
 (0)