-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSymmetry_HACSurv_competing_syn_SurvivalL1.py
More file actions
157 lines (130 loc) · 6.45 KB
/
Symmetry_HACSurv_competing_syn_SurvivalL1.py
File metadata and controls
157 lines (130 loc) · 6.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import sys
# sys.path.append('/home/liuxin/HACSurv')
import warnings
warnings.filterwarnings("ignore")
import os
# 导入您需要的模块
from synthetic_dgp import linear_dgp_hac
from truth_net import Weibull_linear
from metric import surv_diff
from survival import MixExpPhiStochastic, HACSurv_4D_Sym_shared
import torch.optim as optim
# 设置设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(24)
torch.set_default_tensor_type(torch.DoubleTensor)
# 定义种子列表
seeds_list = [41, 42, 43, 44, 45]
# 用于保存每个种子的生存 L1 差异
all_survival_l1 = []
# 遍历每个种子
for seeds in seeds_list:
print(f"Running experiment with seed: {seeds}")
# 生成数据
copula_form = 'Clayton'
sample_size = 80000
rng = np.random.default_rng(142857) # 固定的随机数生成器,用于数据生成
X, observed_time, event_indicator, _, _, beta_e1, beta_e2, beta_e3 = linear_dgp_hac(
copula_name=copula_form, covariate_dim=10, theta=1, sample_size=sample_size, rng=rng
)
# 分割数据集
X_train, X_test, y_train, y_test, indicator_train, indicator_test = train_test_split(
X, observed_time, event_indicator, test_size=0.2, stratify=event_indicator, random_state=seeds
)
X_train, X_val, y_train, y_val, indicator_train, indicator_val = train_test_split(
X_train, y_train, indicator_train, test_size=0.2, stratify=indicator_train, random_state=seeds
)
# 转换为张量并移动到设备
covariate_tensor_train = torch.tensor(X_train, dtype=torch.float64).to(device)
covariate_tensor_val = torch.tensor(X_val, dtype=torch.float64).to(device)
covariate_tensor_test = torch.tensor(X_test, dtype=torch.float64).to(device)
times_tensor_train = torch.tensor(y_train, dtype=torch.float64).to(device)
event_indicator_tensor_train = torch.tensor(indicator_train, dtype=torch.float64).to(device)
times_tensor_val = torch.tensor(y_val, dtype=torch.float64).to(device)
event_indicator_tensor_val = torch.tensor(indicator_val, dtype=torch.float64).to(device)
times_tensor_test = torch.tensor(y_test, dtype=torch.float64).to(device)
event_indicator_tensor_test = torch.tensor(indicator_test, dtype=torch.float64).to(device)
# 定义模型
phi = MixExpPhiStochastic(device)
model = HACSurv_4D_Sym_shared(phi, device=device, num_features=10, tol=1e-14, hidden_size=100).to(device)
optimizer_out = optim.Adam([
{"params": model.shared_embedding.parameters(), "lr": 1e-4},
{"params": model.sumo_e1.parameters(), "lr": 1e-4},
{"params": model.sumo_e2.parameters(), "lr": 1e-4},
{"params": model.sumo_e3.parameters(), "lr": 1e-4},
{"params": model.sumo_c.parameters(), "lr": 1e-4},
# {"params": model.phi.parameters(), "lr": 8e-4},
], weight_decay=0)
from datetime import datetime
def current_time():
return datetime.now().strftime('%Y%m%d_%H%M%S')
best_val_loglikelihood = float('-inf')
epochs_no_improve = 0
num_epochs = 10000
early_stop_epochs = 1600
base_path = "./Competing_SYN/checkpoint"
best_model_filename = ""
for epoch in range(num_epochs):
# 训练过程
model.phi.resample_M(100)
optimizer_out.zero_grad()
logloss = model(covariate_tensor_train, times_tensor_train, event_indicator_tensor_train, max_iter=1000)
(-logloss).backward(retain_graph=True)
optimizer_out.step()
if epoch % 80 == 0:
model.eval()
val_loglikelihood = model(covariate_tensor_val, times_tensor_val, event_indicator_tensor_val, max_iter=1000)
print(f"Epoch {epoch}: Train loglikelihood {logloss.item()}, Val likelihood {val_loglikelihood.item()}")
# 检查是否为最佳模型
if val_loglikelihood > (best_val_loglikelihood + 1):
best_val_loglikelihood = val_loglikelihood
# 保存最佳模型
if best_model_filename:
os.remove(os.path.join(base_path, best_model_filename)) # 删除旧的最佳模型文件
best_model_filename = f"SurvL1_Independent_BestModel_loglik_{best_val_loglikelihood:.4f}_{current_time()}_seed{seeds}.pth"
torch.save(model.state_dict(), os.path.join(base_path, best_model_filename))
epochs_no_improve = 0
print('Best model updated and saved.')
else:
epochs_no_improve += 100
# 早停
if epochs_no_improve >= early_stop_epochs:
print(f'Early stopping triggered at epoch: {epoch}')
break
model.train()
# 加载最佳模型
model.load_state_dict(torch.load(os.path.join(base_path, best_model_filename)))
model.eval()
# 测试集上的对数似然
test_loglikelihood = model(covariate_tensor_test, times_tensor_test, event_indicator_tensor_test, max_iter=1000)
print(f"Test loglikelihood: {test_loglikelihood.item()}")
# 初始化真实模型
truth_model1 = Weibull_linear(num_feature=X_test.shape[1], shape=6, scale=15, device=torch.device("cpu"), coeff=beta_e1)
truth_model2 = Weibull_linear(num_feature=X_test.shape[1], shape=5, scale=14, device=torch.device("cpu"), coeff=beta_e2)
truth_model3 = Weibull_linear(num_feature=X_test.shape[1], shape=4, scale=19, device=torch.device("cpu"), coeff=beta_e3)
# 准备评估数据
steps = np.linspace(y_test.min(), y_test.max(), 1000)
survival_l1 = []
# 计算每个真实模型与预测模型的 L1 范数差异
for event_index, truth_model in enumerate([truth_model1, truth_model2, truth_model3]):
performance = surv_diff(truth_model, model, X_test, steps, event_index)
survival_l1.append(performance)
# 输出每个模型的结果
for i, performance in enumerate(survival_l1, 1):
print(f"Survival L1 difference for Model {i} with seed {seeds}: {performance}")
# 将结果保存到总列表中
all_survival_l1.append({
'seed': seeds,
'survival_l1': survival_l1
})
# 在所有种子上输出结果
print("\nAll survival L1 differences Independent Copual:")
for result in all_survival_l1:
seed = result['seed']
survival_l1 = result['survival_l1']
print(f"\nSeed: {seed}")
for i, performance in enumerate(survival_l1, 1):
print(f"Survival L1 difference for Model {i}: {performance}")