import random
import numpy as np
import pandas as pd
import torch
from transformers import BertModel,BertTokenizer
from tqdm.auto import tqdm
from torch.utils.data import Dataset
import re
"""参考Game-On论文"""
"""util.py"""
def set_seed(seed_value=42):random.seed(seed_value)np.random.seed(seed_value)torch.manual_seed(seed_value)torch.cuda.manual_seed_all(seed_value)
"""util.py""""""文本预处理-textGraph.py"""
def text_preprocessing(text):"""- Remove entity mentions (eg. '@united')- Correct errors (eg. '&' to '&')@param text (str): a string to be processed.@return text (Str): the processed string."""text = re.sub(r'(@.*?)[\s]', ' ', text)text = re.sub(r'&', '&', text)text = re.sub(r'\s+', ' ', text).strip()text = re.sub(r'(?P<url>https?://[^\s]+)', r'', text)text = re.sub(r"\@(\w+)", "", text)text = text.replace('#', '')return textclass TextDataset(Dataset):def __init__(self,df,tokenizer):self.df = df.reset_index(drop=True)self.tokenizer = tokenizerdef __len__(self):return len(self.df)def __getitem__(self, idx):if torch.is_tensor(idx):idx = idx.tolist()text = self.df['tweetText'][idx]unique_id = self.df['tweetId'][idx]input_ids = []attention_mask = []encoded_sent = self.tokenizer.encode_plus(text = text_preprocessing(text), add_special_tokens=True, max_length=512, padding='max_length', return_attention_mask=True, truncation=True )input_ids = encoded_sent.get('input_ids')attention_mask = encoded_sent.get('attention_mask')input_ids = torch.tensor(input_ids)attention_mask =torch.tensor(attention_mask)return {'input_ids':input_ids,'attention_mask':attention_mask,'unique_id':unique_id}def store_data(bert,device,df,dataset,store_dir):lengths = []bert.eval()for idx in tqdm(range(len(df))):sample = dataset.__getitem__(idx)print('原始sample[input_ids]和sample[attention_mask]的维度:',sample['input_ids'].shape,sample['attention_mask'].shape)input_ids,attention_mask = sample['input_ids'].unsqueeze(0),sample['attention_mask'].unsqueeze(0)input_ids = input_ids.to(device)attention_mask = attention_mask.to(device)unique_id = sample['unique_id']num_tokens = attention_mask.sum().detach().cpu().item()"""不生成新的计算图,而是只做权重更新"""with torch.no_grad():out = bert(input_ids=input_ids,attention_mask=attention_mask)out_tokens = out.last_hidden_state[:,1:num_tokens,:].detach().cpu().squeeze(0).numpy() filename = f'{emed_dir}{unique_id}.npy'try:np.save(filename, out_tokens)print(f"文件{filename}保存成功")except FileNotFoundError:np.save(filename, out_tokens)print(f"文件{filename}创建成功并保存成功")lengths.append(num_tokens)out_cls = out.last_hidden_state[:,0,:].unsqueeze(0).detach().cpu().squeeze(0).numpy() filename = f'{emed_dir}{unique_id}_full_text.npy'try:np.save(filename, out_cls)print(f"文件{filename}保存成功")except FileNotFoundError:np.save(filename, out_cls)print(f"文件{filename}创建成功并保存成功")return lengthsif __name__=='__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")root_dir = "./dataset/image-verification-corpus-master/image-verification-corpus-master/mediaeval2015/"emed_dir = './Embedding_File'train_csv_name = "tweetsTrain.csv"test_csv_name = "tweetsTest.csv"tokenizer = BertTokenizer.from_pretrained('./bert/')bert = BertModel.from_pretrained('./bert/', return_dict=True)bert = bert.to(device)store_dir ="Embed_Post/"df_train = pd.read_csv(f'{root_dir}{train_csv_name}')df_train = df_train.dropna().reset_index(drop=True)train_dataset = TextDataset(df_train,tokenizer)lengths = store_data(bert, device, df_train, train_dataset, store_dir)df_test = pd.read_csv(f'{root_dir}{test_csv_name}')df_test = df_test.dropna().reset_index(drop=True)test_dataset = TextDataset(df_test, tokenizer)lengths = store_data(bert, device, df_test, test_dataset, store_dir)"""文本预处理-textGraph.py"""