根据错误信息,问题出在`softmax`操作的`axis`参数上。你可以尝试修改`MultiHeadAttention`层的代码,将`_masked_softmax`方法中的`axis`参数改为`-1`,即对最后一个轴进行softmax操作:
```python
class MultiHeadAttention(keras.layers.Layer):
def __init__(self, num_heads, key_dim, name="multi_head_attention"):
super(MultiHeadAttention, self).__init__(name=name)
self.num_heads = num_heads
self.key_dim = key_dim
self.query_dense = keras.layers.Dense(units=key_dim, name="query")
self.key_dense = keras.layers.Dense(units=key_dim, name="key")
self.value_dense = keras.layers.Dense(units=key_dim, name="value")
self.combine_heads = keras.layers.Dense(units=key_dim, name="combine_heads")
def _masked_softmax(self, attention_scores, attention_mask):
"""
Softmax with a mask to prevent over-attention to padding tokens.
"""
if attention_mask is not None:
# Apply the attention mask
attention_scores = attention_scores - 1e9 * (1.0 - attention_mask)
# Compute softmax on the last axis (num_heads or sequence_length)
attention_weights = keras.activations.softmax(attention_scores, axis=-1)
return attention_weights
```
如果你还是遇到了问题,请提供更多的代码和错误信息,以便我们更好地帮助你解决问题。