-
Notifications
You must be signed in to change notification settings - Fork 104
Expand file tree
/
Copy pathdp_rob.py
More file actions
607 lines (468 loc) · 29.7 KB
/
dp_rob.py
File metadata and controls
607 lines (468 loc) · 29.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Single Process Actor
"""
import itertools
from typing import Iterable, Tuple
import torch
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, log_probs_from_logits_all_rmpad
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
import verl.utils.torch_functional as verl_F
from codetiming import Timer
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
__all__ = ['RobDataParallelPPOActor']
class RobDataParallelPPOActor(BasePPOActor):
def __init__(
self,
config,
actor_module: nn.Module,
actor_optimizer: torch.optim.Optimizer = None,
):
"""When optimizer is None, it is Reference Policy"""
super().__init__(config)
self.actor_module = actor_module
self.actor_optimizer = actor_optimizer
self.use_remove_padding = self.config.get('use_remove_padding', False)
print(f'Actor use_remove_padding={self.use_remove_padding}')
print(f'PRM use dynamic bsz={self.config.get("use_dynamic_bsz", False)}')
self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
self.use_ulysses_sp = False #self.ulysses_sequence_parallel_size > 1
self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
def process_tensor(self, tensor, pad_id):
mask = tensor != pad_id
if not torch.all(mask == mask[0:1], dim=1).all():
raise ValueError("Padding error!")
base_mask = mask[0]
valid_len = base_mask.sum().item()
return tensor[:, base_mask], valid_len
def generate_traj_mask(self, end_step, traj_len):
"""
Args:
end_step: (batch_size,),
traj_len:
Returns:
mask: (batch_size, traj_len),
"""
steps = torch.arange(traj_len, device=end_step.device) # (traj_len,)
steps_expanded = steps.unsqueeze(0).expand(end_step.size(0), -1)
mask = steps_expanded < end_step.unsqueeze(1) # (batch_size, traj_len)
return mask
def apply_mask_with_grad_control(self, log_probs, entropy, mask):
"""
Args:
log_probs: (batch_size, traj_len, ...)
entropy: (batch_size, traj_len, ...)
mask: (batch_size, traj_len)
Returns:
log_probs_masked:
entropy_masked:
"""
mask_expanded = mask.unsqueeze(-1)
log_probs_masked = torch.where(
mask_expanded,
log_probs,
torch.zeros_like(log_probs, requires_grad=False)
)
entropy_masked = torch.where(
mask_expanded,
entropy,
torch.zeros_like(entropy, requires_grad=False)
)
return log_probs_masked, entropy_masked
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
"""
micro_batch:
Returns:
entropy: # (bs, response_len)
log_probs: # (bs, response_len)
"""
batch_size = micro_batch['responses'].size(0)
traj_len = micro_batch['responses'].size(1)
tot_pad_len = micro_batch['input_ids'].size(2)
assert all(micro_batch[key].size(0) == batch_size for key in ['responses', 'input_ids', 'attention_mask', 'pixel_values'])
assert all(micro_batch[key].size(1) == traj_len for key in ['responses', 'input_ids', 'attention_mask', 'pixel_values'])
assert all(micro_batch[key].size(2) == tot_pad_len for key in [ 'input_ids', 'attention_mask'])
if self.config.use_proprio:
assert micro_batch["proprio"].size(0) == batch_size and micro_batch["proprio"].size(1) == traj_len and micro_batch["proprio"].size(2) == self.config.action_token_len
response_length = micro_batch['responses'].size(-1) # 7*8
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
input_ids = micro_batch['input_ids']
attention_mask = micro_batch['attention_mask']
pixel_values = micro_batch["pixel_values"]
responses = micro_batch["responses"]
input_ids = input_ids.reshape((batch_size * traj_len,) + input_ids.shape[2:])
attention_mask = attention_mask.reshape((batch_size * traj_len,) + attention_mask.shape[2:])
pixel_values = pixel_values.reshape((batch_size * traj_len,) + pixel_values.shape[2:])
responses = responses.reshape((batch_size * traj_len,) + responses.shape[2:])
if self.config.use_proprio:
proprio = micro_batch["proprio"]
proprio = proprio.reshape((batch_size * traj_len,) + proprio.shape[2:])
else:
proprio = None
input_ids_unpad, _ = self.process_tensor(input_ids, self.pad_token_id)
attention_mask_unpad, _ = self.process_tensor(attention_mask, 0)
if self.config.vla == "openvla-oft":
logits = self.actor_module(input_ids=input_ids_unpad,
attention_mask=attention_mask_unpad,
pixel_values=pixel_values,
proprio=proprio,
) # prevent model thinks we are generating
assert self.actor_module.vocab_size == 32000
start_index = self.actor_module.vocab_size - 256
logits = logits[..., -256-64:-64] # Shape: [batch_size, seq_len, 256]
responses = responses - start_index
#assert (0<=responses<=255).all()
logits = logits.div(temperature)
log_probs = logprobs_from_logits(logits, responses)
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
assert len(log_probs.shape)==2 and len(entropy.shape)==2
log_probs = log_probs.reshape((batch_size, traj_len*self.config.action_chunks_len,self.config.action_token_len) ) #*
entropy = entropy.reshape((batch_size, traj_len*self.config.action_chunks_len,self.config.action_token_len) )
mask = self.generate_traj_mask(micro_batch['finish_step'], traj_len*self.config.action_chunks_len) #, self.config.action_token_len
log_probs, entropy = self.apply_mask_with_grad_control(log_probs, entropy, mask)
log_probs = log_probs.reshape((batch_size, traj_len*response_length))
entropy = entropy.reshape((batch_size, traj_len*response_length))
elif self.config.vla == "openvla":
output = self.actor_module(input_ids=input_ids_unpad,
attention_mask=attention_mask_unpad,
pixel_values=pixel_values,
use_cache=False) # prevent model thinks we are generating
logits = output.logits
logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
logits = logits.div(temperature)
log_probs = logprobs_from_logits(logits, responses)
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
#ADD
log_probs = log_probs.reshape((batch_size, traj_len,) + log_probs.shape[1:])
entropy = entropy.reshape((batch_size, traj_len,) + entropy.shape[1:])
mask = self.generate_traj_mask(micro_batch['finish_step'], traj_len)
log_probs, entropy = self.apply_mask_with_grad_control(log_probs, entropy, mask)
log_probs = log_probs.reshape((batch_size, traj_len*response_length))
entropy = entropy.reshape((batch_size, traj_len*response_length))
return entropy, log_probs
def _forward_micro_batch_update(self, input_ids, attention_mask, pixel_values, responses, temperature, proprio) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
if self.config.vla == "openvla-oft":
input_ids_unpad, _ = self.process_tensor(input_ids, self.pad_token_id)
attention_mask_unpad, _ = self.process_tensor(attention_mask, 0)
logits = self.actor_module(input_ids=input_ids_unpad,
attention_mask=attention_mask_unpad,
pixel_values=pixel_values,
proprio=proprio,
)
assert logits.requires_grad
assert self.actor_module.vocab_size == 32000
start_index = self.actor_module.vocab_size - 256
logits = logits[..., -256-64:-64] # Shape: [batch_size, seq_len, 256]
responses = responses - start_index
logits = logits.div(temperature)
log_probs = logprobs_from_logits(logits, responses)
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
log_probs = log_probs.reshape((1, -1))
entropy = entropy.reshape((1, -1))
return entropy, log_probs
elif self.config.vla == "openvla":
response_length = responses.size(-1)
input_ids_unpad, _ = self.process_tensor(input_ids, self.pad_token_id)
attention_mask_unpad, _ = self.process_tensor(attention_mask, 0)
output = self.actor_module(input_ids=input_ids_unpad,
attention_mask=attention_mask_unpad,
pixel_values=pixel_values,
use_cache=False) # prevent model thinks we are generating
logits = output.logits
#
logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
logits = logits.div(temperature)
log_probs = logprobs_from_logits(logits, responses)
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
log_probs = log_probs.reshape((1, -1))
entropy = entropy.reshape((1, -1))
return entropy, log_probs
def _forward_micro_batch_entropy(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = micro_batch['responses'].size(0)
traj_len = micro_batch['responses'].size(1)
tot_pad_len = micro_batch['input_ids'].size(2)
assert all(micro_batch[key].size(0) == batch_size for key in ['responses', 'input_ids', 'attention_mask', 'pixel_values'])
assert all(micro_batch[key].size(1) == traj_len for key in ['responses', 'input_ids', 'attention_mask', 'pixel_values'])
assert all(micro_batch[key].size(2) == tot_pad_len for key in [ 'input_ids', 'attention_mask'])
if self.config.use_proprio:
assert micro_batch["proprio"].size(0) == batch_size and micro_batch["proprio"].size(1) == traj_len and micro_batch["proprio"].size(2) == self.config.action_token_len
response_length = micro_batch['responses'].size(-1)
#assert response_length == 7*8
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
input_ids = micro_batch['input_ids']
#batch_size, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
pixel_values = micro_batch["pixel_values"]
input_ids = input_ids.reshape((batch_size * traj_len,) + input_ids.shape[2:])
attention_mask = attention_mask.reshape((batch_size * traj_len,) + attention_mask.shape[2:])
pixel_values = pixel_values.reshape((batch_size * traj_len,) + pixel_values.shape[2:])
if self.config.use_proprio:
proprio = micro_batch["proprio"]
proprio = proprio.reshape((batch_size * traj_len,) + proprio.shape[2:])
else:
proprio = None
input_ids_unpad, _ = self.process_tensor(input_ids, self.pad_token_id)
attention_mask_unpad, _ = self.process_tensor(attention_mask, 0)
if self.config.vla == "openvla-oft":
logits = self.actor_module(input_ids=input_ids_unpad,
attention_mask=attention_mask_unpad,
pixel_values=pixel_values,
proprio=proprio,
)
assert self.actor_module.vocab_size == 32000
start_index = self.actor_module.vocab_size - 256
logits = logits[..., -256-64:-64] # Shape: [batch_size, seq_len, 256]
logits = logits.div(temperature)
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
assert len(entropy.shape)==2
entropy = entropy.reshape((batch_size, traj_len*self.config.action_chunks_len, self.config.action_token_len) )
mask = self.generate_traj_mask(micro_batch['finish_step'], traj_len*self.config.action_chunks_len)
_, entropy = self.apply_mask_with_grad_control(entropy, entropy, mask)
entropy = entropy.reshape((batch_size, traj_len*response_length))
return entropy
elif self.config.vla == "openvla":
output = self.actor_module(input_ids=input_ids_unpad,
attention_mask=attention_mask_unpad,
pixel_values=pixel_values,
use_cache=False) # prevent model thinks we are generating
logits = output.logits
#
logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
logits = logits.div(temperature)
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
#ADD
entropy = entropy.reshape((batch_size, traj_len,) + entropy.shape[1:])
mask = self.generate_traj_mask(micro_batch['finish_step'], traj_len)
_, entropy = self.apply_mask_with_grad_control(entropy, entropy, mask)
entropy = entropy.reshape((batch_size, traj_len*response_length))
return entropy
def _optimizer_step(self):
assert self.config.grad_clip is not None
if isinstance(self.actor_module, FSDP):
grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)
self.actor_optimizer.step()
return grad_norm
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
Returns:
torch.Tensor: the log_prob tensor
"""
self.actor_module.eval()
micro_batch_size = data.meta_info['micro_batch_size'] #256
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error # 1
use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] #trues
self.pad_token_id = data.meta_info['pad_token_id']
select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values',"finish_step"]
if self.config.use_proprio:
select_keys.append("proprio")
batch = data.select(batch_keys=select_keys).batch
if use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
log_probs_lst = []
for micro_batch in micro_batches:
with torch.no_grad():
_, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
log_probs_lst.append(log_probs)
log_probs = torch.concat(log_probs_lst, dim=0)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
log_probs = log_probs[revert_indices]
return log_probs
def update_policy(self, data: DataProto):
self.actor_module.train()
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values', 'old_log_probs', 'advantages',"finish_step"]
if self.config.use_proprio:
select_keys.append("proprio")
batch = data.select(batch_keys=select_keys).batch
assert self.config.ppo_micro_batch_size == 1
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.ppo_mini_batch_size)
metrics = {}
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
# split batch into micro_batches
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size)
self.actor_optimizer.zero_grad()
for test_idx, data in enumerate(micro_batches):
data = data.cuda() # actor device is cpu when using offload
responses = data['responses']
response_length = responses.size(1) * responses.size(2)
finish_step = data['finish_step'] * self.config.action_token_len
steps = torch.arange(response_length, device=data['responses'].device) # (traj_len,)
steps_expanded = steps.unsqueeze(0).expand(data['responses'].size(0), -1)
response_mask = steps_expanded < finish_step.unsqueeze(1) # (batch_size, traj_len)
response_mask_sum = response_mask.sum(axis=None)
old_log_prob = data['old_log_probs']
advantages = data['advantages']
#clip_ratio = self.config.clip_ratio
clip_ratio_high = self.config.clip_ratio_high
clip_ratio_low = self.config.clip_ratio_low
entropy_coeff = self.config.entropy_coeff
batch_size = data['responses'].size(0)
traj_len = data['responses'].size(1)
tot_pad_len = data['input_ids'].size(2)
input_ids = data['input_ids']
attention_mask = data['attention_mask']
pixel_values = data["pixel_values"]
responses = data["responses"]
input_ids = input_ids.reshape((batch_size * traj_len,) + input_ids.shape[2:])
attention_mask = attention_mask.reshape((batch_size * traj_len,) + attention_mask.shape[2:])
pixel_values = pixel_values.reshape((batch_size * traj_len,) + pixel_values.shape[2:])
responses = responses.reshape((batch_size * traj_len,) + responses.shape[2:])
if self.config.use_proprio:
proprio = data["proprio"]
proprio = proprio.reshape((batch_size * traj_len,) + proprio.shape[2:])
else:
proprio = None
loss_info = {
#'actor/entropy_loss': entropy_loss.detach().item(),
'actor/pg_loss':0,
'actor/pg_clipfrac': 0,
'actor/ppo_kl': 0,
}
assert traj_len % self.config.traj_mini_batch_size ==0
traj_split_num = int(traj_len/self.config.traj_mini_batch_size)
for i in range(0, traj_len, int(traj_len/traj_split_num)):
entropy, log_prob = self._forward_micro_batch_update(input_ids=input_ids[i:i+int(traj_len/traj_split_num)],
attention_mask=attention_mask[i:i+int(traj_len/traj_split_num)],
pixel_values=pixel_values[i:i+int(traj_len/traj_split_num)],
responses=responses[i:i+int(traj_len/traj_split_num)],
temperature=temperature,
proprio=proprio[i:i+int(traj_len/traj_split_num)] if proprio is not None else None)
slice_id = i*self.config.action_token_len*self.config.action_chunks_len
next_slice_id = (i+int(traj_len/traj_split_num))*self.config.action_token_len*self.config.action_chunks_len
old_log_prob_tmp = old_log_prob[:, slice_id: next_slice_id]
advantages_tmp = advantages[:, slice_id: next_slice_id]
response_mask_tmp = response_mask[:, slice_id: next_slice_id]
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob_tmp,
log_prob=log_prob,
advantages=advantages_tmp,
eos_mask=response_mask_tmp,
clip_ratio_high=clip_ratio_high,
clip_ratio_low=clip_ratio_low)
response_mask_tmp_sum = response_mask_tmp.sum(axis=None)
pg_loss = pg_loss* response_mask_tmp_sum
pg_clipfrac = pg_clipfrac* response_mask_tmp_sum / response_mask_sum
ppo_kl = ppo_kl* response_mask_tmp_sum / response_mask_sum
policy_loss = pg_loss / response_mask_sum
loss = policy_loss / self.gradient_accumulation
loss.backward()
loss_info['actor/pg_loss'] = loss_info['actor/pg_loss'] + policy_loss.detach().item()
loss_info['actor/pg_clipfrac'] = loss_info['actor/pg_clipfrac'] + pg_clipfrac.detach().item()
loss_info['actor/ppo_kl'] = loss_info['actor/ppo_kl'] + ppo_kl.detach().item()
append_to_dict(metrics, loss_info)
grad_norm = self._optimizer_step()
data = {'actor/grad_norm': grad_norm.detach().item()}
append_to_dict(metrics, data)
torch.cuda.empty_cache()
self.actor_optimizer.zero_grad()
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return metrics
def compute_entropy(self, bacth_data: DataProto):
if bacth_data.meta_info['train_mode'] ==True:
self.actor_module.train()
print("train mode")
else:
self.actor_module.eval()
print("eval mode")
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
temperature = bacth_data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values', "finish_step"]
if self.config.use_proprio:
select_keys.append("proprio")
batch = bacth_data.select(batch_keys=select_keys).batch
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.ppo_mini_batch_size)
print("dataloader_length:", len(dataloader))
metrics = {}
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
# split batch into micro_batches
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size)
for data in micro_batches:
data = data.cuda() # actor device is cpu when using offload
responses = data['responses']
response_length = responses.size(1) * responses.size(2)
finish_step = data['finish_step'] * self.config.action_token_len
steps = torch.arange(response_length, device=data['responses'].device) # (traj_len,)
steps_expanded = steps.unsqueeze(0).expand(data['responses'].size(0), -1)
response_mask = steps_expanded < finish_step.unsqueeze(1) # (batch_size, traj_len)
with torch.no_grad():
entropy = self._forward_micro_batch_entropy(micro_batch=data, temperature=temperature)
entropy_loss = verl_F.masked_mean(entropy, response_mask)
if bacth_data.meta_info['is_filtered'] and bacth_data.meta_info['train_mode']:
data = {
'actor_after/entropy_loss_train': entropy_loss.detach().item(),
}
append_to_dict(metrics, data)
elif bacth_data.meta_info['is_filtered'] and not bacth_data.meta_info['train_mode']:
data = {
'actor_after/entropy_loss_eval': entropy_loss.detach().item(),
}
append_to_dict(metrics, data)
elif not bacth_data.meta_info['is_filtered'] and bacth_data.meta_info['train_mode']:
data = {
'actor_before/entropy_loss_train': entropy_loss.detach().item(),
}
append_to_dict(metrics, data)
elif not bacth_data.meta_info['is_filtered'] and not bacth_data.meta_info['train_mode']:
data = {
'actor_before/entropy_loss_eval': entropy_loss.detach().item(),
}
append_to_dict(metrics, data)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return metrics