From cf1786718b4ff329426aa6bac8ccaad620b39751 Mon Sep 17 00:00:00 2001 From: hieuddo Date: Tue, 2 Jun 2026 15:48:24 +0800 Subject: [PATCH 1/4] Transformer-based NextItemRecommender models --- README.md | 3 + cornac/models/__init__.py | 3 + cornac/models/bert4rec/__init__.py | 16 ++ cornac/models/bert4rec/bert4rec.py | 117 +++++++++ cornac/models/bert4rec/recom_bert4rec.py | 227 ++++++++++++++++++ cornac/models/bert4rec/requirements.txt | 2 + cornac/models/gpt2rec/__init__.py | 16 ++ cornac/models/gpt2rec/gpt2rec.py | 112 +++++++++ cornac/models/gpt2rec/recom_gpt2rec.py | 227 ++++++++++++++++++ cornac/models/gpt2rec/requirements.txt | 2 + cornac/models/sasrec/__init__.py | 16 ++ cornac/models/sasrec/recom_sasrec.py | 290 +++++++++++++++++++++++ cornac/models/sasrec/requirements.txt | 1 + cornac/models/sasrec/sasrec.py | 186 +++++++++++++++ examples/transformer_rec_diginetica.py | 108 +++++++++ 15 files changed, 1326 insertions(+) create mode 100644 cornac/models/bert4rec/__init__.py create mode 100644 cornac/models/bert4rec/bert4rec.py create mode 100644 cornac/models/bert4rec/recom_bert4rec.py create mode 100644 cornac/models/bert4rec/requirements.txt create mode 100644 cornac/models/gpt2rec/__init__.py create mode 100644 cornac/models/gpt2rec/gpt2rec.py create mode 100644 cornac/models/gpt2rec/recom_gpt2rec.py create mode 100644 cornac/models/gpt2rec/requirements.txt create mode 100644 cornac/models/sasrec/__init__.py create mode 100644 cornac/models/sasrec/recom_sasrec.py create mode 100644 cornac/models/sasrec/requirements.txt create mode 100644 cornac/models/sasrec/sasrec.py create mode 100644 examples/transformer_rec_diginetica.py diff --git a/README.md b/README.md index f543ff43a..52055439b 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,7 @@ The table below lists the recommendation models/algorithms featured in Cornac. E | 2023 | [Scalable Approximate NonSymmetric Autoencoder (SANSA)](cornac/models/sansa), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.sansa.recom_sansa), [paper](https://dl.acm.org/doi/10.1145/3604915.3608827) | Collaborative Filtering | [requirements](cornac/models/sansa/requirements.txt), CPU | [quick-start](examples/sansa_movielens.py), [150k-items](examples/sansa_tradesy.py) | 2022 | [Disentangled Multimodal Representation Learning for Recommendation (DMRL)](cornac/models/dmrl), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.dmrl.recom_dmrl), [paper](https://arxiv.org/pdf/2203.05406.pdf) | Content-Based / Text & Image | [requirements](cornac/models/dmrl/requirements.txt), CPU / GPU | [quick-start](examples/dmrl_example.py) | 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.bivaecf.recom_bivaecf), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441759) | Collaborative Filtering / Content-Based | [requirements](cornac/models/bivaecf/requirements.txt), CPU / GPU | [quick-start](https://github.com/PreferredAI/bi-vae), [deep-dive](https://github.com/recommenders-team/recommenders/blob/main/examples/02_model_collaborative_filtering/cornac_bivae_deep_dive.ipynb) +| | [GPT-2 for Sequential Recommendation (GPT2Rec)](cornac/models/gpt2rec), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.gpt2rec.recom_gpt2rec), [paper](https://dl.acm.org/doi/10.1145/3460231.3474255) | Next-Item | [requirements](cornac/models/gpt2rec/requirements.txt), CPU / GPU | [quick-start](examples/transformer_rec_diginetica.py) | | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.causalrec.recom_causalrec), [paper](https://arxiv.org/abs/2107.02390) | Content-Based / Image | [requirements](cornac/models/causalrec/requirements.txt), CPU / GPU | [quick-start](examples/causalrec_clothing.py) | | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.comparer.recom_comparer_sub), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | Explainable | CPU | [quick-start](https://github.com/PreferredAI/ComparER) | 2020 | [Adversarial Multimedia Recommendation (AMR)](cornac/models/amr), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.amr.recom_amr), [paper](https://ieeexplore.ieee.org/document/8618394) | Content-Based / Image | [requirements](cornac/models/amr/requirements.txt), CPU / GPU | [quick-start](examples/amr_clothing.py) @@ -166,10 +167,12 @@ The table below lists the recommendation models/algorithms featured in Cornac. E | | [Temporal-Item-Frequency-based User-KNN (TIFUKNN)](cornac/models/tifuknn), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.tifuknn.recom_tifuknn), [paper](https://arxiv.org/pdf/2006.00556.pdf) | Next-Basket | CPU | [quick-start](examples/tifuknn_tafeng.py) | | [Variational Autoencoder for Top-N Recommendations (RecVAE)](cornac/models/recvae), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.recvae.recom_recvae), [paper](https://doi.org/10.1145/3336191.3371831) | Collaborative Filtering | [requirements](cornac/models/recvae/requirements.txt), CPU / GPU | [quick-start](examples/recvae_example.py) | 2019 | [Correlation-Sensitive Next-Basket Recommendation (Beacon)](cornac/models/beacon), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#correlation-sensitive-next-basket-recommendation-beacon), [paper](https://www.ijcai.org/proceedings/2019/0389.pdf) | Next-Basket | [requirements](cornac/models/beacon/requirements.txt), CPU / GPU | [quick-start](examples/beacon_tafeng.py) +| | [BERT4Rec: Sequential Recommendation with Bidirectional Encoder Representations from Transformer (BERT4Rec)](cornac/models/bert4rec), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.bert4rec.recom_bert4rec), [paper](https://arxiv.org/pdf/1904.06690.pdf) | Next-Item | [requirements](cornac/models/bert4rec/requirements.txt), CPU / GPU | [quick-start](examples/transformer_rec_diginetica.py) | | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.ease.recom_ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | Collaborative Filtering | CPU | [quick-start](examples/ease_movielens.py) | | [Neural Graph Collaborative Filtering (NGCF)](cornac/models/ngcf), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.ngcf.recom_ngcf), [paper](https://arxiv.org/pdf/1905.08108.pdf) | Collaborative Filtering | [requirements](cornac/models/ngcf/requirements.txt), CPU / GPU | [quick-start](examples/ngcf_example.py) | | [Sampler Design for Bayesian Personalized Ranking by Leveraging View Data (VEBPR)](cornac/models/bpr), [paper](https://arxiv.org/pdf/1809.08162) | Collaborative Filtering | CPU | [quick-start](examples/vebpr_example.py) | 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.c2pf.recom_c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | Content-Based / Graph | CPU | [quick-start](examples/c2pf_example.py) +| | [Self-Attentive Sequential Recommendation (SASRec)](cornac/models/sasrec), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.sasrec.recom_sasrec), [paper](https://arxiv.org/pdf/1808.09781.pdf) | Next-Item | [requirements](cornac/models/sasrec/requirements.txt), CPU / GPU | [quick-start](examples/transformer_rec_diginetica.py) | | [Graph Convolutional Matrix Completion (GCMC)](cornac/models/gcmc), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.gcmc.recom_gcmc), [paper](https://www.kdd.org/kdd2018/files/deep-learning-day/DLDay18_paper_32.pdf) | Collaborative Filtering | [requirements](cornac/models/gcmc/requirements.txt), CPU / GPU | [quick-start](examples/gcmc_example.py) | | [Multi-Task Explainable Recommendation (MTER)](cornac/models/mter), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.mter.recom_mter), [paper](https://arxiv.org/pdf/1806.03568.pdf) | Explainable | CPU | [quick-start](examples/mter_example.py), [deep-dive](https://github.com/PreferredAI/tutorials/blob/master/recommender-systems/07_explanations.ipynb) | | [Neural Attention Rating Regression with Review-level Explanations (NARRE)](cornac/models/narre), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.narre.recom_narre), [paper](http://www.thuir.cn/group/~YQLiu/publications/WWW2018_CC.pdf) | Explainable / Content-Based | [requirements](cornac/models/narre/requirements.txt), CPU / GPU | [quick-start](examples/narre_example.py) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 278fed13c..52ad564ed 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -24,6 +24,7 @@ from .ann import ScaNNANN from .baseline_only import BaselineOnly from .beacon import Beacon +from .bert4rec import BERT4Rec from .bivaecf import BiVAECF from .bpr import BPR from .bpr import WBPR @@ -49,6 +50,7 @@ from .gcmc import GCMC from .global_avg import GlobalAvg from .gp_top import GPTop +from .gpt2rec import GPT2Rec from .gru4rec import GRU4Rec from .hft import HFT from .hpf import HPF @@ -75,6 +77,7 @@ from .pmf import PMF from .recvae import RecVAE from .sansa import SANSA +from .sasrec import SASRec from .sbpr import SBPR from .skm import SKMeans from .sorec import SoRec diff --git a/cornac/models/bert4rec/__init__.py b/cornac/models/bert4rec/__init__.py new file mode 100644 index 000000000..af46ce60a --- /dev/null +++ b/cornac/models/bert4rec/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .recom_bert4rec import BERT4Rec diff --git a/cornac/models/bert4rec/bert4rec.py b/cornac/models/bert4rec/bert4rec.py new file mode 100644 index 000000000..5fc1d77c2 --- /dev/null +++ b/cornac/models/bert4rec/bert4rec.py @@ -0,0 +1,117 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import torch +import torch.nn as nn + + +class BERT4RecModel(nn.Module): + """BERT4Rec: bidirectional transformer encoder for sequence rec. + + Uses HuggingFace's :class:`~transformers.BertModel` as the backbone. The + sequence-final hidden state is used to score candidate items via the dot + product with their output embeddings (separate "head" linear or tied to + ``item_emb``). + + Returns ``(B, B+N)`` score matrices when called as + ``forward(_, hist_iids, out_iids, return_hidden=False)``. + """ + + def __init__( + self, + item_num, + embedding_dim=100, + maxlen=20, + n_layers=2, + n_heads=1, + dropout=0.1, + pad_idx=-1, + tie_weights=False, + init_std=0.02, + device="cpu", + ): + super().__init__() + from transformers.models.bert import BertConfig, BertModel + + self.item_num = item_num + self.pad_idx = pad_idx if pad_idx >= 0 else item_num + self.maxlen = maxlen + self.dev = device + self.init_std = init_std + self.tie_weights = tie_weights + + config = BertConfig( + vocab_size=item_num + 1, + hidden_size=embedding_dim, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + intermediate_size=embedding_dim * 4, + hidden_act="gelu", + hidden_dropout_prob=dropout, + attention_probs_dropout_prob=dropout, + max_position_embeddings=maxlen + 1, + initializer_range=init_std, + pad_token_id=self.pad_idx, + layer_norm_eps=1e-12, + use_cache=False, + ) + + self.item_emb = nn.Embedding( + num_embeddings=item_num + 1, + embedding_dim=embedding_dim, + padding_idx=self.pad_idx, + ) + self.transformer_model = BertModel(config) + self.item_biases = nn.Embedding(item_num + 1, 1, padding_idx=self.pad_idx) + self._init_weights() + self.to(device) + + def _init_weights(self): + self.item_emb.weight.data.normal_(mean=0.0, std=self.init_std) + self.item_emb.weight.data[self.pad_idx].zero_() + self.item_biases.weight.data.zero_() + + def _encode(self, hist_iids): + attention_mask = (hist_iids != self.pad_idx).long() + embeds = self.item_emb(hist_iids) + out = self.transformer_model( + inputs_embeds=embeds, attention_mask=attention_mask + ) + return out.last_hidden_state[:, -1, :] + + def forward(self, user_ids, hist_iids, out_iids, return_hidden=False): + hidden = self._encode(hist_iids) + item_e = self.item_emb(out_iids) + bias = self.item_biases(out_iids) + if return_hidden: + return hidden, item_e, bias + scores = torch.mm(hidden, item_e.T) + bias.T + return scores + + @torch.no_grad() + def predict(self, user_ids, log_seqs, item_indices=None): + if item_indices is None: + item_indices = torch.arange(self.item_num, device=self.dev) + else: + item_indices = torch.as_tensor( + item_indices, dtype=torch.long, device=self.dev + ) + if not isinstance(log_seqs, torch.Tensor): + log_seqs = torch.as_tensor(log_seqs, dtype=torch.long, device=self.dev) + hidden = self._encode(log_seqs) + item_e = self.item_emb(item_indices) + bias = self.item_biases(item_indices) + scores = torch.mm(hidden, item_e.T) + bias.T + return scores.squeeze().detach().cpu().numpy() diff --git a/cornac/models/bert4rec/recom_bert4rec.py b/cornac/models/bert4rec/recom_bert4rec.py new file mode 100644 index 000000000..ac716280e --- /dev/null +++ b/cornac/models/bert4rec/recom_bert4rec.py @@ -0,0 +1,227 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from tqdm.auto import trange + +from cornac.models.recommender import NextItemRecommender + +from ...utils import get_rng +from ..seq_utils import session_seq_iter, val_score + +SUPPORTED_LOSSES = ( + "bce", + "ce", + "bpr", + "bpr-max", + "softmax", + "cross-entropy", + "xe_softmax", + "top1", +) + + +class BERT4Rec(NextItemRecommender): + """BERT4Rec: a bidirectional transformer encoder for sequential rec. + + Wraps HuggingFace's :class:`~transformers.BertModel` as the sequence + encoder; the last-position hidden state scores candidate items by dot + product, sharing the ``(B, B+N)`` loss contract of + :mod:`cornac.models.seq_utils`. Parameters mirror + :class:`cornac.models.SASRec` (minus ``use_pos_emb`` — the backbone + provides its own positional embeddings); see the SASRec docstring for + details about ``loss``, ``model_selection``, and the rest. + + Note + ---- + This uses the next-item-at-last-position objective shared by the + transformer family in Cornac, *not* the canonical masked-language-model + (MLM) objective of the original paper. + + References + ---------- + Sun, F., Liu, J., Wu, J., Pei, C., Lin, X., Ou, W., & Jiang, P. (2019). + BERT4Rec: Sequential recommendation with bidirectional encoder + representations from transformer. CIKM. + """ + + def __init__( + self, + name="BERT4Rec", + embedding_dim=100, + loss="ce", + batch_size=512, + learning_rate=0.001, + n_sample=2048, + sample_alpha=0.5, + n_epochs=10, + max_len=50, + num_blocks=2, + num_heads=1, + dropout=0.2, + l2_reg=0.0, + bpreg=1.0, + elu_param=0.5, + device="cpu", + trainable=True, + verbose=False, + seed=None, + model_selection="last", + val_eval_every=5, + val_k=20, + val_metric="recall", + ): + super().__init__(name, trainable=trainable, verbose=verbose) + if loss not in SUPPORTED_LOSSES: + raise ValueError( + f"loss='{loss}' not supported; choose from {SUPPORTED_LOSSES}" + ) + if model_selection not in ("last", "best"): + raise ValueError( + f"model_selection='{model_selection}' not supported; choose 'last' or 'best'" + ) + self.embedding_dim = embedding_dim + self.loss = loss + self.batch_size = batch_size + self.learning_rate = learning_rate + self.n_sample = n_sample + self.sample_alpha = sample_alpha + self.n_epochs = n_epochs + self.max_len = max_len + self.num_blocks = num_blocks + self.num_heads = num_heads + self.dropout = dropout + self.l2_reg = l2_reg + self.bpreg = bpreg + self.elu_param = elu_param + self.device = device + self.seed = seed + self.rng = get_rng(seed) + self.model_selection = model_selection + self.val_eval_every = val_eval_every + self.val_k = val_k + self.val_metric = val_metric + + def fit(self, train_set, val_set=None): + super().fit(train_set, val_set) + if not self.trainable: + return self + + import torch + + from .bert4rec import BERT4RecModel + from ..seq_utils.losses import get_loss_function + + torch.manual_seed(self.seed if self.seed is not None else 0) + + self.pad_idx = self.total_items + self.model = BERT4RecModel( + item_num=self.total_items, + embedding_dim=self.embedding_dim, + maxlen=self.max_len, + n_layers=self.num_blocks, + n_heads=self.num_heads, + dropout=self.dropout, + pad_idx=self.pad_idx, + device=self.device, + ) + + loss_fn = get_loss_function(self.loss) + loss_kwargs = dict( + bpreg=self.bpreg, elu_param=self.elu_param, n_sample=self.n_sample + ) + opt = torch.optim.Adam( + self.model.parameters(), lr=self.learning_rate, betas=(0.9, 0.98) + ) + + best_val = -float("inf") + best_state = None + progress_bar = trange(1, self.n_epochs + 1, disable=not self.verbose) + for epoch_id in progress_bar: + self.model.train() + total_loss = 0.0 + cnt = 0 + for inc, (in_uids, hist_iids, out_iids) in enumerate( + session_seq_iter( + self.train_set, + pad_index=self.pad_idx, + batch_size=self.batch_size, + max_len=self.max_len, + n_sample=self.n_sample, + sample_alpha=self.sample_alpha, + rng=self.rng, + shuffle=True, + ) + ): + if len(hist_iids) < 2: + continue + hist_iids_t = torch.tensor( + hist_iids, dtype=torch.long, device=self.device, requires_grad=False + ) + out_iids_t = torch.tensor( + out_iids, dtype=torch.long, device=self.device, requires_grad=False + ) + + self.model.zero_grad() + item_scores = self.model(None, hist_iids_t, out_iids_t) + L = loss_fn( + item_scores, + out_iids=out_iids_t, + batch_size=len(hist_iids), + **loss_kwargs, + ) + if self.l2_reg > 0: + for p in self.model.parameters(): + L = L + self.l2_reg * torch.norm(p) + + L.backward() + opt.step() + + total_loss += L.cpu().detach().numpy() * len(hist_iids) + cnt += len(hist_iids) + if inc % 10 == 0 and cnt > 0: + progress_bar.set_postfix(loss=(total_loss / cnt)) + + if ( + self.model_selection == "best" + and val_set is not None + and epoch_id % self.val_eval_every == 0 + ): + score = val_score( + self, self.train_set, val_set, metric=self.val_metric, k=self.val_k + ) + if score is not None and score > best_val: + best_val = score + best_state = { + n: p.detach().clone() + for n, p in self.model.state_dict().items() + } + + if self.model_selection == "best" and best_state is not None: + self.model.load_state_dict(best_state) + return self + + def score(self, user_idx, history_items, **kwargs): + import torch + + if len(history_items) == 0: + return np.ones(self.total_items, dtype="float") + log_seq = [self.pad_idx] * (self.max_len - len(history_items)) + list( + history_items + ) + log_seq = log_seq[-self.max_len :] + log_seq_t = torch.tensor([log_seq], dtype=torch.long, device=self.device) + self.model.eval() + return self.model.predict(user_idx, log_seq_t) diff --git a/cornac/models/bert4rec/requirements.txt b/cornac/models/bert4rec/requirements.txt new file mode 100644 index 000000000..a808c3f6f --- /dev/null +++ b/cornac/models/bert4rec/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.12.0 +transformers>=4.30.0 diff --git a/cornac/models/gpt2rec/__init__.py b/cornac/models/gpt2rec/__init__.py new file mode 100644 index 000000000..ed6d4676c --- /dev/null +++ b/cornac/models/gpt2rec/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .recom_gpt2rec import GPT2Rec diff --git a/cornac/models/gpt2rec/gpt2rec.py b/cornac/models/gpt2rec/gpt2rec.py new file mode 100644 index 000000000..93d4e56e4 --- /dev/null +++ b/cornac/models/gpt2rec/gpt2rec.py @@ -0,0 +1,112 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import torch +import torch.nn as nn + + +class GPT2RecModel(nn.Module): + """GPT2-based causal transformer for next-item recommendation. + + Same input/output contract as :class:`SASRecModel` and + :class:`BERT4RecModel`. Returns a ``(B, B+N)`` score matrix. + """ + + def __init__( + self, + item_num, + embedding_dim=100, + maxlen=20, + n_layers=2, + n_heads=1, + dropout=0.1, + pad_idx=-1, + tie_weights=False, + init_std=0.02, + device="cpu", + ): + super().__init__() + from transformers.models.gpt2 import GPT2Config, GPT2Model + + self.item_num = item_num + self.pad_idx = pad_idx if pad_idx >= 0 else item_num + self.maxlen = maxlen + self.dev = device + self.init_std = init_std + self.tie_weights = tie_weights + + config = GPT2Config( + vocab_size=item_num + 1, + n_positions=maxlen + 1, + n_embd=embedding_dim, + n_layer=n_layers, + n_head=n_heads, + n_inner=embedding_dim * 4, + activation_function="gelu_new", + resid_pdrop=dropout, + embd_pdrop=dropout, + attn_pdrop=dropout, + initializer_range=init_std, + pad_token_id=self.pad_idx, + layer_norm_epsilon=1e-12, + use_cache=False, + ) + + self.item_emb = nn.Embedding( + item_num + 1, embedding_dim, padding_idx=self.pad_idx + ) + self.transformer_model = GPT2Model(config) + self.item_biases = nn.Embedding(item_num + 1, 1, padding_idx=self.pad_idx) + + self._init_weights() + self.to(device) + + def _init_weights(self): + self.item_emb.weight.data.normal_(mean=0.0, std=self.init_std) + self.item_emb.weight.data[self.pad_idx].zero_() + self.item_biases.weight.data.zero_() + + def _encode(self, hist_iids): + attention_mask = (hist_iids != self.pad_idx).long() + embeds = self.item_emb(hist_iids) + out = self.transformer_model( + inputs_embeds=embeds, attention_mask=attention_mask + ) + return out.last_hidden_state[:, -1, :] + + def forward(self, user_ids, hist_iids, out_iids, return_hidden=False): + hidden = self._encode(hist_iids) + item_e = self.item_emb(out_iids) + bias = self.item_biases(out_iids) + if return_hidden: + return hidden, item_e, bias + scores = torch.mm(hidden, item_e.T) + bias.T + return scores + + @torch.no_grad() + def predict(self, user_ids, log_seqs, item_indices=None): + if item_indices is None: + item_indices = torch.arange(self.item_num, device=self.dev) + else: + item_indices = torch.as_tensor( + item_indices, dtype=torch.long, device=self.dev + ) + if not isinstance(log_seqs, torch.Tensor): + log_seqs = torch.as_tensor(log_seqs, dtype=torch.long, device=self.dev) + hidden = self._encode(log_seqs) + item_e = self.item_emb(item_indices) + bias = self.item_biases(item_indices) + scores = torch.mm(hidden, item_e.T) + bias.T + return scores.squeeze().detach().cpu().numpy() diff --git a/cornac/models/gpt2rec/recom_gpt2rec.py b/cornac/models/gpt2rec/recom_gpt2rec.py new file mode 100644 index 000000000..df8412c3a --- /dev/null +++ b/cornac/models/gpt2rec/recom_gpt2rec.py @@ -0,0 +1,227 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from tqdm.auto import trange + +from cornac.models.recommender import NextItemRecommender + +from ...utils import get_rng +from ..seq_utils import session_seq_iter, val_score + +SUPPORTED_LOSSES = ( + "bce", + "ce", + "bpr", + "bpr-max", + "softmax", + "cross-entropy", + "xe_softmax", + "top1", +) + + +class GPT2Rec(NextItemRecommender): + """GPT2Rec: a causal (GPT-2) transformer for sequential recommendation. + + Wraps HuggingFace's :class:`~transformers.GPT2Model` as the sequence + encoder; the last-position hidden state scores candidate items by dot + product, sharing the ``(B, B+N)`` loss contract of + :mod:`cornac.models.seq_utils`. Parameters mirror + :class:`cornac.models.SASRec` (minus ``use_pos_emb`` — the backbone + provides its own positional embeddings); see the SASRec docstring for + details about ``loss``, ``model_selection``, and the rest. + + Note + ---- + This uses the next-item-at-last-position objective shared by the + transformer family in Cornac, *not* a canonical causal-language-model + (CLM) loss at every position. + + References + ---------- + de Souza Pereira Moreira, G., Rabhi, S., Lee, J. M., Ak, R., & Oldridge, E. + (2021). Transformers4Rec: Bridging the gap between NLP and sequential / + session-based recommendation. RecSys. + """ + + def __init__( + self, + name="GPT2Rec", + embedding_dim=100, + loss="ce", + batch_size=512, + learning_rate=0.001, + n_sample=2048, + sample_alpha=0.5, + n_epochs=10, + max_len=50, + num_blocks=2, + num_heads=1, + dropout=0.2, + l2_reg=0.0, + bpreg=1.0, + elu_param=0.5, + device="cpu", + trainable=True, + verbose=False, + seed=None, + model_selection="last", + val_eval_every=5, + val_k=20, + val_metric="recall", + ): + super().__init__(name, trainable=trainable, verbose=verbose) + if loss not in SUPPORTED_LOSSES: + raise ValueError( + f"loss='{loss}' not supported; choose from {SUPPORTED_LOSSES}" + ) + if model_selection not in ("last", "best"): + raise ValueError( + f"model_selection='{model_selection}' not supported; choose 'last' or 'best'" + ) + self.embedding_dim = embedding_dim + self.loss = loss + self.batch_size = batch_size + self.learning_rate = learning_rate + self.n_sample = n_sample + self.sample_alpha = sample_alpha + self.n_epochs = n_epochs + self.max_len = max_len + self.num_blocks = num_blocks + self.num_heads = num_heads + self.dropout = dropout + self.l2_reg = l2_reg + self.bpreg = bpreg + self.elu_param = elu_param + self.device = device + self.seed = seed + self.rng = get_rng(seed) + self.model_selection = model_selection + self.val_eval_every = val_eval_every + self.val_k = val_k + self.val_metric = val_metric + + def fit(self, train_set, val_set=None): + super().fit(train_set, val_set) + if not self.trainable: + return self + + import torch + + from .gpt2rec import GPT2RecModel + from ..seq_utils.losses import get_loss_function + + torch.manual_seed(self.seed if self.seed is not None else 0) + + self.pad_idx = self.total_items + self.model = GPT2RecModel( + item_num=self.total_items, + embedding_dim=self.embedding_dim, + maxlen=self.max_len, + n_layers=self.num_blocks, + n_heads=self.num_heads, + dropout=self.dropout, + pad_idx=self.pad_idx, + device=self.device, + ) + + loss_fn = get_loss_function(self.loss) + loss_kwargs = dict( + bpreg=self.bpreg, elu_param=self.elu_param, n_sample=self.n_sample + ) + opt = torch.optim.Adam( + self.model.parameters(), lr=self.learning_rate, betas=(0.9, 0.98) + ) + + best_val = -float("inf") + best_state = None + progress_bar = trange(1, self.n_epochs + 1, disable=not self.verbose) + for epoch_id in progress_bar: + self.model.train() + total_loss = 0.0 + cnt = 0 + for inc, (in_uids, hist_iids, out_iids) in enumerate( + session_seq_iter( + self.train_set, + pad_index=self.pad_idx, + batch_size=self.batch_size, + max_len=self.max_len, + n_sample=self.n_sample, + sample_alpha=self.sample_alpha, + rng=self.rng, + shuffle=True, + ) + ): + if len(hist_iids) < 2: + continue + hist_iids_t = torch.tensor( + hist_iids, dtype=torch.long, device=self.device, requires_grad=False + ) + out_iids_t = torch.tensor( + out_iids, dtype=torch.long, device=self.device, requires_grad=False + ) + + self.model.zero_grad() + item_scores = self.model(None, hist_iids_t, out_iids_t) + L = loss_fn( + item_scores, + out_iids=out_iids_t, + batch_size=len(hist_iids), + **loss_kwargs, + ) + if self.l2_reg > 0: + for p in self.model.parameters(): + L = L + self.l2_reg * torch.norm(p) + + L.backward() + opt.step() + + total_loss += L.cpu().detach().numpy() * len(hist_iids) + cnt += len(hist_iids) + if inc % 10 == 0 and cnt > 0: + progress_bar.set_postfix(loss=(total_loss / cnt)) + + if ( + self.model_selection == "best" + and val_set is not None + and epoch_id % self.val_eval_every == 0 + ): + score = val_score( + self, self.train_set, val_set, metric=self.val_metric, k=self.val_k + ) + if score is not None and score > best_val: + best_val = score + best_state = { + n: p.detach().clone() + for n, p in self.model.state_dict().items() + } + + if self.model_selection == "best" and best_state is not None: + self.model.load_state_dict(best_state) + return self + + def score(self, user_idx, history_items, **kwargs): + import torch + + if len(history_items) == 0: + return np.ones(self.total_items, dtype="float") + log_seq = [self.pad_idx] * (self.max_len - len(history_items)) + list( + history_items + ) + log_seq = log_seq[-self.max_len :] + log_seq_t = torch.tensor([log_seq], dtype=torch.long, device=self.device) + self.model.eval() + return self.model.predict(user_idx, log_seq_t) diff --git a/cornac/models/gpt2rec/requirements.txt b/cornac/models/gpt2rec/requirements.txt new file mode 100644 index 000000000..a808c3f6f --- /dev/null +++ b/cornac/models/gpt2rec/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.12.0 +transformers>=4.30.0 diff --git a/cornac/models/sasrec/__init__.py b/cornac/models/sasrec/__init__.py new file mode 100644 index 000000000..e163b9a7d --- /dev/null +++ b/cornac/models/sasrec/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .recom_sasrec import SASRec diff --git a/cornac/models/sasrec/recom_sasrec.py b/cornac/models/sasrec/recom_sasrec.py new file mode 100644 index 000000000..f9f44758c --- /dev/null +++ b/cornac/models/sasrec/recom_sasrec.py @@ -0,0 +1,290 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from tqdm.auto import trange + +from cornac.models.recommender import NextItemRecommender + +from ...utils import get_rng +from ..seq_utils import session_seq_iter, val_score + +SUPPORTED_LOSSES = ( + "bce", + "ce", + "bpr", + "bpr-max", + "softmax", + "cross-entropy", + "xe_softmax", + "top1", +) + + +class SASRec(NextItemRecommender): + """SASRec: Self-Attentive Sequential Recommendation. + + A self-attention (transformer) encoder over the current session's item + history; the last-position representation scores candidate items by dot + product. Training reuses the shared session iterator (session-based) and + the ``(B, B+N)`` score-matrix loss contract from + :mod:`cornac.models.seq_utils`. + + Parameters + ---------- + name: string, default: 'SASRec' + The name of the recommender model. + + embedding_dim: int, optional, default: 100 + Item embedding dimension. + + loss: str, optional, default: 'ce' + Loss function. Supported: 'bce', 'ce', 'bpr', 'bpr-max', 'softmax' + (a.k.a 'cross-entropy' / 'xe_softmax'), 'top1'. + + batch_size: int, optional, default: 512 + + learning_rate: float, optional, default: 0.001 + + n_sample: int, optional, default: 2048 + Number of negative samples shared per mini-batch. + + sample_alpha: float, optional, default: 0.5 + Popularity-based negative sampling exponent. + + n_epochs: int, optional, default: 10 + + max_len: int, optional, default: 50 + Maximum history length fed to the encoder. + + num_blocks: int, optional, default: 2 + Number of self-attention blocks. + + num_heads: int, optional, default: 1 + Number of attention heads. + + dropout: float, optional, default: 0.2 + + l2_reg: float, optional, default: 0.0 + + bpreg: float, optional, default: 1.0 + Regularization coefficient for the 'bpr-max' loss. + + elu_param: float, optional, default: 0.5 + ELU parameter for the 'bpr-max' loss. + + device: str, optional, default: 'cpu' + Set to 'cuda' for GPU support. + + use_pos_emb: bool, optional, default: True + Whether to add learned positional embeddings. + + model_selection: str, optional, default: 'last' + One of 'last' or 'best'. When 'best', the model with the highest + validation score (evaluated every ``val_eval_every`` epochs) is + restored at the end of ``fit``. + + val_eval_every: int, optional, default: 5 + val_k: int, optional, default: 20 + val_metric: str, optional, default: 'recall' + Cutoff and metric used for best-on-val selection. See + :func:`cornac.models.seq_utils.val_score`. + + trainable: bool, optional, default: True + When False, the model will not be re-trained. + + verbose: bool, optional, default: False + When True, running logs are displayed. + + seed: int, optional, default: None + Random seed for weight initialization. + + References + ---------- + Kang, W.-C., & McAuley, J. (2018). Self-attentive sequential + recommendation. ICDM. + """ + + def __init__( + self, + name="SASRec", + embedding_dim=100, + loss="ce", + batch_size=512, + learning_rate=0.001, + n_sample=2048, + sample_alpha=0.5, + n_epochs=10, + max_len=50, + num_blocks=2, + num_heads=1, + dropout=0.2, + l2_reg=0.0, + bpreg=1.0, + elu_param=0.5, + device="cpu", + use_pos_emb=True, + trainable=True, + verbose=False, + seed=None, + model_selection="last", + val_eval_every=5, + val_k=20, + val_metric="recall", + ): + super().__init__(name, trainable=trainable, verbose=verbose) + if loss not in SUPPORTED_LOSSES: + raise ValueError( + f"loss='{loss}' not supported; choose from {SUPPORTED_LOSSES}" + ) + if model_selection not in ("last", "best"): + raise ValueError( + f"model_selection='{model_selection}' not supported; choose 'last' or 'best'" + ) + self.embedding_dim = embedding_dim + self.loss = loss + self.batch_size = batch_size + self.learning_rate = learning_rate + self.n_sample = n_sample + self.sample_alpha = sample_alpha + self.n_epochs = n_epochs + self.max_len = max_len + self.num_blocks = num_blocks + self.num_heads = num_heads + self.dropout = dropout + self.l2_reg = l2_reg + self.bpreg = bpreg + self.elu_param = elu_param + self.device = device + self.use_pos_emb = use_pos_emb + self.seed = seed + self.rng = get_rng(seed) + self.model_selection = model_selection + self.val_eval_every = val_eval_every + self.val_k = val_k + self.val_metric = val_metric + + def fit(self, train_set, val_set=None): + super().fit(train_set, val_set) + if not self.trainable: + return self + + import torch + + from .sasrec import SASRecModel + from ..seq_utils.losses import get_loss_function + + torch.manual_seed(self.seed if self.seed is not None else 0) + + self.pad_idx = self.total_items + self.model = SASRecModel( + item_num=self.total_items, + embedding_dim=self.embedding_dim, + maxlen=self.max_len, + n_layers=self.num_blocks, + n_heads=self.num_heads, + use_pos_emb=self.use_pos_emb, + dropout=self.dropout, + pad_idx=self.pad_idx, + device=self.device, + ) + + loss_fn = get_loss_function(self.loss) + loss_kwargs = dict( + bpreg=self.bpreg, elu_param=self.elu_param, n_sample=self.n_sample + ) + opt = torch.optim.Adam( + self.model.parameters(), lr=self.learning_rate, betas=(0.9, 0.98) + ) + + best_val = -float("inf") + best_state = None + progress_bar = trange(1, self.n_epochs + 1, disable=not self.verbose) + for epoch_id in progress_bar: + self.model.train() + total_loss = 0.0 + cnt = 0 + for inc, (in_uids, hist_iids, out_iids) in enumerate( + session_seq_iter( + self.train_set, + pad_index=self.pad_idx, + batch_size=self.batch_size, + max_len=self.max_len, + n_sample=self.n_sample, + sample_alpha=self.sample_alpha, + rng=self.rng, + shuffle=True, + ) + ): + if len(hist_iids) < 2: + continue + hist_iids_t = torch.tensor( + hist_iids, dtype=torch.long, device=self.device, requires_grad=False + ) + out_iids_t = torch.tensor( + out_iids, dtype=torch.long, device=self.device, requires_grad=False + ) + + self.model.zero_grad() + item_scores = self.model(None, hist_iids_t, out_iids_t) + L = loss_fn( + item_scores, + out_iids=out_iids_t, + batch_size=len(hist_iids), + **loss_kwargs, + ) + if self.l2_reg > 0: + for p in self.model.parameters(): + L = L + self.l2_reg * torch.norm(p) + + L.backward() + opt.step() + + total_loss += L.cpu().detach().numpy() * len(hist_iids) + cnt += len(hist_iids) + if inc % 10 == 0 and cnt > 0: + progress_bar.set_postfix(loss=(total_loss / cnt)) + + if ( + self.model_selection == "best" + and val_set is not None + and epoch_id % self.val_eval_every == 0 + ): + score = val_score( + self, self.train_set, val_set, metric=self.val_metric, k=self.val_k + ) + if score is not None and score > best_val: + best_val = score + best_state = { + n: p.detach().clone() + for n, p in self.model.state_dict().items() + } + + if self.model_selection == "best" and best_state is not None: + self.model.load_state_dict(best_state) + return self + + def score(self, user_idx, history_items, **kwargs): + import torch + + if len(history_items) == 0: + return np.ones(self.total_items, dtype="float") + log_seq = [self.pad_idx] * (self.max_len - len(history_items)) + list( + history_items + ) + log_seq = log_seq[-self.max_len :] + log_seq_t = torch.tensor([log_seq], dtype=torch.long, device=self.device) + self.model.eval() + return self.model.predict(user_idx, log_seq_t) diff --git a/cornac/models/sasrec/requirements.txt b/cornac/models/sasrec/requirements.txt new file mode 100644 index 000000000..be222b022 --- /dev/null +++ b/cornac/models/sasrec/requirements.txt @@ -0,0 +1 @@ +torch>=1.12.0 diff --git a/cornac/models/sasrec/sasrec.py b/cornac/models/sasrec/sasrec.py new file mode 100644 index 000000000..e785d0f2e --- /dev/null +++ b/cornac/models/sasrec/sasrec.py @@ -0,0 +1,186 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import torch +import torch.nn as nn + + +class PointWiseFeedForward(nn.Module): + def __init__(self, hidden_units, dropout_rate): + super(PointWiseFeedForward, self).__init__() + conv1 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1) + dropout1 = nn.Dropout(p=dropout_rate) + relu = nn.ReLU() + conv2 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1) + dropout2 = nn.Dropout(p=dropout_rate) + self.process = nn.Sequential(conv1, dropout1, relu, conv2, dropout2) + + def forward(self, inputs): + outputs = self.process(inputs.transpose(-1, -2)) + outputs = outputs.transpose(-1, -2) + outputs += inputs + return outputs + + +class SASRecModel(nn.Module): + """SASRec self-attention model (Kang & McAuley, 2018). + + Operates on sequences of past item ids, returning the last-position + representation. Item ids are integers in ``[0, item_num)`` with + ``item_num`` being used as the padding index. + + The model produces a ``(B, B+N)`` score matrix when called as + ``forward(_, hist_iids, out_iids, return_hidden=False)`` where + ``out_iids`` contains the ``B`` in-batch positives followed by ``N`` + shared negatives, matching the contract expected by the loss functions + in :mod:`cornac.models.seq_utils`. + """ + + def __init__( + self, + item_num, + embedding_dim=100, + maxlen=20, + n_layers=2, + n_heads=1, + use_pos_emb=True, + use_biases=True, + dropout=0.2, + pad_idx=-1, + init_std=0.02, + device="cpu", + ): + super(SASRecModel, self).__init__() + self.item_num = item_num + self.pad_idx = pad_idx if pad_idx >= 0 else item_num + self.maxlen = maxlen + self.dev = device + self.init_std = init_std + + # +1 row for the padding entry at pad_idx + self.item_emb = nn.Embedding( + self.item_num + 1, embedding_dim, padding_idx=self.pad_idx + ) + if use_pos_emb: + self.pos_emb = nn.Embedding(maxlen + 1, embedding_dim) + if use_biases: + self.item_biases = nn.Embedding( + self.item_num + 1, 1, padding_idx=self.pad_idx + ) + self.emb_dropout = nn.Dropout(p=dropout) + + self.attention_layernorms = nn.ModuleList() + self.attention_layers = nn.ModuleList() + self.forward_layernorms = nn.ModuleList() + self.forward_layers = nn.ModuleList() + self.last_layernorm = nn.LayerNorm(embedding_dim, eps=1e-8) + + for _ in range(n_layers): + self.attention_layernorms.append(nn.LayerNorm(embedding_dim, eps=1e-8)) + self.attention_layers.append( + nn.MultiheadAttention(embedding_dim, n_heads, dropout) + ) + self.forward_layernorms.append(nn.LayerNorm(embedding_dim, eps=1e-8)) + self.forward_layers.append(PointWiseFeedForward(embedding_dim, dropout)) + + self.apply(self._init_weights) + self.to(device) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=self.init_std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.init_std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _score_items(self, hidden, cand_items, biases=None): + scores = torch.mm(hidden, cand_items.T) + if biases is not None: + return scores + biases.T + return scores + + def _encode(self, hist_iids): + # hist_iids: (B, T) + seqs = self.item_emb(hist_iids) + seqs = seqs * (self.item_emb.embedding_dim**0.5) + positions = np.tile(np.arange(hist_iids.shape[1]), [hist_iids.shape[0], 1]) + if hasattr(self, "pos_emb"): + seqs = seqs + self.pos_emb( + torch.tensor(positions, dtype=torch.long, device=seqs.device) + ) + seqs = self.emb_dropout(seqs) + + timeline_mask = (hist_iids == self.pad_idx).to( + dtype=seqs.dtype, device=seqs.device + ) + seqs = seqs * (1.0 - timeline_mask).unsqueeze(-1) + + tl = seqs.shape[1] + attention_mask = ~torch.tril( + torch.ones((tl, tl), dtype=torch.bool, device=seqs.device) + ) + + for i in range(len(self.attention_layers)): + seqs_t = torch.transpose(seqs, 0, 1) + Q = self.attention_layernorms[i](seqs_t) + mha_out, _ = self.attention_layers[i]( + Q, seqs_t, seqs_t, attn_mask=attention_mask + ) + seqs_t = Q + mha_out + seqs = torch.transpose(seqs_t, 0, 1) + seqs = self.forward_layernorms[i](seqs) + seqs = self.forward_layers[i](seqs) + seqs = seqs * (1.0 - timeline_mask).unsqueeze(-1) + + log_feats = self.last_layernorm(seqs) + return log_feats[:, -1, :] + + def forward(self, user_ids, hist_iids, out_iids, return_hidden=False): + hidden = self._encode(hist_iids) + item_emb = self.item_emb(out_iids) + biases = self.item_biases(out_iids) if hasattr(self, "item_biases") else None + if return_hidden: + return hidden, item_emb, biases + return self._score_items(hidden, item_emb, biases) + + @torch.no_grad() + def predict(self, user_ids, log_seqs, item_indices=None): + """Score all real items for a single padded sequence. + + Returns a 1-D numpy array of size ``item_num`` (padding column is + already stripped). + """ + if item_indices is None: + item_indices = torch.arange(self.item_num, device=self.dev) + else: + item_indices = torch.as_tensor( + item_indices, dtype=torch.long, device=self.dev + ) + if not isinstance(log_seqs, torch.Tensor): + log_seqs = torch.as_tensor(log_seqs, dtype=torch.long, device=self.dev) + hidden = self._encode(log_seqs) + item_emb = self.item_emb(item_indices) + biases = ( + self.item_biases(item_indices) if hasattr(self, "item_biases") else None + ) + scores = self._score_items(hidden, item_emb, biases) + return scores.squeeze().detach().cpu().numpy() diff --git a/examples/transformer_rec_diginetica.py b/examples/transformer_rec_diginetica.py new file mode 100644 index 000000000..28564f6fe --- /dev/null +++ b/examples/transformer_rec_diginetica.py @@ -0,0 +1,108 @@ +# Copyright 2026 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer-based next-item recommenders on Diginetica. + +SASRec, BERT4Rec, and GPT2Rec share one scoring head (encode the current +session, take the last-position hidden state, dot-product against item +embeddings) and differ only in the sequence encoder: + +- SASRec : its own causal self-attention stack (torch only) +- BERT4Rec : a HuggingFace BERT encoder +- GPT2Rec : a HuggingFace GPT-2 decoder + +BERT4Rec and GPT2Rec require the ``transformers`` package (see each model's +requirements.txt). All three use the next-item-at-last-position objective, not +the canonical MLM/CLM losses in Transformers4Rec paper. +""" + +import torch + +import cornac +from cornac.datasets import diginetica +from cornac.eval_methods import NextItemEvaluation +from cornac.metrics import MRR, NDCG, Recall +from cornac.models import BERT4Rec, GPT2Rec, GRU4Rec, SASRec + +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" +print(f"using device: {DEVICE}") + +train_data = diginetica.load_train() +val_data = diginetica.load_val() +test_data = diginetica.load_test() +print("data loaded") + +next_item_eval = NextItemEvaluation.from_splits( + train_data=train_data, + val_data=val_data, + test_data=test_data, + exclude_unknowns=True, + verbose=True, + fmt="USIT", +) + +transformer = dict( + embedding_dim=64, + loss="cross-entropy", + n_sample=512, + batch_size=128, + n_epochs=100, + max_len=20, + num_blocks=2, + num_heads=2, + model_selection="best", + val_eval_every=5, + val_metric="ndcg", + val_k=10, + device=DEVICE, + verbose=True, + seed=123, +) + +models = [ + GRU4Rec( + layers=[100], + loss="cross-entropy", + dropout_p_hidden=0.3, + sample_alpha=0.75, + n_sample=512, + batch_size=64, + learning_rate=0.1, + n_epochs=50, + model_selection="best", + val_eval_every=5, + val_metric="recall", + val_k=20, + device=DEVICE, + verbose=True, + seed=123, + ), + SASRec(learning_rate=0.01, **transformer), + BERT4Rec(learning_rate=0.01, **transformer), + GPT2Rec(learning_rate=0.001, **transformer), +] + +metrics = [ + NDCG(k=10), + NDCG(k=50), + Recall(k=10), + Recall(k=50), + MRR(), +] + +cornac.Experiment( + eval_method=next_item_eval, + models=models, + metrics=metrics, +).run() From fd8005f4b93d0f62bc0396e1e0476a639fa01a4d Mon Sep 17 00:00:00 2001 From: hieuddo Date: Tue, 2 Jun 2026 17:09:22 +0800 Subject: [PATCH 2/4] remove tie_weights --- cornac/models/bert4rec/bert4rec.py | 2 -- cornac/models/gpt2rec/gpt2rec.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/cornac/models/bert4rec/bert4rec.py b/cornac/models/bert4rec/bert4rec.py index 5fc1d77c2..3213059ba 100644 --- a/cornac/models/bert4rec/bert4rec.py +++ b/cornac/models/bert4rec/bert4rec.py @@ -38,7 +38,6 @@ def __init__( n_heads=1, dropout=0.1, pad_idx=-1, - tie_weights=False, init_std=0.02, device="cpu", ): @@ -50,7 +49,6 @@ def __init__( self.maxlen = maxlen self.dev = device self.init_std = init_std - self.tie_weights = tie_weights config = BertConfig( vocab_size=item_num + 1, diff --git a/cornac/models/gpt2rec/gpt2rec.py b/cornac/models/gpt2rec/gpt2rec.py index 93d4e56e4..5cde864c5 100644 --- a/cornac/models/gpt2rec/gpt2rec.py +++ b/cornac/models/gpt2rec/gpt2rec.py @@ -33,7 +33,6 @@ def __init__( n_heads=1, dropout=0.1, pad_idx=-1, - tie_weights=False, init_std=0.02, device="cpu", ): @@ -45,7 +44,6 @@ def __init__( self.maxlen = maxlen self.dev = device self.init_std = init_std - self.tie_weights = tie_weights config = GPT2Config( vocab_size=item_num + 1, From 82a5aaa155617c13c383988f4110d543cec0e2a8 Mon Sep 17 00:00:00 2001 From: hieuddo Date: Tue, 2 Jun 2026 17:12:05 +0800 Subject: [PATCH 3/4] mask pad_idx --- cornac/models/sasrec/sasrec.py | 47 ++++++++++++---------------------- 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/cornac/models/sasrec/sasrec.py b/cornac/models/sasrec/sasrec.py index e785d0f2e..42d2a2985 100644 --- a/cornac/models/sasrec/sasrec.py +++ b/cornac/models/sasrec/sasrec.py @@ -71,15 +71,11 @@ def __init__( self.init_std = init_std # +1 row for the padding entry at pad_idx - self.item_emb = nn.Embedding( - self.item_num + 1, embedding_dim, padding_idx=self.pad_idx - ) + self.item_emb = nn.Embedding(self.item_num + 1, embedding_dim, padding_idx=self.pad_idx) if use_pos_emb: self.pos_emb = nn.Embedding(maxlen + 1, embedding_dim) if use_biases: - self.item_biases = nn.Embedding( - self.item_num + 1, 1, padding_idx=self.pad_idx - ) + self.item_biases = nn.Embedding(self.item_num + 1, 1, padding_idx=self.pad_idx) self.emb_dropout = nn.Dropout(p=dropout) self.attention_layernorms = nn.ModuleList() @@ -90,9 +86,7 @@ def __init__( for _ in range(n_layers): self.attention_layernorms.append(nn.LayerNorm(embedding_dim, eps=1e-8)) - self.attention_layers.append( - nn.MultiheadAttention(embedding_dim, n_heads, dropout) - ) + self.attention_layers.append(nn.MultiheadAttention(embedding_dim, n_heads, dropout)) self.forward_layernorms.append(nn.LayerNorm(embedding_dim, eps=1e-8)) self.forward_layers.append(PointWiseFeedForward(embedding_dim, dropout)) @@ -124,32 +118,29 @@ def _encode(self, hist_iids): seqs = seqs * (self.item_emb.embedding_dim**0.5) positions = np.tile(np.arange(hist_iids.shape[1]), [hist_iids.shape[0], 1]) if hasattr(self, "pos_emb"): - seqs = seqs + self.pos_emb( - torch.tensor(positions, dtype=torch.long, device=seqs.device) - ) + seqs = seqs + self.pos_emb(torch.tensor(positions, dtype=torch.long, device=seqs.device)) seqs = self.emb_dropout(seqs) - timeline_mask = (hist_iids == self.pad_idx).to( - dtype=seqs.dtype, device=seqs.device - ) - seqs = seqs * (1.0 - timeline_mask).unsqueeze(-1) + pad_mask = hist_iids == self.pad_idx # (B, T) + seqs = seqs.masked_fill(pad_mask.unsqueeze(-1), 0.0) - tl = seqs.shape[1] - attention_mask = ~torch.tril( - torch.ones((tl, tl), dtype=torch.bool, device=seqs.device) - ) + B, tl, _ = seqs.shape + future = torch.triu(torch.ones(tl, tl, dtype=torch.bool, device=seqs.device), diagonal=1) + block = future.unsqueeze(0) | pad_mask.unsqueeze(1) # (B, T, T) + block = block & ~torch.eye(tl, dtype=torch.bool, device=seqs.device) + attn_mask = torch.zeros(B, tl, tl, dtype=seqs.dtype, device=seqs.device).masked_fill(block, float("-inf")) + n_heads = self.attention_layers[0].num_heads + attn_mask = attn_mask.repeat_interleave(n_heads, dim=0) # (B*n_heads, T, T) for i in range(len(self.attention_layers)): seqs_t = torch.transpose(seqs, 0, 1) Q = self.attention_layernorms[i](seqs_t) - mha_out, _ = self.attention_layers[i]( - Q, seqs_t, seqs_t, attn_mask=attention_mask - ) + mha_out, _ = self.attention_layers[i](Q, seqs_t, seqs_t, attn_mask=attn_mask) seqs_t = Q + mha_out seqs = torch.transpose(seqs_t, 0, 1) seqs = self.forward_layernorms[i](seqs) seqs = self.forward_layers[i](seqs) - seqs = seqs * (1.0 - timeline_mask).unsqueeze(-1) + seqs = seqs.masked_fill(pad_mask.unsqueeze(-1), 0.0) log_feats = self.last_layernorm(seqs) return log_feats[:, -1, :] @@ -172,15 +163,11 @@ def predict(self, user_ids, log_seqs, item_indices=None): if item_indices is None: item_indices = torch.arange(self.item_num, device=self.dev) else: - item_indices = torch.as_tensor( - item_indices, dtype=torch.long, device=self.dev - ) + item_indices = torch.as_tensor(item_indices, dtype=torch.long, device=self.dev) if not isinstance(log_seqs, torch.Tensor): log_seqs = torch.as_tensor(log_seqs, dtype=torch.long, device=self.dev) hidden = self._encode(log_seqs) item_emb = self.item_emb(item_indices) - biases = ( - self.item_biases(item_indices) if hasattr(self, "item_biases") else None - ) + biases = self.item_biases(item_indices) if hasattr(self, "item_biases") else None scores = self._score_items(hidden, item_emb, biases) return scores.squeeze().detach().cpu().numpy() From 8179d5b9c872cbf1b80a6a7d5fe71acf3f12c1bf Mon Sep 17 00:00:00 2001 From: hieuddo Date: Thu, 4 Jun 2026 12:50:58 +0800 Subject: [PATCH 4/4] docs: add Transformer recommendation example to README --- examples/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/README.md b/examples/README.md index b7be32f0e..0228cdb64 100644 --- a/examples/README.md +++ b/examples/README.md @@ -128,6 +128,8 @@ [fpmc_diginetica.py](fpmc_diginetica.py) - Example of Factorizing Personalized Markov Chains (FPMC) with Diginetica dataset. +[transformer_rec_diginetica.py](transformer_rec_diginetica.py) - Example of Transformer-based Recommendation models (SASRec, BERT4Rec, GPT2Rec) with Diginetica dataset. + ---- ## Next-Basket Algorithms