import tensorflow as tf
[docs]class AttentionHead(tf.keras.layers.Layer):
"""
A class representing a single attention head in a multi-head attention mechanism.
Args:
dmodel (int): The dimensionality of the model.
"""
def __init__(self, dmodel):
super(AttentionHead, self).__init__()
self.dims = dmodel
# q, k, v correspond to query, key, value.
self.qw = tf.keras.layers.Dense(dmodel)
self.kw = tf.keras.layers.Dense(dmodel)
self.vw = tf.keras.layers.Dense(dmodel)
[docs] def call(self, input):
query = self.qw(input)
key = self.kw(input)
vals = self.vw(input)
score = tf.matmul(query, key, transpose_b=True)
scaled_score = score / tf.math.sqrt(tf.cast(self.dims, float))
weights = tf.nn.softmax(scaled_score, axis=-1)
return tf.matmul(weights, vals)
[docs]class MultiHeadAttention(tf.keras.layers.Layer):
"""
A class representing multi-head attention mechanism.
Args:
h (int): Number of attention heads.
dmodel (int): The dimensionality of the model. Aka the output vector dims.
"""
def __init__(self, h, dmodel):
super(MultiHeadAttention, self).__init__()
self.h = h
self.dims = dmodel
self.heads = [AttentionHead(dmodel) for _ in range(h)]
self.linear = tf.keras.layers.Dense(dmodel)
self.add = tf.keras.layers.Add()
[docs] def call(self, input):
res = tf.concat([head(input) for head in self.heads], -1)
res = self.linear(res)
res = self.add([res, input])
return res