|
1 | | -#=============================================================================== |
| 1 | +# =============================================================================== |
2 | 2 | # Copyright 2020-2021 Intel Corporation |
3 | 3 | # |
4 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | | -#=============================================================================== |
| 15 | +# =============================================================================== |
16 | 16 |
|
17 | | -import sys |
18 | | -import os |
19 | 17 | import argparse |
20 | | -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
21 | | -import bench |
| 18 | +import warnings |
| 19 | +from typing import Any |
22 | 20 |
|
| 21 | +import bench |
23 | 22 | import numpy as np |
24 | 23 | from cuml import KMeans |
25 | | -import warnings |
26 | 24 | from sklearn.metrics.cluster import davies_bouldin_score |
27 | 25 |
|
28 | 26 | warnings.filterwarnings('ignore', category=FutureWarning) |
|
41 | 39 | # Load and convert generated data |
42 | 40 | X_train, X_test, _, _ = bench.load_data(params) |
43 | 41 |
|
| 42 | +X_init: Any |
44 | 43 | if params.filei == 'k-means++': |
45 | 44 | X_init = 'k-means++' |
46 | 45 | # Load initial centroids from specified path |
47 | 46 | elif params.filei is not None: |
48 | | - X_init = np.load(params.filei).astype(params.dtype) |
49 | | - params.n_clusters = X_init.shape[0] |
| 47 | + X_init = {k: v.astype(params.dtype) for k, v in np.load(params.filei).items()} |
| 48 | + if isinstance(X_init, np.ndarray): |
| 49 | + params.n_clusters = X_init.shape[0] |
50 | 50 | # or choose random centroids from training data |
51 | 51 | else: |
52 | 52 | np.random.seed(params.seed) |
53 | | - centroids_idx = np.random.randint(0, X_train.shape[0], |
| 53 | + centroids_idx = np.random.randint(low=0, high=X_train.shape[0], |
54 | 54 | size=params.n_clusters) |
55 | 55 | if hasattr(X_train, "iloc"): |
56 | 56 | X_init = X_train.iloc[centroids_idx].to_pandas().values |
|
0 commit comments