@@ -205,9 +205,9 @@ def _clip_none(x, val, out, order, _binary_fn):
205205 order = order ,
206206 )
207207 if x_shape != res_shape :
208- x = dpt .broadcast_to (x , res_shape )
208+ x = dpt_ext .broadcast_to (x , res_shape )
209209 if val_ary .shape != res_shape :
210- val_ary = dpt .broadcast_to (val_ary , res_shape )
210+ val_ary = dpt_ext .broadcast_to (val_ary , res_shape )
211211 _manager = SequentialOrderManager [exec_q ]
212212 dep_evs = _manager .submitted_events
213213 ht_binary_ev , binary_ev = _binary_fn (
@@ -251,8 +251,8 @@ def _clip_none(x, val, out, order, _binary_fn):
251251 )
252252
253253 if x_shape != res_shape :
254- x = dpt .broadcast_to (x , res_shape )
255- buf = dpt .broadcast_to (buf , res_shape )
254+ x = dpt_ext .broadcast_to (x , res_shape )
255+ buf = dpt_ext .broadcast_to (buf , res_shape )
256256 ht_binary_ev , binary_ev = _binary_fn (
257257 src1 = x ,
258258 src2 = buf ,
@@ -580,11 +580,11 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
580580 order = order ,
581581 )
582582 if x_shape != res_shape :
583- x = dpt .broadcast_to (x , res_shape )
583+ x = dpt_ext .broadcast_to (x , res_shape )
584584 if a_min .shape != res_shape :
585- a_min = dpt .broadcast_to (a_min , res_shape )
585+ a_min = dpt_ext .broadcast_to (a_min , res_shape )
586586 if a_max .shape != res_shape :
587- a_max = dpt .broadcast_to (a_max , res_shape )
587+ a_max = dpt_ext .broadcast_to (a_max , res_shape )
588588 _manager = SequentialOrderManager [exec_q ]
589589 dep_ev = _manager .submitted_events
590590 ht_binary_ev , binary_ev = ti ._clip (
@@ -639,10 +639,10 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
639639 order = order ,
640640 )
641641
642- x = dpt .broadcast_to (x , res_shape )
642+ x = dpt_ext .broadcast_to (x , res_shape )
643643 if a_min .shape != res_shape :
644- a_min = dpt .broadcast_to (a_min , res_shape )
645- buf2 = dpt .broadcast_to (buf2 , res_shape )
644+ a_min = dpt_ext .broadcast_to (a_min , res_shape )
645+ buf2 = dpt_ext .broadcast_to (buf2 , res_shape )
646646 ht_binary_ev , binary_ev = ti ._clip (
647647 src = x ,
648648 min = a_min ,
@@ -695,10 +695,10 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
695695 order = order ,
696696 )
697697
698- x = dpt .broadcast_to (x , res_shape )
699- buf1 = dpt .broadcast_to (buf1 , res_shape )
698+ x = dpt_ext .broadcast_to (x , res_shape )
699+ buf1 = dpt_ext .broadcast_to (buf1 , res_shape )
700700 if a_max .shape != res_shape :
701- a_max = dpt .broadcast_to (a_max , res_shape )
701+ a_max = dpt_ext .broadcast_to (a_max , res_shape )
702702 ht_binary_ev , binary_ev = ti ._clip (
703703 src = x ,
704704 min = buf1 ,
@@ -766,9 +766,9 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
766766 order = order ,
767767 )
768768
769- x = dpt .broadcast_to (x , res_shape )
770- buf1 = dpt .broadcast_to (buf1 , res_shape )
771- buf2 = dpt .broadcast_to (buf2 , res_shape )
769+ x = dpt_ext .broadcast_to (x , res_shape )
770+ buf1 = dpt_ext .broadcast_to (buf1 , res_shape )
771+ buf2 = dpt_ext .broadcast_to (buf2 , res_shape )
772772 ht_ , clip_ev = ti ._clip (
773773 src = x ,
774774 min = buf1 ,
0 commit comments