@@ -442,7 +442,7 @@ class Transformer(Layer):
442442 - **blinding**: bool. Whether or not use blinding.
443443 - **seed**: A Python integer to use as random seed.
444444 - **supports_masking**:bool. Whether or not support masking.
445- - **attention_type**: str, Type of attention, the value must be one of { ``'scaled_dot_product'`` , ``'additive'`` }.
445+ - **attention_type**: str, Type of attention, the value must be one of { ``'scaled_dot_product'`` , ``'cos'`` , ``'ln'`` , ``' additive'`` }.
446446 - **output_type**: ``'mean'`` , ``'sum'`` or `None`. Whether or not use average/sum pooling for output.
447447
448448 References
@@ -490,6 +490,9 @@ def build(self, input_shape):
490490 initializer = glorot_uniform (seed = self .seed ))
491491 self .v = self .add_weight ('v' , shape = [self .att_embedding_size ], dtype = tf .float32 ,
492492 initializer = glorot_uniform (seed = self .seed ))
493+ elif self .attention_type == "ln" :
494+ self .att_ln_q = LayerNormalization ()
495+ self .att_ln_k = LayerNormalization ()
493496 # if self.use_res:
494497 # self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
495498 # initializer=TruncatedNormal(seed=self.seed))
@@ -529,28 +532,42 @@ def call(self, inputs, mask=None, training=None, **kwargs):
529532 queries = self .query_pe (queries )
530533 keys = self .key_pe (queries )
531534
532- querys = tf .tensordot (queries , self .W_Query ,
533- axes = (- 1 , 0 )) # None T_q D*head_num
534- keys = tf .tensordot (keys , self .W_key , axes = (- 1 , 0 ))
535- values = tf .tensordot (keys , self .W_Value , axes = (- 1 , 0 ))
535+ Q = tf .tensordot (queries , self .W_Query ,
536+ axes = (- 1 , 0 )) # N T_q D*h
537+ K = tf .tensordot (keys , self .W_key , axes = (- 1 , 0 ))
538+ V = tf .tensordot (keys , self .W_Value , axes = (- 1 , 0 ))
536539
537- # head_num*None T_q D
538- querys = tf .concat (tf .split (querys , self .head_num , axis = 2 ), axis = 0 )
539- keys = tf .concat (tf .split (keys , self .head_num , axis = 2 ), axis = 0 )
540- values = tf .concat (tf .split (values , self .head_num , axis = 2 ), axis = 0 )
540+ # h*N T_q D
541+ Q_ = tf .concat (tf .split (Q , self .head_num , axis = 2 ), axis = 0 )
542+ K_ = tf .concat (tf .split (K , self .head_num , axis = 2 ), axis = 0 )
543+ V_ = tf .concat (tf .split (V , self .head_num , axis = 2 ), axis = 0 )
541544
542545 if self .attention_type == "scaled_dot_product" :
543- # head_num*None T_q T_k
544- outputs = tf .matmul (querys , keys , transpose_b = True )
546+ # h*N T_q T_k
547+ outputs = tf .matmul (Q_ , K_ , transpose_b = True )
545548
546- outputs = outputs / (keys .get_shape ().as_list ()[- 1 ] ** 0.5 )
549+ outputs = outputs / (K_ .get_shape ().as_list ()[- 1 ] ** 0.5 )
550+ elif self .attention_type == "cos" :
551+ Q_cos = tf .nn .l2_normalize (Q_ , dim = - 1 )
552+ K_cos = tf .nn .l2_normalize (K_ , dim = - 1 )
553+
554+ outputs = tf .matmul (Q_cos , K_cos , transpose_b = True ) # h*N T_q T_k
555+
556+ outputs = outputs * 20 # Scale
557+ elif self .attention_type == 'ln' :
558+ Q_ = self .att_ln_q (Q_ )
559+ K_ = self .att_ln_k (K_ )
560+
561+ outputs = tf .matmul (Q_ , K_ , transpose_b = True ) # h*N T_q T_k
562+ # Scale
563+ outputs = outputs / (K_ .get_shape ().as_list ()[- 1 ] ** 0.5 )
547564 elif self .attention_type == "additive" :
548- querys_reshaped = tf .expand_dims (querys , axis = - 2 )
549- keys_reshaped = tf .expand_dims (keys , axis = - 3 )
550- outputs = tf .tanh (tf .nn .bias_add (querys_reshaped + keys_reshaped , self .b ))
565+ Q_reshaped = tf .expand_dims (Q_ , axis = - 2 )
566+ K_reshaped = tf .expand_dims (K_ , axis = - 3 )
567+ outputs = tf .tanh (tf .nn .bias_add (Q_reshaped + K_reshaped , self .b ))
551568 outputs = tf .squeeze (tf .tensordot (outputs , tf .expand_dims (self .v , axis = - 1 ), axes = [- 1 , 0 ]), axis = - 1 )
552569 else :
553- raise ValueError ("attention_type must be scaled_dot_product or additive" )
570+ raise ValueError ("attention_type must be [ scaled_dot_product,cos,ln, additive] " )
554571
555572 key_masks = tf .tile (key_masks , [self .head_num , 1 ])
556573
@@ -583,7 +600,7 @@ def call(self, inputs, mask=None, training=None, **kwargs):
583600 outputs = self .dropout (outputs , training = training )
584601 # Weighted sum
585602 # ( h*N, T_q, C/h)
586- result = tf .matmul (outputs , values )
603+ result = tf .matmul (outputs , V_ )
587604 result = tf .concat (tf .split (result , self .head_num , axis = 0 ), axis = 2 )
588605
589606 if self .use_res :
0 commit comments