pytorch教材
3.4. softmax回归 — 动手学深度学习 2.0.0 documentation
c++实现代码
代码太长了就没整理了,也暂时没有运行效果截图
同样没有本文也没有实现反向自动求导
超长代码警告,757行。不过可能注释占一半
#include <bits/stdc++.h>
using namespace std;
// reverseInt 函数:将32位整数的大小端进行转换
// 参数:
// x: 需要进行大小端转换的32位整数
// 返回值:
// 转换后(即小端转大端或大端转小端)的32位整数
int reverseInt(int x)
{// 定义四个无符号字符变量,用于存储整数x的四个字节 unsigned char a, b, c, d;// 获取整数x的最低8位(即第一个字节) // (int)255的二进制是00000000 00000000 00000000 11111111,与操作后只保留最低8位 a = x & 255;// 获取整数x的第二个字节(即第9-16位)b = (x>>8) & 255;// 获取整数x的第三个字节(即第17-24位)c = (x>>16) & 255;// 获取整数x的最高字节(即第25-32位) d = (x>>24) & 255;// 将这四个字节按照相反的顺序重新组合成一个整数,实现大端序和小端序的转换 int ans = ((int)a<<24) + ((int)b<<16) + ((int)c<<8) + d;return ans;
}/** * @brief 获取最大值 * * 从给定的双精度浮点数数组中找出最大值并返回。 * * @param a 指向双精度浮点数数组的指针 * @param len 数组的长度(元素的数量) * * @return 数组中的最大值 */
double getMax(double* a, int len)
{ double smax = -DBL_MAX; // 初始化最大值为 double max,确保即使数组中包含负数,该函数仍然会返回最大的那个数assert(len>0); // 断言数组长度必须大于 0 for (int i = 0; i < len; i++) { // 使用三元运算符更新最大值 smax = a[i] > smax ? a[i] : smax; } return smax;
}
/** * @brief 计算 Softmax 函数值 * * 对于给定的实数数组,计算其 Softmax 函数值,并返回一个新的数组,其中每个元素是输入数组中对应元素的 Softmax 值。 * * @param num 输入的实数数组 * @param len 数组的长度(元素的数量) * * @return 指向计算得到的 Softmax 值数组的指针 * * @note 返回的数组需要调用者在使用完毕后手动释放内存。 * 为了数值稳定性,在计算 Softmax 之前,先对数组中的最大值进行减去操作(称为 Shifted Softmax)。 * 此外,如果数组中包含极大的正数或极小的负数,可能会导致溢出或下溢,但在此实现中,通过减去最大值来减少溢出的可能性。 */
double* softmax(double* num, int len)
{ // 分配一个新的双精度浮点数数组来存储 Softmax 值 double* ans = new double[len]; // 断言数组长度必须大于 0 assert(len > 0); // 复制输入数组到输出数组(初始时,两者相同) for (int i = 0; i < len; i++){ ans[i] = num[i]; } // 数组元素的总和 与 最大值 double sum = 0, smax = getMax(ans, len); // 对每个元素应用 Shifted Softmax 公式 for (int i = 0; i < len; i++){ // 减去最大值后计算指数函数,避免上溢 ans[i] = exp(ans[i] - smax); // 累加所有 exp() 的值到 sum 中 sum += ans[i]; } // 归一化 Softmax 值 for (int i = 0; i < len; i++){ ans[i] /= sum; } // 返回计算得到的 Softmax 值数组 return ans;
}
/** * @brief 矩阵乘法 * * 执行两个二维数组的矩阵乘法运算,并返回结果矩阵。 * * @param X 第一个矩阵,一个指向指针的指针,表示二维数组 * @param W 第二个矩阵,一个指向指针的指针,表示二维数组 * @param xrow 矩阵X的行数 * @param xcol 矩阵X的列数,同时也是矩阵W的行数(由断言保证) * @param wrow 矩阵W的行数(实际上与xcol相同,但此参数在此函数中不使用) * @param wcol 矩阵W的列数 * * @return 指向结果矩阵的指针,一个指向指针的指针,表示二维数组 * * @note 调用此函数前,应确保矩阵X和W的维度匹配(即X的列数等于W的行数)。 * 此外,返回的结果矩阵需要调用者在使用完毕后手动释放内存。 * 这个函数使用了断言来确保矩阵X的列数等于矩阵W的行数。 */
double** matmul(double** X, double** W, int xrow, int xcol, int wrow, int wcol)
{ // 断言以确保矩阵X的列数等于矩阵W的行数 assert(xcol == wrow); // 分配结果矩阵的内存 double** ans = new double*[xrow]; for (int i = 0; i < xrow; i++) { ans[i] = new double[wcol]; } // 遍历计算结果矩阵的每个元素 for(int i = 0; i < xrow; i++) { for (int j = 0; j < wcol; j++) { double sum = 0; // 初始化累加器 // 遍历矩阵X的第i行和矩阵W的第j列对应的元素,执行乘法并累加 for (int k = 0; k < xcol; k++) { double x = X[i][k]; // 从矩阵X中取出元素 sum += x * W[k][j]; // 累加乘法结果 } // 将累加结果存储到结果矩阵的对应位置 ans[i][j] = sum; } } // 返回结果矩阵 return ans;
}
/** * @brief 矩阵乘法与偏置项相加 * * 对给定的输入矩阵X、权重矩阵W和偏置项b进行线性变换,即执行X*W+b的操作, * 并返回结果矩阵。 * * @param X 输入矩阵,大小为[batch_size, num_input] * @param W 权重矩阵,大小通常为[num_input, num_output]* @param b 偏置项,大小为[num_output] * @param batch_size 批量大小,即输入矩阵X的行数 * @param num_input 输入特征的维度 * @param num_output 输出特征的维度 * * @return 指向结果矩阵的指针,大小为[batch_size, num_output] * * @note 调用此函数前,应确保输入矩阵X、权重矩阵W和偏置项b的维度正确匹配。 * 此外,返回的结果矩阵需要调用者在使用完毕后手动释放内存。 */
double** xwpb(double** X, double** W, double* b, int batch_size, int num_input, int num_output)
{ // 执行矩阵乘法X*W double** o = matmul(X, W, batch_size, num_input, num_input, num_output); // 将偏置项b加到结果矩阵o的每一行上 for (int i = 0; i < batch_size; i++) // 遍历批量中的每个样本 { for(int j = 0; j < num_output; j++) // 遍历输出特征的每个维度 { o[i][j] += b[j]; // 将偏置项加到结果矩阵的对应位置上 } } // 返回结果矩阵 return o;
}/** * @brief Softmax回归函数 * * 对给定的输入矩阵X、权重矩阵W和偏置项b执行线性变换(即XW+b), * 然后对每个样本的输出应用Softmax函数,并返回包含Softmax结果的向量。 * * @param X 输入矩阵,大小为[batch_size, num_input] * @param W 权重矩阵,大小为[num_input, num_output] * @param b 偏置项,大小为[num_output] * @param batch_size 批量大小,即输入矩阵的行数 * @param num_input 输入特征的维度 * @param num_output 输出特征的维度(同时也是类别数) * * @return 返回一个向量,其中每个元素是一个指向double数组的指针,表示每个样本的Softmax输出 * * @note 调用此函数前,应确保输入矩阵X、权重矩阵W和偏置项b的维度正确匹配。 * 返回的向量中的double指针数组(即Softmax结果)在使用完毕后需要手动释放内存。 * 函数内部调用了xwpb函数进行线性变换,并调用了softmax函数对每个样本的输出应用Softmax。 */
vector<double*> sofreg(double** X, double** W, double* b, int batch_size, int num_input, int num_output)
{ // 执行线性变换XW+b,并返回结果矩阵o double** o = xwpb(X, W, b, batch_size, num_input, num_output); // 创建一个大小为batch_size的向量y_hat,用于存储每个样本的Softmax输出 vector<double*> y_hat(batch_size); // 遍历每个样本 for (int i = 0; i < batch_size; i++) { // 对当前样本的输出应用Softmax函数,并返回结果指针so double* so = softmax(o[i], num_output); // 将Softmax结果存储到y_hat向量的对应位置 y_hat[i] = so; } // 释放内存 for(int i=0; i<batch_size; i++) delete[] o[i];delete[] o;// 返回包含每个样本Softmax输出的向量 return y_hat;
}
/** * @brief 交叉熵损失函数 * * 计算给定预测值(经过Softmax处理后的概率分布)y_hat和实际标签y之间的交叉熵损失。 * * @param y_hat 预测值向量,每个元素是一个指向double数组的指针,表示每个样本的Softmax输出 * @param y 实际标签数组,为0到9之间的整数* @param batch_size 批量大小,即y_hat和y中元素的数量 * @param num_output 输出特征的维度(同时也是类别数),在此为10(0-9的10个类别) * * @return 返回一个指向double数组的指针,数组大小为batch_size,表示每个样本的交叉熵损失 * * @note 调用此函数前,应确保y_hat和y的长度相等,并且与batch_size匹配。 * 此外,y中的每个标签值应为0到num_output-1之间的整数。 * 函数内部使用了assert来检查y中的值是否在有效范围内,以及y_hat中对应位置的预测值是否在(0,1)之间。 * 返回的double数组需要调用者在使用完毕后手动释放内存。 */
double* cross_entropy(vector<double*> y_hat, char* y, int batch_size, int num_output)
{ // 分配一个大小为batch_size的double数组,用于存储每个样本的交叉熵损失 double* loss = new double[batch_size]; // 遍历每个样本 for (int i = 0; i < batch_size; i++) { int yi = y[i];// 使用assert断言来检查标签值是否在有效范围内(0-9) assert(yi >= 0 && yi <= 9); // 使用assert断言来检查y_hat中对应位置的预测值是否在(0,1)之间 assert(y_hat[i][yi] > 0 && y_hat[i][yi] < 1); // 计算交叉熵损失,这里只考虑了单标签的情况,即每个样本只有一个类别标签 loss[i] = -log(y_hat[i][yi]); } // 返回包含每个样本交叉熵损失的double数组 return loss;
}
/*** @brief sgd 函数用于执行随机梯度下降(Stochastic Gradient Descent)算法
// 来更新神经网络中的权重 W 和偏置 b // 参数说明:
// X: 输入数据,是一个二维数组(指针的指针),大小为 [batch_size][num_input]
// y: 标签数据,是一个字符串(但实际上是标签的索引数组),大小为 [batch_size]
// W: 权重矩阵,是一个二维数组(指针的指针),大小为 [num_input][num_output]
// b: 偏置向量,是一个一维数组,大小为 [num_output]
// lr: 学习率,用于控制权重更新的步长
// batch_size: 批量大小,即每次用于梯度计算的样本数量
// num_input: 输入数据的特征数量
// num_output: 输出数据的类别数量(或神经元的数量) */
void sgd(double** X, const char* y, double** W, double* b, double lr, int batch_size, int num_input, int num_output)
{
// vector<double*> y_hat = sofreg(X, W, b, batch_size, num_input, num_output);// 计算线性组合的结果(未经过激活函数) double** o=xwpb(X, W, b, batch_size, num_input, num_output);// 为权重梯度 gradw 和偏置梯度 gradb 分配内存 double** gradw=new double*[num_input];double* gradb=new double[num_output];// 初始化权重梯度 gradw 为 0 for (int i=0; i<num_input; i++){gradw[i] = new double[num_output];for (int j=0; j<num_output; j++)gradw[i][j]=0.0;}// 初始化偏置梯度 gradb 为 0for (int j=0; j<num_output; j++)gradb[j]=0.0;// 遍历批量中的每个样本,计算梯度 for (int i=0; i<batch_size; i++){int yi = y[i];// 计算 softmax 函数的结果 double* so=softmax(o[i], num_output);// 计算 cross entropy 对 小批量的未规范化预测 O 的导数 // softmax(o)[j]-y[j], 将 y 视为独热标签向量 double grad[num_output];for(int j=0; j<num_output; j++){grad[j] = so[j];}grad[yi]-=1;// 计算 gradb , cross entropy 对 b 的导数,链式求导 // o = X * W + b for (int j=0; j<num_output; j++){gradb[j]+=grad[j];}// 计算 gradw ,cross entropy 对 W 的导数,链式求导 // o = X * W + b for (int j=0; j<num_input; j++){for (int k=0; k<num_output; k++){double x=X[i][j];gradw[j][k] += grad[k]*x;}}delete[] so;}// 使用计算得到的梯度来更新权重 W 和偏置 b for(int i=0; i<num_input; i++){for (int j=0; j<num_output; j++){W[i][j] = W[i][j] - lr * gradw[i][j] / batch_size;}}for (int i=0; i<num_output; i++){b[i] = b[i] - lr * gradb[i]/ batch_size;}for (int i=0; i<batch_size; i++) delete[] o[i];delete[] o;for (int i=0; i<num_input; i++) delete[] gradw[i];delete[] gradw;delete[] gradb;
}
/** * @brief 计算平均值 * * 计算给定双精度浮点数数组的平均值。 * * @param loss 包含要计算平均值的双精度浮点数的数组 * @param len 数组的长度(元素的数量) * * @return 数组 `loss` 中所有元素的平均值 * */
double mean(double* loss, int len)
{ double ans = 0; // 初始化累加器为 0 assert(len>0); // 断言数组长度必须大于 0 // 遍历数组 `loss` 中的每个元素 for (int i = 0; i < len; i++) { // 将当前元素加到累加器 `ans` 上 ans += loss[i]; } // 返回累加器 `ans` 除以数组长度 `len` 的结果,即平均值 return ans / len;
}unsigned char** read_mnist_image(string file_name, int& num_image, int& num_row, int& num_col, const int check_number);
char* read_mnist_label(string file_name, const int num_image, const int check_number);
unsigned char** get_image(string path, int& num_image, int& num_row, int& num_col, bool is_train);
char* get_label(string path, int num_image, bool is_train);
/** * @brief 归一化图像数据 * * 将输入的二维无符号字符数组(通常是灰度图像)归一化到 0 到 1 的范围内, * 并返回一个二维双精度浮点数数组,其中包含了归一化后的图像数据。 * * @param cX 输入的二维无符号字符数组,代表原始图像数据 * @param row 图像的行数 * @param col 图像的列数 * * @return 指向归一化后二维双精度浮点数数组的指针 * * @note 调用者需要确保输入的 cX 数组是有效且已经分配了足够的内存。 * 返回的 X 数组需要调用者在使用完毕后手动释放内存。 */
double** normalization(unsigned char** cX, int row, int col)
{ // 创建一个新的二维双精度浮点数数组 X 来存储归一化后的图像数据 double** X = new double*[row]; for(int i=0; i<row; i++) { X[i] = new double[col]; } // 遍历原始图像数据的每个像素,并进行归一化 for (int i=0; i<row; i++) { for (int j=0; j<col; j++) { // 读取原始图像数据中的像素值 int x = cX[i][j]; // 归一化到 0 到 1 的范围 X[i][j] = x * 1.0 / 255.0;} } // 返回归一化后的图像数据 return X;
}/** * @brief 打乱图像数据和标签的顺序 * * 使用 Fisher-Yates 洗牌算法(也被称为 Knuth 洗牌)结合一个随机数生成器来 * 打乱传入的图像数据和对应的标签。 * * @param X 指向图像数据的指针数组,每个元素指向一个图像(一维数组) * @param y 指向标签数据的指针,每个元素表示一个标签 * @param num_image 图像和标签的数量 * * @note 此函数会直接修改传入的 X 和 y,而不需要额外的存储空间。 */
void shuffle(unsigned char** X, char* y, int num_image)
{ // 创建一个整数向量 num,用于存储原始索引 vector<int> num(num_image); for(int i = 0; i < num_image; i++) num[i] = i; // 使用当前时间作为随机数生成器的种子 // 这样可以确保每次调用 shuffle 函数时都能得到不同的随机序列 random_device rd; mt19937 g(rd()); // 使用 Mersenne Twister 算法来生成随机数 // 打乱整数向量 num 中的元素顺序 shuffle(num.begin(), num.end(), g); // 使用 Fisher-Yates 洗牌算法来打乱图像数据和标签的顺序 unsigned char* tmpcp; // 临时指针,用于交换图像数据 char tmpc; // 临时字符,用于交换标签 for (int i = 0; i < num_image; i++) { // 交换图像数据 tmpcp = X[i]; X[i] = X[num[i]]; X[num[i]] = tmpcp; // 交换标签数据 tmpc = y[i]; y[i] = y[num[i]]; y[num[i]] = tmpc; }
}int main()
{// 定义数据集的路径 string path="../data/MNIST/raw/";// 定义变量来存储图像和标签的数量以及尺寸 // 训练图像的数量、像素行数和列数(高和宽) int num_image, num_row, num_col;// 测试图像的数量、像素行数和列数 int num_test_image, num_test_row, num_test_col;// 从指定路径读取训练集与测试集图像,并返回图像数据和图像数量以及像素宽高 unsigned char** cX = get_image(path, num_image, num_row, num_col, true);unsigned char** test_cX = get_image(path, num_test_image, num_test_row, num_test_col, false);// 从指定路径加载标签 char* y = get_label(path, num_image, true);char* test_y = get_label(path, num_test_image, false); // 对训练数据和标签进行随机打乱 shuffle(cX, y, num_image);// 对图像数据进行归一化处理,并返回处理后的数据 double** X=normalization(cX, num_image, num_row*num_col);double** test_X=normalization(test_cX, num_test_image, num_test_row*num_test_col); // 定义超参数 const double lr = 0.01;// 学习率 const int num_epochs = 10;// 训练轮数 const int num_output = 10;// 输出层神经元数量(对应MNIST的10个类别) const int batch_size = 256;// 批量大小const int num_sample = num_image;// 总样本数(这里等于训练样本数) const int num_input = num_row * num_col; // 输入层神经元数量(等于图像的像素数) // 初始化权重矩阵W和偏置向量b double** W=new double* [num_input];for (int i=0; i<num_input; i++) W[i]=new double[num_output];double* b=new double[num_output];// 将W和b的所有元素初始化为0.0 for(int i=0; i<num_input; i++){for (int j=0; j<num_output; j++){W[i][j]=0.0;}}for (int j=0; j<num_output; j++){b[j]=0.0;}// 开始进行训练循环,迭代num_epochs次 for (int epoch=0; epoch<num_epochs; epoch++){// 对所有样本进行迭代,每次处理batch_size个样本 for (int j=0; j<num_sample; j+=batch_size){// 确保每一批量获得正确的样本个数 int batch = min(batch_size, num_sample-j);// 对当前batch的数据进行softmax回归计算,得到预测结果y_hat vector<double*> y_hat = sofreg(X+j, W, b, batch, num_input, num_output);// 计算当前batch的交叉熵损失 double* loss = cross_entropy(y_hat, y+j, batch, num_output);// 使用随机梯度下降(SGD)更新权重W和偏置b sgd(X+j, y+j, W, b, lr, batch, num_input, num_output);delete[] loss;for (auto i:y_hat) delete[] i;y_hat.clear();}// 在每个epoch结束后,测试模型在测试集上的性能 {// 初始化索引和当前batch的大小(对于测试集,这里通常使用整个测试集) int j=0;// 但因为测试集通常全部使用,所以batch_size可能不会被限制 int batch = min(batch_size, num_test_image-j);// 对测试集进行softmax回归计算,得到预测结果y_hat vector<double*> y_hat = sofreg(test_X+j, W, b, batch, num_input, num_output);// 初始化预测正确的样本数 int right_num=0;// 遍历当前batch的所有样本 for (int i=0; i<batch; i++){// 获取当前样本的预测结果 double* yy = y_hat[i];double mm=0, id=-1;// 找到预测概率最大的类别 for (int j=0; j<num_output; j++){if (yy[j]>mm) mm=yy[j], id=j;}// 检查预测类别是否与实际类别相同,如果相同则增加正确数if (id == (test_y+j)[i]) right_num++;}// 计算并打印当前epoch的测试集准确率double* loss = cross_entropy(y_hat, test_y+j, batch, num_output);printf("in epoch %d, accuracy is %.4Lf\n", epoch+1, right_num*1.0/batch*1.0);delete[] loss;for (auto i:y_hat) delete[] i;y_hat.clear();} }// 累了,交给操作系统自己释放吧 //delete cX, test_cX, y, test_y, X, test_X, w, b;
}/*******************************************
// 读取MNIST数据集图像的函数
// 参数:
// file_name: 图像文件的名字,需要绝对或相对路径
// num_image: 读取的图像数量(引用传递,用于修改外部变量)
// num_row: 每张图像的行数(引用传递,用于修改外部变量)
// num_col: 每张图像的列数(引用传递,用于修改外部变量)
// check_number: 用于检查文件头部magic number的期望值
// 返回值:
// 返回一个二维指针,指向由unsigned char数组组成的图像数组
// 第一个维度是图片数量,第二个维度是单张图片大小
// 注意:调用此函数的代码应确保在适当的时候释放images指向的内存,避免内存泄漏
********************************************/
unsigned char** read_mnist_image(string file_name, int& num_image, int& num_row, int& num_col, const int check_number)
{// 以二进制读模式打开文件 FILE *fp = fopen(file_name.c_str(), "rb");// 如果文件打开失败,退出程序if (!fp){printf("file open fail!\n");exit(0);}// 读取magic number、图像数量、图像的行数和列数 int magic_number;fread((char*)&magic_number, sizeof(magic_number), 1, fp);fread((char*)&num_image, sizeof(num_image), 1, fp);fread((char*)&num_row, sizeof(num_row), 1, fp);fread((char*)&num_col, sizeof(num_col), 1, fp);//由于MNIST文件是以大端字节序存储的,所以需要转换为小端序 magic_number=reverseInt(magic_number);num_image=reverseInt(num_image);num_row=reverseInt(num_row);num_col=reverseInt(num_col);// 检查magic number是否匹配 if (check_number != magic_number){printf("magic number is error, this is not the right image file\n");fclose(fp); // 关闭文件句柄 exit(0); // 退出程序 }// 分配二维数组以存储图像 unsigned char** images=new unsigned char*[num_image];// 读取所有图像for(int i=0; i<num_image; i++){// 为每个图像分配内存 unsigned char* image=new unsigned char[num_row * num_col];// 读取图像数据,fread(image, sizeof(unsigned char), num_row * num_col, fp);// 将图像数据存入二维数组 images[i]=image;}// 关闭文件句柄 fclose(fp);// 返回二维图像指针 return images;// 示例,使用delete[]来释放每个图像的内存,并最后释放images本身// for (int i = 0; i < num_image; ++i) { // delete[] images[i]; // } // delete[] images;
}/*************************************
// 读取MNIST数据集标签的函数
// 参数:
// file_name: 标签文件的名字,需要绝对或相对路径
// num_image: 预期读取的标签数量(应与文件内标签数量一致)
// check_number: 用于检查文件头部magic number的期望值,检查文件是否正确
// 返回值:
// 返回一个包含标签的char数组指针,此处char应理解为单字节类型整数
// 注意:调用此函数的代码应确保在适当的时候释放labels指向的内存,避免内存泄漏
***************************************/
char* read_mnist_label(string file_name, const int num_image, const int check_number)
{// 以二进制读模式打开文件 FILE *fp = fopen(file_name.c_str(), "rb");// 如果文件打开失败,退出程序if (!fp){printf("file open fail!\n");exit(-1);}// 定义并读取magic number和标签数量 int magic_number, num_label;fread((char*)&magic_number, sizeof(magic_number), 1, fp);fread((char*)&num_label, sizeof(num_label), 1, fp);//由于MNIST文件是以大端字节序存储的,所以需要转换为小端序 magic_number=reverseInt(magic_number);num_label=reverseInt(num_label);// 检查magic number是否匹配 if (check_number != magic_number){printf("magic number is error, this is not the right label file!\n");fclose(fp);exit(-1);}// 检查标签数量是否与预期一致 if (num_label!=num_image){printf("num_label not equal num_image!\n");fclose(fp);exit(-1);}// 动态分配内存以存储标签 char* labels=new char[num_label];// 读取所有标签 for(int i=0; i<num_label; i++){fread(&labels[i], sizeof(char), 1, fp);}// 关闭文件句柄 fclose(fp);// 返回标签数组指针return labels;// 示例,使用delete[]来释放标签的内存//delete[] labels;
}/** * @brief 获取 MNIST 数据集的图像数据 * * 根据指定的文件路径和是否训练数据集的标志,从 MNIST 数据集中加载图像数据, * 并返回指向图像数据的指针(二维数组)。同时,更新图像数量、行数和列数的引用参数。 * * @param path 数据集所在的路径 * @param num_image 引用参数,用于返回图像数量 * @param num_row 引用参数,用于返回每个图像的行数 * @param num_col 引用参数,用于返回每个图像的列数 * @param is_train 是否为训练数据集的标志,true 为训练数据,false 为测试数据 * * @return 指向图像数据的指针(二维数组),每个元素为 unsigned char 类型 */
unsigned char** get_image(string path, int& num_image, int& num_row, int& num_col, bool is_train)
{// 定义 MNIST 数据集的文件名 string name_train_image="train-images-idx3-ubyte";string name_train_label="train-labels-idx1-ubyte";string name_test_image="t10k-images-idx3-ubyte";string name_test_label="t10k-labels-idx1-ubyte";// 根据是否训练数据集的标志,选择加载训练或测试数据集的图像文件 if (is_train) {// 加载训练数据集的图像文件 return read_mnist_image(path+name_train_image, num_image, num_row, num_col, 2051);} else {// 加载测试数据集的图像文件 return read_mnist_image(path+name_test_image, num_image, num_row, num_col, 2051);}
}/** * @brief 获取 MNIST 数据集的标签数据 * * 根据给定的路径、图像数量和是否训练数据集的标志,从 MNIST 数据集中加载标签数据, * 并返回指向标签数据的指针(一维字符数组)。 * * @param path 数据集所在的路径 * @param num_image 预期的标签数量,用于检查文件标签数量是否与预期一致* @param is_train 是否为训练数据集的标志,true 为训练数据,false 为测试数据 * * @return 指向标签数据的指针(一维字符数组),每个元素表示一个标签 */
char* get_label(string path, const int num_image, bool is_train)
{// 定义 MNIST 数据集的文件名 string name_train_image="train-images-idx3-ubyte";string name_train_label="train-labels-idx1-ubyte";string name_test_image="t10k-images-idx3-ubyte";string name_test_label="t10k-labels-idx1-ubyte";// 根据是否训练数据集的标志,选择加载训练或测试数据集的标签文件if (is_train) {// 加载训练数据集的标签文件return read_mnist_label(path+name_train_label, num_image, 2049);} else {// 加载测试数据集的标签文件 return read_mnist_label(path+name_test_label, num_image, 2049);}
}