Pytorch个人学习记录总结 玩俄罗斯方块の深度学习小项目

目录

前言

模型成果演示

训练过程演示

 代码实现

deep_network

tetris

test

train


前言

当今,深度学习在各个领域展现出了惊人的应用潜力,而游戏开发领域也不例外。俄罗斯方块作为经典的益智游戏,一直以来深受玩家喜爱。在这个项目中,我将深度学习与游戏开发相结合,通过使用PyTorch,为俄罗斯方块赋予了智能化的能力。

这个深度学习项目的目标是训练一个模型,使其能够自动玩俄罗斯方块,并且在游戏中取得高分。通过使用神经网络,我以游戏的状态作为输入,然后模型将预测最佳的移动策略,从而使方块能够正确地落下并消除行。通过反复训练和优化,我希望能够让模型达到专业玩家的水平,并且掌握一些高级策略。

本博客将详细介绍我在这个项目中所采用的深度学习方法和技术。我将分享我的代码实现,并解释我在训练过程中所遇到的挑战和解决方案。无论你是对深度学习感兴趣还是对俄罗斯方块情有独钟,这个项目都能够给你带来一些启发和思考。

我相信通过将深度学习和游戏开发相结合,我们能够为游戏带来全新的可能性。让我们一起探索这个项目,看看深度学习如何在俄罗斯方块这个经典游戏中展现其强大的应用能力吧!

模型成果演示

Pytorch个人学习记录总结 俄罗斯方块の深度学习小项目

训练过程演示

Pytorch个人学习记录总结 俄罗斯方块の深度学习小项目

 代码实现

deep_network

import torch.nn as nnclass DeepQNetwork(nn.Module):def __init__(self):super(DeepQNetwork, self).__init__()self.conv1 = nn.Sequential(nn.Linear(4, 64), nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Linear(64, 64), nn.ReLU(inplace=True))self.conv3 = nn.Sequential(nn.Linear(64, 1))self._create_weights()def _create_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.constant_(m.bias, 0)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)return x

tetris

import numpy as np
from PIL import Image
import cv2
from matplotlib import style
import torch
import randomstyle.use("ggplot")class Tetris:piece_colors = [(0, 0, 0),(255, 255, 0),(147, 88, 254),(54, 175, 144),(255, 0, 0),(102, 217, 238),(254, 151, 32),(0, 0, 255)]pieces = [[[1, 1],[1, 1]],[[0, 2, 0],[2, 2, 2]],[[0, 3, 3],[3, 3, 0]],[[4, 4, 0],[0, 4, 4]],[[5, 5, 5, 5]],[[0, 0, 6],[6, 6, 6]],[[7, 0, 0],[7, 7, 7]]]def __init__(self, height=20, width=10, block_size=20):self.height = heightself.width = widthself.block_size = block_sizeself.extra_board = np.ones((self.height * self.block_size, self.width * int(self.block_size / 2), 3),dtype=np.uint8) * np.array([204, 204, 255], dtype=np.uint8)self.text_color = (200, 20, 220)self.reset()def reset(self):self.board = [[0] * self.width for _ in range(self.height)]self.score = 0self.tetrominoes = 0self.cleared_lines = 0self.bag = list(range(len(self.pieces)))random.shuffle(self.bag)self.ind = self.bag.pop()self.piece = [row[:] for row in self.pieces[self.ind]]self.current_pos = {"x": self.width // 2 - len(self.piece[0]) // 2, "y": 0}self.gameover = Falsereturn self.get_state_properties(self.board)def rotate(self, piece):num_rows_orig = num_cols_new = len(piece)num_rows_new = len(piece[0])rotated_array = []for i in range(num_rows_new):new_row = [0] * num_cols_newfor j in range(num_cols_new):new_row[j] = piece[(num_rows_orig - 1) - j][i]rotated_array.append(new_row)return rotated_arraydef get_state_properties(self, board):lines_cleared, board = self.check_cleared_rows(board)holes = self.get_holes(board)bumpiness, height = self.get_bumpiness_and_height(board)return torch.FloatTensor([lines_cleared, holes, bumpiness, height])def get_holes(self, board):num_holes = 0for col in zip(*board):row = 0while row < self.height and col[row] == 0:row += 1num_holes += len([x for x in col[row + 1:] if x == 0])return num_holesdef get_bumpiness_and_height(self, board):board = np.array(board)mask = board != 0invert_heights = np.where(mask.any(axis=0), np.argmax(mask, axis=0), self.height)heights = self.height - invert_heightstotal_height = np.sum(heights)currs = heights[:-1]nexts = heights[1:]diffs = np.abs(currs - nexts)total_bumpiness = np.sum(diffs)return total_bumpiness, total_heightdef get_next_states(self):states = {}piece_id = self.indcurr_piece = [row[:] for row in self.piece]if piece_id == 0:  # O piecenum_rotations = 1elif piece_id == 2 or piece_id == 3 or piece_id == 4:num_rotations = 2else:num_rotations = 4for i in range(num_rotations):valid_xs = self.width - len(curr_piece[0])for x in range(valid_xs + 1):piece = [row[:] for row in curr_piece]pos = {"x": x, "y": 0}while not self.check_collision(piece, pos):pos["y"] += 1self.truncate(piece, pos)board = self.store(piece, pos)states[(x, i)] = self.get_state_properties(board)curr_piece = self.rotate(curr_piece)return statesdef get_current_board_state(self):board = [x[:] for x in self.board]for y in range(len(self.piece)):for x in range(len(self.piece[y])):board[y + self.current_pos["y"]][x + self.current_pos["x"]] = self.piece[y][x]return boarddef new_piece(self):if not len(self.bag):self.bag = list(range(len(self.pieces)))random.shuffle(self.bag)self.ind = self.bag.pop()self.piece = [row[:] for row in self.pieces[self.ind]]self.current_pos = {"x": self.width // 2 - len(self.piece[0]) // 2,"y": 0}if self.check_collision(self.piece, self.current_pos):self.gameover = Truedef check_collision(self, piece, pos):future_y = pos["y"] + 1for y in range(len(piece)):for x in range(len(piece[y])):if future_y + y > self.height - 1 or self.board[future_y + y][pos["x"] + x] and piece[y][x]:return Truereturn Falsedef truncate(self, piece, pos):gameover = Falselast_collision_row = -1for y in range(len(piece)):for x in range(len(piece[y])):if self.board[pos["y"] + y][pos["x"] + x] and piece[y][x]:if y > last_collision_row:last_collision_row = yif pos["y"] - (len(piece) - last_collision_row) < 0 and last_collision_row > -1:while last_collision_row >= 0 and len(piece) > 1:gameover = Truelast_collision_row = -1del piece[0]for y in range(len(piece)):for x in range(len(piece[y])):if self.board[pos["y"] + y][pos["x"] + x] and piece[y][x] and y > last_collision_row:last_collision_row = yreturn gameoverdef store(self, piece, pos):board = [x[:] for x in self.board]for y in range(len(piece)):for x in range(len(piece[y])):if piece[y][x] and not board[y + pos["y"]][x + pos["x"]]:board[y + pos["y"]][x + pos["x"]] = piece[y][x]return boarddef check_cleared_rows(self, board):to_delete = []for i, row in enumerate(board[::-1]):if 0 not in row:to_delete.append(len(board) - 1 - i)if len(to_delete) > 0:board = self.remove_row(board, to_delete)return len(to_delete), boarddef remove_row(self, board, indices):for i in indices[::-1]:del board[i]board = [[0 for _ in range(self.width)]] + boardreturn boarddef step(self, action, render=True, video=None):x, num_rotations = actionself.current_pos = {"x": x, "y": 0}for _ in range(num_rotations):self.piece = self.rotate(self.piece)while not self.check_collision(self.piece, self.current_pos):self.current_pos["y"] += 1if render:self.render(video)overflow = self.truncate(self.piece, self.current_pos)if overflow:self.gameover = Trueself.board = self.store(self.piece, self.current_pos)lines_cleared, self.board = self.check_cleared_rows(self.board)score = 1 + (lines_cleared ** 2) * self.widthself.score += scoreself.tetrominoes += 1self.cleared_lines += lines_clearedif not self.gameover:self.new_piece()if self.gameover:self.score -= 2return score, self.gameoverdef render(self, video=None):if not self.gameover:img = [self.piece_colors[p] for row in self.get_current_board_state() for p in row]else:img = [self.piece_colors[p] for row in self.board for p in row]img = np.array(img).reshape((self.height, self.width, 3)).astype(np.uint8)img = img[..., ::-1]img = Image.fromarray(img, "RGB")img = img.resize((self.width * self.block_size, self.height * self.block_size), 0)img = np.array(img)img[[i * self.block_size for i in range(self.height)], :, :] = 0img[:, [i * self.block_size for i in range(self.width)], :] = 0img = np.concatenate((img, self.extra_board), axis=1)cv2.putText(img, "Score:", (self.width * self.block_size + int(self.block_size / 2), self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, str(self.score),(self.width * self.block_size + int(self.block_size / 2), 2 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, "Pieces:", (self.width * self.block_size + int(self.block_size / 2), 4 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, str(self.tetrominoes),(self.width * self.block_size + int(self.block_size / 2), 5 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, "Lines:", (self.width * self.block_size + int(self.block_size / 2), 7 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, str(self.cleared_lines),(self.width * self.block_size + int(self.block_size / 2), 8 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)if video:video.write(img)cv2.imshow("Deep Q-Learning Tetris", img)cv2.waitKey(1)

test

import argparse
import torch
import cv2
from src.tetris import Tetrisdef get_args():parser = argparse.ArgumentParser("""Implementation of Deep Q Network to play Tetris""")parser.add_argument("--width", type=int, default=10, help="The common width for all images")parser.add_argument("--height", type=int, default=20, help="The common height for all images")parser.add_argument("--block_size", type=int, default=30, help="Size of a block")parser.add_argument("--fps", type=int, default=300, help="frames per second")parser.add_argument("--saved_path", type=str, default="trained_models")parser.add_argument("--output", type=str, default="output.mp4")args = parser.parse_args()return argsdef run_test(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)if torch.cuda.is_available():model = torch.load("{}/tetris".format(opt.saved_path))else:model = torch.load("{}/tetris".format(opt.saved_path), map_location=lambda storage, loc: storage)model.eval()env = Tetris(width=opt.width, height=opt.height, block_size=opt.block_size)env.reset()if torch.cuda.is_available():model.cuda()out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), opt.fps,(int(1.5*opt.width*opt.block_size), opt.height*opt.block_size))while True:next_steps = env.get_next_states()next_actions, next_states = zip(*next_steps.items())next_states = torch.stack(next_states)if torch.cuda.is_available():next_states = next_states.cuda()predictions = model(next_states)[:, 0]index = torch.argmax(predictions).item()action = next_actions[index]_, done = env.step(action, render=True, video=out)if done:out.release()breakif __name__ == "__main__":opt = get_args()run_test(opt)

train

import argparse
import os
import shutil
from random import random, randint, sampleimport numpy as np
import torch
import torch.nn as nn
from tensorboardX import SummaryWriterfrom src.deep_q_network import DeepQNetwork
from src.tetris import Tetris
from collections import dequedef get_args():parser = argparse.ArgumentParser("""Implementation of Deep Q Network to play Tetris""")parser.add_argument("--width", type=int, default=10, help="The common width for all images")parser.add_argument("--height", type=int, default=20, help="The common height for all images")parser.add_argument("--block_size", type=int, default=30, help="Size of a block")parser.add_argument("--batch_size", type=int, default=512, help="The number of images per batch")parser.add_argument("--lr", type=float, default=1e-3)parser.add_argument("--gamma", type=float, default=0.99)parser.add_argument("--initial_epsilon", type=float, default=1)parser.add_argument("--final_epsilon", type=float, default=1e-3)parser.add_argument("--num_decay_epochs", type=float, default=2000)parser.add_argument("--num_epochs", type=int, default=3000)parser.add_argument("--save_interval", type=int, default=1000)parser.add_argument("--replay_memory_size", type=int, default=30000,help="Number of epoches between testing phases")parser.add_argument("--log_path", type=str, default="tensorboard")parser.add_argument("--saved_path", type=str, default="trained_models")args = parser.parse_args()return argsdef train(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)if os.path.isdir(opt.log_path):shutil.rmtree(opt.log_path)os.makedirs(opt.log_path)writer = SummaryWriter(opt.log_path)env = Tetris(width=opt.width, height=opt.height, block_size=opt.block_size)model = DeepQNetwork()optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)criterion = nn.MSELoss()state = env.reset()if torch.cuda.is_available():model.cuda()state = state.cuda()replay_memory = deque(maxlen=opt.replay_memory_size)epoch = 0while epoch < opt.num_epochs:next_steps = env.get_next_states()# Exploration or exploitationepsilon = opt.final_epsilon + (max(opt.num_decay_epochs - epoch, 0) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_decay_epochs)u = random()random_action = u <= epsilonnext_actions, next_states = zip(*next_steps.items())next_states = torch.stack(next_states)if torch.cuda.is_available():next_states = next_states.cuda()model.eval()with torch.no_grad():predictions = model(next_states)[:, 0]model.train()if random_action:index = randint(0, len(next_steps) - 1)else:index = torch.argmax(predictions).item()next_state = next_states[index, :]action = next_actions[index]reward, done = env.step(action, render=True)if torch.cuda.is_available():next_state = next_state.cuda()replay_memory.append([state, reward, next_state, done])if done:final_score = env.scorefinal_tetrominoes = env.tetrominoesfinal_cleared_lines = env.cleared_linesstate = env.reset()if torch.cuda.is_available():state = state.cuda()else:state = next_statecontinueif len(replay_memory) < opt.replay_memory_size / 10:continueepoch += 1batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))state_batch, reward_batch, next_state_batch, done_batch = zip(*batch)state_batch = torch.stack(tuple(state for state in state_batch))reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])next_state_batch = torch.stack(tuple(state for state in next_state_batch))if torch.cuda.is_available():state_batch = state_batch.cuda()reward_batch = reward_batch.cuda()next_state_batch = next_state_batch.cuda()q_values = model(state_batch)model.eval()with torch.no_grad():next_prediction_batch = model(next_state_batch)model.train()y_batch = torch.cat(tuple(reward if done else reward + opt.gamma * prediction for reward, done, prediction inzip(reward_batch, done_batch, next_prediction_batch)))[:, None]optimizer.zero_grad()loss = criterion(q_values, y_batch)loss.backward()optimizer.step()print("Epoch: {}/{}, Action: {}, Score: {}, Tetrominoes {}, Cleared lines: {}".format(epoch,opt.num_epochs,action,final_score,final_tetrominoes,final_cleared_lines))writer.add_scalar('Train/Score', final_score, epoch - 1)writer.add_scalar('Train/Tetrominoes', final_tetrominoes, epoch - 1)writer.add_scalar('Train/Cleared lines', final_cleared_lines, epoch - 1)if epoch > 0 and epoch % opt.save_interval == 0:torch.save(model, "{}/tetris_{}".format(opt.saved_path, epoch))torch.save(model, "{}/tetris".format(opt.saved_path))if __name__ == "__main__":opt = get_args()train(opt)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/13910.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python web实战 | 用 Flask 框架快速构建 Web 应用【实战】

概要 Python web 开发已经有了相当长的历史&#xff0c;从最早的 CGI 脚本到现在的全栈 Web 框架&#xff0c;现在已经成为了一种非常流行的方式。 Python 最早被用于 Web 开发是在 1995 年&#xff08;90年代早期&#xff09;&#xff0c;当时使用 CGI 脚本编写动态 Web 页面…

spring启动流程 (6完结) springmvc启动流程

SpringMVC的启动入口在SpringServletContainerInitializer类&#xff0c;它是ServletContainerInitializer实现类(Servlet3.0新特性)。在实现方法中使用WebApplicationInitializer创建ApplicationContext、创建注册DispatcherServlet、初始化ApplicationContext等。 SpringMVC…

68. 文本左右对齐

题目链接&#xff1a;力扣 解题思路&#xff1a;遍历单词数组&#xff0c;确定每一行的单词数量&#xff0c; 之后就可以得到每一个需要补充的空格数量。从而得到单词之间需要补充的空格数量。具体算法如下&#xff1a; 确定每一行的单词数量 初始值&#xff1a; num 0&…

【JavaWeb】正则表达式

&#x1f384;欢迎来到边境矢梦的csdn博文&#xff0c;本文主要讲解Java 中正则表达式 的相关知识&#x1f384; &#x1f308;我是边境矢梦&#xff0c;一个正在为秋招和算法竞赛做准备的学生&#x1f308; &#x1f386;喜欢的朋友可以关注一下&#x1faf0;&#x1faf0;&am…

2023年的深度学习入门指南(22) - 百川大模型13B的运行及量化

2023年的深度学习入门指南(22) - 百川大模型13B的运行及量化 不知道上一讲的大段代码大家看晕了没有。但是如果你仔细看了会发现&#xff0c;其实代码还是不全的。比如分词器我们就没讲。 另外&#xff0c;13B比7B的改进点也没有讲。 再有&#xff0c;对于13B需要多少显存我们…

ios 查看模拟器沙盒的路径

打一个断点运行程序&#xff0c;在xcode consol底部控制台输入&#xff1a; po NSHomeDirectory() 复制路径粘帖到前往文件夹打开沙盒缓存文件夹

golang pprof

pprof是一个用于分析数据的可视化和分析工具&#xff0c;由谷歌公司的开发团队使用go语言编写成的。一般用于对golang资源占用进行分析。不是原创&#xff0c;参考&#xff1a;https://juejin.cn/post/7122473470424219656 1. 通过页面查看golang运行情况 访问 http://127.0.0…

ppt怎么压缩到10m以内?分享好用的压缩方法

PPT是一种常见的演示文稿格式&#xff0c;有时候文件过大&#xff0c;我们会遇到无法发送、上传的现象&#xff0c;这时候简单的解决方法就是压缩其大小&#xff0c;那怎么才能将PPT压缩到10M以内呢&#xff1f; PPT文件大小受到影响的主要因素就是以下几点&#xff1a; 1、图…

Keepalived 在CentOS安装

下载 有两种下载方式&#xff0c;一种为yum源下载&#xff0c;另一种通过源代码下载&#xff0c;本文章使用源代码编译下载。 官网下载地址&#xff1a;https://www.keepalived.org/download.html wget https://www.keepalived.org/software/keepalived-2.0.20.tar.gz --no-…

CNN卷积详解

转载自&#xff1a;https://blog.csdn.net/yilulvxing/article/details/107452153 仅用于自己学习过程中经典文章讲解的记录&#xff0c;防止原文失效。 1&#xff1a;单通道卷积 以单通道卷积为例&#xff0c;输入为&#xff08;1,5,5&#xff09;&#xff0c;分别表示1个通道…

支配树学习笔记

学习链接【学习笔记】支配树_cz_xuyixuan的博客-CSDN博客 主要的求法是最后两个结论&#xff1a; 定理4用来求sdom&#xff0c;先搞一个dfs树&#xff0c;然后将点按dfs序从大到小加入&#xff0c;对每个点维护到当前根&#xff08;即已加入点&#xff09;路径上sdom最小是哪个…

sky-notes-01

1、DTO类 DTO&#xff08;Data Transfer Object&#xff09;&#xff1a;数据传输对象&#xff0c;Service 或 Manager 向外传输的对象。 详见阿里巴巴Java开发手册中的DO、DTO、BO、AO、VO、POJO定义 当前端提交的数据和实体类中对应的属性差别比较大时&#xff0c;建议使用…

级联选择框

文章目录 实现级联选择框效果图实现前端工具版本添加依赖main.js导入依赖级联选择框样式 后端数据库设计 实现级联选择框 效果图 实现 前端 工具版本 node.js v16.6.0vue3 级联选择框使用 Element-Plus 实现 添加依赖 在 package.json 添加依赖&#xff0c;并 npm i 导入…

【LeetCode】28. 找出字符串中第一个匹配项的下标

题目&#xff1a; 28. 找出字符串中第一个匹配项的下标 这道题一看就是经典的KMP算法求解字符串模式匹配问题。 但这里我用了java里自带的字符串匹配函数 indexOf(),虽然有点偷懒&#xff0c;但运行结果还不错。主要是怕有时候竞赛会突然忘了一些算法&#xff0c;不过有时候多…

Minio在windows环境配置https访问

minio启动后&#xff0c;默认访问方式为http&#xff0c;但是有的时候我们的访问场景必须是https&#xff0c;浏览器有的会默认以https进行访问&#xff0c;这个时候就需要我们进行配置上的调整&#xff0c;将minio从http访问升级到https。而查看minio的官方文档&#xff0c;并…

【Spring Cloud Alibaba】限流--Sentinel

文章目录 概述一、Sentinel 是啥&#xff1f;二、Sentinel 的生态环境三、Sentinel 核心概念3.1、资源3.2、规则 四、Sentinel 限流4.1、单机限流4.1.1、引入依赖4.1.2、定义限流规则4.1.3、定义限流资源4.1.4、运行结果 4.2、控制台限流4.2.1、客户端接入控制台4.2.2、引入依赖…

【Qt】QML-02:QQuickView用法

1、先看demo QtCreator自动生成的工程是使用QQmlApplicationEngine来加载qml文件&#xff0c;下面的demo将使用QQuickView来加载qml文件 #include <QGuiApplication> #include <QtQuick/QQuickView>int main(int argc, char *argv[]) {QGuiApplication app(argc,…

layui各种事件无效(例如表格重载或 分页插件按钮失效)的解决方法

下图是我一个系统的操作日志&#xff0c;在分页插件右下角嵌入了一个导出所有数据的按钮 &#xff0c;代码没有任何问题&#xff0c;点击导出按钮却失效 排查之后&#xff0c;发现表格标签table定义了ID又定义了lay-filter&#xff0c;因我使用的layui从2.7.6升级到2.8.11&…

JavaScript高级——ES6基础入门

目录 前言let 和 const块级作用域模板字符串一.模板字符串是什么二.模板字符串的注意事项三. 模板字符串的应用 箭头函数一.箭头函数是什么二.普通函数与箭头函数的转换三.this指向1. 全局作用域中的 this 指向2. 一般函数&#xff08;非箭头函数&#xff09;中的this指向3.箭头…

SSL 证书过期巡检脚本

哈喽大家好&#xff0c;我是咸鱼 我们知道 SSL 证书是会过期的&#xff0c;一旦过期之后需要重新申请。如果没有及时更换证书的话&#xff0c;就有可能导致网站出问题&#xff0c;给公司业务带来一定的影响 所以说我们要每隔一定时间去检查网站上的 SSL 证书是否过期 如果公…