1313from itertools import chain
1414from typing import Any , Counter , DefaultDict , Dict , List , Optional , Tuple
1515
16- import h5py
17- import numpy as np
16+ import torch
1817
1918from torchbiggraph .config import (
2019 ConfigFileLoader ,
2423 override_config_dict ,
2524)
2625from torchbiggraph .converters .dictionary import Dictionary
26+ from torchbiggraph .edgelist import EdgeList
27+ from torchbiggraph .entitylist import EntityList
2728from torchbiggraph .graph_storages import (
29+ AbstractEdgeStorage ,
2830 AbstractEntityStorage ,
2931 AbstractRelationTypeStorage ,
32+ EDGE_STORAGES ,
3033 ENTITY_STORAGES ,
3134 RELATION_TYPE_STORAGES ,
3235)
@@ -160,6 +163,8 @@ def generate_entity_path_files(
160163
161164def generate_edge_path_files (
162165 edge_file_in : str ,
166+ edge_path_out : str ,
167+ edge_storage : AbstractEdgeStorage ,
163168 entities_by_type : Dict [str , Dictionary ],
164169 relation_types : Dictionary ,
165170 relation_configs : List [RelationSchema ],
@@ -168,21 +173,16 @@ def generate_edge_path_files(
168173 rhs_col : int ,
169174 rel_col : Optional [int ],
170175) -> None :
171-
172- basename , _ = os .path .splitext (edge_file_in )
173- edge_path_out = basename + '_partitioned'
174-
175- print ("Preparing edge path %s, out of the edges found in %s"
176- % (edge_path_out , edge_file_in ))
177- os .makedirs (edge_path_out , exist_ok = True )
176+ print (f"Preparing edge path { edge_path_out } , "
177+ f"out of the edges found in { edge_file_in } " )
178+ edge_storage .prepare ()
178179
179180 num_lhs_parts = max (entities_by_type [rconfig .lhs ].num_parts
180181 for rconfig in relation_configs )
181182 num_rhs_parts = max (entities_by_type [rconfig .rhs ].num_parts
182183 for rconfig in relation_configs )
183184
184- print ("- Edges will be partitioned in %d x %d buckets."
185- % (num_lhs_parts , num_rhs_parts ))
185+ print (f"- Edges will be partitioned in { num_lhs_parts } x { num_rhs_parts } buckets." )
186186
187187 buckets : DefaultDict [Tuple [int , int ], List [Tuple [int , int , int ]]] = \
188188 DefaultDict (list )
@@ -198,8 +198,8 @@ def generate_edge_path_files(
198198 rel_word = words [rel_col ] if rel_col is not None else None
199199 except IndexError :
200200 raise RuntimeError (
201- "Line %d of %s has only %d words"
202- % ( line_num , edge_file_in , len ( words )) ) from None
201+ f "Line { line_num } of { edge_file_in } has only { len ( words ) } words"
202+ ) from None
203203
204204 if rel_col is None :
205205 rel_id = 0
@@ -232,26 +232,24 @@ def generate_edge_path_files(
232232
233233 processed = processed + 1
234234 if processed % 100000 == 0 :
235- print ("- Processed %d edges so far..." % processed )
235+ print (f "- Processed { processed } edges so far..." )
236236
237- print ("- Processed %d edges in total" % processed )
237+ print (f "- Processed { processed } edges in total" )
238238 if skipped > 0 :
239- print ("- Skipped %d edges because their relation type or entities were "
240- " unknown (either not given in the config or filtered out as too "
241- " rare)." % skipped )
239+ print (f "- Skipped { skipped } edges because their relation type or "
240+ f"entities were unknown (either not given in the config or "
241+ f"filtered out as too rare)." )
242242
243243 for i in range (num_lhs_parts ):
244244 for j in range (num_rhs_parts ):
245- print ("- Writing bucket (%d, %d), containing %d edges..."
246- % (i , j , len (buckets [i , j ])))
247- edges = np .array (buckets [i , j ], dtype = np .int64 ).reshape ((- 1 , 3 ))
248- with h5py .File (os .path .join (
249- edge_path_out , "edges_%d_%d.h5" % (i , j )
250- ), "w" ) as hf :
251- hf .attrs ["format_version" ] = 1
252- hf .create_dataset ("lhs" , data = edges [:, 0 ])
253- hf .create_dataset ("rhs" , data = edges [:, 1 ])
254- hf .create_dataset ("rel" , data = edges [:, 2 ])
245+ print (f"- Writing bucket ({ i } , { j } ), "
246+ f"containing { len (buckets [i , j ])} edges..." )
247+ edges = torch .tensor (buckets [i , j ], dtype = torch .long ).view ((- 1 , 3 ))
248+ edge_storage .save_edges (i , j , EdgeList (
249+ EntityList .from_tensor (edges [:, 0 ]),
250+ EntityList .from_tensor (edges [:, 1 ]),
251+ edges [:, 2 ],
252+ ))
255253
256254
257255def convert_input_data (
@@ -268,6 +266,8 @@ def convert_input_data(
268266) -> None :
269267 entity_storage = ENTITY_STORAGES .make_instance (entity_path )
270268 relation_type_storage = RELATION_TYPE_STORAGES .make_instance (entity_path )
269+ edge_paths_out = [os .path .splitext (ep )[0 ] + "_partitioned" for ep in edge_paths ]
270+ edge_storages = [EDGE_STORAGES .make_instance (ep ) for ep in edge_paths_out ]
271271
272272 some_files_exists = []
273273 some_files_exists .extend (
@@ -282,8 +282,7 @@ def convert_input_data(
282282 some_files_exists .append (relation_type_storage .has_count ())
283283 some_files_exists .append (relation_type_storage .has_names ())
284284 some_files_exists .extend (
285- os .path .exists (os .path .join (os .path .splitext (edge_file )[0 ] + "_partitioned" , "edges_0_0.h5" ))
286- for edge_file in edge_paths )
285+ edge_storage .has_edges (0 , 0 ) for edge_storage in edge_storages )
287286
288287 if all (some_files_exists ):
289288 print ("Found some files that indicate that the input data "
@@ -319,9 +318,12 @@ def convert_input_data(
319318 dynamic_relations ,
320319 )
321320
322- for edge_path in edge_paths :
321+ for edge_path , edge_path_out , edge_storage \
322+ in zip (edge_paths , edge_paths_out , edge_storages ):
323323 generate_edge_path_files (
324324 edge_path ,
325+ edge_path_out ,
326+ edge_storage ,
325327 entities_by_type ,
326328 relation_types ,
327329 relation_configs ,
0 commit comments