1818from dataclasses import dataclass
1919from typing import Callable , Dict , Optional
2020
21+ import paddle .jit .dy2static .utils as jit_utils
2122import paddle .nn .layer
2223from paddle .device .cuda import graphs
2324
@@ -51,27 +52,24 @@ class ConcreteSizeEntry:
5152
5253class Dy2StCudaGraphManager :
5354 def __init__ (self ):
54- # NOTE(gongshaotian): Use local import to avoid RLHF version problems
55- from paddle .jit .dy2static .utils import CUDAGraphState
5655
57- self .state = CUDAGraphState .DISABLE
56+ self .state = jit_utils . CUDAGraphState .DISABLE
5857 self .captured_batch_size = set ()
5958 self .batch_size = - 1
6059
6160 def run_impl (self , original_run_impl , inputs , parameters , attrs ):
62- from paddle .jit .dy2static .utils import CUDAGraphState
6361
6462 run_state = self .state
6563 prog_attrs , cuda_graph_attrs = attrs
66- if run_state == CUDAGraphState .REPLAY :
64+ if run_state == jit_utils . CUDAGraphState .REPLAY :
6765 if self .batch_size not in self .captured_batch_size :
68- run_state = CUDAGraphState .DISABLE
69- elif run_state == CUDAGraphState .CAPTURE :
66+ run_state = jit_utils . CUDAGraphState .DISABLE
67+ elif run_state == jit_utils . CUDAGraphState .CAPTURE :
7068 self .captured_batch_size .add (self .batch_size )
7169
7270 cuda_graph_attrs |= {
7371 "cuda_graph_state" : run_state ,
74- "cuda_graph_dispatch_key" : self .batch_size if run_state != CUDAGraphState .DISABLE else 0 ,
72+ "cuda_graph_dispatch_key" : self .batch_size if run_state != jit_utils . CUDAGraphState .DISABLE else 0 ,
7573 }
7674 return original_run_impl (inputs , parameters , (prog_attrs , cuda_graph_attrs ))
7775
@@ -104,7 +102,6 @@ def __init__(
104102 self .cuda_graph_manager = Dy2StCudaGraphManager ()
105103
106104 def run_static_model (self , entry : ConcreteSizeEntry , ** kwargs ):
107- from paddle .jit .dy2static .utils import CUDAGraphState
108105
109106 if not entry .captured :
110107 # Warmup the model
@@ -121,14 +118,14 @@ def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
121118 entry .input_addresses = input_addresses
122119
123120 # Capture
124- self .cuda_graph_manager .state = CUDAGraphState .CAPTURE
121+ self .cuda_graph_manager .state = jit_utils . CUDAGraphState .CAPTURE
125122 self .cuda_graph_manager .batch_size = entry .real_shape
126123 entry .captured = True
127124 with self .cuda_graph_manager .run_impl_guard ():
128125 entry .runnable (** kwargs )
129126
130127 # Replay
131- self .cuda_graph_manager .state = CUDAGraphState .REPLAY
128+ self .cuda_graph_manager .state = jit_utils . CUDAGraphState .REPLAY
132129 self .cuda_graph_manager .batch_size = entry .real_shape
133130 with self .cuda_graph_manager .run_impl_guard ():
134131 return entry .runnable (** kwargs )
0 commit comments