babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现

在进入到主程序前,我们还需要看一个Embedding的实现代码,这里的功能主要是为了计算代码之间的相关性。
embedding可以文本中的词语转化为低维实数向量的表示,来计算两段文字间的几何距离来判断词语的含义是否相近。

1. 源码阅读-初始化和计算代码库的嵌入值

这段代码主要是设定了初始化变量,包括使用的embedding的模型,以及tokenizer(分词器),分词器按照\n,作为分词符号和分词长度。

class Embeddings:def __init__(self, workspace_path: str):self.workspace_path = workspace_pathopenai.api_key = os.getenv("OPENAI_API_KEY", "")self.DOC_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.QUERY_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.SEPARATOR = "\n* "self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")self.separator_len = len(self.tokenizer.tokenize(self.SEPARATOR))

下面的代码用于计算整个代码库的embedding值,用于查找相关代码,实现了以下步骤

  • 删除playground_data 代码空间下的所有文件,不会有旧的数据重新计算
  • 将代码文件转换为特定格式,放入repository_info.csv文件中
  • 计算repository_info.csv中内容的嵌入值,放入到doc_embeddings.csv
def compute_repository_embeddings(self):try:playground_data_path = os.path.join(self.workspace_path, 'playground_data')# Delete the contents of the playground_data directory but not the directory itself# This is to ensure that we don't have any old data lying aroundfor filename in os.listdir(playground_data_path):file_path = os.path.join(playground_data_path, filename)try:if os.path.isfile(file_path) or os.path.islink(file_path):os.unlink(file_path)elif os.path.isdir(file_path):shutil.rmtree(file_path)except Exception as e:print(f"Failed to delete {file_path}. Reason: {str(e)}")except Exception as e:print(f"Error: {str(e)}")# extract and save info to csvinfo = self.extract_info(REPOSITORY_PATH)self.save_info_to_csv(info)df = pd.read_csv(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'))df = df.set_index(["filePath", "lineCoverage"])self.df = dfcontext_embeddings = self.compute_doc_embeddings(df)self.save_doc_embeddings_to_csv(context_embeddings, df, os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))try:self.document_embeddings = self.load_embeddings(os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))except:pass

下面是使用到的extract_info函数、save_info_to_csv函数、compute_doc_embeddings函数、save_doc_embeddings_to_csv函数,load_embeddings 函数。

1.1 extract_info 提取代码文件信息

这个函数的功能是从文件中获取信息,转化为特定的形式一个列表,包含三个信息

  • filePath 文件路径
  • lineCoverage 为一个元组,包含两个值 第一行位置和最后一行的位置
  • chunkContent 代码的内容
# Extract information from files in the repository in chunks
# Return a list of [filePath, lineCoverage, chunkContent]
def extract_info(self, REPOSITORY_PATH):# Initialize an empty list to store the informationinfo = []LINES_PER_CHUNK = 60# Iterate through the files in the repositoryfor root, dirs, files in os.walk(REPOSITORY_PATH):for file in files:file_path = os.path.join(root, file)# Read the contents of the filewith open(file_path, "r", encoding="utf-8") as f:try:contents = f.read()except:continue# Split the contents into lineslines = contents.split("\n")# Ignore empty lineslines = [line for line in lines if line.strip()]# Split the lines into chunks of LINES_PER_CHUNK lineschunks = [lines[i:i+LINES_PER_CHUNK]for i in range(0, len(lines), LINES_PER_CHUNK)]# Iterate through the chunksfor i, chunk in enumerate(chunks):# Join the lines in the chunk back into a single stringchunk = "\n".join(chunk)# Get the first and last line numbersfirst_line = i * LINES_PER_CHUNK + 1last_line = first_line + len(chunk.split("\n")) - 1line_coverage = (first_line, last_line)# Add the file path, line coverage, and content to the listinfo.append((os.path.join(root, file), line_coverage, chunk))# Return the list of informationreturn info

1.2 save_info_to_csv保存提取出的信息

这个函数的功能是将代码信息存放到csv文件中,使用pandas库

def save_info_to_csv(self, info):# Open a CSV file for writingos.makedirs(os.path.join(self.workspace_path, "playground_data"), exist_ok=True)with open(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'), "w", newline="") as csvfile:# Create a CSV writerwriter = csv.writer(csvfile)# Write the header rowwriter.writerow(["filePath", "lineCoverage", "content"])# Iterate through the infofor file_path, line_coverage, content in info:# Write a row for each chunk of datawriter.writerow([file_path, line_coverage, content])

1.3 compute_doc_embeddings计算文档的嵌入值信息

计算每个文件的嵌入值,并返回嵌入值字典

def compute_doc_embeddings(self, df: pd.DataFrame) -> dict[tuple[str, str], list[float]]:"""Create an embedding for each row in the dataframe using the OpenAI Embeddings API.Return a dictionary that maps between each embedding vector and the index of the row that it corresponds to."""embeddings = {}for idx, r in df.iterrows():# Wait one second before making the next call to the OpenAI Embeddings API# print("Waiting one second before embedding next row\n")time.sleep(1)embeddings[idx] = self.get_doc_embedding(r.content.replace("\n", " "))return embeddings

1.4 save_doc_embeddings_to_csv 保存嵌入值到文件中

该函数从文件中读取已经保存的embbeding信息,转换为一个dict

  • key为一个元组(filePath, lineCoverage)
  • value为一个数组,把其余列存放至后面
def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]:       df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}

1.5 save_doc_embbedings_to_csv将嵌入值保存到csv文件中

这里处理了一下,不是讲整个嵌入值放到数组中,而是更具嵌入的维度放入到列中,不同的维度有不同的嵌入值

def save_doc_embeddings_to_csv(self, doc_embeddings: dict, df: pd.DataFrame, csv_filepath: str):# Get the dimensionality of the embedding vectors from the first element in the doc_embeddings dictionaryif len(doc_embeddings) == 0:returnEMBEDDING_DIM = len(list(doc_embeddings.values())[0])# Create a new dataframe with the filePath, lineCoverage, and embedding vector columnsembeddings_df = pd.DataFrame(columns=["filePath", "lineCoverage"] + [f"{i}" for i in range(EMBEDDING_DIM)])# Iterate over the rows in the original dataframefor idx, _ in df.iterrows():# Get the embedding vector for the current rowembedding = doc_embeddings[idx]# Create a new row in the embeddings dataframe with the filePath, lineCoverage, and embedding vector valuesrow = [idx[0], idx[1]] + embeddingembeddings_df.loc[len(embeddings_df)] = row# Save the embeddings dataframe to a CSV fileembeddings_df.to_csv(csv_filepath, index=False)

1.6 load_embeddings加载嵌入值,从文件中

def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]:       df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}

2. embedding第二部分-获取代码相关性

获取相关的代码段,根据

  • 任务描述
  • 任务上下文
    获取相关的代码,相似度最高的两个代码块
def get_relevant_code_chunks(self, task_description: str, task_context: str):query = task_description + "\n" + task_contextmost_relevant_document_sections = self.order_document_sections_by_query_similarity(query, self.document_embeddings)selected_chunks = []for _, section_index in most_relevant_document_sections:try:document_section = self.df.loc[section_index]selected_chunks.append(self.SEPARATOR + document_section['content'].replace("\n", " "))if len(selected_chunks) >= 2:breakexcept:passreturn selected_chunks

这个函数有两个参数,他会根据相似度对整个文件排序

  • query 请求的文本,用作计算嵌入值
  • context 上下文,用作查找和这段文本的相似度,就是上文的字典
def order_document_sections_by_query_similarity(self, query: str, contexts: dict[(str, str), np.array]) -> list[(float, (str, str))]:"""Find the query embedding for the supplied query, and compare it against all of the pre-calculated document embeddingsto find the most relevant sections. Return the list of document sections, sorted by relevance in descending order."""query_embedding = self.get_query_embedding(query)document_similarities = sorted([(self.vector_similarity(query_embedding, doc_embedding), doc_index) for doc_index, doc_embedding in contexts.items()], reverse=True)return document_similarities

用数量积的方式计算两个向量之间的相似度,这里数量积可以表示为 a ⋅ b = ∣ a ∣ ∣ b ∣ c o s θ a{\cdot}b=|a||b|cos{\theta} ab=a∣∣bcosθ,当两个向量垂直时,计算的值为0,计算的值越大说明相似度越高

def vector_similarity(self, x: list[float], y: list[float]) -> float:return np.dot(np.array(x), np.array(y))

3. 通过OpenAI获取嵌入值相关函数

计算这段文字的嵌入值

def get_query_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.QUERY_EMBEDDINGS_MODEL)

计算文档相关性

def get_doc_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.DOC_EMBEDDINGS_MODEL)

处理openAi返回

    def get_embedding(self, text: str, model: str) -> list[float]:result = openai.Embedding.create(model=model,input=text)return result["data"][0]["embedding"]

下一篇文章,我们将进入主程序的阅读,看看embedding是如何和主程序结合的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/804405.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HJ1 字符串最后一个单词的长度(字符串,import java.util.HashSet;)

import java.util.Scanner;// 注意类名必须为 Main, 不要有任何 package xxx 信息 public class Main {public static void main(String[] args) {Scanner sc new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别int num sc.nextInt();boolean[] in new boolean[…

解析View树、apk安装

apk安装 https://github.com/sucese/android-open-source-project-analysis/blob/master/doc/Android%E7%B3%BB%E7%BB%9F%E5%BA%94%E7%94%A8%E6%A1%86%E6%9E%B6%E7%AF%87/Android%E5%8C%85%E7%AE%A1%E7%90%86%E6%A1%86%E6%9E%B6/02Android%E5%8C%85%E7%AE%A1%E7%90%86%E6%A1%8…

docker部署coredns服务器

创建文件夹 mkdir /coredns/config/添加一个CoreDNS配置文件 cat >/coredns/config/Corefile<<EOF.:53 {forward . 114.114.114.114:53log}EOF启动docker docker run -d --name coredns --restartalways \-v /coredns/config:/etc/coredns \-p 53:53/udp \regist…

Android 接入MQTT服务器

加入MQTT库 加入库可以直接下载对应的jar包&#xff0c;也可以在build.gradle里导入&#xff0c;然后加载进入。 这里直接在build.gradle加库 dependencies {implementation(libs.appcompat)implementation(libs.material)implementation(libs.activity)implementation(libs…

clickhouse深入浅出

基础知识原理 极致压缩率 极速查询性能 列式数据库管理 &#xff0c;读请求多 大批次更新或无更新 读很多但用很少 大量的列 列的值小数值/短字符串 一致性要求低 DBMS&#xff1a;动态创建/修改/删除库 表 视图&#xff0c;动态查/增/修/删&#xff0c;用户粒度设库…

控ID型sd生成 - AI写真

1.lora 2.dreambooth 3.串联型&#xff1a;facechain 4.串联型&#xff1a;easyphoto 5.instanceid 6.photomaker 7.ip-adapter-faceid

7天八股速记之Java后端——Day 7

讲一讲 JVM 启动时都有哪些参数 JVM&#xff08;Java Virtual Machine&#xff09;启动时可以通过命令行参数来配置其行为。这些参数通常可以分为以下几类&#xff1a; 标准参数&#xff08;Standard Options&#xff09;&#xff1a; 这些参数是所有 JVM 实现都必须支持的&am…

python项目练习——20、图片浏览器

这个项目允许用户浏览本地计算机上的图片文件,并在界面上显示图片,以及提供一些基本的操作,比如上一张、下一张、放大、缩小等。它涉及到文件操作、图像处理和用户界面设计等方面的技术。 示例 import os # 导入 os 模块 import tkinter as tk # 导入 Tkinter 库 from PI…

关于Java 中的Optional的一些事

开始 定义一个实体类User,实现get、set方法 public class User {String name;String sex;public User(String name, String sex) {this.name name;this.sex sex;}public String getName() {return name;}public void setName(String name) {this.name name;}public String…

Xinstall:专业的App下载量统计工具,让推广效果可衡量

在移动互联网时代&#xff0c;App的下载量是衡量一个应用受欢迎程度的重要指标。然而&#xff0c;很多开发者和广告主在推广App时&#xff0c;都面临着一个共同的问题&#xff1a;如何准确统计App的下载量&#xff1f;这不仅关系到推广效果的评估&#xff0c;还直接影响到广告R…

DC-1渗透测试复现

DC-1渗透测试复现 目的&#xff1a; 获取最高权限以及5个flag 过程&#xff1a; 信息打点-cms框架漏洞利用-数据库-登入admin-提权 环境&#xff1a; 攻击机&#xff1a;kali(192.168.85.136) 靶机&#xff1a;DC_1(192.168.85.131) 复现&#xff1a; 一.信息收集 扫…

对文件内容特殊关键字做高亮处理

效果&#xff1a; 对文件中指定的关键字&#xff08;内容&#xff09;做标记&#xff0c;适用于日志系统特殊化处理。比如对出现Error字段所在的行进行标红高亮 同时支持对关键字的管理以及关键在属性的设置 下面是对内容高亮&#xff1a; void MainWindow::displayDecodeResi…

Python爬虫基础快速入门

目录 前言一、什么是爬虫二、快速编写一个爬虫2.1 爬虫需要用到的库2.2 搭建项目工程2.3 安装三方库2.4 案例编写 三、爬虫实战3.1 目标分析3.2 清洗数据 四、代码改进 前言 本博客旨在分享爬虫技术相关知识&#xff0c;仅供学习和研究之用。使用者在阅读本博客的内容时&#…

Qt/C++推流组件使用说明

2.1 网络推流 公众号&#xff1a;Qt实战&#xff0c;各种开源作品、经验整理、项目实战技巧&#xff0c;专注Qt/C软件开发&#xff0c;视频监控、物联网、工业控制、嵌入式软件、国产化系统应用软件开发。 公众号&#xff1a;Qt入门和进阶&#xff0c;专门介绍Qt/C相关知识点学…

镗床工作台开槽的作用

镗床工作台开槽的作用主要有以下几点&#xff1a; 改善工作台的刚度和稳定性&#xff1a;开槽可以增加工作台的刚度&#xff0c;使其能够承受更大的切削力和振动力&#xff0c;提高工作台的稳定性。 方便工件夹紧和定位&#xff1a;开槽可用于夹紧和定位工件&#xff0c;使其能…

ChatGPT利器:让论文写作更高效更精准

ChatGPT无限次数:点击直达 ChatGPT利器&#xff1a;让论文写作更高效更精准 引言 在当今信息爆炸的时代&#xff0c;论文写作是许多学者和专业人士必不可少的任务。然而&#xff0c;即使对于有经验的专业人士&#xff0c;写作仍然是一个繁琐且耗时的过程。在这样的背景下&…

【DM8】序列

创建序列 图形化界面创建 DDL CREATE SEQUENCE "TEST"."S1" INCREMENT BY 1 START WITH 1 MAXVALUE 100 MINVALUE 1;参数&#xff1a; INCREMENT BY < 增量值 >| START WITH < 初值 >| MAXVALUE < 最大值 >| MINVALUE < 最小值…

PostgreSQL入门到实战-第十弹

PostgreSQL入门到实战 PostgreSQL数据过滤(三)官网地址PostgreSQL概述PostgreSQL中OR操作理论PostgreSQL中OR实操更新计划 PostgreSQL数据过滤(三) 了解PostgreSQL OR逻辑运算符以及如何使用它来组合多个布尔表达式。 官网地址 声明: 由于操作系统, 版本更新等原因, 文章所列…

MQ之————如何保证消息的可靠性

MQ之保证消息的可靠性 1.消费端消息可靠性保证&#xff1a; 1.1 消息确认&#xff08;Acknowledgements&#xff09;&#xff1a; 消费者在接收到消息后&#xff0c;默认情况下RabbitMQ会自动确认消息&#xff08;autoAcktrue&#xff09;。为保证消息可靠性&#xff0c;可以…

Thingsboard PE智慧运维仪表板实例(二)【智慧排口截污实例】

ThingsBoard 的仪表板是一个用于可视化和监控物联网数据的重要组件。 它具有以下特点: 1. 可定制性:用户可以根据自己的需求创建各种类型的图表、表格和指标。 2. 数据可视化:以直观的方式展示设备数据,帮助用户快速了解系统状态。 3. 实时更新:实时反映设备的最新数据…