Skip to content

Commit 33661e7

Browse files
authored
Merge pull request #425 from kywch/stop-train
Adaptive early-stop thresholding in sweep
2 parents 8f88d9f + 12b284d commit 33661e7

5 files changed

Lines changed: 258 additions & 39 deletions

File tree

pufferlib/config/default.ini

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@ prio_beta0 = 0.2
6363
[sweep]
6464
method = Protein
6565
metric = score
66+
metric_distribution = linear
6667
goal = maximize
6768
max_suggestion_cost = 3600
6869
downsample = 5
6970
use_gpu = True
7071
prune_pareto = True
72+
early_stop_quantile = 0.3
7173

7274
#[sweep.vec.num_envs]
7375
#distribution = uniform_pow2
@@ -100,6 +102,12 @@ min = 0.00001
100102
max = 0.1
101103
scale = 0.5
102104

105+
[sweep.train.min_lr_ratio]
106+
distribution = uniform
107+
min = 0.0
108+
max = 0.5
109+
scale = auto
110+
103111
[sweep.train.ent_coef]
104112
distribution = log_normal
105113
min = 0.00001

pufferlib/config/ocean/breakout.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ vf_coef = 1.6832989594296321
5353
vtrace_c_clip = 2.878171091654008
5454
vtrace_rho_clip = 0.7876748061547312
5555

56+
[sweep]
57+
5658
[sweep.train.total_timesteps]
5759
distribution = log_normal
5860
min = 3e7

pufferlib/config/ocean/tower_climb.ini

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,53 +5,72 @@ policy_name = TowerClimb
55
rnn_name = TowerClimbLSTM
66

77
[vec]
8-
num_envs = 8
8+
num_envs = 4
99

1010
[env]
1111
num_envs = 1024
12-
num_maps = 50
13-
reward_climb_row = 0.636873185634613
14-
reward_fall_row = -0.15898257493972778
15-
reward_illegal_move = -0.003928301855921745
16-
reward_move_block = 0.235064297914505
12+
num_maps = 200
13+
reward_climb_row = 0.27
14+
reward_fall_row = 0
15+
reward_illegal_move = 0
16+
reward_move_block = 0.18
1717

1818
[train]
19-
total_timesteps = 150_000_000
20-
#gamma = 0.98
21-
#learning_rate = 0.05
22-
minibatch_size = 32768
19+
# https://wandb.ai/kywch/pufferlib/runs/8r3l9l1h?nw=nwuserkywch
20+
total_timesteps = 30_000_000
21+
anneal_lr = True
22+
batch_size = auto
23+
bptt_horizon = 64
24+
minibatch_size = 65536
25+
26+
clip_coef = 1.0
27+
ent_coef = 0.2
28+
gae_lambda = 0.96
29+
gamma = 0.92
30+
vf_clip_coef = 0.1
31+
vf_coef = 0.34
32+
33+
learning_rate = 0.029
34+
max_grad_norm = 0.8
35+
36+
adam_beta1 = 0.89
37+
adam_beta2 = 0.999
38+
adam_eps = 2e-11
39+
prio_alpha = 0.86
40+
prio_beta0 = 0.30
41+
vtrace_c_clip = 0.92
42+
vtrace_rho_clip = 1.44
43+
44+
[sweep]
45+
metric = perf
46+
metric_distribution = percentile
2347

2448
[sweep.train.total_timesteps]
2549
distribution = uniform
26-
min = 50_000_000
50+
min = 10_000_000
2751
max = 200_000_000
28-
mean = 100_000_000
2952
scale = 0.5
3053

3154
[sweep.env.reward_climb_row]
3255
distribution = uniform
3356
min = 0.0
3457
max = 1.0
35-
mean = 0.5
3658
scale = auto
3759

3860
[sweep.env.reward_fall_row]
3961
distribution = uniform
4062
min = -1.0
4163
max = 0.0
42-
mean = -0.5
4364
scale = auto
4465

4566
[sweep.env.reward_illegal_move]
4667
distribution = uniform
4768
min = -1e-2
4869
max = -1e-4
49-
mean = -1e-3
5070
scale = auto
5171

5272
[sweep.env.reward_move_block]
5373
distribution = uniform
5474
min = 0.0
5575
max = 1.0
56-
mean = 0.5
5776
scale = auto

pufferlib/pufferl.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def __init__(self, args):
828828
def log(self, logs, step):
829829
pass
830830

831-
def close(self, model_path):
831+
def close(self, model_path, early_stop):
832832
pass
833833

834834
class NeptuneLogger:
@@ -859,7 +859,8 @@ def log(self, logs, step):
859859
def upload_model(self, model_path):
860860
self.neptune['model'].track_files(model_path)
861861

862-
def close(self, model_path):
862+
def close(self, model_path, early_stop):
863+
self.neptune['early_stop'] = early_stop
863864
if self.should_upload_model:
864865
self.upload_model(model_path)
865866
self.neptune.stop()
@@ -894,7 +895,8 @@ def upload_model(self, model_path):
894895
artifact.add_file(model_path)
895896
self.wandb.run.log_artifact(artifact)
896897

897-
def close(self, model_path):
898+
def close(self, model_path, early_stop):
899+
self.wandb.run.summary['early_stop'] = early_stop
898900
if self.should_upload_model:
899901
self.upload_model(model_path)
900902
self.wandb.finish()
@@ -905,7 +907,7 @@ def download(self):
905907
model_file = max(os.listdir(data_dir))
906908
return f'{data_dir}/{model_file}'
907909

908-
def train(env_name, args=None, vecenv=None, policy=None, logger=None, should_stop_early=None):
910+
def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop_fn=None):
909911
args = args or load_config(env_name)
910912

911913
# Assume TorchRun DDP is used if LOCAL_RANK is set
@@ -944,7 +946,10 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, should_sto
944946
train_config = { **args['train'], 'env': env_name }
945947
pufferl = PuffeRL(train_config, vecenv, policy, logger)
946948

949+
# Sweep needs data for early stopped runs, so send data when steps > 100M
950+
logging_threshold = min(0.20*train_config['total_timesteps'], 100_000_000)
947951
all_logs = []
952+
948953
while pufferl.global_step < train_config['total_timesteps']:
949954
if train_config['device'] == 'cuda':
950955
torch.compiler.cudagraph_mark_step_begin()
@@ -954,12 +959,19 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, should_sto
954959
logs = pufferl.train()
955960

956961
if logs is not None:
957-
if pufferl.global_step > 0.20*train_config['total_timesteps']:
962+
should_stop_early = False
963+
if early_stop_fn is not None:
964+
should_stop_early = early_stop_fn(logs)
965+
# This is hacky, but need to see if threshold looks reasonable
966+
if 'early_stop_threshold' in logs:
967+
pufferl.logger.log({'environment/early_stop_threshold': logs['early_stop_threshold']}, logs['agent_steps'])
968+
969+
if pufferl.global_step > logging_threshold:
958970
all_logs.append(logs)
959971

960-
if should_stop_early is not None and should_stop_early(logs):
972+
if should_stop_early:
961973
model_path = pufferl.close()
962-
pufferl.logger.close(model_path)
974+
pufferl.logger.close(model_path, early_stop=True)
963975
return all_logs
964976

965977
# Final eval. You can reset the env here, but depending on
@@ -976,7 +988,7 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, should_sto
976988

977989
pufferl.print_dashboard()
978990
model_path = pufferl.close()
979-
pufferl.logger.close(model_path)
991+
pufferl.logger.close(model_path, early_stop=False)
980992
return all_logs
981993

982994
def eval(env_name, args=None, vecenv=None, policy=None):
@@ -1053,6 +1065,30 @@ def sweep(args=None, env_name=None):
10531065
sweep = sweep_cls(args['sweep'])
10541066
points_per_run = args['sweep']['downsample']
10551067
target_key = f'environment/{args["sweep"]["metric"]}'
1068+
running_target_buffer = deque(maxlen=30)
1069+
1070+
def stop_if_perf_below(logs):
1071+
if stop_if_loss_nan(logs):
1072+
logs['is_loss_nan'] = True
1073+
return True
1074+
1075+
if method != 'Protein':
1076+
return False
1077+
1078+
if ('uptime' in logs and target_key in logs):
1079+
metric_val, cost = logs[target_key], logs['uptime']
1080+
running_target_buffer.append(metric_val)
1081+
target_running_mean = np.mean(running_target_buffer)
1082+
1083+
# If metric distribution is percentile, threshold is also logit transformed
1084+
threshold = sweep.get_early_stop_threshold(cost)
1085+
logs['early_stop_threshold'] = max(threshold, -5) # clipping for visualization
1086+
1087+
if sweep.should_stop(max(target_running_mean, metric_val), cost):
1088+
logs['is_loss_nan'] = False
1089+
return True
1090+
return False
1091+
10561092
for i in range(args['max_runs']):
10571093
seed = time.time_ns() & 0xFFFFFFFF
10581094
random.seed(seed)
@@ -1063,7 +1099,7 @@ def sweep(args=None, env_name=None):
10631099
if i > 0:
10641100
sweep.suggest(args)
10651101

1066-
all_logs = train(env_name, args=args, should_stop_early=stop_if_loss_nan)
1102+
all_logs = train(env_name, args=args, early_stop_fn=stop_if_perf_below)
10671103
all_logs = [e for e in all_logs if target_key in e]
10681104

10691105
if not all_logs:
@@ -1076,7 +1112,8 @@ def sweep(args=None, env_name=None):
10761112
costs = downsample([log['uptime'] for log in all_logs], points_per_run)
10771113
timesteps = downsample([log['agent_steps'] for log in all_logs], points_per_run)
10781114

1079-
if len(timesteps) > 0 and timesteps[-1] < 0.7 * total_timesteps: # 0.7 is arbitrary
1115+
is_final_loss_nan = all_logs[-1].get('is_loss_nan', False)
1116+
if is_final_loss_nan:
10801117
s = scores.pop()
10811118
c = costs.pop()
10821119
args['train']['total_timesteps'] = timesteps.pop()

0 commit comments

Comments
 (0)