Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit 75885ed

Browse files
lwfacebook-github-bot
authored andcommitted
Fix things that bothered me while testing
Reviewed By: adamlerer Differential Revision: D17226664 fbshipit-source-id: ee83bf7f498eb5da21f8d0c6f6b4fdc9cad44d62
1 parent f488504 commit 75885ed

9 files changed

Lines changed: 94 additions & 109 deletions

File tree

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ During preprocessing, the entities and relation types had their identifiers conv
139139
```bash
140140
torchbiggraph_export_to_tsv \
141141
torchbiggraph/examples/configs/fb15k_config.py \
142-
--checkpoint model/fb15k \
143142
--entities-output entity_embeddings.tsv \
144143
--relation-types-output relation_types_parameters.tsv
145144
```

torchbiggraph/checkpoint_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class CheckpointManager:
209209
def __init__(
210210
self,
211211
url: str,
212-
rank: Rank = -1,
212+
rank: Rank = 0,
213213
num_machines: int = 1,
214214
background: bool = False,
215215
partition_client: Optional[PartitionClient] = None,

torchbiggraph/checkpoint_storage.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import os
1212
from abc import ABC, abstractmethod
1313
from pathlib import Path
14-
from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type
14+
from typing import Any, Dict, NamedTuple, Optional, Tuple
1515

1616
import h5py
1717
import numpy as np
@@ -411,8 +411,11 @@ def save_config(self, config_json: str) -> None:
411411
tf.write(config_json)
412412

413413
def load_config(self) -> str:
414-
with self.get_config_file().open("rt") as tf:
415-
return tf.read()
414+
try:
415+
with self.get_config_file().open("rt") as tf:
416+
return tf.read()
417+
except FileNotFoundError as err:
418+
raise CouldNotLoadData() from err
416419

417420
def prepare_snapshot(self, version: int, epoch_idx: int) -> None:
418421
self.get_snapshot_path(epoch_idx).mkdir(parents=True, exist_ok=True)

torchbiggraph/converters/export_to_tsv.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def write(outf: TextIO, key: Iterable[str], value: Iterable[float]) -> None:
2727

2828
def make_tsv(
2929
config: ConfigSchema,
30-
checkpoint: str,
3130
entities_tf: TextIO,
3231
relation_types_tf: TextIO,
3332
) -> None:
@@ -39,7 +38,7 @@ def make_tsv(
3938
model = make_model(config)
4039

4140
print("Loading model check point...")
42-
checkpoint_manager = CheckpointManager(checkpoint)
41+
checkpoint_manager = CheckpointManager(config.checkpoint_path)
4342
state_dict, _ = checkpoint_manager.read_model()
4443
if state_dict is not None:
4544
model.load_state_dict(state_dict, strict=False)
@@ -136,7 +135,6 @@ def main():
136135
)
137136
parser.add_argument('config', help="Path to config file")
138137
parser.add_argument('-p', '--param', action='append', nargs='*')
139-
parser.add_argument('--checkpoint')
140138
parser.add_argument('--entities-output', required=True)
141139
parser.add_argument('--relation-types-output', required=True)
142140
opt = parser.parse_args()
@@ -152,7 +150,6 @@ def main():
152150
open(opt.relation_types_output, "xt") as relation_types_tf:
153151
make_tsv(
154152
config,
155-
opt.checkpoint,
156153
entities_tf,
157154
relation_types_tf,
158155
)

torchbiggraph/converters/import_from_tsv.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
# LICENSE.txt file in the root directory of this source tree.
88

99
import argparse
10-
import os
11-
import os.path
1210
import random
1311
from itertools import chain
12+
from pathlib import Path
1413
from typing import Any, Counter, DefaultDict, Dict, List, Optional, Tuple
1514

1615
import torch
@@ -23,6 +22,7 @@
2322
override_config_dict,
2423
)
2524
from torchbiggraph.converters.dictionary import Dictionary
25+
from torchbiggraph.converters.utils import convert_path
2626
from torchbiggraph.edgelist import EdgeList
2727
from torchbiggraph.entitylist import EntityList
2828
from torchbiggraph.graph_storages import (
@@ -37,7 +37,7 @@
3737

3838
def collect_relation_types(
3939
relation_configs: List[RelationSchema],
40-
edge_paths: List[str],
40+
edge_paths: List[Path],
4141
dynamic_relations: bool,
4242
rel_col: Optional[int],
4343
relation_type_min_count: int,
@@ -49,30 +49,29 @@ def collect_relation_types(
4949
print("Looking up relation types in the edge files...")
5050
counter: Counter[str] = Counter()
5151
for edgepath in edge_paths:
52-
with open(edgepath, "rt") as tf:
52+
with edgepath.open("rt") as tf:
5353
for line_num, line in enumerate(tf, start=1):
5454
words = line.split()
5555
try:
5656
rel_word = words[rel_col]
5757
except IndexError:
5858
raise RuntimeError(
59-
"Line %d of %s has only %d words"
60-
% (line_num, edgepath, len(words))) from None
59+
f"Line {line_num} of {edgepath} has only {len(words)} words"
60+
) from None
6161
counter[rel_word] += 1
62-
print("- Found %d relation types" % len(counter))
62+
print(f"- Found {len(counter)} relation types")
6363
if relation_type_min_count > 0:
64-
print("- Removing the ones with fewer than %d occurrences..."
65-
% relation_type_min_count)
64+
print(f"- Removing the ones with fewer than {relation_type_min_count} occurrences...")
6665
counter = Counter({k: c for k, c in counter.items()
6766
if c >= relation_type_min_count})
68-
print("- Left with %d relation types" % len(counter))
67+
print(f"- Left with {len(counter)} relation types")
6968
print("- Shuffling them...")
7069
names = list(counter.keys())
7170
random.shuffle(names)
7271

7372
else:
7473
names = [rconfig.name for rconfig in relation_configs]
75-
print("Using the %d relation types given in the config" % len(names))
74+
print(f"Using the {len(names)} relation types given in the config")
7675

7776
return Dictionary(names)
7877

@@ -81,7 +80,7 @@ def collect_entities_by_type(
8180
relation_types: Dictionary,
8281
entity_configs: Dict[str, EntitySchema],
8382
relation_configs: List[RelationSchema],
84-
edge_paths: List[str],
83+
edge_paths: List[Path],
8584
dynamic_relations: bool,
8685
lhs_col: int,
8786
rhs_col: int,
@@ -95,7 +94,7 @@ def collect_entities_by_type(
9594

9695
print("Searching for the entities in the edge files...")
9796
for edgepath in edge_paths:
98-
with open(edgepath, "rt") as tf:
97+
with edgepath.open("rt") as tf:
9998
for line_num, line in enumerate(tf, start=1):
10099
words = line.split()
101100
try:
@@ -120,14 +119,13 @@ def collect_entities_by_type(
120119

121120
entities_by_type: Dict[str, Dictionary] = {}
122121
for entity_name, counter in counters.items():
123-
print("Entity type %s:" % entity_name)
124-
print("- Found %d entities" % len(counter))
122+
print(f"Entity type {entity_name}:")
123+
print(f"- Found {len(counter)} entities")
125124
if entity_min_count > 0:
126-
print("- Removing the ones with fewer than %d occurrences..."
127-
% entity_min_count)
125+
print(f"- Removing the ones with fewer than {entity_min_count} occurrences...")
128126
counter = Counter({k: c for k, c in counter.items()
129127
if c >= entity_min_count})
130-
print("- Left with %d entities" % len(counter))
128+
print(f"- Left with {len(counter)} entities")
131129
print("- Shuffling them...")
132130
names = list(counter.keys())
133131
random.shuffle(names)
@@ -162,8 +160,8 @@ def generate_entity_path_files(
162160

163161

164162
def generate_edge_path_files(
165-
edge_file_in: str,
166-
edge_path_out: str,
163+
edge_file_in: Path,
164+
edge_path_out: Path,
167165
edge_storage: AbstractEdgeStorage,
168166
entities_by_type: Dict[str, Dictionary],
169167
relation_types: Dictionary,
@@ -189,7 +187,7 @@ def generate_edge_path_files(
189187
processed = 0
190188
skipped = 0
191189

192-
with open(edge_file_in, "rt") as tf:
190+
with edge_file_in.open("rt") as tf:
193191
for line_num, line in enumerate(tf, start=1):
194192
words = line.split()
195193
try:
@@ -256,7 +254,7 @@ def convert_input_data(
256254
entity_configs: Dict[str, EntitySchema],
257255
relation_configs: List[RelationSchema],
258256
entity_path: str,
259-
edge_paths: List[str],
257+
edge_paths: List[Path],
260258
lhs_col: int,
261259
rhs_col: int,
262260
rel_col: Optional[int] = None,
@@ -266,8 +264,8 @@ def convert_input_data(
266264
) -> None:
267265
entity_storage = ENTITY_STORAGES.make_instance(entity_path)
268266
relation_type_storage = RELATION_TYPE_STORAGES.make_instance(entity_path)
269-
edge_paths_out = [os.path.splitext(ep)[0] + "_partitioned" for ep in edge_paths]
270-
edge_storages = [EDGE_STORAGES.make_instance(ep) for ep in edge_paths_out]
267+
edge_paths_out = [convert_path(ep) for ep in edge_paths]
268+
edge_storages = [EDGE_STORAGES.make_instance(str(ep)) for ep in edge_paths_out]
271269

272270
some_files_exists = []
273271
some_files_exists.extend(
@@ -287,7 +285,8 @@ def convert_input_data(
287285
if all(some_files_exists):
288286
print("Found some files that indicate that the input data "
289287
"has already been preprocessed, not doing it again.")
290-
print(f"These files are in {entity_path} and {edge_paths}")
288+
all_paths = ", ".join(str(p) for p in [entity_path] + edge_paths_out)
289+
print(f"These files are in: {all_paths}")
291290
return
292291

293292
relation_types = collect_relation_types(
@@ -371,7 +370,7 @@ def main():
371370
)
372371
parser.add_argument('config', help='Path to config file')
373372
parser.add_argument('-p', '--param', action='append', nargs='*')
374-
parser.add_argument('edge_paths', nargs='*', help='Input file paths')
373+
parser.add_argument('edge_paths', type=Path, nargs='*', help='Input file paths')
375374
parser.add_argument('-l', '--lhs-col', type=int, required=True,
376375
help='Column index for source entity')
377376
parser.add_argument('-r', '--rhs-col', type=int, required=True,

torchbiggraph/converters/utils.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,44 @@
77
# LICENSE.txt file in the root directory of this source tree.
88

99
import gzip
10-
import os
1110
import shutil
1211
import tarfile
13-
import urllib.request
12+
from pathlib import Path
1413
from typing import Callable, Optional
14+
from urllib.parse import urlparse
15+
from urllib.request import urlretrieve
1516

1617
from tqdm import tqdm
1718

1819

19-
def extract_gzip(gzip_path: str, remove_finished: bool = False) -> str:
20-
print('Extracting %s' % gzip_path)
21-
fpath, ext = os.path.splitext(gzip_path)
22-
if ext != ".gz":
20+
def convert_path(fname: Path) -> Path:
21+
return fname.parent / f"{fname.stem}_partitioned"
22+
23+
24+
def extract_gzip(gzip_path: Path, remove_finished: bool = False) -> str:
25+
print(f"Extracting {gzip_path}")
26+
if gzip_path.suffix != ".gz":
2327
raise RuntimeError("Not a gzipped file")
28+
fpath = gzip_path.with_suffix("")
2429

25-
if os.path.exists(fpath):
30+
if fpath.exists():
2631
print("Found a file that indicates that the input data "
2732
"has already been extracted, not doing it again.")
28-
print("This file is: %s" % fpath)
33+
print(f"This file is: {fpath}")
2934
return fpath
3035

31-
with open(fpath, "wb") as out_bf, gzip.GzipFile(gzip_path) as zip_f:
36+
with fpath.open("wb") as out_bf, gzip.GzipFile(gzip_path) as zip_f:
3237
shutil.copyfileobj(zip_f, out_bf)
3338
if remove_finished:
34-
os.unlink(gzip_path)
39+
gzip_path.unlink()
3540

3641
return fpath
3742

3843

39-
def extract_tar(fpath: str) -> None:
44+
def extract_tar(fpath: Path) -> None:
4045
# extract file
41-
root = os.path.dirname(fpath)
4246
with tarfile.open(fpath, "r:gz") as tar:
43-
tar.extractall(path=root)
47+
tar.extractall(path=fpath.parent)
4448

4549

4650
def gen_bar_updater(pbar: tqdm) -> Callable[[int, int, int], None]:
@@ -53,7 +57,7 @@ def bar_update(count: int, block_size: int, total_size: int) -> None:
5357
return bar_update
5458

5559

56-
def download_url(url: str, root: str, filename: Optional[str] = None) -> str:
60+
def download_url(url: str, root: Path, filename: Optional[str] = None) -> str:
5761
"""Download a file from a url and place it in root.
5862
Args:
5963
url (str): URL to download file from
@@ -62,24 +66,24 @@ def download_url(url: str, root: str, filename: Optional[str] = None) -> str:
6266
If None, use the basename of the URL
6367
"""
6468

65-
root = os.path.expanduser(root)
66-
if not filename:
67-
filename = os.path.basename(url)
68-
fpath = os.path.join(root, filename)
69-
if not os.path.exists(root):
70-
os.makedirs(root)
69+
root = root.expanduser()
70+
if filename is None:
71+
filename = Path(urlparse(url).path).name
72+
fpath = root / filename
73+
if not root.exists():
74+
root.mkdir(parents=True, exist_ok=True)
7175

7276
# downloads file
73-
if os.path.isfile(fpath):
74-
print('Using downloaded and verified file: ' + fpath)
77+
if fpath.is_file():
78+
print(f"Using downloaded and verified file: {fpath}")
7579
else:
7680
try:
77-
print('Downloading ' + url + ' to ' + fpath)
78-
urllib.request.urlretrieve(
79-
url, fpath,
81+
print(f"Downloading {url} to {fpath}")
82+
urlretrieve(
83+
url, str(fpath),
8084
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
8185
)
8286
except OSError:
83-
print('Failed to download from url: ' + url)
87+
print(f"Failed to download from url: {url}")
8488

8589
return fpath

0 commit comments

Comments
 (0)