python - tf.keras.metrics.MeanSquaredError 结果与相同输入和权重的 tf.keras.losses.MeanSquaredError 不匹配

当使用下面的损失 MeanSquaredError 时,输出按预期进行

y_true = tf.constant([[1., 9.], [2., 5.]])
y_pred = tf.constant([[4., 8.], [12., 3.]])
sample_weight = tf.constant([1.2, 0.5])

mse = tf.keras.losses.MeanSquaredError()

print("mse w/o weights", mse(y_true, y_pred).numpy())
print("mse w weights", mse(y_true, y_pred, sample_weight=sample_weight).numpy())

输出 -

mse w/o weights 28.5
mse w weights 16.0

预期的 -

# mse = [((4 - 1)^2 + (8 - 9)^2) / 2, ((12 - 2)^2 + (3 - 5)^2) / 2]
# mse = [5, 52]
# weighted_mse = [5 * 1.2, 52 * 0.5] = [6, 26]
# reduced_weighted_mse = (6 + 26) / 2 = 16

当使用度量 MeanSquaredError 时,权重的结果会有所不同

mse = tf.keras.metrics.MeanSquaredError()
mse.update_state(y_true, y_pred)
print("mse w/o weights", mse.result().numpy())

mse = tf.keras.metrics.MeanSquaredError()
mse.update_state(y_true, y_pred, sample_weight=sample_weight)
print("mse w weights", mse.result().numpy())

输出 -

mse w/o weights 28.5
mse w weights 18.823528

为什么带有权重的指标输出与损失输出不同?

回答1

  1. 对于 tf.keras.losses.MeanSquaredErrorsample_weight 充当损失的系数。因此,它是 [5 * 1.2, 52 * 0.5]=16.0 的平均值。
  2. 对于 tf.keras.metrics.MeanSquaredErrorsample_weight 对样本进行加权平均。因此它是 (5 * 1.2 + 52 * 0.5)/(1.2+0.5)=32/1.7=18.823

相似文章

最新文章