c++ 读取MNIST数据集实现softmax回归

pytorch教材

3.4. softmax回归 — 动手学深度学习 2.0.0 documentation

c++实现代码

代码太长了就没整理了,也暂时没有运行效果截图

同样没有本文也没有实现反向自动求导

超长代码警告,757行。不过可能注释占一半

#include <bits/stdc++.h>
using namespace std;
// reverseInt 函数:将32位整数的大小端进行转换  
// 参数:  
// x: 需要进行大小端转换的32位整数  
// 返回值:  
// 转换后(即小端转大端或大端转小端)的32位整数 
int reverseInt(int x)
{// 定义四个无符号字符变量,用于存储整数x的四个字节  unsigned char a, b, c, d;// 获取整数x的最低8位(即第一个字节) // (int)255的二进制是00000000 00000000 00000000 11111111,与操作后只保留最低8位  a = x & 255;// 获取整数x的第二个字节(即第9-16位)b = (x>>8) & 255;// 获取整数x的第三个字节(即第17-24位)c = (x>>16) & 255;// 获取整数x的最高字节(即第25-32位)  d = (x>>24) & 255;// 将这四个字节按照相反的顺序重新组合成一个整数,实现大端序和小端序的转换 int ans = ((int)a<<24) + ((int)b<<16) + ((int)c<<8) + d;return ans;
}/**  * @brief 获取最大值  *  * 从给定的双精度浮点数数组中找出最大值并返回。  *  * @param a 指向双精度浮点数数组的指针  * @param len 数组的长度(元素的数量)  *  * @return 数组中的最大值   */  
double getMax(double* a, int len)  
{  double smax = -DBL_MAX; // 初始化最大值为 double max,确保即使数组中包含负数,该函数仍然会返回最大的那个数assert(len>0); // 断言数组长度必须大于 0  for (int i = 0; i < len; i++)  {  // 使用三元运算符更新最大值  smax = a[i] > smax ? a[i] : smax;  }   return smax;  
}
/**  * @brief 计算 Softmax 函数值  *  * 对于给定的实数数组,计算其 Softmax 函数值,并返回一个新的数组,其中每个元素是输入数组中对应元素的 Softmax 值。  *  * @param num 输入的实数数组  * @param len 数组的长度(元素的数量)  *  * @return 指向计算得到的 Softmax 值数组的指针  *  * @note 返回的数组需要调用者在使用完毕后手动释放内存。  *       为了数值稳定性,在计算 Softmax 之前,先对数组中的最大值进行减去操作(称为 Shifted Softmax)。  *       此外,如果数组中包含极大的正数或极小的负数,可能会导致溢出或下溢,但在此实现中,通过减去最大值来减少溢出的可能性。  */  
double* softmax(double* num, int len)  
{  // 分配一个新的双精度浮点数数组来存储 Softmax 值  double* ans = new double[len];  // 断言数组长度必须大于 0  assert(len > 0);  // 复制输入数组到输出数组(初始时,两者相同)  for (int i = 0; i < len; i++){  ans[i] = num[i];  }  // 数组元素的总和 与 最大值  double sum = 0, smax = getMax(ans, len);  // 对每个元素应用 Shifted Softmax 公式  for (int i = 0; i < len; i++){  // 减去最大值后计算指数函数,避免上溢 ans[i] = exp(ans[i] - smax);  // 累加所有 exp() 的值到 sum 中  sum += ans[i];  }  // 归一化 Softmax 值  for (int i = 0; i < len; i++){  ans[i] /= sum;  }  // 返回计算得到的 Softmax 值数组  return ans;  
}
/**  * @brief 矩阵乘法  *  * 执行两个二维数组的矩阵乘法运算,并返回结果矩阵。  *  * @param X 第一个矩阵,一个指向指针的指针,表示二维数组  * @param W 第二个矩阵,一个指向指针的指针,表示二维数组  * @param xrow 矩阵X的行数  * @param xcol 矩阵X的列数,同时也是矩阵W的行数(由断言保证)  * @param wrow 矩阵W的行数(实际上与xcol相同,但此参数在此函数中不使用)  * @param wcol 矩阵W的列数  *  * @return 指向结果矩阵的指针,一个指向指针的指针,表示二维数组  *  * @note 调用此函数前,应确保矩阵X和W的维度匹配(即X的列数等于W的行数)。  *       此外,返回的结果矩阵需要调用者在使用完毕后手动释放内存。  *       这个函数使用了断言来确保矩阵X的列数等于矩阵W的行数。  */  
double** matmul(double** X, double** W, int xrow, int xcol, int wrow, int wcol)  
{  // 断言以确保矩阵X的列数等于矩阵W的行数  assert(xcol == wrow);  // 分配结果矩阵的内存  double** ans = new double*[xrow];  for (int i = 0; i < xrow; i++)  {  ans[i] = new double[wcol];  }  // 遍历计算结果矩阵的每个元素  for(int i = 0; i < xrow; i++)  {  for (int j = 0; j < wcol; j++)  {  double sum = 0; // 初始化累加器  // 遍历矩阵X的第i行和矩阵W的第j列对应的元素,执行乘法并累加  for (int k = 0; k < xcol; k++)  {  double x = X[i][k]; // 从矩阵X中取出元素  sum += x * W[k][j]; // 累加乘法结果  }  // 将累加结果存储到结果矩阵的对应位置  ans[i][j] = sum;  }  }  // 返回结果矩阵  return ans;  
}
/**  * @brief 矩阵乘法与偏置项相加  *  * 对给定的输入矩阵X、权重矩阵W和偏置项b进行线性变换,即执行X*W+b的操作,  * 并返回结果矩阵。  *  * @param X 输入矩阵,大小为[batch_size, num_input]  * @param W 权重矩阵,大小通常为[num_input, num_output]* @param b 偏置项,大小为[num_output]  * @param batch_size 批量大小,即输入矩阵X的行数  * @param num_input 输入特征的维度  * @param num_output 输出特征的维度  *  * @return 指向结果矩阵的指针,大小为[batch_size, num_output]  *  * @note 调用此函数前,应确保输入矩阵X、权重矩阵W和偏置项b的维度正确匹配。  *       此外,返回的结果矩阵需要调用者在使用完毕后手动释放内存。  */  
double** xwpb(double** X, double** W, double* b, int batch_size, int num_input, int num_output)  
{  // 执行矩阵乘法X*W  double** o = matmul(X, W, batch_size, num_input, num_input, num_output); // 将偏置项b加到结果矩阵o的每一行上  for (int i = 0; i < batch_size; i++) // 遍历批量中的每个样本  {  for(int j = 0; j < num_output; j++) // 遍历输出特征的每个维度  {  o[i][j] += b[j]; // 将偏置项加到结果矩阵的对应位置上  }  }  // 返回结果矩阵  return o;  
}/**  * @brief Softmax回归函数  *  * 对给定的输入矩阵X、权重矩阵W和偏置项b执行线性变换(即XW+b),  * 然后对每个样本的输出应用Softmax函数,并返回包含Softmax结果的向量。  *  * @param X 输入矩阵,大小为[batch_size, num_input]  * @param W 权重矩阵,大小为[num_input, num_output]  * @param b 偏置项,大小为[num_output]  * @param batch_size 批量大小,即输入矩阵的行数  * @param num_input 输入特征的维度  * @param num_output 输出特征的维度(同时也是类别数)  *  * @return 返回一个向量,其中每个元素是一个指向double数组的指针,表示每个样本的Softmax输出  *  * @note 调用此函数前,应确保输入矩阵X、权重矩阵W和偏置项b的维度正确匹配。  *       返回的向量中的double指针数组(即Softmax结果)在使用完毕后需要手动释放内存。  *       函数内部调用了xwpb函数进行线性变换,并调用了softmax函数对每个样本的输出应用Softmax。  */  
vector<double*> sofreg(double** X, double** W, double* b, int batch_size, int num_input, int num_output)  
{  // 执行线性变换XW+b,并返回结果矩阵o  double** o = xwpb(X, W, b, batch_size, num_input, num_output);  // 创建一个大小为batch_size的向量y_hat,用于存储每个样本的Softmax输出  vector<double*> y_hat(batch_size);  // 遍历每个样本  for (int i = 0; i < batch_size; i++)  {  // 对当前样本的输出应用Softmax函数,并返回结果指针so  double* so = softmax(o[i], num_output);  // 将Softmax结果存储到y_hat向量的对应位置  y_hat[i] = so;  } // 释放内存 for(int i=0; i<batch_size; i++) delete[] o[i];delete[] o;// 返回包含每个样本Softmax输出的向量  return y_hat;  
}
/**  * @brief 交叉熵损失函数  *  * 计算给定预测值(经过Softmax处理后的概率分布)y_hat和实际标签y之间的交叉熵损失。  *  * @param y_hat 预测值向量,每个元素是一个指向double数组的指针,表示每个样本的Softmax输出  * @param y 实际标签数组,为0到9之间的整数* @param batch_size 批量大小,即y_hat和y中元素的数量  * @param num_output 输出特征的维度(同时也是类别数),在此为10(0-9的10个类别)  *  * @return 返回一个指向double数组的指针,数组大小为batch_size,表示每个样本的交叉熵损失  *  * @note 调用此函数前,应确保y_hat和y的长度相等,并且与batch_size匹配。  *       此外,y中的每个标签值应为0到num_output-1之间的整数。  *       函数内部使用了assert来检查y中的值是否在有效范围内,以及y_hat中对应位置的预测值是否在(0,1)之间。  *       返回的double数组需要调用者在使用完毕后手动释放内存。  */  
double* cross_entropy(vector<double*> y_hat, char* y, int batch_size, int num_output)  
{  // 分配一个大小为batch_size的double数组,用于存储每个样本的交叉熵损失  double* loss = new double[batch_size];  // 遍历每个样本  for (int i = 0; i < batch_size; i++)  {  int yi = y[i];// 使用assert断言来检查标签值是否在有效范围内(0-9)  assert(yi >= 0 && yi <= 9);  // 使用assert断言来检查y_hat中对应位置的预测值是否在(0,1)之间  assert(y_hat[i][yi] > 0 && y_hat[i][yi] < 1);  // 计算交叉熵损失,这里只考虑了单标签的情况,即每个样本只有一个类别标签  loss[i] = -log(y_hat[i][yi]);  }  // 返回包含每个样本交叉熵损失的double数组  return loss;  
}
/*** @brief sgd 函数用于执行随机梯度下降(Stochastic Gradient Descent)算法  
// 来更新神经网络中的权重 W 和偏置 b  // 参数说明:  
// X: 输入数据,是一个二维数组(指针的指针),大小为 [batch_size][num_input]  
// y: 标签数据,是一个字符串(但实际上是标签的索引数组),大小为 [batch_size]  
// W: 权重矩阵,是一个二维数组(指针的指针),大小为 [num_input][num_output]  
// b: 偏置向量,是一个一维数组,大小为 [num_output]  
// lr: 学习率,用于控制权重更新的步长  
// batch_size: 批量大小,即每次用于梯度计算的样本数量  
// num_input: 输入数据的特征数量  
// num_output: 输出数据的类别数量(或神经元的数量)  */
void sgd(double** X, const char* y, double** W, double* b, double lr, int batch_size, int num_input, int num_output)
{
//	vector<double*> y_hat = sofreg(X, W, b, batch_size, num_input, num_output);// 计算线性组合的结果(未经过激活函数)  double** o=xwpb(X, W, b, batch_size, num_input, num_output);// 为权重梯度 gradw 和偏置梯度 gradb 分配内存  	double** gradw=new double*[num_input];double* gradb=new double[num_output];// 初始化权重梯度 gradw 为 0  for (int i=0; i<num_input; i++){gradw[i] = new double[num_output];for (int j=0; j<num_output; j++)gradw[i][j]=0.0;}// 初始化偏置梯度 gradb 为 0for (int j=0; j<num_output; j++)gradb[j]=0.0;// 遍历批量中的每个样本,计算梯度  for (int i=0; i<batch_size; i++){int yi = y[i];// 计算 softmax 函数的结果  double* so=softmax(o[i], num_output);// 计算 cross entropy 对 小批量的未规范化预测 O 的导数 // softmax(o)[j]-y[j], 将 y 视为独热标签向量 double grad[num_output];for(int j=0; j<num_output; j++){grad[j] = so[j];}grad[yi]-=1;// 计算 gradb , cross entropy 对 b 的导数,链式求导 // o = X * W + b for (int j=0; j<num_output; j++){gradb[j]+=grad[j];}// 计算 gradw ,cross entropy 对 W 的导数,链式求导 // o = X * W + b for (int j=0; j<num_input; j++){for (int k=0; k<num_output; k++){double x=X[i][j];gradw[j][k] += grad[k]*x;}}delete[] so;}// 使用计算得到的梯度来更新权重 W 和偏置 b  for(int i=0; i<num_input; i++){for (int j=0; j<num_output; j++){W[i][j] = W[i][j] - lr * gradw[i][j] / batch_size;}}for (int i=0; i<num_output; i++){b[i] = b[i] - lr * gradb[i]/ batch_size;}for (int i=0; i<batch_size; i++) delete[] o[i];delete[] o;for (int i=0; i<num_input; i++) delete[] gradw[i];delete[] gradw;delete[] gradb; 
}
/**  * @brief 计算平均值  *  * 计算给定双精度浮点数数组的平均值。  *  * @param loss 包含要计算平均值的双精度浮点数的数组  * @param len 数组的长度(元素的数量)  *  * @return 数组 `loss` 中所有元素的平均值  *  */  
double mean(double* loss, int len)  
{  double ans = 0; // 初始化累加器为 0  assert(len>0); // 断言数组长度必须大于 0  // 遍历数组 `loss` 中的每个元素  for (int i = 0; i < len; i++)  {  // 将当前元素加到累加器 `ans` 上  ans += loss[i];  }  // 返回累加器 `ans` 除以数组长度 `len` 的结果,即平均值  return ans / len;  
}unsigned char** read_mnist_image(string file_name, int& num_image, int& num_row, int& num_col, const int check_number);
char* read_mnist_label(string file_name, const int num_image, const int check_number);
unsigned char** get_image(string path, int& num_image, int& num_row, int& num_col, bool is_train);
char* get_label(string path, int num_image, bool is_train);
/**  * @brief 归一化图像数据  *  * 将输入的二维无符号字符数组(通常是灰度图像)归一化到 0 到 1 的范围内,  * 并返回一个二维双精度浮点数数组,其中包含了归一化后的图像数据。  *  * @param cX 输入的二维无符号字符数组,代表原始图像数据  * @param row 图像的行数  * @param col 图像的列数  *  * @return 指向归一化后二维双精度浮点数数组的指针  *  * @note 调用者需要确保输入的 cX 数组是有效且已经分配了足够的内存。  *       返回的 X 数组需要调用者在使用完毕后手动释放内存。  */  
double** normalization(unsigned char** cX, int row, int col)  
{  // 创建一个新的二维双精度浮点数数组 X 来存储归一化后的图像数据  double** X = new double*[row];  for(int i=0; i<row; i++)  {  X[i] = new double[col];  }  // 遍历原始图像数据的每个像素,并进行归一化  for (int i=0; i<row; i++)  {  for (int j=0; j<col; j++)  {  // 读取原始图像数据中的像素值  int x = cX[i][j];  // 归一化到 0 到 1 的范围  X[i][j] = x * 1.0 / 255.0;}  }  // 返回归一化后的图像数据  return X;  
}/**  * @brief 打乱图像数据和标签的顺序  *  * 使用 Fisher-Yates 洗牌算法(也被称为 Knuth 洗牌)结合一个随机数生成器来  * 打乱传入的图像数据和对应的标签。 *  * @param X 指向图像数据的指针数组,每个元素指向一个图像(一维数组)  * @param y 指向标签数据的指针,每个元素表示一个标签  * @param num_image 图像和标签的数量  *  * @note  此函数会直接修改传入的 X 和 y,而不需要额外的存储空间。  */  
void shuffle(unsigned char** X, char* y, int num_image)  
{  // 创建一个整数向量 num,用于存储原始索引  vector<int> num(num_image);  for(int i = 0; i < num_image; i++) num[i] = i;  // 使用当前时间作为随机数生成器的种子  // 这样可以确保每次调用 shuffle 函数时都能得到不同的随机序列  random_device rd;    mt19937 g(rd()); // 使用 Mersenne Twister 算法来生成随机数  // 打乱整数向量 num 中的元素顺序  shuffle(num.begin(), num.end(), g);  // 使用 Fisher-Yates 洗牌算法来打乱图像数据和标签的顺序  unsigned char* tmpcp; // 临时指针,用于交换图像数据  char tmpc;            // 临时字符,用于交换标签  for (int i = 0; i < num_image; i++)  {  // 交换图像数据  tmpcp = X[i];  X[i] = X[num[i]];  X[num[i]] = tmpcp;  // 交换标签数据  tmpc = y[i];  y[i] = y[num[i]];  y[num[i]] = tmpc;  }  
}int main()
{// 定义数据集的路径  string path="../data/MNIST/raw/";// 定义变量来存储图像和标签的数量以及尺寸  // 训练图像的数量、像素行数和列数(高和宽) int num_image, num_row, num_col;// 测试图像的数量、像素行数和列数  int num_test_image, num_test_row, num_test_col;// 从指定路径读取训练集与测试集图像,并返回图像数据和图像数量以及像素宽高 unsigned char**      cX = get_image(path, num_image,      num_row,      num_col,      true);unsigned char** test_cX = get_image(path, num_test_image, num_test_row, num_test_col, false);// 从指定路径加载标签  char*      y = get_label(path, num_image,      true);char* test_y = get_label(path, num_test_image, false);	// 对训练数据和标签进行随机打乱  shuffle(cX, y, num_image);// 对图像数据进行归一化处理,并返回处理后的数据  double** X=normalization(cX, num_image, num_row*num_col);double** test_X=normalization(test_cX, num_test_image, num_test_row*num_test_col);	// 定义超参数  const double lr = 0.01;// 学习率 const int num_epochs = 10;// 训练轮数  const int num_output = 10;// 输出层神经元数量(对应MNIST的10个类别) const int batch_size = 256;// 批量大小const int num_sample = num_image;// 总样本数(这里等于训练样本数) const int num_input = num_row * num_col;	// 输入层神经元数量(等于图像的像素数) 	// 初始化权重矩阵W和偏置向量b  double** W=new double* [num_input];for (int i=0; i<num_input; i++) W[i]=new double[num_output];double* b=new double[num_output];// 将W和b的所有元素初始化为0.0  for(int i=0; i<num_input; i++){for (int j=0; j<num_output; j++){W[i][j]=0.0;}}for (int j=0; j<num_output; j++){b[j]=0.0;}// 开始进行训练循环,迭代num_epochs次  for (int epoch=0; epoch<num_epochs; epoch++){// 对所有样本进行迭代,每次处理batch_size个样本  for (int j=0; j<num_sample; j+=batch_size){// 确保每一批量获得正确的样本个数 int batch = min(batch_size, num_sample-j);// 对当前batch的数据进行softmax回归计算,得到预测结果y_hat  vector<double*> y_hat = sofreg(X+j, W, b, batch, num_input, num_output);// 计算当前batch的交叉熵损失 double* loss = cross_entropy(y_hat, y+j, batch, num_output);// 使用随机梯度下降(SGD)更新权重W和偏置b  sgd(X+j, y+j, W, b, lr, batch, num_input, num_output);delete[] loss;for (auto i:y_hat) delete[] i;y_hat.clear();}// 在每个epoch结束后,测试模型在测试集上的性能  {// 初始化索引和当前batch的大小(对于测试集,这里通常使用整个测试集)  int j=0;// 但因为测试集通常全部使用,所以batch_size可能不会被限制 int batch = min(batch_size, num_test_image-j);// 对测试集进行softmax回归计算,得到预测结果y_hat  vector<double*> y_hat = sofreg(test_X+j, W, b, batch, num_input, num_output);// 初始化预测正确的样本数  int right_num=0;// 遍历当前batch的所有样本 for (int i=0; i<batch; i++){// 获取当前样本的预测结果 double* yy = y_hat[i];double mm=0, id=-1;// 找到预测概率最大的类别  for (int j=0; j<num_output; j++){if (yy[j]>mm) mm=yy[j], id=j;}// 检查预测类别是否与实际类别相同,如果相同则增加正确数if (id == (test_y+j)[i]) right_num++;}// 计算并打印当前epoch的测试集准确率double* loss = cross_entropy(y_hat, test_y+j, batch, num_output);printf("in epoch %d, accuracy is %.4Lf\n", epoch+1, right_num*1.0/batch*1.0);delete[] loss;for (auto i:y_hat) delete[] i;y_hat.clear();}	}// 累了,交给操作系统自己释放吧 //delete cX, test_cX, y, test_y, X, test_X, w, b;
}/*******************************************
// 读取MNIST数据集图像的函数  
// 参数:  
// file_name: 图像文件的名字,需要绝对或相对路径    
// num_image: 读取的图像数量(引用传递,用于修改外部变量)  
// num_row: 每张图像的行数(引用传递,用于修改外部变量)  
// num_col: 每张图像的列数(引用传递,用于修改外部变量)  
// check_number: 用于检查文件头部magic number的期望值  
// 返回值:  
// 返回一个二维指针,指向由unsigned char数组组成的图像数组 
// 第一个维度是图片数量,第二个维度是单张图片大小 
// 注意:调用此函数的代码应确保在适当的时候释放images指向的内存,避免内存泄漏 
********************************************/
unsigned char** read_mnist_image(string file_name, int& num_image, int& num_row, int& num_col, const int check_number)
{// 以二进制读模式打开文件  FILE *fp = fopen(file_name.c_str(), "rb");// 如果文件打开失败,退出程序if (!fp){printf("file open fail!\n");exit(0);}// 读取magic number、图像数量、图像的行数和列数 int magic_number;fread((char*)&magic_number, sizeof(magic_number), 1, fp);fread((char*)&num_image, sizeof(num_image), 1, fp);fread((char*)&num_row, sizeof(num_row), 1, fp);fread((char*)&num_col, sizeof(num_col), 1, fp);//由于MNIST文件是以大端字节序存储的,所以需要转换为小端序 magic_number=reverseInt(magic_number);num_image=reverseInt(num_image);num_row=reverseInt(num_row);num_col=reverseInt(num_col);// 检查magic number是否匹配  if (check_number != magic_number){printf("magic number is error, this is not the right image file\n");fclose(fp); // 关闭文件句柄  exit(0); // 退出程序  }// 分配二维数组以存储图像  unsigned char** images=new unsigned char*[num_image];// 读取所有图像for(int i=0; i<num_image; i++){// 为每个图像分配内存  unsigned char* image=new unsigned char[num_row * num_col];// 读取图像数据,fread(image, sizeof(unsigned char), num_row * num_col, fp);// 将图像数据存入二维数组 images[i]=image;}// 关闭文件句柄  fclose(fp);// 返回二维图像指针 return images;// 示例,使用delete[]来释放每个图像的内存,并最后释放images本身// for (int i = 0; i < num_image; ++i) {  //     delete[] images[i];  // }  // delete[] images;  
}/*************************************
// 读取MNIST数据集标签的函数  
// 参数:  
// file_name: 标签文件的名字,需要绝对或相对路径  
// num_image: 预期读取的标签数量(应与文件内标签数量一致)  
// check_number: 用于检查文件头部magic number的期望值,检查文件是否正确  
// 返回值:  
// 返回一个包含标签的char数组指针,此处char应理解为单字节类型整数  
// 注意:调用此函数的代码应确保在适当的时候释放labels指向的内存,避免内存泄漏 
***************************************/
char* read_mnist_label(string file_name, const int num_image, const int check_number)
{// 以二进制读模式打开文件  FILE *fp = fopen(file_name.c_str(), "rb");// 如果文件打开失败,退出程序if (!fp){printf("file open fail!\n");exit(-1);}// 定义并读取magic number和标签数量  int magic_number, num_label;fread((char*)&magic_number, sizeof(magic_number), 1, fp);fread((char*)&num_label, sizeof(num_label), 1, fp);//由于MNIST文件是以大端字节序存储的,所以需要转换为小端序 magic_number=reverseInt(magic_number);num_label=reverseInt(num_label);// 检查magic number是否匹配  if (check_number != magic_number){printf("magic number is error, this is not the right label file!\n");fclose(fp);exit(-1);}// 检查标签数量是否与预期一致  if (num_label!=num_image){printf("num_label not equal num_image!\n");fclose(fp);exit(-1);}// 动态分配内存以存储标签  char* labels=new char[num_label];// 读取所有标签 for(int i=0; i<num_label; i++){fread(&labels[i], sizeof(char), 1, fp);}// 关闭文件句柄  fclose(fp);// 返回标签数组指针return labels;// 示例,使用delete[]来释放标签的内存//delete[] labels; 
}/**  * @brief 获取 MNIST 数据集的图像数据  *  * 根据指定的文件路径和是否训练数据集的标志,从 MNIST 数据集中加载图像数据,  * 并返回指向图像数据的指针(二维数组)。同时,更新图像数量、行数和列数的引用参数。  *  * @param path 数据集所在的路径  * @param num_image 引用参数,用于返回图像数量  * @param num_row 引用参数,用于返回每个图像的行数  * @param num_col 引用参数,用于返回每个图像的列数  * @param is_train 是否为训练数据集的标志,true 为训练数据,false 为测试数据  *  * @return 指向图像数据的指针(二维数组),每个元素为 unsigned char 类型  */  
unsigned char** get_image(string path, int& num_image, int& num_row, int& num_col, bool is_train)
{// 定义 MNIST 数据集的文件名  string name_train_image="train-images-idx3-ubyte";string name_train_label="train-labels-idx1-ubyte";string name_test_image="t10k-images-idx3-ubyte";string name_test_label="t10k-labels-idx1-ubyte";// 根据是否训练数据集的标志,选择加载训练或测试数据集的图像文件  if (is_train) {// 加载训练数据集的图像文件  return read_mnist_image(path+name_train_image, num_image, num_row, num_col, 2051);} else {// 加载测试数据集的图像文件  return read_mnist_image(path+name_test_image, num_image, num_row, num_col, 2051);}
}/**  * @brief 获取 MNIST 数据集的标签数据  *  * 根据给定的路径、图像数量和是否训练数据集的标志,从 MNIST 数据集中加载标签数据,  * 并返回指向标签数据的指针(一维字符数组)。  *  * @param path 数据集所在的路径  * @param num_image 预期的标签数量,用于检查文件标签数量是否与预期一致* @param is_train 是否为训练数据集的标志,true 为训练数据,false 为测试数据  *  * @return 指向标签数据的指针(一维字符数组),每个元素表示一个标签  */  
char* get_label(string path, const int num_image, bool is_train)
{// 定义 MNIST 数据集的文件名 string name_train_image="train-images-idx3-ubyte";string name_train_label="train-labels-idx1-ubyte";string name_test_image="t10k-images-idx3-ubyte";string name_test_label="t10k-labels-idx1-ubyte";// 根据是否训练数据集的标志,选择加载训练或测试数据集的标签文件if (is_train) {// 加载训练数据集的标签文件return read_mnist_label(path+name_train_label, num_image, 2049);} else {// 加载测试数据集的标签文件  return read_mnist_label(path+name_test_label, num_image, 2049);}
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/16626.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HOOK定义

什么是hook HOOK&#xff0c;中文译为“挂钩”或“钩子”。在iOS逆向中是指改变程序运行流程的一种技术。 例如&#xff0c;一个正常的程序运行流程是A->B->C,通过hook技术可以让程序的执行变成A->我们自己的代码->B->C。在这个过程中&#xff0c;我们的代码可…

04Django项目基本运行逻辑及模板资源套用

对应视频链接点击直达 Django项目用户管理及模板资源 对应视频链接点击直达1.基本运行逻辑Django的基本运行路线&#xff1a;视图views.py中的 纯操作、数据返回、页面渲染 2.模版套用1.寻找一个好的模版2.模板部署--修改适配联动 OVER&#xff0c;不会有人不会吧不会的加Q1394…

Java 类加载过程和双亲委派模型

Java 类加载过程概述 在 Java 中&#xff0c;类装载器把一个类装入 Java 虚拟机中&#xff0c;要经过三个步骤来完成&#xff1a;装载、链接和初始化&#xff0c;其中链接又可以分成校验、准备、解析 Java类加载过程分为如下步骤&#xff1a; 1.装载&#xff08; 加载&#xf…

Python编程-后端开发之Django5应用请求处理与模板基础

Python编程-后端开发之Django5应用请求处理与模板基础 最近写项目&#xff0c;刚好用到了Django&#xff0c;现在差不多闲下来&#xff0c;个人觉得单体项目来讲django确实舒服&#xff0c;故写此总结 模板语法了解即可&#xff0c;用到了再看&#xff0c;毕竟分离已经是主流操…

LeetCode300:最长递增子序列

题目描述 给你一个整数数组 nums &#xff0c;找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列&#xff0c;删除&#xff08;或不删除&#xff09;数组中的元素而不改变其余元素的顺序。例如&#xff0c;[3,6,2,7] 是数组 [0,3,1,6,2,2,7] 的 子序列 代码…

react 函数组件 开发模式默认被渲染两次

这是 React 刻意为之&#xff0c;函数式组件应当遵从函数式编程风格&#xff0c;每次执行应该是无副作用的(no sideEffect)&#xff0c;在 dev 下多次渲染组件&#xff0c;是为了防止开发者写出有问题的代码。 用 React 写函数组件&#xff0c;如何避免重复渲染&#xff1f; -…

Java学习【面向对象综合练习——实现图书管理系统】

Java学习【面向对象综合练习——实现图书管理系统】 前期效果图书的创建用户的创建操作的实现完善操作显示图书查找图书新增图书借阅图书归还图书删除图书 前期效果 用户分为普通用户和管理员&#xff0c;登录进系统之后可以对图书进行一系列操作&#xff0c;此时我们要明白&am…

斐讯N1刷OpenWRT并安装内网穿透服务实现远程管理旁路由

文章目录 前言1. 制作刷机固件U盘1.1 制作刷机U盘需要准备以下软件&#xff1a;1.2 制作步骤 2. N1盒子降级与U盘启动2.1 N1盒子降级2.2 N1盒子U盘启动设置2.3 使用U盘刷入OpenWRT2.4 OpenWRT后台IP地址修改2.5 设置旁路由&无线上网 3. 安装cpolar内网穿透3.1 下载公钥3.2 …

时空数据治理白皮书(2024)

来源&#xff1a;泰伯智库&#xff1a; 近期历史回顾&#xff1a;

企业微信修改主体花了大几千的踩坑经验,家人们避雷

企业微信变更主体有什么作用&#xff1f;如果原有的公司注销了&#xff0c;或者要更换一家公司主体来运营企业微信&#xff0c;那么就可以进行变更主体&#xff0c;变更主体后才可以保留原来企业微信上的所有用户&#xff0c;否则就只能重新申请重新积累用户了。企业微信变更主…

什么情况下数据库事务会失效?

事务失效&#xff1a; 在Java中&#xff0c;数据库事务失效通常指的是事务无法保证其ACID特性&#xff0c;即原子性、一致性、隔离性和持久性。事务失效可能导致数据不一致&#xff0c;影响系统的可靠性和稳定性。以下是一些常见的事务失效情况&#xff1a; 1、未启用Spring事…

运维专题.Docker功能权限(Capabilities)管理和查看

运维专题 Docker功能权限&#xff08;Capabilities&#xff09; - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:htt…

MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model 论文总结

题目&#xff1a;MedSegDiff: Medical Image Segmentation&#xff08;图像分割&#xff09;with Diffusion Probabilistic Model&#xff08;扩散概率模型&#xff09; 论文&#xff08;MIDL会议&#xff09;&#xff1a;MedSegDiff: Medical Image Segmentation with Diffusi…

python-13(案例讲解)

目录 抓取链家前十页的数据 计算均价和总价 计算的类型&#xff08;整租&#xff0c;合租&#xff09; 计算的房型 抓取boss直聘前十页的数据 抓取boss直聘前十页的数据 将获取数据本地序列化 计算每个区的需求个数与均价 抓取链家前十页的数据 链家网址&#xff1a;长…

海思SD3403,SS928/926,hi3519dv500,hi3516dv500移植yolov7,yolov8(18)-Yolov8改进

yolov8进行二次改进后进行了量化和速度测试 &#xff0c;没有明显速度增加。对比一下模型的性能。 分别用原始模型和改后的模型进行了100 epochs训练。 以下是原始模型的结果。 class P R map0.5 map.95 1 0.79 0.49 0.571 0.316 2 0.851 0.738 0.801 0.538 …

勇于创新,勤于探索 —— 我的创作纪念日

作者主页&#xff1a;爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?typeblog个…

【Java】全局统一异常处理类封装

文章目录 全局异常处理类自定义异常如何使用&#xff08;手动抛出异常&#xff09; 我是一名立志把细节说清楚的博主&#xff0c;欢迎【关注】&#x1f389; ~ 原创不易&#xff0c; 如果有帮助 &#xff0c;记得【点赞】【收藏】 哦~ ❥(^_-)~ 如有错误、疑惑&#xff0c;欢…

纯CSS丝滑边框线条动画

在这个网站&#xff08;minimal-portfolio-swart.vercel.app&#xff09;发现一个不错的交互效果&#xff0c;用户体验效果很不错。如封面图所示&#xff0c;这个卡片上有一根白色的线条围绕着卡片移动&#xff0c;且在线条的卡片内部跟随这一块模糊阴影&#xff0c;特别是在线…

关于Nginx热部署的细节分析

文章目录 前言一、环境准备二、热部署步骤总结 前言 Nginx由于其高并发、高性能、可扩展性好、高可靠性、热部署、BSD许可证等优势被广泛使用&#xff0c;本人主要针对热部署的部分展开说明热部署的具体步骤以及步骤背后发生的具体事情。 本次热部署采用的Nginx版本号为&…

在CentOS 7上配置Elasticsearch开启自启动需要通过systemd服务管理器来实现

在CentOS 7上配置Elasticsearch开启自启动需要通过systemd服务管理器来实现。 1. 安装Elasticsearch 首先,确保你已经安装了Elasticsearch。如果还没有安装,可以按照以下步骤进行安装: # Import the Elasticsearch PGP key sudo rpm --import https://artifacts.elastic.…