admin管理员组

文章数量:1389768

I have built a U-Net like architecture called DF-Net for image manipulation localization. I implemented the model in Pytroch but found later the model's saved metadata and so on in Keras format. I loaded all weights from the Keras's model to Pytorch.

The Torch model's output I implemented is vastly different from the loaded Keras's model. I started debugging and found the discrepancy starting from the first layer in the encoder block. I inspected the first Conv2D block and found that outputs are identical in both models. However, the problem arises because of the BatchNormalization. The outputs are very different. I even inspected the parameters of BNs in both models and they are identical.

Keras BN gamma: [1.1910659  0.64996934 1.3252271  1.1093639  0.71945614]
Keras BN beta: [ 0.268981   -0.26507446  0.09138709 -0.01533892 -0.22096473]
Keras BN Moving Mean: [ 0.05168979  0.12713721 -0.01842033  0.00101717 -0.52817714]
Keras BN Moving Variance: [0.008371   0.00603298 0.0031222  0.00025985 0.09115636]
PyTorch BN gamma: [1.1910659  0.64996934 1.3252271  1.1093639  0.71945614]
PyTorch BN beta: [ 0.268981   -0.26507446  0.09138709 -0.01533892 -0.22096473]
PyTorch BN Running Mean: [ 0.05168979  0.12713721 -0.01842033  0.00101717 -0.52817714]
PyTorch BN Running Variance: [0.008371   0.00603298 0.0031222  0.00025985 0.09115636]
input_shape = (1,256,256,3)
keras_input = np.ones(input_shape)
torch_input = torch.from_numpy(keras_input).permute(0,3,1,2)
keras_output = M1_tf.predict(keras_input)
from tensorflow.keras.models import Model

# Create a sub-model that outputs activations from a specific layer
layer_name = "batch_normalization_81"  # Replace with your layer name
intermediate_model = Model(inputs=M1_tf.input, outputs=M1_tf.get_layer(layer_name).output)
keras_layer_output = intermediate_model.predict(keras_input)

m1.eval()  # Ensure inference mode
layer_input = None  # To store the input
layer_output = None  # To store the output

def hook_fn(module, input, output):
    global layer_input, layer_output
    layer_input = input[0].detach()  # Capture input
    layer_output = output.detach()   # Capture output

target_layer = m1.down_convs[0].parallel_conv_blocks[0].bn  # Replace with your layer
hook = target_layer.register_forward_hook(hook_fn)

# Run inference
with torch.no_grad():
    m1(torch_input)

hook.remove()  # Remove hook after use

print(60*"=")
print("PyTorch layer output:", layer_output[0].permute(1,2,0).numpy().flatten()[:50])
print("Keras layer output:", keras_layer_output.flatten()[:50])

Output :

PyTorch layer output: [ 6.37458324e-01  3.82063180e-01  4.29470986e-01  2.57231295e-05
 -3.16684663e-01  1.11704357e-01  8.17334801e-02  3.11658066e-02
  5.07744670e-01  4.63411123e-01 -2.69585550e-01 -6.68283165e-01
 -2.28013784e-01 -3.12905580e-01  3.43340598e-02  2.58536279e-01
  7.83286989e-03 -2.22982496e-01  2.51532018e-01 -6.86605215e-01
  2.96584144e-02 -4.80362698e-02 -1.08390920e-01  2.83417434e-01
  3.13855149e-02 -2.57040292e-01 -1.84278190e-02 -3.31664622e-01
  7.47844353e-02 -3.62884812e-03  4.47052151e-01 -6.04453266e-01
  9.07126606e-01  5.73347270e-01 -1.01024106e-01 -1.66961960e-02
 -8.40807796e-01  9.38138887e-02 -1.55476332e-01  2.54854243e-02
  4.34181899e-01 -6.21834695e-02 -7.02126846e-02 -4.76066291e-01
  1.82371408e-01 -3.82577702e-02  1.82515644e-02 -3.12020183e-01
 -2.49998122e-02 -2.88213879e-01]
Keras layer output: [ 7.4762206e+00  1.7106993e+00  9.3362055e+00 -4.6326146e-02
  2.8026474e-01  1.6977308e+00  1.3332834e+00  4.3128264e-01
  5.5462074e+00  8.9298620e+00 -2.7876496e+00 -6.4391127e+00
 -4.5693979e+00 -1.9292536e+00  4.6420819e-01  3.2088432e+00
  7.8105974e-01 -1.3556652e+00  5.0551910e+00 -1.6469488e+01
  2.7070484e-01 -3.2417399e-01 -2.4771357e+00 -6.1011815e-01
  5.4895997e-02 -2.7302039e+00  4.6613199e-01 -5.2117767e+00
 -9.5965725e-01 -1.1980299e+00  7.1799951e+00 -1.0363155e+01
  1.0794193e+01  3.1932242e+00 -1.6136179e+00 -5.6896377e-01
 -9.6188796e-01  1.3953602e+00 -1.2235196e+00  2.9183704e-01
  4.6142755e+00 -2.0011232e+00  1.5729040e-02 -4.3455467e+00
  5.2770085e+00  3.3058962e-01  3.2997942e-01 -3.1017118e+00
 -6.5815002e-02 -1.9106138e+00]

本文标签: pythonBatchNormalization mismatch between Pytorch and KerasStack Overflow