TensorFlow 实现 Mixture Density Network (MDN) 的完整说明

本文档详细解释了一段使用 TensorFlow 构建和训练混合密度网络(Mixture Density Network, MDN)的代码,涵盖数据生成、模型构建、自定义损失函数与预测可视化等各个环节。


1. 导入库与设置超参数

import numpy as np 
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import math

说明

  • 引入用于数值运算(NumPy)、构建深度学习模型(TensorFlow/Keras)和绘图(Matplotlib)的基础工具包。

超参数定义

N_HIDDEN = 15         # 隐藏层神经元数量
N_MIXES = 10          # GMM 中混合成分数量
OUTPUT_DIMS = 1       # 输出维度(目标变量维度)

2. 自定义 MDN 层

class MDN(layers.Layer):def __init__(self, output_dims, num_mixtures, **kwargs):super(MDN, self).__init__(**kwargs)self.output_dims = output_dimsself.num_mixtures = num_mixturesself.params = self.num_mixtures * (2 * self.output_dims + 1)  # pi, mu, sigmaself.dense = layers.Dense(self.params)def call(self, inputs):output = self.dense(inputs)return output

说明

  • params 表示 GMM 每个分量包含 mu(均值)、sigma(标准差)和 pi(权重),共 2*D + 1 个参数。
  • 输出维度为 (batch_size, num_mixtures * (2*output_dims + 1))

3. 自定义 MDN 损失函数

def get_mixture_loss_func(output_dims, num_mixtures):def mdn_loss(y_true, y_pred):y_true = tf.reshape(y_true, [-1, 1])out_mu = y_pred[:, :num_mixtures * output_dims]out_sigma = y_pred[:, num_mixtures * output_dims:2 * num_mixtures * output_dims]out_pi = y_pred[:, -num_mixtures:]mu = tf.reshape(out_mu, [-1, num_mixtures, output_dims])sigma = tf.exp(tf.reshape(out_sigma, [-1, num_mixtures, output_dims]))pi = tf.nn.softmax(out_pi)y_true = tf.tile(y_true[:, tf.newaxis, :], [1, num_mixtures, 1])normal_dist = tf.exp(-0.5 * tf.square((y_true - mu) / sigma)) / (sigma * tf.sqrt(2.0 * np.pi))prob = tf.reduce_prod(normal_dist, axis=2)weighted_prob = prob * piloss = -tf.math.log(tf.reduce_sum(weighted_prob, axis=1) + 1e-8)return tf.reduce_mean(loss)return mdn_loss

说明

  • 通过概率密度函数计算目标值属于 GMM 各个分布的概率,并取加权平均。
  • 对数似然函数取负作为损失。

4. 从输出分布中采样

def sample_from_output(y_pred, output_dims, num_mixtures, temp=1.0):out_mu = y_pred[:num_mixtures * output_dims]out_sigma = y_pred[num_mixtures * output_dims:2 * num_mixtures * output_dims]out_pi = y_pred[-num_mixtures:]out_sigma = np.exp(out_sigma)out_pi = np.exp(out_pi / temp)out_pi /= np.sum(out_pi)mixture_idx = np.random.choice(np.arange(num_mixtures), p=out_pi)mu = out_mu[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]sigma = out_sigma[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]sample = np.random.normal(mu, sigma)return sample

说明

  • 使用 softmax 处理 pi,选择一个分布后按对应的 musigma 采样。
  • temp 控制采样温度(温度越高分布越平坦)。

5. 生成训练数据

NSAMPLE = 3000
y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))
r_data = np.random.normal(size=NSAMPLE)
x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0
x_data = x_data.reshape((NSAMPLE, 1))
y_data = y_data.reshape((NSAMPLE, 1))

说明

  • 构造非线性映射关系的合成数据:x = sin(0.75y)*7 + 0.5y + 噪声
  • x 是输入,y 是目标。

6. 构建模型

model = keras.Sequential([layers.Dense(N_HIDDEN, input_shape=(1,), activation='relu'),layers.Dense(N_HIDDEN, activation='relu'),MDN(OUTPUT_DIMS, N_MIXES)
])
model.compile(loss=get_mixture_loss_func(OUTPUT_DIMS, N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()

说明

  • 构建一个两层隐层的前馈神经网络,输出 MDN 层。
  • 使用自定义的 MDN 损失函数训练模型。

7. 模型训练

model.fit(x_data, y_data, batch_size=128, epochs=200, validation_split=0.15, verbose=1)
  • 批量大小 128,训练 200 个 epoch,保留 15% 数据用于验证。

8. 模型测试与预测可视化

x_test = np.linspace(-15, 15, 1000).astype(np.float32).reshape(-1, 1)
y_pred = model.predict(x_test)
y_samples = np.array([sample_from_output(p, OUTPUT_DIMS, N_MIXES) for p in y_pred])
  • 对连续输入进行预测并从预测的 GMM 中采样。

可视化预测结果

plt.figure()
plt.scatter(x_test, y_samples, alpha=0.3, s=10)
plt.title("MDN Predictions")
plt.xlabel("x")
plt.ylabel("y")
plt.show()

原始数据与预测对比

plt.figure(figsize=(8, 5))
plt.scatter(x_data, y_data, label="Original Data", alpha=0.2, s=10)
plt.scatter(x_test, y_samples, label="MDN Samples", alpha=0.5, s=10, color='r')
plt.title("MDN Prediction vs Training Data")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()

总代码如下

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import math# 超参数
N_HIDDEN = 15
N_MIXES = 10
OUTPUT_DIMS = 1# === 1. 自定义 MDN 层 ===
class MDN(layers.Layer):def __init__(self, output_dims, num_mixtures, **kwargs):super(MDN, self).__init__(**kwargs)self.output_dims = output_dimsself.num_mixtures = num_mixturesself.params = self.num_mixtures * (2 * self.output_dims + 1)  # pi, mu, sigmaself.dense = layers.Dense(self.params)def call(self, inputs):output = self.dense(inputs)return output# === 2. 自定义损失函数 ===
def get_mixture_loss_func(output_dims, num_mixtures):def mdn_loss(y_true, y_pred):y_true = tf.reshape(y_true, [-1, 1])out_mu = y_pred[:, :num_mixtures * output_dims]out_sigma = y_pred[:, num_mixtures * output_dims:2 * num_mixtures * output_dims]out_pi = y_pred[:, -num_mixtures:]mu = tf.reshape(out_mu, [-1, num_mixtures, output_dims])sigma = tf.exp(tf.reshape(out_sigma, [-1, num_mixtures, output_dims]))pi = tf.nn.softmax(out_pi)y_true = tf.tile(y_true[:, tf.newaxis, :], [1, num_mixtures, 1])normal_dist = tf.exp(-0.5 * tf.square((y_true - mu) / sigma)) / (sigma * tf.sqrt(2.0 * np.pi))prob = tf.reduce_prod(normal_dist, axis=2)weighted_prob = prob * piloss = -tf.math.log(tf.reduce_sum(weighted_prob, axis=1) + 1e-8)return tf.reduce_mean(loss)return mdn_loss# === 3. 从输出采样函数 ===
def sample_from_output(y_pred, output_dims, num_mixtures, temp=1.0):out_mu = y_pred[:num_mixtures * output_dims]out_sigma = y_pred[num_mixtures * output_dims:2 * num_mixtures * output_dims]out_pi = y_pred[-num_mixtures:]out_sigma = np.exp(out_sigma)out_pi = np.exp(out_pi / temp)out_pi /= np.sum(out_pi)mixture_idx = np.random.choice(np.arange(num_mixtures), p=out_pi)mu = out_mu[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]sigma = out_sigma[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]sample = np.random.normal(mu, sigma)return sample# === 4. 生成训练数据 ===
NSAMPLE = 3000
y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))
r_data = np.random.normal(size=NSAMPLE)
x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0
x_data = x_data.reshape((NSAMPLE, 1))
y_data = y_data.reshape((NSAMPLE, 1))plt.figure()
plt.scatter(x_data, y_data, alpha=0.3)
plt.title("Training Data")
plt.show()# === 5. 构建模型 ===
model = keras.Sequential([layers.Dense(N_HIDDEN, input_shape=(1,), activation='relu'),layers.Dense(N_HIDDEN, activation='relu'),MDN(OUTPUT_DIMS, N_MIXES)
])
model.compile(loss=get_mixture_loss_func(OUTPUT_DIMS, N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()# === 6. 模型训练 ===
model.fit(x_data, y_data, batch_size=128, epochs=200, validation_split=0.15, verbose=1)# === 7. 测试与可视化 ===
x_test = np.linspace(-15, 15, 1000).astype(np.float32).reshape(-1, 1)
y_pred = model.predict(x_test)
y_samples = np.array([sample_from_output(p, OUTPUT_DIMS, N_MIXES) for p in y_pred])plt.figure()
plt.scatter(x_test, y_samples, alpha=0.3, s=10)
plt.title("MDN Predictions")
plt.xlabel("x")
plt.ylabel("y")
plt.show()
# === 8. 测试数据与预测对比图 ===plt.figure(figsize=(8, 5))
plt.scatter(x_data, y_data, label="Original Data", alpha=0.2, s=10)
plt.scatter(x_test, y_samples, label="MDN Samples", alpha=0.5, s=10, color='r')
plt.title("MDN Prediction vs Training Data")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()

总结

本项目展示了如何使用 TensorFlow 构建混合密度网络,用以建模复杂的条件分布。相比传统回归模型,MDN 能够生成多峰预测结果,适用于不确定性高、输出存在多解的场景。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/79663.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构实验7.2:二叉树的基本运算

文章目录 一,实验目的二,问题描述三,基本要求四,实验操作五,示例代码六,运行效果 一,实验目的 深入理解树与二叉树的基本概念,包括节点、度、层次、深度等,清晰区分二叉…

直线轴承常规分类知多少?

直线轴承的分类方式多样,以下是从材质、结构形状和常规系列三个维度进行的具体分类: 按主要材质分类 外壳材质:常见的有不锈钢,具有良好的耐腐蚀性,适用于一些对环境要求较高、易受腐蚀的工作场景;轴承…

websocket和SSE学习记录

websocket学习记录 websocket使用场景 即时聊天在线文档协同编辑实施地图位置 从开发角度来学习websocket开发 即使通信项目 通过node建立简单的后端接口,利用fs, path, express app.get(*, (req, res) > {const assetsType req.url.split(/)[…

CUDA编程中影响性能的小细节总结

一、内存访问优化 合并内存访问:确保相邻线程访问连续内存地址(全局内存对齐访问)。优先使用共享内存(Shared Memory)减少全局内存访问。避免共享内存的Bank Conflict(例如,使用padding或调整访…

【双指针】对撞指针 快慢指针 移动零

文章目录 双指针介绍对撞指针快慢指针283. 移动零解题思路算法思路算法流程双指针介绍 ​ 算法中的双指针,并不一定是指我们平常在 c/c++ 使用的指针类型,更多时候其实是数组的下标等,因为它们也是有标识某个元素的功能,通常我们也就顺其自然地称其为 “指针” ! ​ 常见…

数据结构0基础学习堆

文章目录 简介公式建立堆函数解释 堆排序O(n logn)topk问题 简介 堆是一种重要的数据结构,是一种完全二叉树,(二叉树的内容后面会出), 堆分为大小堆,大堆,左右结点都小于根节点,&am…

4.17--4.19刷题记录(贪心)

第一部分:准备工作 代码随想录中解释为:贪心的本质是选择每一阶段的局部最优,从而达到全局最优。 而我的理解为:贪心实质上是具有最优子结构的一种算法。所有的解都能由当前最优的解组成。 第二部分:开始刷题 &…

学习笔记十七——Rust 支持面向对象编程吗?

🧠 Rust 支持面向对象编程吗? Rust 是一门多范式语言,主要以 安全、并发、函数式、系统级编程为核心目标,但它同时也支持面向对象的一些关键特性,比如: 特性传统 OOP(如 Java/C)Ru…

【Linux】43.网络基础(2.5)

文章目录 2.4 TCP/UDP对比2.4.1 用UDP实现可靠传输(经典面试题) 2.5 TCP 相关实验2.5.1 理解 listen 的第二个参数 2.4 TCP/UDP对比 我们说了TCP是可靠连接, 那么是不是TCP一定就优于UDP呢? TCP和UDP之间的优点和缺点, 不能简单, 绝对的进行比较TCP用于可靠传输的情况, 应用于…

three.js与webgl在buffer上的对应关系

一、three.js的类名 最近开始接触three.js 看到three.js中的一些类名和webgl的很相似 不自觉的就想对比一下 二、three.js中绘制4个点 // 创建点的几何体 const vertices new Float32Array([0.0, 0.0, 0.0, // 点11.0, 0.0, 0.0, // 点20.0, 1.0, 0.0, // 点30.…

DataWhale AI春训营 问题汇总

1.没用下载训练集导致出错,爆错如下。 这个时候需要去比赛官网下载对应的初赛训练集 unzip -d /mnt/workspace/sais_third_new_energy_baseline/data /mnt/workspace/sais_third_new_energy_baseline/初赛训练集.zip 在命令行执行这个命令解压 2.没定义测试集 te…

CANFD技术在新能源汽车通信网络中的应用与可靠性分析

一、引言 新能源汽车产业正处于快速发展阶段,其电子系统复杂度不断攀升,涵盖众多传感器、控制器与执行器。高效通信网络成为确保新能源汽车安全运行与智能功能实现的核心要素。传统CAN总线因带宽限制,难以满足高级驾驶辅助系统(A…

Python字典深度解析:高效键值对数据管理指南

一、字典核心概念解析 1. 字典定义与特征 字典(Dictionary)是Python中​​基于哈希表实现​​的无序可变容器,通过键值对存储数据,具有以下核心特性: ​​键值对结构​​:{key: value}形式存储数据​​快…

C++中unique_lock和lock_guard区别

目录 1.自动锁定与解锁机制 2.灵活性 3.所有权转移 4.可与条件变量配合使用 5.性能开销 在 C 中&#xff0c;std::unique_lock 和 std::lock_guard 都属于标准库 <mutex> 中的互斥锁管理工具&#xff0c;用于简化互斥锁的使用并确保线程安全。但它们存在一些显著区别…

Nvidia显卡架构演进

1 简介 显示卡&#xff08;英语&#xff1a;Display Card&#xff09;简称显卡&#xff0c;也称图形卡&#xff08;Graphics Card&#xff09;&#xff0c;是个人电脑上以图形处理器&#xff08;GPU&#xff09;为核心的扩展卡&#xff0c;用途是提供中央处理器以外的微处理器帮…

下载electron 22.3.27 源码错误集锦

下载步骤同 electron源码下载及编译_electron源码编译-CSDN博客 问题1 从github 下载 dugite超时&#xff0c;原因没有找到 Validation failed. Expected 8ea2d0d3c9d9e4615069913207371ffe892dc10fb93975972f2f6e668f2e3b3a but got e3b0c44298fc1c149afbf4c8996fb92427ae41e…

洛谷P1120 小木棍

#算法/进阶搜索 思路: 首先,最初始想法,将我们需要枚举的长木棍个数计算出来,在dfs中,我们先判断,此时枚举这根长木棍需要的长度是否为0,如果为0,我们就枚举下一个根木棍,接着再判断,此时仍需要枚举的木棍个数是否为0,如果为0,代表我们这种方案可行,直接打印长木棍长度,接着我们…

Linux教程-常用命令系列二

文章目录 1. 系统管理常用命令1. useradd - 创建用户账户功能基本用法常用选项示例 2. passwd - 管理用户密码功能基本用法常用选项示例 3. kill - 终止进程功能基本用法常用信号示例 4. date - 显示和设置系统时间功能基本用法常用选项时间格式示例 5. bc - 高精度计算器功能基…

18、TimeDiff论文笔记

TimeDiff **1. 背景与动机****2. 扩散模型基础****3. TimeDiff 模型****3.1 前向扩散过程****3.2 后向去噪过程** 4、TimeDiff&#xff08;架构&#xff09;原理训练推理其他关键点解释 DDPM&#xff08;相关数学&#xff09;1、正态分布2、条件概率1. **与多个条件相关**&…

整合SSM——(SpringMVC+Spring+Mybatis)

目录 SSM整合 创建项目 导入依赖 配置文件 SpringConfig MyBatisConfig JdbcConfig ServletConfig SpringMvcConfig 功能模块 测试 业务层接口测试 控制层测试 SSM是Java Web开发中常用的三个主流框架组合的缩写&#xff0c;分别对应Spring、Spring MVC、MyBatis…