DQN强化学习

算是自己写的第一个强化学习环境,目前还有很多纰漏,逐步改进ing。
希望能在两周内施工完成。


import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import pandas as pddef moving_average(data, window_size):"""平滑函数:param data::param window_size::return:"""if window_size <= 0:raise ValueError("Window size should be greater than 0.")if window_size > len(data):raise ValueError("Window size should not be greater than the length of data.")# Cumulative sum of data elementscumsum = [0]for i, x in enumerate(data):cumsum.append(cumsum[i] + x)# Compute moving averagesma_values = []for i in range(len(data) - window_size + 1):average = (cumsum[i + window_size] - cumsum[i]) / window_sizema_values.append(average)return ma_valuesdef plot_data(data, title="Data Plot", x_label="X-axis", y_label="Y-axis"):"""画图:param data::param title::param x_label::param y_label::return:Plots a simple line graph based on the provided data.Parameters:- data (list): A list of integers or floats to be plotted.- title (str): The title of the plot.- x_label (str): The label for the x-axis.- y_label (str): The label for the y-axis."""plt.figure(figsize=(10, 5))  # Set the figure sizeplt.plot(data)  # Plot the dataplt.title(title)  # Set the titleplt.xlabel(x_label)  # Set x-axis labelplt.ylabel(y_label)  # Set y-axis labelplt.grid(True, which='both', linestyle='--', linewidth=0.5)  # Add a gridplt.tight_layout()  # Adjust subplot parameters to give specified paddingplt.show()class TransportMatchingEnv:def __init__(self, num_drivers=5, num_goods=5, max_price=10, max_time=5):""":param num_drivers: 货车数量:param num_goods: 货物数量:param max_price: 最大价格:param max_time: 最大时间"""self.num_drivers = num_driversself.num_goods = num_goodsself.max_price = max_priceself.max_time = max_time# 动作空间self.action_dim = self.num_drivers * self.num_goods * self.max_price * self.max_time# 当前协商状态 TODO: 状态,需要加很多东西self.current_negotiation = None# 状态self.combined_state = self.reset()# 距离矩阵,表示货与车之间的距离self.distance_matrix = np.random.randint(0, 100, (self.num_goods, self.num_drivers))# 货主期望抵达时间self.goods_time_preferences = np.random.randint(0, self.max_time, self.num_goods)# 货主期望价格self.goods_expected_prices = np.random.randint(0, self.max_price, self.num_goods)# 车主是否空闲self.driver_availabilities = np.random.choice([0, 1], self.num_drivers)# 货物是否有特殊需求self.goods_special_requirements = np.random.choice([0, 1], self.num_goods)# 车主是否有接受特殊货物的能力self.driver_special_capabilities = np.random.choice([0, 1])def decode_action(self, encoded_action):"""将action解码为人类可以读懂的形式:param encoded_action::return:"""total_actions_for_price_time = self.max_price * self.max_timetotal_actions_per_good = self.num_drivers * total_actions_for_price_timetotal_actions = self.num_goods * total_actions_per_goodif encoded_action >= total_actions:raise ValueError("Encoded action is out of bounds!")good_index = encoded_action // total_actions_per_goodresidual = encoded_action % total_actions_per_gooddriver_index = residual // total_actions_for_price_timeresidual = residual % total_actions_for_price_timeprice = residual // self.max_timetime = residual % self.max_timereturn driver_index, good_index, price, timedef compute_reward(self, driver_index, good_index, price, time):"""计算reward,:param driver_index::param good_index::param price::param time::return:"""# 1. Distance factor (assuming you have a distance matrix or function to compute distance)# distance_matrix = ... # a matrix containing distances between goods and driversdistance = self.distance_matrix[good_index][driver_index]distance_factor = -distance  # negative reward for longer distances# 2. Time factordelivery_time_preference = self.goods_time_preferences[good_index]  # assuming you have this datatime_penalty = -abs(delivery_time_preference - time) * 2  # penalize based on how far from preferred time# 3. Price factorexpected_price = self.goods_expected_prices[good_index]  # assuming you have this dataprice_difference = price - expected_priceprice_factor = -abs(price_difference)  # prefer prices close to expected# 4. Availability of the driver (assuming you have this data)driver_availability = self.driver_availabilities[driver_index]  # e.g., 0 for not available, 1 for availableavailability_factor = driver_availability * 10  # give a bonus for available drivers# 5. Special requirements (assuming you have this data)good_requirement = self.goods_special_requirements[good_index]  # e.g., 0 for no requirement, 1 for special storagedriver_capability = self.driver_special_capabilities[driver_index]  # e.g., 0 for no capability, 1 for special storagerequirement_factor = 0if good_requirement > 0 and driver_capability < good_requirement:requirement_factor = -20  # huge penalty if driver can't meet the special requirementtotal_reward = distance_factor + time_penalty + price_factor + availability_factor + requirement_factorreturn total_rewarddef reset(self):"""重置环境:return:"""random.seed(0)self.current_negotiation = np.zeros((self.num_goods, self.num_drivers))# Refresh all the parameters every time you reset the environmentself.distance_matrix = np.random.randint(0, 100, (self.num_goods, self.num_drivers))self.goods_time_preferences = np.random.randint(0, self.max_time, self.num_goods)self.goods_expected_prices = np.random.randint(0, self.max_price, self.num_goods)self.driver_availabilities = np.random.choice([0, 1], self.num_drivers)self.goods_special_requirements = np.random.choice([0, 1], self.num_goods)self.driver_special_capabilities = np.random.choice([0, 1], self.num_drivers)# print(f'self.distance_matrix:{self.distance_matrix}')# print(f'goods_time_preferences:{self.goods_time_preferences}')# print(f'goods_expected_prices:{self.goods_expected_prices}')# print(f'driver_availabilities:{self.driver_availabilities}')# print(f'goods_special_requirements:{self.goods_special_requirements}')# print(f'driver_special_capabilities:{self.driver_special_capabilities}')# self.distance_matrix = np.array([[67, 53, 24, 68, 92, 64, 85, 6, 77, 43],#                                  [40, 78, 48, 31, 14, 6, 7, 37, 26, 67],#                                  [96, 43, 73, 2, 71, 74, 37, 87, 17, 64],#                                  [28, 25, 84, 62, 51, 28, 32, 58, 98, 72],#                                  [13, 52, 38, 44, 11, 49, 11, 56, 80, 25],#                                  [3, 68, 25, 65, 50, 64, 2, 22, 40, 46],#                                  [98, 1, 9, 45, 80, 51, 86, 65, 22, 50],#                                  [98, 6, 73, 22, 12, 58, 84, 13, 38, 79],#                                  [78, 48, 52, 21, 36, 92, 71, 1, 22, 33],#                                  [43, 76, 74, 89, 19, 51, 34, 63, 11, 99]])# self.goods_time_preferences = [1, 1, 3, 4, 1, 1, 1, 3, 0, 4]# self.goods_expected_prices = [3, 4, 7, 1, 2, 2, 7, 5, 8, 2]# self.driver_availabilities = [1, 1, 0, 1, 0, 0, 1, 1, 0, 0]# self.goods_special_requirements = [0, 1, 0, 0, 1, 1, 1, 1, 0, 0]# self.driver_special_capabilities = [1, 1, 0, 0, 0, 1, 0, 0, 1, 1]# Combine everything into a single flattened statecombined_state = np.concatenate((self.current_negotiation.flatten(),self.distance_matrix.flatten(),self.goods_time_preferences,self.goods_expected_prices,self.driver_availabilities,self.goods_special_requirements,self.driver_special_capabilities))# print(f'combined_state.shape:{combined_state.shape}')return combined_statedef driver_satisfaction(self, fee_received, expected_fee, distance_travelled, max_distance, wait_time,max_wait_time,goods_condition):"""为车主设计的满意度计算:param fee_received: 收到的费用:param expected_fee: 预期费用:param distance_travelled: 行驶距离:param max_distance: 最大距离:param wait_time: 等待时间:param max_wait_time: 最大等待时间:param goods_condition: 货物状况:return:"""# 价格满意度price_satisfaction = (fee_received / expected_fee) * 40  # assuming max weightage of 40 for price# 距离满意度distance_satisfaction = ((max_distance - distance_travelled) / max_distance) * 30  # assuming max weightage of 30 for distance# 等待时间满意度wait_satisfaction = ((max_wait_time - wait_time) / max_wait_time) * 20  # assuming max weightage of 20 for wait time# 货物状况满意度goods_satisfaction = 10 if goods_condition == 'good' else 0  # assuming max weightage of 10 for goods condition# 总满意度total_satisfaction = price_satisfaction + distance_satisfaction + wait_satisfaction + goods_satisfactionreturn total_satisfactiondef shipper_satisfaction(self, fee_paid, expected_fee, delivery_time, expected_delivery_time, goods_condition,driver_service_quality):"""为货主设计的满意度计算:param fee_paid: 已付费用:param expected_fee: 预期费用:param delivery_time: 运输时间:param expected_delivery_time: 期望运输时间:param goods_condition: 货物状况:param driver_service_quality: 司机服务质量:return:"""# 价格满意度price_satisfaction = (expected_fee / fee_paid) * 30  # assuming max weightage of 30 for price# 时间满意度time_satisfaction = ((expected_delivery_time - delivery_time) / expected_delivery_time) * 30  # assuming max weightage of 30 for delivery time# 货物状况满意度goods_satisfaction = 20 if goods_condition == 'good' else 0# 服务满意度service_satisfaction = driver_service_quality * 20 / 100# 总满意度total_satisfaction = price_satisfaction + time_satisfaction + goods_satisfaction + service_satisfactionreturn total_satisfactiondef successOrFailure(self):# 判断是否协商成功,根据双方满意度# True为协商成功,false为协商失败return 1def step(self, encoded_action):""" TODO核心逻辑部分首先,明确何时协商成功,何时协商失败:param encoded_action: 待被decode的action:return:"""driver_index, good_index, price, time = self.decode_action(encoded_action)# print(f'driver_index, good_index, price, time:{driver_index, good_index, price, time}')# if self.current_negotiation[good_index][driver_index] == 1 or price >= self.max_price and time >= self.max_time:#     # 如果已经被匹配#     reward = 0#     state = self.current_negotiation.flatten()#     done = np.sum(self.current_negotiation) == self.num_goods#     return state, reward, done, {}# self.shipper_satisfaction()# if self.successOrFailure() == 1:#     # 如果协商成功#     pass# elif self.successOrFailure() == 2:#     # 协商失败,进行报价与反报价#     pass# else:#     # 协商失败,直接结束#     passif price <= self.max_price and time <= self.max_time:self.current_negotiation[good_index][driver_index] = 1reward = self.compute_reward(driver_index, good_index, price, time)combined_state = np.concatenate((self.current_negotiation.flatten(),self.distance_matrix.flatten(),self.goods_time_preferences,self.goods_expected_prices,self.driver_availabilities,self.goods_special_requirements,self.driver_special_capabilities))done = np.sum(self.current_negotiation) == self.num_goods# print(f'reward, state, done:{reward, state, done}')return combined_state, reward, done, {}def render(self):print(self.current_negotiation)# Simple random agent for testing
class RandomAgent:def __init__(self, action_dim):self.action_dim = action_dimdef act(self):return np.random.choice(self.action_dim)class DQN(nn.Module):def __init__(self, input_dim, output_dim):# print(f'input_dim,output_dim:{input_dim, output_dim}')super(DQN, self).__init__()self.fc = nn.Sequential(nn.Linear(input_dim, 128),nn.ReLU(),nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, output_dim))def forward(self, x):# print(f'x.shape:{x.shape}')return self.fc(x)class DQNAgent:def __init__(self, input_dim, action_dim, gamma=0.99, epsilon=0.99, lr=0.001):self.input_dim = input_dimself.action_dim = action_dimself.gamma = gammaself.epsilon = epsilonself.lr = lrself.network = DQN(input_dim, action_dim).float().to(device)self.target_network = DQN(input_dim, action_dim).float().to(device)self.target_network.load_state_dict(self.network.state_dict())self.optimizer = optim.Adam(self.network.parameters(), lr=self.lr)self.memory = deque(maxlen=2000)def act(self, state):if np.random.random() > self.epsilon:state = torch.tensor([state], dtype=torch.float32).to(device)with torch.no_grad():action = self.network(state).argmax().item()return actionelse:return np.random.choice(self.action_dim)def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def train(self, batch_size=64):if len(self.memory) < batch_size:returnbatch = random.sample(self.memory, batch_size)# print(f'batch:{len(batch)}')states, actions, rewards, next_states, dones = zip(*batch)states = torch.tensor(states, dtype=torch.float32).to(device)actions = torch.tensor(actions, dtype=torch.int64).to(device)rewards = torch.tensor(rewards, dtype=torch.float32).to(device)next_states = torch.tensor(next_states, dtype=torch.float32).to(device)dones = torch.tensor(dones, dtype=torch.float32).to(device)current_values = self.network(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)next_values = self.target_network(next_states).max(1)[0].detach()target_values = rewards + self.gamma * next_values * (1 - dones)loss = nn.MSELoss()(current_values, target_values)self.optimizer.zero_grad()loss.backward()self.optimizer.step()def update_target_network(self):self.target_network.load_state_dict(self.network.state_dict())def decrease_epsilon(self, decrement_value=0.001, min_epsilon=0.1):self.epsilon = max(self.epsilon - decrement_value, min_epsilon)if __name__ == '__main__':start = time.time()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")rewards = []env = TransportMatchingEnv(num_drivers=10, num_goods=10)agent = DQNAgent(env.combined_state.flatten().shape[0], env.action_dim)# agent = DQNAgent(env, env.action_dim)# 运行次数episodes = 2000for episode in tqdm(range(episodes)):state = env.reset()done = Falseepisode_reward = 0total_reward = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)agent.remember(state, action, reward, next_state, done)agent.train()episode_reward += rewardtotal_reward += rewardstate = next_state# print(f'done:{type(done)}')done = done.item()# if done is True:# print(f'state:{state}')agent.decrease_epsilon()rewards.append(total_reward)if episode % 50 == 0:agent.update_target_network()# print(f"Episode {episode + 1}/{episodes} - Reward: {episode_reward}")# 将数据df = pd.DataFrame(data=rewards)# 将DataFrame保存为excel文件df.to_excel('sample.xlsx', index=True)plot_data(moving_average(data=rewards, window_size=1), title='reward', x_label='epoch', y_label='reward')end = time.time()print(f'device: {device}')print(f'time: {end - start}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/128685.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[100天算法】-二叉树剪枝(day 48)

题目描述 给定二叉树根结点 root &#xff0c;此外树的每个结点的值要么是 0&#xff0c;要么是 1。返回移除了所有不包含 1 的子树的原二叉树。( 节点 X 的子树为 X 本身&#xff0c;以及所有 X 的后代。)示例1: 输入: [1,null,0,0,1] 输出: [1,null,0,null,1]示例2: 输入: […

Vue3 实现 clipboard 复制功能

一个很小的交互功能&#xff0c;网上搜了一下有一个 vue3-clipboard 直接支持vue3&#xff0c;到github仓库看了下&#xff0c;原作者已经不维护这个项目了&#xff1a; 推荐使用 vueuse 自带的 useclipboard 功能&#xff0c;由 vue 团队维护&#xff0c;稳定性基本没问题 官…

十六章反射与注解总结

16.1 反射 反射&#xff08;Reflection&#xff09;是指在运行时获取类的信息&#xff0c;并可以动态调用类的方法、访问或修改类的属性&#xff0c;以及构造对象的能力。 Java的反射提供了一套API&#xff0c;允许你在运行时检查类的结构、调用类的方法、获取和设置类的属性&…

学习笔记三十三:准入控制

ResourceQuota准入控制器 ResourceQuota准入控制器限制cpu、内存、pod、deployment数量限制存储空间大小 LimitRanger准入控制器在limit名称空间创建pod&#xff0c;不指定资源&#xff0c;看看是否会被limitrange规则自动附加其资源限制创建pod&#xff0c;指定cpu请求是100m&…

git init

git init&#xff1a;初始化版本库 比喻&#xff1a;想象你有一块空白的画布&#xff0c;上面什么都没有。你希望开始绘制一幅画&#xff0c;但在开始之前&#xff0c;你需要明确告诉绘图工具你要开始绘制了。这个过程就好比是在画布上执行 git init。它创建了一个空白的版本库…

Xcode15 模拟器 Rosetta 模式

打开Xcode15的方式其实没有Rosetta 选项了&#xff0c;但是可以跑Xcode默认Rosetta 模拟器。在xcode中如下方式打开&#xff1a; Product -> Destination -> Destination Architectures -> 打开Show Rosetta Destinations 然后用这些带Rosetta的模拟器运行&#xff1…

java中如何压缩本地pdf文件,最好可以设置压缩率代码类实例编写?

在Java中&#xff0c;你可以使用Apache PDFBox库来压缩PDF文件。下面是一个简单的代码示例&#xff0c;展示如何使用PDFBox库来压缩PDF文件&#xff0c;并可以设置压缩率。 首先&#xff0c;确保你的项目中已经添加了PDFBox依赖。如果你使用Maven&#xff0c;可以在pom.xml文件…

《研发效能(DevOps)工程师》课程简介(二)丨IDCF

为贯彻落实《关于深化人才发展体制机制改革的意见》&#xff0c;推动实施人才强国战略&#xff0c;促进专业技术人员提升职业素养、补充新知识新技能&#xff0c;实现人力资源深度开发&#xff0c;推动经济社会全面发展&#xff0c;根据《中华人民共和国劳动法》有关规定&#…

vivado 报错之procedural assignment to a non-register result is not permitted“

文章目录 这个错误通常是由于尝试在非寄存器类型的对象上进行过程赋值所引起的。在 Verilog 中&#xff0c;当使用 always 块时&#xff0c;其中的赋值操作应该只用于寄存器类型的变量&#xff0c;比如 reg 类型。非寄存器类型的信号&#xff08;比如 wire&#xff09;不能在 a…

【tio-websocket】15、学习tio的第1步—tio-study

tio-study 工程简介 tio-study 是用于学习 t-io 的示范工程,tio-study 是入门 t-io 最好的方式!tio-study 工程演示的是一个典型的 TCP 长连接应用工程,分为 server(服务端) 和 client(客户端) 工程,server 和 client 共用 common(公共模块) 工程。 关于 tio-study…

黑色星期五来袭,Ozon为你提供丰富的推广工具和资源,助你实现销售突破!

Ozon的“黑色星期五”促销活动为卖家们提供了丰富的推广工具和资源&#xff0c;以确保他们的商品在促销期间获得最大的曝光度和销售额。卖家们应该充分利用这些机会&#xff0c;制定合适的折扣策略&#xff0c;并确保他们的商品在Ozon平台上脱颖而出。 为了推广Ozon黑色星期五促…

C++归并排序算法的应用:计算右侧小于当前元素的个数

题目 给你一个整数数组 nums &#xff0c;按要求返回一个新数组 counts 。数组 counts 有该性质&#xff1a; counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。 示例 1&#xff1a; 输入&#xff1a;nums [5,2,6,1] 输出&#xff1a;[2,1,1,0] 解释&#xff1a; 5 …

深入理解TCP协议

深入理解TCP 1.TCP基础概念了解 1.1简介 TCP&#xff08;Transmission Control Protocol&#xff09;是一种计算机网络协议&#xff0c;用于在网络上可靠地传输数据。它确保数据的完整性、顺序性和可靠性&#xff0c;通过建立连接、数据分段、错误检测和恢复机制&#xff0c…

【数据结构】二叉树结构

二叉树 前言引入二叉树——二叉树的独特之处一、二叉树的结构 的 核心思想二、二叉树的代码实现>binary tree.h> binary tree.c&#xff08;一&#xff09;手动构建二叉树 <测试用>&#xff08;二&#xff09;二叉树销毁&#xff08;三&#xff09;节点个数&#x…

Java 客户端、服务端NIO大文件传输

一、需求 公司电脑不让使用U盘&#xff0c;又不想通过公司聊天软件传输&#xff0c;怕被监控。但是通过QQ、微信传输文件对文件大小又有限制。基于种种原因&#xff0c;自己简单写了个服务端、客户端进行文件传输&#xff0c;大文件最好在局域网内进行数据传输。 二、pom依赖…

OSPF高级特性1(重发布,虚链路)

目录 OSPF高级特性(1) 一、OSPF不规则区域类型 二、解决方案 1、使用虚连接 演示一&#xff1a;非骨干区域无法和骨干区域保持连通 演示二&#xff1a;骨干区域被分割 2、使用多进程双向重发布 OSPF高级特性(1) 一、OSPF不规则区域类型 产生原因&#xff1a;区…

界面组件DevExtreme v23.1 —— UI模板库更新新功能

在DevExtreme在v22.2版本中附带了针对Angular、React和Vue的新UI模板库&#xff0c;这个新的UI模板库包含多个响应式UI模板&#xff0c;您可以将其用作业务应用程序的起点&#xff0c;模板包括类似CRM的布局、仪表盘、身份验证表单等。在这篇文章中&#xff0c;我们将看看在v23…

【LLM】大语言模型高效微调方案Lora||直击底层逻辑

敬请期待... Reference 深入浅出剖析 LoRA 技术原理_lora csdn-CSDN博客 【OpenLLM 006】LoRA:大模型的低秩适配-最近大火的lora到底是什么东西&#xff1f;为啥stable diffusion和开源ChatGPT复现都在用&#xff1f; - 知乎 (zhihu.com)

如何让企业配件管理高效又智能!仓库配件出入库管理系统哪家的好用?

在当今快速发展的商业环境中&#xff0c;企业运营的效率和管理的重要性日益凸显。对于许多企业来说&#xff0c;仓库配件管理是一个关键的环节&#xff0c;它不仅涉及到物品的存储和分发&#xff0c;还与企业的成本控制和运营流程紧密相关。然而&#xff0c;管理仓库配件是一项…

[概述] 获取点云数据的仪器

这里所说的获取点云的仪器指的是可以获取场景中物体距离信息的相关设备&#xff0c;下面分别从测距原理以及适用场景来进行介绍。 一、三角测距法 三角测距原理 就是利用三角形的几何关系来测量物体的距离。想象一下&#xff0c;你站在一个地方&#xff0c;你的朋友站在另一…