【Pytorch】搭建一个简单的泰坦尼克号预测模型

介绍

本文使用PyTorch构建一个简单而有效的泰坦尼克号生存预测模型。通过这个项目,你会学到如何使用PyTorch框架创建神经网络、进行数据预处理和训练模型。我们将探讨如何处理泰坦尼克号数据集,设计并训练一个神经网络,以预测乘客是否在灾难中幸存。

主要内容包括:

  1. 数据准备:介绍如何加载和预处理泰坦尼克号数据集,包括处理缺失值、对类别特征进行编码等。
  2. 构建神经网络模型:定义一个简单的神经网络模型,包括输入层、隐藏层和输出层,并选择适当的激活函数。
  3. 模型训练与评估:通过将数据集划分为训练集和测试集,展示如何训练模型并评估其性能。
  4. 结果预测:对测试集数据进行处理和预测,并将最终结果导出。

通过这个简单的项目,展示如何构建一个简单但实用的预测模型。

目录

    • 介绍
    • 1. 数据准备
      • 数据导入
      • 特征转换
      • 缺失值处理
      • 删除多余数据
    • 2. 模型搭建
    • 3. 模型训练
    • 4. 结果预测
      • 测试集数据处理
      • 预测计算
      • 结果导出

1. 数据准备

import numpy as np 
import torch
import torch.nn as nn
import torch.nn.functional as F 
import os 
from scipy import stats
import pandas as pd 

数据导入

titanic_data =pd.read_csv('train.csv')
titanic_data.columns

在这里插入图片描述

特征转换

df=pd.concat([titanic_data,pd.get_dummies(titanic_data['Sex']).astype(int),pd.get_dummies(titanic_data['Embarked'],prefix='Embarked').astype(int),pd.get_dummies(titanic_data['Pclass'],prefix='class').astype(int)],axis=1)
df.head()

在这里插入图片描述

缺失值处理

df['Age']=df['Age'].fillna(df.Age.mean())
df['Fare']=df['Fare'].fillna(df.Fare.mean())

删除多余数据

df_clean=df.drop(['PassengerId','Name','Ticket','Cabin','Sex','Embarked','Pclass'],axis=1)
df_clean.head()

在这里插入图片描述### 数据切分

labels=df_clean['Survived'].to_numpy()df_clean=df_clean.drop(['Survived'],axis=1)
data=df_clean.to_numpy()feature_names=list(df_clean.columns)np.random.seed(10)
train_indices=np.random.choice(len(labels),int(0.7*len(labels)),replace=False)
test_indices=list(set(range(len(labels)))-set(train_indices))train_features=data[train_indices]
train_labels=labels[train_indices]test_features=data[test_indices]
test_labels=labels[test_indices]len(test_labels)

2. 模型搭建

# 定义Mish激活函数
class Mish(nn.Module):def __init__(self):super().__init__()def forward(self, x):# Mish激活函数的前向传播过程x = x * (torch.tanh(F.softplus(x)))# 返回经过Mish激活函数的结果return x# 设置随机种子
torch.manual_seed(0)# 定义ThreelinearModel模型
class ThreelinearModel(nn.Module):def __init__(self):super().__init__()# 定义三个线性层,用于处理输入特征self.linear1 = nn.Linear(12, 12)self.mish1 = Mish()  # 使用自定义激活函数Mishself.linear2 = nn.Linear(12, 8)self.mish2 = Mish()  # 使用Mish作为第二个激活函数self.linear3 = nn.Linear(8, 2)  # 输出层,用于生成分类结果self.softmax = nn.Softmax(dim=1)  # 对输出进行Softmax,将结果转为概率分布self.criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,用于计算模型误差def forward(self, x):# 模型的前向传播过程lin1_out = self.linear1(x)out1 = self.mish1(lin1_out)out2 = self.mish2(self.linear2(out1))# 经过线性层和激活函数后,通过Softmax得到最终的概率分布return self.softmax(self.linear3(out2))def getloss(self, x, y):# 计算模型预测值与实际标签之间的交叉熵损失y_pred = self.forward(x)loss = self.criterion(y_pred, y)# 返回计算得到的损失值return loss

3. 模型训练

if __name__ == '__main__':# 创建 ThreelinearModel 的神经网络模型net = ThreelinearModel()# 设置训练轮数为200次,选择Adam优化器,学习率为0.04num_epochs = 200optimizer = torch.optim.Adam(net.parameters(), lr=0.04)# 将训练数据转换为PyTorch张量格式input_tensor = torch.from_numpy(train_features).type(torch.FloatTensor)label_tensor = torch.from_numpy(train_labels)# 用于存储每轮训练的损失值losses = []# 开始训练循环for epoch in range(num_epochs):# 计算当前模型在训练数据上的损失值loss = net.getloss(input_tensor, label_tensor)# 记录损失值losses.append(loss.item())# 清零梯度,防止梯度累积optimizer.zero_grad()# 反向传播,计算梯度loss.backward()# 更新模型参数optimizer.step()# 每20轮打印一次训练损失if epoch % 20 == 0:print('Epoch {}/{} => Loss: {:.2f}'.format(epoch + 1, num_epochs, loss.item()))# 创建'models'文件夹(如果不存在),保存训练好的模型参数os.makedirs('models', exist_ok=True)torch.save(net.state_dict(), 'models/titanic_model.pt')# 使用训练好的模型进行训练集的预测out_probs = net(input_tensor).detach().numpy()out_classes = np.argmax(out_probs, axis=1)# 输出训练集准确率print("Training Accuracy: ", sum(out_classes == train_labels) / len(train_labels))# 使用训练好的模型进行测试集的预测test_input_tensor = torch.from_numpy(test_features).type(torch.FloatTensor)out_probs = net(test_input_tensor).detach().numpy()out_classes = np.argmax(out_probs, axis=1)# 输出测试集准确率print("Testing Accuracy: ", sum(out_classes == test_labels) / len(test_labels))

在这里插入图片描述

4. 结果预测

测试集数据处理

test=pd.read_csv('/kaggle/input/titanic/test.csv')
test_df=pd.concat([test,pd.get_dummies(test['Sex']).astype(int),pd.get_dummies(test['Embarked'],prefix='Embarked').astype(int),pd.get_dummies(test['Pclass'],prefix='class').astype(int)],axis=1)test_df['Age']=test_df['Age'].fillna(df.Age.mean())
test_df['Fare']=test_df['Fare'].fillna(df.Fare.mean())Id=test_df['PassengerId']
test_df_clean=test_df.drop(['PassengerId','Name','Ticket','Cabin','Sex','Embarked','Pclass'],axis=1)
pred_features=test_df_clean.to_numpy()

预测计算

pred_input_tensor=torch.from_numpy(pred_features).type(torch.FloatTensor)
pred_out_probs=net(pred_input_tensor).detach().numpy()
pred_classes=np.argmax(pred_out_probs,axis=1)

结果导出

submission= pd.DataFrame({'PassengerId': Id,'Survived': pred_classes[:],
})
# Save the submission file
submission.to_csv('submission.csv', index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/641538.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL数据库查询语句之组函数,子查询语句

组函数 以组为操作单位,一组数据得到一个结果。 在没有手动分组的前提下,整张表默认为一组数据 max(列名):获取最大值 min(列名):获取最小值 sum(列名):获取总和 avg(列名):获取平均值 count(列名)&a…

20.云原生之GitLab CICD实战

云原生专栏大纲 文章目录 GitLab RunnerGitLab Runner 介绍Gitlab Runner工作流程 Gitlab集成Gitlab RunnerGitLab Runner 版本选择Gitlab Runner部署docker-compose方式安装kubesphere中可视化方式安装helm方式安装 配置gitlab-runner配置gitlab-ci.ymlgitlab-ci.yml 介绍编写…

基于FPGA的高效乘法器

1、设计思路 二进制的乘法运算与十进制的乘法运算相似,如下图所示,二进制数据6’b110010乘以二进制数据4’b1011,得到乘积结果10’b1000100110。 图1 二进制乘法运算 仔细观察上图发现,乘数最低位为1(上图紫色数据位&a…

机器学习:什么是监督学习和无监督学习

目录 一、监督学习 (一)回归 (二)分类 二、无监督学习 聚类 一、监督学习 介绍:监督学习是指学习输入到输出(x->y)映射的机器学习算法,监督即理解为:已知正确答案…

期末考试发等级发成绩,就用易查分!

期末考试后,学校老师如何发布私密成绩?易查分可以轻松创建等级、成绩查询系统,让家长仅看到自己孩子成绩。 支持查询后留言反馈,电子签名确认签收等高级功能,节省老师沟通时间,大大提升工作效率。 &#x1…

linux安装docker(入门一)

环境:centos 7(linux) 网站 官网: https://docs.docker.com/ Docker Hub 网站: https://hub.docker.com/ 容器官方概述 一句话概括容器:容器就是将软件打包成标准化单元,以用于开发、交付和部署。 容器镜像是轻量的、可执行的独立软件包 &…

【百面机器学习】读书笔记(一)

本文系列主要作用就是读书笔记,自己看的话比较杂,没怎么归类过,所以现在跟着这个分类走一遍。本文主要内容为前两章,特征工程和模型评估。 如果我想起一些相关的内容也会做适当的补充,主打就是一个intuition&#xff…

OpenCV书签 #直方图算法的原理与相似图片搜索实验

1. 介绍 直方图算法(Image Histogram Algorithm) 通过统计图像中各个颜色值的分布情况来提供关于图像颜色特征的信息,它可以用来衡量两张图片在颜色分布上的相似度,进而可以用来进行图像相似度的比较,因此&#xff0c…

电脑录屏软件大比拼,哪个最适合你?

现如今,电脑录屏软件成为了许多用户记录、分享和教学的重要工具。从游戏玩家到专业制作人员,都需要高效的录屏软件。本文将介绍三款优秀的电脑录屏软件,通过详细的步骤和简洁的介绍,帮助用户轻松掌握这些工具的使用方法。 电脑录屏…

附1:k8s服务器初始化

转载说明:如果您喜欢这篇文章并打算转载它,请私信作者取得授权。感谢您喜爱本文,请文明转载,谢谢。 关联文章: 《RKE快速搭建离线k8s集群并用rancher管理界面》 1.创建普通用户sre并赋予sudo权限 # adduser sre # ec…

【动态规划】879. 盈利计划

作者推荐 【动态规划】【广度优先搜索】【状态压缩】847 访问所有节点的最短路径 本文涉及知识点 动态规划汇总 LeetCode879. 盈利计划 集团里有 n 名员工,他们可以完成各种各样的工作创造利润。 第 i 种工作会产生 profit[i] 的利润,它要求 group[…

大模型独立解答30道国际奥数难题,能力接近金牌选手!

谷歌旗下的AI研究机构DeepMind和纽约大学的研究人员联合开发了一个AI模型——AlphaGeometry。 AlphaGeometr是一种神经符号模型,内置了大语言模型和符号推理引擎等功能,主要用于解决各种超难几何数学题,同时可以自动生成易于查看的解题原理。 为了验证AlphaGeomet…

React Router v6 改变页面Title

先说正事再闲聊 1、在路由表加个title字段 2、在index包裹路由 3、在App设置title 闲聊: 看到小黄波浪线了没 就是说默认不支持title字段了 出来的提示, 所以我本来是像下面这样搞的,就是感觉有点难维护,就还是用上面的方法了 …

Linux配置yum源以及基本yum指令

文章目录 一、yum介绍二、什么是软件包三、配置yum源四、一键配置yum源【三步走】五、yum指令搜索软件安装软件卸载软件 六、其他yum指令更新内核更新软件更新指定软件显示所有可更新的软件清单卸载指定包并自动移除依赖包删除软件包,以及软件包数据和配置文件 一、…

快速上手MyBatis Plus:简化CRUD操作,提高开发效率!

MyBatisPlus 1,MyBatisPlus入门案例与简介1.1 入门案例步骤1:创建数据库及表步骤2:创建SpringBoot工程步骤3:勾选配置使用技术步骤4:pom.xml补全依赖步骤5:添加MP的相关配置信息步骤6:根据数据库表创建实体类步骤7:创建Dao接口步骤8:编写引导类步骤9:编写测试类 1.2…

Redis常见类型及常用命令

目录 常见的数据类型 一、String类型 1、简介 2、常用命令 (1)新建key (2)设值取值 ​编辑 (3)批量操作 (4)递增递减 3、原子性操作 4、数据结构 二、list类型 1、list常…

Pytest中conftest.py的用法

Pytest中conftest.py的用法 ​ 在官方文档中,描述conftest.py是一个本地插件的文件,简单的说就是在这个文件中编写的方法,可以在其他地方直接进行调用。 注意事项 只能在根目录编写conftest.py 插件加载顺序在搜集用例之前 基础用法 这里…

centos 启动nacos pg版本

背景:支持国产化需求,不再使用mysql 1.修改插件 git clone https://github.com/wuchubuzai2018/nacos-datasource-extend-plugins.git cd nacos-datasource-extend-plugins/nacos-postgresql-datasource-plugin-ext mvn package编译成功后,…

原来岳云鹏背后的女人竟然是她?有她,岳云鹏红遍大江南北。

♥ 为方便您进行讨论和分享,同时也为能带给您不一样的参与感。请您在阅读本文之前,点击一下“关注”,非常感谢您的支持! 文 |猴哥聊娱乐 编 辑|徐 婷 校 对|侯欢庭 岳云鹏,一个出身于农村的普通孩子,曾经…

springboot小白入门

创建启动 省略。。。 第二章 springboot接口 本章学习: 1.接口定义 2.接收数据 3.返回数据 RestController注解,相当于ResponseBody + ControllerController负责接收用户的请求ResponseBody把数据写入到HTTP响应体的body部分RequestMappin…