Keras实现SegNet

我真服了原来我之前用tf复现SegNet给复现错了
在网上试了多个版本代码,折腾了好久,现在终于复现对了,代码也跑通了
SegNet的架构比较老了,这几年都没人更新代码了,我这里算是提供一个最近能跑通的版本的代码吧

tf版本2.4.1

首先主要是构建两个类来实现池化索引,这里经过反复尝试我懵懵懂懂地解决了其它代码直接搬运过来导致的各种报错

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layerclass MaxPoolingWithArgmax2D(Layer):def __init__(self, pool_size=(2, 2), strides=(2, 2), padding='same', **kwargs):super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)self.padding = paddingself.pool_size = pool_sizeself.strides = stridesdef call(self, inputs, **kwargs):padding = self.paddingpool_size = self.pool_sizestrides = self.stridesif K.backend() == 'tensorflow':ksize = [1, pool_size[0], pool_size[1], 1]padding = padding.upper()strides = [1, strides[0], strides[1], 1]output, argmax = tf.nn.max_pool_with_argmax(inputs, ksize=ksize, strides=strides, padding=padding)else:errmsg = '{} backend is not supported for layer {}'.format(K.backend(), type(self).__name__)raise NotImplementedError(errmsg)argmax = K.cast(argmax, K.floatx())return [output, argmax]def compute_output_shape(self, input_shape):ratio = (1, 2, 2, 1)output_shape = [dim // ratio[idx] if dim is not None else None for idx, dim in enumerate(input_shape)]output_shape = tuple(output_shape)return [output_shape, output_shape]def compute_mask(self, inputs, mask=None):return 2 * [None]def get_config(self):config = super(MaxPoolingWithArgmax2D, self).get_config()config.update({"pool_size": self.pool_size,"strides": self.strides,"padding": self.padding,})return configclass MaxUnpooling2D(Layer):def __init__(self, size=(2, 2), **kwargs):super(MaxUnpooling2D, self).__init__(**kwargs)self.size = sizedef call(self, inputs, output_shape=None):updates, mask = inputs[0], inputs[1]with tf.compat.v1.variable_scope(self.name):mask = K.cast(mask, 'int32')input_shape = tf.shape(updates, out_type='int32')#  calculation new shapeif output_shape is None:output_shape = (input_shape[0], input_shape[1] * self.size[0], input_shape[2] * self.size[1], input_shape[3])self.output_shape1 = output_shape# calculation indices for batch, height, width and feature mapsone_like_mask = K.ones_like(mask, dtype='int32')batch_shape = K.concatenate([[input_shape[0]], [1], [1], [1]], axis=0)batch_range = K.reshape(tf.range(output_shape[0], dtype='int32'), shape=batch_shape)b = one_like_mask * batch_rangey = mask // (output_shape[2] * output_shape[3])x = (mask // output_shape[3]) % output_shape[2]feature_range = tf.range(output_shape[3], dtype='int32')f = one_like_mask * feature_range# transpose indices & reshape update values to one dimensionupdates_size = tf.size(updates)indices = K.transpose(K.reshape(K.stack([b, y, x, f]), [4, updates_size]))values = K.reshape(updates, [updates_size])ret = tf.scatter_nd(indices, values, output_shape)input_shape = updates.shapeout_shape = [-1,input_shape[1] * self.size[0],input_shape[2] * self.size[1],input_shape[3]]return K.reshape(ret, out_shape)def compute_output_shape(self, input_shape):mask_shape = input_shape[1]return mask_shape[0], mask_shape[1] * self.size[0], mask_shape[2] * self.size[1], mask_shape[3]def get_config(self):config = super(MaxUnpooling2D, self).get_config()config.update({"size": self.size,})return config

另外SegNet网络主体部分,注意池化和反池化的时候filters数量要对得上

def SegNet(fNum, dates, lossweights, filters=64):inputs = keras.layers.Input((fNum*dates, img_h, img_w))inputs0 = keras.layers.Lambda(reshapes2)(inputs) # 针对我数据的reshape# Encoderconv1 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(inputs0)conv1 = keras.layers.BatchNormalization()(conv1)conv1 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(conv1)conv1 = keras.layers.BatchNormalization()(conv1)pool1, idx1 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv1)conv2 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(pool1)conv2 = keras.layers.BatchNormalization()(conv2)conv2 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(conv2)conv2 = keras.layers.BatchNormalization()(conv2)pool2, idx2 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv2)conv3 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(pool2)conv3 = keras.layers.BatchNormalization()(conv3)conv3 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(conv3)conv3 = keras.layers.BatchNormalization()(conv3)pool3, idx3 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv3)conv4 = keras.layers.Conv2D(filters*8, (3, 3), activation='relu', padding='same')(pool3)conv4 = keras.layers.BatchNormalization()(conv4)conv4 = keras.layers.Conv2D(filters*8, (3, 3), activation='relu', padding='same')(conv4)conv4 = keras.layers.BatchNormalization()(conv4)pool4, idx4 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv4)# Decoderup5 = MaxUnpooling2D((2,2))([pool4, idx4])conv5 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(up5)conv5 = keras.layers.BatchNormalization()(conv5)conv5 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(conv5)conv5 = keras.layers.BatchNormalization()(conv5)up6 = MaxUnpooling2D(size=(2, 2))([conv5, idx3])conv6 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(up6)conv6 = keras.layers.BatchNormalization()(conv6)conv6 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(conv6)conv6 = keras.layers.BatchNormalization()(conv6)up7 = MaxUnpooling2D(size=(2, 2))([conv6, idx2])conv7 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(up7)conv7 = keras.layers.BatchNormalization()(conv7)conv7 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(conv7)conv7 = keras.layers.BatchNormalization()(conv7)up8 = MaxUnpooling2D(size=(2, 2))([conv7, idx1])conv8 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(up8)conv8 = keras.layers.BatchNormalization()(conv8)conv8 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(conv8)conv8 = keras.layers.BatchNormalization()(conv8)outputs = keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(conv8)model = keras.models.Model(inputs=inputs, outputs=outputs)return model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/841502.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Elasticsearch集群许可证过期问题解决方法汇总

最近在使用elasticsearch的过程中,使用elastic-head进行可视化展示集群的状态和信息,从2024年5月18日突然elastic-head无法现在集群的状态界面啦,elasticsearch集群状态是正常,命令如下: curl -X GET "localhost:9200/_cluster/health?pretty" 在google页面上通过…

引流500+创业粉,抖音口播工具

在抖音平台运营一个专注于口播的工具号,旨在集结超过500位热衷于创业的粉丝,这需要精心筹划的内容策略和周到的运营计划。首先,明确你的口播工具号所专注的领域,无论是分享创业经验、财务管理技巧还是案例分析,确保你所…

Axmol 2.1.3 发布

我们非常荣幸,axmol 能在发布此版本之前被 awsome-cpp 收录! The 2.1.3 release is a minor LTS release for bugfixes and improvements, thanks to iAndyHD3 add axmol to awsome-cpp The axmol home page was change to https://axmol.dev Signifi…

引入Dao

1.crm和数据库的结合 我们先前实现的crm项目的数据都是自定义的 而非数据库获取 因此现在我们应该实现crm和数据库的集成 ListServlet.java doPost方法中在处理异常的选项中 并没有发现throws方式 而只有try-catch方式 这是因为子类throws的异常必须和父类throws异常一致或者是…

【电子元件】TL431 电压基准

TL431(C23892)是一种常用的可调节精密电压基准和电压调节器。它广泛应用于电源管理、精密参考电压和稳压电路等领域。以下是TL431的一些关键特点和使用方法: 关键特点 可调输出电压:TL431的输出电压可以通过外部电阻网络在2.495V到36V范围内调整。精度高…

淘宝x5sec

声明 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!wx a15018601872 本文章未…

vuedraggable插件 修改元素首次拖拽进入占位样式

vuedraggable是一款适用于vue3 的可拖拽插件。 通过配置ghost-class“ghost” 属性,可以对组件内元素拖拽过程中的占位符进行修改。但是无法根据ghost这一class对元素首次拖拽进组件内的占位元素进行样式修改 解决方法:元素首次拖拽进vuedraggable 中时…

python 面对对象 类 继承

继承 继承就是为了解决两个有大量重复性代码的类,抽象出一个更抽象的类放公共代码,主要是代码复用,方便代码的管理与修改 类的继承包括属性和方法,私有属性也可继承 class Person(): # 默认是继承object超类pass…

【数据结构(邓俊辉)学习笔记】二叉树03——重构

0 .概述 介绍下二叉树重构 1. 遍历序列 任何一棵二叉树我们都可以导出先序、中序、后序遍历序列。这三个序列的长度相同,他们都是由树中的所有节点依照相应的遍历策略所确定的次序,依次排列而成。 若已知某棵树的遍历序列是否可以忠实地还原出这棵树…

ic基础|时钟篇05:芯片中buffer到底是干嘛的?一文带你了解buffer的作用

大家好,我是数字小熊饼干,一个练习时长两年半的ic打工人。我在两年前通过自学跨行社招加入了IC行业。现在我打算将这两年的工作经验和当初面试时最常问的一些问题进行总结,并通过汇总成文章的形式进行输出,相信无论你是在职的还是…

图片AI高效生成惊艳之作,一键解锁无限创意,轻松打造概念艺术新纪元!

在数字化时代,图片已经成为我们表达创意、传递信息的重要载体。然而,传统的图片生成方式往往耗时耗力,无法满足我们对于高效、创意的需求。幸运的是,现在有了图片AI,它以其高效、智能的特点,为我们带来了全…

微服务-系统架构

微服务: 系统架构的演变 单一应用架构 早期的互联网应用架构,大量应用服务 功能 集中在一个包里,把大量的应用打包为一个jar包,部署在一台服务器,例如tomcat上部署Javaweb项目 缺点:耦合度高,一台服务器…

一千题,No.0014(素数对猜想)

让我们定义dn​为&#xff1a;dn​pn1​−pn​&#xff0c;其中pi​是第i个素数。显然有d1​1&#xff0c;且对于n>1有dn​是偶数。“素数对猜想”认为“存在无穷多对相邻且差为2的素数”。 现给定任意正整数N(<105)&#xff0c;请计算不超过N的满足猜想的素数对的个数。…

分布式缓存:探讨如何在Java中使用分布式缓存解决方案,比如Redis或Hazelcast等

分布式缓存简介 分布式缓存是一种数据管理策略,它可以帮助我们更有效地使用网络中多台服务器的存储资源,从而提高数据获取的速度。我们可以把数据(如数据库查询结果、计算结果等)存储在这种缓存中,从而提供更快的数据访问速度,减少对原始数据源的访问并降低网络负载。 …

Point-Nerf 理论笔记和理解

文章目录 什么是point nerf 和Nerf 有什么区别Point Nerf 核心结构有哪些&#xff1f;什么是point-based radiance field? 点云位置以及置信度是怎么来Point pruning 和 Point Growing 什么是point nerf 和Nerf 有什么区别 基本的nerf 是通过过拟合MLP来完成任意视角场景的重…

缓存穿透、击穿、雪崩的解决方法

一、缓存穿透指的是查询一个不存在的数据&#xff0c;由于缓存中没有对应的值&#xff0c;每次请求都要查询数据库&#xff0c;容易导致数据库压力过大。 解决方法&#xff1a; 使用布隆过滤器等手段可以在请求到达后台处理之前就过滤掉这些不存在的请求&#xff0c;避免了对数…

gazebo中通过编写插件发布随动关节的角度值到话题

1. cpp 编写 #include <gazebo/gazebo.hh> #include <gazebo/physics/physics.hh> #include <gazebo/common/common.hh> #include <ros/ros.h> #include <std_msgs/Float64.h>namespace gazebo {class PoleJointAnglePublisher : public ModelP…

KAFKA消费者-进阶用法

Apache Kafka 是一个分布式流处理平台&#xff0c;用于构建实时流数据管道和应用程序。在 Kafka 中&#xff0c;消费者&#xff08;Consumer&#xff09;用于从 Kafka 主题&#xff08;Topic&#xff09;中读取消息并进行处理。本文将介绍 Kafka 消费者的进阶用法&#xff0c;包…

Linux(六)

Linux&#xff08;六&#xff09; 自定义头文件自定义头文件中写什么如何引入头文件条件编译条件编译作用 gcc工作原理Make 工作管理器什么是Make什么是Makefile/makefileMakefile假目标Makefile中的变量自定义变量预定义变量自动变量 Makefile中变量展开方式递归展开方式简单展…

正运动机器视觉运动控制一体机应用例程

机器视觉运动控制一体机应用例程-多目标形状匹配-正运动技术 (zmotion.com.cn) 机器视觉运动控制一体机应用例程&#xff08;二&#xff09; 颜色识别-正运动技术 (zmotion.com.cn) 机器视觉运动控制一体机应用例程&#xff08;三&#xff09; 基于BLOB分析的多圆定位-正运动…