SEAN代码(1)

代码地址
首先定义一个trainer。

trainer = Pix2PixTrainer(opt)

在Pix2PixTrainer内部,首先定义Pix2PixModel模型。

self.pix2pix_model = Pix2PixModel(opt)

在Pix2PixModel内部定义生成器,判别器。

self.netG, self.netD, self.netE = self.initialize_networks(opt)

在initialize_networks内部定义功能。

netG = networks.define_G(opt)
netD = networks.define_D(opt) if opt.isTrain else None
netE = networks.define_E(opt) if opt.use_vae else None

首先看生成器:

def define_G(opt):netG_cls = find_network_using_name(opt.netG, 'generator')#netG=spadereturn create_network(netG_cls, opt)

输入的参数是opt.netG,在option中对应的是spade。在find_network_using_name中:

def find_network_using_name(target_network_name, filename):#spade,generatortarget_class_name = target_network_name + filename#spadegeneratormodule_name = 'models.networks.' + filename#models.networks.generatornetwork = util.find_class_in_module(target_class_name, module_name)#<class 'models.networks.generator.SPADEGenerator'>assert issubclass(network, BaseNetwork), \"Class %s should be a subclass of BaseNetwork" % networkreturn network

根据target_network_name和对应的filename输入到find_class_in_module中:

def find_class_in_module(target_cls_name, module):target_cls_name = target_cls_name.replace('_', '').lower()#spadegeneratorclslib = importlib.import_module(module)#import_module()返回指定的包或模块cls = Nonefor name, clsobj in clslib.__dict__.items():if name.lower() == target_cls_name:cls = clsobjif cls is None:print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))exit(0)return cls

我们通过import_module函数载入module这个模块,module对应的是models.networks.generator。即clslib 就是generator文件中的类。我们遍历clslib的字典,如果name等于spadegenerator,令cls = clsobj。
即network等于cls。

network = util.find_class_in_module(target_class_name, module_name)

这里有两个语法问题:
①:导入importlib,调用import_module()方法,根据输入的字符串可以获得模块clslib ,clslib 可以调用models.networks.generator文件下所有的属性和方法。
在这里插入图片描述
在generator内部是:
在这里插入图片描述可以通过clslib.SPADEGenerator来实例化SPADEGenerator,然后再调用SPADEGenerator内部的方法。
举个例子:新建三个文件。
在这里插入图片描述
train:
在这里插入图片描述
用不到test,在tt文件内部中导入train中的类s。
在这里插入图片描述
因为是同级目录,直接导入字符串train即可,如果不在同级目录,需要导入前一个目录。
接着a就会变成一个module,即train。然后实例化train文件夹下的类s。最后调用类s的方法kill和qqq。
输出:
在这里插入图片描述
②: dict,该属性可以用类名或者类的实例对象来调用,用**类名直接调用 dict,会输出该由类中所有类属性组成的字典;**而使用类的实例对象调用 dict,会输出由类中所有实例属性组成的字典。
参考
这里SPADEGenerator继承了BaseNetwork,对于具有继承关系的父类和子类来说,父类有自己的 dict,同样子类也有自己的 dict,它不会包含父类的 dict
例子:按上面的例子,a是一个module,查看a的__dict__:
在这里插入图片描述
输出:
在这里插入图片描述
回到代码中:我们输出的network就是类<class ‘models.networks.generator.SPADEGenerator’>。
下一步我们创建网络:在这里插入图片描述
在这里插入图片描述
cls对应的是SPADEGenerator网络。
在SPADE中:

"""
Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
from models.networks.architecture import Zencoderclass SPADEGenerator(BaseNetwork):@staticmethoddef modify_commandline_options(parser, is_train):parser.set_defaults(norm_G='spectralspadesyncbatch3x3')parser.add_argument('--num_upsampling_layers',choices=('normal', 'more', 'most'), default='normal',help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")return parserdef __init__(self, opt):super().__init__()self.opt = optnf = opt.ngfself.sw, self.sh = self.compute_latent_vector_size(opt)self.Zencoder = Zencoder(3, 512)self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='head_0')self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_0')self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_1')self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, Block_Name='up_0')self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, Block_Name='up_1')self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, Block_Name='up_2')self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, Block_Name='up_3', use_rgb=False)final_nc = nfif opt.num_upsampling_layers == 'most':self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt, Block_Name='up_4')final_nc = nf // 2self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)self.up = nn.Upsample(scale_factor=2)#self.up = nn.Upsample(scale_factor=2, mode='bilinear')def compute_latent_vector_size(self, opt):if opt.num_upsampling_layers == 'normal':#默认num_up_layers = 5elif opt.num_upsampling_layers == 'more':num_up_layers = 6elif opt.num_upsampling_layers == 'most':num_up_layers = 7else:raise ValueError('opt.num_upsampling_layers [%s] not recognized' %opt.num_upsampling_layers)sw = opt.crop_size // (2**num_up_layers)#256//32=16sh = round(sw / opt.aspect_ratio)#8return sw, shdef forward(self, input, rgb_img, obj_dic=None):seg = inputx = F.interpolate(seg, size=(self.sh, self.sw))#(16,16)x = self.fc(x)#(b,1024,16,16)style_codes = self.Zencoder(input=rgb_img, segmap=seg)x = self.head_0(x, seg, style_codes, obj_dic=obj_dic)x = self.up(x)x = self.G_middle_0(x, seg, style_codes, obj_dic=obj_dic)if self.opt.num_upsampling_layers == 'more' or \self.opt.num_upsampling_layers == 'most':x = self.up(x)x = self.G_middle_1(x, seg, style_codes,  obj_dic=obj_dic)x = self.up(x)x = self.up_0(x, seg, style_codes, obj_dic=obj_dic)x = self.up(x)x = self.up_1(x, seg, style_codes, obj_dic=obj_dic)x = self.up(x)x = self.up_2(x, seg, style_codes, obj_dic=obj_dic)x = self.up(x)x = self.up_3(x, seg, style_codes,  obj_dic=obj_dic)# if self.opt.num_upsampling_layers == 'most':#     x = self.up(x)#     x= self.up_4(x, seg, style_codes,  obj_dic=obj_dic)x = self.conv_img(F.leaky_relu(x, 2e-1))x = F.tanh(x)return x

首先计算潜在空间向量的大小:
在这里插入图片描述
接着计算style matrixST。对应文章的 :
在这里插入图片描述
在代码中:通过卷积,下采样,下采样,上采样,卷积。输出一个通道为512的向量。
在这里插入图片描述
接着是连续的四个上采样模块:
在这里插入图片描述
对应于:
在这里插入图片描述
在SPADEResnetBlock内部:使用ACE类定义了SEAN块。
在这里插入图片描述
在ACE内部定义了归一化的参数和噪声等。
在这里插入图片描述
下面设计python正则表达式,没学过,下去补。只能先用debug获得结果。
在这里插入图片描述
这里使用SynchronizedBatchNorm2d进行归一化:
在这里插入图片描述
γ和β通过卷积获得:
在这里插入图片描述
执行完上采样的四个SEAN块之后,最后进过一个卷积输出合成图像。这就是整个network的流程。
生成器打印参数:
在这里插入图片描述
接着是判别器:
按照生成器的逻辑,target_class_name=multiscalediscriminator,module_name=models.networks.discriminator
然后我们导入判别器模块。
在这里插入图片描述
在多尺度判别器内部:创建两个single_discriminator。
在这里插入图片描述
在这里插入图片描述
在单个判别器内部定义参数:在这里插入图片描述
定义判别器的输入:将label通道和RGB图片拼接后输入。
在这里插入图片描述
接着经过一个4x4大小步长为2的卷积,再经过两个步长为2的卷积,最后再经过输出通道为1,步长为1的卷积。将每一个卷积都注册到模型中。
在这里插入图片描述
即判别器由五个卷积组成。
将单个判别器注册到判别器中。注册两次,这样盘比起由10个卷积组成,且都有对应的吗名称。
在这里插入图片描述

MultiscaleDiscriminator((discriminator_0): NLayerDiscriminator((model0): Sequential((0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))(1): LeakyReLU(negative_slope=0.2))(model1): Sequential((0): Sequential((0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False))(1): LeakyReLU(negative_slope=0.2))(model2): Sequential((0): Sequential((0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False))(1): LeakyReLU(negative_slope=0.2))(model3): Sequential((0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))))(discriminator_1): NLayerDiscriminator((model0): Sequential((0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))(1): LeakyReLU(negative_slope=0.2))(model1): Sequential((0): Sequential((0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False))(1): LeakyReLU(negative_slope=0.2))(model2): Sequential((0): Sequential((0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False))(1): LeakyReLU(negative_slope=0.2))(model3): Sequential((0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))))
)

这样生成器判别器都狗仔完毕,netE为空。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/65799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ansible学习笔记10

1、在group1的被管理机里的mariadb里创建一个abc库&#xff1b; 1&#xff09; 然后我们到agent主机上进行检查&#xff1a; 可以看到数据库已经创建成功。 再看几个其他命令&#xff1a; #a组主机重启mysql&#xff0c;并设置开机自启 ansible a -m service -a "namemy…

算法练习(10):牛客在线编程10 贪心算法

package jz.bm;import java.util.ArrayList; import java.util.Arrays;public class bm10 {/*** BM95 分糖果问题*/public int candy (int[] arr) {int res 0;int n arr.length;int[] nums new int[n];//每个人都分配一个糖果for (int i 0; i < n; i) {nums[i] 1;}//从…

HDMI 输出实验

FPGA教程学习 第十四章 HDMI 输出实验 文章目录 FPGA教程学习前言实验原理实验过程程序设计时钟模块&#xff08;video_pll&#xff09;彩条产生模块&#xff08;color_bar)配置数据查找表模块&#xff08;lut_adv7511&#xff09;I2C Master 寄存器配置模块&#xff08;i2c_c…

elasticSearch+kibana+logstash+filebeat集群改成https认证

文章目录 一、生成相关证书二、配置elasticSearh三、配置kibana四、配置logstash五、配置filebeat六、连接https es的java api 一、生成相关证书 ps&#xff1a;主节点操作 切换用户&#xff1a;su es 进入目录&#xff1a;cd /home/es/elasticsearch-7.6.2 创建文件&#x…

Pytest 框架执行用例流程浅谈

背景&#xff1a; 根据以下简单的代码示例&#xff0c;我们将从源码的角度分析其中的关键加载执行步骤&#xff0c;对pytest整体流程架构有个初步学习。 代码示例&#xff1a; import pytest def test_add(): assert 1 1 2 def test_sub(): assert 2 - 1 1 通过 pytes…

uniapp项目实践总结(八)自定义加载组件

有时候一个页面请求接口需要加载很长时间,这时候就需要一个加载页面来告知用户内容正在请求加载中,下面就写一个简单的自定义加载组件。 目录 准备工作逻辑思路实战演练效果预览准备工作 在之前的全局组件目录components下新建一个组件文件夹,命名为q-loading,组件为q-loa…

Adobe Illustrator 2023 for mac安装教程,可用。

Adobe Illustrator 是行业标准的矢量图形应用程序&#xff0c;可以为印刷、网络、视频和移动设备创建logos、图标、绘图、排版和插图。数以百万计的设计师和艺术家使用Illustrator CC创作&#xff0c;从网页图标和产品包装到书籍插图和广告牌。此版本是2023版本&#xff0c;适配…

LeetCode(力扣)236. 二叉树的最近公共祖先Python

LeetCode236. 二叉树的最近公共祖先 题目链接代码 题目链接 https://leetcode.cn/problems/lowest-common-ancestor-of-a-binary-tree/ 代码 # Definition for a binary tree node. # class TreeNode: # def __init__(self, x): # self.val x # self.…

C语言深入理解指针(非常详细)(二)

目录 指针运算指针-整数指针-指针指针的关系运算 野指针野指针成因指针未初始化指针越界访问指针指向的空间释放 如何规避野指针指针初始化注意指针越界指针不使用时就用NULL避免返回局部变量的地址 assert断言指针的使用和传址调用传址调用例子&#xff08;strlen函数的实现&a…

The Cherno——OpenGL

The Cherno——OpenGL 1. 欢迎来到OpenGL OpenGL是一种跨平台的图形接口&#xff08;API&#xff09;&#xff0c;就是一大堆我们能够调用的函数去做一些与图像相关的事情。特殊的是&#xff0c;OpenGL允许我们访问GPU&#xff08;Graphics Processing Unit 图像处理单元&…

pwngdb 中 b *$rebase(0x相对基址偏移) 是什么意思

pwngdb 中 b *$rebase(0x相对基址偏移) 是什么意思 pwngdb 是一个针对二进制漏洞利用的调试工具库&#xff0c;用于在 GDB 调试器中辅助进行漏洞开发和漏洞利用的调试。b *$rebase(0x相对基址偏移) 是 pwngdb 中的一个调试命令&#xff0c;用于在基地址重定位后设置断点。 在二…

Python小知识 - 如何使用Python的Flask框架快速开发Web应用

如何使用Python的Flask框架快速开发Web应用 现在越来越多的人把Python作为自己的第一语言来学习&#xff0c;Python的简洁易学的语法以及丰富的第三方库让人们越来越喜欢上了这门语言。本文将介绍如何使用Python的Flask框架快速开发Web应用。 Flask是一个使用Python编写的轻量级…

Spring Boot中通过maven进行多环境配置

上文 java Spring Boot将不同配置拆分入不同文件管理 中 我们说到了&#xff0c;多环境的多文件区分管理 说到多环境 其实不止我们 Spring Boot有 很多的东西都有 那么 这就有一个问题 如果 spring 和 maven 都配置了环境 而且他们配的不一样 那么 会用谁的呢&#xff1f; 此…

MySQL编写建表语句,如何优雅处理创建时间与更新时间

在 MySQL 中&#xff0c;可以使用 TIMESTAMP 或者 DATETIME 数据类型来存储日期和时间信息&#xff0c;并结合默认值和触发器来实现自动更新 createTime 和 updateTime 字段。 以下是一个示例建表语句&#xff0c;演示如何设置自动更新的 createTime 和 updateTime 字段&#…

《TCP/IP网络编程》阅读笔记--基于Windows实现Hello Word服务器端和客户端

目录 1--Hello Word服务器端 2--客户端 3--编译运行 3-1--编译服务器端 3-2--编译客户端 3-3--运行 1--Hello Word服务器端 // gcc hello_server_win.c -o hello_server_win -lwsock32 // hello_server_win 9190 #include <stdio.h> #include <stdlib.h> #i…

【算法刷题-双指针篇】

目录 1.leetcode-27. 移除元素2.leetcode-344. 反转字符串3.leetcode-剑指 Offer 05. 替换空格4.leetcode-206. 反转链表5.leetcode-19. 删除链表的倒数第 N 个结点6.leetcode-面试题 02.07. 链表相交7.leetcode-142. 环形链表 II8.leetcode-15. 三数之和9.leetcode-18. 四数之…

Git使用——GitHub项目回退版本

查看历史版本 使用git log命令查看项目的历史版本&#xff1a; 可以一直回车&#xff0c;直到找到想要的历史版本&#xff0c;复制commit后面的那一串id。 恢复历史版本 执行命令 git reset --hard 版本号&#xff1a; git reset --hard 39ac3ea2448e81ea992b7c4fdad9252983…

Ubuntu系统环境搭建(五)——Ubuntu安装maven

ubuntu环境搭建专栏&#x1f517;点击跳转 Ubuntu系统环境搭建&#xff08;五&#xff09;——Ubuntu安装maven 更新 sudo apt update安装 sudo apt install maven验证 mvn -version

ARM 汇编基础知识

1.为什么学习汇编&#xff1f; 我们在进行嵌入式 Linux 开发的时候是绝对要掌握基本的 ARM 汇编&#xff0c;因为 Cortex-A 芯片一 上电 SP 指针还没初始化&#xff0c; C 环境还没准备好&#xff0c;所以肯定不能运行 C 代码&#xff0c;必须先用汇编语言设置好 C 环境…

Python编程练习与解答 练习96:字符串是否表示整数

本练习将编写一个名为isInteger的函数&#xff0c;用于确定字符串中的字符是否代表有效整数&#xff0c;确定字符串是否表示整数时&#xff0c;则应忽略开通要或者结尾的任何空白。一旦这个空白被忽略&#xff0c;如果字符串的长度至少是1&#xff0c;且只包含数字&#xff0c;…