隐私计算实训营第二期第十课:基于SPU机器学习建模实践

隐私计算实训营第二期-第十课

  • 第十课:基于SPU机器学习建模实践
    • 1 隐私保护机器学习背景
      • 1.1 机器学习中隐私保护的需求
      • 1.2 PPML提供的技术解决方案
    • 2 SPU架构
      • 2.1 SPU前端
      • 2.2 SPU编译器
      • 2.3 SPU运行时
      • 2.4 SPU目标
    • 3 密态训练与推理
      • 3.1 四个基本问题
      • 3.2 解决数据来源问题
      • 3.3 解决数据安全问题
      • 3.4 解决模型计算问题
      • 3.5 解决密态计算问题
      • 3.6 如何应对更复杂的模型
      • 3.7 已有模型的复用
    • 4 作业实践
      • 4.1 基础NN模型作业
      • 4.2 进阶Transformer模型作业

第十课:基于SPU机器学习建模实践

首先必须感谢蚂蚁集团及隐语社区带来的隐私计算实训第二期的学习机会!
本节课由蚂蚁隐私计算部算法工程师吴豪奇老师讲解。

在这里插入图片描述
本节课主要内容为:

  • 隐私保护机器学习背景
  • SPU架构简介
  • NN密态训练/推理示例

1 隐私保护机器学习背景

1.1 机器学习中隐私保护的需求

本节课前两个小节的内容,我们这之前的课程中已有一些了解,
本节课可以回顾一下。
数据和模型的隐私保护需求是产生隐私保护机器学习的根因。

在这里插入图片描述

1.2 PPML提供的技术解决方案

MPC提供了隐私保护的技术解决方案。

在这里插入图片描述

使用MPC结合机器学习,为模型训练和推理提供隐私保护。
问题
我们是否可以直接以 MPC 的方式高效地运行已有的机器学习程序?

在这里插入图片描述

2 SPU架构

SPU架构我们在之前已经学习过,宏观上主要分为三部分:

  1. 前端部分
  2. 编译器
  3. 运行时

2.1 SPU前端

SPU前端尽量支持原生的AI编程方式,支持JAX、TensorFlow,Pytorch
等典型的AI编程框架。

在这里插入图片描述

2.2 SPU编译器

SPU的编译器以优化方式生成SPU的密态中间语言。

SPu

2.3 SPU运行时

SPU的运行时支持多种并行模式(数据并行+指令并行),多种MPC协议
以及多种部署模式。

在这里插入图片描述

2.4 SPU目标

SPU的最终目标是实现易用、可扩展和高性能的密态计算虚拟设备。

在这里插入图片描述

3 密态训练与推理

3.1 四个基本问题

密态的训练和推理需要解决的四个问题:

  • 数据从哪来?
  • 如何加密保护数据?
  • 如何定义模型计算?
  • 如何执行密态模型计算?

在这里插入图片描述

3.2 解决数据来源问题

数据由数据个参与方以密态的形式提供。

在这里插入图片描述

3.3 解决数据安全问题

数据安全通过MPC协议或者同态加密等外部模式解决。

在这里插入图片描述

3.4 解决模型计算问题

NN模型的计算问题通过JAX实现前向和反向传播。

在这里插入图片描述

3.5 解决密态计算问题

NN模型的密态计算SPU的编译器转换为密态算子,然后按照MPC协议
进行计算。

在这里插入图片描述

密态的计算过程与明文类似,通过SPU密态计算配置实现密态训练。
在这里插入图片描述

3.6 如何应对更复杂的模型

对于复杂模型,使用stax和flax来进行实现。

在这里插入图片描述

3.7 已有模型的复用

已有模型的复用问题,根据明文实现来进行密态计算的迁移。
比如,明文实现的GPT2模型。
在这里插入图片描述

然后进行密态迁移:

在这里插入图片描述

在支持不同的模型方面,SPU还需要更新和优化自己的实现以满足不同
模型的需求。

在这里插入图片描述

4 作业实践

4.1 基础NN模型作业

本次课程有两个作业,一个是基础的NN模型。另一个是进阶的Transformer
模型。

在这里插入图片描述

完成步骤如下:

1、加载数据集

import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizerdef breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):x, y = load_breast_cancer(return_X_y=True)x = (x - np.min(x)) / (np.max(x) - np.min(x))x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)if train:if party_id:if party_id == 1:return x_train[:, :15], _else:return x_train[:, 15:], y_trainelse:return x_train, y_trainelse:return x_test, y_test

2、定义模型

from typing import Sequence
import flax.linen as nnFEATURES = [30, 15, 8, 1]class MLP(nn.Module):features: Sequence[int]@nn.compactdef __call__(self, x):for feat in self.features[:-1]:x = nn.relu(nn.Dense(feat)(x))x = nn.Dense(self.features[-1])(x)return x

3、定义训练参数

import jax.numpy as jnpdef predict(params, x):# TODO(junfeng): investigate why need to have a duplicated definition in notebook,# which is not the case in a normal python program.from typing import Sequenceimport flax.linen as nnFEATURES = [30, 15, 8, 1]class MLP(nn.Module):features: Sequence[int]@nn.compactdef __call__(self, x):for feat in self.features[:-1]:x = nn.relu(nn.Dense(feat)(x))x = nn.Dense(self.features[-1])(x)return xreturn MLP(FEATURES).apply(params, x)def loss_func(params, x, y):pred = predict(params, x)def mse(y, pred):def squared_error(y, y_pred):return jnp.multiply(y - y_pred, y - y_pred) / 2.0return jnp.mean(squared_error(y, pred))return mse(y, pred)def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):x = jnp.concatenate((x1, x2), axis=1)xs = jnp.array_split(x, len(x) / n_batch, axis=0)ys = jnp.array_split(y, len(y) / n_batch, axis=0)def body_fun(_, loop_carry):params = loop_carryfor x, y in zip(xs, ys):_, grads = jax.value_and_grad(loss_func)(params, x, y)params = jax.tree_util.tree_map(lambda p, g: p - step_size * g, params, grads)return paramsparams = jax.lax.fori_loop(0, n_epochs, body_fun, params)return paramsdef model_init(n_batch=10):model = MLP(FEATURES)return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))

4、验证参数

from sklearn.metrics import roc_auc_score
def validate_model(params, X_test, y_test):y_pred = predict(params, X_test)return roc_auc_score(y_test, y_pred)

5、开始明文训练

import jax# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

这里输出的明文训练结果为:

在这里插入图片描述

6、开始密文训练

import secretflow as sf# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))# In case you have a running secretflow runtime already.
sf.shutdown()sf.init(['alice', 'bob'], address='local')alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)

7、检查参数

params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)

8、输出训练结果

X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

密文训练输出结果为:

在这里插入图片描述
可以看出,密文训练和明文训练的效果相同,本作业结束,

4.2 进阶Transformer模型作业

完成步骤如下:
1、安装Transformer模型

import sys
!{sys.executable} -m pip install transformers[flax] -i https://pypi.tuna.tsinghua.edu.cn/simple

2、设置镜像huggingface

import os
import sys
!{sys.executable} -m pip install huggingface_hub
os.environ['HF_ENDPOINT']='https://hf-mirror.com'

3、加载模型

from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config
tokenizer = AutoTokenizer.from_pretrained("gpt2")
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")

4、定义文本生成函数

def text_generation(input_ids, params):config = GPT2Config()model = FlaxGPT2LMHeadModel(config=config)for _ in range(10):outputs = model(input_ids=input_ids, params=params)next_token_logits = outputs[0][0, -1, :]next_token = jnp.argmax(next_token_logits)input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)return input_ids

5、进行明文的文本生成

import jax.numpy as jnpinputs_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
outputs_ids = text_generation(inputs_ids, pretrained_model.params)print('-' * 65 + '\nRun on CPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

生成的输出结果为:

在这里插入图片描述
6、进行密文训练

import secretflow as sf# In case you have a running secretflow runtime already.
sf.shutdown()sf.init(['alice', 'bob', 'carol'], address='local')alice, bob = sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])
conf['runtime_config']['fxp_exp_mode'] = 1
conf['runtime_config']['experimental_disable_mmul_split'] = True
spu = sf.SPU(conf)def get_model_params():pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")return pretrained_model.paramsdef get_token_ids():tokenizer = AutoTokenizer.from_pretrained("gpt2")return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')model_params = alice(get_model_params)()
input_token_ids = bob(get_token_ids)()device = spu
model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)output_token_ids = spu(text_generation)(input_token_ids_, model_params_)

这里由于机器配置不够,内存不足,被系统kill进程,导致无法完成训练。小伙伴们机器好的应该可以跑完。

在这里插入图片描述

7、输出密文训练结果

outputs_ids = sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on SPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

至此,本次作业全部结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/39049.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

全新升级!中央集中式架构功能测试为新车型保驾护航

“软件定义汽车”新时代下,整车电气电气架构向中央-区域集中式发展已成为行业共识,车型架构的变革带来更复杂的整车功能定义、更多的新技术的应用(如SOA服务化、智能配电等)和更短的车型研发周期,对整车和新产品研发的…

OkHttp的源码解读1

介绍 OkHttp 是 Square 公司开源的一款高效的 HTTP 客户端,用于与服务器进行 HTTP 请求和响应。它具有高效的连接池、透明的 GZIP 压缩和响应缓存等功能,是 Android 开发中广泛使用的网络库。 本文将详细解读 OkHttp 的源码,包括其主要组件…

Qt实现手动切换多种布局

引言 之前写了一个手动切换多个布局的程序,下面来记录一下。 程序运行效果如下: 示例 需求 通过点击程序界面上不同的布局按钮,使主工作区呈现出不同的页面布局,多个布局之间可以通过点击不同布局按钮切换。支持的最多的窗口…

burpsuite 设置监听窗口 火狐利用插件快速切换代理状态

一、修改burpsuite监听端口 1、首先打开burpsuite,点击Proxy下的Options选项: 2、可以看到默认的监听端口为8080,首先选中我们想要修改的监听,点击Edit进行编辑 3、将端口改为9876,并保存 4、可以看到监听端口修改成功…

typescript学习回顾(五)

今天来分享一下ts的泛型,最后来做一个练习 泛型 有时候,我们在书写某些函数的时候,会丢失一些类型信息,比如我下面有一个例子,我想提取一个数组的某个索引之前的所有数据 function getArraySomeData(newArr, n:numb…

JVM原理(十):JVM虚拟机调优分析与实战

1. 大内存硬件上的程序部署策略 这是笔者很久之前处理过的一个案例,但今天仍然具有代表性。一个15万PV/日左右的在线文档类型网站最近更换了硬件系统,服务器的硬件为四路志强处理器、16GB物理内存,操作系统为64位CentOS5.4,Resin…

阿里云centos 取消硬盘挂载并重建数据盘信息再次挂载

一、取消挂载 umount [挂载点或设备] 如果要取消挂载/dev/sdb1分区,可以使用以下命令: umount /dev/sdb1 如果要取消挂载在/mnt/mydisk的挂载点,可以使用以下命令: umount /mnt/mydisk 如果设备正忙,无法立即取消…

系统安全及应用(命令)

目录 一、账号安全控制 1.1 系统账号清理 1.2 密码安全控制 1.3 历史记录控制 1.4 终端自动注销 二、系统引导和登陆控制 2.1 限制su命令用户 2.2 PAM安全认证 示例一:通过pam 模块来防止暴力破解ssh 2.3 sudo机制提升权限 2.3.1 sudo命令(ro…

Java的日期类常用方法

Java_Date 第一代日期类 获取当前时间 Date date new Date(); System.out.printf("当前时间" date); 格式化时间信息 SimpleDateFormat simpleDateFormat new SimpleDateFormat("yyyy-mm-dd hh:mm:ss E); System.out.printf("格式化后时间" si…

【windows|012】光猫、路由器、交换机详解

🍁博主简介: 🏅云计算领域优质创作者 🏅2022年CSDN新星计划python赛道第一名 🏅2022年CSDN原力计划优质作者 ​ 🏅阿里云ACE认证高级工程师 ​ 🏅阿里云开发者社区专家博主 💊交流社…

windows USB 驱动开发-URB结构

通用串行总线 (USB) 客户端驱动程序无法直接与其设备通信。 相反,客户端驱动程序会创建请求并将其提交到 USB 驱动程序堆栈进行处理。 在每个请求中,客户端驱动程序提供一个可变长度的数据结构,称为 USB 请求块 (URB) ,URB 结构描…

ctfshow-web入门-命令执行(web75-web77)

目录 1、web75 2、web76 3、web77 1、web75 使用 glob 协议绕过 open_basedir&#xff0c;读取根目录下的文件&#xff0c;payload&#xff1a; c?><?php $anew DirectoryIterator("glob:///*"); foreach($a as $f) {echo($f->__toString(). ); } ex…

读书笔记-Java并发编程的艺术-第3章(Java内存模型)-第9节(Java内存模型综述)

3.9 Java内存模型综述 前面对Java内存模型的基础知识和内存模型的具体实现进行了说明。下面对Java内存模型的相关知识做一个总结。 3.9.1 处理器的内存模型 顺序一致性内存模型是一个理论参考模型&#xff0c;JMM和处理器内存模型在设计时通常会以顺序一致性内存模型为参照。…

C#/WPF 自制白板工具

随着电子屏幕技术的发展&#xff0c;普通的黑板已不再适用现在的教学和演示环境&#xff0c;电子白板应运而生。本篇使用WPF开发了一个电子白板工具&#xff0c;功能丰富&#xff0c;非常使用日常免费使用&#xff0c;或者进行再次开发。 示例代码如下&#xff1a; Stack<St…

拓扑排序[讲课留档]

拓扑排序 拓扑排序要解决的问题是给一个有向无环图的所有节点排序。 即在 A O E AOE AOE网中找关键路径。 前置芝士&#xff01; 有向图&#xff1a;有向图中的每一个边都是有向边&#xff0c;即其中的每一个元素都是有序二元组。在一条有向边 ( u , v ) (u,v) (u,v)中&…

ChatGPT 官方发布桌面端,向所有用户免费开放

Open AI 官方已经发布了适用于 macOS 的 ChatGPT 桌面端应用。 此前&#xff0c;该应用一直处于测试阶段&#xff0c;仅 Plus 付费订阅用户可以使用。 目前已面向所有用户开放&#xff0c;所有 Mac 用户均可免费下载使用。 我们可以访问官网下载安装包&#xff1a;https://op…

2024 年江西省研究生数学建模竞赛题目 B题投标中的竞争策略问题--完整思路、代码结果分享(仅供学习)

招投标问题是企业运营过程中必须面对的基本问题之一。现有的招投标平台有国家级的&#xff0c;也有地方性的。在招投标过程中&#xff0c;企业需要全面了解招标公告中的相关信息&#xff0c;在遵守招投标各种规范和制度的基础上&#xff0c;选择有效的竞争策略和技巧&#xff0…

基于JSP技术的校园餐厅管理系统

开头语&#xff1a; 你好呀&#xff0c;我是计算机学长猫哥&#xff01;如果您对校园餐厅管理系统感兴趣或有相关需求&#xff0c;欢迎随时联系我。我的联系方式在文末&#xff0c;期待与您交流&#xff01; 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#x…

QT的编译过程

qmake -project 用于从源代码生成项目文件&#xff0c;qmake 用于从项目文件生成 Makefile&#xff0c;而 make 用于根据 Makefile 构建项目。 详细解释&#xff1a; qmake -project 这个命令用于从源代码目录生成一个初始的 Qt 项目文件&#xff08;.pro 文件&#xff09;。它…

Keil5中:出现:failed to execute ‘...\ARMCC\bin\ArmCC‘

点三个点&#xff0c;去自己的磁盘找自己的ARM\ARMCC\bin