论文辅助笔记:LLM-MOB代码解读

论文笔记 Where Would I Go Next? Large Language Models as Human Mobility Predictor-CSDN博客

1 主函数

1.1 导入库

import os
import pickle
import time
import ast
import logging
from datetime import datetime
import pandas as pd
from openai import OpenAIclient = OpenAI(api_key=...)

1.2 参数读取

dataname = "geolife"  
# 数据集名称num_historical_stay = 40  
# 长期mobility 的跨度num_context_stay = 5  
# 短期mobility的跨度top_k = 10  
# 输出location的数量with_time = False  
# 是否将目标stay的时间信息融入进来sleep_single_query = 0.1 
''' 
the sleep time between queries 
after the recent updates, the reliability of the API is greatly improved
so we can reduce the sleep time
'''sleep_if_crash = 1  
'''
the sleep time if the server crashes
'''output_dir = f"output/{dataname}/top10_wot"  
'''
the output path
'''log_dir = f"logs/{dataname}/top10_wot"  
'''
the log dir
'''

1.3 读取参数

tv_data, test_file = get_dataset(dataname)
#Number of total test sample:  3459'''
这个数量是比f"data/{dataname}/{dataname}_test.csv" 要少的
'''

1.4 日志文件生成

logger = get_logger('my_logger', log_dir=log_dir)

1.5 user id 提取

uid_list = get_unqueried_user(dataname, output_dir)
print(f"uid_list: {uid_list}")'''
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
Number of the remaining id: 45
uid_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
'''

 1.6 生成query

query_all_user(client, dataname, uid_list, logger, tv_data, num_historical_stay, num_context_stay,test_file, output_dir=output_dir, top_k=top_k, is_wt=with_time,sleep_query=sleep_single_query, sleep_crash=sleep_if_crash)

2 get_dataset函数

def get_dataset(dataname):# Get training and validation set and merge themtrain_data = pd.read_csv(f"data/{dataname}/{dataname}_train.csv")valid_data = pd.read_csv(f"data/{dataname}/{dataname}_valid.csv")#读取训练+验证集# Get test datawith open(f"data/{dataname}/{dataname}_testset.pk", "rb") as f:test_file = pickle.load(f)  # test_file is a list of dict#测试集# merge train and valid datatv_data = pd.concat([train_data, valid_data], ignore_index=True)tv_data.sort_values(['user_id', 'start_day', 'start_min'], inplace=True)if dataname == 'geolife':tv_data['duration'] = tv_data['duration'].astype(int)#合并训练+验证集print("Number of total test sample: ", len(test_file))return tv_data, test_file

 3 get_logger

def get_logger(logger_name, log_dir='logs/'):# Create log dirif not os.path.exists(log_dir):os.makedirs(log_dir)# Create a logger instancelogger = logging.getLogger(logger_name)logger.setLevel(logging.DEBUG)#创建一个日志记录器实例,并将其命名为 logger_name,并设置日志记录器的级别为 DEBUG# Create a console handler and set its log levelconsole_handler = logging.StreamHandler()console_handler.setLevel(logging.DEBUG)# 创建一个控制台处理器,其作用是将接收到的日志消息输出到控制台# Create a file handler and set its log levelcurrent_datetime = datetime.now()formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")#获取当前日期和时间,格式化为 "YYYYMMDD_HHMMSS" 形式,这部分用于生成日志文件名log_file = 'log_file' + formatted_datetime + '.log'#将格式化的时间字符串添加到 "log_file" 后面,构成日志文件名,例如 log_file20230424_153000.loglog_file_path = os.path.join(log_dir, log_file)#使用 os.path.join(log_dir, log_file) 创建完整的日志文件路径file_handler = logging.FileHandler(log_file_path)file_handler.setLevel(logging.DEBUG)#创建一个文件处理器,用于将日志消息写入到指定的文件中# Create a formatter and add it to the handlersformatter = logging.Formatter('%(message)s')#创建一个格式器 formatter,设置日志格式为仅包含消息体,即 '%(message)s'console_handler.setFormatter(formatter)file_handler.setFormatter(formatter)#将控制台处理器和文件处理器添加到日志记录器实例上# Add the handlers to the loggerlogger.addHandler(console_handler)logger.addHandler(file_handler)return logger

4 get_unqueried_user

提取数据集对应的user id

def get_unqueried_user(dataname, output_dir='output/'):if not os.path.exists(output_dir):os.makedirs(output_dir)if dataname == "geolife":all_user_id = [i+1 for i in range(45)]elif dataname == "fsq":all_user_id = [i+1 for i in range(535)]processed_id = [int(file.split('.')[0]) for file in os.listdir(output_dir) if file.endswith('.csv')]remain_id = [i for i in all_user_id if i not in processed_id]print(remain_id)print(f"Number of the remaining id: {len(remain_id)}")return remain_id

5 query_all_user

def query_all_user(client, dataname, uid_list, logger, train_data, num_historical_stay,num_context_stay, test_file, top_k, is_wt, output_dir, sleep_query, sleep_crash):for uid in uid_list:logger.info(f"=================Processing user {uid}==================")user_train = get_user_data(train_data, uid, num_historical_stay, logger)#当前研究的uid的长期历史mobility(M条)historical_data, predict_X, predict_y = organise_data(dataname, user_train, test_file, uid, logger, num_context_stay)'''返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),'''single_user_query(client, dataname, uid, historical_data, predict_X, predict_y, logger, top_k=top_k, is_wt=is_wt, output_dir=output_dir, sleep_query=sleep_query, sleep_crash=sleep_crash)

5.1 get_user_data

提取当前研究的uid的长期历史mobility(M条)

def get_user_data(train_data, uid, num_historical_stay, logger):user_train = train_data[train_data['user_id']==uid]#找到当下研究的user id对应的所有recordlogger.info(f"Length of user {uid} train data: {len(user_train)}")#user id一共多少条记录user_train = user_train.tail(num_historical_stay)logger.info(f"Number of user historical stays: {len(user_train)}")#long term mobility需要考虑多长的历史轨迹return user_train

5.2 organise_data

    返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth

每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),

def organise_data(dataname, user_train, test_file, uid, logger, num_context_stay=5):# Use another way of organising data# user_train只是临近的M个recordhistorical_data = []if dataname == 'geolife':for _, row in user_train.iterrows():historical_data.append((convert_to_12_hour_clock(int(row['start_min'])), int2dow(row['weekday']),int(row['duration']),row['location_id']))elif dataname == 'fsq':for _, row in user_train.iterrows():historical_data.append((convert_to_12_hour_clock(int(row['start_min'])),int2dow(row['weekday']),row['location_id']))'''
每次append如下内容time-of-day:时间转化成几点几分 AM/PM的形式
day-of-week:日子转化成星期几的形式
duration   :持续时间类型转化为整型
location id:location 对应的ideg,
[('09:08 PM', 'Wednesday', 466, 10),('04:58 AM', 'Thursday', 187, 17),('08:07 AM', 'Thursday', 146, 1),('10:35 AM', 'Thursday', 193, 17),('01:54 PM', 'Thursday', 556, 10)]'''logger.info(f"historical_data: {historical_data}")logger.info(f"Number of historical_data: {len(historical_data)}")# Get user ith test datalist_user_dict = []for i_dict in test_file:if dataname == 'geolife':i_uid = i_dict['user_X'][0]elif dataname == 'fsq':i_uid = i_dict['user_X']if i_uid == uid:list_user_dict.append(i_dict)#测试集中和user id 相同的 record 放入 list_user_dict#这个user id 需要测试的轨迹predict_X = []predict_y = []for i_dict in list_user_dict:construct_dict = {}if dataname == 'geolife':context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]], [int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]], [int(i) for i in i_dict['dur_X'][-num_context_stay:]], i_dict['X'][-num_context_stay:]))elif dataname == 'fsq':context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]], [int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]], i_dict['X'][-num_context_stay:]))'''只看geolife的话,context是一个有五个元素的list每个元素和前面append到historical_data的格式是一样的'''target = (convert_to_12_hour_clock(int(i_dict['start_min_Y'])), int2dow(i_dict['weekday_Y']), None, "<next_place_id>")#('12:36 AM', 'Friday', None, '<next_place_id>')construct_dict['context_stay'] = contextconstruct_dict['target_stay'] = target#构造输入,临近的N个location+目标的时刻和星期predict_y.append(i_dict['Y'])#ground-truth的station idpredict_X.append(construct_dict)#构造的输入logger.info(f"Number of predict_data: {len(predict_X)}")#这个user_id在test 数据中有多少条记录logger.info(f"predict_y: {predict_y}")logger.info(f"Number of predict_y: {len(predict_y)}")#虽然这个数量应该和predict_X的一样return historical_data, predict_X, predict_y'''返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth'''

5.2.1  convert_to_12_hour_clock

#转化成几点几分 AM/PM的形式

def convert_to_12_hour_clock(minutes):#原始数据的minutes 从这一天的0点算起,第几分钟if minutes < 0 or minutes >= 1440:return "Invalid input. Minutes should be between 0 and 1439."hours = minutes // 60minutes %= 60period = "AM"if hours >= 12:period = "PM"if hours == 0:hours = 12elif hours > 12:hours -= 12return f"{hours:02d}:{minutes:02d} {period}"#转化成几点几分 AM/PM的形式

5.2.2 int2dow

#转化成星期几的形式

def int2dow(int_day):tmp = {0: 'Monday', 1: 'Tuesday', 2: 'Wednesday',3: 'Thursday', 4: 'Friday', 5: 'Saturday', 6: 'Sunday'}return tmp[int_day]

5.3 single_user_query

保存location的预测结果

def single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,logger, top_k, is_wt, output_dir, sleep_query, sleep_crash):# Initialize variablestotal_queries = len(predict_X)logger.info(f"Total_queries: {total_queries}")#这个user id 一共有多少条查询processed_queries = 0current_results = pd.DataFrame({'user_id': None,'ground_truth': None,'prediction': None,'reason': None}, index=[])out_filename = f"{uid:02d}" + ".csv"out_filepath = os.path.join(output_dir, out_filename)try:# Attempt to load previous results if availablecurrent_results = load_results(out_filepath)processed_queries = len(current_results)logger.info(f"Loaded {processed_queries} previous results.")except FileNotFoundError:logger.info("No previous results found. Starting from scratch.")'''读取这个用户已经处理了的预测结果'''# Process remaining queriesfor i in range(processed_queries, total_queries):#预测这个用户剩余的查询logger.info(f'The {i+1}th sample: ')if dataname == 'geolife':if is_wt is True:if top_k == 1:completions = single_query_top1(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")else:if top_k == 1:completions = single_query_top1_wot(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10_wot(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")elif dataname == 'fsq':if is_wt is True:if top_k == 1:completions = single_query_top1_fsq(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10_fsq(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")else:if top_k == 1:completions = single_query_top1_wot_fsq(client, historical_data, predict_X[i])elif top_k == 10:completions = single_query_top10_wot_fsq(client, historical_data, predict_X[i])else:raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")'''gpt针对不同情况的完整response'''response = completions.choices[0].message.content#gpt的response# Log the prediction results and usage.logger.info(f"Pred results: {response}")logger.info(f"Ground truth: {predict_y[i]}")logger.info(dict(completions).get('usage'))#使用的token数try:res_dict = ast.literal_eval(response)  # 解析gpt的输出,至字典的形式if top_k != 1:res_dict['prediction'] = str(res_dict['prediction'])res_dict['user_id'] = uidres_dict['ground_truth'] = predict_y[i]except Exception as e:res_dict = {'user_id': uid, 'ground_truth': predict_y[i], 'prediction': -100, 'reason': None}logger.info(e)logger.info(f"API request failed for the {i+1}th query")# time.sleep(sleep_crash)#如果上述任何一步出问题,说明预测失败finally:new_row = pd.DataFrame(res_dict, index=[0])  # A dataframe with only one record#当前这个location 预测的dataframecurrent_results = pd.concat([current_results, new_row], ignore_index=True)  # Add new row to the current df#这个user id的累计 预测location# Save the current resultscurrent_results.to_csv(out_filepath, index=False)#save_results(current_results, out_filename)logger.info(f"Saved {len(current_results)} results to {out_filepath}")#保存这个user的location 预测结果# Continue processing remaining queriesif len(current_results) < total_queries:#remaining_predict_X = predict_X[len(current_results):]#remaining_predict_y = predict_y[len(current_results):]#remaining_queries = queries[len(current_results):]logger.info("Restarting queries from the last successful point.")single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,logger, top_k, is_wt, output_dir, sleep_query, sleep_crash)

5.3.1 load_results

就是读取这个user 之前已经保存的预测记录

def load_results(filename):# Load previously saved results from a CSV file    results = pd.read_csv(filename)return results

5.3.2 single_query_top1_wot

5.3.3 single_query_top1

5.3.4 single_query_top10

top10的区别不大(几乎一模一样),就多了这样一句话:

5.3.5 get_chat_completion

提供gpt prompt获得相应的结果

def get_chat_completion(client, prompt, model="gpt-3.5-turbo-0613", json_mode=False, max_tokens=1200):"""args:client: the openai client object (new in 1.x version)prompt: the prompt to be completedmodel: specify the model to usejson_mode: whether return the response in json format (new in 1.x version)"""messages = [{"role": "user", "content": prompt}]if json_mode:completion = client.chat.completions.create(model=model,response_format={"type": "json_object"},messages=messages,temperature=0,  # the degree of randomness of the model's outputmax_tokens=max_tokens  # the maximum number of tokens to generate)else:completion = client.chat.completions.create(model=model,messages=messages,temperature=0,max_tokens=max_tokens)# res_content = response.choices[0].message["content"]# token_usage = response.usagereturn completion

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/2934.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【003_音频开发_基础篇_Linux进程通信(20种你了解几种?)】

003_音频开发_基础篇_Linux进程通信&#xff08;20种你了解几种&#xff1f;) 文章目录 003_音频开发_基础篇_Linux进程通信&#xff08;20种你了解几种&#xff1f;)创作背景Linux 进程通信类型fork() 函数fork() 输出 2 次fork() 输出 8 次fork() 返回值fork() 创建子进程 方…

Diffusion Model原理剖析

目录 前言1. DDPM演算法初览2. 图像生成模型共同目标3. VAE: Lower bound of l o g P ( x ) logP(x) logP(x)4. Diffusion Model背后的数学原理5. 为什么需要Sample?6. Diffusion Model的应用7. Diffusion Model成功的关键总结参考 前言 接着上篇文章 图像生成模型浅析&#…

Three.js入门学习笔记

学习资料&#xff1a; 【Three.js】Three.js快速上手教程_three.module.js-CSDN博客 2024年了&#xff0c;是该学学Three.js了_three.js 2024-CSDN博客 一、three.js简介 three.js是JavaScript编写的WebGL第三方库。 three.js&#xff0c;webGL&#xff0c;openGL三者的关…

【Linux高性能服务器编程】两种高性能并发模式剖析——领导者/追随者模式

hello &#xff01;大家好呀&#xff01; 欢迎大家来到我的Linux高性能服务器编程系列之两种高性能并发模式介绍&#xff0c;在这篇文章中&#xff0c;你将会学习到高效的创建自己的高性能服务器&#xff0c;并且我会给出源码进行剖析&#xff0c;以及手绘UML图来帮助大家来理解…

SpringBoot自动配置底层源码分析

文章目录 1. 什么是SpringBoot的自动装配&#xff1f;2. SpringBoot自动装配的底层原理 1. 什么是SpringBoot的自动装配&#xff1f; Spring Boot的自动配置是一种机制&#xff0c;它使得开发者能够快速地开始构建Spring应用&#xff0c;而不需要手动编写大量的样板代码。Spri…

FRPC+PHP+MYSQL+APACHE2=个人网站

应用背景有公网需求,但是又不想去买又贵又低配置的服务器,然后方案就应运而生 frp/README_zh.md at dev fatedier/frp (github.com) 在这里, FRPC作为内网穿透服务, PHPMYSQLAPACHE2,作为网站搭建,具体细节不细讲, 但是在我的/var/www/html下面 linaroHinlink:/var/www/h…

CSS3新增特性(二)

四、2D 转换 • 属性名&#xff1a;transform &#xff08;可用于制作2D转换&#xff0c;也可用于制作3D转转换&#xff1b;2D转换是平面上的转换&#xff0c;3D转换是在三维立体空间的转换&#xff09; • 作用&#xff1a;对元素进行水平或垂直方向的移动、缩放、旋转、拉长…

二. 搭建Nginx 直播流程服务器

目录 1. 前言 2. 安装 Nginx 依赖 3.下载源码 4. 编译安装 5.配置 rtmp 服务 6.验证配置 1. 前言 服务器由 NGINX+RTMP 构成。 NGINX 是 HTTP 服务器, RTMP 是附加模块。 其中 NGINX 我选择的是用 源码编译方式 进行安装,因为这种方式可以自定义安装指定的…

DevOps(八)Jenkins的Maven和Git插件

一、Maven简介 Maven是一个构建生命周期管理和理解工具&#xff0c;用于Java项目。它提供了标准化的构建流程&#xff0c;并简化了从项目编译到文档生成等各种构建方面的管理。 Maven是由Apache软件基金会开发和维护的一个流行的项目管理工具。它的设计目的是简化Java项目的构…

Linux驱动开发:深入理解I2C时序

目录标题 I2C简介I2C时序关键点Linux内核中的I2C时序处理I2C适配器I2C算法I2C核心 代码示例&#xff1a;I2C设备访问调试I2C时序问题 在Linux驱动开发中&#xff0c;理解和正确处理I2C时序对于确保I2C设备正常工作至关重要。本文将详细介绍I2C通信协议的时序特征&#xff0c;并…

应用在防蓝光显示器中的LED防蓝光灯珠

相比抗蓝光眼镜、防蓝光覆膜、软体降低蓝光强度这些“软”净蓝手段&#xff0c;通过对LED的发光磷粉进行LED背光进行技术革新&#xff0c;可实现硬件“净蓝”。其能够将90%以上的有害蓝光转换为450nm以上的长波低能光线&#xff0c;从硬件的角度解决了蓝光危害眼睛的问题&#…

❤️新版Linux零基础快速入门到精通——第一部分❤️

❤️新版Linux零基础快速入门到精通——第一部分❤️ 非科班的我&#xff01;Ta&#xff01;还是来了~~~1. 来认识一下Linux吧!1.1 操作系统概述1.1.1 操作系统概述1.1.2 操作系统的发展史1.1.2.1 Unix1.1.2.2 Minix1.1.2.3 Linux 1.1.3 操作系统的发展 1.2 Linux初识1.2.1 Lin…

【MySQL】数据库操作指南:数据类型篇

&#x1f331;博客主页&#xff1a;青竹雾色间 &#x1f331;系列专栏&#xff1a;MySQL探险日记 &#x1f618;博客制作不易欢迎各位&#x1f44d;点赞⭐收藏➕关注 ✨人生如寄&#xff0c;多忧何为 ✨ 文章目录 1. 数值类型1.1 tinyint 类型1.2 bit 类型1.3 小数类型1.3.1 f…

nacos配置mysql(windows)

nacos默认是使用的内置数据库derby ,可通过配置修改成mysql,修改成mysql之后&#xff0c;之前配置在derby的数据会丢失 本文使用mysql版本为8.0.22 nacos版本为2.3.1 在mysql里面先创建一个数据库test(名称自定义&#xff0c;和后面配置文件里面的一样就好了) 在上面创建的数据…

Milvus 在哈啰的应用与落地

向量数据库还有哪些可能性&#xff1f; 本期的【User Tech】直播告诉你答案&#xff01;明晚的直播&#xff0c;我们邀请了来自哈啰的资深研发工程师王永辉&#xff0c;他将为我们详细讲解 Milvus 在本地出行及生活服务平台的应用及未来发展的诸多可能性&#xff0c;敬请期待&a…

如何在Windows服务做性能测试(CPU、磁盘、内存)

目录 前言1. 基本知识2. 参数说明 前言 由于需要做一些接口测试&#xff0c;测试是否有真的优化 1. 基本知识 该基本知识主要用来用到Performance Monitor&#xff0c;以下着重介绍下这方面的知识 性能监视器&#xff08;Performance Monitor&#xff09;&#xff1a;Windo…

C++ 核心编程 - 内存分区模型

文章目录 1.1 程序运行前1.2 程序运行后1.3 new 操作符 C 程序在执行时&#xff0c;将内存大致划分为 4个区域&#xff1a; 代码区&#xff1a;存放函数体的二进制代码&#xff0c;由操作系统进行管理&#xff1b;全局区&#xff1a;存放全局变量和静态变量以及常量&#xff1…

MT2041 三角形的个数

思路&#xff1a;找规律&#xff0c;推公式 4等分&#xff1a; 头朝上的三角形&#xff1a; 边长为1&#xff1a;1234s1&#xff1b; 边长为2&#xff1a;123s2&#xff1b; 边长为3&#xff1a;12s3&#xff1b; 边长为4&#xff1a;1s4&#xff1b; 即si12...n-i1(n-i2)*(n-i…

基于高斯混合模型的视频背景提取和人员跟踪算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 MATLAB2013B 3.部分核心程序 .............................................................................. %我们这里…

根据当年节假日和非工作时间计算请假时间-获取每个月的节假日,计算每个月的工作日时间进度,节假日每年更新

根据需求请假时间要排除法定节假日和非工作时间 1.获取当年的节假日 节假日是每年更新的&#xff0c;没有固定接口&#xff0c;需要手动录入 个人根据官方的节假日整理了当年的所有节假日&#xff0c;可以根据个人需求进行修改 // 获取每个月的节假日&#xff0c;如果当月没…