admin管理员组文章数量:1296914
I am working with old TensorFlow (2.7) and Java Tensorflow but the answer may still help others tracing similar issues. My model produces the error when saved to the 2.7 save() format:
CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
Originally I imagined this came from the metric and loss function, since the model cannot save and load these, presumably because of some sort of serialization issue.
def masked_mae(y_true, y_pred):
# Mask NaN values, replace by 0
y_true = tf.where(tf.math.is_nan(y_true), y_pred, y_true)
# Calculate absolute differences
absolute_differences = tf.abs(y_true - y_pred)
# Compute the mean, ignoring potential NaN values (if any remain after replacement)
mae = tf.reduce_mean(absolute_differences)
return mae
I don't know how to deliver custom_objects in Java, and I'm done training, so I tried compiling the model with loss=None and metrics=None and saving it that way. The model can be loaded with this change but the error has not gone away. Is there a way to trace what layer is causing this and why? Is it just a nuisance? There are no obvious uses of masks.
本文标签: pythonTracing CustomMaskWarning to its sourceStack Overflow
版权声明:本文标题:python - Tracing CustomMaskWarning to its source - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1741641223a2389928.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论