Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit 3edcc3e

Browse files
lwfacebook-github-bot
authored andcommitted
Make entity count readers also responsible for writing, and call them storages
Summary: Also make them handle the dynamic relation count. By enclosing all the I/O code that actually deals with files inside a single class, it is now enough to register a new implementation of that class to transparently use a new storage backend across all of PBG, including import and export! Reviewed By: adamlerer Differential Revision: D17183917 fbshipit-source-id: ae0a1bebeb93b8868acea8827a5ef139b60703ea
1 parent f4a819b commit 3edcc3e

12 files changed

Lines changed: 309 additions & 158 deletions

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ During preprocessing, the entities and relation types had their identifiers conv
139139
```bash
140140
torchbiggraph_export_to_tsv \
141141
torchbiggraph/examples/configs/fb15k_config.py \
142-
--dict data/FB15k/dictionary.json \
143142
--checkpoint model/fb15k \
144143
--entities-output entity_embeddings.tsv \
145144
--relation-types-output relation_types_parameters.tsv

docs/source/downstream_tasks.rst

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,20 @@ Reading the HDF5 format
3030
Suppose that you have completed the training of the ``torchbiggraph_example_fb15k`` command and want to now
3131
look up the embedding of some entity. For that, we'll need to read:
3232

33-
- the embeddings, from the checkpoint files (the :file:`.h5` files in the `model/fb15k` directory, or
33+
- the embeddings, from the checkpoint files (the :file:`.h5` files in the :file:`model/fb15k` directory, or
3434
whatever directory was specified as the ``checkpoint_path``); and
35-
- the mapping from entity names to their partitions and offsets, from the :file:`data/FB15k/dictionary.json`
36-
file created by the ``torchbiggraph_import_from_tsv`` command.
35+
- the names of the entities of a certain type and partition (ordere by their offset), from the files in the
36+
:file:`data/FB15k` directory (or an alternative directory given as the ``entity_path``), created by the
37+
``torchbiggraph_import_from_tsv`` command.
3738

3839
The embedding of, say, entity ``/m/05hf_5`` can be found as follows::
3940

4041
import json
4142
import h5py
4243

43-
with open("data/FB15k/dictionary.json", "rt") as tf:
44-
dictionary = json.load(tf)
45-
offset = dictionary["entities"]["all"].index("/m/05hf_5")
44+
with open("data/FB15k/entity_names_all_0.json", "rt") as tf:
45+
names = json.load(tf)
46+
offset = names.index("/m/05hf_5")
4647

4748
with h5py.File("model/fb15k/embeddings_all_0.v50.h5", "r") as hf:
4849
embedding = hf["embeddings"][offset, :]
@@ -162,12 +163,16 @@ being the capital of France::
162163
operator.load_state_dict(operator_state_dict)
163164
comparator = DotComparator()
164165

165-
# Load the offsets of the entities and the index of the relation type
166-
with open("data/FB15k/dictionary.json", "rt") as tf:
167-
dictionary = json.load(tf)
168-
src_entity_offset = dictionary["entities"]["all"].index("/m/0f8l9c") # France
169-
dest_entity_offset = dictionary["entities"]["all"].index("/m/05qtj") # Paris
170-
rel_type_index = dictionary["relations"].index("/location/country/capital")
166+
# Load the names of the entities, ordered by offset.
167+
with open("data/FB15k/entity_names_all_0.json", "rt") as tf:
168+
entity_names = json.load(tf)
169+
src_entity_offset = entity_names.index("/m/0f8l9c") # France
170+
dest_entity_offset = entity_names.index("/m/05qtj") # Paris
171+
172+
# Load the names of the relation types, ordered by index.
173+
with open("data/FB15k/dynamic_rel_names.json", "rt") as tf:
174+
rel_type_names = json.load(tf)
175+
rel_type_index = rel_type_names.index("/location/country/capital")
171176

172177
# Load the trained embeddings
173178
with h5py.File("model/fb15k/embeddings_all_0.v50.h5", "r") as hf:
@@ -220,10 +225,12 @@ entities are most likely to be the capital of France::
220225
comparator = DotComparator()
221226

222227
# Load the offsets of the entities and the index of the relation type
223-
with open("data/FB15k/dictionary.json", "rt") as tf:
224-
dictionary = json.load(tf)
225-
src_entity_offset = dictionary["entities"]["all"].index("/m/0f8l9c") # France
226-
rel_type_index = dictionary["relations"].index("/location/country/capital")
228+
with open("data/FB15k/entity_names_all_0.json", "rt") as tf:
229+
entity_names = json.load(tf)
230+
src_entity_offset = entity_names.index("/m/0f8l9c") # France
231+
with open("data/FB15k/dynamic_rel_names.json", "rt") as tf:
232+
rel_type_names = json.load(tf)
233+
rel_type_index = rel_type_names.index("/location/country/capital")
227234

228235
# Load the trained embeddings
229236
with h5py.File("model/fb15k/embeddings_all_0.v50.h5", "r") as hf:
@@ -245,7 +252,7 @@ entities are most likely to be the capital of France::
245252

246253
# Sort the entities by their score
247254
permutation = scores.flatten().argsort(descending=True)
248-
top5_entities = [dictionary["entities"]["all"][index] for index in permutation[:5]]
255+
top5_entities = [entity_names[index] for index in permutation[:5]]
249256

250257
print(top5_entities)
251258

@@ -271,17 +278,17 @@ library. The following code looks for the entities that are closest to Paris::
271278
index.add(hf["embeddings"][...])
272279

273280
# Get trained embedding of Paris
274-
with open("data/FB15k/dictionary.json", "rt") as f:
275-
dictionary = json.load(f)
276-
target_entity_offset = dictionary["entities"]["all"].index("/m/05qtj") # Paris
281+
with open("data/FB15k/entity_names_all_0.json", "rt") as tf:
282+
entity_names = json.load(tf)
283+
target_entity_offset = entity_names.index("/m/05qtj") # Paris
277284
with h5py.File("model/fb15k/embeddings_all_0.v50.h5", "r") as hf:
278285
target_embedding = hf["embeddings"][target_entity_offset, :]
279286

280287
# Search nearest neighbors
281288
_, neighbors = index.search(target_embedding.reshape((1, 400)), 5)
282289

283290
# Map back to entity names
284-
top5_entities = [dictionary["entities"]["all"][index] for index in neighbors[0]]
291+
top5_entities = [entity_names[index] for index in neighbors[0]]
285292

286293
print(top5_entities)
287294

torchbiggraph/checkpoint_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from torchbiggraph.checkpoint_storage import (
2424
AbstractCheckpointStorage,
2525
CHECKPOINT_STORAGES,
26-
CouldNotLoadData,
2726
ModelParameter,
2827
)
2928
from torchbiggraph.config import ConfigSchema
@@ -37,7 +36,7 @@
3736
Partition,
3837
Rank,
3938
)
40-
from torchbiggraph.util import create_pool, get_async_result
39+
from torchbiggraph.util import CouldNotLoadData, create_pool, get_async_result
4140

4241

4342
logger = logging.getLogger("torchbiggraph")

torchbiggraph/checkpoint_storage.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,12 @@
2424
ModuleStateDict,
2525
Partition,
2626
)
27+
from torchbiggraph.util import CouldNotLoadData
2728

2829

2930
logger = logging.getLogger("torchbiggraph")
3031

3132

32-
class CouldNotLoadData(Exception):
33-
pass
34-
35-
3633
class ModelParameter(NamedTuple):
3734
# This is the "internal" name, the one of the model's state dict, which is
3835
# considered an implementation detail. Thus the parameters are stored under

torchbiggraph/converters/dictionary.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# This source code is licensed under the BSD-style license found in the
77
# LICENSE.txt file in the root directory of this source tree.
88

9+
import math
910
from typing import Dict, List, Tuple
1011

1112

@@ -27,18 +28,27 @@ def get_id(self, word: str) -> int:
2728
def size(self) -> int:
2829
return len(self.ix_to_word)
2930

31+
def get_list(self) -> List[str]:
32+
return self.ix_to_word
33+
34+
def part_start(self, part: int) -> int:
35+
return math.ceil(part / self.num_parts * self.size())
36+
37+
def part_end(self, part: int) -> int:
38+
return self.part_start(part + 1)
39+
3040
def part_size(self, part: int) -> int:
3141
if not 0 <= part < self.num_parts:
32-
raise ValueError("%d not in [0, %d)" % (part, self.num_parts))
33-
part_begin = (part * self.size() - 1) // self.num_parts + 1
34-
part_end = ((part + 1) * self.size() - 1) // self.num_parts
35-
return part_end - part_begin + 1
42+
raise ValueError(f"{part} not in [0, {self.num_parts})")
43+
return self.part_end(part) - self.part_start(part)
3644

3745
def get_partition(self, word: str) -> Tuple[int, int]:
3846
idx = self.get_id(word)
39-
part = idx * self.num_parts // self.size()
40-
part_begin = (part * self.size() - 1) // self.num_parts + 1
41-
return part, idx - part_begin
47+
part = math.floor(idx / self.size() * self.num_parts)
48+
assert self.part_start(part) <= idx < self.part_end(part)
49+
return part, idx - self.part_start(part)
4250

43-
def get_list(self) -> List[str]:
44-
return self.ix_to_word
51+
def get_part_list(self, part: int) -> List[str]:
52+
if not 0 <= part < self.num_parts:
53+
raise ValueError(f"{part} not in [0, {self.num_parts})")
54+
return self.ix_to_word[self.part_start(part):self.part_end(part)]

torchbiggraph/converters/export_to_tsv.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@
77
# LICENSE.txt file in the root directory of this source tree.
88

99
import argparse
10-
import json
1110
from itertools import chain
12-
from typing import Dict, Iterable, List, TextIO
11+
from typing import Iterable, TextIO
1312

1413
from torchbiggraph.checkpoint_manager import CheckpointManager
1514
from torchbiggraph.config import ConfigFileLoader, ConfigSchema
15+
from torchbiggraph.graph_storages import (
16+
AbstractEntityStorage,
17+
AbstractRelationTypeStorage,
18+
ENTITY_STORAGES,
19+
RELATION_TYPE_STORAGES,
20+
)
1621
from torchbiggraph.model import MultiRelationEmbedder, make_model
1722

1823

@@ -23,11 +28,13 @@ def write(outf: TextIO, key: Iterable[str], value: Iterable[float]) -> None:
2328
def make_tsv(
2429
config: ConfigSchema,
2530
checkpoint: str,
26-
entities_by_type: Dict[str, List[str]],
27-
relation_types: List[str],
2831
entities_tf: TextIO,
2932
relation_types_tf: TextIO,
3033
) -> None:
34+
print("Loading relation types and entities...")
35+
entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
36+
relation_type_storage = RELATION_TYPE_STORAGES.make_instance(config.entity_path)
37+
3138
print("Initializing model...")
3239
model = make_model(config)
3340

@@ -40,29 +47,28 @@ def make_tsv(
4047
make_tsv_for_entities(
4148
model,
4249
checkpoint_manager,
43-
entities_by_type,
50+
entity_storage,
4451
entities_tf,
4552
)
4653
make_tsv_for_relation_types(
4754
model,
48-
relation_types,
55+
relation_type_storage,
4956
relation_types_tf,
5057
)
5158

5259

5360
def make_tsv_for_entities(
5461
model: MultiRelationEmbedder,
5562
checkpoint_manager: CheckpointManager,
56-
entities_by_type: Dict[str, List[str]],
63+
entity_storage: AbstractEntityStorage,
5764
entities_tf: TextIO,
5865
) -> None:
5966
print("Writing entity embeddings...")
6067
for ent_t_name, ent_t_config in model.entities.items():
61-
entities = entities_by_type[ent_t_name]
62-
partition_offset = 0
6368
for partition in range(ent_t_config.num_partitions):
6469
print(f"Reading embeddings for entity type {ent_t_name} partition "
6570
f"{partition} from checkpoint...")
71+
entities = entity_storage.load_names(ent_t_name, partition)
6672
embeddings, _ = checkpoint_manager.read(ent_t_name, partition)
6773

6874
if model.global_embs is not None:
@@ -71,23 +77,22 @@ def make_tsv_for_entities(
7177
print(f"Writing embeddings for entity type {ent_t_name} partition "
7278
f"{partition} to output file...")
7379
for ix in range(len(embeddings)):
74-
write(entities_tf, (entities[partition_offset + ix],), embeddings[ix])
80+
write(entities_tf, (entities[ix],), embeddings[ix])
7581
if (ix + 1) % 5000 == 0:
7682
print(f"- Processed {ix+1}/{len(embeddings)} entities so far...")
7783
print(f"- Processed all {len(embeddings)} entities")
7884

79-
partition_offset += len(embeddings)
80-
8185
entities_output_filename = getattr(entities_tf, "name", "the output file")
8286
print(f"Done exporting entity data to {entities_output_filename}")
8387

8488

8589
def make_tsv_for_relation_types(
8690
model: MultiRelationEmbedder,
87-
relation_types: List[str],
91+
relation_type_storage: AbstractRelationTypeStorage,
8892
relation_types_tf: TextIO,
8993
) -> None:
9094
print("Writing relation type parameters...")
95+
relation_types = relation_type_storage.load_names()
9196
if model.num_dynamic_rels > 0:
9297
rel_t_config, = model.relations
9398
op_name = rel_t_config.operator
@@ -132,7 +137,6 @@ def main():
132137
parser.add_argument('config', help="Path to config file")
133138
parser.add_argument('-p', '--param', action='append', nargs='*')
134139
parser.add_argument('--checkpoint')
135-
parser.add_argument('--dict', required=True)
136140
parser.add_argument('--entities-output', required=True)
137141
parser.add_argument('--relation-types-output', required=True)
138142
opt = parser.parse_args()
@@ -144,17 +148,11 @@ def main():
144148
loader = ConfigFileLoader()
145149
config = loader.load_config(opt.config, overrides)
146150

147-
print("Loading relation types and entities...")
148-
with open(opt.dict, "rt") as tf:
149-
dump = json.load(tf)
150-
151151
with open(opt.entities_output, "xt") as entities_tf, \
152152
open(opt.relation_types_output, "xt") as relation_types_tf:
153153
make_tsv(
154154
config,
155155
opt.checkpoint,
156-
dump["entities"],
157-
dump["relations"],
158156
entities_tf,
159157
relation_types_tf,
160158
)

0 commit comments

Comments
 (0)