当我读到Cartoongan 卡通络:用于照片卡通化的生成对抗性网络的报纸时,我对vgg在整个网络中的位置感到非常困惑。
根据关于GAN 生成对抗性网的文章,我看不到网络中存在vgg的任何地方。
我觉得我看不清卡通根的整个建筑。我知道GAN是如何用发电机和识别器构造的,但卡通根不是。如果有人能帮忙,我很感激。
发布于 2022-09-20 12:53:24
很大的问题是,VGG主要用于特征提取。因此,我们可以比较两种图像的风格之间的相似性。通常情况下,真实的一批图像和生成的图像。
这主要用作损失函数中的附加项。如下所示,s_loss和c_loss
这可以在所涉论文的GitHub回购上看到:
if self.content_lambda != 0. or self.style_lambda != 0.:
vgg_generated_images = self.pass_to_vgg(generated_images)
if self.content_lambda != 0.:
c_loss = self.content_lambda * self.content_loss(
self.pass_to_vgg(source_images), vgg_generated_images)
g_total_loss = g_total_loss + c_loss
if self.style_lambda != 0.:
s_loss = self.style_lambda * self.style_loss(
self.pass_to_vgg(target_images[:vgg_generated_images.shape[0]]),
vgg_generated_images)
g_total_loss = g_total_loss + s_loss
d_grads = d_tape.gradient(d_total_loss, discriminator.trainable_variables)
g_grads = g_tape.gradient(g_total_loss, generator.trainable_variables) 当VGG被初始化并在请求培训时加载时,如下所示:
base_model = VGG19(weights="imagenet", include_top=False, input_shape=input_shape)
tmp_vgg_output = base_model.get_layer("block4_conv3").output
tmp_vgg_output = Conv2D(512, (3, 3), activation='linear', padding='same',
name='block4_conv4')(tmp_vgg_output)
self.vgg = tf.keras.Model(inputs=base_model.input, outputs=tmp_vgg_output)
self.vgg.load_weights(os.path.expanduser(os.path.join(
"~", ".keras", "models",
"vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5")), by_name=True)https://datascience.stackexchange.com/questions/114552
复制相似问题