DeepFM算法代码

以下代码均采用Tensorflow1.15版本

数据集私聊我
import tensorflow as tf
import numpy as np
import pandas as pd# 定义特征列
def get_feature_columns():# 假设 Criteo 数据集有 10 个数值特征和 10 个类别特征numerical_feature_columns = [tf.feature_column.numeric_column("num_feature_{}".format(i)) for i in range(10)]categorical_feature_columns = [tf.feature_column.categorical_column_with_hash_bucket("cat_feature_{}".format(i), hash_bucket_size=100) for i in range(10)]return numerical_feature_columns + categorical_feature_columns# 定义 DeepFM 模型
def deep_fm_model(features, labels, mode):# 嵌入层embedding_list = []for column in get_feature_columns():if isinstance(column, tf.feature_column.categorical_column_with_hash_bucket):embedding = tf.feature_column.embedding_column(column, dimension=8)embedding_list.append(embedding)# FM 部分fm_input = tf.concat([tf.feature_column.input_layer(features, column) for column in get_feature_columns()], axis=1)linear_part = tf.layers.dense(fm_input, 1)sum_square = tf.square(tf.reduce_sum(fm_input, axis=1))square_sum = tf.reduce_sum(tf.square(fm_input), axis=1)fm_part = 0.5 * tf.reduce_sum(sum_square - square_sum, axis=1, keepdims=True)# Deep 部分deep_input = tf.concat([tf.feature_column.input_layer(features, column) for column in get_feature_columns()], axis=1)deep_hidden_1 = tf.layers.dense(deep_input, 128, activation=tf.nn.relu)deep_hidden_2 = tf.layers.dense(deep_hidden_1, 64, activation=tf.nn.relu)deep_output = tf.layers.dense(deep_hidden_2, 1)# 合并combined_output = linear_part + fm_part + deep_output# 预测和损失if mode == tf.estimator.ModeKeys.PREDICT:predictions = {'predictions': combined_output}return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)loss = tf.losses.mean_squared_error(labels, combined_output)# 优化器optimizer = tf.train.AdamOptimizer(learning_rate=0.001)# 训练和评估操作if mode == tf.estimator.ModeKeys.TRAIN:train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)if mode == tf.estimator.ModeKeys.EVAL:eval_metric_ops = {'mse': tf.metrics.mean_squared_error(labels, combined_output)}return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)# 输入函数
def input_fn(data_path, batch_size):data = pd.read_csv(data_path)labels = data['label']features = data.drop('label', axis=1)dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))dataset = dataset.shuffle(buffer_size=1000).batch(batch_size).repeat()iterator = dataset.make_one_shot_iterator()features, labels = iterator.get_next()return features, labels# 训练和评估
def train_and_evaluate():# 创建 Estimatorestimator = tf.estimator.Estimator(model_fn=deep_fm_model,model_dir='your_model_dir')# 训练estimator.train(input_fn=lambda: input_fn('train_data_path.csv', batch_size=128),steps=1000)# 评估estimator.evaluate(input_fn=lambda: input_fn('eval_data_path.csv', batch_size=128))if __name__ == '__main__':train_and_evaluate()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/53250.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024.9.3 作业

自己实现栈和队列 代码&#xff1a; /*******************************************/ 文件名&#xff1a;sq.h /*******************************************/ #ifndef SQ_H #define SQ_H #include <iostream> #include<cstring>using namespace std; class …

秋招突击——算法练习——8/26——图论——200-岛屿数量、994-腐烂的橘子、207-课程表、208-实现Trie

文章目录 引言正文200-岛屿数量个人实现 994、腐烂的橘子个人实现参考实现 207、课程表个人实现参考实现 208、实现Trie前缀树个人实现参考实现 总结 引言 正文 200-岛屿数量 题目链接 个人实现 我靠&#xff0c;这道题居然是腾讯一面的类似题&#xff0c;那道题是计算最…

[数据集][目标检测]智慧牧场猪只检测数据集VOC+YOLO格式16245张1类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;16245 标注数量(xml文件个数)&#xff1a;16245 标注数量(txt文件个数)&#xff1a;16245 标…

力扣SQL仅数据库(175~185)

175. 组合两个表 题目&#xff1a;编写解决方案&#xff0c;报告 Person 表中每个人的姓、名、城市和州。如果 personId 的地址不在 Address 表中&#xff0c;则报告为 null 准备数据&#xff1a; Create table If Not Exists Person (personId int, firstName varchar(255…

vue使用html2Canvas导出图片 input文字向上偏移

vue使用html2Canvas导出图片 input文字向上偏移 图中 用的是element的输入框 行高 32px,经常测试 你使用原生的input 还是会出现偏移。 解决方法&#xff1a;修改css样式 1.怎么实现导出 网上随便找很多 2.在第一步 获取你要导出的元素id 克隆后 修改他的样式或者 你直接在你需…

设计模式 19 观察者模式

设计模式 19 创建型模式&#xff08;5&#xff09;&#xff1a;工厂方法模式、抽象工厂模式、单例模式、建造者模式、原型模式结构型模式&#xff08;7&#xff09;&#xff1a;适配器模式、桥接模式、组合模式、装饰者模式、外观模式、享元模式、代理模式行为型模式&#xff…

十二、建立自已的北斗卫星实时定位基站

一、背景 连续运行卫星定位服务系统(Continuous Operational Reference System,简称CORS系统)是现代北斗/GNSS的发展热点之一。CORS系统将网络化概念引入到了大地测量应用中,该系统的建立不仅为测绘行业带来深刻的变革,而且也将为现代网络社会中的空间信息服务带来新的思维…

基于单片机的水箱水质监测系统设计

本设计基于STM32F103C8T6为核心控制器设计了水质监测系统&#xff0c;选用DS18B20温度传感器对水箱水体温度进行采集&#xff1b;E-201-C PH传感器获取水体PH值&#xff1b;选用TS-300B浊度传感器检测水体浊度&#xff1b;采用YW01液位传感器获取水位&#xff0c;当检测水位低于…

宽带和带宽分不清楚

如何理解带宽 我们平时经常听到的带宽其实是宽带&#xff0c;举个栗子&#xff1a;我家用的是xx运营商提供的&#xff0c;号称1000M宽带&#xff0c;这其实指是的网络数据传输的速率是&#xff1a;1000Mbs&#xff08;即125MBps&#xff09;。 那么既然有宽带&#xff0c;就有…

【LVGL- 组 lv_group_t】

lv_group_t ■ group■ 组api■ lv_group_create后后面的控件自动添加到group ■ group if (code LV_EVENT_SCREEN_LOADED) //一般放在loaded 事件中添加到lv_group_set_default(key_group); lv_indev_set_group(indev_keypad, key_group); //和输入设备关联。 }■ 组api…

MCU官方IDE软件安装及学习教程集合 — STM32CubeIDE(STM32)

简介 各MCU厂商为保证产品的市场地位以及用户体验&#xff0c;不断的完善自己的产品配套&#xff0c;搭建自己的开发生态&#xff0c;像国外ST公司&#xff0c;国内的GD&#xff08;兆易创新&#xff09;&#xff0c;AT&#xff08;雅特力&#xff09;等等。目前就开发生态而言…

09.定时器02

#include "reg52.h"sbit led P3^6;void delay10ms() { //1. 配置定时器0工作模式位16位计时TMOD 0x01;//2. 给初值&#xff0c;定一个10ms出来TL00x00;TH00xDC;//3. 开始计时TR0 1;TF0 0; } void main() {int cnt 0;led 1;while(1){if(TF0 1)//当爆表的时候&a…

【Qt】QLCDNumber | QProgressBar | QCalendarWidget

文章目录 QLCDNumber —— 显示数字QLCDNumber 的属性QLCDNumber 的使用 QProgressBar —— 进度条QProgressBar 的属性创建一个进度条修改为 红色的进度条 QCalendarWidget —— 日历QCalendarWidget 的属性QCalendarWidget 的使用 QLCDNumber —— 显示数字 QLCDNumber 的属…

UE4_后期处理_后期处理材质及后期处理体积一

后期处理效果 在渲染之前应用于整个渲染场景的效果。 后期处理效果&#xff08;Post-processing effect&#xff09;使美术师和设计师能够对影响颜色、色调映射、光照的属性和功能进行组合选择&#xff0c;从而定义场景的整体外观。要访问这些功能&#xff0c;可以将一种称为…

使用docker调试odoo

使用 Visual Studio Code (VSCode) 的 Dev Containers 进行 Odoo 开发和调试是一个高效的方法&#xff0c;尤其是当你希望在一个清洁且一致的开发环境中工作时。以下是设置和配置 Dev Container 以在 Docker 环境中单步调试 Odoo 系统的步骤&#xff1a; ### 步骤 1: 准备 Doc…

多角度解读WMS:探寻仓库管理系统的核心功能

多角度解读 WMS 仓库管理系统 1. 概述 WMS 在数字化工厂中具有举足轻重的地位&#xff0c;它不仅提高了仓储管理的效率与准确性&#xff0c;还能优化整个供应链的管理&#xff0c;支持灵活生产模式&#xff0c;并提供决策支持的关键数据。通过现代前后端技术的架构设计&#xf…

【Spring Boot 3】自定义拦截器

【Spring Boot 3】自定义拦截器 背景介绍开发环境开发步骤及源码工程目录结构总结背景 软件开发是一门实践性科学,对大多数人来说,学习一种新技术不是一开始就去深究其原理,而是先从做出一个可工作的DEMO入手。但在我个人学习和工作经历中,每次学习新技术总是要花费或多或…

Prometheus_0基础_学习笔记

一、基本概念 Prometheus是由golang语言开发的一套开源的监控、报警、时间序列数据库的组合&#xff0c;是一款基于时序数据库的开源监控告警系统。 时间序列数据库&#xff1a;时间序列数据库&#xff08;Time Serires Database , TSDB&#xff09;不同于传统的关系型数据库。…

idea如何高亮、标记代码颜色的2种方式

zihao 第一种高亮方式 ctrlf 双击选择执行快捷键&#xff0c;所有被搜索的单词都会被搜索且高亮 第二种高亮方式 安装grep console 日志管理插件 ctrlaltf3 双击选择执行快捷键&#xff0c;所有被标记一个颜色高亮

银行卡二三四要素验证-银行卡二三四要素验证接口-银行卡二三四要素验证api

1、接口介绍 银行卡二三四要素验证接口是一种用于验证用户银行卡信息真实性和有效性的技术接口。这种接口在金融、电商等领域有着广泛的应用&#xff0c;旨在确保交易的安全性和合规性。 2、接口地址 全面覆盖&#xff0c;支持所有带银联标识的银行卡; 高准确性-验证结果实时返…