@@ -828,7 +828,7 @@ def __init__(self, args):
828828 def log (self , logs , step ):
829829 pass
830830
831- def close (self , model_path ):
831+ def close (self , model_path , early_stop ):
832832 pass
833833
834834class NeptuneLogger :
@@ -859,7 +859,8 @@ def log(self, logs, step):
859859 def upload_model (self , model_path ):
860860 self .neptune ['model' ].track_files (model_path )
861861
862- def close (self , model_path ):
862+ def close (self , model_path , early_stop ):
863+ self .neptune ['early_stop' ] = early_stop
863864 if self .should_upload_model :
864865 self .upload_model (model_path )
865866 self .neptune .stop ()
@@ -894,7 +895,8 @@ def upload_model(self, model_path):
894895 artifact .add_file (model_path )
895896 self .wandb .run .log_artifact (artifact )
896897
897- def close (self , model_path ):
898+ def close (self , model_path , early_stop ):
899+ self .wandb .run .summary ['early_stop' ] = early_stop
898900 if self .should_upload_model :
899901 self .upload_model (model_path )
900902 self .wandb .finish ()
@@ -905,7 +907,7 @@ def download(self):
905907 model_file = max (os .listdir (data_dir ))
906908 return f'{ data_dir } /{ model_file } '
907909
908- def train (env_name , args = None , vecenv = None , policy = None , logger = None , should_stop_early = None ):
910+ def train (env_name , args = None , vecenv = None , policy = None , logger = None , early_stop_fn = None ):
909911 args = args or load_config (env_name )
910912
911913 # Assume TorchRun DDP is used if LOCAL_RANK is set
@@ -944,7 +946,10 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, should_sto
944946 train_config = { ** args ['train' ], 'env' : env_name }
945947 pufferl = PuffeRL (train_config , vecenv , policy , logger )
946948
949+ # Sweep needs data for early stopped runs, so send data when steps > 100M
950+ logging_threshold = min (0.20 * train_config ['total_timesteps' ], 100_000_000 )
947951 all_logs = []
952+
948953 while pufferl .global_step < train_config ['total_timesteps' ]:
949954 if train_config ['device' ] == 'cuda' :
950955 torch .compiler .cudagraph_mark_step_begin ()
@@ -954,12 +959,19 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, should_sto
954959 logs = pufferl .train ()
955960
956961 if logs is not None :
957- if pufferl .global_step > 0.20 * train_config ['total_timesteps' ]:
962+ should_stop_early = False
963+ if early_stop_fn is not None :
964+ should_stop_early = early_stop_fn (logs )
965+ # This is hacky, but need to see if threshold looks reasonable
966+ if 'early_stop_threshold' in logs :
967+ pufferl .logger .log ({'environment/early_stop_threshold' : logs ['early_stop_threshold' ]}, logs ['agent_steps' ])
968+
969+ if pufferl .global_step > logging_threshold :
958970 all_logs .append (logs )
959971
960- if should_stop_early is not None and should_stop_early ( logs ) :
972+ if should_stop_early :
961973 model_path = pufferl .close ()
962- pufferl .logger .close (model_path )
974+ pufferl .logger .close (model_path , early_stop = True )
963975 return all_logs
964976
965977 # Final eval. You can reset the env here, but depending on
@@ -976,7 +988,7 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, should_sto
976988
977989 pufferl .print_dashboard ()
978990 model_path = pufferl .close ()
979- pufferl .logger .close (model_path )
991+ pufferl .logger .close (model_path , early_stop = False )
980992 return all_logs
981993
982994def eval (env_name , args = None , vecenv = None , policy = None ):
@@ -1053,6 +1065,30 @@ def sweep(args=None, env_name=None):
10531065 sweep = sweep_cls (args ['sweep' ])
10541066 points_per_run = args ['sweep' ]['downsample' ]
10551067 target_key = f'environment/{ args ["sweep" ]["metric" ]} '
1068+ running_target_buffer = deque (maxlen = 30 )
1069+
1070+ def stop_if_perf_below (logs ):
1071+ if stop_if_loss_nan (logs ):
1072+ logs ['is_loss_nan' ] = True
1073+ return True
1074+
1075+ if method != 'Protein' :
1076+ return False
1077+
1078+ if ('uptime' in logs and target_key in logs ):
1079+ metric_val , cost = logs [target_key ], logs ['uptime' ]
1080+ running_target_buffer .append (metric_val )
1081+ target_running_mean = np .mean (running_target_buffer )
1082+
1083+ # If metric distribution is percentile, threshold is also logit transformed
1084+ threshold = sweep .get_early_stop_threshold (cost )
1085+ logs ['early_stop_threshold' ] = max (threshold , - 5 ) # clipping for visualization
1086+
1087+ if sweep .should_stop (max (target_running_mean , metric_val ), cost ):
1088+ logs ['is_loss_nan' ] = False
1089+ return True
1090+ return False
1091+
10561092 for i in range (args ['max_runs' ]):
10571093 seed = time .time_ns () & 0xFFFFFFFF
10581094 random .seed (seed )
@@ -1063,7 +1099,7 @@ def sweep(args=None, env_name=None):
10631099 if i > 0 :
10641100 sweep .suggest (args )
10651101
1066- all_logs = train (env_name , args = args , should_stop_early = stop_if_loss_nan )
1102+ all_logs = train (env_name , args = args , early_stop_fn = stop_if_perf_below )
10671103 all_logs = [e for e in all_logs if target_key in e ]
10681104
10691105 if not all_logs :
@@ -1076,7 +1112,8 @@ def sweep(args=None, env_name=None):
10761112 costs = downsample ([log ['uptime' ] for log in all_logs ], points_per_run )
10771113 timesteps = downsample ([log ['agent_steps' ] for log in all_logs ], points_per_run )
10781114
1079- if len (timesteps ) > 0 and timesteps [- 1 ] < 0.7 * total_timesteps : # 0.7 is arbitrary
1115+ is_final_loss_nan = all_logs [- 1 ].get ('is_loss_nan' , False )
1116+ if is_final_loss_nan :
10801117 s = scores .pop ()
10811118 c = costs .pop ()
10821119 args ['train' ]['total_timesteps' ] = timesteps .pop ()
0 commit comments