2024Datawhale AI夏令营---Inclusion・The Global Multimedia Deepfake Detection--学习笔记

赛题背景:

        其实总结起来就是一句话,这个项目是基于目前的深度伪装技术,就是通过大量人脸的原数据集进行模型训练之后,能够生成伪造的人脸视频。这项目就是教我们如何去实现这个DeepFake技术。

Task1:了解Deepfake和跑通baseline

代码架构如下:

  1. 模型定义:使用timm库创建一个预训练的resnet18模型。

  2. 训练/验证数据加载:使用torch.utils.data.DataLoader来加载训练集和验证集数据,并通过定义的transforms进行数据增强。

  3. 训练与验证过程

    1. 定义了train函数来执行模型在一个epoch上的训练过程,包括前向传播、损失计算、反向传播和参数更新。

    2. 定义了validate函数来评估模型在验证集上的性能,计算准确率。

  4. 性能评估:使用准确率(Accuracy)作为性能评估的主要指标,并在每个epoch后输出验证集上的准确率。

  5. 提交:最后,将预测结果保存到CSV文件中,准备提交到Kaggle比赛。

代码解释如下:

详见代码注释吧

这份代码后续还是要好好精读理解一下的吧,好好分析一下,顺便提升一下代码能力。--7.15

from PIL import Image
Image.open('/kaggle/input/deepfake/phase1/trainset/63fee8a89581307c0b4fd05a48e0ff79.jpg')import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = Trueimport torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import timm
import timeimport pandas as pd
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm_notebooktrain_label = pd.read_csv('/kaggle/input/deepfake/phase1/trainset_label.txt')
val_label = pd.read_csv('/kaggle/input/deepfake/phase1/valset_label.txt')train_label['path'] = '/kaggle/input/deepfake/phase1/trainset/' + train_label['img_name']
val_label['path'] = '/kaggle/input/deepfake/phase1/valset/' + val_label['img_name']train_label['target'].value_counts()val_label['target'].value_counts()train_label.head(10)class AverageMeter(object):"""Computes and stores the average and current value"""def __init__(self, name, fmt=':f'):self.name = nameself.fmt = fmtself.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.countdef __str__(self):fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'return fmtstr.format(**self.__dict__)class ProgressMeter(object):def __init__(self, num_batches, *meters):self.batch_fmtstr = self._get_batch_fmtstr(num_batches)self.meters = metersself.prefix = ""def pr2int(self, batch):entries = [self.prefix + self.batch_fmtstr.format(batch)]entries += [str(meter) for meter in self.meters]print('\t'.join(entries))def _get_batch_fmtstr(self, num_batches):num_digits = len(str(num_batches // 1))fmt = '{:' + str(num_digits) + 'd}'return '[' + fmt + '/' + fmt.format(num_batches) + ']'def validate(val_loader, model, criterion):batch_time = AverageMeter('Time', ':6.3f')losses = AverageMeter('Loss', ':.4e')top1 = AverageMeter('Acc@1', ':6.2f')progress = ProgressMeter(len(val_loader), batch_time, losses, top1)# switch to evaluate modemodel.eval()with torch.no_grad():end = time.time()for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):input = input.cuda()target = target.cuda()# compute outputoutput = model(input)loss = criterion(output, target)# measure accuracy and record lossacc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100losses.update(loss.item(), input.size(0))top1.update(acc, input.size(0))# measure elapsed timebatch_time.update(time.time() - end)end = time.time()# TODO: this should also be done with the ProgressMeterprint(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))return top1def predict(test_loader, model, tta=10):# switch to evaluate modemodel.eval()test_pred_tta = Nonefor _ in range(tta):test_pred = []with torch.no_grad():end = time.time()for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):input = input.cuda()target = target.cuda()# compute outputoutput = model(input)output = F.softmax(output, dim=1)output = output.data.cpu().numpy()test_pred.append(output)test_pred = np.vstack(test_pred)if test_pred_tta is None:test_pred_tta = test_predelse:test_pred_tta += test_predreturn test_pred_ttadef train(train_loader, model, criterion, optimizer, epoch):batch_time = AverageMeter('Time', ':6.3f')losses = AverageMeter('Loss', ':.4e')top1 = AverageMeter('Acc@1', ':6.2f')progress = ProgressMeter(len(train_loader), batch_time, losses, top1)# switch to train modemodel.train()end = time.time()for i, (input, target) in enumerate(train_loader):input = input.cuda(non_blocking=True)target = target.cuda(non_blocking=True)# compute outputoutput = model(input)loss = criterion(output, target)# measure accuracy and record losslosses.update(loss.item(), input.size(0))acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100top1.update(acc, input.size(0))# compute gradient and do SGD stepoptimizer.zero_grad()loss.backward()optimizer.step()# measure elapsed timebatch_time.update(time.time() - end)end = time.time()if i % 100 == 0:progress.pr2int(i)class FFDIDataset(Dataset):def __init__(self, img_path, img_label, transform=None):self.img_path = img_pathself.img_label = img_labelif transform is not None:self.transform = transformelse:self.transform = Nonedef __getitem__(self, index):img = Image.open(self.img_path[index]).convert('RGB')if self.transform is not None:img = self.transform(img)return img, torch.from_numpy(np.array(self.img_label[index]))def __len__(self):return len(self.img_path)import timm
model = timm.create_model('resnet18', pretrained=True, num_classes=2)
model = model.cuda()train_loader = torch.utils.data.DataLoader(FFDIDataset(train_label['path'].head(1000), train_label['target'].head(1000), transforms.Compose([transforms.Resize((256, 256)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=40, shuffle=True, num_workers=4, pin_memory=True
)val_loader = torch.utils.data.DataLoader(FFDIDataset(val_label['path'].head(1000), val_label['target'].head(1000), transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), 0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
best_acc = 0.0
for epoch in range(2):scheduler.step()print('Epoch: ', epoch)train(train_loader, model, criterion, optimizer, epoch)val_acc = validate(val_loader, model, criterion)if val_acc.avg.item() > best_acc:best_acc = round(val_acc.avg.item(), 2)torch.save(model.state_dict(), f'./model_{best_acc}.pt')test_loader = torch.utils.data.DataLoader(FFDIDataset(val_label['path'], val_label['target'], transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)val_label['y_pred'] = predict(test_loader, model, 1)[:, 1]
val_label[['img_name', 'y_pred']].to_csv('submit.csv', index=None)

        本来是想直接在本地运行的,但是这个数据集实在是太大了,受限于操作和设备,只能在kaggle云运行这个代码咯,结果如下:

        提交上kaggle进行评分:

Inclusion・The Global Multimedia Deepfake Detection | Kaggle

        结果挺差的,毕竟这就是个普通的原始代码,啥都没优化的,参数也没调,只能先这样咯,后续task再来优化调整咔咔上分吧。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/46984.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv8目标跟踪deepsort

原文:YOLOv8目标跟踪deepsort - 知乎 (zhihu.com) 一、YOLOV8 tracking 参考文章: Ctrl CV:YoloV5 + DeepSort1 赞同 0 评论文章 二、行人重识别(ReID) ——Market-1501 数据集 2.1、数据集简介 Market-1501 数据集在清华大学校园中采集,夏天拍摄,在 2015 年构建并…

【PyTorch快速入门教程】02 Jupyter notebook安装及配置

文章目录 1 安装 Jupyter notebook2 安装 ipykernel3 更改 jupyter 默认配置3.1 生成配置文件3.2 关键配置信息 4 扩展插件推荐参考 1 安装 Jupyter notebook 一行命令搞定 python -m pip install jupyter 现在就可以打开Jupyter notebook来运行python啦。 jupyter notebook…

去除重复字母

题目链接 去除重复字母 题目描述 注意点 s 由小写英文字母组成1 < s.length < 10^4需保证 返回结果的字典序最小&#xff08;要求不能打乱其他字符的相对位置&#xff09; 解答思路 本题与移掉 K 位数字类似&#xff0c;需要注意的是&#xff0c;并不是每个字母都能…

Windows安装Pycharm及汉化教程

在安装好了Python之后呢&#xff0c;我们需要更方便的进行编写代码&#xff0c;使用Python自带的IDLE和命令行是不太友好的。 那么有没有一款免费好用的写代码工具呢&#xff1f;答案是有的&#xff01; PyCharm 是由 JetBrains 打造的一款 Python IDE&#xff0c;提供代码分析…

SQL常用数据过滤---IN操作符

在SQL中&#xff0c;IN操作符常用于过滤数据&#xff0c;允许在WHERE子句中指定多个可能的值。如果列中的值匹配IN操作符后面括号中的任何一个值&#xff0c;那么该行就会被选中。 以下是使用IN操作符的基本语法&#xff1a; SELECT column1, column2, ... FROM table_name WH…

本地多模态看图说话-llava

其中图片为bast64转码&#xff0c;方便json序列化。 其中模型llava为本地ollama运行的模型&#xff0c;如&#xff1a;ollama run llava 还有其它的模型如&#xff1a;llava-phi3&#xff0c;通过phi3微调过的版本。 实际测试下来&#xff0c;发现本地多模型的性能不佳&…

怎么将几个pdf合成为一个pdf?几个合并PDF文件的方法

怎么将几个pdf合成为一个pdf&#xff1f;当需要将多个PDF文件合并成一个单一的PDF文件时&#xff0c;这种操作不仅能够提高文件管理的效率&#xff0c;还能使得相关文档更加集中和易于访问。合并PDF的过程不仅仅是简单地将几个文件结合在一起&#xff0c;更是将信息整合成一个更…

遥感降水评估

遥感降水可以作为地面雨量计和雷达观测降水的补充&#xff0c;在偏远山区和缺资料地区更为适合。目前&#xff0c;学界有多种降水数据&#xff0c;每一种降水数据都有独特的方法制作。因此&#xff0c;在使用前需要对这些降水的可靠性进行评估。在获得误差基础上&#xff0c;方…

Apollo 常见math库学习

1 Vec2d 向量表示point vec2d.h #pragma once // 定义二维向量类 #include <cmath> #include <string>/*** namespace apollo::common::math* brief apollo::common::math*/ namespace apollo { namespace common { namespace math {constexpr double kMathEpsil…

刷题日志——模拟专题(python实现)

模拟往往不需要设计太多的算法&#xff0c;而是要按照题目的要求尽可能用代码表示出题目的旨意。 以下是蓝桥杯官网模拟专题的选题&#xff0c;大多数比较基础&#xff0c;但是十分适合新手入门&#xff1a; 一. 可链接在线OJ题 饮料换购图像模糊螺旋矩阵冰雹数回文日期长草最…

华为以客户为中心的战略

2005年&#xff0c;伴随着国际化步伐的加快&#xff0c;华为重新梳理了自己的愿景、使命和发展战略&#xff0c;提出了以客户为中心的战略定位&#xff1a; 为客户服务是华为存在的唯一理由&#xff1b;客户需求是华为发展的原动力。质量好、服务好、运作成本低&#xff0c;优…

mac安装win10到外接固态硬盘

1、制作win10系统 1.1 下载 winToUSB&#xff0c;打开后选择第一个 1.2 选择本地下载镜像&#xff0c; 我用的分区方案是适用于UEFI的GPT模式 1.3 点右下角执行&#xff0c;等待执行完成即可 2、mac系统下载win驱动 2.1 comman空格 搜索启动转换助理&#xff0c;打开后选择…

前端框架入门之Vue _el和data的两种写法 分析MVVM模型

目录 _el与data的两种写法 MVVM模型 _el与data的两种写法 查看vue的实例对象 我们在这边注释掉了el属性 这样的话div容器就绑定不了vue实例 当我们可以在这里写一个定时任务 然后再回头指定 这个mount有挂载的意思 就是把容器对象交给vue实例后 去给他挂载指定的对象 &…

深入解析HTTPS与HTTP

在当今数字化时代&#xff0c;网络安全已成为社会各界关注的焦点。随着互联网技术的飞速发展&#xff0c;个人和企业的数据安全问题日益凸显。在此背景下&#xff0c;HTTPS作为一种更加安全的通信协议&#xff0c;逐渐取代了传统的HTTP协议&#xff0c;成为保护网络安全的重要屏…

【概率论三】参数估计

文章目录 一. 点估计1. 矩估计法2. 极大似然法1. 似然函数2. 极大似然估计 3. 评价估计量的标准2.1. 无偏性2.2. 有效性2.3. 一致性 三. 区间估计1. 区间估计的概念2. 正态总体参数的区间估计 参数估计讲什么 由样本来确定未知参数参数估计分为点估计与区间估计 一. 点估计 所…

IDEA启动Web项目总是提示端口占用

文章目录 IDEA启动Web项目总是提示端口占用一、前言1.场景2.环境 二、正文1.场景一:真端口占用2. 场景二:假端口占用 IDEA启动Web项目总是提示端口占用 一、前言 1.场景 IDEA启动Web项目总是提示端口占用&#xff1a; 确实是端口被占用&#xff0c;比如&#xff1a;没有正常…

clion中建立c文件工程,读取或创建sqlite3数据库文件

1.首先前往SQLite官网下载sqlite3所需文件 SQLite Download Page 2.解压文件&#xff0c;将其中的sqlite3.c和sqlite3.h拷贝到你对应的文件工程中 3.修改CMakeLists.txt文件&#xff0c;添加编译选项及连接文件 4.运行代码及查询数据库文件

【NLP自然语言处理】基于BERT实现文本情感分类

Bert概述 BERT&#xff08;Bidirectional Encoder Representations from Transformers&#xff09;是一种深度学习模型&#xff0c;用于自然语言处理&#xff08;NLP&#xff09;任务。BERT的核心是由一种强大的神经网络架构——Transformer驱动的。这种架构包含了一种称为自注…

【Mamba】Mamba的部署

ubuntu系统安装11.6版本的cuda 可以参考这两篇博客 ubuntu22.04多版本安装cuda及快速切换&#xff08;cuda11.1和11.8&#xff09;_ubuntu调整cuda版本 【Linux】在一台机器上同时安装多个版本的CUDA&#xff08;切换CUDA版本&#xff09;_linux安装多个cuda 安装CUDA https…

【React打卡学习第一天】

React入门 一、简介二、基本使用1.引入相关js库2.babel.js的作用 二、创建虚拟DOM三、JSX&#xff08;JavaScript XML&#xff09;1.本质2.作用3.基本语法规则定义虚拟DOM时&#xff0c;不要写引号。标签中混入JS表达式时要用{}。样式的类名指定不要用class,要用className.内联…