-
Notifications
You must be signed in to change notification settings - Fork 104
Expand file tree
/
Copy pathfsdp_workers.py
More file actions
1728 lines (1419 loc) · 84.1 KB
/
fsdp_workers.py
File metadata and controls
1728 lines (1419 loc) · 84.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""
import os
import logging
import warnings
import ray
import torch
import torch.distributed
from omegaconf import DictConfig, open_dict
from transformers import AutoModelForCausalLM
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.utils.model import compute_position_id_with_mask
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, load_fsdp_grad, offload_fsdp_grad, init_fn, get_init_weight_context_manager, get_fsdp_wrap_policy_vla
from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_param_and_grad, load_fsdp_optimizer, load_fsdp_param_and_grad
from verl.utils.import_utils import import_external_libs
from verl.utils.debug import log_gpu_memory_usage
import verl.utils.hdfs_io as hdfs_io
from verl.utils import hf_tokenizer
from ..trainer.ppo import core_algos
from verl.utils.py_functional import append_to_dict
from codetiming import Timer
from verl.utils.openvla_utils import update_auto_map , check_model_logic_mismatch
from peft import LoraConfig, PeftModel, get_peft_model, TaskType
import json
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
def convert_to_regular_types(obj):
"""Convert Hydra configs and other special types to regular Python types."""
from omegaconf import ListConfig, DictConfig
if isinstance(obj, (ListConfig, DictConfig)):
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
elif isinstance(obj, (list, tuple)):
return [convert_to_regular_types(x) for x in obj]
elif isinstance(obj, dict):
return {k: convert_to_regular_types(v) for k, v in obj.items()}
return obj
class RobActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str):
super().__init__()
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
# build device mesh
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
self._is_lora = self.config.model.get('lora_rank', 0) > 0
self.role = role
assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']
self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref']
self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref']
self._is_ref = self.role in ['ref', 'actor_rollout_ref']
self._is_offload_param = False
self._is_offload_grad = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False)
self._is_offload_grad = self.config.actor.fsdp_config.get('grad_offload', False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False)
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False)
# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0]
self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0]
if self._is_rollout:
self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0]
if self._is_ref:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0]
def _build_model_optimizer(self,
model_path,
fsdp_config,
optim_config,
override_model_config,
enable_gradient_checkpointing=False,
trust_remote_code=False):
from verl.utils.model import print_model_size, update_model_config
from verl.utils.torch_dtypes import PrecisionType
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \
CPUOffload
from torch import optim
log_gpu_memory_usage('Before init from HF AutoModel', logger=logger)
local_path = copy_local_path_from_hdfs(model_path)
#add oft
if self.config.model.vla == "openvla-oft":
from verl.utils.vla_utils.openvla_oft.configuration_prismatic import OpenVLAConfig
from verl.utils.vla_utils.openvla_oft.modeling_prismatic import OpenVLAForActionPrediction
from verl.utils.vla_utils.openvla_oft.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
AutoConfig.register("openvla", OpenVLAConfig)
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
if self.rank == 0:
update_auto_map(local_path)
check_model_logic_mismatch(local_path)
torch.distributed.barrier()
elif self.config.model.vla == "openvla":
from verl.utils.vla_utils.openvla.configuration_prismatic import OpenVLAConfig
from verl.utils.vla_utils.openvla.modeling_prismatic import OpenVLAForActionPrediction
from verl.utils.vla_utils.openvla.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
AutoConfig.register("openvla", OpenVLAConfig)
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
if self.rank == 0:
update_auto_map(local_path)
check_model_logic_mismatch(local_path)
torch.distributed.barrier()
#add end
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code, model = self.config.model.vla)
torch_dtype = fsdp_config.get('model_dtype', None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
if self.config.model.use_remove_padding:
from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(actor_model_config.model_type)
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
if self.config.rollout.use_proprio:
override_config_kwargs["use_proprio"] = True
override_config_kwargs["proprio_dim"] = self.config.model.action_token_len
else:
override_config_kwargs["use_proprio"] = False
override_config_kwargs["proprio_dim"] = self.config.model.action_token_len
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f'Model config after override: {actor_model_config}')
init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if self.config.model.vla == "openvla-oft":
actor_module = AutoModelForVision2Seq.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
#attn_implementation="flash_attention_2",
config=actor_model_config,
trust_remote_code=True,
)
if self.config.rollout.use_proprio and self.config.model.resume == False:
# Load proprio projector weights if available
actor_module.load_proprio_projector_weights(local_path)
print("******Loaded pre-trained proprio projector weights*********")
#oft add
actor_module.vision_backbone.set_num_images_in_input(self.config.actor.num_images_in_input)
dataset_statistics_path = os.path.join(local_path, "dataset_statistics.json")
if os.path.isfile(dataset_statistics_path):
with open(dataset_statistics_path, "r") as f:
norm_stats = json.load(f)
actor_module.norm_stats = norm_stats
else:
print(
"WARNING: No local dataset_statistics.json file found for current checkpoint.\n"
"You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint."
"Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`."
)
elif self.config.model.vla == "openvla":
actor_module = AutoModelForVision2Seq.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
config=actor_model_config,
trust_remote_code=True,
)
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable()
# lora add
if self._is_lora:
print("Applying LoRA to actor module")
lora_config = {
#'task_type': TaskType.CAUSAL_LM,
'r': self.config.model.lora_rank,
'lora_alpha': self.config.model.lora_alpha,
"lora_dropout": 0 ,
'target_modules': convert_to_regular_types(self.config.model.target_modules),
'init_lora_weights': "gaussian"
}
actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))
actor_module.print_trainable_parameters()
# lora end
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
log_gpu_memory_usage('After init from HF AutoModel', logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get('mixed_precision', None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16'))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32'))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32'))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
if self._is_ref:
mixed_precision = None
#oft add
auto_wrap_policy = get_fsdp_wrap_policy_vla(module=actor_module, config=fsdp_config.get('wrap_policy', None), is_lora=self.config.model.get('lora_rank', 0) > 0)
#oft add end
print(f'wrap_policy: {auto_wrap_policy}')
# TODO(sgm): support hybrid
if auto_wrap_policy is None:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
else:
sharding_strategy = ShardingStrategy.FULL_SHARD
# TODO: add transformer policy
actor_module_fsdp = FSDP(
actor_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh)
log_gpu_memory_usage('After Actor FSDP init', logger=logger)
# TODO: add more optimizer args into config
if self._is_actor:
from verl.utils.torch_functional import get_constant_schedule_with_warmup
actor_optimizer = optim.AdamW(actor_module_fsdp.parameters(),
lr=optim_config.lr,
betas=optim_config.get('betas', (0.9, 0.999)),
weight_decay=optim_config.get('weight_decay', 1e-2))
total_steps = optim_config.get('total_training_steps', 0)
num_warmup_steps_ratio = optim_config.get('lr_warmup_steps_ratio', 0.)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}')
actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps)
else:
actor_optimizer = None
actor_lr_scheduler = None
log_gpu_memory_usage('After actor optimizer init', logger=logger)
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
def _build_rollout(self):
if self.config.rollout.name == 'hf':
from verl.workers.rollout import RobHFRollout
from verl.workers.hybrid_engine import BaseShardingManager
rollout = RobHFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm':
raise ValueError
# from verl.workers.rollout.vllm_rollout import vLLMRollout
# from verl.workers.hybrid_engine import FSDPVLLMShardingManager
# log_gpu_memory_usage('Before building vllm rollout', logger=None)
# rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
# config=self.config.rollout,
# tokenizer=self.tokenizer,
# model_hf_config=self.actor_model_config)
# log_gpu_memory_usage('After building vllm rollout', logger=None)
# if torch.distributed.get_world_size() == 1:
# self.config.rollout.load_format = 'dummy_hf'
# sharding_manager = FSDPVLLMShardingManager(module=self.actor_module_fsdp,
# inference_engine=rollout.inference_engine,
# model_config=self.actor_model_config,
# full_params='hf' in self.config.rollout.load_format)
# log_gpu_memory_usage('After building sharding manager', logger=None)
return rollout, sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from verl.workers.actor import RobDataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
else:
optim_config = None
fsdp_config = OmegaConf.create()
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
trust_remote_code=True) #self.config.model.get('trust_remote_code', True)
# get the original unwrapped module
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_param:
# param is require during state_dict in sharding manager
offload_fsdp_grad(module=self.actor_module_fsdp)
log_gpu_memory_usage('After offload actor grad during init', logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage('After offload actor optimizer during init', logger=logger)
# load from checkpoint
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
self.actor = RobDataParallelPPOActor(config=self.config.actor,
actor_module=self.actor_module_fsdp,
actor_optimizer=self.actor_optimizer)
if self._is_rollout:
self.rollout, self.sharding_manager = self._build_rollout()
if self._is_ref:
self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
trust_remote_code=True)[0] #self.config.model.get('trust_remote_code', False)
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
OmegaConf.set_struct(self.config.ref, True)
self.ref_policy = RobDataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
#data = data.to('cuda')
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
#data.batch = data.batch.cuda()
log_gpu_memory_usage('Before update policy', logger=logger)
metrics = self.actor.update_policy(data=data)
self.actor_lr_scheduler.step()
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics['actor/lr(1e-4)'] = lr * 1e4
log_gpu_memory_usage('After update policy', logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={'metrics': metrics})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_entropy(self, data: DataProto):
data = data.to('cuda')
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
data.batch = data.batch.cuda()
log_gpu_memory_usage('Before compute entropy', logger=logger)
metrics = self.actor.compute_entropy(bacth_data=data)
log_gpu_memory_usage('After compute entropy', logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={'metrics': metrics})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts):
prompts = prompts.to('cuda')
# set to False if it is validation
recompute_log_prob = prompts.meta_info.get('recompute_log_prob', True)
assert self._is_rollout
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
prompts.meta_info.update(meta_info)
#tmp_sample = prompts.meta_info.get('n_samples', -1)
# with Timer(name=f'gen seq will start, and the num samples are: {tmp_sample}', text="{name}: {seconds:.1f} seconds") as timer:
# print(f"gen seq will start, and the num samples are: {tmp_sample}")
with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger)
prompts = self.sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage('After rollout generation', logger=logger)
output = self.sharding_manager.postprocess_data(output)
torch.cuda.synchronize()
# with Timer(name=f'gen seq end , old log will begin', text="{name}: {seconds:.1f} seconds") as timer:
# print("gen seq end , old log will begin")
if self._is_actor and recompute_log_prob:
# we should always recompute old_log_probs when it is HybridEngine
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size
output.meta_info['temperature'] = self.config.rollout.temperature
output.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz
output.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu
output.meta_info['pad_token_id'] = self.tokenizer.pad_token_id
old_log_probs = self.actor.compute_log_prob(data=output)
output.batch['old_log_probs'] = old_log_probs
output = output.to('cpu')
if self._is_offload_param:
# NOTE(sgm): the grad is already in CPU, only offload param here
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
# clear kv cache
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
log_gpu_memory_usage('After recompute log prob', logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
data = data.to('cuda')
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.ref_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
micro_batch_size = self.config.ref.log_prob_micro_batch_size
data.meta_info['micro_batch_size'] = micro_batch_size
data.meta_info['temperature'] = self.config.rollout.temperature
data.meta_info['max_token_len'] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info['use_dynamic_bsz'] = self.config.ref.log_prob_use_dynamic_bsz
data.meta_info['pad_token_id'] = self.tokenizer.pad_token_id
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={'ref_log_prob': output})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None):
assert self._is_actor
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from peft import PeftModel
import transformers
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
#lora add
if self._is_lora and isinstance(self.actor_module, PeftModel):
if dist.get_rank() == 0:
os.makedirs(local_path, exist_ok=True)
lora_save_path = os.path.join(local_path, "lora_adapter")
if isinstance(self.actor_module_fsdp, FSDP):
with FSDP.summon_full_params(self.actor_module_fsdp, writeback=False, offload_to_cpu=True):
if dist.get_rank() == 0:
from typing import OrderedDict
lora_params = OrderedDict()
model = self.actor_module_fsdp._fsdp_wrapped_module.base_model.model
for name, param in model.named_parameters():
if ".lora_" in name:
name = "base_model.model." + name.replace("._fsdp_wrapped_module.", ".")
lora_params[name] = param
self.actor_module_fsdp.save_pretrained(
lora_save_path,
state_dict=lora_params,
safe_serialization=True
)
else:
self.actor_module.save_pretrained(lora_save_path, safe_serialization=True)
dist.barrier()
if dist.get_rank() == 0:
print(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}")
# save total model
base_vla = AutoModelForVision2Seq.from_pretrained(
self.config.model.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map="cpu"
)
merged_vla = PeftModel.from_pretrained(base_vla, lora_save_path)
merged_vla = merged_vla.merge_and_unload()
if dist.get_rank() == 0:
merged_vla.save_pretrained(local_path)
print(f"Saved merged model at: {local_path}")
# Wait for merged model to be saved
dist.barrier()
# TODO: support DCP and save sharded checkpoints
else:
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(self.actor.actor_module, StateDictType.FULL_STATE_DICT, cfg):
state_dict = self.actor.actor_module.state_dict()
if self.rank == 0:
print(f'Saving actor checkpoint to {local_path}')
os.makedirs(local_path, exist_ok=True)
self.actor_module.save_pretrained(local_path, state_dict=state_dict)
self.tokenizer.save_pretrained(local_path)
if hdfs_path is not None:
print(f'Uploading actor checkpoint to {hdfs_path}')
hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=local_path, dst=hdfs_path)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
class ActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str):
super().__init__()
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
# build device mesh
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
self.role = role
assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']
self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref']
self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref']
self._is_ref = self.role in ['ref', 'actor_rollout_ref']
self._is_offload_param = False
self._is_offload_grad = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False)
self._is_offload_grad = self.config.actor.fsdp_config.get('grad_offload', False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False)
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False)
# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0]
self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0]
if self._is_rollout:
self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0]
if self._is_ref:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0]
def _build_model_optimizer(self,
model_path,
fsdp_config,
optim_config,
override_model_config,
enable_gradient_checkpointing=False,
trust_remote_code=False):
from verl.utils.model import print_model_size, update_model_config
from verl.utils.torch_dtypes import PrecisionType
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \
CPUOffload
from torch import optim
log_gpu_memory_usage('Before init from HF AutoModel', logger=logger)
local_path = copy_local_path_from_hdfs(model_path)
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
torch_dtype = fsdp_config.get('model_dtype', None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
if self.config.model.use_remove_padding:
from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(actor_model_config.model_type)
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f'Model config after override: {actor_model_config}')
# NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
from liger_kernel.transformers import AutoLigerKernelForCausalLM
actor_module = AutoLigerKernelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable()
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
log_gpu_memory_usage('After init from HF AutoModel', logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get('mixed_precision', None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16'))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32'))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32'))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
if self._is_ref:
mixed_precision = None
auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None))
if self._is_rollout and self.config.rollout.name == 'hf':
# TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
auto_wrap_policy = None
print(f'wrap_policy: {auto_wrap_policy}')
# TODO(sgm): support hybrid
if auto_wrap_policy is None:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
else:
sharding_strategy = ShardingStrategy.FULL_SHARD
# TODO: add transformer policy
actor_module_fsdp = FSDP(
actor_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh)
log_gpu_memory_usage('After Actor FSDP init', logger=logger)
# TODO: add more optimizer args into config
if self._is_actor:
from verl.utils.torch_functional import get_constant_schedule_with_warmup
actor_optimizer = optim.AdamW(actor_module_fsdp.parameters(),
lr=optim_config.lr,
betas=optim_config.get('betas', (0.9, 0.999)),
weight_decay=optim_config.get('weight_decay', 1e-2))
total_steps = optim_config.get('total_training_steps', 0)
num_warmup_steps_ratio = optim_config.get('lr_warmup_steps_ratio', 0.)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}')
actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps)
else:
actor_optimizer = None
actor_lr_scheduler = None
log_gpu_memory_usage('After actor optimizer init', logger=logger)
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
def _build_rollout(self):
if self.config.rollout.name == 'hf':
from verl.workers.rollout import HFRollout
from verl.workers.hybrid_engine import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.hybrid_engine import FSDPVLLMShardingManager
log_gpu_memory_usage('Before building vllm rollout', logger=None)
rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config)
log_gpu_memory_usage('After building vllm rollout', logger=None)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = 'dummy_hf'
sharding_manager = FSDPVLLMShardingManager(module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params='hf' in self.config.rollout.load_format)
log_gpu_memory_usage('After building sharding manager', logger=None)
return rollout, sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from verl.workers.actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
else:
optim_config = None
fsdp_config = OmegaConf.create()
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
trust_remote_code=self.config.model.get('trust_remote_code', False))
# get the original unwrapped module
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_param:
# param is require during state_dict in sharding manager
offload_fsdp_grad(module=self.actor_module_fsdp)
log_gpu_memory_usage('After offload actor grad during init', logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage('After offload actor optimizer during init', logger=logger)
# load from checkpoint
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
self.actor = DataParallelPPOActor(config=self.config.actor,
actor_module=self.actor_module_fsdp,
actor_optimizer=self.actor_optimizer)
if self._is_rollout:
self.rollout, self.sharding_manager = self._build_rollout()
if self._is_ref:
self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
trust_remote_code=self.config.model.get(
'trust_remote_code', False))[0]
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
OmegaConf.set_struct(self.config.ref, True)
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
data = data.to('cuda')
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
data.batch = data.batch.cuda()
log_gpu_memory_usage('Before update policy', logger=logger)
metrics = self.actor.update_policy(data=data)
self.actor_lr_scheduler.step()
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics['actor/lr(1e-4)'] = lr * 1e4
log_gpu_memory_usage('After update policy', logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={'metrics': metrics})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_entropy(self, data: DataProto):
data = data.to('cuda')
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
data.batch = data.batch.cuda()
log_gpu_memory_usage('Before compute entropy', logger=logger)
metrics = self.actor.compute_entropy(bacth_data=data)
log_gpu_memory_usage('After compute entropy', logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={'metrics': metrics})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
prompts = prompts.to('cuda')
# set to False if it is validation
recompute_log_prob = prompts.meta_info.get('recompute_log_prob', True)
assert self._is_rollout
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
prompts.meta_info.update(meta_info)
with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger)