LoRA低秩自适应微调技术原理及实践

大型语言模型的低秩自适应 (LoRA) 用于解决微调大型语言模型 (LLM) 的挑战。GPT 和 Llama 等模型拥有数十亿个参数,通常对于特定任务或领域进行微调的成本过高。LoRA 保留了预训练的模型权重,并在每个模型块中加入了可训练层。这显著减少了需要微调的参数数量,并大大降低了 GPU 内存需求。LoRA 的主要优势在于,它大幅减少了可训练参数的数量——有时最多可减少 10,000 倍——从而大大降低了 GPU 资源需求。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

1、LoRA 为何有效

预训练的 LLM 在适应新任务时具有较低的“固有维度”,这意味着数据可以通过低维空间有效地表示或近似,同时保留其大部分基本信息或结构。我们可以将适应任务的新权重矩阵分解为低维(较小)矩阵,而不会丢失大量重要信息。我们通过低秩近似实现这一点。

矩阵的秩是一个可以让你了解矩阵复杂度的值。矩阵的低秩近似旨在尽可能接近原始矩阵,但秩较低。低秩矩阵降低了计算复杂度,从而提高了矩阵乘法的效率。低秩分解是指通过推导矩阵 A 的低秩近似来有效近似矩阵 A 的过程。奇异值分解 (SVD) 是一种常用的低秩分解方法。

假设 W 表示给定神经网络层中的权重矩阵,假设 ΔW 是经过完全微调后 W 的权重更新。然后,我们可以将权重更新矩阵 ΔW 分解为两个较小的矩阵:ΔW = WA*WB,其中 WA 是 A × r 维矩阵,WB 是 r × B 维矩阵。在这里,我们保持原始权重 W 不变,只训练新矩阵 WA 和 WB。这总结了 LoRA 方法,如下图所示。

LoRA 的优势如下:

  • 减少资源消耗。对深度学习模型进行微调通常需要大量计算资源,这可能既昂贵又耗时。LoRA 可在保持高性能的同时减少对资源的需求。
  • 更快的迭代。LoRA 可实现快速迭代,从而更轻松地尝试不同的微调任务并快速调整模型。
  • 改进迁移学习。LoRA 提高了迁移学习的有效性,因为带有 LoRA 适配器的模型可以用更少的数据进行微调。这在标记数据稀缺的情况下尤其有价值。
  • 广泛适用。LoRA 用途广泛,可应用于自然语言处理、计算机视觉和语音识别等不同领域。
  • 降低碳排放。通过减少计算要求,LoRA 有助于实现更环保、更可持续的深度学习方法。

2、使用 LoRA 技术训练神经网络

在此博客中,我们利用 CIFAR-10 数据集,使用几个epoch从头开始训练基本图像分类器。之后,我们进一步使用 LoRA 训练模型,说明将 LoRA 纳入训练过程的优势。

2.1 设置

此演示使用以下设置创建。有关全面的支持详细信息,请参阅 ROCm 文档。

硬件和操作系统:

  • AMD Instinct GPU
  • Ubuntu 22.04.3 LTS

软件:

  • ROCm 5.7.0+
  • Pytorch 2.0+

2.2 训练初始模型

导入软件包:

import torch
import torchvision
import torchvision.transforms as transforms

加载数据集并设置目标设备:

# 10 classes from CIFAR10 dataset
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# batch size
batch_size = 8# image preprocessing
preprocessor = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# training dataset
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True,download=True, transform=preprocessor)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,shuffle=True, num_workers=8)
# test dataset
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False,download=True, transform=preprocessor)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,shuffle=False, num_workers=8)# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

展示数据集的样本:

import matplotlib.pyplot as plt
import numpy as np# helper function to display image
def image_display(images):# get the original imageimages = images * 0.5 + 0.5plt.imshow(np.transpose(images.numpy(), (1, 2, 0)))plt.axis('off')plt.show()# get a batch of images
images, labels = next(iter(train_loader))
# display images
image_display(torchvision.utils.make_grid(images))
# show ground truth labels
print('Ground truth labels: ', ' '.join(f'{classes[labels[j]]}' for j in range(images.shape[0])))

输出显示:

Ground truth labels:  cat ship ship airplane frog frog automobile frog

创建一个用于图像分类的基本三层神经网络,注重简单性以说明 LoRA 效果:

import torch.nn as nn
import torch.nn.functional as Fclass net(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(3*32*32, 4096)self.fc2 = nn.Linear(4096, 2048)self.fc3 = nn.Linear(2048, 10)def forward(self, x):x = torch.flatten(x, 1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# move the model to device
classifier = net().to(device)

接下来训练模型。

我们使用交叉熵损失和 Adam 作为损失函数和优化器。

import torch.optim as optimdef train(train_loader, classifier, start_epoch = 0, epochs=1, device="cuda:0"):classifier = classifier.to(device)classifier.train()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(classifier.parameters(), lr=0.001)for epoch in range(epochs):  # training looploss_log = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = data[0].to(device), data[1].to(device)# Resets the parameter gradientsoptimizer.zero_grad()outputs = classifier(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print loss after every 1000 mini-batchesloss_log += loss.item()if i % 2000 == 1999:    print(f'[{start_epoch + epoch}, {i+1:5d}] loss: {loss_log / 2000:.3f}')loss_log = 0.0

现在开始训练:

import timestart_epoch = 0
epochs = 1
# warm up the gpu with one epoch
train(train_loader, classifier, start_epoch=start_epoch, epochs=epochs, device=device)# run another epoch to record the time
start_epoch += epochs
epochs = 1
start = time.time()
train(train_loader, classifier, start_epoch=start_epoch, epochs=epochs, device=device)
torch.cuda.synchronize()
end = time.time()
train_time = (end - start)print(f"One epoch takes {train_time:.3f} seconds")

输出如下:

    [0,  2000] loss: 1.987[0,  4000] loss: 1.906[0,  6000] loss: 1.843[1,  2000] loss: 1.807[1,  4000] loss: 1.802[1,  6000] loss: 1.782One epoch takes 31.896 seconds

一个 epoch 大约需要 31 秒。

保存模型:

model_path = './classifier_cira10.pth'
torch.save(classifier.state_dict(), model_path)

稍后我们将使用 LoRA 训练相同的模型,并检查训练一个 epoch 需要多长时间。

加载保存的模型并进行快速测试:

# Prepare the test data.
images, labels = next(iter(test_loader))
# display the test images
image_display(torchvision.utils.make_grid(images))
# show ground truth labels
print('Ground truth labels: ', ' '.join(f'{classes[labels[j]]}' for j in range(images.shape[0])))# Load the saved model and have a test
model = net()
model.load_state_dict(torch.load(model_path))
model = model.to(device)
images = images.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join(f'{classes[predicted[j]]}'for j in range(images.shape[0])))

输出:

Ground truth labels:  cat ship ship airplane frog frog automobile frog
Predicted:  deer truck airplane ship deer frog automobile bird

我们观察到,仅对模型进行两个阶段的训练并不能产生令人满意的结果。让我们检查一下该模型在整个测试数据集上的表现。

def test(model, test_loader, device):model=model.to(device)model.eval()correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = data[0].to(device), data[1].to(device)# images = images.to(device)# labels = labels.to(device)# inferenceoutputs = model(images)# get the best prediction_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the given model on the {total} test images is {100 * correct // total} %')test(model, test_loader, device)

输出:

    Accuracy of the given model on the 10000 test images is 32 %

这一结果表明,通过进一步训练,模型有很大的改进潜力。在以下部分中,我们将把 LoRA 应用于模型,并继续使用这种方法进行训练。

2.3 将 LoRA 应用于模型

定义用于将 LoRA 应用于模型的辅助函数:

class ParametrizationWithLoRA(nn.Module):def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):super().__init__()# Create A B and scale used in ∆W = BA x α/rself.lora_weights_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))nn.init.normal_(self.lora_weights_A, mean=0, std=1)self.lora_weights_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))self.scale = alpha / rankself.enabled = Truedef forward(self, original_weights):if self.enabled:return original_weights + torch.matmul(self.lora_weights_B, self.lora_weights_A).view(original_weights.shape) * self.scaleelse:return original_weightsdef apply_parameterization_lora(layer, device, rank=1, alpha=1):"""Apply loRA to a given layer"""features_in, features_out = layer.weight.shapereturn ParametrizationWithLoRA(features_in, features_out, rank=rank, alpha=alpha, device=device)def enable_lora(model, enabled=True):"""enabled = True: incorporate the the lora parameters to the modelenabled = False: the lora parameters have no impact on the model"""for layer in [model.fc1, model.fc2, model.fc3]:layer.parametrizations["weight"][0].enabled = enabled

将 LoRA 应用于我们的模型:

import torch.nn.utils.parametrize as parametrize
parametrize.register_parametrization(model.fc1, "weight", apply_parameterization_lora(model.fc1, device))
parametrize.register_parametrization(model.fc2, "weight", apply_parameterization_lora(model.fc2, device))
parametrize.register_parametrization(model.fc3, "weight", apply_parameterization_lora(model.fc3, device))

现在,我们的模型参数由两部分组成:原始参数和 LoRA 引入的参数。由于我们尚未训练此更新后的模型,因此 LoRA 权重的初始化方式不会影响模型的准确性(请参阅“ParametrizationWithLoRA”)。因此,禁用或启用 LoRA 应该会导致模型的准确性相同。让我们来测试一下这个假设。

enable_lora(model, enabled=False)
test(model, test_loader, device)

输出:

    Accuracy of the network on the 10000 test images: 32 %
enable_lora(model, enabled=True)
test(model, test_loader, device)

输出:

    Accuracy of the network on the 10000 test images: 32 %

这正是我们所期望的。

现在让我们看看 LoRA 添加了多少参数。

total_lora_params = 0
total_original_params = 0
for index, layer in enumerate([model.fc1, model.fc2, model.fc3]):total_lora_params += layer.parametrizations["weight"][0].lora_weights_A.nelement() + layer.parametrizations["weight"][0].lora_weights_B.nelement()total_original_params += layer.weight.nelement() + layer.bias.nelement()print(f'Number of parameters in the model with LoRA: {total_lora_params + total_original_params:,}')
print(f'Parameters added by LoRA: {total_lora_params:,}')
params_increment = (total_lora_params / total_original_params) * 100
print(f'Parameters increment: {params_increment:.3f}%')

输出:

    Number of parameters in the model with LoRA: 21,013,524Parameters added by LoRA: 15,370Parameters increment: 0.073%

LoRA 只为我们的模型添加了 0.073% 的参数。

接下来继续使用 LoRA 训练模型。

在继续训练模型之前,我们希望冻结模型的所有原始参数,如论文所述。通过这样做,我们只更新 LoRA 引入的权重,这是原始模型参数数量的 0.073%。

for name, param in model.named_parameters():if 'lora' not in name:param.requires_grad = False

继续应用 LoRA 训练模型。

# make sure the loRA is enabled 
enable_lora(model, enabled=True)start_epoch += epochs
epochs = 1
# warm up the GPU with the new model (loRA enabled) one epoch for testing the training time
train(train_loader, model, start_epoch=start_epoch, epochs=epochs, device=device)start = time.time()
# run another epoch to record the time
start_epoch += epochs
epochs = 1
import time
start = time.time()
train(train_loader, model, start_epoch=start_epoch, epochs=epochs, device=device)
torch.cuda.synchronize()
end = time.time()
train_time = (end - start)
print(f"One epoch takes {train_time} seconds")

输出:

    [2,  2000] loss: 1.643[2,  4000] loss: 1.606[2,  6000] loss: 1.601[3,  2000] loss: 1.568[3,  4000] loss: 1.560[3,  6000] loss: 1.585One epoch takes 16.622623205184937 seconds

你可能会注意到,现在只需大约 16 秒即可完成一个 epoch 的训练,这大约是训练原始模型所需时间(31 秒)的 53%。

损失的减少表明模型已经从更新 LoRA 引入的参数中学习到了知识。现在,如果我们在启用 LoRA 的情况下测试模型,准确率应该高于我们之前使用原始模型实现的准确率(32%)。如果我们禁用 LoRA,该模型应该能够达到与原始模型相同的准确率。让我们继续进行这些测试。

enable_lora(model, enabled=True)
test(model, test_loader, device)
enable_lora(model, enabled=False)
test(model, test_loader, device)

输出:

    Accuracy of the given model on the 10000 test images is 42 %Accuracy of the given model on the 10000 test images is 32 %

使用之前的图像再次测试更新后的模型。

# display the test images
image_display(torchvision.utils.make_grid(images.cpu()))
# show ground truth labels
print('Ground truth labels: ', ' '.join(f'{classes[labels[j]]}' for j in range(images.shape[0])))# Load the saved model and have a test
enable_lora(model, enabled=True)
images = images.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join(f'{classes[predicted[j]]}'for j in range(images.shape[0])))

输出:

    Ground truth labels:  cat ship ship airplane frog frog automobile frogPredicted:  cat ship ship ship frog frog automobile frog

我们可以观察到,与步骤 6 中获得的结果相比,新模型的表现更好,表明参数确实学到了有意义的信息。

3、结束语

在这篇博文中,我们探索了 LoRA 算法,深入研究了它的原理以及在带有 ROCm 的 AMD GPU 上的实现。我们从头开始开发了一个基本网络和 LoRA 模块,以展示 LoRA 如何有效地减少可训练参数和训练时间。


原文链接:用LoRA进行高效微调 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/21419.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

k8s学习--ConfigMap详细解释与应用

文章目录 一 什么是configmapConfigMap 的好处ConfigMap 的限制 二.创建ConfigMap的4种方式1.在命令行指定参数创建2.在命令行通过多个文件创建3.在命令行通过文件提供多个键值对创建4.YAML资源清单文件创建 三 configmap的两种使用方法1.通过环境变量的方式传递给pod2.通过vol…

MySQL学习——影响选项文件处理的命令行选项和程序选项修改器

大多数支持选项文件的MySQL程序都处理以下选项。因为这些选项会影响选项文件的处理,所以必须在命令行上给出,而不是在选项文件中给出。为了正常工作,这些选项中的每一个都必须先于其他选项给出,但以下情况除外: -prin…

WLAN组网模型探究

目录 一、WLAN基本概念二、WLAN组网方式三、WLAN转发模型 随着信息技术的飞速发展,无线局域网(WLAN)已逐渐成为企业网络架构中不可或缺的一部分。不同的企业组织因其业务特性、规模大小及安全需求的不同,对WLAN的要求也各有侧重。…

物联网面试准备

volatile的作用 volatile关键字用于告诉编译器,该变量可能会在程序的执行过程中被意外更改,因此编译器不应该对该变量进行优化或者缓存。 这样可以确保每次访问该变量时都会从内存中读取最新的值,而不是使用缓存中的旧值。 在多线程编程中&…

Arduino IDE 2.3.2找不到端口解决方法

Arduino IDE 2.3.2找不到端口解决方法 问题描述 Arduino IDE 2.3.2 软件找不到端口(端口显示灰色),Arduino开发板连接电脑后,设备管理器能够看到端口信息,Arduino IDE软件中看不到端口。 设备管理器中能够看到端口信息 Arduino IDE中端口显…

植物大战僵尸杂交版(最新版)

杂交版1.0文件链接 链接:https://pan.baidu.com/s/1Ew6iTg0_d_Ut8N9_18KGLw 提取码:yspa 杂交版2.0文件链接 链接:https://pan.baidu.com/s/1tuchowb4C_oNT6EpqSvr_w?pwdy2fz 提取码:y2fz

HTML静态网页成品作业(HTML+CSS)—— 香奈儿香水介绍网页(1个页面)

🎉不定期分享源码,关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 🏷️本套采用HTMLCSS,未使用Javacsript代码,共有1个页面。 二、作品演示 三、代…

上位机图像处理和嵌入式模块部署(f407 mcu中tf卡模拟u盘)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 在f407开发板上面,本身是有一个usb接口的。这个usb接口也不仅仅是作为电源使用的,它还可以用来做很多的事情。一方面&#…

计算机网络错题答案汇总

王道学习 第1章 计算机网络体系结构 1.1 1.2

vue配置代理服务器解决跨域方法

一.vue配置代理服务器解决跨域方法一 过程如图: 1.在配置文件中设置代理服务器的地址 //vue.config.js module.exports{pages:{index:{// 入口entry:src/main.js,},},lintOnSave:false, //关闭语法检测// 开启代理服务器devServer:{proxy:http://localhost:8000//…

Java基础教程:算术运算符快速掌握

哈喽,各位小伙伴们,你们好呀,我是喵手。运营社区:C站/掘金/腾讯云;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一…

操作系统之银行家算法

目录 前言 银行家算法 定义 举例 策略 思路 结束语 前言 今天是坚持写博客的第16天,已经超过半个月了,希望可以继续坚持,不断积累与回顾,夯实基础知识体系的基础。我们今天来讲讲操作系统当中的另一个重要知识点——银行家…

vue2组件传参方法

一、父传子 1、$refs方法 <template><div class"father"><h1>我是父亲</h1><button click"getHeight">获取身高</button><ChildView ref"childRef"></ChildView></div> </template&…

第4章:车辆的横向优化控制

4.1 车辆动力学模型 注1&#xff1a;运动学模型和动力学模型最大的不同点在于 运动学模型是在我们不考虑车辆的受力情况下建立的&#xff08;回顾我们推导出运动学模型的过程&#xff0c;我们没有使用到任何车辆所受的外力作为公式中的已知量&#xff0c;而是直接通过 “ 车速…

云计算-云基础设施的配置 (Provisioning Cloud Infrastructure)

AWS CloudFormation (AWS CloudFormation) 它是一项服务&#xff0c;允许我们自动建模和配置所需的AWS资源。可以使用模板来实现这一目的。这个模板基本上是用JSON或YAML格式编写的。AWS CloudFormation会根据模板描述的内容来实施资源的配置和管理。我们可以成组配置和管理一组…

华为交换机的基本配置

实验拓扑&#xff1a; 实验目的&#xff1a;认识二层交换机和二层交换技术的工作原理&#xff1b;认识三层交换和三层交换技术。 三层功能简而言之就是了具有路由的功能&#xff0c;设备可以充当网关和路由器。 实验要求&#xff1a;公司的两个部门用vlan进行划分&#xff0c…

vs中.\ 与 ..\ 区别

100编程书屋_孔夫子旧书网 在 Visual Studio 中&#xff0c;. 和 .. 是表示相对路径的两个特殊符号。 . 表示当前目录&#xff0c;例如&#xff1a;.\file.txt 表示当前目录下的文件 "file.txt"。 .. 表示上一级目录&#xff0c;例如&#xff1a;..\file.txt 表示上…

喵星人必备!福派斯三文鱼猫粮,营养满分!

猫粮品牌&#xff1a;福派斯三文鱼猫粮测评体验 在快节奏的都市生活中&#xff0c;我们的宠物猫也需要适应当下的生活环境&#xff0c;并保持健康和活力。作为一名合格的铲屎官&#xff0c;我们总是关心如何为猫咪提供既健康又美味的饮食。今天&#xff0c;我有幸为大家带来一…

QT 如何在 QListWidget 的选项中插入自定义组件

有时我们需要 QListWidget 完成更复杂的操作&#xff0c;而不仅限于添加文本或者图标&#xff0c;那么就会使用到 setItemWidget 函数&#xff0c;但是这也会伴生一个问题&#xff0c;插入自定义组件后&#xff0c;QListWidget 对选项点击事件的获取会收到阻塞&#xff0c;因…

Docker安装启动Mysql

1、安装Docker&#xff08;省略&#xff09; 网上教程很多 2、下载Mysql5.7版本 docker pull mysql:5.7 3、查看镜像是够下载成功 docker images 4、启动镜像&#xff0c;生成容器 docker run --name mysql5.7 -p 13306:3306 -e MYSQL_ROOT_PASSWORD123456 -d mysql:5.7 5…