tensorflow 实现逻辑回归——原以为TensorFlow不擅长做线性回归或者逻辑回归,原来是这么简单哇!...

实现的是预测 低 出生 体重 的 概率。
尼克·麦克卢尔(Nick McClure). TensorFlow机器学习实战指南 (智能系统与技术丛书) (Kindle 位置 1060-1061). Kindle 版本.

# Logistic Regression
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
#  y = 0 or 1 = low birth weight
#  x = demographic and medical history dataimport matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csvops.reset_default_graph()# Create graph
sess = tf.Session()###
# Obtain and prepare data for modeling
#### Set name of data file
birth_weight_file = 'birth_weight.csv'# Download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'birth_file = requests.get(birthdata_url)birth_data = birth_file.text.split('\r\n')birth_header = birth_data[0].split('\t')birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]with open(birth_weight_file, 'w', newline='') as f:writer = csv.writer(f)writer.writerow(birth_header)writer.writerows(birth_data)f.close()# Read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:csv_reader = csv.reader(csvfile)birth_header = next(csv_reader)for row in csv_reader:birth_data.append(row)birth_data = [[float(x) for x in row] for row in birth_data]# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])# Set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)# Split data into train/test = 80%/20%
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]# Normalize by column (min-max norm)
def normalize_cols(m):col_max = m.max(axis=0)col_min = m.min(axis=0)return (m-col_min) / (col_max - col_min)x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))###
# Define Tensorflow computational graph¶
#### Declare batch size
batch_size = 25# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)###
# Train model
#### Initialize variables
init = tf.global_variables_initializer()
sess.run(init)# Actual Prediction
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)# Training loop
loss_vec = []
train_acc = []
test_acc = []
for i in range(15000):rand_index = np.random.choice(len(x_vals_train), size=batch_size)rand_x = x_vals_train[rand_index]rand_y = np.transpose([y_vals_train[rand_index]])sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})loss_vec.append(temp_loss)temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})train_acc.append(temp_acc_train)temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})test_acc.append(temp_acc_test)if (i+1)%300==0:print('Loss = ' + str(temp_loss))###
# Display model performance
#### Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/390510.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sdlc 瀑布式 生命周期_SDLC指南–软件开发生命周期的阶段和方法

sdlc 瀑布式 生命周期When I decided to teach myself how to code almost four years ago I had never heard of, let alone thought about, the software development life cycle.当我差不多四年前决定教自己如何编码时,我从未听说过软件开发生命周期,…

剑指 Offer 48. 最长不含重复字符的子字符串

请从字符串中找出一个最长的不包含重复字符的子字符串,计算该最长子字符串的长度。 示例 1: 输入: “abcabcbb” 输出: 3 解释: 因为无重复字符的最长子串是 “abc”,所以其长度为 3。 示例 2: 输入: “bbbbb” 输出: 1 解释: 因为无重复字符的最长子…

Mysql-my-innodb-heavy-4G.cnf配置文件注解

Mysql-同Nginx等一样具备多实例的特点,简单的讲就是在一台服务器上同时开启多个不同的服务端口(3306,3307)同时运行多个Mysql服务进程,这些服务进程通过不同的socket监听不同的服务端口来提供服务。这些Mysql多实例公用一套Mysql安…

is 和 == 的区别

is 和 操作符的区别 python官方解释: 的meaning为equal; is的meaning为object identity; is 判断对象是否相等,即身份是否相同,使用id值判断; 判断对象的值是否相等。id值是什么?id()函数官网…

win10管理凌乱桌面_用于管理凌乱的开源存储库的命令行技巧

win10管理凌乱桌面Effective collaboration, especially in open source software development, starts with effective organization. To make sure that nothing gets missed, the general rule, “one issue, one pull request” is a nice rule of thumb.有效的协作(特别是…

JAVA数组Java StringBuffer 和 StringBuilder 类

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_34173549/article/details/80215173 Java StringBuffer 和 StringBuilder 类 当对字符串进行修改的时候,需要使用 StringBuffer 和 StringBuilder 类。 和 Str…

strlen和sizeof的长度区别

strlen返回字符长度 而sizeof返回整个数组占多长,字符串的\0也会计入一个长度转载于:https://www.cnblogs.com/DawaTech/p/8086055.html

了解如何使用Yii2 PHP框架创建YouTube克隆

Yii is a fast, secure, and efficient PHP framework used to create all kinds of web apps. Weve released a full video course on how to use the Yii2 framework.Yii是一个快速,安全,高效PHP框架,用于创建各种Web应用程序。 我们已经发…

剑指 Offer 66. 构建乘积数组

给定一个数组 A[0,1,…,n-1],请构建一个数组 B[0,1,…,n-1],其中 B[i] 的值是数组 A 中除了下标 i 以外的元素的积, 即 B[i]A[0]A[1]…A[i-1]A[i1]…A[n-1]。不能使用除法。 示例: 输入: [1,2,3,4,5] 输出: [120,60,40,30,24] 提示: 所有…

Statement与PreparedStatement的区别

Statement与PreparedStatement的区别 PreparedStatement预编译SQL语句,性能好。 PreparedStatement无序拼接SQL语句,编程更简单. PreparedStatement可以防止SQL注入,安全性好。 Statement由方法createStatement()创建,该对象用于发…

剑指 Offer 45. 把数组排成最小的数

输入一个非负整数数组&#xff0c;把数组里所有数字拼接起来排成一个数&#xff0c;打印能拼接出的所有数字中最小的一个。 示例 1: 输入: [10,2] 输出: “102” 示例 2: 输入: [3,30,34,5,9] 输出: “3033459” 提示: 0 < nums.length < 100 说明: 输出结果可能非…

python 科学计算机_在这个免费的虚拟俱乐部中学习计算机科学和Python的基础知识

python 科学计算机Are you learning how to code in 2020? 您是否正在学习2020年编码&#xff1f; Or are you already working as a developer but want to learn computer science fundamentals? 还是您已经在从事开发人员工作&#xff0c;但想学习计算机科学基础知识&…

Struts2框架使用(十)之struts2的上传和下载

Struts2 文件上传 首先是Struts2的上传&#xff0c;Struts2 文件上传是基于 Struts2 拦截器实现的&#xff0c;使用的是fileupload组件&#xff1b; 首先如果想要上传文件&#xff0c;则需要在表单处添加 enctype"multipart/form-data" 属性。 <% page language&…

module_param 用于动态开启/关闭 驱动打印信息

1.定义模块参数的方法: module_param(name, type, perm); 其中,name:表示参数的名字; type:表示参数的类型; perm:表示参数的访问权限; type参数设定的类型和perm的访问权限具体数值数值请参考内核定义。 2、可以在insmod&#xff08;装载模块&#xff09;的时候为参…

超链接href属性_如何使用标签上的HREF属性制作HTML超链接

超链接href属性A website is a collection of web pages. And these pages need to be linked or connected by something. And to do so, we need to use a tag provided by HTML: the a tag. 网站是网页的集合。 这些页面需要通过某种方式链接或连接。 为此&#xff0c;我们需…

剑指 Offer 42. 连续子数组的最大和

输入一个整型数组&#xff0c;数组中的一个或连续多个整数组成一个子数组。求所有子数组的和的最大值。 要求时间复杂度为O(n)。 示例1: 输入: nums [-2,1,-3,4,-1,2,1,-5,4] 输出: 6 解释: 连续子数组 [4,-1,2,1] 的和最大&#xff0c;为 6。 解题思路 对于一个数组&…

centos 7安装配置vsftpd

yum install -y vsftpd #安装vsftpd yum install -y psmisc net-tools systemd-devel libdb-devel perl-DBI #安装vsftpd虚拟用户配置依赖包 systemctl enable vsftpd.service #设置vsftpd开机启动 cp /etc/vsftpd/vsftpd.conf /etc/vsftpd/vsftpd.conf-bak #备份默认配置文…

amazeui学习笔记--css(基本样式3)--文字排版Typography

amazeui学习笔记--css&#xff08;基本样式3&#xff09;--文字排版Typography 一、总结 1、字体&#xff1a;amaze默认非 衬线字体&#xff08;sans-serif&#xff09; 2、引用块blockquote和定义列表&#xff1a;引用块blockquote和定义列表&#xff08;dl dt&#xff09;注意…

剑指 Offer 46. 把数字翻译成字符串

给定一个数字&#xff0c;我们按照如下规则把它翻译为字符串&#xff1a;0 翻译成 “a” &#xff0c;1 翻译成 “b”&#xff0c;……&#xff0c;11 翻译成 “l”&#xff0c;……&#xff0c;25 翻译成 “z”。一个数字可能有多个翻译。请编程实现一个函数&#xff0c;用来计…

Zend Guard 7 , Zend Guard Loader处理PHP加密

环境&#xff1a;使用Zend Guard 7 软件加密。 PHP 5.6 LNMP 一键安装&#xff0c;PHP5.6Zend Guard Loader &#xff08;对应的版本文件&#xff09;是已经安装好了&#xff0c;还要安装 opcache.so ,直接在lnmp 安装教程中有。因为自动安装 的 版本并不对应&#xff0c;于…