diff --git a/cornac/models/gru4rec/gru4rec.py b/cornac/models/gru4rec/gru4rec.py index 2ea880b58..4126a1d45 100644 --- a/cornac/models/gru4rec/gru4rec.py +++ b/cornac/models/gru4rec/gru4rec.py @@ -1,95 +1,32 @@ -from collections import Counter +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ import numpy as np import torch from torch import nn from torch.autograd import Variable -from torch.optim import Optimizer -from cornac.utils.common import get_rng +from ..seq_utils.iterators import io_iter +from ..seq_utils.optim import IndexedAdagradM -def init_parameter_matrix( - tensor: torch.Tensor, dim0_scale: int = 1, dim1_scale: int = 1 -): - sigma = np.sqrt( - 6.0 / float(tensor.size(0) / dim0_scale + tensor.size(1) / dim1_scale) - ) +def init_parameter_matrix(tensor: torch.Tensor, dim0_scale: int = 1, dim1_scale: int = 1): + sigma = np.sqrt(6.0 / float(tensor.size(0) / dim0_scale + tensor.size(1) / dim1_scale)) return nn.init._no_grad_uniform_(tensor, -sigma, sigma) -class IndexedAdagradM(Optimizer): - def __init__(self, params, lr=0.05, momentum=0.0, eps=1e-6): - if lr <= 0.0: - raise ValueError("Invalid learning rate: {}".format(lr)) - if momentum < 0.0: - raise ValueError("Invalid momentum value: {}".format(momentum)) - if eps <= 0.0: - raise ValueError("Invalid epsilon value: {}".format(eps)) - - defaults = dict(lr=lr, momentum=momentum, eps=eps) - super(IndexedAdagradM, self).__init__(params, defaults) - - for group in self.param_groups: - for p in group["params"]: - state = self.state[p] - state["acc"] = torch.full_like( - p, 0, memory_format=torch.preserve_format - ) - if momentum > 0: - state["mom"] = torch.full_like( - p, 0, memory_format=torch.preserve_format - ) - - def share_memory(self): - for group in self.param_groups: - for p in group["params"]: - state = self.state[p] - state["acc"].share_memory_() - if group["momentum"] > 0: - state["mom"].share_memory_() - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - grad = p.grad - state = self.state[p] - clr = group["lr"] - momentum = group["momentum"] - if grad.is_sparse: - grad = grad.coalesce() - grad_indices = grad._indices()[0] - grad_values = grad._values() - accs = state["acc"][grad_indices] + grad_values.pow(2) - state["acc"].index_copy_(0, grad_indices, accs) - accs.add_(group["eps"]).sqrt_().mul_(-1 / clr) - if momentum > 0: - moma = state["mom"][grad_indices] - moma.mul_(momentum).add_(grad_values / accs) - state["mom"].index_copy_(0, grad_indices, moma) - p.index_add_(0, grad_indices, moma) - else: - p.index_add_(0, grad_indices, grad_values / accs) - else: - state["acc"].add_(grad.pow(2)) - accs = state["acc"].add(group["eps"]) - accs.sqrt_() - if momentum > 0: - mom = state["mom"] - mom.mul_(momentum).addcdiv_(grad, accs, value=-clr) - p.add_(mom) - else: - p.addcdiv_(grad, accs, value=-clr) - return loss - - class GRUEmbedding(nn.Module): def __init__(self, dim_in, dim_out): super(GRUEmbedding, self).__init__() @@ -118,6 +55,21 @@ def forward(self, X, H): class GRU4RecModel(nn.Module): + """GRU4Rec PyTorch architecture. + + The model computes a hidden state from the latest item id (and previous + hidden state) and scores all candidate items via either: + + - ``constrained_embedding=True``: tied input/output item embeddings. + - ``embedding > 0``: separate input embedding of given size. + - otherwise: a custom :class:`GRUEmbedding` directly producing the hidden. + + The model returns a ``(B, B+N)`` score matrix where columns 0..B-1 are + the in-batch positives and columns B..B+N-1 are shared sampled negatives. + Padding row ``n_items`` is included so it is safe to pass ``n_items`` as + a fake input id for cold-start fallback at score-time. + """ + def __init__( self, n_items, @@ -146,12 +98,11 @@ def __init__( self.elu_param = elu_param self.bpreg = bpreg self.loss = loss - self.set_loss_function(self.loss) self.start = 0 if constrained_embedding: n_input = layers[-1] elif embedding: - self.E = nn.Embedding(n_items, embedding, sparse=True) + self.E = nn.Embedding(n_items + 1, embedding, sparse=True, padding_idx=n_items) n_input = embedding else: self.GE = GRUEmbedding(n_items, layers[0]) @@ -165,8 +116,8 @@ def __init__( self.D.append(nn.Dropout(dropout_p_hidden)) self.G = nn.ModuleList(self.G) self.D = nn.ModuleList(self.D) - self.Wy = nn.Embedding(n_items, layers[-1], sparse=True) - self.By = nn.Embedding(n_items, 1, sparse=True) + self.Wy = nn.Embedding(n_items + 1, layers[-1], sparse=True, padding_idx=n_items) + self.By = nn.Embedding(n_items + 1, 1, sparse=True, padding_idx=n_items) self.reset_parameters() @torch.no_grad() @@ -182,73 +133,30 @@ def reset_parameters(self): nn.init.zeros_(self.G[i].bias_hh) init_parameter_matrix(self.Wy.weight) nn.init.zeros_(self.By.weight) - - def set_loss_function(self, loss): - if loss == "cross-entropy": - self.loss_function = self.xe_loss_with_softmax - elif loss == "bpr-max": - self.loss_function = self.bpr_max_loss_with_elu - elif loss == "top1": - self.loss_function = self.top1 - else: - raise NotImplementedError - - def xe_loss_with_softmax(self, O, Y, M): - if self.logq > 0: - O = O - self.logq * torch.log( - torch.cat([self.P0[Y[:M]], self.P0[Y[M:]] ** self.sample_alpha]) - ) - X = torch.exp(O - O.max(dim=1, keepdim=True)[0]) - X = X / X.sum(dim=1, keepdim=True) - return -torch.sum(torch.log(torch.diag(X) + 1e-24)) - - def softmax_neg(self, X): - hm = 1.0 - torch.eye(*X.shape, out=torch.empty_like(X)) - X = X * hm - e_x = torch.exp(X - X.max(dim=1, keepdim=True)[0]) * hm - return e_x / e_x.sum(dim=1, keepdim=True) - - def bpr_max_loss_with_elu(self, O, Y, M): - if self.elu_param > 0: - O = nn.functional.elu(O, self.elu_param) - softmax_scores = self.softmax_neg(O) - target_scores = torch.diag(O) - target_scores = target_scores.reshape(target_scores.shape[0], -1) - return torch.sum( - ( - -torch.log( - torch.sum(torch.sigmoid(target_scores - O) * softmax_scores, dim=1) - + 1e-24 - ) - + self.bpreg * torch.sum((O**2) * softmax_scores, dim=1) - ) - ) - - def top1(self, O, Y, M): - target_scores = torch.diag(O) - target_scores = target_scores.reshape(target_scores.shape[0], -1) - return torch.sum( - ( - torch.mean( - torch.sigmoid(O - target_scores) + torch.sigmoid(O**2), axis=1 - ) - - torch.sigmoid(target_scores**2) / (M + self.n_sample) - ) - ) + if self.Wy.padding_idx is not None: + self.Wy.weight.data[self.Wy.padding_idx].zero_() + if self.By.padding_idx is not None: + self.By.weight.data[self.By.padding_idx].zero_() def _init_numpy_weights(self, shape): - sigma = np.sqrt(6.0 / (shape[0] + shape[1])) - m = np.random.rand(*shape).astype("float32") * 2 * sigma - sigma + sigma = float(np.sqrt(6.0 / (shape[0] + shape[1]))) + m = (np.random.rand(*shape) * 2 * sigma - sigma).astype("float32") return m @torch.no_grad() def _reset_weights_to_compatibility_mode(self): + """Reset weights using numpy RNG with seed 42 for reproducibility. + + Note: when ``constrained_embedding=False`` and ``embedding > 0`` the + ``E`` embedding includes a padding row (``n_items``), and ``Wy``/``By`` + also include padding rows. We only set the first ``n_items`` rows. + """ np.random.seed(42) if self.constrained_embedding: n_input = self.layers[-1] elif self.embedding: n_input = self.embedding - self.E.weight.set_( + self.E.weight.data[: self.n_items].copy_( torch.tensor( self._init_numpy_weights((self.n_items, n_input)), device=self.E.weight.device, @@ -256,56 +164,58 @@ def _reset_weights_to_compatibility_mode(self): ) else: n_input = self.n_items - m = [] - m.append(self._init_numpy_weights((n_input, self.layers[0]))) - m.append(self._init_numpy_weights((n_input, self.layers[0]))) - m.append(self._init_numpy_weights((n_input, self.layers[0]))) - self.GE.Wx0.weight.set_( - torch.tensor(np.hstack(m), device=self.GE.Wx0.weight.device) - ) - m2 = [] - m2.append(self._init_numpy_weights((self.layers[0], self.layers[0]))) - m2.append(self._init_numpy_weights((self.layers[0], self.layers[0]))) - self.GE.Wrz0.set_(torch.tensor(np.hstack(m2), device=self.GE.Wrz0.device)) + m = [ + self._init_numpy_weights((n_input, self.layers[0])), + self._init_numpy_weights((n_input, self.layers[0])), + self._init_numpy_weights((n_input, self.layers[0])), + ] + self.GE.Wx0.weight.set_(torch.tensor(np.hstack(m), dtype=torch.float32, device=self.GE.Wx0.weight.device)) + m2 = [ + self._init_numpy_weights((self.layers[0], self.layers[0])), + self._init_numpy_weights((self.layers[0], self.layers[0])), + ] + self.GE.Wrz0.set_(torch.tensor(np.hstack(m2), dtype=torch.float32, device=self.GE.Wrz0.device)) self.GE.Wh0.set_( torch.tensor( self._init_numpy_weights((self.layers[0], self.layers[0])), + dtype=torch.float32, device=self.GE.Wh0.device, ) ) - self.GE.Bh0.set_( - torch.zeros((self.layers[0] * 3,), device=self.GE.Bh0.device) - ) + self.GE.Bh0.set_(torch.zeros((self.layers[0] * 3,), device=self.GE.Bh0.device)) for i in range(self.start, len(self.layers)): - m = [] - m.append(self._init_numpy_weights((n_input, self.layers[i]))) - m.append(self._init_numpy_weights((n_input, self.layers[i]))) - m.append(self._init_numpy_weights((n_input, self.layers[i]))) - self.G[i].weight_ih.set_( - torch.tensor(np.vstack(m), device=self.G[i].weight_ih.device) - ) - m2 = [] - m2.append(self._init_numpy_weights((self.layers[i], self.layers[i]))) - m2.append(self._init_numpy_weights((self.layers[i], self.layers[i]))) - m2.append(self._init_numpy_weights((self.layers[i], self.layers[i]))) + m = [ + self._init_numpy_weights((n_input, self.layers[i])), + self._init_numpy_weights((n_input, self.layers[i])), + self._init_numpy_weights((n_input, self.layers[i])), + ] + self.G[i].weight_ih.set_(torch.tensor(np.vstack(m), dtype=torch.float32, device=self.G[i].weight_ih.device)) + m2 = [ + self._init_numpy_weights((self.layers[i], self.layers[i])), + self._init_numpy_weights((self.layers[i], self.layers[i])), + self._init_numpy_weights((self.layers[i], self.layers[i])), + ] self.G[i].weight_hh.set_( - torch.tensor(np.vstack(m2), device=self.G[i].weight_hh.device) - ) - self.G[i].bias_hh.set_( - torch.zeros((self.layers[i] * 3,), device=self.G[i].bias_hh.device) - ) - self.G[i].bias_ih.set_( - torch.zeros((self.layers[i] * 3,), device=self.G[i].bias_ih.device) + torch.tensor( + np.vstack(m2), + dtype=torch.float32, + device=self.G[i].weight_hh.device, + ) ) - self.Wy.weight.set_( + self.G[i].bias_hh.set_(torch.zeros((self.layers[i] * 3,), device=self.G[i].bias_hh.device)) + self.G[i].bias_ih.set_(torch.zeros((self.layers[i] * 3,), device=self.G[i].bias_ih.device)) + self.Wy.weight.data[: self.n_items].copy_( torch.tensor( self._init_numpy_weights((self.n_items, self.layers[-1])), + dtype=torch.float32, device=self.Wy.weight.device, ) ) - self.By.weight.set_( - torch.zeros((self.n_items, 1), device=self.By.weight.device) - ) + self.By.weight.data[: self.n_items].copy_(torch.zeros((self.n_items, 1), device=self.By.weight.device)) + if self.Wy.padding_idx is not None: + self.Wy.weight.data[self.Wy.padding_idx].zero_() + if self.By.padding_idx is not None: + self.By.weight.data[self.By.padding_idx].zero_() def embed_constrained(self, X, Y=None): if Y is not None: @@ -359,8 +269,8 @@ def hidden_step(self, X, H, training=False): return X def score_items(self, X, O, B): - O = torch.mm(X, O.T) + B.T - return O + out = torch.mm(X, O.T) + B.T + return out def forward(self, X, H, Y, training=False): E, O, B = self.embed(X, H, Y) @@ -373,126 +283,18 @@ def forward(self, X, H, Y, training=False): return R -def io_iter( - s_iter, uir_tuple, n_sample=0, sample_alpha=0, rng=None, batch_size=1, shuffle=False -): - """Paralellize mini-batch of input-output items. Create an iterator over data yielding batch of input item indices, batch of output item indices, - batch of start masking, batch of end masking, and batch of valid ids (relative positions of current sequences in the last batch). - - Parameters - ---------- - batch_size: int, optional, default = 1 - - shuffle: bool, optional, default: False - If `True`, orders of triplets will be randomized. If `False`, default orders kept. - - Returns - ------- - iterator : batch of input item indices, batch of output item indices, batch of starting sequence mask, batch of ending sequence mask, batch of valid ids +def score(model, layers, device, history_items): + """Score all items given a flat ``history_items`` list of integers. + Returns a numpy array of length ``n_items`` (or ``n_items + 1`` if the + output embedding includes a padding row; in that case the caller is + responsible for trimming). """ - rng = rng if rng is not None else get_rng(None) - start_mask = np.zeros(batch_size, dtype="int") - end_mask = np.ones(batch_size, dtype="int") - input_iids = None - output_iids = None - l_pool = [] - c_pool = [None for _ in range(batch_size)] - sizes = np.zeros(batch_size, dtype="int") - if n_sample > 0: - item_count = Counter(uir_tuple[1]) - item_indices = np.array( - [iid for iid, _ in item_count.most_common()], dtype="int" - ) - item_dist = ( - np.array([cnt for _, cnt in item_count.most_common()], dtype="float") - ** sample_alpha - ) - item_dist = item_dist / item_dist.sum() - for _, batch_mapped_ids in s_iter(batch_size, shuffle): - l_pool += batch_mapped_ids - while len(l_pool) > 0: - if end_mask.sum() == 0: - input_iids = uir_tuple[1][ - [mapped_ids[-sizes[idx]] for idx, mapped_ids in enumerate(c_pool)] - ] - output_iids = uir_tuple[1][ - [ - mapped_ids[-sizes[idx] + 1] - for idx, mapped_ids in enumerate(c_pool) - ] - ] - sizes -= 1 - for idx, size in enumerate(sizes): - if size == 1: - end_mask[idx] = 1 - if n_sample > 0: - negative_samples = rng.choice( - item_indices, size=n_sample, replace=True, p=item_dist - ) - output_iids = np.concatenate([output_iids, negative_samples]) - yield input_iids, output_iids, start_mask, np.arange( - batch_size, dtype="int" - ) - start_mask.fill(0) # reset start masking - while end_mask.sum() > 0 and len(l_pool) > 0: - next_seq = l_pool.pop() - if len(next_seq) > 1: - idx = np.nonzero(end_mask)[0][0] - end_mask[idx] = 0 - start_mask[idx] = 1 - c_pool[idx] = next_seq - sizes[idx] = len(c_pool[idx]) - - valid_id = np.ones(batch_size, dtype="int") - while True: - for idx, size in enumerate(sizes): - if size == 1: - end_mask[idx] = 1 - valid_id[idx] = 0 - input_iids = uir_tuple[1][ - [ - mapped_ids[-sizes[idx]] - for idx, mapped_ids in enumerate(c_pool) - if sizes[idx] > 1 - ] - ] - output_iids = uir_tuple[1][ - [ - mapped_ids[-sizes[idx] + 1] - for idx, mapped_ids in enumerate(c_pool) - if sizes[idx] > 1 - ] - ] - sizes -= 1 - for idx, size in enumerate(sizes): - if size == 1: - end_mask[idx] = 1 - start_mask = start_mask[np.nonzero(valid_id)[0]] - end_mask = end_mask[np.nonzero(valid_id)[0]] - sizes = sizes[np.nonzero(valid_id)[0]] - c_pool = [_ for _, valid in zip(c_pool, valid_id) if valid > 0] - if n_sample > 0: - negative_samples = rng.choice( - item_indices, size=n_sample, replace=True, p=item_dist - ) - output_iids = np.concatenate([output_iids, negative_samples]) - yield input_iids, output_iids, start_mask, np.nonzero(valid_id)[0] - valid_id = np.ones(len(input_iids), dtype="int") - if end_mask.sum() == len(input_iids): - break - start_mask.fill(0) # reset start masking - - -def score(model, layers, device, history_items): model.eval() H = [] for i in range(len(layers)): - H.append( - torch.zeros( - (1, layers[i]), dtype=torch.float32, requires_grad=False, device=device - ) - ) + H.append(torch.zeros((1, layers[i]), dtype=torch.float32, requires_grad=False, device=device)) + O = None for iid in history_items: O = model.forward( torch.tensor([iid], requires_grad=False, device=device), @@ -500,4 +302,6 @@ def score(model, layers, device, history_items): None, training=False, ) + if O is None: + return None return O.squeeze().cpu().detach().numpy() diff --git a/cornac/models/gru4rec/recom_gru4rec.py b/cornac/models/gru4rec/recom_gru4rec.py index 7e161d924..d48bf4607 100644 --- a/cornac/models/gru4rec/recom_gru4rec.py +++ b/cornac/models/gru4rec/recom_gru4rec.py @@ -1,4 +1,4 @@ -# Copyright 2023 The Cornac Authors. All Rights Reserved. +# Copyright 2026 The Cornac Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +13,30 @@ # limitations under the License. # ============================================================================ +from collections import Counter + import numpy as np from tqdm.auto import trange -from collections import Counter from cornac.models.recommender import NextItemRecommender from ...utils import get_rng +from ..seq_utils import io_iter + +SUPPORTED_LOSSES = ( + "cross-entropy", + "xe_softmax", + "softmax", + "bpr", + "bpr-max", + "top1", + "bce", + "ce", +) class GRU4Rec(NextItemRecommender): - """Session-based Recommendations with Recurrent Neural Networks + """Session-based Recommendations with Recurrent Neural Networks. Parameters ---------- @@ -31,35 +44,40 @@ class GRU4Rec(NextItemRecommender): The name of the recommender model. layers: list of int, optional, default: [100] - The number of hidden units in each layer + The number of hidden units in each layer. loss: str, optional, default: 'cross-entropy' - Select the loss function. + Loss function. Supported: 'cross-entropy', 'bpr', 'bpr-max', 'top1', + 'bce', 'ce'. batch_size: int, optional, default: 512 - Batch size + Batch size. dropout_p_embed: float, optional, default: 0.0 - Dropout ratio for embedding layers + Dropout ratio for embedding layer. dropout_p_hidden: float, optional, default: 0.0 - Dropout ratio for hidden layers + Dropout ratio for hidden layers. learning_rate: float, optional, default: 0.05 - Learning rate for the optimizer + Learning rate for the optimizer. momentum: float, optional, default: 0.0 - Momentum for adaptive learning rate + Momentum for the adaptive learning rate optimizer. sample_alpha: float, optional, default: 0.5 - Tradeoff factor controls the contribution of negative sample towards final loss + Tradeoff factor controls the contribution of negative samples + towards the final loss (popularity-based sampling exponent). n_sample: int, optional, default: 2048 - Number of negative samples + Number of additional shared negative samples per mini-batch. embedding: int, optional, default: 0 + Size of the separate input embedding. ``0`` means no separate + embedding; use ``"layersize"`` to set it to ``layers[0]``. constrained_embedding: bool, optional, default: True + Whether input and output item embeddings are tied. n_epochs: int, optional, default: 10 @@ -67,18 +85,19 @@ class GRU4Rec(NextItemRecommender): Regularization coefficient for 'bpr-max' loss. elu_param: float, optional, default: 0.5 - Elu param for 'bpr-max' loss + ELU parameter for 'bpr-max' loss. - logq: float, optional, default: 0, - LogQ correction to offset the sampling bias affecting 'cross-entropy' loss. + logq: float, optional, default: 0.0 + LogQ correction strength to offset sampling bias for the + cross-entropy loss. device: str, optional, default: 'cpu' Set to 'cuda' for GPU support. - trainable: boolean, optional, default: True - When False, the model will not be re-trained, and input of pre-trained parameters are required. + trainable: bool, optional, default: True + When False, the model will not be re-trained. - verbose: boolean, optional, default: True + verbose: bool, optional, default: False When True, running logs are displayed. seed: int, optional, default: None @@ -89,7 +108,6 @@ class GRU4Rec(NextItemRecommender): Hidasi, B., Karatzoglou, A., Baltrunas, L., & Tikk, D. (2015). Session-based recommendations with recurrent neural networks. arXiv preprint arXiv:1511.06939. - """ def __init__( @@ -116,6 +134,8 @@ def __init__( seed=None, ): super().__init__(name, trainable=trainable, verbose=verbose) + if loss not in SUPPORTED_LOSSES: + raise ValueError(f"loss='{loss}' not supported; choose from {SUPPORTED_LOSSES}") self.layers = layers self.loss = loss self.batch_size = batch_size @@ -137,20 +157,29 @@ def __init__( def fit(self, train_set, val_set=None): super().fit(train_set, val_set) + if not self.trainable: + return self + import torch - from .gru4rec import GRU4RecModel, IndexedAdagradM, io_iter + from .gru4rec import GRU4RecModel + from ..seq_utils.losses import get_loss_function + from ..seq_utils.optim import IndexedAdagradM item_freq = Counter(self.train_set.uir_tuple[1]) - P0 = torch.tensor( - [item_freq[iid] for (_, iid) in self.train_set.iid_map.items()], - dtype=torch.float32, - device=self.device, - ) if self.logq > 0 else None + self.P0 = ( + torch.tensor( + [item_freq[iid] for (_, iid) in self.train_set.iid_map.items()], + dtype=torch.float32, + device=self.device, + ) + if self.logq > 0 + else None + ) self.model = GRU4RecModel( n_items=self.total_items, - P0=P0, + P0=self.P0, layers=self.layers, dropout_p_embed=self.dropout_p_embed, dropout_p_hidden=self.dropout_p_hidden, @@ -162,26 +191,33 @@ def fit(self, train_set, val_set=None): elu_param=self.elu_param, loss=self.loss, ).to(self.device) - self.model._reset_weights_to_compatibility_mode() - opt = IndexedAdagradM( - self.model.parameters(), self.learning_rate, self.momentum + loss_fn = get_loss_function(self.loss) + loss_kwargs = dict( + P0=self.P0, + logq=self.logq, + sample_alpha=self.sample_alpha, + batch_size=None, # filled per-batch + bpreg=self.bpreg, + elu_param=self.elu_param, + n_sample=self.n_sample, ) + opt = IndexedAdagradM(self.model.parameters(), self.learning_rate, self.momentum) + progress_bar = trange(1, self.n_epochs + 1, disable=not self.verbose) for _ in progress_bar: - H = [] - for i in range(len(self.layers)): - H.append( - torch.zeros( - (self.batch_size, self.layers[i]), - dtype=torch.float32, - requires_grad=False, - device=self.device, - ) + H = [ + torch.zeros( + (self.batch_size, self.layers[i]), + dtype=torch.float32, + requires_grad=False, + device=self.device, ) - total_loss = 0 + for i in range(len(self.layers)) + ] + total_loss = 0.0 cnt = 0 for inc, (in_iids, out_iids, start_mask, valid_id) in enumerate( io_iter( @@ -198,24 +234,30 @@ def fit(self, train_set, val_set=None): H[i][np.nonzero(start_mask)[0], :] = 0 H[i].detach_() H[i] = H[i][valid_id] - in_iids = torch.tensor(in_iids, requires_grad=False, device=self.device) - out_iids = torch.tensor( - out_iids, requires_grad=False, device=self.device - ) + in_iids_t = torch.tensor(in_iids, dtype=torch.long, requires_grad=False, device=self.device) + out_iids_t = torch.tensor(out_iids, dtype=torch.long, requires_grad=False, device=self.device) self.model.zero_grad() - R = self.model.forward(in_iids, H, out_iids, training=True) - L = self.model.loss_function(R, out_iids, len(in_iids)) / len(in_iids) + R = self.model.forward(in_iids_t, H, out_iids_t, training=True) + loss_kwargs["batch_size"] = len(in_iids) + loss_kwargs["out_iids"] = out_iids_t + L = loss_fn(R, **loss_kwargs) L.backward() opt.step() total_loss += L.cpu().detach().numpy() * len(in_iids) cnt += len(in_iids) - if inc % 10 == 0: + if inc % 10 == 0 and cnt > 0: progress_bar.set_postfix(loss=(total_loss / cnt)) return self def score(self, user_idx, history_items, **kwargs): - from .gru4rec import score - - if len(history_items) > 0: - return score(self.model, self.layers, self.device, history_items) - return np.ones(self.total_items, dtype="float") + from .gru4rec import score as _score + + if len(history_items) == 0: + return np.ones(self.total_items, dtype="float") + scores = _score(self.model, self.layers, self.device, history_items) + if scores is None: + return np.ones(self.total_items, dtype="float") + # The output embedding has a +1 padding row; trim if present. + if scores.shape[-1] == self.total_items + 1: + scores = scores[: self.total_items] + return scores diff --git a/cornac/models/seq_utils/__init__.py b/cornac/models/seq_utils/__init__.py new file mode 100644 index 000000000..6d8023684 --- /dev/null +++ b/cornac/models/seq_utils/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Shared utilities for session-based sequential recommendation models. + +* **Model forward output**: every sequential model returns an + ``item_scores`` matrix of shape ``(B, B + N)`` where the diagonal holds + the positive target scores and the remaining columns are in-batch + + sampled negatives. +* **Loss functions** (see :mod:`.losses`) all consume that ``(B, B+N)`` + matrix and return a scalar. +* **Iterators** (see :mod:`.iterators`) yield uniform tuples: + + ========================= ================================================= + Iterator Yielded tuple + ========================= ================================================= + ``io_iter`` ``(in_iids, out_iids, start_mask, valid_id)`` + -- per-item RNN, session-based. + ``session_seq_iter`` ``(in_uids, hist_iids, out_iids)`` -- sequence + models, session-based. + ========================= ================================================= +""" + +from .iterators import io_iter, session_seq_iter + +__all__ = [ + "io_iter", + "session_seq_iter", +] diff --git a/cornac/models/seq_utils/iterators.py b/cornac/models/seq_utils/iterators.py new file mode 100644 index 000000000..5e23a4c83 --- /dev/null +++ b/cornac/models/seq_utils/iterators.py @@ -0,0 +1,186 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Mini-batch iterators for session-based training. +""" + +from collections import Counter + +import numpy as np + +from ...utils.common import get_rng + + +def _build_neg_sampler(uir_tuple, sample_alpha): + """Precompute popularity-based sampling distribution over items.""" + item_count = Counter(uir_tuple[1]) + item_indices = np.array([iid for iid, _ in item_count.most_common()], dtype="int") + item_dist = np.array([cnt for _, cnt in item_count.most_common()], dtype="float") ** sample_alpha + item_dist = item_dist / item_dist.sum() + return item_indices, item_dist + + +def io_iter(s_iter, uir_tuple, n_sample=0, sample_alpha=0, rng=None, batch_size=1, shuffle=False): + """Session-based per-item iterator (parallel sessions). + + Yields per training step a 4-tuple + ``(in_iids, out_iids, start_mask, valid_id)`` where: + + - ``in_iids``: current input item id in each slot. Shape ``(B',)``. + - ``out_iids``: target item ids (followed by ``n_sample`` shared + negatives). Shape ``(B' + N,)``. + - ``start_mask``: 1 in slots that just started a new session (used to + reset RNN hidden state). Shape ``(B',)``. + - ``valid_id``: indices of the slots that remain valid in the current + batch (used to trim the hidden state when sessions end at different + times). Shape ``(B',)``. + + ``B'`` equals ``batch_size`` for the main loop and shrinks during the + drain phase as sessions are exhausted. + """ + rng = rng if rng is not None else get_rng(None) + start_mask = np.zeros(batch_size, dtype="int") + end_mask = np.ones(batch_size, dtype="int") + input_iids = None + output_iids = None + l_pool = [] # pending sessions (list of mapped-id lists) + c_pool = [None for _ in range(batch_size)] + sizes = np.zeros(batch_size, dtype="int") + if n_sample > 0: + item_indices, item_dist = _build_neg_sampler(uir_tuple, sample_alpha) + + for _, batch_mapped_ids in s_iter(batch_size, shuffle): + l_pool += batch_mapped_ids + while len(l_pool) > 0: + if end_mask.sum() == 0: + input_iids = uir_tuple[1][[mapped_ids[-sizes[idx]] for idx, mapped_ids in enumerate(c_pool)]] + output_iids = uir_tuple[1][[mapped_ids[-sizes[idx] + 1] for idx, mapped_ids in enumerate(c_pool)]] + sizes -= 1 + for idx, size in enumerate(sizes): + if size == 1: + end_mask[idx] = 1 + if n_sample > 0: + negatives = rng.choice(item_indices, size=n_sample, replace=True, p=item_dist) + output_iids = np.concatenate([output_iids, negatives]) + yield ( + input_iids, + output_iids, + start_mask.copy(), + np.arange(batch_size, dtype="int"), + ) + start_mask.fill(0) + while end_mask.sum() > 0 and len(l_pool) > 0: + next_seq = l_pool.pop() + if len(next_seq) > 1: + idx = np.nonzero(end_mask)[0][0] + end_mask[idx] = 0 + start_mask[idx] = 1 + c_pool[idx] = next_seq + sizes[idx] = len(c_pool[idx]) + + valid_id = np.ones(batch_size, dtype="int") + while True: + for idx, size in enumerate(sizes): + if size == 1: + end_mask[idx] = 1 + valid_id[idx] = 0 + keep = [idx for idx in range(len(c_pool)) if sizes[idx] > 1] + if not keep: + break + input_iids = uir_tuple[1][[c_pool[idx][-sizes[idx]] for idx in keep]] + output_iids = uir_tuple[1][[c_pool[idx][-sizes[idx] + 1] for idx in keep]] + sizes -= 1 + for idx, size in enumerate(sizes): + if size == 1: + end_mask[idx] = 1 + keep_mask = np.nonzero(valid_id)[0] + start_mask = start_mask[keep_mask] + end_mask = end_mask[keep_mask] + sizes = sizes[keep_mask] + c_pool = [_ for _, valid in zip(c_pool, valid_id) if valid > 0] + if n_sample > 0: + negatives = rng.choice(item_indices, size=n_sample, replace=True, p=item_dist) + output_iids = np.concatenate([output_iids, negatives]) + yield input_iids, output_iids, start_mask.copy(), np.nonzero(valid_id)[0] + valid_id = np.ones(len(input_iids), dtype="int") + if end_mask.sum() == len(input_iids): + break + start_mask.fill(0) + + +def session_seq_iter( + train_set, + pad_index, + batch_size=64, + max_len=20, + n_sample=2048, + sample_alpha=0.5, + rng=None, + shuffle=True, +): + """Session-based sequence iterator for transformer/seq models. + + Iterates over sessions (each session = one training sequence). For a + session ``[i0, i1, ..., iT]`` it yields ``num_sessions * (T)`` training + triples ``(uid, hist[max_len], target)`` where ``hist`` is the + left-padded prefix and ``target`` is the next item. + """ + rng = rng if rng is not None else get_rng(None) + uir_tuple = train_set.uir_tuple + sessions = train_set.sessions + sids = list(sessions.keys()) + if shuffle: + rng.shuffle(sids) + if n_sample > 0: + item_indices, item_dist = _build_neg_sampler(uir_tuple, sample_alpha) + + buffer_uids, buffer_hist, buffer_target = [], [], [] + for sid in sids: + mapped_ids = sessions[sid] + items = list(uir_tuple[1][mapped_ids]) + if len(items) < 2: + continue + uid = int(uir_tuple[0][mapped_ids[0]]) + for t in range(1, len(items)): + hist = items[:t][-max_len:] + hist = [pad_index] * (max_len - len(hist)) + list(hist) + buffer_uids.append(uid) + buffer_hist.append(hist) + buffer_target.append(items[t]) + if len(buffer_uids) == batch_size: + target = np.array(buffer_target, dtype="int") + if n_sample > 0: + negatives = rng.choice(item_indices, size=n_sample, replace=True, p=item_dist) + out_iids = np.concatenate([target, negatives]) + else: + out_iids = target + yield ( + np.array(buffer_uids, dtype="int"), + np.array(buffer_hist, dtype="int"), + out_iids, + ) + buffer_uids, buffer_hist, buffer_target = [], [], [] + if len(buffer_uids) > 1: + target = np.array(buffer_target, dtype="int") + if n_sample > 0: + negatives = rng.choice(item_indices, size=n_sample, replace=True, p=item_dist) + out_iids = np.concatenate([target, negatives]) + else: + out_iids = target + yield ( + np.array(buffer_uids, dtype="int"), + np.array(buffer_hist, dtype="int"), + out_iids, + ) diff --git a/cornac/models/seq_utils/losses.py b/cornac/models/seq_utils/losses.py new file mode 100644 index 000000000..e3483091f --- /dev/null +++ b/cornac/models/seq_utils/losses.py @@ -0,0 +1,122 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Loss functions for sequential recommendation models. +""" + +import torch +import torch.nn.functional as F + + +def softmax_neg(X): + """Softmax over negatives, masking out the diagonal (positives).""" + hm = 1.0 - torch.eye(*X.shape, out=torch.empty_like(X)) + X = X * hm + e_x = torch.exp(X - X.max(dim=1, keepdim=True)[0]) * hm + if e_x.size(0) == 1: + return e_x + return e_x / (e_x.sum(dim=1, keepdim=True) + 1e-24) + + +def bpr_loss(item_scores, **kwargs): + """BPR pairwise logsigmoid loss against in-batch negatives. + + Parameters + ---------- + item_scores: torch.Tensor of shape (B, B+N) + Score matrix with positives on the diagonal. + """ + pos = torch.diag(item_scores) + pos = pos.reshape(pos.shape[0], -1) + logits = F.logsigmoid(pos - item_scores) + mask = 1.0 - torch.eye(*logits.shape, out=torch.empty_like(logits)) + loss = -torch.sum(logits * mask) + return loss / logits.size(0) / max(logits.size(1) - 1, 1) + + +def top1_loss(item_scores, n_sample=0, **kwargs): + """TOP1 ranking loss from Hidasi et al. (2015).""" + target = torch.diag(item_scores) + target = target.reshape(target.shape[0], -1) + return torch.sum( + torch.mean( + torch.sigmoid(item_scores - target) + torch.sigmoid(item_scores**2), + dim=1, + ) + - torch.sigmoid(target**2) / (item_scores.size(0) + n_sample) + ) / item_scores.size(0) + + +def xe_softmax_loss(item_scores, out_iids=None, P0=None, logq=0.0, sample_alpha=0.5, batch_size=None, **kwargs): + """Cross-entropy with softmax over in-batch + sampled negatives. + + Supports an optional logQ correction (Hidasi & Karatzoglou, 2018) when + ``P0`` (item popularity prior) and ``logq > 0`` are provided. + """ + if logq > 0 and P0 is not None and out_iids is not None and batch_size is not None: + item_scores = item_scores - logq * torch.log( + torch.cat([P0[out_iids[:batch_size]], P0[out_iids[batch_size:]] ** sample_alpha]) + ) + X = torch.exp(item_scores - item_scores.max(dim=1, keepdim=True)[0]) + X = X / (X.sum(dim=1, keepdim=True) + 1e-24) + return -torch.sum(torch.log(torch.diag(X) + 1e-24)) / item_scores.size(0) + + +def bpr_max_loss(item_scores, bpreg=1.0, elu_param=0.5, **kwargs): + """BPR-max with softmax-weighted negatives and L2 regularisation on scores.""" + if elu_param > 0: + item_scores = F.elu(item_scores, elu_param) + softmax_scores = softmax_neg(item_scores) + target = torch.diag(item_scores) + target = target.reshape(target.shape[0], -1) + return torch.sum( + -torch.log(torch.sum(torch.sigmoid(target - item_scores) * softmax_scores, dim=1) + 1e-24) + + bpreg * torch.sum((item_scores**2) * softmax_scores, dim=1) + ) / item_scores.size(0) + + +def bce_loss(item_scores, **kwargs): + """Binary cross-entropy treating the diagonal as positive and all other + columns (in-batch negatives + sampled negatives) as negatives. + """ + B, N = item_scores.shape + targets = torch.zeros_like(item_scores) + targets[torch.arange(B), torch.arange(B)] = 1.0 + return F.binary_cross_entropy_with_logits(item_scores, targets) + + +def ce_loss(item_scores, **kwargs): + """Standard cross-entropy where the target class is the in-batch diagonal.""" + targets = torch.arange(item_scores.size(0), device=item_scores.device, dtype=torch.long) + return F.cross_entropy(item_scores, targets) + + +LOSS_FUNCTIONS = { + "bpr": bpr_loss, + "top1": top1_loss, + "cross-entropy": xe_softmax_loss, + "xe_softmax": xe_softmax_loss, + "softmax": xe_softmax_loss, + "bpr-max": bpr_max_loss, + "bce": bce_loss, + "ce": ce_loss, +} + + +def get_loss_function(name): + """Look up a loss function by name. Raises ``ValueError`` if unknown.""" + if name not in LOSS_FUNCTIONS: + raise ValueError(f"Unknown loss '{name}'. Supported: {sorted(set(LOSS_FUNCTIONS))}") + return LOSS_FUNCTIONS[name] diff --git a/cornac/models/seq_utils/optim.py b/cornac/models/seq_utils/optim.py new file mode 100644 index 000000000..588a1128e --- /dev/null +++ b/cornac/models/seq_utils/optim.py @@ -0,0 +1,90 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Custom optimizer(s) for sequential recommendation models. +""" + +import torch +from torch.optim import Optimizer + + +class IndexedAdagradM(Optimizer): + """Sparse-aware Adagrad with momentum, used by GRU4Rec and FPMC.""" + + def __init__(self, params, lr=0.05, momentum=0.0, eps=1e-6): + if lr <= 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if eps <= 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + + defaults = dict(lr=lr, momentum=momentum, eps=eps) + super(IndexedAdagradM, self).__init__(params, defaults) + + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["acc"] = torch.full_like(p, 0, memory_format=torch.preserve_format) + if momentum > 0: + state["mom"] = torch.full_like(p, 0, memory_format=torch.preserve_format) + + def share_memory(self): + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["acc"].share_memory_() + if group["momentum"] > 0: + state["mom"].share_memory_() + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + state = self.state[p] + clr = group["lr"] + momentum = group["momentum"] + if grad.is_sparse: + grad = grad.coalesce() + grad_indices = grad._indices()[0] + grad_values = grad._values() + accs = state["acc"][grad_indices] + grad_values.pow(2) + state["acc"].index_copy_(0, grad_indices, accs) + accs.add_(group["eps"]).sqrt_().mul_(-1 / clr) + if momentum > 0: + moma = state["mom"][grad_indices] + moma.mul_(momentum).add_(grad_values / accs) + state["mom"].index_copy_(0, grad_indices, moma) + p.index_add_(0, grad_indices, moma) + else: + p.index_add_(0, grad_indices, grad_values / accs) + else: + state["acc"].add_(grad.pow(2)) + accs = state["acc"].add(group["eps"]) + accs.sqrt_() + if momentum > 0: + mom = state["mom"] + mom.mul_(momentum).addcdiv_(grad, accs, value=-clr) + p.add_(mom) + else: + p.addcdiv_(grad, accs, value=-clr) + return loss