你可以在`MultiHeadAttention`层的定义处进行修改。找到`MultiHeadAttention`类的定义,然后将其中的`_masked_softmax`方法中的`axis`参数改为`-1`,即对最后一个轴进行softmax操作:
```python
class MultiHeadAttention(keras.layers.Layer):
def __init__(self, num_heads, key_dim, name="multi_head_attention"):
super(MultiHeadAttention, self).__init__(name=name)
self.num_heads = num_heads
self.key_dim = key_dim
self.query_dense = keras.layers.Dense(units=key_dim, name="query")
self.key_dense = keras.layers.Dense(units=key_dim, name="key")
self.value_dense = keras.layers.Dense(units=key_dim, name="value")
self.combine_heads = keras.layers.Dense(units=key_dim, name="combine_heads")
def _masked_softmax(self, attention_scores, attention_mask):
"""
Softmax with a mask to prevent over-attention to padding tokens.
"""
if attention_mask is not None:
# Apply the attention mask
attention_scores = attention_scores - 1e9 * (1.0 - attention_mask)
# Compute softmax on the last axis (num_heads or sequence_length)