Pytorch个人学习记录总结 玩俄罗斯方块の深度学习小项目

目录

前言

模型成果演示

训练过程演示

 代码实现

deep_network

tetris

test

train


前言

当今,深度学习在各个领域展现出了惊人的应用潜力,而游戏开发领域也不例外。俄罗斯方块作为经典的益智游戏,一直以来深受玩家喜爱。在这个项目中,我将深度学习与游戏开发相结合,通过使用PyTorch,为俄罗斯方块赋予了智能化的能力。

这个深度学习项目的目标是训练一个模型,使其能够自动玩俄罗斯方块,并且在游戏中取得高分。通过使用神经网络,我以游戏的状态作为输入,然后模型将预测最佳的移动策略,从而使方块能够正确地落下并消除行。通过反复训练和优化,我希望能够让模型达到专业玩家的水平,并且掌握一些高级策略。

本博客将详细介绍我在这个项目中所采用的深度学习方法和技术。我将分享我的代码实现,并解释我在训练过程中所遇到的挑战和解决方案。无论你是对深度学习感兴趣还是对俄罗斯方块情有独钟,这个项目都能够给你带来一些启发和思考。

我相信通过将深度学习和游戏开发相结合,我们能够为游戏带来全新的可能性。让我们一起探索这个项目,看看深度学习如何在俄罗斯方块这个经典游戏中展现其强大的应用能力吧!

模型成果演示

Pytorch个人学习记录总结 俄罗斯方块の深度学习小项目

训练过程演示

Pytorch个人学习记录总结 俄罗斯方块の深度学习小项目

 代码实现

deep_network

import torch.nn as nnclass DeepQNetwork(nn.Module):def __init__(self):super(DeepQNetwork, self).__init__()self.conv1 = nn.Sequential(nn.Linear(4, 64), nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Linear(64, 64), nn.ReLU(inplace=True))self.conv3 = nn.Sequential(nn.Linear(64, 1))self._create_weights()def _create_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.constant_(m.bias, 0)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)return x

tetris

import numpy as np
from PIL import Image
import cv2
from matplotlib import style
import torch
import randomstyle.use("ggplot")class Tetris:piece_colors = [(0, 0, 0),(255, 255, 0),(147, 88, 254),(54, 175, 144),(255, 0, 0),(102, 217, 238),(254, 151, 32),(0, 0, 255)]pieces = [[[1, 1],[1, 1]],[[0, 2, 0],[2, 2, 2]],[[0, 3, 3],[3, 3, 0]],[[4, 4, 0],[0, 4, 4]],[[5, 5, 5, 5]],[[0, 0, 6],[6, 6, 6]],[[7, 0, 0],[7, 7, 7]]]def __init__(self, height=20, width=10, block_size=20):self.height = heightself.width = widthself.block_size = block_sizeself.extra_board = np.ones((self.height * self.block_size, self.width * int(self.block_size / 2), 3),dtype=np.uint8) * np.array([204, 204, 255], dtype=np.uint8)self.text_color = (200, 20, 220)self.reset()def reset(self):self.board = [[0] * self.width for _ in range(self.height)]self.score = 0self.tetrominoes = 0self.cleared_lines = 0self.bag = list(range(len(self.pieces)))random.shuffle(self.bag)self.ind = self.bag.pop()self.piece = [row[:] for row in self.pieces[self.ind]]self.current_pos = {"x": self.width // 2 - len(self.piece[0]) // 2, "y": 0}self.gameover = Falsereturn self.get_state_properties(self.board)def rotate(self, piece):num_rows_orig = num_cols_new = len(piece)num_rows_new = len(piece[0])rotated_array = []for i in range(num_rows_new):new_row = [0] * num_cols_newfor j in range(num_cols_new):new_row[j] = piece[(num_rows_orig - 1) - j][i]rotated_array.append(new_row)return rotated_arraydef get_state_properties(self, board):lines_cleared, board = self.check_cleared_rows(board)holes = self.get_holes(board)bumpiness, height = self.get_bumpiness_and_height(board)return torch.FloatTensor([lines_cleared, holes, bumpiness, height])def get_holes(self, board):num_holes = 0for col in zip(*board):row = 0while row < self.height and col[row] == 0:row += 1num_holes += len([x for x in col[row + 1:] if x == 0])return num_holesdef get_bumpiness_and_height(self, board):board = np.array(board)mask = board != 0invert_heights = np.where(mask.any(axis=0), np.argmax(mask, axis=0), self.height)heights = self.height - invert_heightstotal_height = np.sum(heights)currs = heights[:-1]nexts = heights[1:]diffs = np.abs(currs - nexts)total_bumpiness = np.sum(diffs)return total_bumpiness, total_heightdef get_next_states(self):states = {}piece_id = self.indcurr_piece = [row[:] for row in self.piece]if piece_id == 0:  # O piecenum_rotations = 1elif piece_id == 2 or piece_id == 3 or piece_id == 4:num_rotations = 2else:num_rotations = 4for i in range(num_rotations):valid_xs = self.width - len(curr_piece[0])for x in range(valid_xs + 1):piece = [row[:] for row in curr_piece]pos = {"x": x, "y": 0}while not self.check_collision(piece, pos):pos["y"] += 1self.truncate(piece, pos)board = self.store(piece, pos)states[(x, i)] = self.get_state_properties(board)curr_piece = self.rotate(curr_piece)return statesdef get_current_board_state(self):board = [x[:] for x in self.board]for y in range(len(self.piece)):for x in range(len(self.piece[y])):board[y + self.current_pos["y"]][x + self.current_pos["x"]] = self.piece[y][x]return boarddef new_piece(self):if not len(self.bag):self.bag = list(range(len(self.pieces)))random.shuffle(self.bag)self.ind = self.bag.pop()self.piece = [row[:] for row in self.pieces[self.ind]]self.current_pos = {"x": self.width // 2 - len(self.piece[0]) // 2,"y": 0}if self.check_collision(self.piece, self.current_pos):self.gameover = Truedef check_collision(self, piece, pos):future_y = pos["y"] + 1for y in range(len(piece)):for x in range(len(piece[y])):if future_y + y > self.height - 1 or self.board[future_y + y][pos["x"] + x] and piece[y][x]:return Truereturn Falsedef truncate(self, piece, pos):gameover = Falselast_collision_row = -1for y in range(len(piece)):for x in range(len(piece[y])):if self.board[pos["y"] + y][pos["x"] + x] and piece[y][x]:if y > last_collision_row:last_collision_row = yif pos["y"] - (len(piece) - last_collision_row) < 0 and last_collision_row > -1:while last_collision_row >= 0 and len(piece) > 1:gameover = Truelast_collision_row = -1del piece[0]for y in range(len(piece)):for x in range(len(piece[y])):if self.board[pos["y"] + y][pos["x"] + x] and piece[y][x] and y > last_collision_row:last_collision_row = yreturn gameoverdef store(self, piece, pos):board = [x[:] for x in self.board]for y in range(len(piece)):for x in range(len(piece[y])):if piece[y][x] and not board[y + pos["y"]][x + pos["x"]]:board[y + pos["y"]][x + pos["x"]] = piece[y][x]return boarddef check_cleared_rows(self, board):to_delete = []for i, row in enumerate(board[::-1]):if 0 not in row:to_delete.append(len(board) - 1 - i)if len(to_delete) > 0:board = self.remove_row(board, to_delete)return len(to_delete), boarddef remove_row(self, board, indices):for i in indices[::-1]:del board[i]board = [[0 for _ in range(self.width)]] + boardreturn boarddef step(self, action, render=True, video=None):x, num_rotations = actionself.current_pos = {"x": x, "y": 0}for _ in range(num_rotations):self.piece = self.rotate(self.piece)while not self.check_collision(self.piece, self.current_pos):self.current_pos["y"] += 1if render:self.render(video)overflow = self.truncate(self.piece, self.current_pos)if overflow:self.gameover = Trueself.board = self.store(self.piece, self.current_pos)lines_cleared, self.board = self.check_cleared_rows(self.board)score = 1 + (lines_cleared ** 2) * self.widthself.score += scoreself.tetrominoes += 1self.cleared_lines += lines_clearedif not self.gameover:self.new_piece()if self.gameover:self.score -= 2return score, self.gameoverdef render(self, video=None):if not self.gameover:img = [self.piece_colors[p] for row in self.get_current_board_state() for p in row]else:img = [self.piece_colors[p] for row in self.board for p in row]img = np.array(img).reshape((self.height, self.width, 3)).astype(np.uint8)img = img[..., ::-1]img = Image.fromarray(img, "RGB")img = img.resize((self.width * self.block_size, self.height * self.block_size), 0)img = np.array(img)img[[i * self.block_size for i in range(self.height)], :, :] = 0img[:, [i * self.block_size for i in range(self.width)], :] = 0img = np.concatenate((img, self.extra_board), axis=1)cv2.putText(img, "Score:", (self.width * self.block_size + int(self.block_size / 2), self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, str(self.score),(self.width * self.block_size + int(self.block_size / 2), 2 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, "Pieces:", (self.width * self.block_size + int(self.block_size / 2), 4 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, str(self.tetrominoes),(self.width * self.block_size + int(self.block_size / 2), 5 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, "Lines:", (self.width * self.block_size + int(self.block_size / 2), 7 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)cv2.putText(img, str(self.cleared_lines),(self.width * self.block_size + int(self.block_size / 2), 8 * self.block_size),fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)if video:video.write(img)cv2.imshow("Deep Q-Learning Tetris", img)cv2.waitKey(1)

test

import argparse
import torch
import cv2
from src.tetris import Tetrisdef get_args():parser = argparse.ArgumentParser("""Implementation of Deep Q Network to play Tetris""")parser.add_argument("--width", type=int, default=10, help="The common width for all images")parser.add_argument("--height", type=int, default=20, help="The common height for all images")parser.add_argument("--block_size", type=int, default=30, help="Size of a block")parser.add_argument("--fps", type=int, default=300, help="frames per second")parser.add_argument("--saved_path", type=str, default="trained_models")parser.add_argument("--output", type=str, default="output.mp4")args = parser.parse_args()return argsdef run_test(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)if torch.cuda.is_available():model = torch.load("{}/tetris".format(opt.saved_path))else:model = torch.load("{}/tetris".format(opt.saved_path), map_location=lambda storage, loc: storage)model.eval()env = Tetris(width=opt.width, height=opt.height, block_size=opt.block_size)env.reset()if torch.cuda.is_available():model.cuda()out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), opt.fps,(int(1.5*opt.width*opt.block_size), opt.height*opt.block_size))while True:next_steps = env.get_next_states()next_actions, next_states = zip(*next_steps.items())next_states = torch.stack(next_states)if torch.cuda.is_available():next_states = next_states.cuda()predictions = model(next_states)[:, 0]index = torch.argmax(predictions).item()action = next_actions[index]_, done = env.step(action, render=True, video=out)if done:out.release()breakif __name__ == "__main__":opt = get_args()run_test(opt)

train

import argparse
import os
import shutil
from random import random, randint, sampleimport numpy as np
import torch
import torch.nn as nn
from tensorboardX import SummaryWriterfrom src.deep_q_network import DeepQNetwork
from src.tetris import Tetris
from collections import dequedef get_args():parser = argparse.ArgumentParser("""Implementation of Deep Q Network to play Tetris""")parser.add_argument("--width", type=int, default=10, help="The common width for all images")parser.add_argument("--height", type=int, default=20, help="The common height for all images")parser.add_argument("--block_size", type=int, default=30, help="Size of a block")parser.add_argument("--batch_size", type=int, default=512, help="The number of images per batch")parser.add_argument("--lr", type=float, default=1e-3)parser.add_argument("--gamma", type=float, default=0.99)parser.add_argument("--initial_epsilon", type=float, default=1)parser.add_argument("--final_epsilon", type=float, default=1e-3)parser.add_argument("--num_decay_epochs", type=float, default=2000)parser.add_argument("--num_epochs", type=int, default=3000)parser.add_argument("--save_interval", type=int, default=1000)parser.add_argument("--replay_memory_size", type=int, default=30000,help="Number of epoches between testing phases")parser.add_argument("--log_path", type=str, default="tensorboard")parser.add_argument("--saved_path", type=str, default="trained_models")args = parser.parse_args()return argsdef train(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)if os.path.isdir(opt.log_path):shutil.rmtree(opt.log_path)os.makedirs(opt.log_path)writer = SummaryWriter(opt.log_path)env = Tetris(width=opt.width, height=opt.height, block_size=opt.block_size)model = DeepQNetwork()optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)criterion = nn.MSELoss()state = env.reset()if torch.cuda.is_available():model.cuda()state = state.cuda()replay_memory = deque(maxlen=opt.replay_memory_size)epoch = 0while epoch < opt.num_epochs:next_steps = env.get_next_states()# Exploration or exploitationepsilon = opt.final_epsilon + (max(opt.num_decay_epochs - epoch, 0) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_decay_epochs)u = random()random_action = u <= epsilonnext_actions, next_states = zip(*next_steps.items())next_states = torch.stack(next_states)if torch.cuda.is_available():next_states = next_states.cuda()model.eval()with torch.no_grad():predictions = model(next_states)[:, 0]model.train()if random_action:index = randint(0, len(next_steps) - 1)else:index = torch.argmax(predictions).item()next_state = next_states[index, :]action = next_actions[index]reward, done = env.step(action, render=True)if torch.cuda.is_available():next_state = next_state.cuda()replay_memory.append([state, reward, next_state, done])if done:final_score = env.scorefinal_tetrominoes = env.tetrominoesfinal_cleared_lines = env.cleared_linesstate = env.reset()if torch.cuda.is_available():state = state.cuda()else:state = next_statecontinueif len(replay_memory) < opt.replay_memory_size / 10:continueepoch += 1batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))state_batch, reward_batch, next_state_batch, done_batch = zip(*batch)state_batch = torch.stack(tuple(state for state in state_batch))reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])next_state_batch = torch.stack(tuple(state for state in next_state_batch))if torch.cuda.is_available():state_batch = state_batch.cuda()reward_batch = reward_batch.cuda()next_state_batch = next_state_batch.cuda()q_values = model(state_batch)model.eval()with torch.no_grad():next_prediction_batch = model(next_state_batch)model.train()y_batch = torch.cat(tuple(reward if done else reward + opt.gamma * prediction for reward, done, prediction inzip(reward_batch, done_batch, next_prediction_batch)))[:, None]optimizer.zero_grad()loss = criterion(q_values, y_batch)loss.backward()optimizer.step()print("Epoch: {}/{}, Action: {}, Score: {}, Tetrominoes {}, Cleared lines: {}".format(epoch,opt.num_epochs,action,final_score,final_tetrominoes,final_cleared_lines))writer.add_scalar('Train/Score', final_score, epoch - 1)writer.add_scalar('Train/Tetrominoes', final_tetrominoes, epoch - 1)writer.add_scalar('Train/Cleared lines', final_cleared_lines, epoch - 1)if epoch > 0 and epoch % opt.save_interval == 0:torch.save(model, "{}/tetris_{}".format(opt.saved_path, epoch))torch.save(model, "{}/tetris".format(opt.saved_path))if __name__ == "__main__":opt = get_args()train(opt)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/13910.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python web实战 | 用 Flask 框架快速构建 Web 应用【实战】

概要 Python web 开发已经有了相当长的历史&#xff0c;从最早的 CGI 脚本到现在的全栈 Web 框架&#xff0c;现在已经成为了一种非常流行的方式。 Python 最早被用于 Web 开发是在 1995 年&#xff08;90年代早期&#xff09;&#xff0c;当时使用 CGI 脚本编写动态 Web 页面…

spring启动流程 (6完结) springmvc启动流程

SpringMVC的启动入口在SpringServletContainerInitializer类&#xff0c;它是ServletContainerInitializer实现类(Servlet3.0新特性)。在实现方法中使用WebApplicationInitializer创建ApplicationContext、创建注册DispatcherServlet、初始化ApplicationContext等。 SpringMVC…

68. 文本左右对齐

题目链接&#xff1a;力扣 解题思路&#xff1a;遍历单词数组&#xff0c;确定每一行的单词数量&#xff0c; 之后就可以得到每一个需要补充的空格数量。从而得到单词之间需要补充的空格数量。具体算法如下&#xff1a; 确定每一行的单词数量 初始值&#xff1a; num 0&…

【JavaWeb】正则表达式

&#x1f384;欢迎来到边境矢梦的csdn博文&#xff0c;本文主要讲解Java 中正则表达式 的相关知识&#x1f384; &#x1f308;我是边境矢梦&#xff0c;一个正在为秋招和算法竞赛做准备的学生&#x1f308; &#x1f386;喜欢的朋友可以关注一下&#x1faf0;&#x1faf0;&am…

2023年的深度学习入门指南(22) - 百川大模型13B的运行及量化

2023年的深度学习入门指南(22) - 百川大模型13B的运行及量化 不知道上一讲的大段代码大家看晕了没有。但是如果你仔细看了会发现&#xff0c;其实代码还是不全的。比如分词器我们就没讲。 另外&#xff0c;13B比7B的改进点也没有讲。 再有&#xff0c;对于13B需要多少显存我们…

ios 查看模拟器沙盒的路径

打一个断点运行程序&#xff0c;在xcode consol底部控制台输入&#xff1a; po NSHomeDirectory() 复制路径粘帖到前往文件夹打开沙盒缓存文件夹

Oracle存过-对象权限创建回收、同义词创建删除

Oracle存过-对象权限创建回收、同义词创建删除 -- Oracle存过-对象权限创建回收、同义词创建删除--得到对象授权语句--调用&#xff1a;CALL LOG.COMMON_PKG.get_tab_grant_privs_p(LOG,TZQ,INFO);PROCEDURE get_tab_grant_privs_p(pi_grantor IN VARCHAR2,pi_grantee IN …

golang pprof

pprof是一个用于分析数据的可视化和分析工具&#xff0c;由谷歌公司的开发团队使用go语言编写成的。一般用于对golang资源占用进行分析。不是原创&#xff0c;参考&#xff1a;https://juejin.cn/post/7122473470424219656 1. 通过页面查看golang运行情况 访问 http://127.0.0…

使用Streamlit快速搭建和共享交互式应用

大家好&#xff0c;在数据科学和机器学习领域&#xff0c;向他人展示见解和分享结果与分析本身同样重要&#xff0c;然而创建交互式和用户友好型的应用程序通常需要复杂的框架和耗时的开发过程。Streamlit是一个Python库&#xff0c;它简化了以数据为重点的网络应用程序的创建过…

ppt怎么压缩到10m以内?分享好用的压缩方法

PPT是一种常见的演示文稿格式&#xff0c;有时候文件过大&#xff0c;我们会遇到无法发送、上传的现象&#xff0c;这时候简单的解决方法就是压缩其大小&#xff0c;那怎么才能将PPT压缩到10M以内呢&#xff1f; PPT文件大小受到影响的主要因素就是以下几点&#xff1a; 1、图…

如何使用 PHP 进行数据库连接池优化?

连接池是一个存放数据库连接的地方&#xff0c;就像一个水池&#xff0c;你在这里可以得到数据库连接。这比每次都新建和关闭连接要快得多&#xff0c;因为连接池中的连接是可以重复使用的。 下面是一个简单的例子&#xff0c;展示如何使用PHP和PDO&#xff08;PHP Data Objec…

Keepalived 在CentOS安装

下载 有两种下载方式&#xff0c;一种为yum源下载&#xff0c;另一种通过源代码下载&#xff0c;本文章使用源代码编译下载。 官网下载地址&#xff1a;https://www.keepalived.org/download.html wget https://www.keepalived.org/software/keepalived-2.0.20.tar.gz --no-…

Android 项目架构

🔥 什么是架构 🔥 在维基百科里是这样定义的: 软件架构是一个系统的轮廓 . 软件架构描述的对象是直接构成系统的抽象组件. 各个组件之间的连接则明确和相对细致地描述组件之间的通讯 . 在实现阶段, 这些抽象组件被细化为实际组件 , 比如具体某个类或者对象 . 面试的过程中…

CNN卷积详解

转载自&#xff1a;https://blog.csdn.net/yilulvxing/article/details/107452153 仅用于自己学习过程中经典文章讲解的记录&#xff0c;防止原文失效。 1&#xff1a;单通道卷积 以单通道卷积为例&#xff0c;输入为&#xff08;1,5,5&#xff09;&#xff0c;分别表示1个通道…

libuv库学习笔记-networking

Networking 在 libuv 中&#xff0c;网络编程与直接使用 BSD socket 区别不大&#xff0c;有些地方还更简单&#xff0c;概念保持不变的同时&#xff0c;libuv 上所有接口都是非阻塞的。它还提供了很多工具函数&#xff0c;抽象了恼人、啰嗦的底层任务&#xff0c;如使用 BSD …

Git拉取远程分支并创建本地分支

一、查看远程分支 使用如下git命令查看所有远程分支&#xff1a; git branch -r 查看远程和本地所有分支&#xff1a; git branch -a 查看本地分支&#xff1a; git branch 在输出结果中&#xff0c;前面带* 的是当前分支。 二、拉取远程分支并创建本地分支 方法一 使用…

支配树学习笔记

学习链接【学习笔记】支配树_cz_xuyixuan的博客-CSDN博客 主要的求法是最后两个结论&#xff1a; 定理4用来求sdom&#xff0c;先搞一个dfs树&#xff0c;然后将点按dfs序从大到小加入&#xff0c;对每个点维护到当前根&#xff08;即已加入点&#xff09;路径上sdom最小是哪个…

CentOS 8上安装和配置Redis

在本篇博客中&#xff0c;我们将演示如何在CentOS 8上安装和配置Redis。我们将首先安装Redis&#xff0c;然后配置Redis以设置密码并允许公开访问。 步骤 1&#xff1a;安装Redis 首先&#xff0c;更新软件包列表&#xff1a; sudo yum update安装Redis&#xff1a; sudo yum …

sky-notes-01

1、DTO类 DTO&#xff08;Data Transfer Object&#xff09;&#xff1a;数据传输对象&#xff0c;Service 或 Manager 向外传输的对象。 详见阿里巴巴Java开发手册中的DO、DTO、BO、AO、VO、POJO定义 当前端提交的数据和实体类中对应的属性差别比较大时&#xff0c;建议使用…

session无法读取问题解决(cookie浏览器权限)

问题 使用go的 "github.com/gin-contrib/sessions"库对session进行设置并获取时&#xff0c;浏览器拒绝掉请求携带cookie&#xff0c;体现在浏览器上为“被过滤掉的session”&#xff0c;并携带小三角提示符。 基本概念 SameSite Chrome 51 开始&#xff0c;浏览…