MeanIoU for sparse_categorical_crossentropy(Tensorflow2.3.0驗證成功)

之前在tensorflow2.0版本使用以下搭配可以成功

class MeanIoU(tf.keras.metrics.MeanIoU):
    """MeanIoU for sparse_categorical_crossentropy"""
    def __call__(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().__call__(y_true, y_pred, sample_weight=sample_weight)

model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', 
               metrics=['accuracy', MeanIoU(num_classes=CLASSES, name='mIOU')])

但是升級到了Tensorflow2.3.0版本一直出現以下維度的問題
Shapes of all inputs must match: values[0].shape = [80] != values[1].shape = [400] (num_class=5)
更新后可以正常使用了。

class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
    def __init__(self,
               y_true=None,
               y_pred=None,
               num_classes=None,
               name=None,
               dtype=None):
        super(UpdatedMeanIoU, self).__init__(num_classes = num_classes,name=name, dtype=dtype)
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.math.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)

    
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', 
               metrics=['accuracy', UpdatedMeanIoU(num_classes=CLASSES, name='mIOU')])
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容