77# LICENSE.txt file in the root directory of this source tree.
88
99import argparse
10- import os
11- import os .path
1210import random
1311from itertools import chain
12+ from pathlib import Path
1413from typing import Any , Counter , DefaultDict , Dict , List , Optional , Tuple
1514
1615import torch
2322 override_config_dict ,
2423)
2524from torchbiggraph .converters .dictionary import Dictionary
25+ from torchbiggraph .converters .utils import convert_path
2626from torchbiggraph .edgelist import EdgeList
2727from torchbiggraph .entitylist import EntityList
2828from torchbiggraph .graph_storages import (
3737
3838def collect_relation_types (
3939 relation_configs : List [RelationSchema ],
40- edge_paths : List [str ],
40+ edge_paths : List [Path ],
4141 dynamic_relations : bool ,
4242 rel_col : Optional [int ],
4343 relation_type_min_count : int ,
@@ -49,30 +49,29 @@ def collect_relation_types(
4949 print ("Looking up relation types in the edge files..." )
5050 counter : Counter [str ] = Counter ()
5151 for edgepath in edge_paths :
52- with open (edgepath , "rt" ) as tf :
52+ with edgepath . open ("rt" ) as tf :
5353 for line_num , line in enumerate (tf , start = 1 ):
5454 words = line .split ()
5555 try :
5656 rel_word = words [rel_col ]
5757 except IndexError :
5858 raise RuntimeError (
59- "Line %d of %s has only %d words"
60- % ( line_num , edgepath , len ( words )) ) from None
59+ f "Line { line_num } of { edgepath } has only { len ( words ) } words"
60+ ) from None
6161 counter [rel_word ] += 1
62- print ("- Found %d relation types" % len ( counter ) )
62+ print (f "- Found { len ( counter ) } relation types" )
6363 if relation_type_min_count > 0 :
64- print ("- Removing the ones with fewer than %d occurrences..."
65- % relation_type_min_count )
64+ print (f"- Removing the ones with fewer than { relation_type_min_count } occurrences..." )
6665 counter = Counter ({k : c for k , c in counter .items ()
6766 if c >= relation_type_min_count })
68- print ("- Left with %d relation types" % len ( counter ) )
67+ print (f "- Left with { len ( counter ) } relation types" )
6968 print ("- Shuffling them..." )
7069 names = list (counter .keys ())
7170 random .shuffle (names )
7271
7372 else :
7473 names = [rconfig .name for rconfig in relation_configs ]
75- print ("Using the %d relation types given in the config" % len ( names ) )
74+ print (f "Using the { len ( names ) } relation types given in the config" )
7675
7776 return Dictionary (names )
7877
@@ -81,7 +80,7 @@ def collect_entities_by_type(
8180 relation_types : Dictionary ,
8281 entity_configs : Dict [str , EntitySchema ],
8382 relation_configs : List [RelationSchema ],
84- edge_paths : List [str ],
83+ edge_paths : List [Path ],
8584 dynamic_relations : bool ,
8685 lhs_col : int ,
8786 rhs_col : int ,
@@ -95,7 +94,7 @@ def collect_entities_by_type(
9594
9695 print ("Searching for the entities in the edge files..." )
9796 for edgepath in edge_paths :
98- with open (edgepath , "rt" ) as tf :
97+ with edgepath . open ("rt" ) as tf :
9998 for line_num , line in enumerate (tf , start = 1 ):
10099 words = line .split ()
101100 try :
@@ -120,14 +119,13 @@ def collect_entities_by_type(
120119
121120 entities_by_type : Dict [str , Dictionary ] = {}
122121 for entity_name , counter in counters .items ():
123- print ("Entity type %s:" % entity_name )
124- print ("- Found %d entities" % len (counter ))
122+ print (f "Entity type { entity_name } :" )
123+ print (f "- Found { len (counter )} entities" )
125124 if entity_min_count > 0 :
126- print ("- Removing the ones with fewer than %d occurrences..."
127- % entity_min_count )
125+ print (f"- Removing the ones with fewer than { entity_min_count } occurrences..." )
128126 counter = Counter ({k : c for k , c in counter .items ()
129127 if c >= entity_min_count })
130- print ("- Left with %d entities" % len (counter ))
128+ print (f "- Left with { len (counter )} entities" )
131129 print ("- Shuffling them..." )
132130 names = list (counter .keys ())
133131 random .shuffle (names )
@@ -162,8 +160,8 @@ def generate_entity_path_files(
162160
163161
164162def generate_edge_path_files (
165- edge_file_in : str ,
166- edge_path_out : str ,
163+ edge_file_in : Path ,
164+ edge_path_out : Path ,
167165 edge_storage : AbstractEdgeStorage ,
168166 entities_by_type : Dict [str , Dictionary ],
169167 relation_types : Dictionary ,
@@ -189,7 +187,7 @@ def generate_edge_path_files(
189187 processed = 0
190188 skipped = 0
191189
192- with open (edge_file_in , "rt" ) as tf :
190+ with edge_file_in . open ("rt" ) as tf :
193191 for line_num , line in enumerate (tf , start = 1 ):
194192 words = line .split ()
195193 try :
@@ -256,7 +254,7 @@ def convert_input_data(
256254 entity_configs : Dict [str , EntitySchema ],
257255 relation_configs : List [RelationSchema ],
258256 entity_path : str ,
259- edge_paths : List [str ],
257+ edge_paths : List [Path ],
260258 lhs_col : int ,
261259 rhs_col : int ,
262260 rel_col : Optional [int ] = None ,
@@ -266,8 +264,8 @@ def convert_input_data(
266264) -> None :
267265 entity_storage = ENTITY_STORAGES .make_instance (entity_path )
268266 relation_type_storage = RELATION_TYPE_STORAGES .make_instance (entity_path )
269- edge_paths_out = [os . path . splitext (ep )[ 0 ] + "_partitioned" for ep in edge_paths ]
270- edge_storages = [EDGE_STORAGES .make_instance (ep ) for ep in edge_paths_out ]
267+ edge_paths_out = [convert_path (ep ) for ep in edge_paths ]
268+ edge_storages = [EDGE_STORAGES .make_instance (str ( ep ) ) for ep in edge_paths_out ]
271269
272270 some_files_exists = []
273271 some_files_exists .extend (
@@ -287,7 +285,8 @@ def convert_input_data(
287285 if all (some_files_exists ):
288286 print ("Found some files that indicate that the input data "
289287 "has already been preprocessed, not doing it again." )
290- print (f"These files are in { entity_path } and { edge_paths } " )
288+ all_paths = ", " .join (str (p ) for p in [entity_path ] + edge_paths_out )
289+ print (f"These files are in: { all_paths } " )
291290 return
292291
293292 relation_types = collect_relation_types (
@@ -371,7 +370,7 @@ def main():
371370 )
372371 parser .add_argument ('config' , help = 'Path to config file' )
373372 parser .add_argument ('-p' , '--param' , action = 'append' , nargs = '*' )
374- parser .add_argument ('edge_paths' , nargs = '*' , help = 'Input file paths' )
373+ parser .add_argument ('edge_paths' , type = Path , nargs = '*' , help = 'Input file paths' )
375374 parser .add_argument ('-l' , '--lhs-col' , type = int , required = True ,
376375 help = 'Column index for source entity' )
377376 parser .add_argument ('-r' , '--rhs-col' , type = int , required = True ,
0 commit comments