diff --git a/repogpt/crawler.py b/repogpt/crawler.py index 1750e20..14d7611 100644 --- a/repogpt/crawler.py +++ b/repogpt/crawler.py @@ -4,7 +4,12 @@ from langchain.embeddings.base import Embeddings from langchain.vectorstores import DeepLake from repogpt.parsers.python_parser import PythonParser -from repogpt.parsers.treesitter_parser import TreeSitterParser +from repogpt.parsers.treesitter import TreeSitterParser +from repogpt.parsers.cpp_treesitter_parser import CppTreeSitterParser +from repogpt.parsers.java_treesitter_parser import JavaTreeSitterParser +from repogpt.parsers.js_treesitter_parser import JsTreeSitterParser +from repogpt.parsers.go_treesitter_parser import GoTreeSitterParser +from repogpt.parsers.treesitter import FileSummary from tqdm import tqdm from typing import List, Optional import os @@ -68,16 +73,28 @@ def process_file( ) -> List[Document]: """For a given file, get the summary, split into chunks and create context document chunks to be indexed""" file_doc = file_contents[0] + # get file summary for raw file + # TODO: Add parsers for more languages if extension == '.py': file_summary = PythonParser.get_file_summary(file_doc.page_content, file_name) + elif extension == '.cpp': + file_summary = CppTreeSitterParser.get_file_summary(file_doc.page_content, file_name) + elif extension == '.java': + file_summary = JavaTreeSitterParser.get_file_summary(file_doc.page_content, file_name) + elif extension == '.js': + file_summary = JsTreeSitterParser.get_file_summary(file_doc.page_content, file_name) + elif extension == '.go': + file_summary = GoTreeSitterParser.get_file_summary(file_doc.page_content, file_name) else: - file_summary = TreeSitterParser.get_file_summary(file_doc.page_content, file_name) + file_summary = FileSummary() # split file contents based on file extension splitter = RecursiveCharacterTextSplitter.from_language( - language=LANG_MAPPING[extension], chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True - ) + language=LANG_MAPPING[extension], + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + add_start_index=True) split_docs = splitter.split_documents(file_contents) # add file path, starting line and summary to each chunk @@ -88,12 +105,9 @@ def process_file( doc.metadata['ending_line'] = ending_line # get methods and classes associated with chunk - if extension == '.py': - method_class_summary = PythonParser.get_closest_method_class_in_snippet(file_summary, starting_line, + method_class_summary = TreeSitterParser.get_closest_method_class_in_snippet(file_summary, + starting_line, ending_line) - else: - method_class_summary = TreeSitterParser.get_closest_method_class_in_snippet(file_summary, starting_line, - ending_line) doc.page_content = f"The following code snippet is from a file at location " \ f"{os.path.join(dir_path, file_name)} " \ f"starting at line {starting_line} and ending at line {ending_line}. " \ @@ -101,6 +115,7 @@ def process_file( f"The code snippet starting at line {starting_line} and ending at line " \ f"{ending_line} is \n ```\n{doc.page_content}\n``` " + print(doc.page_content) return split_docs diff --git a/repogpt/parsers/base.py b/repogpt/parsers/base.py index 1dee99f..26b7f5d 100644 --- a/repogpt/parsers/base.py +++ b/repogpt/parsers/base.py @@ -22,6 +22,7 @@ def add_method(self, method_name: str, method_start_line: int, method_end_line: class CodeParser(ABC): + @staticmethod def get_summary_from_position( summary_positions: List[SummaryPosition], @@ -93,5 +94,5 @@ def get_closest_method_class_in_snippet( @staticmethod @abstractmethod - def get_file_summary(code: str, file_name: str) -> FileSummary: + def get_file_summary(code: str) -> FileSummary: """Given a code string, parse and return important aspects of the code""" diff --git a/repogpt/parsers/cpp_treesitter_parser.py b/repogpt/parsers/cpp_treesitter_parser.py new file mode 100644 index 0000000..134e825 --- /dev/null +++ b/repogpt/parsers/cpp_treesitter_parser.py @@ -0,0 +1,43 @@ +from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition +from tree_sitter import Language, Parser + + +class CppTreeSitterParser(TreeSitterParser): + @staticmethod + def get_file_summary(code: str, file_name: str) -> FileSummary: + + if not TreeSitterParser.loaded: + TreeSitterParser.initialize_treesitter() + + file_summary = FileSummary() + parser = Parser() + parser.set_language(TreeSitterParser.languages['cpp']) + tree = parser.parse(bytes(code, "utf-8")) + + def traverse(node, current_line): + if node.type == 'function_declarator': + for child in node.children: + if child.type == 'identifier' or child.type == 'field_identifier': + function_name = code[child.start_byte: child.end_byte] + file_summary.methods.append( + SummaryPosition(function_name, node.start_point[0], node.end_point[0])) + + if node.type == 'class_specifier': + for child in node.children: + if child.type == 'type_identifier': + class_name = code[child.start_byte: child.end_byte] + file_summary.classes.append( + SummaryPosition(class_name, node.start_point[0], node.end_point[0])) + + for child in node.children: + traverse(child, current_line) + + root_node = tree.root_node + + traverse(root_node, 0) + + # methods and classes are not in order so sort + file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line) + file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line) + + return file_summary diff --git a/repogpt/parsers/go_treesitter_parser.py b/repogpt/parsers/go_treesitter_parser.py new file mode 100644 index 0000000..5f7d21f --- /dev/null +++ b/repogpt/parsers/go_treesitter_parser.py @@ -0,0 +1,51 @@ +from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition +from tree_sitter import Language, Parser + + +class GoTreeSitterParser(TreeSitterParser): + @staticmethod + def get_file_summary(code: str, file_name:str) -> FileSummary: + + if not TreeSitterParser.loaded: + TreeSitterParser.initialize_treesitter() + + file_summary = FileSummary() + parser = Parser() + parser.set_language(TreeSitterParser.languages['go']) + tree = parser.parse(bytes(code, "utf-8")) + + def traverse(node, current_line): + + if node.type == 'type_spec': + for child in node.children: + if child.type == 'type_identifier': + class_name = code[child.start_byte: child.end_byte] + file_summary.classes.append( + SummaryPosition(class_name, node.start_point[0], node.end_point[0])) + + if node.type == 'method_declaration': + for child in node.children: + if child.type == 'field_identifier': + function_name = code[child.start_byte: child.end_byte] + file_summary.methods.append( + SummaryPosition(function_name, node.start_point[0], node.end_point[0])) + + if node.type == 'function_declaration': + for child in node.children: + if child.type == 'identifier': + function_name = code[child.start_byte: child.end_byte] + file_summary.methods.append( + SummaryPosition(function_name, node.start_point[0], node.end_point[0])) + + for child in node.children: + traverse(child, current_line) + + root_node = tree.root_node + + traverse(root_node, 0) + + # methods and classes are not in order so sort + file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line) + file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line) + + return file_summary \ No newline at end of file diff --git a/repogpt/parsers/java_treesitter_parser.py b/repogpt/parsers/java_treesitter_parser.py new file mode 100644 index 0000000..62b6154 --- /dev/null +++ b/repogpt/parsers/java_treesitter_parser.py @@ -0,0 +1,50 @@ +from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition +from tree_sitter import Language, Parser + + +class JavaTreeSitterParser(TreeSitterParser): + @staticmethod + def get_file_summary(code: str, file_name: str) -> FileSummary: + + if not TreeSitterParser.loaded: + TreeSitterParser.initialize_treesitter() + + file_summary = FileSummary() + parser = Parser() + parser.set_language(TreeSitterParser.languages['java']) + tree = parser.parse(bytes(code, "utf-8")) + + def traverse(node, current_line): + if node.type == 'constructor_declaration': + for child in node.children: + if child.type == 'identifier': + function_name = code[child.start_byte: child.end_byte] + file_summary.methods.append( + SummaryPosition(function_name, node.start_point[0], node.end_point[0])) + + if node.type == 'method_declaration': + for child in node.children: + if child.type == 'identifier': + function_name = code[child.start_byte: child.end_byte] + file_summary.methods.append( + SummaryPosition(function_name, node.start_point[0], node.end_point[0])) + + if node.type == 'class_declaration': + for child in node.children: + if child.type == 'identifier': + class_name = code[child.start_byte: child.end_byte] + file_summary.classes.append( + SummaryPosition(class_name, node.start_point[0], node.end_point[0])) + + for child in node.children: + traverse(child, current_line) + + root_node = tree.root_node + + traverse(root_node, 0) + + # methods and classes are not in order so sort + file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line) + file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line) + + return file_summary diff --git a/repogpt/parsers/js_treesitter_parser.py b/repogpt/parsers/js_treesitter_parser.py new file mode 100644 index 0000000..436f744 --- /dev/null +++ b/repogpt/parsers/js_treesitter_parser.py @@ -0,0 +1,41 @@ +from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition +from tree_sitter import Language, Parser + + +class JsTreeSitterParser(TreeSitterParser): + @staticmethod + def get_file_summary(code: str, file_name: str) -> FileSummary: + + if not TreeSitterParser.loaded: + TreeSitterParser.initialize_treesitter() + + file_summary = FileSummary() + parser = Parser() + parser.set_language(TreeSitterParser.languages['cpp']) + tree = parser.parse(bytes(code, "utf-8")) + + def traverse(node, current_line): + if node.type == 'function_declarator': + for child in node.children: + if child.type == 'identifier' or child.type == 'field_identifier': + function_name = code[child.start_byte: child.end_byte] + file_summary.methods.append( + SummaryPosition(function_name, node.start_point[0], node.end_point[0])) + + if node.type == 'class_specifier': + class_name_node = node.children[1] + class_name = code[class_name_node.start_byte: class_name_node.end_byte] + file_summary.classes.append(SummaryPosition(class_name, node.start_point[0], node.end_point[0])) + + for child in node.children: + traverse(child, current_line) + + root_node = tree.root_node + + traverse(root_node, 0) + + # methods and classes are not in order so sort + file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line) + file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line) + + return file_summary diff --git a/repogpt/parsers/python_parser.py b/repogpt/parsers/python_parser.py index 55b1cf7..3a6c654 100644 --- a/repogpt/parsers/python_parser.py +++ b/repogpt/parsers/python_parser.py @@ -1,15 +1,13 @@ -from repogpt.parsers.base import CodeParser, FileSummary, SummaryPosition +from repogpt.parsers.treesitter import TreeSitterParser, FileSummary, SummaryPosition import ast -from typing import List, Tuple -class PythonParser(CodeParser): +class PythonParser(TreeSitterParser): @staticmethod - def get_file_summary(code: str, file_name: str) -> FileSummary: - """Get the classes and methods in python code. Here we can get the end lines too.""" + def get_file_summary(code: str, file_name:str) -> FileSummary: + """Get the classes and methods in python code.""" parsed_tree = ast.parse(code) - file_summary = FileSummary() # Traverse the AST to find function and class definitions @@ -18,13 +16,15 @@ def get_file_summary(code: str, file_name: str) -> FileSummary: function_name = node.name function_start_line = node.lineno function_end_line = node.end_lineno - file_summary.methods.append(SummaryPosition(function_name, function_start_line, function_end_line)) + file_summary.methods.append( + SummaryPosition(function_name, function_start_line, function_end_line)) elif isinstance(node, ast.ClassDef): class_name = node.name class_start_line = node.lineno class_end_line = node.end_lineno - file_summary.classes.append(SummaryPosition(class_name, class_start_line, class_end_line)) + file_summary.classes.append( + SummaryPosition(class_name, class_start_line, class_end_line)) # methods and classes are not in order so sort file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line) diff --git a/repogpt/parsers/treesitter.py b/repogpt/parsers/treesitter.py new file mode 100644 index 0000000..f99d5ef --- /dev/null +++ b/repogpt/parsers/treesitter.py @@ -0,0 +1,129 @@ +from abc import ABC, abstractmethod +import subprocess +from tree_sitter import Language, Parser +from typing import List, Tuple + + +class SummaryPosition: + def __init__(self, name: str, start_line: int, end_line: int = None): + self.name = name + self.start_line = start_line + self.end_line = end_line + + +class FileSummary: + def __init__(self): + self.classes = [] + self.methods = [] + + def add_class(self, class_name: str, class_start_line: int, class_end_line: int = None): + self.classes.append(SummaryPosition(class_name, class_start_line, class_end_line)) + + def add_method(self, method_name: str, method_start_line: int, method_end_line: int = None): + self.methods.append(SummaryPosition(method_name, method_start_line, method_end_line)) + + +class TreeSitterParser(ABC): + languages = {} + loaded = False + + @staticmethod + def initialize_treesitter(): + # download tree-sitter shared objects according to + # https://github.com/sweepai/sweep/blob/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/utils/utils.py#L169 + LANGUAGE_NAMES = ["python", "java", "cpp", "go", "rust", "ruby", "php"] + for language in LANGUAGE_NAMES: + subprocess.run( + f"git clone https://github.com/tree-sitter/tree-sitter-{language} cache/tree-sitter-{language}", + shell=True) + for language in LANGUAGE_NAMES: + Language.build_library(f'cache/build/{language}.so', [f"cache/tree-sitter-{language}"]) + subprocess.run(f"cp cache/build/{language}.so /tmp/{language}.so", shell=True) + TreeSitterParser.languages = {language: Language(f"/tmp/{language}.so", language) for language in + LANGUAGE_NAMES} + subprocess.run(f"git clone https://github.com/tree-sitter/tree-sitter-c-sharp cache/tree-sitter-c-sharp", + shell=True) + Language.build_library(f'cache/build/c-sharp.so', [f"cache/tree-sitter-c-sharp"]) + subprocess.run(f"cp cache/build/c-sharp.so /tmp/c-sharp.so", shell=True) + TreeSitterParser.languages["c-sharp"] = Language("/tmp/c-sharp.so", "c_sharp") + + subprocess.run(f"git clone https://github.com/tree-sitter/tree-sitter-typescript cache/tree-sitter-typescript", + shell=True) + Language.build_library(f'cache/build/typescript.so', [f"cache/tree-sitter-typescript/tsx"]) + subprocess.run(f"cp cache/build/typescript.so /tmp/typescript.so", shell=True) + TreeSitterParser.languages["tsx"] = Language("/tmp/typescript.so", "tsx") + TreeSitterParser.loaded = True + + @staticmethod + def get_summary_from_position( + summary_positions: List[SummaryPosition], + start_line: int, + end_line: int + ) -> Tuple[List[SummaryPosition], List[SummaryPosition]]: + + last_obj = [] + current_obj = [] + + # TODO: binary search-ify this + for s_pos in summary_positions: + # get last defined method before the snippet + if s_pos.start_line < start_line and s_pos.end_line >= start_line: + last_obj.append(s_pos) + + # get any method defined in this snippet + if start_line <= s_pos.start_line <= end_line: + current_obj.append(s_pos) + + # ignore everything past this snippet + if s_pos.start_line > end_line: + break + return last_obj, current_obj + + @staticmethod + def get_closest_method_class_in_snippet( + file_summary: FileSummary, + snippet_start_line: int, + snippet_end_line: int + ) -> str: + closest_method_class_summary = "" + + last_class, current_class = TreeSitterParser.get_summary_from_position(file_summary.classes, snippet_start_line, + snippet_end_line) + + if last_class: + closest_method_class_summary += f" The last class defined before this snippet was called " \ + f"`{last_class[-1].name}` starting at line {last_class[-1].start_line} " \ + f"and ending at line {last_class[-1].end_line}." + if len(current_class) == 1: + closest_method_class_summary += f" The class defined in this snippet is called `{current_class[0].name}`" \ + f"starting at line {current_class[0].start_line} and ending at line " \ + f"{current_class[0].end_line}." + elif len(current_class) > 1: + multi_class_summary = " and ".join( + [f"`{c.name}` starting at line {c.start_line} and ending at line {c.end_line}" for c in current_class]) + closest_method_class_summary += f" The classes defined in this snippet are {multi_class_summary}." + + last_method, current_method = TreeSitterParser.get_summary_from_position(file_summary.methods, snippet_start_line, + snippet_end_line) + + if last_method: + closest_method_class_summary += f" The last method starting before this snippet is called " \ + f"`{last_method[-1].name}` which starts on line " \ + f"{last_method[-1].start_line} and ends at line {last_method[-1].end_line}." + if len(current_method) == 1: + closest_method_class_summary += f" The method defined in this snippet is called " \ + f"`{current_method[0].name}` starting at line " \ + f"{current_method[0].start_line} and ending at line " \ + f"{current_method[0].end_line}." + elif len(current_method) > 1: + multi_method_summary = " and ".join( + [f"`{meth.name}` starting at line {meth.start_line} and ending at line {meth.end_line}" for meth in + current_method]) + closest_method_class_summary += f" The methods defined in this snippet are {multi_method_summary}." + + return closest_method_class_summary + + @staticmethod + @abstractmethod + def get_file_summary(code: str, file_name:str) -> FileSummary: + """Given a code string, parse and return important aspects of the code""" \ No newline at end of file diff --git a/repogpt/parsers/treesitter_parser.py b/repogpt/parsers/treesitter_parser.py deleted file mode 100644 index 94e7cc8..0000000 --- a/repogpt/parsers/treesitter_parser.py +++ /dev/null @@ -1,125 +0,0 @@ -from repogpt.parsers.base import CodeParser, FileSummary, SummaryPosition -import subprocess -from tree_sitter import Language, Parser -import os - - -class TreeSitterParser(CodeParser): - languages = {} - loaded = False - - @staticmethod - def initialize_treesitter(): - # download tree-sitter shared objects according to - # https://github.com/sweepai/sweep/blob/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/utils/utils.py#L169 - LANGUAGE_NAMES = ["python", "java", "cpp", "go", "rust", "ruby", "php"] - for language in LANGUAGE_NAMES: - subprocess.run( - f"git clone https://github.com/tree-sitter/tree-sitter-{language} cache/tree-sitter-{language}", - shell=True) - for language in LANGUAGE_NAMES: - Language.build_library(f'cache/build/{language}.so', [f"cache/tree-sitter-{language}"]) - subprocess.run(f"cp cache/build/{language}.so /tmp/{language}.so", shell=True) - TreeSitterParser.languages = {language: Language(f"/tmp/{language}.so", language) for language in - LANGUAGE_NAMES} - subprocess.run(f"git clone https://github.com/tree-sitter/tree-sitter-c-sharp cache/tree-sitter-c-sharp", - shell=True) - Language.build_library(f'cache/build/c-sharp.so', [f"cache/tree-sitter-c-sharp"]) - subprocess.run(f"cp cache/build/c-sharp.so /tmp/c-sharp.so", shell=True) - TreeSitterParser.languages["c-sharp"] = Language("/tmp/c-sharp.so", "c_sharp") - - subprocess.run(f"git clone https://github.com/tree-sitter/tree-sitter-typescript cache/tree-sitter-typescript", - shell=True) - Language.build_library(f'cache/build/typescript.so', [f"cache/tree-sitter-typescript/tsx"]) - subprocess.run(f"cp cache/build/typescript.so /tmp/typescript.so", shell=True) - TreeSitterParser.languages["tsx"] = Language("/tmp/typescript.so", "tsx") - TreeSitterParser.loaded = True - - @staticmethod - def get_file_summary(code: str, file_name: str) -> FileSummary: - - # download and build treesitter shared objects if not already done. - if not TreeSitterParser.loaded: - TreeSitterParser.initialize_treesitter() - - extension_to_language = { - ".py": "python", - ".rs": "rust", - ".go": "go", - ".java": "java", - ".cpp": "cpp", - ".cc": "cpp", - ".cxx": "cpp", - ".c": "cpp", - ".h": "cpp", - ".hpp": "cpp", - ".rb": "ruby", - ".php": "php", - ".js": "tsx", - ".jsx": "tsx", - ".ts": "tsx", - ".tsx": "tsx", - ".mjs": "tsx", - ".cs": "c-sharp", - - } - - file_summary = FileSummary() - _, extension = os.path.splitext(file_name) - - if extension in extension_to_language: - language = TreeSitterParser.languages[extension_to_language[extension]] - parser = Parser() - parser.set_language(language) - tree = parser.parse(bytes(code, "utf-8")) - - def traverse(node, current_line): - - if node.type == 'function_definition': - function_name_node = node.children[1] - function_name = code[function_name_node.start_byte: function_name_node.end_byte] - file_summary.methods.append(SummaryPosition(function_name, node.start_point[0], node.end_point[0])) - - if node.type == 'function_declaration': - for child in node.children: - if child.type == 'identifier': - function_name = code[child.start_byte: child.end_byte] - file_summary.methods.append( - SummaryPosition(function_name, node.start_point[0], node.end_point[0])) - - if node.type == 'constructor_declaration': - function_name_node = node.children[1] - function_name = code[function_name_node.start_byte: function_name_node.end_byte] - file_summary.methods.append(SummaryPosition(function_name, node.start_point[0], node.end_point[0])) - - if node.type == 'method_declaration': - for child in node.children: - if child.type == 'identifier': - function_name = code[child.start_byte: child.end_byte] - file_summary.methods.append( - SummaryPosition(function_name, node.start_point[0], node.end_point[0])) - - if node.type == 'class_specifier': - class_name_node = node.children[1] - class_name = code[class_name_node.start_byte: class_name_node.end_byte] - file_summary.classes.append(SummaryPosition(class_name, node.start_point[0], node.end_point[0])) - - if node.type == 'class_declaration': - for child in node.children: - if child.type == 'identifier': - class_name = code[child.start_byte: child.end_byte] - file_summary.classes.append( - SummaryPosition(class_name, node.start_point[0], node.end_point[0])) - - for child in node.children: - traverse(child, current_line) - - root_node = tree.root_node - - traverse(root_node, 0) - - # methods and classes are not in order so sort - file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line) - file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line) - - return file_summary diff --git a/tests/crawler/test_process_file.py b/tests/crawler/test_process_file.py index 5d7396c..5d3848b 100644 --- a/tests/crawler/test_process_file.py +++ b/tests/crawler/test_process_file.py @@ -14,7 +14,8 @@ def hello_world(): hello_world() """ - docs = process_file([Document(page_content=PYTHON_CODE)], "/my/file/path/", "hello.py", ".py", 100, 0) + docs = process_file([Document(page_content=PYTHON_CODE)], "/my/file/path/", "hello.py", + ".py", 100, 0) expected_docs = [Document(page_content='The following code snippet is from a file at location ' '/my/file/path/hello.py starting at line 2 and ending at line 6. ' diff --git a/tests/parser/test_code_files/test_cpp_file.cpp b/tests/parser/test_code_files/test_cpp_file.cpp new file mode 100644 index 0000000..020ddab --- /dev/null +++ b/tests/parser/test_code_files/test_cpp_file.cpp @@ -0,0 +1,91 @@ +#include +#include + +// Using a class declaration +class RectangleClass { +private: + double length; + double width; + +public: + RectangleClass(double len, double wid) : length(len), width(wid) {} + + double calculateArea() { + return length * width; + } + + double calculatePerimeter() { + return 2 * (length + width); + } +}; + +// Using a separate function to create an object (Factory function) +class Square { +private: + double side; + +public: + Square(double s) : side(s) {} + + static Square createSquare(double side) { + return Square(side); + } +}; + +double calculateSquareArea(const Square& square) { + return square.side * square.side; +} + +double calculateSquarePerimeter(const Square& square) { + return 4 * square.side; +} + +// Using a class with static member functions +class CircleStatic { +private: + double radius; + +public: + CircleStatic(double r) : radius(r) {} + + static double calculateArea(double radius) { + return 3.14159 * radius * radius; + } + + static double calculateCircumference(double radius) { + return 2 * 3.14159 * radius; + } +}; + +// Using a class with instance member functions +class CircleInstance { +private: + double radius; + +public: + CircleInstance(double r) : radius(r) {} + + double calculateArea() { + return 3.14159 * radius * radius; + } + + double calculateCircumference() { + return 2 * 3.14159 * radius; + } +}; + +int main() { + // Creating objects and using functions/classes + RectangleClass rectangle(5.0, 3.0); + Square square = Square::createSquare(4.0); + CircleStatic circleStatic(2.0); + CircleInstance circleInstance(3.0); + + // Calculating and printing results + std::cout << "Rectangle Area: " << rectangle.calculateArea() << std::endl; + std::cout << "Square Area: " << calculateSquareArea(square) << std::endl; + std::cout << "Circle (Static) Area: " << CircleStatic::calculateArea(circleStatic) << std::endl; + std::cout << "Circle (Instance) Area: " << circleInstance.calculateArea() << std::endl; + + return 0; +} \ No newline at end of file diff --git a/tests/parser/test_code_files/test_go_file.go b/tests/parser/test_code_files/test_go_file.go new file mode 100644 index 0000000..a6853ce --- /dev/null +++ b/tests/parser/test_code_files/test_go_file.go @@ -0,0 +1,49 @@ +package main + +import ( + "fmt" +) + +// Define a struct called "Person" +type Person struct { + FirstName string + LastName string + Age int +} + +// Define a method called "FullName" for the "Person" struct +func (p Person) FullName() string { + return p.FirstName + " " + p.LastName +} + +type Shape interface { + Area() float64 +} + +// Define another struct called "Employee" that embeds "Person" +type Employee struct { + Person // Embedded "Person" struct + EmployeeID int +} + +// Define a method called "PrintEmployeeInfo" for the "Employee" struct +func (e Employee) PrintEmployeeInfo() { + fmt.Printf("Employee ID: %d\n", e.EmployeeID) + fmt.Printf("Full Name: %s\n", e.FullName()) // Accessing the method from the embedded struct +} + +func main() { + // Create an instance of "Employee" and initialize its fields + employee := Employee{ + Person: Person{ + FirstName: "John", + LastName: "Doe", + Age: 30, + }, + EmployeeID: 12345, + } + + // Access methods and fields of the "Employee" struct + fmt.Println("Employee Information:") + employee.PrintEmployeeInfo() +} diff --git a/tests/parser/test_code_files/test_java_file.java b/tests/parser/test_code_files/test_java_file.java new file mode 100644 index 0000000..001ea2b --- /dev/null +++ b/tests/parser/test_code_files/test_java_file.java @@ -0,0 +1,81 @@ +// Using a class declaration +class RectangleClass { + double length; + double width; + + RectangleClass(double length, double width) { + this.length = length; + this.width = width; + } + + double calculateArea() { + return length * width; + } + + double calculatePerimeter() { + return 2 * (length + width); + } +} + +// Using a separate function to create an object (Factory function) +class Square { + double side; + + Square(double side) { + this.side = side; + } + + static Square createSquare(double side) { + return new Square(side); + } +} + +// Using a class with static methods +class CircleStatic { + double radius; + + CircleStatic(double radius) { + this.radius = radius; + } + + static double calculateArea(double radius) { + return Math.PI * radius * radius; + } + + static double calculateCircumference(double radius) { + return 2 * Math.PI * radius; + } +} + +// Using a class with instance methods +class CircleInstance { + double radius; + + CircleInstance(double radius) { + this.radius = radius; + } + + double calculateArea() { + return Math.PI * radius * radius; + } + + double calculateCircumference() { + return 2 * Math.PI * radius; + } +} + +public class Main { + public static void main(String[] args) { + // Creating objects and using functions/classes + RectangleClass rectangle = new RectangleClass(5.0, 3.0); + Square square = Square.createSquare(4.0); + CircleStatic circleStatic = new CircleStatic(2.0); + CircleInstance circleInstance = new CircleInstance(3.0); + + // Calculating and printing results + System.out.println("Rectangle Area: " + rectangle.calculateArea()); + System.out.println("Square Area: " + calculateSquareArea(square)); + System.out.println("Circle (Static) Area: " + CircleStatic.calculateArea(circleStatic.radius)); + System.out.println("Circle (Instance) Area: " + circleInstance.calculateArea()); + } +} diff --git a/tests/parser/test_code_files/test_js_file.js b/tests/parser/test_code_files/test_js_file.js new file mode 100644 index 0000000..7bca653 --- /dev/null +++ b/tests/parser/test_code_files/test_js_file.js @@ -0,0 +1,15 @@ +// Using a class declaration +class Rectangle { + constructor(length, width) { + this.length = length; + this.width = width; + } + + calculateArea() { + return this.length * this.width; + } + + calculatePerimeter() { + return 2 * (this.length + this.width); + } +} \ No newline at end of file diff --git a/tests/parser/test_code_files/test_python_file.py b/tests/parser/test_code_files/test_python_file.py new file mode 100644 index 0000000..1d22404 --- /dev/null +++ b/tests/parser/test_code_files/test_python_file.py @@ -0,0 +1,52 @@ +# Using a class definition +class RectangleClass: + def __init__(self, length, width): + self.length = length + self.width = width + + def calculate_area(self): + return self.length * self.width + + def calculate_perimeter(self): + return 2 * (self.length + self.width) + + +# Using a function to create an object (Factory function) +def create_square(side_length): + return {"side_length": side_length} + +def square_area(square): + return square["side_length"] ** 2 + +def square_perimeter(square): + return 4 * square["side_length"] + + +# Using a class with @staticmethod decorator +class CircleStatic: + def __init__(self, radius): + self.radius = radius + + @staticmethod + def calculate_area(radius): + return 3.14159 * radius * radius + + @staticmethod + def calculate_circumference(radius): + return 2 * 3.14159 * radius + + +# Using a class with @classmethod decorator +class CircleClassMethod: + def __init__(self, radius): + self.radius = radius + + @classmethod + def create_circle(cls, diameter): + return cls(diameter / 2) + + def calculate_area(self): + return 3.14159 * self.radius * self.radius + + def calculate_circumference(self): + return 2 * 3.14159 * self.radius \ No newline at end of file diff --git a/tests/parser/test_treesitter_parser.py b/tests/parser/test_treesitter_parser.py index 1e9e7a5..862c75c 100644 --- a/tests/parser/test_treesitter_parser.py +++ b/tests/parser/test_treesitter_parser.py @@ -1,87 +1,98 @@ import unittest -from repogpt.parsers.treesitter_parser import TreeSitterParser -from repogpt.parsers.base import SummaryPosition +from repogpt.parsers.java_treesitter_parser import JavaTreeSitterParser +from repogpt.parsers.cpp_treesitter_parser import CppTreeSitterParser +from repogpt.parsers.go_treesitter_parser import GoTreeSitterParser +from repogpt.parsers.js_treesitter_parser import JsTreeSitterParser +import os class TreeSitterParserTest(unittest.TestCase): + def test_treesitter_parser_java_file(self): + test_dir = os.path.dirname(os.path.abspath(__file__)) + relative_path = "test_code_files/test_java_file.java" + file_path = os.path.join(test_dir, relative_path) - def test_treesitter_parser_cpp(self): - cpp_code = ''' - #include - using namespace std; - - class MyClass { - public: - void greet() { - cout << "Hello!" << endl; - } - }; - ''' - - tsp = TreeSitterParser() - fs = tsp.get_file_summary(cpp_code, "test.cpp") - - expected_methods = [SummaryPosition("greet()", 6, 8)] - expected_classes = [SummaryPosition("MyClass", 4, 9)] - - assert expected_methods[0].name == fs.methods[0].name \ - and expected_methods[0].start_line == fs.methods[0].start_line \ - and expected_methods[0].end_line == fs.methods[0].end_line \ - and expected_classes[0].name == fs.classes[0].name \ - and expected_classes[0].start_line == fs.classes[0].start_line \ - and expected_classes[0].end_line == fs.classes[0].end_line - - def test_treesitter_parser_java(self): - java_code = """ -public class IntegerSequenceTest { - @Test - public void testRangeMultipleIterations() { - // Check that we can iterate several times using the same instance. - final int start = 1; - final int max = 7; - final int step = 2; - - final List seq = new ArrayList<>(); - final IntegerSequence.Range r = IntegerSequence.range(start, max, step); - - final int numTimes = 3; - for (int n = 0; n < numTimes; n++) { - seq.clear(); - for (Integer i : r) { - seq.add(i); - } - Assert.assertEquals(4, seq.size()); - Assert.assertEquals(seq.size(), r.size()); - } - } - """ - - tsp = TreeSitterParser() - fs = tsp.get_file_summary(java_code, "test.java") - - expected_methods = [SummaryPosition("testRangeMultipleIterations", 2, 21)] - - assert expected_methods[0].name == fs.methods[0].name \ - and expected_methods[0].start_line == fs.methods[0].start_line \ - and expected_methods[0].end_line == fs.methods[0].end_line - - def test_treesitter_parser_js(self): - js_code = """ -export default function enqueueTask(task: () => void): void { - const channel = new MessageChannel(); - channel.port1.onmessage = () => { - channel.port1.close(); - task(); - }; - channel.port2.postMessage(undefined); -} - """ - - tsp = TreeSitterParser() - fs = tsp.get_file_summary(js_code, "test.js") - - expected_methods = [SummaryPosition("enqueueTask", 1, 8)] - - assert expected_methods[0].name == fs.methods[0].name \ - and expected_methods[0].start_line == fs.methods[0].start_line \ - and expected_methods[0].end_line == fs.methods[0].end_line + with open(file_path, 'r') as file: + file_contents = file.read() + + tsp = JavaTreeSitterParser() + fs = tsp.get_file_summary(file_contents, "test.java") + + actual_methods = [(m.name, m.start_line, m.end_line) for m in fs.methods] + actual_classes = [(c.name, c.start_line, c.end_line) for c in fs.classes] + + expected_methods =[('RectangleClass', 5, 8), ('calculateArea', 10, 12), + ('calculatePerimeter', 14, 16), ('Square', 23, 25), + ('createSquare', 27, 29), ('CircleStatic', 36, 38), + ('calculateArea', 40, 42), ('calculateCircumference', 44, 46), + ('CircleInstance', 53, 55), ('calculateArea', 57, 59), + ('calculateCircumference', 61, 63), ('main', 67, 79)] + expected_classes = [('RectangleClass', 1, 17), ('Square', 20, 30), ('CircleStatic', 33, 47), + ('CircleInstance', 50, 64), ('Main', 66, 80)] + + assert actual_methods == expected_methods and actual_classes == expected_classes + + def test_treesitter_parser_cpp_file(self): + test_dir = os.path.dirname(os.path.abspath(__file__)) + relative_path = "test_code_files/test_cpp_file.cpp" + file_path = os.path.join(test_dir, relative_path) + + with open(file_path, 'r') as file: + file_contents = file.read() + + tsp = CppTreeSitterParser() + fs = tsp.get_file_summary(file_contents, "test.cpp") + + actual_methods = [(m.name, m.start_line, m.end_line) for m in fs.methods] + actual_classes = [(c.name, c.start_line, c.end_line) for c in fs.classes] + + expected_methods = [('RectangleClass', 10, 10), ('calculateArea', 12, 12), + ('calculatePerimeter', 16, 16), ('Square', 27, 27), + ('createSquare', 29, 29), ('calculateSquareArea', 34, 34), + ('calculateSquarePerimeter', 38, 38), ('CircleStatic', 48, 48), + ('calculateArea', 50, 50), ('calculateCircumference', 54, 54), + ('CircleInstance', 65, 65), ('calculateArea', 67, 67), + ('calculateCircumference', 71, 71), ('main', 76, 76)] + expected_classes = [('RectangleClass', 4, 19), ('Square', 22, 32), ('CircleStatic', 43, 57), + ('CircleInstance', 60, 74)] + + assert actual_methods == expected_methods and actual_classes == expected_classes + + def test_treesitter_parser_js_file(self): + test_dir = os.path.dirname(os.path.abspath(__file__)) + relative_path = "test_code_files/test_js_file.js" + file_path = os.path.join(test_dir, relative_path) + + with open(file_path, 'r') as file: + file_contents = file.read() + + tsp = JsTreeSitterParser() + fs = tsp.get_file_summary(file_contents, "test.js") + + actual_methods = [(m.name, m.start_line, m.end_line) for m in fs.methods] + actual_classes = [(c.name, c.start_line, c.end_line) for c in fs.classes] + + expected_methods = [('constructor', 2, 2), ('calculateArea', 7, 7), + ('calculatePerimeter', 11, 11)] + expected_classes = [('Rectangle', 1, 14)] + + assert actual_methods == expected_methods and actual_classes == expected_classes + + def test_treesitter_parser_go_file(self): + test_dir = os.path.dirname(os.path.abspath(__file__)) + relative_path = "test_code_files/test_go_file.go" + file_path = os.path.join(test_dir, relative_path) + + with open(file_path, 'r') as file: + file_contents = file.read() + + tsp = GoTreeSitterParser() + fs = tsp.get_file_summary(file_contents, "test.go") + + actual_methods = [(m.name, m.start_line, m.end_line) for m in fs.methods] + actual_classes = [(c.name, c.start_line, c.end_line) for c in fs.classes] + + expected_methods = [('FullName', 14, 16), ('PrintEmployeeInfo', 29, 32), ('main', 34, 48)] + expected_classes = [('Person', 7, 11), ('Shape', 18, 20), ('Employee', 23, 26)] + + assert actual_methods == expected_methods and actual_classes == expected_classes