Skip to content

Commit

Permalink
Merge pull request #18 from alexminnaar/readme
Browse files Browse the repository at this point in the history
adding more treesitter parsers
  • Loading branch information
alexminnaar authored Oct 7, 2023
2 parents 959b7b0 + f9dc6b1 commit bde3139
Show file tree
Hide file tree
Showing 16 changed files with 731 additions and 226 deletions.
33 changes: 24 additions & 9 deletions repogpt/crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import DeepLake
from repogpt.parsers.python_parser import PythonParser
from repogpt.parsers.treesitter_parser import TreeSitterParser
from repogpt.parsers.treesitter import TreeSitterParser
from repogpt.parsers.cpp_treesitter_parser import CppTreeSitterParser
from repogpt.parsers.java_treesitter_parser import JavaTreeSitterParser
from repogpt.parsers.js_treesitter_parser import JsTreeSitterParser
from repogpt.parsers.go_treesitter_parser import GoTreeSitterParser
from repogpt.parsers.treesitter import FileSummary
from tqdm import tqdm
from typing import List, Optional
import os
Expand Down Expand Up @@ -68,16 +73,28 @@ def process_file(
) -> List[Document]:
"""For a given file, get the summary, split into chunks and create context document chunks to be indexed"""
file_doc = file_contents[0]

# get file summary for raw file
# TODO: Add parsers for more languages
if extension == '.py':
file_summary = PythonParser.get_file_summary(file_doc.page_content, file_name)
elif extension == '.cpp':
file_summary = CppTreeSitterParser.get_file_summary(file_doc.page_content, file_name)
elif extension == '.java':
file_summary = JavaTreeSitterParser.get_file_summary(file_doc.page_content, file_name)
elif extension == '.js':
file_summary = JsTreeSitterParser.get_file_summary(file_doc.page_content, file_name)
elif extension == '.go':
file_summary = GoTreeSitterParser.get_file_summary(file_doc.page_content, file_name)
else:
file_summary = TreeSitterParser.get_file_summary(file_doc.page_content, file_name)
file_summary = FileSummary()

# split file contents based on file extension
splitter = RecursiveCharacterTextSplitter.from_language(
language=LANG_MAPPING[extension], chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
)
language=LANG_MAPPING[extension],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=True)
split_docs = splitter.split_documents(file_contents)

# add file path, starting line and summary to each chunk
Expand All @@ -88,19 +105,17 @@ def process_file(
doc.metadata['ending_line'] = ending_line

# get methods and classes associated with chunk
if extension == '.py':
method_class_summary = PythonParser.get_closest_method_class_in_snippet(file_summary, starting_line,
method_class_summary = TreeSitterParser.get_closest_method_class_in_snippet(file_summary,
starting_line,
ending_line)
else:
method_class_summary = TreeSitterParser.get_closest_method_class_in_snippet(file_summary, starting_line,
ending_line)
doc.page_content = f"The following code snippet is from a file at location " \
f"{os.path.join(dir_path, file_name)} " \
f"starting at line {starting_line} and ending at line {ending_line}. " \
f"{method_class_summary} " \
f"The code snippet starting at line {starting_line} and ending at line " \
f"{ending_line} is \n ```\n{doc.page_content}\n``` "

print(doc.page_content)
return split_docs


Expand Down
3 changes: 2 additions & 1 deletion repogpt/parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def add_method(self, method_name: str, method_start_line: int, method_end_line:


class CodeParser(ABC):

@staticmethod
def get_summary_from_position(
summary_positions: List[SummaryPosition],
Expand Down Expand Up @@ -93,5 +94,5 @@ def get_closest_method_class_in_snippet(

@staticmethod
@abstractmethod
def get_file_summary(code: str, file_name: str) -> FileSummary:
def get_file_summary(code: str) -> FileSummary:
"""Given a code string, parse and return important aspects of the code"""
43 changes: 43 additions & 0 deletions repogpt/parsers/cpp_treesitter_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition
from tree_sitter import Language, Parser


class CppTreeSitterParser(TreeSitterParser):
@staticmethod
def get_file_summary(code: str, file_name: str) -> FileSummary:

if not TreeSitterParser.loaded:
TreeSitterParser.initialize_treesitter()

file_summary = FileSummary()
parser = Parser()
parser.set_language(TreeSitterParser.languages['cpp'])
tree = parser.parse(bytes(code, "utf-8"))

def traverse(node, current_line):
if node.type == 'function_declarator':
for child in node.children:
if child.type == 'identifier' or child.type == 'field_identifier':
function_name = code[child.start_byte: child.end_byte]
file_summary.methods.append(
SummaryPosition(function_name, node.start_point[0], node.end_point[0]))

if node.type == 'class_specifier':
for child in node.children:
if child.type == 'type_identifier':
class_name = code[child.start_byte: child.end_byte]
file_summary.classes.append(
SummaryPosition(class_name, node.start_point[0], node.end_point[0]))

for child in node.children:
traverse(child, current_line)

root_node = tree.root_node

traverse(root_node, 0)

# methods and classes are not in order so sort
file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line)
file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line)

return file_summary
51 changes: 51 additions & 0 deletions repogpt/parsers/go_treesitter_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition
from tree_sitter import Language, Parser


class GoTreeSitterParser(TreeSitterParser):
@staticmethod
def get_file_summary(code: str, file_name:str) -> FileSummary:

if not TreeSitterParser.loaded:
TreeSitterParser.initialize_treesitter()

file_summary = FileSummary()
parser = Parser()
parser.set_language(TreeSitterParser.languages['go'])
tree = parser.parse(bytes(code, "utf-8"))

def traverse(node, current_line):

if node.type == 'type_spec':
for child in node.children:
if child.type == 'type_identifier':
class_name = code[child.start_byte: child.end_byte]
file_summary.classes.append(
SummaryPosition(class_name, node.start_point[0], node.end_point[0]))

if node.type == 'method_declaration':
for child in node.children:
if child.type == 'field_identifier':
function_name = code[child.start_byte: child.end_byte]
file_summary.methods.append(
SummaryPosition(function_name, node.start_point[0], node.end_point[0]))

if node.type == 'function_declaration':
for child in node.children:
if child.type == 'identifier':
function_name = code[child.start_byte: child.end_byte]
file_summary.methods.append(
SummaryPosition(function_name, node.start_point[0], node.end_point[0]))

for child in node.children:
traverse(child, current_line)

root_node = tree.root_node

traverse(root_node, 0)

# methods and classes are not in order so sort
file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line)
file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line)

return file_summary
50 changes: 50 additions & 0 deletions repogpt/parsers/java_treesitter_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition
from tree_sitter import Language, Parser


class JavaTreeSitterParser(TreeSitterParser):
@staticmethod
def get_file_summary(code: str, file_name: str) -> FileSummary:

if not TreeSitterParser.loaded:
TreeSitterParser.initialize_treesitter()

file_summary = FileSummary()
parser = Parser()
parser.set_language(TreeSitterParser.languages['java'])
tree = parser.parse(bytes(code, "utf-8"))

def traverse(node, current_line):
if node.type == 'constructor_declaration':
for child in node.children:
if child.type == 'identifier':
function_name = code[child.start_byte: child.end_byte]
file_summary.methods.append(
SummaryPosition(function_name, node.start_point[0], node.end_point[0]))

if node.type == 'method_declaration':
for child in node.children:
if child.type == 'identifier':
function_name = code[child.start_byte: child.end_byte]
file_summary.methods.append(
SummaryPosition(function_name, node.start_point[0], node.end_point[0]))

if node.type == 'class_declaration':
for child in node.children:
if child.type == 'identifier':
class_name = code[child.start_byte: child.end_byte]
file_summary.classes.append(
SummaryPosition(class_name, node.start_point[0], node.end_point[0]))

for child in node.children:
traverse(child, current_line)

root_node = tree.root_node

traverse(root_node, 0)

# methods and classes are not in order so sort
file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line)
file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line)

return file_summary
41 changes: 41 additions & 0 deletions repogpt/parsers/js_treesitter_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition
from tree_sitter import Language, Parser


class JsTreeSitterParser(TreeSitterParser):
@staticmethod
def get_file_summary(code: str, file_name: str) -> FileSummary:

if not TreeSitterParser.loaded:
TreeSitterParser.initialize_treesitter()

file_summary = FileSummary()
parser = Parser()
parser.set_language(TreeSitterParser.languages['cpp'])
tree = parser.parse(bytes(code, "utf-8"))

def traverse(node, current_line):
if node.type == 'function_declarator':
for child in node.children:
if child.type == 'identifier' or child.type == 'field_identifier':
function_name = code[child.start_byte: child.end_byte]
file_summary.methods.append(
SummaryPosition(function_name, node.start_point[0], node.end_point[0]))

if node.type == 'class_specifier':
class_name_node = node.children[1]
class_name = code[class_name_node.start_byte: class_name_node.end_byte]
file_summary.classes.append(SummaryPosition(class_name, node.start_point[0], node.end_point[0]))

for child in node.children:
traverse(child, current_line)

root_node = tree.root_node

traverse(root_node, 0)

# methods and classes are not in order so sort
file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line)
file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line)

return file_summary
16 changes: 8 additions & 8 deletions repogpt/parsers/python_parser.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from repogpt.parsers.base import CodeParser, FileSummary, SummaryPosition
from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition
import ast
from typing import List, Tuple


class PythonParser(CodeParser):
class PythonParser(TreeSitterParser):

@staticmethod
def get_file_summary(code: str, file_name: str) -> FileSummary:
"""Get the classes and methods in python code. Here we can get the end lines too."""
def get_file_summary(code: str, file_name:str) -> FileSummary:
"""Get the classes and methods in python code."""
parsed_tree = ast.parse(code)

file_summary = FileSummary()

# Traverse the AST to find function and class definitions
Expand All @@ -18,13 +16,15 @@ def get_file_summary(code: str, file_name: str) -> FileSummary:
function_name = node.name
function_start_line = node.lineno
function_end_line = node.end_lineno
file_summary.methods.append(SummaryPosition(function_name, function_start_line, function_end_line))
file_summary.methods.append(
SummaryPosition(function_name, function_start_line, function_end_line))

elif isinstance(node, ast.ClassDef):
class_name = node.name
class_start_line = node.lineno
class_end_line = node.end_lineno
file_summary.classes.append(SummaryPosition(class_name, class_start_line, class_end_line))
file_summary.classes.append(
SummaryPosition(class_name, class_start_line, class_end_line))

# methods and classes are not in order so sort
file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line)
Expand Down
Loading

0 comments on commit bde3139

Please sign in to comment.