论文辅助笔记:LLM-MOB代码解读

论文笔记 Where Would I Go Next? Large Language Models as Human Mobility Predictor-CSDN博客

1 主函数

1.1 导入库

import os
import pickle
import time
import ast
import logging
from datetime import datetime
import pandas as pd
from openai import OpenAIclient = OpenAI(api_key=...)

1.2 参数读取

dataname = "geolife"  
# 数据集名称num_historical_stay = 40  
# 长期mobility 的跨度num_context_stay = 5  
# 短期mobility的跨度top_k = 10  
# 输出location的数量with_time = False  
# 是否将目标stay的时间信息融入进来sleep_single_query = 0.1 
''' 
the sleep time between queries 
after the recent updates, the reliability of the API is greatly improved
so we can reduce the sleep time
'''sleep_if_crash = 1  
'''
the sleep time if the server crashes
'''output_dir = f"output/{dataname}/top10_wot"  
'''
the output path
'''log_dir = f"logs/{dataname}/top10_wot"  
'''
the log dir
'''

1.3 读取参数

tv_data, test_file = get_dataset(dataname)
#Number of total test sample:  3459'''
这个数量是比f"data/{dataname}/{dataname}_test.csv" 要少的
'''

1.4 日志文件生成

logger = get_logger('my_logger', log_dir=log_dir)

1.5 user id 提取

uid_list = get_unqueried_user(dataname, output_dir)
print(f"uid_list: {uid_list}")'''
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
Number of the remaining id: 45
uid_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
'''

 1.6 生成query

query_all_user(client, dataname, uid_list, logger, tv_data, num_historical_stay, num_context_stay,test_file, output_dir=output_dir, top_k=top_k, is_wt=with_time,sleep_query=sleep_single_query, sleep_crash=sleep_if_crash)

2 get_dataset函数

def get_dataset(dataname):# Get training and validation set and merge themtrain_data = pd.read_csv(f"data/{dataname}/{dataname}_train.csv")valid_data = pd.read_csv(f"data/{dataname}/{dataname}_valid.csv")#读取训练+验证集# Get test datawith open(f"data/{dataname}/{dataname}_testset.pk", "rb") as f:test_file = pickle.load(f)  # test_file is a list of dict#测试集# merge train and valid datatv_data = pd.concat([train_data, valid_data], ignore_index=True)tv_data.sort_values(['user_id', 'start_day', 'start_min'], inplace=True)if dataname == 'geolife':tv_data['duration'] = tv_data['duration'].astype(int)#合并训练+验证集print("Number of total test sample: ", len(test_file))return tv_data, test_file

 3 get_logger

def get_logger(logger_name, log_dir='logs/'):# Create log dirif not os.path.exists(log_dir):os.makedirs(log_dir)# Create a logger instancelogger = logging.getLogger(logger_name)logger.setLevel(logging.DEBUG)#创建一个日志记录器实例,并将其命名为 logger_name,并设置日志记录器的级别为 DEBUG# Create a console handler and set its log levelconsole_handler = logging.StreamHandler()console_handler.setLevel(logging.DEBUG)# 创建一个控制台处理器,其作用是将接收到的日志消息输出到控制台# Create a file handler and set its log levelcurrent_datetime = datetime.now()formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")#获取当前日期和时间,格式化为 "YYYYMMDD_HHMMSS" 形式,这部分用于生成日志文件名log_file = 'log_file' + formatted_datetime + '.log'#将格式化的时间字符串添加到 "log_file" 后面,构成日志文件名,例如 log_file20230424_153000.loglog_file_path = os.path.join(log_dir, log_file)#使用 os.path.join(log_dir, log_file) 创建完整的日志文件路径file_handler = logging.FileHandler(log_file_path)file_handler.setLevel(logging.DEBUG)#创建一个文件处理器,用于将日志消息写入到指定的文件中# Create a formatter and add it to the handlersformatter = logging.Formatter('%(message)s')#创建一个格式器 formatter,设置日志格式为仅包含消息体,即 '%(message)s'console_handler.setFormatter(formatter)file_handler.setFormatter(formatter)#将控制台处理器和文件处理器添加到日志记录器实例上# Add the handlers to the loggerlogger.addHandler(console_handler)logger.addHandler(file_handler)return logger

4 get_unqueried_user

提取数据集对应的user id

def get_unqueried_user(dataname, output_dir='output/'):if not os.path.exists(output_dir):os.makedirs(output_dir)if dataname == "geolife":all_user_id = [i+1 for i in range(45)]elif dataname == "fsq":all_user_id = [i+1 for i in range(535)]processed_id = [int(file.split('.')[0]) for file in os.listdir(output_dir) if file.endswith('.csv')]remain_id = [i for i in all_user_id if i not in processed_id]print(remain_id)print(f"Number of the remaining id: {len(remain_id)}")return remain_id

5 query_all_user

def query_all_user(client, dataname, uid_list, logger, train_data, num_historical_stay,num_context_stay, test_file, top_k, is_wt, output_dir, sleep_query, sleep_crash):for uid in uid_list:logger.info(f"=================Processing user {uid}==================")user_train = get_user_data(train_data, uid, num_historical_stay, logger)#当前研究的uid的长期历史mobility(M条)historical_data, predict_X, predict_y = organise_data(dataname, user_train, test_file, uid, logger, num_context_stay)'''返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),'''single_user_query(client, dataname, uid, historical_data, predict_X, predict_y, logger, top_k=top_k, is_wt=is_wt, output_dir=output_dir, sleep_query=sleep_query, sleep_crash=sleep_crash)

5.1 get_user_data

提取当前研究的uid的长期历史mobility(M条)

def get_user_data(train_data, uid, num_historical_stay, logger):user_train = train_data[train_data['user_id']==uid]#找到当下研究的user id对应的所有recordlogger.info(f"Length of user {uid} train data: {len(user_train)}")#user id一共多少条记录user_train = user_train.tail(num_historical_stay)logger.info(f"Number of user historical stays: {len(user_train)}")#long term mobility需要考虑多长的历史轨迹return user_train

5.2 organise_data

    返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth

每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),

def organise_data(dataname, user_train, test_file, uid, logger, num_context_stay=5):# Use another way of organising data# user_train只是临近的M个recordhistorical_data = []if dataname == 'geolife':for _, row in user_train.iterrows():historical_data.append((convert_to_12_hour_clock(int(row['start_min'])), int2dow(row['weekday']),int(row['duration']),row['location_id']))elif dataname == 'fsq':for _, row in user_train.iterrows():historical_data.append((convert_to_12_hour_clock(int(row['start_min'])),int2dow(row['weekday']),row['location_id']))'''
每次append如下内容time-of-day:时间转化成几点几分 AM/PM的形式
day-of-week:日子转化成星期几的形式
duration   :持续时间类型转化为整型
location id:location 对应的ideg,
[('09:08 PM', 'Wednesday', 466, 10),('04:58 AM', 'Thursday', 187, 17),('08:07 AM', 'Thursday', 146, 1),('10:35 AM', 'Thursday', 193, 17),('01:54 PM', 'Thursday', 556, 10)]'''logger.info(f"historical_data: {historical_data}")logger.info(f"Number of historical_data: {len(historical_data)}")# Get user ith test datalist_user_dict = []for i_dict in test_file:if dataname == 'geolife':i_uid = i_dict['user_X'][0]elif dataname == 'fsq':i_uid = i_dict['user_X']if i_uid == uid:list_user_dict.append(i_dict)#测试集中和user id 相同的 record 放入 list_user_dict#这个user id 需要测试的轨迹predict_X = []predict_y = []for i_dict in list_user_dict:construct_dict = {}if dataname == 'geolife':context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]], [int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]], [int(i) for i in i_dict['dur_X'][-num_context_stay:]], i_dict['X'][-num_context_stay:]))elif dataname == 'fsq':context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]], [int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]], i_dict['X'][-num_context_stay:]))'''只看geolife的话,context是一个有五个元素的list每个元素和前面append到historical_data的格式是一样的'''target = (convert_to_12_hour_clock(int(i_dict['start_min_Y'])), int2dow(i_dict['weekday_Y']), None, "<next_place_id>")#('12:36 AM', 'Friday', None, '<next_place_id>')construct_dict['context_stay'] = contextconstruct_dict['target_stay'] = target#构造输入,临近的N个location+目标的时刻和星期predict_y.append(i_dict['Y'])#ground-truth的station idpredict_X.append(construct_dict)#构造的输入logger.info(f"Number of predict_data: {len(predict_X)}")#这个user_id在test 数据中有多少条记录logger.info(f"predict_y: {predict_y}")logger.info(f"Number of predict_y: {len(predict_y)}")#虽然这个数量应该和predict_X的一样return historical_data, predict_X, predict_y'''返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth'''

5.2.1  convert_to_12_hour_clock

#转化成几点几分 AM/PM的形式

def convert_to_12_hour_clock(minutes):#原始数据的minutes 从这一天的0点算起,第几分钟if minutes < 0 or minutes >= 1440:return "Invalid input. Minutes should be between 0 and 1439."hours = minutes // 60minutes %= 60period = "AM"if hours >= 12:period = "PM"if hours == 0:hours = 12elif hours > 12:hours -= 12return f"{hours:02d}:{minutes:02d} {period}"#转化成几点几分 AM/PM的形式

5.2.2 int2dow

#转化成星期几的形式

def int2dow(int_day):tmp = {0: 'Monday', 1: 'Tuesday', 2: 'Wednesday',3: 'Thursday', 4: 'Friday', 5: 'Saturday', 6: 'Sunday'}return tmp[int_day]

5.3 single_user_query

保存location的预测结果

def single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,logger, top_k, is_wt, output_dir, sleep_query, sleep_crash):# Initialize variablestotal_queries = len(predict_X)logger.info(f"Total_queries: {total_queries}")#这个user id 一共有多少条查询processed_queries = 0current_results = pd.DataFrame({'user_id': None,'ground_truth': None,'prediction': None,'reason': None}, index=[])out_filename = f"{uid:02d}" + ".csv"out_filepath = os.path.join(output_dir, out_filename)try:# Attempt to load previous results if availablecurrent_results = load_results(out_filepath)processed_queries = len(current_results)logger.info(f"Loaded {processed_queries} previous results.")except FileNotFoundError:logger.info("No previous results found. Starting from scratch.")'''读取这个用户已经处理了的预测结果'''# Process remaining queriesfor i in range(processed_queries, total_queries):#预测这个用户剩余的查询logger.info(f'The {i+1}th sample: ')if dataname == 'geolife':if is_wt is True:if top_k == 1:completions = single_query_top1(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")else:if top_k == 1:completions = single_query_top1_wot(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10_wot(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")elif dataname == 'fsq':if is_wt is True:if top_k == 1:completions = single_query_top1_fsq(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10_fsq(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")else:if top_k == 1:completions = single_query_top1_wot_fsq(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10_wot_fsq(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")'''gpt针对不同情况的完整response'''response = completions.choices[0].message.content#gpt的response# Log the prediction results and usage.logger.info(f"Pred results: {response}")logger.info(f"Ground truth: {predict_y[i]}")logger.info(dict(completions).get('usage'))#使用的token数try:res_dict = ast.literal_eval(response)  # 解析gpt的输出,至字典的形式if top_k != 1:res_dict['prediction'] = str(res_dict['prediction'])res_dict['user_id'] = uidres_dict['ground_truth'] = predict_y[i]except Exception as e:res_dict = {'user_id': uid, 'ground_truth': predict_y[i], 'prediction': -100, 'reason': None}logger.info(e)logger.info(f"API request failed for the {i+1}th query")# time.sleep(sleep_crash)#如果上述任何一步出问题,说明预测失败finally:new_row = pd.DataFrame(res_dict, index=[0])  # A dataframe with only one record#当前这个location 预测的dataframecurrent_results = pd.concat([current_results, new_row], ignore_index=True)  # Add new row to the current df#这个user id的累计 预测location# Save the current resultscurrent_results.to_csv(out_filepath, index=False)#save_results(current_results, out_filename)logger.info(f"Saved {len(current_results)} results to {out_filepath}")#保存这个user的location 预测结果# Continue processing remaining queriesif len(current_results) < total_queries:#remaining_predict_X = predict_X[len(current_results):]#remaining_predict_y = predict_y[len(current_results):]#remaining_queries = queries[len(current_results):]logger.info("Restarting queries from the last successful point.")single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,logger, top_k, is_wt, output_dir, sleep_query, sleep_crash)

5.3.1 load_results

就是读取这个user 之前已经保存的预测记录

def load_results(filename):# Load previously saved results from a CSV file    results = pd.read_csv(filename)return results

5.3.2 single_query_top1_wot

5.3.3 single_query_top1

5.3.4 single_query_top10

top10的区别不大(几乎一模一样),就多了这样一句话:

5.3.5 get_chat_completion

提供gpt prompt获得相应的结果

def get_chat_completion(client, prompt, model="gpt-3.5-turbo-0613", json_mode=False, max_tokens=1200):"""args:client: the openai client object (new in 1.x version)prompt: the prompt to be completedmodel: specify the model to usejson_mode: whether return the response in json format (new in 1.x version)"""messages = [{"role": "user", "content": prompt}]if json_mode:completion = client.chat.completions.create(model=model,response_format={"type": "json_object"},messages=messages,temperature=0,  # the degree of randomness of the model's outputmax_tokens=max_tokens  # the maximum number of tokens to generate)else:completion = client.chat.completions.create(model=model,messages=messages,temperature=0,max_tokens=max_tokens)# res_content = response.choices[0].message["content"]# token_usage = response.usagereturn completion

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/2934.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【003_音频开发_基础篇_Linux进程通信(20种你了解几种?)】

003_音频开发_基础篇_Linux进程通信&#xff08;20种你了解几种&#xff1f;) 文章目录 003_音频开发_基础篇_Linux进程通信&#xff08;20种你了解几种&#xff1f;)创作背景Linux 进程通信类型fork() 函数fork() 输出 2 次fork() 输出 8 次fork() 返回值fork() 创建子进程 方…

Diffusion Model原理剖析

目录 前言1. DDPM演算法初览2. 图像生成模型共同目标3. VAE: Lower bound of l o g P ( x ) logP(x) logP(x)4. Diffusion Model背后的数学原理5. 为什么需要Sample?6. Diffusion Model的应用7. Diffusion Model成功的关键总结参考 前言 接着上篇文章 图像生成模型浅析&#…

1-k8s集群安装报错CGROUPS_CPU: missing

加入集群报错 [rootiZuf65r8i4e90z40vlh8mgZ ~]# kubeadm join 172.19.35.202:6443 --token 9edy1q.209zfq0387qtiv5x --discovery-token-ca-cert-hash sha256:24e0953896046aa8ce573ec7faf6609b87250883a7691fcad70a0faa81978c3b --control-plane --cri-socket "unix://…

Three.js入门学习笔记

学习资料&#xff1a; 【Three.js】Three.js快速上手教程_three.module.js-CSDN博客 2024年了&#xff0c;是该学学Three.js了_three.js 2024-CSDN博客 一、three.js简介 three.js是JavaScript编写的WebGL第三方库。 three.js&#xff0c;webGL&#xff0c;openGL三者的关…

【Linux高性能服务器编程】两种高性能并发模式剖析——领导者/追随者模式

hello &#xff01;大家好呀&#xff01; 欢迎大家来到我的Linux高性能服务器编程系列之两种高性能并发模式介绍&#xff0c;在这篇文章中&#xff0c;你将会学习到高效的创建自己的高性能服务器&#xff0c;并且我会给出源码进行剖析&#xff0c;以及手绘UML图来帮助大家来理解…

SpringBoot自动配置底层源码分析

文章目录 1. 什么是SpringBoot的自动装配&#xff1f;2. SpringBoot自动装配的底层原理 1. 什么是SpringBoot的自动装配&#xff1f; Spring Boot的自动配置是一种机制&#xff0c;它使得开发者能够快速地开始构建Spring应用&#xff0c;而不需要手动编写大量的样板代码。Spri…

代码随想录第34天: 贪心part03

力扣 1005.K次取反后最大化的数组和 class Solution {public int largestSumAfterKNegations(int[] nums, int k) {// 将基本类型的int数组转换成IntStream&#xff0c;以便进行流操作。nums Arrays.stream(nums)// 将IntStream中的int元素转换&#xff08;装箱&#xff09;为…

FRPC+PHP+MYSQL+APACHE2=个人网站

应用背景有公网需求,但是又不想去买又贵又低配置的服务器,然后方案就应运而生 frp/README_zh.md at dev fatedier/frp (github.com) 在这里, FRPC作为内网穿透服务, PHPMYSQLAPACHE2,作为网站搭建,具体细节不细讲, 但是在我的/var/www/html下面 linaroHinlink:/var/www/h…

17_c/c++开源库 easylogging日志库

1.简介与安装 简介: EasyLogging的主要特点包括&#xff1a; 简单易用&#xff1a;EasyLogging的API设计简洁明了&#xff0c;使用起来非常方便。开发者只需包含头文件并初始化库&#xff0c;即可开始记录日志。 高效性&#xff1a;EasyLogging采用异步日志记录方式&#xff…

CSS3新增特性(二)

四、2D 转换 • 属性名&#xff1a;transform &#xff08;可用于制作2D转换&#xff0c;也可用于制作3D转转换&#xff1b;2D转换是平面上的转换&#xff0c;3D转换是在三维立体空间的转换&#xff09; • 作用&#xff1a;对元素进行水平或垂直方向的移动、缩放、旋转、拉长…

stable diffusion QA

Q&#xff1a;有关于扩散模型的一个点不太懂&#xff0c;就是损失为何是去噪Unt的输出跟随机噪声的均方差&#xff1f;假如是图像修复任务&#xff0c;那为何不是去噪结果与真实图像进行损失计算呢&#xff1f; A&#xff1a;扩散模型simple loss将U-Net的输出与随机噪声计算M…

原生小程序自定义vantUI中van-collapse手风琴组件的标题

可以根据官网的提示&#xff1a; Vant Weapp - 轻量、可靠的小程序 UI 组件库 自己做的&#xff1a; <van-collapse accordion value"{{ activeName }}" bind:change"onChange"><van-collapse-item name"{{index}}"><!-- 这是自…

二. 搭建Nginx 直播流程服务器

目录 1. 前言 2. 安装 Nginx 依赖 3.下载源码 4. 编译安装 5.配置 rtmp 服务 6.验证配置 1. 前言 服务器由 NGINX+RTMP 构成。 NGINX 是 HTTP 服务器, RTMP 是附加模块。 其中 NGINX 我选择的是用 源码编译方式 进行安装,因为这种方式可以自定义安装指定的…

React 之 内置方法setState改变state(一)

简述 this.setState 方法是React组件类&#xff08;React.Component 的子类&#xff09;的一个内置方法。当你在创建一个React组件类时&#xff0c;你继承自 React.Component&#xff0c;因此你的组件类会自动获得this.setState 方法。this.setState 用于更新组件的state。当st…

DevOps(八)Jenkins的Maven和Git插件

一、Maven简介 Maven是一个构建生命周期管理和理解工具&#xff0c;用于Java项目。它提供了标准化的构建流程&#xff0c;并简化了从项目编译到文档生成等各种构建方面的管理。 Maven是由Apache软件基金会开发和维护的一个流行的项目管理工具。它的设计目的是简化Java项目的构…

Linux驱动开发:深入理解I2C时序

目录标题 I2C简介I2C时序关键点Linux内核中的I2C时序处理I2C适配器I2C算法I2C核心 代码示例&#xff1a;I2C设备访问调试I2C时序问题 在Linux驱动开发中&#xff0c;理解和正确处理I2C时序对于确保I2C设备正常工作至关重要。本文将详细介绍I2C通信协议的时序特征&#xff0c;并…

Mongo 实现简单全文检索

创建文本索引&#xff1a; 选择一个或多个要进行全文检索的字段。使用createIndex()方法在这些字段上创建文本索引db.collection.createIndex({ fieldName: "text" }) 执行全文检索查询&#xff1a; 使用$text操作符执行全文检索查询。使用$search指定要搜索的关键…

应用在防蓝光显示器中的LED防蓝光灯珠

相比抗蓝光眼镜、防蓝光覆膜、软体降低蓝光强度这些“软”净蓝手段&#xff0c;通过对LED的发光磷粉进行LED背光进行技术革新&#xff0c;可实现硬件“净蓝”。其能够将90%以上的有害蓝光转换为450nm以上的长波低能光线&#xff0c;从硬件的角度解决了蓝光危害眼睛的问题&#…

05_c/c++开源库 spdlog日志库

1.简介与安装 spdlog 是一个用于 C 的高性能、易用的日志库。它提供了丰富的日志功能&#xff0c;包括多种日志级别、格式化输出、异步日志、自定义日志接收器等。spdlog 是一个轻量级的库&#xff0c;性能优越&#xff0c;非常适合用于需要高性能日志记录的场景。 特点 高性…

mmdetection3.1.0 bug(已解决)

mmdetection版本3.1.0 想这训练rpn网络&#xff0c;但是训练后val的时候出现了问题&#xff0c;根据Traceback&#xff0c;找到bug。 报错信息&#xff1a;ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dim…