@@ -22,6 +22,8 @@ public SGD(float learning_rate,
2222 _set_hyper ( "decay" , decay ) ;
2323
2424 _momentum = momentum > 0 ;
25+ if ( momentum < 0 || momentum > 1 )
26+ throw new ValueError ( $ "momentum must be a number between 0 and 1, got { momentum } .") ;
2527
2628 _set_hyper ( "momentum" , momentum ) ;
2729
@@ -30,6 +32,13 @@ public SGD(float learning_rate,
3032#pragma warning restore CS1717 // Assignment made to same variable
3133 }
3234
35+ protected override void _create_slots ( IVariableV1 [ ] var_list )
36+ {
37+ if ( _momentum )
38+ foreach ( var var in var_list )
39+ add_slot ( var , "momentum" ) ;
40+ }
41+
3342 protected override void _prepare_local ( DeviceDType device_dtype ,
3443 Dictionary < DeviceDType , Dictionary < string , Tensor > > _apply_state )
3544 {
@@ -43,7 +52,15 @@ protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad,
4352 {
4453 if ( _momentum )
4554 {
46- throw new NotImplementedException ( "_resource_apply_dense" ) ;
55+ var momentum_var = get_slot ( var , "momentum" ) ;
56+ return gen_training_ops . resource_apply_keras_momentum (
57+ var . Handle ,
58+ momentum_var . Handle ,
59+ _get_hyper ( "learning_rate" , var . dtype ) ,
60+ grad ,
61+ _get_hyper ( "momentum" , var . dtype ) ,
62+ use_locking : _use_locking ,
63+ use_nesterov : nesterov ) ;
4764 }
4865 var device_dtype = _apply_state . Keys . FirstOrDefault ( x => x . Device == var . Device && x . DType == var . dtype . as_base_dtype ( ) ) ;
4966
0 commit comments