用神经网络预测三角形的面积

周末遛狗时,我想起一个老问题:神经网络能预测三角形的面积吗?

神经网络非常擅长分类,例如根据花瓣长度和宽度以及萼片长度和宽度预测鸢尾花的种类(setosa、versicolor 或 virginica)。神经网络还擅长一些回归问题,例如根据城镇房屋的平均面积、城镇的税率、与最近大城市的距离等预测城镇的房价中位数。

但神经网络并不适用于普通的数学计算,例如根据底边和高计算三角形的面积。如果你的小学数学有点生疏,我会提醒你,三角形的面积是底边乘以高的 1/2。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

我在一家大型科技公司工作,PyTorch 是官方首选的神经网络代码库,也是我个人首选的库。周末遛狗时,我决定使用 PyTorch 1.6 版(当前版本)预测三角形面积。

我编写了一个程序,以编程方式生成 10,000 个训练样本,其中底边和高边的值是 0.1 到 0.9 之间的随机值(因此面积在 0.005 到 0.405 之间)。我创建了一个 2-(100-100-100-100)-1 神经网络 — 2 个输入节点、4 个隐藏层(每个隐藏层有 100 个节点)和一个输出节点。我在隐藏节点上使用了 tanh 激活,在输出节点上没有使用激活。

我使用 10 个样本作为批次对网络进行了 1,000 个周期的训练。

训练后,网络正确地预测了 100% 的训练项目在正确区域的 10% 以内,100% 的训练项目在正确区域的 5% 以内,82% 的训练项目在正确区域的 1% 以内。这是否是一个好结果取决于你的观点。

很有趣。深度学习引起了很多关注,并且有大量关于该主题的研究活动。但这不是魔术。

在我思考三角形的那个周末,我看了一部 1967 年的老间谍电影《比男杀手更致命》,里面的女杀手都留着蜂窝发型。左图:女演员 Elke Sommer 扮演主要杀手。我不知道这种发型是怎么回事。中图和右图:互联网图片搜索返回了不少这样的图片,所以我猜蜂窝发型现在有时仍然在使用。

我的代码如下:

# triangle_area_nn.py
# predict area of triangle using PyTorch NNimport numpy as np
import torch as T
device = T.device("cpu")class TriangleDataset(T.utils.data.Dataset):# 0.40000, 0.80000, 0.16000 #   [0]      [1]      [2]def __init__(self, src_file, num_rows=None):all_data = np.loadtxt(src_file, max_rows=num_rows,usecols=range(0,3), delimiter=",", skiprows=0,dtype=np.float32)self.x_data = T.tensor(all_data[:,0:2],dtype=T.float32).to(device)self.y_data = T.tensor(all_data[:,2],dtype=T.float32).to(device)self.y_data = self.y_data.reshape(-1,1)def __len__(self):return len(self.x_data)def __getitem__(self, idx):if T.is_tensor(idx):idx = idx.tolist()base_ht = self.x_data[idx,:]  # idx rows, all 4 colsarea = self.y_data[idx,:]    # idx rows, the 1 colsample = { 'base_ht' : base_ht, 'area' : area }return sample# ---------------------------------------------------------def accuracy(model, ds):# ds is a iterable Dataset of Tensorsn_correct10 = 0; n_wrong10 = 0n_correct05 = 0; n_wrong05 = 0n_correct01 = 0; n_wrong01 = 0# alt: create DataLoader and then enumerate itfor i in range(len(ds)):inpts = ds[i]['base_ht']tri_area = ds[i]['area']    # float32  [0.0] or [1.0]with T.no_grad():oupt = model(inpts)delta = tri_area.item() - oupt.item()if delta < 0.10 * tri_area.item():n_correct10 += 1else:n_wrong10 += 1if delta < 0.05 * tri_area.item():n_correct05 += 1else:n_wrong05 += 1if delta < 0.01 * tri_area.item():n_correct01 += 1else:n_wrong01 += 1acc10 = (n_correct10 * 1.0) / (n_correct10 + n_wrong10)acc05 = (n_correct05 * 1.0) / (n_correct05 + n_wrong05)acc01 = (n_correct01 * 1.0) / (n_correct01 + n_wrong01)return (acc10, acc05, acc01)# ----------------------------------------------------------class Net(T.nn.Module):def __init__(self):super(Net, self).__init__()self.hid1 = T.nn.Linear(2, 100)  # 2-(100-100-100-100)-1self.hid2 = T.nn.Linear(100, 100)self.hid3 = T.nn.Linear(100, 100)self.hid4 = T.nn.Linear(100, 100)self.oupt = T.nn.Linear(100, 1)T.nn.init.xavier_uniform_(self.hid1.weight)  # glorotT.nn.init.zeros_(self.hid1.bias)T.nn.init.xavier_uniform_(self.hid2.weight)  # glorotT.nn.init.zeros_(self.hid2.bias)T.nn.init.xavier_uniform_(self.hid3.weight)  # glorotT.nn.init.zeros_(self.hid3.bias)T.nn.init.xavier_uniform_(self.hid4.weight)  # glorotT.nn.init.zeros_(self.hid4.bias)T.nn.init.xavier_uniform_(self.oupt.weight)  # glorotT.nn.init.zeros_(self.oupt.bias)def forward(self, x):z = T.tanh(self.hid1(x))  # or T.nn.Tanh()z = T.tanh(self.hid2(z))z = T.tanh(self.hid3(z))z = T.tanh(self.hid4(z))z = self.oupt(z)          # no activationreturn z# ----------------------------------------------------------def main():# 0. make training data filenp.random.seed(1)T.manual_seed(1)hi = 0.9; lo = 0.1train_f = open("area_train.txt", "w")for i in range(10000):base = (hi - lo) * np.random.random() + loheight = (hi - lo) * np.random.random() + loarea = 0.5 * base * heights = "%0.5f, %0.5f, %0.5f \n" % (base, height, area)train_f.write(s)train_f.close()# 1. create Dataset and DataLoader objectsprint("Creating Triangle Area train DataLoader ")train_file = ".\\area_train.txt"train_ds = TriangleDataset(train_file)  # all rowsbat_size = 10train_ldr = T.utils.data.DataLoader(train_ds,batch_size=bat_size, shuffle=True)# 2. create neural networkprint("Creating 2-(100-100-100-100)-1 regression NN ")net = Net()# 3. train networkprint("\nPreparing training")net = net.train()  # set training modelrn_rate = 0.01loss_func = T.nn.MSELoss()optimizer = T.optim.SGD(net.parameters(),lr=lrn_rate)max_epochs = 1000ep_log_interval = 100print("Loss function: " + str(loss_func))print("Optimizer: SGD")print("Learn rate: 0.01")print("Batch size: 10")print("Max epochs: " + str(max_epochs))print("\nStarting training")for epoch in range(0, max_epochs):epoch_loss = 0.0            # for one full epochepoch_loss_custom = 0.0num_lines_read = 0for (batch_idx, batch) in enumerate(train_ldr):X = batch['base_ht']  # [10,4]  base, height inputsY = batch['area']     # [10,1]  correct area to predictoptimizer.zero_grad()oupt = net(X)            # [10,1]  computed loss_obj = loss_func(oupt, Y)  # a tensorepoch_loss += loss_obj.item()  # accumulateloss_obj.backward()optimizer.step()if epoch % ep_log_interval == 0:print("epoch = %4d   loss = %0.4f" % \(epoch, epoch_loss))print("Done ")# 4. evaluate modelnet = net.eval()(acc10, acc05, acc01) = accuracy(net, train_ds)print("\nAccuracy (.10) on train data = %0.2f%%" % \(acc10 * 100))print("\nAccuracy (.05) on train data = %0.2f%%" % \(acc05 * 100))print("\nAccuracy (.01) on train data = %0.2f%%" % \(acc01 * 100))if __name__ == "__main__":main()

原文链接:用神经网络预测三角形面积 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/13957.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024中青杯A题数学建模成品文章数据代码分享

人工智能视域下养老辅助系统的构建 摘要 随着全球人口老龄化的加剧&#xff0c;养老问题已经成为一个世界性的社会问题&#xff0c;对社会各个方面产生了深远影响&#xff0c;包括劳动力市场、医疗保健和养老金制度等。人口结构变化对养老服务的质量和覆盖面提出了更高要求。特…

Python 爬虫编写入门

一、爬虫概述 网络爬虫&#xff08;Web Crawler&#xff09;或称为网络蜘蛛&#xff08;Web Spider&#xff09;&#xff0c;是一种按照一定规则&#xff0c;自动抓取互联网信息的程序或者脚本。它们可以自动化地浏览网络中的信息&#xff0c;通过解析网页内容&#xff0c;提取…

3台机器快速安装ELK集群

安装和配置 Elasticsearch、Kibana 和 Logstash 以下是安装和配置 Elasticsearch、Kibana 和 Logstash 的详细步骤&#xff0c;并设置开机自启。 步骤 1&#xff1a;修改系统参数 编辑系统参数并使其生效&#xff1a; vim /etc/sysctl.conf添加以下行&#xff1a; vm.max_…

Xpath元素定位和三大等待,三大切换

在页面的操作过程当中&#xff0c;都需要适当的等待。特别是&#xff1a; 候【发生了页面切换的时候】。而我们接下来的操作都是在变化的内容上。 代码就要等等页面的加载&#xff0c;等等页面的渲染。代码是非常快的&#xff0c;页面加载跟不 上&#xff0c;就需要等待。 三大…

getaway基本配置

Getaway 是一款用于容器化应用的轻量级 API 网关。它提供了一种简单的方式来管理和路由 API 请求&#xff0c;通常用于微服务架构中。以下是 Getaway 的基本配置指南&#xff0c;包括安装、配置文件示例、以及一些常见的配置选项。 ### 1. 安装 Getaway 通常通过 Docker 容器…

用友开发平台调用审核提示U8授权失败可能原因

U8授权失败可能有多种原因&#xff0c;这里有几个可能的解决方案供您参考&#xff1a; 登录接口未调用&#xff1a;在调用审核接口&#xff08;如audit、abandon、verify、unverify&#xff09;之前&#xff0c;请确保已经调用了登录接口&#xff08;login&#xff09;。如果登…

红队攻防渗透技术实战流程:云安全之云原生安全:K8s安全

红队云攻防实战 1.云原生-K8s安全-名词架构&各攻击点1.1 云原生-K8s安全-概念1.2 云原生-K8s安全-K8S集群架构解释1.2.1 K8s安全-K8S集群架构-Master节点1.2.2 K8s安全-K8S集群架构-Node节点1.2.3 K8s安全-K8S集群架构-Pod容器1.3 云原生安全-K8s安全-K8S集群攻击点 `(重点…

ARP基本原理

相关概念 ARP报文 ARP报文分为ARP请求报文和ARP应答报文&#xff0c;报文格式如图1所示。 图1 ARP报文格式 Ethernet Address of destination&#xff08;0–31&#xff09;和Ethernet Address of destination&#xff08;32–47&#xff09;分别表示Ethernet Address of dest…

【算法】前缀和——除自身以外数组的乘积

本节博客是用前缀和算法求解“除自身以外数组的乘积”&#xff0c;有需要借鉴即可。 目录 1.题目2.前缀和算法3.变量求解4.总结 1.题目 题目链接&#xff1a;LINK 2.前缀和算法 1.创建两个数组 第一个数组第i位置表示原数组[0,i-1]之积第二个数组第i位置表示原数组[i1,n-1]…

laravel8 JWT配置

一、安装JWT composer require tymon/jwt-auth二、config/app.php 注册服务提供者 providers > [Tymon\JWTAuth\Providers\LaravelServiceProvider::class, ]aliases > [JWTAuth > Tymon\JWTAuth\Facades\JWTAuth::class,JWTFactory > Tymon\JWTAuth\Facades\JWT…

Hadoop 客户端 FileSystem加载过程

如何使用hadoop客户端 public class testCreate {public static void main(String[] args) throws IOException {System.setProperty("HADOOP_USER_NAME", "hdfs");String pathStr "/home/hdp/shanshajia";Path path new Path(pathStr);Confi…

AWS安全性身份和合规性之Amazon Detective

分析和直观呈现安全数据&#xff0c;以调查潜在的安全问题。 Amazon Detective使您可以更轻松地分析、调查和快速确定潜在安全问题或可疑活动的根本原因。Amazon Detective会自动从您地AWS资源中收集日志数据并使用机器学习、统计分析和图论来构建一组关联的数据&#xff0c;使…

在DAYU200上实现OpenHarmony跳转拨号界面

一、简介 日常生活中&#xff0c;打电话是最常见的交流方式之一&#xff0c;那么如何在OpenAtom OpenHarmony&#xff08;简称“OpenHarmony”&#xff09;中进行电话服务相关的开发呢&#xff1f;今天我们可以一起来了解一下如何通过电话服务系统支持的API实现拨打电话的功能…

ECMAScript 详解

ECMAScript 是一种脚本语言规范&#xff0c;由欧洲计算机制造商协会&#xff08;ECMA&#xff09;通过 ECMA-262 标准化&#xff0c;广泛用于客户端脚本编程。它最著名的实现是 JavaScript&#xff0c;主要用于 Web 开发。以下是 ECMAScript 的详细解析&#xff1a; ### 1. 历…

C#中System.Threading.Timer的使用

文章速览 概述创建计时器对象循环执行的方法停止计时器参考文章 坚持记录实属不易&#xff0c;希望友善多金的码友能够随手点一个赞。 共同创建氛围更加良好的开发者社区&#xff01; 谢谢~ 概述 本文着重于System.Threading.Timer的简单使用方法。 由于在实际开发过程中&…

LabVIEW机械臂自动化在精密制造中的应用

精密制造是现代工业中的关键环节&#xff0c;要求高精度、高效率以及一致性。机械臂自动化技术结合LabVIEW软件&#xff0c;提供了强大的控制、数据处理和用户界面设计能力&#xff0c;使其在精密制造中得到了广泛应用。以下是几个具体的应用实例&#xff1a; 1. 电路板焊接 …

C#-根据日志等级进行日志的过滤输出

文章速览 概要具体实施创建Log系统动态修改日志等级 坚持记录实属不易&#xff0c;希望友善多金的码友能够随手点一个赞。 共同创建氛围更加良好的开发者社区&#xff01; 谢谢~ 概要 方便后期对软件进行维护&#xff0c;需要在一些关键处添加log日志输出&#xff0c;但时间长…

【408精华知识】指令周期的数据流

文章目录 一、取指周期二、间址周期三、执行周期&#xff08;一&#xff09;数据传送类指令(mov/load/store)&#xff08;二&#xff09;运算类指令(加/减/乘/除/移位/与/或)&#xff08;三&#xff09;转移类指令(jmp/jxxx) 四、中断周期 CPU每取出并且执行一条指令所需要的全…

二叉数之插入操作

首先是题目 给定二叉搜索树&#xff08;BST&#xff09;的根节点 root 和要插入树中的值 value &#xff0c;将值插入二叉搜索树。 返回插入后二叉搜索树的根节点。 输入数据 保证 &#xff0c;新值和原始二叉搜索树中的任意节点值都不同。 注意&#xff0c;可能存在多种有效…

AcWing 217:绿豆蛙的归宿 ← 搜索算法

【题目来源】https://www.acwing.com/problem/content/219/【题目描述】 给出一个有向无环的连通图&#xff0c;起点为 1&#xff0c;终点为 N&#xff0c;每条边都有一个长度。 数据保证从起点出发能够到达图中所有的点&#xff0c;图中所有的点也都能够到达终点。 绿豆蛙从起…