Skip to content

Commit 5fde2fd

Browse files
committed
Fix shared embeddings and string sparse validation
1 parent 694d6d9 commit 5fde2fd

4 files changed

Lines changed: 97 additions & 14 deletions

File tree

deepctr/feature_column.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,23 @@
1414
DEFAULT_GROUP_NAME = "default_group"
1515

1616

17+
def _is_string_dtype(dtype):
18+
try:
19+
return tf.as_dtype(dtype) == tf.string
20+
except TypeError:
21+
return dtype == "string"
22+
23+
24+
def _check_sparse_feature_dtype(fc):
25+
if _is_string_dtype(fc.dtype) and not fc.use_hash:
26+
raise ValueError(
27+
"SparseFeat(name='{}', dtype='string') requires use_hash=True "
28+
"so string ids can be converted before embedding lookup. "
29+
"Alternatively, encode the feature values to integer ids before "
30+
"passing them to DeepCTR.".format(fc.name)
31+
)
32+
33+
1734
class SparseFeat(namedtuple('SparseFeat',
1835
['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'vocabulary_path', 'dtype', 'embeddings_initializer',
1936
'embedding_name',
@@ -129,12 +146,14 @@ def build_input_features(feature_columns, prefix=''):
129146
input_features = OrderedDict()
130147
for fc in feature_columns:
131148
if isinstance(fc, SparseFeat):
149+
_check_sparse_feature_dtype(fc)
132150
input_features[fc.name] = Input(
133151
shape=(1,), name=prefix + fc.name, dtype=fc.dtype)
134152
elif isinstance(fc, DenseFeat):
135153
input_features[fc.name] = Input(
136154
shape=(fc.dimension,), name=prefix + fc.name, dtype=fc.dtype)
137155
elif isinstance(fc, VarLenSparseFeat):
156+
_check_sparse_feature_dtype(fc)
138157
input_features[fc.name] = Input(shape=(fc.maxlen,), name=prefix + fc.name,
139158
dtype=fc.dtype)
140159
if fc.weight_name is not None:

deepctr/inputs.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,58 @@
1616
from .layers.utils import Hash
1717

1818

19+
def _create_embedding_layer(feat, l2_reg, prefix, name_suffix, mask_zero=False):
20+
emb = Embedding(feat.vocabulary_size, feat.embedding_dim,
21+
embeddings_initializer=feat.embeddings_initializer,
22+
embeddings_regularizer=l2(l2_reg),
23+
name=prefix + '_' + name_suffix + '_' + feat.embedding_name,
24+
mask_zero=mask_zero)
25+
emb.trainable = feat.trainable
26+
return emb
27+
28+
29+
def _check_embedding_compatible(embedding_name, existing_feat, feat):
30+
for attr in ('vocabulary_size', 'embedding_dim', 'trainable'):
31+
if getattr(existing_feat, attr) != getattr(feat, attr):
32+
raise ValueError(
33+
"Feature columns with the same embedding_name must share the same "
34+
"{}. embedding_name='{}' has {} and {}.".format(
35+
attr, embedding_name, getattr(existing_feat, attr), getattr(feat, attr)
36+
)
37+
)
38+
39+
1940
def get_inputs_list(inputs):
2041
return list(chain(*list(map(lambda x: x.values(), filter(lambda x: x is not None, inputs)))))
2142

2243

2344
def create_embedding_dict(sparse_feature_columns, varlen_sparse_feature_columns, seed, l2_reg,
2445
prefix='sparse_', seq_mask_zero=True):
2546
sparse_embedding = {}
47+
embedding_feature_dict = {}
48+
varlen_embedding_names = set(
49+
feat.embedding_name for feat in varlen_sparse_feature_columns
50+
) if varlen_sparse_feature_columns else set()
51+
2652
for feat in sparse_feature_columns:
27-
emb = Embedding(feat.vocabulary_size, feat.embedding_dim,
28-
embeddings_initializer=feat.embeddings_initializer,
29-
embeddings_regularizer=l2(l2_reg),
30-
name=prefix + '_emb_' + feat.embedding_name)
31-
emb.trainable = feat.trainable
32-
sparse_embedding[feat.embedding_name] = emb
53+
embedding_name = feat.embedding_name
54+
if embedding_name in sparse_embedding:
55+
_check_embedding_compatible(embedding_name, embedding_feature_dict[embedding_name], feat)
56+
continue
57+
mask_zero = seq_mask_zero and feat.embedding_name in varlen_embedding_names
58+
emb = _create_embedding_layer(feat, l2_reg, prefix, 'emb', mask_zero)
59+
sparse_embedding[embedding_name] = emb
60+
embedding_feature_dict[embedding_name] = feat
3361

3462
if varlen_sparse_feature_columns and len(varlen_sparse_feature_columns) > 0:
3563
for feat in varlen_sparse_feature_columns:
36-
# if feat.name not in sparse_embedding:
37-
emb = Embedding(feat.vocabulary_size, feat.embedding_dim,
38-
embeddings_initializer=feat.embeddings_initializer,
39-
embeddings_regularizer=l2(
40-
l2_reg),
41-
name=prefix + '_seq_emb_' + feat.name,
42-
mask_zero=seq_mask_zero)
43-
emb.trainable = feat.trainable
64+
embedding_name = feat.embedding_name
65+
if embedding_name in sparse_embedding:
66+
_check_embedding_compatible(embedding_name, embedding_feature_dict[embedding_name], feat)
67+
continue
68+
emb = _create_embedding_layer(feat, l2_reg, prefix, 'seq_emb', seq_mask_zero)
4469
sparse_embedding[feat.embedding_name] = emb
70+
embedding_feature_dict[feat.embedding_name] = feat
4571
return sparse_embedding
4672

4773

tests/feature_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from deepctr.models import DeepFM
22
from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat, get_feature_names
3+
from deepctr.inputs import create_embedding_matrix
34
import numpy as np
5+
import pytest
46

57

68
def test_long_dense_vector():
@@ -28,3 +30,28 @@ def test_feature_column_sparsefeat_vocabulary_path():
2830
vlsf = VarLenSparseFeat(sf, 6)
2931
if vlsf.vocabulary_path != vocab_path:
3032
raise ValueError("vlsf.vocabulary_path is invalid")
33+
34+
35+
def test_create_embedding_matrix_reuses_same_embedding_name():
36+
feature_columns = [
37+
SparseFeat('item_id', 4, embedding_dim=8),
38+
SparseFeat('item_id_copy', 4, embedding_dim=8, embedding_name='item_id'),
39+
VarLenSparseFeat(SparseFeat('hist_item_id', 4, embedding_dim=8, embedding_name='item_id'), maxlen=3),
40+
VarLenSparseFeat(SparseFeat('neg_hist_item_id', 4, embedding_dim=8, embedding_name='item_id'), maxlen=3),
41+
]
42+
43+
embedding_dict = create_embedding_matrix(feature_columns, l2_reg=0, seed=1024)
44+
45+
assert list(embedding_dict.keys()) == ['item_id']
46+
assert embedding_dict['item_id'].name == 'sparse_emb_item_id'
47+
assert embedding_dict['item_id'].mask_zero is True
48+
49+
50+
def test_create_embedding_matrix_rejects_inconsistent_shared_embedding():
51+
feature_columns = [
52+
SparseFeat('item_id', 4, embedding_dim=8),
53+
VarLenSparseFeat(SparseFeat('hist_item_id', 5, embedding_dim=8, embedding_name='item_id'), maxlen=3),
54+
]
55+
56+
with pytest.raises(ValueError, match="same embedding_name"):
57+
create_embedding_matrix(feature_columns, l2_reg=0, seed=1024)

tests/models/MTL_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import tensorflow as tf
33

4+
from deepctr.feature_column import SparseFeat
45
from deepctr.models.multitask import SharedBottom, ESMM, MMOE, PLE
56
from ..utils_mtl import get_mtl_test_data, check_mtl_model
67

@@ -27,6 +28,16 @@ def test_ESMM():
2728
check_mtl_model(model, model_name, x, y_list, task_types=['binary', 'binary'])
2829

2930

31+
def test_ESMM_string_sparse_requires_hash():
32+
with pytest.raises(ValueError, match="use_hash=True"):
33+
ESMM([SparseFeat('user_id', 10, dtype='string')], tower_dnn_hidden_units=(8,))
34+
35+
36+
def test_ESMM_string_sparse_with_hash():
37+
model = ESMM([SparseFeat('user_id', 10, use_hash=True, dtype='string')], tower_dnn_hidden_units=(8,))
38+
assert len(model.outputs) == 2
39+
40+
3041
def test_MMOE():
3142
if tf.__version__ == "1.15.0": # slow in tf 1.15
3243
return

0 commit comments

Comments
 (0)