【人工智能】Python常用库-TensorFlow常用方法教程

TensorFlow 是一个广泛应用的开源深度学习框架,支持多种机器学习任务,如深度学习、神经网络、强化学习等。以下是 TensorFlow 的详细教程,涵盖基础使用方法和示例代码。


1. 安装与导入

安装 TensorFlow:

pip install tensorflow

导入 TensorFlow:

import tensorflow as tf
import numpy as np

验证安装:

print(tf.__version__)  # 查看 TensorFlow 版本

2. TensorFlow 基础

2.1 张量(Tensor)

TensorFlow 的核心数据结构是张量,它是一个多维数组。

# 创建张量
a = tf.constant([1, 2, 3], dtype=tf.float32)  # 常量张量
b = tf.Variable([4, 5, 6], dtype=tf.float32)  # 可变张量# 基本运算
c = a + b
print(c.numpy())  # 转换为 NumPy 数组输出

输出结果

[5. 7. 9.]
2.2 自动求导

TensorFlow 支持自动计算梯度。

x = tf.Variable(3.0)with tf.GradientTape() as tape:y = x**2  # 定义目标函数dy_dx = tape.gradient(y, x)  # 自动求导
print(dy_dx.numpy())

输出结果

6.0

3. 构建模型

3.1 使用 Sequential API
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense# 构建简单神经网络
model = Sequential([Dense(64, activation='relu', input_shape=(10,)),Dense(32, activation='relu'),Dense(1, activation='sigmoid')
])# 查看模型结构
model.summary()

输出结果

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================dense (Dense)               (None, 64)                704       dense_1 (Dense)             (None, 32)                2080      dense_2 (Dense)             (None, 1)                 33        =================================================================
Total params: 2,817
Trainable params: 2,817
Non-trainable params: 0
_________________________________________________________________
3.2 自定义模型
import tensorflow as tf
from tensorflow.keras.layers import Denseclass MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = Dense(64, activation='relu')self.dense2 = Dense(32, activation='relu')self.output_layer = Dense(1, activation='sigmoid')def call(self, inputs):x = self.dense1(inputs)x = self.dense2(x)return self.output_layer(x)model = MyModel()input_shape = (None, 128, 128, 3)
model.build(input_shape)
model.summary()

输出结果

Model: "my_model"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================dense (Dense)               multiple                  256       dense_1 (Dense)             multiple                  2080      dense_2 (Dense)             multiple                  33        =================================================================
Total params: 2,369
Trainable params: 2,369
Non-trainable params: 0
_________________________________________________________________

4. 数据处理

4.1 数据加载
from tensorflow.keras.datasets import mnist# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 数据预处理
x_train = x_train / 255.0  # 归一化
x_test = x_test / 255.0
x_train = x_train.reshape(-1, 28*28)  # 展平
x_test = x_test.reshape(-1, 28*28)
4.2 创建数据管道
# 使用 Dataset API 创建数据管道
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(10000).batch(32).prefetch(tf.data.AUTOTUNE)

5. 模型训练与评估

5.1 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
5.2 训练模型
history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.2)
 5.3 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

完整代码

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.datasets import mnist# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 数据预处理:归一化到 [0, 1]
x_train = x_train / 255.0
x_test = x_test / 255.0# 构建模型
model = Sequential([Flatten(input_shape=(28, 28)),  # 将28x28的图像展平为1维Dense(128, activation='relu'),  # 全连接层,128个神经元Dense(10, activation='softmax')  # 输出层,10个类别
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 模型训练
history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.2)# 模型评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

输出结果

Epoch 1/10
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2894 - accuracy: 0.9178 - val_loss: 0.1607 - val_accuracy: 0.9547
Epoch 2/10
1500/1500 [==============================] - 2s 1ms/step - loss: 0.1301 - accuracy: 0.9614 - val_loss: 0.1131 - val_accuracy: 0.9656
Epoch 3/10
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0875 - accuracy: 0.9736 - val_loss: 0.1000 - val_accuracy: 0.9683
Epoch 4/10
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0658 - accuracy: 0.9804 - val_loss: 0.0934 - val_accuracy: 0.9728
Epoch 5/10
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0506 - accuracy: 0.9852 - val_loss: 0.0893 - val_accuracy: 0.9715
Epoch 6/10
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0397 - accuracy: 0.9878 - val_loss: 0.0908 - val_accuracy: 0.9731
Epoch 7/10
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0311 - accuracy: 0.9906 - val_loss: 0.0882 - val_accuracy: 0.9749
Epoch 8/10
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0251 - accuracy: 0.9924 - val_loss: 0.0801 - val_accuracy: 0.9777
Epoch 9/10
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0196 - accuracy: 0.9945 - val_loss: 0.0866 - val_accuracy: 0.9755
Epoch 10/10
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0166 - accuracy: 0.9949 - val_loss: 0.0980 - val_accuracy: 0.9735
313/313 [==============================] - 0s 863us/step - loss: 0.0886 - accuracy: 0.9758
Test accuracy: 0.9757999777793884

代码说明

  1. 数据加载与预处理

    • mnist.load_data():加载手写数字数据集。
    • 数据归一化:将像素值从 0-255 归一化到 0-1,有助于加速训练。
  2. 模型构建

    • Flatten 层:将二维的图像数据展平为一维数组,便于输入全连接层。
    • Dense 层:
      • 第一层使用 ReLU 激活函数。
      • 第二层是输出层,使用 Softmax 激活函数,用于多分类任务。
  3. 模型编译

    • 优化器:adam 是一种适用于大多数情况的优化算法。
    • 损失函数:sparse_categorical_crossentropy,用于分类任务。
  4. 训练

    • validation_split=0.2:从训练数据中划分 20% 用作验证集。
    • epochs=10:训练 10 个轮次。
  5. 评估

    • model.evaluate():评估模型在测试集上的性能,返回损失值和准确率。

6. 可视化

6.1 绘制训练过程
import matplotlib.pyplot as plt# 绘制训练与验证准确率
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Accuracy over Epochs')
plt.show()
6.2 绘制模型预测
# 显示预测结果
predictions = model.predict(x_test[:10])
print("Predicted labels:", np.argmax(predictions, axis=1))
print("True labels:", y_test[:10])

输出结果

Predicted labels: [7 2 1 0 4 1 4 9 5 9]
True labels: [7 2 1 0 4 1 4 9 5 9]

7. 高级功能

7.1 保存与加载模型
# 保存模型
model.save('my_model.h5')# 加载模型
loaded_model = tf.keras.models.load_model('my_model.h5')
7.2 自定义训练过程
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()for epoch in range(5):for x_batch, y_batch in dataset:with tf.GradientTape() as tape:predictions = model(x_batch, training=True)loss = loss_fn(y_batch, predictions)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))print(f"Epoch {epoch+1} Loss: {loss.numpy()}")

完整代码 

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 将标签二值化(偶数为 1,奇数为 0)
y_train = (y_train % 2 == 0).astype(int)
y_test = (y_test % 2 == 0).astype(int)# 数据预处理
x_train = x_train / 255.0  # 归一化
x_test = x_test / 255.0
x_train = x_train.reshape(-1, 28 * 28)  # 展平
x_test = x_test.reshape(-1, 28 * 28)# 使用 Dataset API 创建数据管道
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(10000).batch(32).prefetch(tf.data.AUTOTUNE)# 定义模型
class MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = Dense(64, activation='relu')self.dense2 = Dense(32, activation='relu')self.output_layer = Dense(1, activation='sigmoid')  # 输出单个概率def call(self, inputs):x = self.dense1(inputs)x = self.dense2(x)return self.output_layer(x)model = MyModel()# 自定义训练模型# 优化器和损失函数
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.BinaryCrossentropy()# 模型训练
for epoch in range(5):for x_batch, y_batch in dataset:with tf.GradientTape() as tape:predictions = model(x_batch, training=True)loss = loss_fn(y_batch, predictions)  # 使用二分类损失函数gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))print(f"Epoch {epoch + 1} Loss: {loss.numpy()}")

输出结果

Epoch 1 Loss: 0.14392520487308502
Epoch 2 Loss: 0.013877220451831818
Epoch 3 Loss: 0.006577217951416969
Epoch 4 Loss: 0.004411072935909033
Epoch 5 Loss: 0.0037908260710537434

8. 实际应用案例

8.1 图像分类
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Dense
from tensorflow.keras import Sequential# 加载数据集
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()# 数据预处理
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = to_categorical(y_train), to_categorical(y_test)# 模型构建与训练
model = Sequential([Dense(128, activation='relu', input_shape=(28*28,)),Dense(64, activation='relu'),Dense(10, activation='softmax')
])model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train.reshape(-1, 28*28), y_train, epochs=5, batch_size=32, validation_split=0.2)

输出结果

Epoch 1/5
1500/1500 [==============================] - 3s 2ms/step - loss: 0.5128 - accuracy: 0.8172 - val_loss: 0.3955 - val_accuracy: 0.8561
Epoch 2/5
1500/1500 [==============================] - 3s 2ms/step - loss: 0.3794 - accuracy: 0.8621 - val_loss: 0.3925 - val_accuracy: 0.8546
Epoch 3/5
1500/1500 [==============================] - 3s 2ms/step - loss: 0.3403 - accuracy: 0.8741 - val_loss: 0.3721 - val_accuracy: 0.8661
Epoch 4/5
1500/1500 [==============================] - 3s 2ms/step - loss: 0.3158 - accuracy: 0.8826 - val_loss: 0.3390 - val_accuracy: 0.8767
Epoch 5/5
1500/1500 [==============================] - 2s 2ms/step - loss: 0.3011 - accuracy: 0.8883 - val_loss: 0.3292 - val_accuracy: 0.8790

总结

TensorFlow 提供了从数据处理到模型训练和部署的完整解决方案。其灵活的 API 和强大的功能使得研究人员和工程师可以快速实现复杂的机器学习和深度学习任务。通过不断实践,可以深入了解 TensorFlow 的更多特性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62525.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot教程之十一:获取Request 请求 和 Put请求

如何在 Spring Boot 中获取Request Body? Java 语言是所有编程语言中最流行的语言之一。使用 Java 编程语言有几个优点,无论是出于安全目的还是构建大型分发项目。使用 Java 的优点之一是 Java 试图借助类、继承、多态等概念将语言中的每个概念与现实世…

uniapp实现组件竖版菜单

社区图片页面 scroll-view scroll-view | uni-app官网 (dcloud.net.cn) 可滚动视图区域。用于区域滚动。 需注意在webview渲染的页面中&#xff0c;区域滚动的性能不及页面滚动。 <template><view class"pics"><scroll-view class"left"…

Vue教程|搭建vue项目|Vue-CLI2.x 模板脚手架

一、项目构建环境准备 在构建Vue项目之前&#xff0c;需要搭建Node环境以及Vue-CLI脚手架&#xff0c;由于本篇文章为上一篇文章的补充&#xff0c;也是为了给大家分享更为完整的搭建vue项目方式&#xff0c;所以环境准备部分采用Vue教程&#xff5c;搭建vue项目&#xff5c;V…

Rust面向对象特性

Rust的面向对象特性 本文已同步至自建博客 Rust在设计的时候受到很多编程范式的影响&#xff0c;包括面向对象。面向对象的语言共有一些共同的特征&#xff0c;即对象、封装和继承。 封装 一个对象的实现细节对使用该对象的代码不可访问。因此&#xff0c;对象交互的唯一方…

前海湾地铁的腾通数码大厦背后的临时免费停车点探寻

临时免费停车点&#xff1a;前海湾地铁的腾通数码大厦背后的桂湾大街&#xff0c;目前看不仅整条桂湾大街停了​车&#xff0c;而且还有工地餐点。可能是这个区域还是半工地状态&#xff0c;故暂时还不会有​罚单的情况出现。 中建三局腾讯数码大厦项目部A栋 广东省深圳市南山…

Python知识分享第十五天

“”" 细节: 1.如下定义的类的几种写法 并无任何区别 最终效果都一样 只是写法不同 2.所有的类都直接或间接继承自object 它是所有类的父类 定义类的格式 格式1 class 类名: pass 格式2 class 类名(): pass 格式3 class 类名(父类名): pass “”" # 需求: 通过上述…

遥感数据集:FTW全球农田边界和对应影像数据,约160万田块边界及7万多个样本

Fields of The World (FTW) 是一个面向农业田地边界实例分割的基准数据集&#xff0c;旨在推动机器学习模型的发展&#xff0c;满足全球农业监测对高精度、可扩展的田地边界数据的需求。该数据集由kerner-lab提供&#xff0c;于2024年8月28日发布&#xff0c;主要特征包括&…

USB Type-C一线通扩展屏:多场景应用,重塑高效办公与极致娱乐体验

在追求高效与便捷的时代&#xff0c;启明智显USB Type-C一线通扩展屏方案正以其独特的优势&#xff0c;成为众多职场人士、娱乐爱好者和游戏玩家的首选。这款扩展屏不仅具备卓越的性能和广泛的兼容性&#xff0c;更能在多个应用场景中发挥出其独特的价值。 USB2.0显卡&#xff…

项目二技巧一

目录 nginx实现根据域名来访问不同的ip端口 配置Maven私服 快照版和发布版的区别 快照版本&#xff08;Snapshot&#xff09; 发布版本&#xff08;Release&#xff09; 导入发布版的父工程 理清楚授权规则 一.首先浏览器发送/manager/**路径请求 第二步&#xff1a;构造…

如何更好地设计SaaS系统架构

SaaS&#xff08;Software as a Service&#xff09;架构设计的核心目标是满足多租户需求、支持弹性扩展和高性能&#xff0c;同时保持低成本和高可靠性。一个成功的SaaS系统需要兼顾技术架构、资源利用、用户体验和商业目标。本文从以下几个方面探讨如何更好地设计SaaS系统架构…

手搓一个不用中间件的分表策略

场景&#xff1a;针对一些特别的项目&#xff0c;不用中间件&#xff0c;以月为维度进行分表&#xff0c;代码详细设计方案 1. 定义分片策略 首先&#xff0c;定义一个分片策略类&#xff0c;用于决定数据存储在哪个分表中 import java.time.LocalDate; import java.time.fo…

详解SpringCloud集成Camunda7.19实现工作流审批(二)

本章将分享的是camunda流程设计器--Camunda Modeler的基本使用&#xff08;对应camunda版本是7.19&#xff09;&#xff0c;包括bpmn流程图画法&#xff0c;各种控件使用以及一些日常业务场景的流程图的实现 参考资料&#xff1a; Camunda BPMN 基础组件-CSDN博客 Camunda: Exe…

webpack(react)基本构建

文章目录 概要整体架构流程技术名词解释技术细节小结 概要 Webpack 是一个现代 JavaScript 应用程序的静态模块打包工具。它的主要功能是将各种资源&#xff08;如 JavaScript、CSS、图片等&#xff09;视为模块&#xff0c;并将它们打包成一个或多个输出文件&#xff0c;以便…

html select下拉多选 修改yselect.js插件实现下拉多选,搜索,限制选中,默认回显等操作

需求&#xff1a;要在select标签实现下拉多选&#xff0c;搜索&#xff0c;限制选中&#xff0c;默认回显等操作&#xff0c;之前同事用的yselect.js&#xff0c;网上用的简直是寥寥无几&#xff0c;找了半天没找到限制选中的方法&#xff0c;看了源代码才发现根本没有&#xf…

c++哈希表(原理、实现、开放寻址法)适合新手

c系列哈希的原理及实现&#xff08;上&#xff09; 文章目录 c系列哈希的原理及实现&#xff08;上&#xff09;前言一、哈希的概念二、哈希冲突三、哈希冲突解决3.1、开放寻址法3.2、删除操作3.3、负载因子四、代码实现 总结 前言 红黑树平衡树和哈希有不同的用途。 红黑树、…

了解HTTPS以及CA在其中的作用

在这个信息爆炸的时代&#xff0c;每一次指尖轻触屏幕&#xff0c;都是一次数据的旅行。但您是否真正了解&#xff0c;这些数据在通往目的地的旅途中&#xff0c;是如何被保护的呢&#xff1f; HTTPS&#xff08;HyperText Transfer Protocol Secure&#xff09;是一种安全的网…

electron-vite_14窗口默认全屏铺满

有时候应用打包后&#xff0c;希望全屏显示;而默认的宽度和高度,是无法满足的;这时需要单独处理; 核心代码 // 1.引入screen对象 import { BrowserWindow, screen } from electron; function createWindow(): void {// 2.获取屏幕尺寸const { width, height } screen.getPrim…

mysql-为什么需要线程池

mysql-为什么需要线程池 MySQL线程池的概述与应用 MySQL线程池是MySQL数据库中的一个重要组件&#xff0c;旨在提高数据库的性能、吞吐量和可伸缩性。它通过管理数据库服务器的线程生命周期&#xff0c;减少了线程的创建和销毁的开销&#xff0c;并通过优化资源使用&#xff…

【接口封装】——10、系统托盘

解释&#xff1a; 1、定义好按钮的状态&#xff1a;创建 map 映射关系&#xff0c;即 一个名字对应一个按钮 2、对不同按钮实现不同的信号槽函数 头文件&#xff1a; #include "SysTrayIcon.h" #include <qwidget.h> #include "define.h" #include &…

Nginx——配置部署域名服务器路由nginx

文章目录 基本配置报错解决只能通过[域名]:[端口]/[API路径]的方式请求 基本配置 user www-data; worker_processes auto;error_log /var/log/nginx/error.log notice; pid /run/nginx.pid;events {worker_connections 1024; }http {include /etc/nginx/mime…