论文笔记 Where Would I Go Next? Large Language Models as Human Mobility Predictor-CSDN博客
1 主函数
1.1 导入库
import os
import pickle
import time
import ast
import logging
from datetime import datetime
import pandas as pd
from openai import OpenAIclient = OpenAI(api_key=...)
1.2 参数读取
dataname = "geolife"
# 数据集名称num_historical_stay = 40
# 长期mobility 的跨度num_context_stay = 5
# 短期mobility的跨度top_k = 10
# 输出location的数量with_time = False
# 是否将目标stay的时间信息融入进来sleep_single_query = 0.1
'''
the sleep time between queries
after the recent updates, the reliability of the API is greatly improved
so we can reduce the sleep time
'''sleep_if_crash = 1
'''
the sleep time if the server crashes
'''output_dir = f"output/{dataname}/top10_wot"
'''
the output path
'''log_dir = f"logs/{dataname}/top10_wot"
'''
the log dir
'''
1.3 读取参数
tv_data, test_file = get_dataset(dataname)
#Number of total test sample: 3459'''
这个数量是比f"data/{dataname}/{dataname}_test.csv" 要少的
'''
1.4 日志文件生成
logger = get_logger('my_logger', log_dir=log_dir)
1.5 user id 提取
uid_list = get_unqueried_user(dataname, output_dir)
print(f"uid_list: {uid_list}")'''
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
Number of the remaining id: 45
uid_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
'''
1.6 生成query
query_all_user(client, dataname, uid_list, logger, tv_data, num_historical_stay, num_context_stay,test_file, output_dir=output_dir, top_k=top_k, is_wt=with_time,sleep_query=sleep_single_query, sleep_crash=sleep_if_crash)
2 get_dataset函数
def get_dataset(dataname):# Get training and validation set and merge themtrain_data = pd.read_csv(f"data/{dataname}/{dataname}_train.csv")valid_data = pd.read_csv(f"data/{dataname}/{dataname}_valid.csv")#读取训练+验证集# Get test datawith open(f"data/{dataname}/{dataname}_testset.pk", "rb") as f:test_file = pickle.load(f) # test_file is a list of dict#测试集# merge train and valid datatv_data = pd.concat([train_data, valid_data], ignore_index=True)tv_data.sort_values(['user_id', 'start_day', 'start_min'], inplace=True)if dataname == 'geolife':tv_data['duration'] = tv_data['duration'].astype(int)#合并训练+验证集print("Number of total test sample: ", len(test_file))return tv_data, test_file
3 get_logger
def get_logger(logger_name, log_dir='logs/'):# Create log dirif not os.path.exists(log_dir):os.makedirs(log_dir)# Create a logger instancelogger = logging.getLogger(logger_name)logger.setLevel(logging.DEBUG)#创建一个日志记录器实例,并将其命名为 logger_name,并设置日志记录器的级别为 DEBUG# Create a console handler and set its log levelconsole_handler = logging.StreamHandler()console_handler.setLevel(logging.DEBUG)# 创建一个控制台处理器,其作用是将接收到的日志消息输出到控制台# Create a file handler and set its log levelcurrent_datetime = datetime.now()formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")#获取当前日期和时间,格式化为 "YYYYMMDD_HHMMSS" 形式,这部分用于生成日志文件名log_file = 'log_file' + formatted_datetime + '.log'#将格式化的时间字符串添加到 "log_file" 后面,构成日志文件名,例如 log_file20230424_153000.loglog_file_path = os.path.join(log_dir, log_file)#使用 os.path.join(log_dir, log_file) 创建完整的日志文件路径file_handler = logging.FileHandler(log_file_path)file_handler.setLevel(logging.DEBUG)#创建一个文件处理器,用于将日志消息写入到指定的文件中# Create a formatter and add it to the handlersformatter = logging.Formatter('%(message)s')#创建一个格式器 formatter,设置日志格式为仅包含消息体,即 '%(message)s'console_handler.setFormatter(formatter)file_handler.setFormatter(formatter)#将控制台处理器和文件处理器添加到日志记录器实例上# Add the handlers to the loggerlogger.addHandler(console_handler)logger.addHandler(file_handler)return logger
4 get_unqueried_user
提取数据集对应的user id
def get_unqueried_user(dataname, output_dir='output/'):if not os.path.exists(output_dir):os.makedirs(output_dir)if dataname == "geolife":all_user_id = [i+1 for i in range(45)]elif dataname == "fsq":all_user_id = [i+1 for i in range(535)]processed_id = [int(file.split('.')[0]) for file in os.listdir(output_dir) if file.endswith('.csv')]remain_id = [i for i in all_user_id if i not in processed_id]print(remain_id)print(f"Number of the remaining id: {len(remain_id)}")return remain_id
5 query_all_user
def query_all_user(client, dataname, uid_list, logger, train_data, num_historical_stay,num_context_stay, test_file, top_k, is_wt, output_dir, sleep_query, sleep_crash):for uid in uid_list:logger.info(f"=================Processing user {uid}==================")user_train = get_user_data(train_data, uid, num_historical_stay, logger)#当前研究的uid的长期历史mobility(M条)historical_data, predict_X, predict_y = organise_data(dataname, user_train, test_file, uid, logger, num_context_stay)'''返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),'''single_user_query(client, dataname, uid, historical_data, predict_X, predict_y, logger, top_k=top_k, is_wt=is_wt, output_dir=output_dir, sleep_query=sleep_query, sleep_crash=sleep_crash)
5.1 get_user_data
提取当前研究的uid的长期历史mobility(M条)
def get_user_data(train_data, uid, num_historical_stay, logger):user_train = train_data[train_data['user_id']==uid]#找到当下研究的user id对应的所有recordlogger.info(f"Length of user {uid} train data: {len(user_train)}")#user id一共多少条记录user_train = user_train.tail(num_historical_stay)logger.info(f"Number of user historical stays: {len(user_train)}")#long term mobility需要考虑多长的历史轨迹return user_train
5.2 organise_data
返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth
每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),
def organise_data(dataname, user_train, test_file, uid, logger, num_context_stay=5):# Use another way of organising data# user_train只是临近的M个recordhistorical_data = []if dataname == 'geolife':for _, row in user_train.iterrows():historical_data.append((convert_to_12_hour_clock(int(row['start_min'])), int2dow(row['weekday']),int(row['duration']),row['location_id']))elif dataname == 'fsq':for _, row in user_train.iterrows():historical_data.append((convert_to_12_hour_clock(int(row['start_min'])),int2dow(row['weekday']),row['location_id']))'''
每次append如下内容time-of-day:时间转化成几点几分 AM/PM的形式
day-of-week:日子转化成星期几的形式
duration :持续时间类型转化为整型
location id:location 对应的ideg,
[('09:08 PM', 'Wednesday', 466, 10),('04:58 AM', 'Thursday', 187, 17),('08:07 AM', 'Thursday', 146, 1),('10:35 AM', 'Thursday', 193, 17),('01:54 PM', 'Thursday', 556, 10)]'''logger.info(f"historical_data: {historical_data}")logger.info(f"Number of historical_data: {len(historical_data)}")# Get user ith test datalist_user_dict = []for i_dict in test_file:if dataname == 'geolife':i_uid = i_dict['user_X'][0]elif dataname == 'fsq':i_uid = i_dict['user_X']if i_uid == uid:list_user_dict.append(i_dict)#测试集中和user id 相同的 record 放入 list_user_dict#这个user id 需要测试的轨迹predict_X = []predict_y = []for i_dict in list_user_dict:construct_dict = {}if dataname == 'geolife':context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]], [int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]], [int(i) for i in i_dict['dur_X'][-num_context_stay:]], i_dict['X'][-num_context_stay:]))elif dataname == 'fsq':context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]], [int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]], i_dict['X'][-num_context_stay:]))'''只看geolife的话,context是一个有五个元素的list每个元素和前面append到historical_data的格式是一样的'''target = (convert_to_12_hour_clock(int(i_dict['start_min_Y'])), int2dow(i_dict['weekday_Y']), None, "<next_place_id>")#('12:36 AM', 'Friday', None, '<next_place_id>')construct_dict['context_stay'] = contextconstruct_dict['target_stay'] = target#构造输入,临近的N个location+目标的时刻和星期predict_y.append(i_dict['Y'])#ground-truth的station idpredict_X.append(construct_dict)#构造的输入logger.info(f"Number of predict_data: {len(predict_X)}")#这个user_id在test 数据中有多少条记录logger.info(f"predict_y: {predict_y}")logger.info(f"Number of predict_y: {len(predict_y)}")#虽然这个数量应该和predict_X的一样return historical_data, predict_X, predict_y'''返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth'''
5.2.1 convert_to_12_hour_clock
#转化成几点几分 AM/PM的形式
def convert_to_12_hour_clock(minutes):#原始数据的minutes 从这一天的0点算起,第几分钟if minutes < 0 or minutes >= 1440:return "Invalid input. Minutes should be between 0 and 1439."hours = minutes // 60minutes %= 60period = "AM"if hours >= 12:period = "PM"if hours == 0:hours = 12elif hours > 12:hours -= 12return f"{hours:02d}:{minutes:02d} {period}"#转化成几点几分 AM/PM的形式
5.2.2 int2dow
#转化成星期几的形式
def int2dow(int_day):tmp = {0: 'Monday', 1: 'Tuesday', 2: 'Wednesday',3: 'Thursday', 4: 'Friday', 5: 'Saturday', 6: 'Sunday'}return tmp[int_day]
5.3 single_user_query
保存location的预测结果
def single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,logger, top_k, is_wt, output_dir, sleep_query, sleep_crash):# Initialize variablestotal_queries = len(predict_X)logger.info(f"Total_queries: {total_queries}")#这个user id 一共有多少条查询processed_queries = 0current_results = pd.DataFrame({'user_id': None,'ground_truth': None,'prediction': None,'reason': None}, index=[])out_filename = f"{uid:02d}" + ".csv"out_filepath = os.path.join(output_dir, out_filename)try:# Attempt to load previous results if availablecurrent_results = load_results(out_filepath)processed_queries = len(current_results)logger.info(f"Loaded {processed_queries} previous results.")except FileNotFoundError:logger.info("No previous results found. Starting from scratch.")'''读取这个用户已经处理了的预测结果'''# Process remaining queriesfor i in range(processed_queries, total_queries):#预测这个用户剩余的查询logger.info(f'The {i+1}th sample: ')if dataname == 'geolife':if is_wt is True:if top_k == 1:completions = single_query_top1(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")else:if top_k == 1:completions = single_query_top1_wot(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10_wot(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")elif dataname == 'fsq':if is_wt is True:if top_k == 1:completions = single_query_top1_fsq(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10_fsq(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")else:if top_k == 1:completions = single_query_top1_wot_fsq(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10_wot_fsq(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")'''gpt针对不同情况的完整response'''response = completions.choices[0].message.content#gpt的response# Log the prediction results and usage.logger.info(f"Pred results: {response}")logger.info(f"Ground truth: {predict_y[i]}")logger.info(dict(completions).get('usage'))#使用的token数try:res_dict = ast.literal_eval(response) # 解析gpt的输出,至字典的形式if top_k != 1:res_dict['prediction'] = str(res_dict['prediction'])res_dict['user_id'] = uidres_dict['ground_truth'] = predict_y[i]except Exception as e:res_dict = {'user_id': uid, 'ground_truth': predict_y[i], 'prediction': -100, 'reason': None}logger.info(e)logger.info(f"API request failed for the {i+1}th query")# time.sleep(sleep_crash)#如果上述任何一步出问题,说明预测失败finally:new_row = pd.DataFrame(res_dict, index=[0]) # A dataframe with only one record#当前这个location 预测的dataframecurrent_results = pd.concat([current_results, new_row], ignore_index=True) # Add new row to the current df#这个user id的累计 预测location# Save the current resultscurrent_results.to_csv(out_filepath, index=False)#save_results(current_results, out_filename)logger.info(f"Saved {len(current_results)} results to {out_filepath}")#保存这个user的location 预测结果# Continue processing remaining queriesif len(current_results) < total_queries:#remaining_predict_X = predict_X[len(current_results):]#remaining_predict_y = predict_y[len(current_results):]#remaining_queries = queries[len(current_results):]logger.info("Restarting queries from the last successful point.")single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,logger, top_k, is_wt, output_dir, sleep_query, sleep_crash)
5.3.1 load_results
就是读取这个user 之前已经保存的预测记录
def load_results(filename):# Load previously saved results from a CSV file results = pd.read_csv(filename)return results
5.3.2 single_query_top1_wot
5.3.3 single_query_top1
5.3.4 single_query_top10
top10的区别不大(几乎一模一样),就多了这样一句话:
5.3.5 get_chat_completion
提供gpt prompt获得相应的结果
def get_chat_completion(client, prompt, model="gpt-3.5-turbo-0613", json_mode=False, max_tokens=1200):"""args:client: the openai client object (new in 1.x version)prompt: the prompt to be completedmodel: specify the model to usejson_mode: whether return the response in json format (new in 1.x version)"""messages = [{"role": "user", "content": prompt}]if json_mode:completion = client.chat.completions.create(model=model,response_format={"type": "json_object"},messages=messages,temperature=0, # the degree of randomness of the model's outputmax_tokens=max_tokens # the maximum number of tokens to generate)else:completion = client.chat.completions.create(model=model,messages=messages,temperature=0,max_tokens=max_tokens)# res_content = response.choices[0].message["content"]# token_usage = response.usagereturn completion