隐私计算实训营第二期第十课:基于SPU机器学习建模实践

隐私计算实训营第二期-第十课

  • 第十课:基于SPU机器学习建模实践
    • 1 隐私保护机器学习背景
      • 1.1 机器学习中隐私保护的需求
      • 1.2 PPML提供的技术解决方案
    • 2 SPU架构
      • 2.1 SPU前端
      • 2.2 SPU编译器
      • 2.3 SPU运行时
      • 2.4 SPU目标
    • 3 密态训练与推理
      • 3.1 四个基本问题
      • 3.2 解决数据来源问题
      • 3.3 解决数据安全问题
      • 3.4 解决模型计算问题
      • 3.5 解决密态计算问题
      • 3.6 如何应对更复杂的模型
      • 3.7 已有模型的复用
    • 4 作业实践
      • 4.1 基础NN模型作业
      • 4.2 进阶Transformer模型作业

第十课:基于SPU机器学习建模实践

首先必须感谢蚂蚁集团及隐语社区带来的隐私计算实训第二期的学习机会!
本节课由蚂蚁隐私计算部算法工程师吴豪奇老师讲解。

在这里插入图片描述
本节课主要内容为:

  • 隐私保护机器学习背景
  • SPU架构简介
  • NN密态训练/推理示例

1 隐私保护机器学习背景

1.1 机器学习中隐私保护的需求

本节课前两个小节的内容,我们这之前的课程中已有一些了解,
本节课可以回顾一下。
数据和模型的隐私保护需求是产生隐私保护机器学习的根因。

在这里插入图片描述

1.2 PPML提供的技术解决方案

MPC提供了隐私保护的技术解决方案。

在这里插入图片描述

使用MPC结合机器学习,为模型训练和推理提供隐私保护。
问题
我们是否可以直接以 MPC 的方式高效地运行已有的机器学习程序?

在这里插入图片描述

2 SPU架构

SPU架构我们在之前已经学习过,宏观上主要分为三部分:

  1. 前端部分
  2. 编译器
  3. 运行时

2.1 SPU前端

SPU前端尽量支持原生的AI编程方式,支持JAX、TensorFlow,Pytorch
等典型的AI编程框架。

在这里插入图片描述

2.2 SPU编译器

SPU的编译器以优化方式生成SPU的密态中间语言。

SPu

2.3 SPU运行时

SPU的运行时支持多种并行模式(数据并行+指令并行),多种MPC协议
以及多种部署模式。

在这里插入图片描述

2.4 SPU目标

SPU的最终目标是实现易用、可扩展和高性能的密态计算虚拟设备。

在这里插入图片描述

3 密态训练与推理

3.1 四个基本问题

密态的训练和推理需要解决的四个问题:

  • 数据从哪来?
  • 如何加密保护数据?
  • 如何定义模型计算?
  • 如何执行密态模型计算?

在这里插入图片描述

3.2 解决数据来源问题

数据由数据个参与方以密态的形式提供。

在这里插入图片描述

3.3 解决数据安全问题

数据安全通过MPC协议或者同态加密等外部模式解决。

在这里插入图片描述

3.4 解决模型计算问题

NN模型的计算问题通过JAX实现前向和反向传播。

在这里插入图片描述

3.5 解决密态计算问题

NN模型的密态计算SPU的编译器转换为密态算子,然后按照MPC协议
进行计算。

在这里插入图片描述

密态的计算过程与明文类似,通过SPU密态计算配置实现密态训练。
在这里插入图片描述

3.6 如何应对更复杂的模型

对于复杂模型,使用stax和flax来进行实现。

在这里插入图片描述

3.7 已有模型的复用

已有模型的复用问题,根据明文实现来进行密态计算的迁移。
比如,明文实现的GPT2模型。
在这里插入图片描述

然后进行密态迁移:

在这里插入图片描述

在支持不同的模型方面,SPU还需要更新和优化自己的实现以满足不同
模型的需求。

在这里插入图片描述

4 作业实践

4.1 基础NN模型作业

本次课程有两个作业,一个是基础的NN模型。另一个是进阶的Transformer
模型。

在这里插入图片描述

完成步骤如下:

1、加载数据集

import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizerdef breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):x, y = load_breast_cancer(return_X_y=True)x = (x - np.min(x)) / (np.max(x) - np.min(x))x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)if train:if party_id:if party_id == 1:return x_train[:, :15], _else:return x_train[:, 15:], y_trainelse:return x_train, y_trainelse:return x_test, y_test

2、定义模型

from typing import Sequence
import flax.linen as nnFEATURES = [30, 15, 8, 1]class MLP(nn.Module):features: Sequence[int]@nn.compactdef __call__(self, x):for feat in self.features[:-1]:x = nn.relu(nn.Dense(feat)(x))x = nn.Dense(self.features[-1])(x)return x

3、定义训练参数

import jax.numpy as jnpdef predict(params, x):# TODO(junfeng): investigate why need to have a duplicated definition in notebook,# which is not the case in a normal python program.from typing import Sequenceimport flax.linen as nnFEATURES = [30, 15, 8, 1]class MLP(nn.Module):features: Sequence[int]@nn.compactdef __call__(self, x):for feat in self.features[:-1]:x = nn.relu(nn.Dense(feat)(x))x = nn.Dense(self.features[-1])(x)return xreturn MLP(FEATURES).apply(params, x)def loss_func(params, x, y):pred = predict(params, x)def mse(y, pred):def squared_error(y, y_pred):return jnp.multiply(y - y_pred, y - y_pred) / 2.0return jnp.mean(squared_error(y, pred))return mse(y, pred)def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):x = jnp.concatenate((x1, x2), axis=1)xs = jnp.array_split(x, len(x) / n_batch, axis=0)ys = jnp.array_split(y, len(y) / n_batch, axis=0)def body_fun(_, loop_carry):params = loop_carryfor x, y in zip(xs, ys):_, grads = jax.value_and_grad(loss_func)(params, x, y)params = jax.tree_util.tree_map(lambda p, g: p - step_size * g, params, grads)return paramsparams = jax.lax.fori_loop(0, n_epochs, body_fun, params)return paramsdef model_init(n_batch=10):model = MLP(FEATURES)return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))

4、验证参数

from sklearn.metrics import roc_auc_score
def validate_model(params, X_test, y_test):y_pred = predict(params, X_test)return roc_auc_score(y_test, y_pred)

5、开始明文训练

import jax# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

这里输出的明文训练结果为:

在这里插入图片描述

6、开始密文训练

import secretflow as sf# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))# In case you have a running secretflow runtime already.
sf.shutdown()sf.init(['alice', 'bob'], address='local')alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)

7、检查参数

params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)

8、输出训练结果

X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

密文训练输出结果为:

在这里插入图片描述
可以看出,密文训练和明文训练的效果相同,本作业结束,

4.2 进阶Transformer模型作业

完成步骤如下:
1、安装Transformer模型

import sys
!{sys.executable} -m pip install transformers[flax] -i https://pypi.tuna.tsinghua.edu.cn/simple

2、设置镜像huggingface

import os
import sys
!{sys.executable} -m pip install huggingface_hub
os.environ['HF_ENDPOINT']='https://hf-mirror.com'

3、加载模型

from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config
tokenizer = AutoTokenizer.from_pretrained("gpt2")
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")

4、定义文本生成函数

def text_generation(input_ids, params):config = GPT2Config()model = FlaxGPT2LMHeadModel(config=config)for _ in range(10):outputs = model(input_ids=input_ids, params=params)next_token_logits = outputs[0][0, -1, :]next_token = jnp.argmax(next_token_logits)input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)return input_ids

5、进行明文的文本生成

import jax.numpy as jnpinputs_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
outputs_ids = text_generation(inputs_ids, pretrained_model.params)print('-' * 65 + '\nRun on CPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

生成的输出结果为:

在这里插入图片描述
6、进行密文训练

import secretflow as sf# In case you have a running secretflow runtime already.
sf.shutdown()sf.init(['alice', 'bob', 'carol'], address='local')alice, bob = sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])
conf['runtime_config']['fxp_exp_mode'] = 1
conf['runtime_config']['experimental_disable_mmul_split'] = True
spu = sf.SPU(conf)def get_model_params():pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")return pretrained_model.paramsdef get_token_ids():tokenizer = AutoTokenizer.from_pretrained("gpt2")return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')model_params = alice(get_model_params)()
input_token_ids = bob(get_token_ids)()device = spu
model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)output_token_ids = spu(text_generation)(input_token_ids_, model_params_)

这里由于机器配置不够,内存不足,被系统kill进程,导致无法完成训练。小伙伴们机器好的应该可以跑完。

在这里插入图片描述

7、输出密文训练结果

outputs_ids = sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on SPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

至此,本次作业全部结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/39049.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

全新升级!中央集中式架构功能测试为新车型保驾护航

“软件定义汽车”新时代下,整车电气电气架构向中央-区域集中式发展已成为行业共识,车型架构的变革带来更复杂的整车功能定义、更多的新技术的应用(如SOA服务化、智能配电等)和更短的车型研发周期,对整车和新产品研发的…

OkHttp的源码解读1

介绍 OkHttp 是 Square 公司开源的一款高效的 HTTP 客户端,用于与服务器进行 HTTP 请求和响应。它具有高效的连接池、透明的 GZIP 压缩和响应缓存等功能,是 Android 开发中广泛使用的网络库。 本文将详细解读 OkHttp 的源码,包括其主要组件…

Qt实现手动切换多种布局

引言 之前写了一个手动切换多个布局的程序,下面来记录一下。 程序运行效果如下: 示例 需求 通过点击程序界面上不同的布局按钮,使主工作区呈现出不同的页面布局,多个布局之间可以通过点击不同布局按钮切换。支持的最多的窗口…

如何使用 AppML

如何使用 AppML AppML(Application Markup Language)是一种轻量级的标记语言,旨在简化Web应用的创建和部署过程。它允许开发者通过XML或JSON格式的配置文件来定义应用的结构和行为,从而实现快速开发和灵活扩展。AppML特别适用于构建数据驱动的企业级应用,它可以与各种后端…

pytorch跑手写体实验

目录 1、环境条件 2、代码实现 3、总结 1、环境条件 pycharm编译器pytorch依赖matplotlib依赖numpy依赖等等 2、代码实现 import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import matpl…

burpsuite 设置监听窗口 火狐利用插件快速切换代理状态

一、修改burpsuite监听端口 1、首先打开burpsuite,点击Proxy下的Options选项: 2、可以看到默认的监听端口为8080,首先选中我们想要修改的监听,点击Edit进行编辑 3、将端口改为9876,并保存 4、可以看到监听端口修改成功…

typescript学习回顾(五)

今天来分享一下ts的泛型,最后来做一个练习 泛型 有时候,我们在书写某些函数的时候,会丢失一些类型信息,比如我下面有一个例子,我想提取一个数组的某个索引之前的所有数据 function getArraySomeData(newArr, n:numb…

JVM原理(十):JVM虚拟机调优分析与实战

1. 大内存硬件上的程序部署策略 这是笔者很久之前处理过的一个案例,但今天仍然具有代表性。一个15万PV/日左右的在线文档类型网站最近更换了硬件系统,服务器的硬件为四路志强处理器、16GB物理内存,操作系统为64位CentOS5.4,Resin…

js数组方法归纳——concat、join、reverse

1、concat( ) 用途:可以连接两个或多个数组,并将新的数组返回该方法不会对原数组产生影响 var arr ["孙悟空","猪八戒","沙和尚"];var arr2 ["白骨精","玉兔精","蜘蛛精"];var arr3 [&…

Vue Router的深度解析

引言 在现代Web应用开发中,客户端路由已成为实现流畅用户体验的关键技术。与传统的服务器端路由不同,客户端路由通过JavaScript在浏览器中控制页面内容的更新,避免了页面的全量刷新。Vue Router作为Vue.js官方的路由解决方案,以其…

阿里云centos 取消硬盘挂载并重建数据盘信息再次挂载

一、取消挂载 umount [挂载点或设备] 如果要取消挂载/dev/sdb1分区,可以使用以下命令: umount /dev/sdb1 如果要取消挂载在/mnt/mydisk的挂载点,可以使用以下命令: umount /mnt/mydisk 如果设备正忙,无法立即取消…

【Spring Boot】简单了解spring boot支持的三种服务器

Tomcat 概述:Tomcat 是 Apache 软件基金会(Apache Software Foundation)的 Jakarta EE 项目中的一个核心项目,由 Apache、Sun 和其他一些公司及个人共同开发而成。它作为 Java Servlet、JSP、JavaServer Pages Expression Languag…

系统安全及应用(命令)

目录 一、账号安全控制 1.1 系统账号清理 1.2 密码安全控制 1.3 历史记录控制 1.4 终端自动注销 二、系统引导和登陆控制 2.1 限制su命令用户 2.2 PAM安全认证 示例一:通过pam 模块来防止暴力破解ssh 2.3 sudo机制提升权限 2.3.1 sudo命令(ro…

Java的日期类常用方法

Java_Date 第一代日期类 获取当前时间 Date date new Date(); System.out.printf("当前时间" date); 格式化时间信息 SimpleDateFormat simpleDateFormat new SimpleDateFormat("yyyy-mm-dd hh:mm:ss E); System.out.printf("格式化后时间" si…

【windows|012】光猫、路由器、交换机详解

🍁博主简介: 🏅云计算领域优质创作者 🏅2022年CSDN新星计划python赛道第一名 🏅2022年CSDN原力计划优质作者 ​ 🏅阿里云ACE认证高级工程师 ​ 🏅阿里云开发者社区专家博主 💊交流社…

windows USB 驱动开发-URB结构

通用串行总线 (USB) 客户端驱动程序无法直接与其设备通信。 相反,客户端驱动程序会创建请求并将其提交到 USB 驱动程序堆栈进行处理。 在每个请求中,客户端驱动程序提供一个可变长度的数据结构,称为 USB 请求块 (URB) ,URB 结构描…

ctfshow-web入门-命令执行(web75-web77)

目录 1、web75 2、web76 3、web77 1、web75 使用 glob 协议绕过 open_basedir&#xff0c;读取根目录下的文件&#xff0c;payload&#xff1a; c?><?php $anew DirectoryIterator("glob:///*"); foreach($a as $f) {echo($f->__toString(). ); } ex…

读书笔记-Java并发编程的艺术-第3章(Java内存模型)-第9节(Java内存模型综述)

3.9 Java内存模型综述 前面对Java内存模型的基础知识和内存模型的具体实现进行了说明。下面对Java内存模型的相关知识做一个总结。 3.9.1 处理器的内存模型 顺序一致性内存模型是一个理论参考模型&#xff0c;JMM和处理器内存模型在设计时通常会以顺序一致性内存模型为参照。…

ORB-SLAM2 安装编译运行(非 ROS)

安装编译 必备安装工具 主要包括 cmake 、 git 、 gcc 、 g gcc 的全称是 GNU Compiler Collection&#xff0c;它是由 GNU 推出的一款功能强大的、性能优越的 多平台编译器&#xff0c;是一个能够编译多种语言的编译器。最开始 gcc 是作为 C 语言的编译器&#xff08;GNU …

如何将等保2.0的要求融入日常安全运维实践中?

等保2.0的基本要求 等保2.0是中国网络安全领域的基本国策和基本制度&#xff0c;它要求网络运营商按照网络安全等级保护制度的要求&#xff0c;履行相关的安全保护义务。等保2.0的实施得到了《中华人民共和国网络安全法》等法律法规的支持&#xff0c;要求相关行业和单位必须按…