Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit cb7830b

Browse files
Xin Jinfacebook-github-bot
authored andcommitted
Ability to Disable One Side's Negative Sampling
Summary: As titled. Reviewed By: lerks Differential Revision: D17337389 fbshipit-source-id: fdf9884b51d08dd4bb9ffde45d0e1926d0d7c390
1 parent 14395ca commit cb7830b

3 files changed

Lines changed: 51 additions & 17 deletions

File tree

torchbiggraph/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,14 @@ class ConfigSchema(Schema):
259259
metadata={'help': "The number of negatives uniformly sampled from the "
260260
"currently active partition, per positive edge."},
261261
)
262+
disable_lhs_negs : bool = attr.ib(
263+
default=False,
264+
metadata={'help': "Disable negative sampling on the left-hand side."},
265+
)
266+
disable_rhs_negs : bool = attr.ib(
267+
default=False,
268+
metadata={'help': "Disable negative sampling on the right-hand side."},
269+
)
262270
lr: float = attr.ib(
263271
default=1e-2,
264272
validator=non_negative,
@@ -373,6 +381,9 @@ def __attrs_post_init__(self):
373381
if self.loss_fn == "logistic" and self.comparator == "cos":
374382
logger.warning("You have logistic loss and cosine distance. Are you sure?")
375383

384+
if self.disable_lhs_negs and self.disable_rhs_negs:
385+
raise ValueError("Cannot disable negative sampling on both sides.")
386+
376387

377388
# TODO make this a non-inplace operation
378389
def override_config_dict(config_dict: Any, overrides: List[str]) -> Any:

torchbiggraph/eval.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,28 @@ def eval(
6565
) -> Stats:
6666
batch_size = len(batch_edges)
6767

68-
lhs_rank = (scores.lhs_neg >= scores.lhs_pos.unsqueeze(1)).sum(1) + 1
69-
rhs_rank = (scores.rhs_neg >= scores.rhs_pos.unsqueeze(1)).sum(1) + 1
70-
71-
lhs_auc = compute_randomized_auc(scores.lhs_pos, scores.lhs_neg, batch_size)
72-
rhs_auc = compute_randomized_auc(scores.rhs_pos, scores.rhs_neg, batch_size)
68+
ranks = []
69+
aucs = []
70+
if scores.lhs_neg.nelement() > 0:
71+
lhs_rank = (scores.lhs_neg >= scores.lhs_pos.unsqueeze(1)).sum(1) + 1
72+
lhs_auc = compute_randomized_auc(scores.lhs_pos, scores.lhs_neg, batch_size)
73+
ranks.append(lhs_rank)
74+
aucs.append(lhs_auc)
75+
76+
if scores.rhs_neg.nelement() > 0:
77+
rhs_rank = (scores.rhs_neg >= scores.rhs_pos.unsqueeze(1)).sum(1) + 1
78+
rhs_auc = compute_randomized_auc(scores.rhs_pos, scores.rhs_neg, batch_size)
79+
ranks.append(rhs_rank)
80+
aucs.append(rhs_auc)
7381

7482
return Stats(
75-
pos_rank=average_of_sums(lhs_rank, rhs_rank),
76-
mrr=average_of_sums(lhs_rank.float().reciprocal(),
77-
rhs_rank.float().reciprocal()),
78-
r1=average_of_sums(lhs_rank.le(1), rhs_rank.le(1)),
79-
r10=average_of_sums(lhs_rank.le(10), rhs_rank.le(10)),
80-
r50=average_of_sums(lhs_rank.le(50), rhs_rank.le(50)),
83+
pos_rank=average_of_sums(*ranks),
84+
mrr=average_of_sums(*(rank.float().reciprocal() for rank in ranks)),
85+
r1=average_of_sums(*(rank.le(1) for rank in ranks)),
86+
r10=average_of_sums(*(rank.le(10) for rank in ranks)),
87+
r50=average_of_sums(*(rank.le(50) for rank in ranks)),
8188
# At the end the AUC will be averaged over count.
82-
auc=batch_size * (lhs_auc + rhs_auc) / 2,
89+
auc=batch_size * sum(aucs) / len(aucs),
8390
count=batch_size)
8491

8592

torchbiggraph/model.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,8 @@ def __init__(
772772
entities: Dict[str, EntitySchema],
773773
num_batch_negs: int,
774774
num_uniform_negs: int,
775+
disable_lhs_negs: bool,
776+
disable_rhs_negs: bool,
775777
lhs_operators: Sequence[Optional[Union[AbstractOperator, AbstractDynamicOperator]]],
776778
rhs_operators: Sequence[Optional[Union[AbstractOperator, AbstractDynamicOperator]]],
777779
comparator: AbstractComparator,
@@ -795,6 +797,9 @@ def __init__(
795797
self.num_batch_negs: int = num_batch_negs
796798
self.num_uniform_negs: int = num_uniform_negs
797799

800+
self.disable_lhs_negs = disable_lhs_negs
801+
self.disable_rhs_negs = disable_rhs_negs
802+
798803
self.comparator = comparator
799804

800805
self.lhs_embs: nn.ParameterDict = nn.ModuleDict()
@@ -1000,6 +1005,14 @@ def forward(
10001005
chunk_size = self.num_batch_negs
10011006
negative_sampling_method = Negatives.BATCH_UNIFORM
10021007

1008+
lhs_negative_sampling_method = negative_sampling_method
1009+
rhs_negative_sampling_method = negative_sampling_method
1010+
1011+
if self.disable_lhs_negs:
1012+
lhs_negative_sampling_method = Negatives.NONE
1013+
if self.disable_rhs_negs:
1014+
rhs_negative_sampling_method = Negatives.NONE
1015+
10031016
if self.num_dynamic_rels == 0:
10041017
# In this case the operator is only applied to the RHS. This means
10051018
# that an edge (u, r, v) is scored with c(u, f_r(v)), whereas the
@@ -1012,7 +1025,8 @@ def forward(
10121025
raise RuntimeError("In non-dynamic relation mode there should "
10131026
"be only a right-hand side operator")
10141027

1015-
# Apply operator to right-hand side, sample negatives on both sides.
1028+
# Apply operator to right-hand side, sample negatives on both sides unless
1029+
# one side is disabled.
10161030
pos_scores, lhs_neg_scores, rhs_neg_scores = self.forward_direction_agnostic(
10171031
edges.lhs,
10181032
edges.rhs,
@@ -1026,8 +1040,8 @@ def forward(
10261040
lhs_pos,
10271041
rhs_pos,
10281042
chunk_size,
1029-
negative_sampling_method,
1030-
negative_sampling_method,
1043+
lhs_negative_sampling_method,
1044+
rhs_negative_sampling_method,
10311045
)
10321046
lhs_pos_scores = rhs_pos_scores = pos_scores
10331047

@@ -1061,7 +1075,7 @@ def forward(
10611075
lhs_pos,
10621076
rhs_pos,
10631077
chunk_size,
1064-
negative_sampling_method,
1078+
lhs_negative_sampling_method,
10651079
Negatives.NONE,
10661080
)
10671081
# "Reverse" edges: apply operator to lhs, sample negatives on rhs.
@@ -1078,7 +1092,7 @@ def forward(
10781092
rhs_pos,
10791093
lhs_pos,
10801094
chunk_size,
1081-
negative_sampling_method,
1095+
rhs_negative_sampling_method,
10821096
Negatives.NONE,
10831097
)
10841098

@@ -1187,6 +1201,8 @@ def make_model(config: ConfigSchema) -> MultiRelationEmbedder:
11871201
config.entities,
11881202
num_uniform_negs=config.num_uniform_negs,
11891203
num_batch_negs=config.num_batch_negs,
1204+
disable_lhs_negs=config.disable_lhs_negs,
1205+
disable_rhs_negs=config.disable_rhs_negs,
11901206
lhs_operators=lhs_operators,
11911207
rhs_operators=rhs_operators,
11921208
comparator=comparator,

0 commit comments

Comments
 (0)