【域适应】基于域分离网络的MNIST数据10分类典型方法实现

关于

大规模数据收集和注释的成本通常使得将机器学习算法应用于新任务或数据集变得异常昂贵。规避这一成本的一种方法是在合成数据上训练模型,其中自动提供注释。尽管它们很有吸引力,但此类模型通常无法从合成图像推广到真实图像,因此需要域适应算法来操纵这些模型,然后才能成功应用。现有的方法要么侧重于将表示从一个域映射到另一个域,要么侧重于学习提取对于提取它们的域而言不变的特征。然而,通过只关注在两个域之间创建映射或共享表示,他们忽略了每个域的单独特征。域分离网络可以实现对每个域的独特之处进行特征建模,,同时进行模型域不变特征的提取

参考文章: https://arxiv.org/abs/1608.06019

工具

 

方法实现

数据集定义
import torch.utils.data as data
from PIL import Image
import osclass GetLoader(data.Dataset):def __init__(self, data_root, data_list, transform=None):self.root = data_rootself.transform = transformf = open(data_list, 'r')data_list = f.readlines()f.close()self.n_data = len(data_list)self.img_paths = []self.img_labels = []for data in data_list:self.img_paths.append(data[:-3])self.img_labels.append(data[-2])def __getitem__(self, item):img_paths, labels = self.img_paths[item], self.img_labels[item]imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')if self.transform is not None:imgs = self.transform(imgs)labels = int(labels)return imgs, labelsdef __len__(self):return self.n_data
模型搭建
import torch.nn as nn
from functions import ReverseLayerFclass DSN(nn.Module):def __init__(self, code_size=100, n_class=10):super(DSN, self).__init__()self.code_size = code_size########################################### private source encoder##########################################self.source_encoder_conv = nn.Sequential()self.source_encoder_conv.add_module('conv_pse1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,padding=2))self.source_encoder_conv.add_module('ac_pse1', nn.ReLU(True))self.source_encoder_conv.add_module('pool_pse1', nn.MaxPool2d(kernel_size=2, stride=2))self.source_encoder_conv.add_module('conv_pse2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5,padding=2))self.source_encoder_conv.add_module('ac_pse2', nn.ReLU(True))self.source_encoder_conv.add_module('pool_pse2', nn.MaxPool2d(kernel_size=2, stride=2))self.source_encoder_fc = nn.Sequential()self.source_encoder_fc.add_module('fc_pse3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size))self.source_encoder_fc.add_module('ac_pse3', nn.ReLU(True))########################################## private target encoder#########################################self.target_encoder_conv = nn.Sequential()self.target_encoder_conv.add_module('conv_pte1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,padding=2))self.target_encoder_conv.add_module('ac_pte1', nn.ReLU(True))self.target_encoder_conv.add_module('pool_pte1', nn.MaxPool2d(kernel_size=2, stride=2))self.target_encoder_conv.add_module('conv_pte2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5,padding=2))self.target_encoder_conv.add_module('ac_pte2', nn.ReLU(True))self.target_encoder_conv.add_module('pool_pte2', nn.MaxPool2d(kernel_size=2, stride=2))self.target_encoder_fc = nn.Sequential()self.target_encoder_fc.add_module('fc_pte3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size))self.target_encoder_fc.add_module('ac_pte3', nn.ReLU(True))################################# shared encoder (dann_mnist)################################self.shared_encoder_conv = nn.Sequential()self.shared_encoder_conv.add_module('conv_se1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,padding=2))self.shared_encoder_conv.add_module('ac_se1', nn.ReLU(True))self.shared_encoder_conv.add_module('pool_se1', nn.MaxPool2d(kernel_size=2, stride=2))self.shared_encoder_conv.add_module('conv_se2', nn.Conv2d(in_channels=32, out_channels=48, kernel_size=5,padding=2))self.shared_encoder_conv.add_module('ac_se2', nn.ReLU(True))self.shared_encoder_conv.add_module('pool_se2', nn.MaxPool2d(kernel_size=2, stride=2))self.shared_encoder_fc = nn.Sequential()self.shared_encoder_fc.add_module('fc_se3', nn.Linear(in_features=7 * 7 * 48, out_features=code_size))self.shared_encoder_fc.add_module('ac_se3', nn.ReLU(True))# classify 10 numbersself.shared_encoder_pred_class = nn.Sequential()self.shared_encoder_pred_class.add_module('fc_se4', nn.Linear(in_features=code_size, out_features=100))self.shared_encoder_pred_class.add_module('relu_se4', nn.ReLU(True))self.shared_encoder_pred_class.add_module('fc_se5', nn.Linear(in_features=100, out_features=n_class))self.shared_encoder_pred_domain = nn.Sequential()self.shared_encoder_pred_domain.add_module('fc_se6', nn.Linear(in_features=100, out_features=100))self.shared_encoder_pred_domain.add_module('relu_se6', nn.ReLU(True))# classify two domainself.shared_encoder_pred_domain.add_module('fc_se7', nn.Linear(in_features=100, out_features=2))####################################### shared decoder (small decoder)######################################self.shared_decoder_fc = nn.Sequential()self.shared_decoder_fc.add_module('fc_sd1', nn.Linear(in_features=code_size, out_features=588))self.shared_decoder_fc.add_module('relu_sd1', nn.ReLU(True))self.shared_decoder_conv = nn.Sequential()self.shared_decoder_conv.add_module('conv_sd2', nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5,padding=2))self.shared_decoder_conv.add_module('relu_sd2', nn.ReLU())self.shared_decoder_conv.add_module('conv_sd3', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5,padding=2))self.shared_decoder_conv.add_module('relu_sd3', nn.ReLU())self.shared_decoder_conv.add_module('us_sd4', nn.Upsample(scale_factor=2))self.shared_decoder_conv.add_module('conv_sd5', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3,padding=1))self.shared_decoder_conv.add_module('relu_sd5', nn.ReLU(True))self.shared_decoder_conv.add_module('conv_sd6', nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3,padding=1))def forward(self, input_data, mode, rec_scheme, p=0.0):result = []if mode == 'source':# source private encoderprivate_feat = self.source_encoder_conv(input_data)private_feat = private_feat.view(-1, 64 * 7 * 7)private_code = self.source_encoder_fc(private_feat)elif mode == 'target':# target private encoderprivate_feat = self.target_encoder_conv(input_data)private_feat = private_feat.view(-1, 64 * 7 * 7)private_code = self.target_encoder_fc(private_feat)result.append(private_code)# shared encodershared_feat = self.shared_encoder_conv(input_data)shared_feat = shared_feat.view(-1, 48 * 7 * 7)shared_code = self.shared_encoder_fc(shared_feat)result.append(shared_code)reversed_shared_code = ReverseLayerF.apply(shared_code, p)domain_label = self.shared_encoder_pred_domain(reversed_shared_code)result.append(domain_label)if mode == 'source':class_label = self.shared_encoder_pred_class(shared_code)result.append(class_label)# shared decoderif rec_scheme == 'share':union_code = shared_codeelif rec_scheme == 'all':union_code = private_code + shared_codeelif rec_scheme == 'private':union_code = private_coderec_vec = self.shared_decoder_fc(union_code)rec_vec = rec_vec.view(-1, 3, 14, 14)rec_code = self.shared_decoder_conv(rec_vec)result.append(rec_code)return result
 模型训练
import random
import os
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms
from model_compat import DSN
from data_loader import GetLoader
from functions import SIMSE, DiffLoss, MSE
from test import test######################
# params             #
######################source_image_root = os.path.join('.', 'dataset', 'mnist')
target_image_root = os.path.join('.', 'dataset', 'mnist_m')
model_root = 'model'
cuda = True
cudnn.benchmark = True
lr = 1e-2
batch_size = 32
image_size = 28
n_epoch = 100
step_decay_weight = 0.95
lr_decay_step = 20000
active_domain_loss_step = 10000
weight_decay = 1e-6
alpha_weight = 0.01
beta_weight = 0.075
gamma_weight = 0.25
momentum = 0.9manual_seed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)#######################
# load data           #
#######################img_transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])dataset_source = datasets.MNIST(root=source_image_root,train=True,transform=img_transform
)dataloader_source = torch.utils.data.DataLoader(dataset=dataset_source,batch_size=batch_size,shuffle=True,num_workers=8
)train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt')dataset_target = GetLoader(data_root=os.path.join(target_image_root, 'mnist_m_train'),data_list=train_list,transform=img_transform
)dataloader_target = torch.utils.data.DataLoader(dataset=dataset_target,batch_size=batch_size,shuffle=True,num_workers=8
)#####################
#  load model       #
#####################my_net = DSN()#####################
# setup optimizer   #
#####################def exp_lr_scheduler(optimizer, step, init_lr=lr, lr_decay_step=lr_decay_step, step_decay_weight=step_decay_weight):# Decay learning rate by a factor of step_decay_weight every lr_decay_stepcurrent_lr = init_lr * (step_decay_weight ** (step / lr_decay_step))if step % lr_decay_step == 0:print 'learning rate is set to %f' % current_lrfor param_group in optimizer.param_groups:param_group['lr'] = current_lrreturn optimizeroptimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)loss_classification = torch.nn.CrossEntropyLoss()
loss_recon1 = MSE()
loss_recon2 = SIMSE()
loss_diff = DiffLoss()
loss_similarity = torch.nn.CrossEntropyLoss()if cuda:my_net = my_net.cuda()loss_classification = loss_classification.cuda()loss_recon1 = loss_recon1.cuda()loss_recon2 = loss_recon2.cuda()loss_diff = loss_diff.cuda()loss_similarity = loss_similarity.cuda()for p in my_net.parameters():p.requires_grad = True#############################
# training network          #
#############################
 MNIST数据重建/共有部分特征/私有数据特征可视化

 MNIST_m数据重建/共有部分特征/私有数据特征可视化

代码获取

相关问题和项目开发,欢迎私信交流和沟通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/815827.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在Mac上更好的运行Windows,推荐这几款Mac虚拟机 mac运行windows虚拟机性能

想要在Mac OS上更好的运行Windows系统吗?推荐你使用mac虚拟机。虚拟机通过生成现有操作系统的全新虚拟镜像,它具有真实windows系统完全一样的功能,进入虚拟系统后,所有操作都是在这个全新的独立的虚拟系统里面进行,可以…

vue列表列表过滤

对已知的列表进行数据过滤(根据输入框里面的内容进行数据过滤) 编写案例 通过案例来演示说明 效果就是这样的 输入框是模糊查询 想要实现功能,其实就两大步,1获取输入框内容 2根据输入内容进行数据过滤 绑定收集数据 我们可以使用v-model去双向绑定 …

深入理解Cortex-M7 SVC和PendSV

1前言 1.1 PendSV 在ARM V7上,PendSV用来作为RTOS调度器的御用通道,上下文切换,任务调度都是在其ISR中实现的。所谓pend,字面意思即有悬起等待的意思,ARM官方也明确说明,PendSV应该在其他异常处理完毕后执…

python的算术运算符

python常用算术运算符代码如下: #算术运算符操作 x 10 y 20 z 30 #加法运算 a x y print("a的值为:", a) #减法运算 a x - y print("a的值为:", a) #乘法运算 a x*y print("a的值为:", a) …

计算机网络——ARP协议

前言 本博客是博主用于复习计算机网络的博客,如果疏忽出现错误,还望各位指正。 这篇博客是在B站掌芝士zzs这个UP主的视频的总结,讲的非常好。 可以先去看一篇视频,再来参考这篇笔记(或者说直接偷走)。 …

OpenCV4.9​​​​基本阈值操作

目标 在本教程中,您将学习如何: 使用 OpenCV 函数 cv::threshold 执行基本阈值操作 理论依据 注意 下面的解释属于 Bradski 和 Kaehler 的 Learning OpenCV 一书 阈值? 最简单的分割方法应用示例:分…

步骤大全:网站建设3个基本流程详解

一.领取一个免费域名和SSL证书,和CDN 1.打开网站链接:https://www.rainyun.com/z22_ 2.在网站主页上,您会看到一个"登陆/注册"的选项。 3.点击"登陆/注册",然后选择"微信登录"选项。 4.使用您的…

Claude3和GPT4哪个强?

在短短两个月内,全球最强人工智能的桂冠再次易主。此前,Claude3 Opus以其卓越的表现超越了GPT-4,吸引了无数用户抛弃GPT,转而拥抱Claude3。然而,OpenAI近日强势回归,用实力证明了GPT依然是人工智能领域的霸…

Jmeter杂记:测试计划参数详解

测试计划各参数详解 1,用户自定义变量,是全局变量,供所有线程组使用,可用配置元件:用户自定义变量替代 2,连续的运行线程组,默认不勾选,则随机的运行多个线程组中的取样器&#xff…

图机器学习NetworkX代码实战-创建图和可视化

完整代码见资源,下面列举了其中的几个图 安装networkX及相应工具包 pip install numpy pandas matplotlib tqdm networkx 当安装完成后,输入如下代码验证版本及是否安装成功 import networkx as nxnx.__version__ import matplotlib.pyplot as plt …

国内ai人工智能软件大全

很多人一直在寻找一个稳定且可靠的全球AI大模型测试平台,希望它不仅真实可信,而且能提供稳定、快速的服务,不会频繁出现故障或响应缓慢。迄今为止,我已经尝试了国内外至少10个不同的服务站点。不幸的是,这些站点总是存…

Linux 文件页反向映射

0. 引言 操作系统中与匿名页相对的是文件页,文件页的反向映射对比匿名页的反向映射更为简单。如果还不清楚匿名页反向映射逻辑的,请移步 匿名页反向映射 1. 文件页反向映射数据结构 struct file: 用户进程每open()一次文件,则会生…

分享一个 git stash 的实际使用场景。

当我将新的变更记录提交为 git commit --amend 后,发现这需要修改云端上的提交记录,也就是 vscode 中会出现这张图 于是,我通过 git reset head^ 撤销掉刚刚的提交。 reset 前: reset 后: 但在撤销的同时&#xf…

深入理解计算机网络分层结构

一、 为什么要分层? 计算机网络分层的主要目的是将复杂的网络通信过程分解为多个相互独立的层次,每个层次负责特定的功能。这样做有以下几个好处: 模块化设计:每个层次都有清晰定义的功能和接口,使得网络系统更易于设…

解决Xshell登录云服务器的免密码和云服务器生成子用户问题

Xshell登录云服务器的免密码问题 前言一、Xshell登录云服务器的免密码操作实践 二、centos创建用户创建用户实操删除用户更改用户密码直接删除子用户 前言 Xshell登录云服务器免密码问题的解决方案通常涉及使用SSH密钥对。用户生成一对密钥(公钥和私钥)…

Spring源码刨析之配置文件的解析和bean的创建以及生命周期

public void test1(){XmlBeanFactory xmlBeanFactory new XmlBeanFactory(new ClassPathResource("applicationContext.xml"));user u xmlBeanFactory.getBean("user",org.xhpcd.user.class);// System.out.println(u.getStu());}先介绍一个类XmlBeanFac…

Linux —— FTP服务【从0-1】

目录 一、介绍 1.概述 2.FTP的传输模式 PORT 主动模式 PASV 被动模式 3.FTP服务的作用 二、搭建FTP服务器 FTP服务端配置 1.安装vsftpd文件服务 2.启动服务 3.防火墙配置 4.FTP服务相关文件说明 FTP客户端配置 1.安装FTP客户端工具 lftp 2.访问FTP服务器 Linux系…

深度学习图像处理基础工具——opencv 实战2 文档扫描OCR

输入一个文档,怎么进行文档扫描,输出扫描后的图片呢? 今天学习了 opencv实战项目 文档扫描OCR 问题重构:输入图像 是一个含有文档的图像——> 目标是将其转化为 规则的扫描图片 那么怎么实现呢? 问题分解&#…

Python 复杂密码图形化生成工具,支持选择生成10位和12位复杂密码(初版)

代码 #!/usr/bin/env python # -*- coding: utf-8 -*- # Time : 2024/3/26 15:22 # Author : wyq # File : 部署测试.py import random import string from tkinter import *def generate_password(length):characters string.ascii_letters string.digits string.p…

【vue】Vue3开发中常用的VSCode插件

Vue - Official:vue的语法特性,如代码高亮,自动补全等 Vue VSCode Snippets:自定义一些代码片段 v3单文件组件vdata数据vmethod方法 别名路径跳转 参考 https://www.bilibili.com/video/BV1nV411Q7RX