diff --git a/CsiNet_train.py b/CsiNet_train.py index acaf5ec..37f84d0 100644 --- a/CsiNet_train.py +++ b/CsiNet_train.py @@ -19,28 +19,31 @@ encoded_dim = 512 #compress rate=1/4->dim.=512, compress rate=1/16->dim.=128, compress rate=1/32->dim.=64, compress rate=1/64->dim.=32 # Bulid the autoencoder model of CsiNet -def residual_network(x, residual_num, encoded_dim): - def add_common_layers(y): +def residual_block_decoded(): + conv1 = Conv2D(8, kernel_size=(3, 3), padding='same') + conv2 = Conv2D(16, kernel_size=(3, 3), padding='same') + conv3 = Conv2D(2, kernel_size=(3, 3), padding='same') + bn = BatchNormalization() + + def block(x): + shortcut = x + y = conv1(x) y = BatchNormalization()(y) y = LeakyReLU()(y) - return y - def residual_block_decoded(y): - shortcut = y - y = Conv2D(8, kernel_size=(3, 3), padding='same', data_format='channels_first')(y) - y = add_common_layers(y) - - y = Conv2D(16, kernel_size=(3, 3), padding='same', data_format='channels_first')(y) - y = add_common_layers(y) - - y = Conv2D(2, kernel_size=(3, 3), padding='same', data_format='channels_first')(y) + + y = conv2(y) y = BatchNormalization()(y) + y = LeakyReLU()(y) + + y = conv3(y) + y = bn(y) y = add([shortcut, y]) y = LeakyReLU()(y) - return y + return block - x = Conv2D(2, (3, 3), padding='same', data_format="channels_first")(x) + x = Conv2D(2, (3, 3), padding='same')(x) x = add_common_layers(x) @@ -48,15 +51,15 @@ def residual_block_decoded(y): encoded = Dense(encoded_dim, activation='linear')(x) x = Dense(img_total, activation='linear')(encoded) - x = Reshape((img_channels, img_height, img_width,))(x) + x = Reshape((img_height, img_width, img_channels))(x) for i in range(residual_num): x = residual_block_decoded(x) - x = Conv2D(2, (3, 3), activation='sigmoid', padding='same', data_format="channels_first")(x) + x = Conv2D(2, (3, 3), activation='sigmoid', padding='same')(x) return x -image_tensor = Input(shape=(img_channels, img_height, img_width)) +image_tensor = Input(shape=(img_height, img_width, img_channels)) network_output = residual_network(image_tensor, residual_num, encoded_dim) autoencoder = Model(inputs=[image_tensor], outputs=[network_output]) autoencoder.compile(optimizer='adam', loss='mse') @@ -200,4 +203,4 @@ def on_epoch_end(self, epoch, logs={}): json_file.write(model_json) # serialize weights to HDF5 outfile = "result/model_%s.h5"%file -autoencoder.save_weights(outfile) \ No newline at end of file +autoencoder.save_weights(outfile)