Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit 53c9ce2

Browse files
lwfacebook-github-bot
authored andcommitted
Have edge storage raise proper exception
Summary: The interface prescribes raising CouldNotLoadData when a file is missing, we were raising RuntimeError. Reviewed By: adamlerer Differential Revision: D17571776 fbshipit-source-id: 04af21aa26a38fed235a8084c863fc6b5a4c0cee
1 parent 9c9e809 commit 53c9ce2

1 file changed

Lines changed: 47 additions & 36 deletions

File tree

torchbiggraph/graph_storages.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# This source code is licensed under the BSD-style license found in the
77
# LICENSE.txt file in the root directory of this source tree.
88

9+
import errno
910
import json
1011
import logging
1112
from abc import ABC, abstractmethod
@@ -394,12 +395,17 @@ def has_edges(
394395

395396
def get_number_of_edges(self, lhs_p: int, rhs_p: int) -> int:
396397
file_path = self.get_edges_file(lhs_p, rhs_p)
397-
if not file_path.is_file():
398-
raise RuntimeError(f"{file_path} does not exist")
399-
with h5py.File(file_path, "r") as hf:
400-
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
401-
raise RuntimeError(f"Version mismatch in edge file {file_path}")
402-
return hf["rel"].len()
398+
try:
399+
with h5py.File(file_path, "r") as hf:
400+
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
401+
raise RuntimeError(f"Version mismatch in edge file {file_path}")
402+
return hf["rel"].len()
403+
except OSError as err:
404+
# h5py refuses to make it easy to figure out what went wrong. The errno
405+
# attribute is set to None. See https://github.com/h5py/h5py/issues/493.
406+
if f"errno = {errno.ENOENT}" in str(err):
407+
raise CouldNotLoadData() from err
408+
raise err
403409

404410
def load_chunk_of_edges(
405411
self,
@@ -409,36 +415,41 @@ def load_chunk_of_edges(
409415
num_chunks: int = 1,
410416
) -> EdgeList:
411417
file_path = self.get_edges_file(lhs_p, rhs_p)
412-
if not file_path.is_file():
413-
raise RuntimeError(f"{file_path} does not exist")
414-
with h5py.File(file_path, 'r') as hf:
415-
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
416-
raise RuntimeError(f"Version mismatch in edge file {file_path}")
417-
lhs_ds = hf['lhs']
418-
rhs_ds = hf['rhs']
419-
rel_ds = hf['rel']
420-
421-
num_edges = rel_ds.len()
422-
begin = int(chunk_idx * num_edges / num_chunks)
423-
end = int((chunk_idx + 1) * num_edges / num_chunks)
424-
chunk_size = end - begin
425-
426-
lhs = torch.empty((chunk_size,), dtype=torch.long)
427-
rhs = torch.empty((chunk_size,), dtype=torch.long)
428-
rel = torch.empty((chunk_size,), dtype=torch.long)
429-
430-
# Needed because https://github.com/h5py/h5py/issues/870.
431-
if chunk_size > 0:
432-
lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end])
433-
rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end])
434-
rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end])
435-
436-
lhsd = self.read_dynamic(hf, 'lhsd', begin, end)
437-
rhsd = self.read_dynamic(hf, 'rhsd', begin, end)
438-
439-
return EdgeList(EntityList(lhs, lhsd),
440-
EntityList(rhs, rhsd),
441-
rel)
418+
try:
419+
with h5py.File(file_path, "r") as hf:
420+
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
421+
raise RuntimeError(f"Version mismatch in edge file {file_path}")
422+
lhs_ds = hf["lhs"]
423+
rhs_ds = hf["rhs"]
424+
rel_ds = hf["rel"]
425+
426+
num_edges = rel_ds.len()
427+
begin = int(chunk_idx * num_edges / num_chunks)
428+
end = int((chunk_idx + 1) * num_edges / num_chunks)
429+
chunk_size = end - begin
430+
431+
lhs = torch.empty((chunk_size,), dtype=torch.long)
432+
rhs = torch.empty((chunk_size,), dtype=torch.long)
433+
rel = torch.empty((chunk_size,), dtype=torch.long)
434+
435+
# Needed because https://github.com/h5py/h5py/issues/870.
436+
if chunk_size > 0:
437+
lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end])
438+
rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end])
439+
rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end])
440+
441+
lhsd = self.read_dynamic(hf, "lhsd", begin, end)
442+
rhsd = self.read_dynamic(hf, "rhsd", begin, end)
443+
444+
return EdgeList(EntityList(lhs, lhsd),
445+
EntityList(rhs, rhsd),
446+
rel)
447+
except OSError as err:
448+
# h5py refuses to make it easy to figure out what went wrong. The errno
449+
# attribute is set to None. See https://github.com/h5py/h5py/issues/493.
450+
if f"errno = {errno.ENOENT}" in str(err):
451+
raise CouldNotLoadData() from err
452+
raise err
442453

443454
@staticmethod
444455
def read_dynamic(

0 commit comments

Comments
 (0)