|
1 | 1 | import numpy as np |
| 2 | +import sys |
2 | 3 | import tensorflow as tf |
3 | 4 |
|
4 | | -from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat,get_feature_names |
5 | | -from deepctr.models import DIEN |
| 5 | + |
| 6 | +def version_tuple(version): |
| 7 | + return tuple(int(part) for part in version.split(".")[:3] if part.isdigit()) |
6 | 8 |
|
7 | 9 |
|
8 | 10 | def get_xy_fd(use_neg=False, hash_flag=False): |
@@ -49,8 +51,22 @@ def get_xy_fd(use_neg=False, hash_flag=False): |
49 | 51 |
|
50 | 52 |
|
51 | 53 | if __name__ == "__main__": |
| 54 | + if version_tuple(tf.__version__) >= (1, 14, 0): |
| 55 | + print( |
| 56 | + "run_dien.py skipped: this DIEN example enables AUGRU with negative " |
| 57 | + "sampling, which depends on legacy TensorFlow private RNN APIs. " |
| 58 | + "Please run it with TensorFlow < 1.14, or modify the example to " |
| 59 | + "disable negative sampling/use a supported DIEN configuration. " |
| 60 | + "Detected TensorFlow %s." % tf.__version__ |
| 61 | + ) |
| 62 | + sys.exit(0) |
| 63 | + |
52 | 64 | if tf.__version__ >= '2.0.0': |
53 | 65 | tf.compat.v1.disable_eager_execution() |
| 66 | + |
| 67 | + from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat,get_feature_names |
| 68 | + from deepctr.models import DIEN |
| 69 | + |
54 | 70 | USE_NEG = True |
55 | 71 | x, y, feature_columns, behavior_feature_list = get_xy_fd(use_neg=USE_NEG) |
56 | 72 |
|
|
0 commit comments