【Pytorch】Fizz Buzz

在这里插入图片描述

文章目录

  • 1 数据编码
  • 2 网络搭建
  • 3 网络配置,训练
  • 4 结果预测
  • 5 翻车现场

学习参考来自:

  • Fizz Buzz in Tensorflow
  • https://github.com/wmn7/ML_Practice/tree/master/2019_06_10
  • Fizz Buzz in Pytorch

I need you to print the numbers from 1 to 100, except that if the number is divisible by 3 print “fizz”, if it’s divisible by 5 print “buzz”, and if it’s divisible by 15 print “fizzbuzz”.

编程题很简单,我们用 MLP 实现试试

思路,训练集数据101~1024,对其进行某种规则的编码,标签为经分类 one-hot 编码后的标签
测试集,1~100

don’t say so much, show me the code.

1 数据编码

import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Datadef binary_encode(i, num_digits):"""将每个input转换为binary digits(转换为二进制的表示, 最多可是表示2^num_digits):param i::param num_digits::return:"""return np.array([i >> d & 1 for d in range(num_digits)])

编码形式,依次除以 2 0 , 1 , 2 , 3 , . . . 2^{0,1,2,3,...} 20,1,2,3,...,结果按位与 1

m & 1,结果为 0 表示 m 为偶数, 结果为 1 表示 m 为奇数

> > m >> m >>m 右移表示除以 2 m 2^m 2m

第一位就能表示奇偶了,所有数字编码都不一样

eg,101 进行 num_digits=10 编码后结果为 1 0 1 0 0 1 1 0 0 0

步骤

101 / 1 = 101 奇数 1
101 / 2 = 50 偶数 0
101 / 4 = 25 奇数 1
101 / 8 = 12 偶数 0
101 / 16 = 6 偶数 0
101 / 32 = 3 奇数 1
101 / 64 = 1 奇数 1
101 / 128 = 0 偶数 0
101 / 256= 0 偶数 0
101 / 512= 0 偶数 0

标签,0,1,2,3 四个类别

def fizz_buzz_encode(i):"""将output转换为lebel:param i::return:"""if i % 15 == 0:  # fizzbuzzreturn 3elif i % 5 == 0:  # buzzreturn 2elif i % 3 == 0:  # fizzreturn 1else:return 0

编码长度设定,数据集 101 ~ 1024

NUM_DIGITS = 10
trX = np.array([binary_encode(i, NUM_DIGITS) for i in range(101, 2**NUM_DIGITS)])  # 101~1024
trY = np.array([fizz_buzz_encode(i) for i in range(101, 2**NUM_DIGITS)])# print(len(trX), len(trY))  # 923 923
# print(trX[:5])
"""
[[1 0 1 0 0 1 1 0 0 0][0 1 1 0 0 1 1 0 0 0][1 1 1 0 0 1 1 0 0 0][0 0 0 1 0 1 1 0 0 0][1 0 0 1 0 1 1 0 0 0]]
"""
# print(trY[:5])  # [0 1 0 0 3]

2 网络搭建

搭建简单的 MLP 网络

class FizzBuzzModel(nn.Module):def __init__(self, in_features, out_classes, hidden_size, n_hidden_layers):super(FizzBuzzModel,self).__init__()layers = []for i in range(n_hidden_layers):layers.append(nn.Linear(hidden_size,hidden_size))# layers.append(nn.Dropout(0.5))layers.append(nn.BatchNorm1d(hidden_size))layers.append(nn.ReLU())self.inputLayer = nn.Linear(in_features, hidden_size)self.relu = nn.ReLU()self.layers = nn.Sequential(*layers)  # 重复的搭建隐藏层self.outputLayer = nn.Linear(hidden_size, out_classes)def forward(self, x):x = self.inputLayer(x)x = self.relu(x)x = self.layers(x)out = self.outputLayer(x)return out

初始化网络,看看网络结构

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# define the model
simpleModel = FizzBuzzModel(NUM_DIGITS, 4, 150, 3).to(device)
print(simpleModel)
"""
FizzBuzzModel((inputLayer): Linear(in_features=10, out_features=150, bias=True)(relu): ReLU()(layers): Sequential((0): Linear(in_features=150, out_features=150, bias=True)(1): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU()(3): Linear(in_features=150, out_features=150, bias=True)(4): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU()(6): Linear(in_features=150, out_features=150, bias=True)(7): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(8): ReLU())(outputLayer): Linear(in_features=150, out_features=4, bias=True)
)
"""

输入 10, 输出4,隐藏层维度 150,隐藏层重复了 3 次

3 网络配置,训练

定义下超参数,损失函数,优化器,载入数据训练,输出训练精度与损失

# Loss and optimizer
learning_rate = 0.05
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(simpleModel.parameters(), lr=learning_rate)
optimizer = torch.optim.SGD(simpleModel.parameters(), lr=learning_rate)# 使用batch进行训练
FizzBuzzDataset = Data.TensorDataset(torch.from_numpy(trX).float().to(device),torch.from_numpy(trY).long().to(device))loader = Data.DataLoader(dataset=FizzBuzzDataset,batch_size=128*5,shuffle=True)# 进行训练
simpleModel.train()
epochs = 3000for epoch in range(1, epochs):for step, (batch_x, batch_y) in enumerate(loader):out = simpleModel(batch_x)  # 前向传播loss = criterion(out, batch_y)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播optimizer.step()  # 随机梯度下降correct = 0total = 0_, predicted = torch.max(out.data, 1)total += batch_y.size(0)correct += (predicted == batch_y).sum().item()acc = 100*correct/totalprint('Epoch : {:0>4d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, loss, acc))"""
Epoch : 0001 | Loss : 1.5343 | Train Accuracy : 14.63 %
Epoch : 0002 | Loss : 1.9779 | Train Accuracy : 42.58 %
Epoch : 0003 | Loss : 2.4198 | Train Accuracy : 53.41 %
Epoch : 0004 | Loss : 1.7360 | Train Accuracy : 53.41 %
Epoch : 0005 | Loss : 1.3161 | Train Accuracy : 49.73 %
Epoch : 0006 | Loss : 1.4866 | Train Accuracy : 22.75 %
Epoch : 0007 | Loss : 1.3993 | Train Accuracy : 25.57 %
Epoch : 0008 | Loss : 1.2428 | Train Accuracy : 28.49 %
Epoch : 0009 | Loss : 1.1906 | Train Accuracy : 44.31 %
Epoch : 0010 | Loss : 1.1929 | Train Accuracy : 52.44 %
...
Epoch : 2990 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2991 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2992 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2993 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2994 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2995 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2996 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2997 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2998 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2999 | Loss : 0.0000 | Train Accuracy : 100.00%
"""

训练集上精度是 OK 的,能到 100%,下面看看测试集上的精度

4 结果预测

把 one-hot 标签转化成 fizz buzz 的形式

def fizz_buzz_decode(i, prediction):return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

载入测试集,开始预测

simpleModel.eval()
# 进行预测
testX = np.array([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
predicts = simpleModel(torch.from_numpy(testX).float().to(device))
# 预测的结果
_, res = torch.max(predicts, 1)
print(res)
"""
tensor([0, 0, 0, 1, 0, 0, 0, 2, 1, 0, 1, 3, 3, 1, 1, 0, 0, 0, 0, 0, 0, 3, 1, 0,0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0, 1, 1, 1, 0,0, 0, 0, 0], device='cuda:0')
"""# 格式的转换
predictions = [fizz_buzz_decode(i, prediction) for (i, prediction) in zip(range(1, 101), res)]
print(predictions)
"""
['1', '2', '3', 'fizz', '5', '6', '7', 'buzz', 'fizz', '10', 'fizz', 'fizzbuzz', 'fizzbuzz', 'fizz', 'fizz', '16', '17', '18', '19', '20', '21', 'fizzbuzz', 'fizz', '24', '25', '26', '27', '28', '29', '30', 'fizz', '32', '33', '34', '35', '36', '37', '38', '39', 'fizz', '41', 'fizz', '43', '44', '45', 'fizz', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', 'fizz', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', 'fizz', 'fizz', 'fizz', 'fizz', 'buzz', 'buzz', 'fizz', '80', '81', '82', '83', '84', '85', '86', '87', 'fizzbuzz', '89', '90', 'fizz', '92', 'fizz', 'fizz', 'fizz', '96', '97', '98', '99', '100']
"""

5 翻车现场

对比下标签

labels = []
for i in range(1, 101):if i % 15 == 0:  # fizzbuzzlabels.append("fizzbuzz")elif i % 5 == 0:  # buzzlabels.append("buzz")elif i % 3 == 0:  # fizzlabels.append("fizz")else:labels.append(str(i))
print(labels)
print(labels == predictions)"""
['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'buzz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz']
False
"""

哈哈哈, False 翻车了,尝试了很多次,很难 True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/224354.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

牛客网BC92逆序输出

答案&#xff1a; #include <stdio.h>int main() {int i0, j0;int arr[10]{0};for(i0;i<10;i) //将10个整数存进数组里{scanf("%d",&arr[i]);}for(j9;j>0;j--) //逆序打印{printf("%d ",arr[j]); //若要求最后一个数后面不打印空格…

【Hive】——CLI客户端(bin/beeline,bin/hive)

1 HiveServer、HiveServer2 2 bin/hive 、bin/beeline 区别 3 bin/hive 客户端 hive-site.xml 配置远程 MateStore 地址 XML <?xml version"1.0" encoding"UTF-8" standalone"no"?> <?xml-stylesheet type"text/xsl" hre…

C# WPF上位机开发(利用tcp/ip网络访问plc)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 c# wpf如果是用来开发非标上位机的&#xff0c;那么和plc的通信肯定是少不了的。而且&#xff0c;大部分plc都支持modbus协议&#xff0c;所以这个…

neo4j安装报错:neo4j.bat : 无法将“neo4j.bat”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。

neo4j安装报错&#xff1a; neo4j.bat : 无法将“neo4j.bat”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。请检查名称的拼写&#xff0c;如果包括路径&#xff0c;请确 保路径正确&#xff0c;然后再试一次。 解决办法&#xff1a; 在环境变量中的&#xff0c;用户…

Shopee ERP:提升电商管理效率的终极解决方案

Shopee ERP&#xff08;Enterprise Resource Planning&#xff0c;企业资源规划&#xff09;是一款专为Shopee卖家设计的集成化电商管理软件。通过使用Shopee ERP系统&#xff0c;卖家可以更高效地管理他们的在线商店&#xff0c;实现库存管理、订单处理、物流跟踪、财务管理、…

优先考虑类型安全的异构容器

在Java中&#xff0c;异构容器是一种可以存储不同类型元素的容器。为了提高类型安全性&#xff0c;可以使用泛型和类型安全的异构容器&#xff0c;而不是传统的非类型安全容器。下面是一个例子&#xff0c;演示如何使用类型安全的异构容器 import java.util.HashMap; import j…

alpine linux 之嵌入式搭建

目录 序启动修改源安装 openssh设置开机网络 ip参考 序 最近发现了 alpine linux 这个文件系统&#xff0c;这是一个基于 musl libc 和 busybox 的面向安全的轻量级 Linux 发行版。 下载了他的文件系统&#xff0c;只有 3M 多的压缩包&#xff0c;非常适合嵌入式系统。 地址…

AIGC专题报告:ChatGPT的工作原理

今天分享的AIGC系列深度研究报告&#xff1a;《AIGC专题报告&#xff1a;ChatGPT的工作原理》。 &#xff08;报告出品方&#xff1a;省时查&#xff09; 报告共计&#xff1a;107页 前言 ChatGPT 能够自动生成一些读起来表面上甚至像人写的文字的东西&#xff0c;这非常了不…

计算机毕业设计 基于SpringBoot的日常办公用品直售推荐系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

AWTK 串口屏开发(2) - 数据绑定高级用法

AWTK 串口屏 智能家居示例 1. 功能 这个例子稍微复杂一点&#xff0c;界面这里直接使用了 立功科技 ZDP1440 HMI 显示驱动芯片 例子中的 UI 文件和资源&#xff0c;重点关注数据绑定。在这里例子中&#xff0c;模型&#xff08;也就是数据&#xff09;里包括一台空调和一台咖…

申请香港高才通计划有什么好处和优势?

申请香港高才通计划有什么好处和优势&#xff1f; 据香港特区政府最新消息&#xff0c;截至今年11月底&#xff0c;各项输入人才计划共收到超过20万宗申请&#xff0c;是2022年申请数目的接近四倍。 在香港特区政府积极引进人才的推动下&#xff0c;今年首11个月&#xff0c;超…

FreeModbus--学习函数指针

目录 函数指针 最简单的例子 稍作修改例子 引入协议栈的函数指针 引入协议栈第二处函数指针 函数指针 该协议栈中使用到函数指针&#xff0c;现开展一篇专门存放函数指针的文章。 C语言的函数指针是指向函数的指针变量&#xff0c;可以用来存储和调用函数的地址。在C语言中…

Vue 组件传参 emit

emit 属性&#xff1a;用于创建自定义事件&#xff0c;接收子组件传递过来的数据。 注意&#xff1a;如果自定义事件的名称&#xff0c;和原生事件的名称一样&#xff0c;那么只会触发自定义事件。 setup 语法糖写法请见&#xff1a;《Vue3 子传父 组件传参 defineEmits》 语…

Qt容器QMdiArea 小部件提供一个显示 MDI 窗口的区域

## QMdiArea ## 控件简介 QMdiArea 继承 QAbstractScrollArea。QMdiArea 小部件提供一个显示 MDI 窗口的区域。QMdiArea的功能本质上类似于MDI窗口的窗口管理器。大多数复杂的程序,都使用MDI框架,在 Qt designer 中可以直接将控件 MDI Area 拖入使用。 ## 用法示例 例 qm…

luceda ipkiss教程 49:以pcell的方式定义线路

在ipkiss中&#xff0c;通常以i3.Circuit来设计线路&#xff08;见教程2&#xff09;&#xff0c;以i3.Pcell的框架也可以来设计线路&#xff1a; 以SplitterTree为例&#xff1a; 线路仿真结果&#xff1a; 所有代码如下&#xff1a; from si_fab import all as pdk import…

ShellCode注入程序

程序功能是利用NtQueueApcThreadEx注入ShellCode到一个进程中&#xff0c;程序运行后会让你选择模式&#xff0c;按1为普通模式&#xff0c;所需的常规API接口都是使用Windows原本正常的API&#xff1b;在有游戏保护的进程中Windows原本正常的API无法使用&#xff0c;这时候需要…

【Stable Diffusion】在windows环境下部署并使用Stable Diffusion Web UI---通过 Conda

本专栏主要记录人工智能的应用方面的内容&#xff0c;包括chatGPT、AI绘图等等&#xff1b; 在当今AI的热潮下&#xff0c;不学习AI&#xff0c;就要被AI淘汰&#xff1b;所以欢迎小伙伴加入本专栏和我一起探索AI的应用&#xff0c;通过AI来帮助自己提升生产力&#xff1b; 订阅…

计算机网络:物理层(三种数据交换方式)

今天又学到一个知识&#xff0c;加油&#xff01; 目录 前言 一、电路交换 二、报文交换 三、分组交换 1、数据报方式 2、虚电路方式 3、比较 总结 前言 为什么要进行数据交换&#xff1f; 一、电路交换 电路交换原理&#xff1a;在数据传输期间&#xff0c;源结点与…

无机物及分析化学3d虚拟实验室软件提高教学效果

VR化学虚拟仿真实验室软件可以解决以下难题&#xff1a; 实验场地限制&#xff1a;传统的化学实验室需要占用大量的物理空间&#xff0c;并且需要严格的安全措施。而VR技术可以提供一个虚拟的实验室环境&#xff0c;不受空间限制&#xff0c;可以同时容纳更多的学生参与实验。 …

分类信息网商业运营版源码系统:适合各类行业分类站点建站 带安装部署教程

随着互联网的快速发展&#xff0c;信息分类网站在各个行业中得到了广泛应用。为了满足不同行业的需求&#xff0c;罗峰给大家分享一款适合各类行业分类站点建站的商业运营版源码系统。该系统旨在提供一套完整的解决方案&#xff0c;帮助用户快速搭建自己的分类信息网站&#xf…