利用tree-sitter提取代码文件中的函数和注释
- 1. 需求
- 2. 工具
- 3. 实现
1. 需求
提取.c或.cpp文件中的带有注释的函数,作为训练数据喂给大语言模型。要求是能够批量处理,提取函数前带有注释的函数和注释,并将函数中的注释同样提取出来作为辅助训练数据,结果保存在JSON文件中。
2. 工具
tree-sitter
如何配置使用环境见https://blog.csdn.net/sluck_0430/article/details/134194493
pycharm
如何将conda的虚拟python环境添加到pycharm中见https://blog.csdn.net/weixin_62783109/article/details/129962054?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171402346916800178588080%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=171402346916800178588080&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2allfirst_rank_ecpm_v1~rank_v31_ecpm-4-129962054-null-null.142%5Ev100%5Econtrol&utm_term=conda%E5%90%8E%E7%9A%84%E7%8E%AF%E5%A2%83%E5%A6%82%E4%BD%95%E6%B7%BB%E5%8A%A0%E5%88%B0pycharm%E4%B8%AD&spm=1018.2226.3001.4187
3. 实现
from tree_sitter import Language, Parser
import json
import os
import re# 加载C语言模块
Language.build_library('build/my-languages.so',['vendor/tree-sitter-c']
)C_LANGUAGE = Language('build/my-languages.so', 'c')
parser = Parser()
parser.set_language(C_LANGUAGE)# 提取代码信息
def extract_code_information(node, code):functions = [] # 存放最终的代码提取结果comment = '' # 存放函数前的注释in_comment = '' # 存放函数中的注释function = '' # 存放函数for child in node.children:# 只保存函数前存在注释的函数及其注释if child.type == 'function_definition' and child.prev_sibling and child.prev_sibling.type == 'comment':# 首先处理函数function = extract_node_information(child, code)# 然后处理函数中的注释in_comment = traverse_children(child, code)# 最后处理函数前的注释temp_node = child.prev_siblingwhile temp_node.type == 'comment':comment += extract_node_information(temp_node, code)if temp_node.prev_sibling:temp_node = temp_node.prev_siblingelse:break# 将函数和其注释保存到最终的结果中functions.append({'comment_before_function': comment,'comment_in_function': in_comment,'function': function})comment = ''in_comment = ''function = ''return functions# 深度优先遍历节点的全部孩子节点
def traverse_children(node, code):if node is None:return ''comment = ''if node.type == 'comment':comment += extract_node_information(node, code)for child in node.children:comment += traverse_children(child, code)return comment# 提取节点信息
def extract_node_information(node, code):try:start_row, start_col = node.start_pointend_row, end_col = node.end_point# 将源代码按行进行拆分code_lines = code.split('\n')# 如果起始行和结束行在同一行if start_row == end_row:extracted_code = code_lines[start_row][start_col:end_col]else:# 提取起始行到结束行中的内容extracted_code = code_lines[start_row][start_col:]for i in range(start_row + 1, end_row):extracted_code += code_lines[i] + '\n'extracted_code += code_lines[end_row][:end_col]return extracted_codeexcept AttributeError as e:return ''# 查找文件夹中的.c和.cpp文件
def get_c_files(folder):c_files = []for root, dirs, files in os.walk(folder):for file in files:if re.search(r'\.c$|\.cpp$', file):c_files.append(os.path.join(root, file))return c_files# 处理文件夹中的.c和.cpp文件
def pipeline(folder_path):c_files = get_c_files(folder_path)functions = []for c_file in c_files:print(c_file)temp = []try:try:with open(c_file, 'r', encoding='gbk') as file:code = file.read()tree = parser.parse(bytes(code, 'gbk'))root_node = tree.root_nodetemp = extract_code_information(root_node, code)functions.append(temp)except UnicodeDecodeError as e:with open(c_file, 'r', encoding='utf8') as file:code = file.read()tree = parser.parse(bytes(code, 'utf8'))root_node = tree.root_nodetemp = extract_code_information(root_node, code)functions.append(temp)except UnicodeDecodeError as e:print("UnicodeDecodeError!")# 将结果保存在functions.json中with open('functions.json', 'w', encoding='utf8') as json_file:json.dump(functions, json_file, indent=4, ensure_ascii=False)if __name__ == '__main__':folder_path = '文件夹的绝对路径'pipeline(folder_path)