Skip to content

Commit

Permalink
Merge pull request #10 from alexminnaar/readme
Browse files Browse the repository at this point in the history
adding treesitter parser
  • Loading branch information
alexminnaar authored Aug 3, 2023
2 parents 019fa70 + 9a6e53c commit e3e874e
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 123 deletions.
62 changes: 32 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,38 @@ RepoGPT adds additional context to the chunk including
* A summary of the classes and methods contained in the file.
* The line number where the chunk appears in the file.

## Usage

### 1. Create a config.ini File
The `config.ini` file sets the parameters that RepoGPT needs to run. They are

* `REPO_PATH`: The path to the root directory of the git repo.
* `VS_PATH`: The path where the vector store will be created.
* `NUM_RESULTS`: The number of search results returned by the vector store for a given query.
* `EMBEDDING_TYPE`: The name of the embedding being used.
* `MODEL_NAME`: The name of the LLM to use.

Example `config.ini` files can be found in the `example_config_files` directory in this repo.


### 2. Initialize Repo
This step crawls and indexes the repo specified in `example_config.ini`.
```commandline
python cli.py --init example_config.ini
```

### 3. Ask Questions
Run the command
```commandline
python cli.py example_config.ini
```
you should then see

```commandline
Ask a question:
```
Then ask your question and wait for the response. To exit, type 'exit'.

## Demo

In this demo, the [Pandas](https://github.com/pandas-dev/pandas/tree/main) python library repo has been crawled and
Expand Down Expand Up @@ -84,36 +116,6 @@ The following languages/file types can be crawled with RepoGPT

⚠️ Warning: Crawling a large repo while using OpenAI embeddings could result in many thousands of embedding requests ⚠️

## Usage

### 1. Create a config.ini File
The `config.ini` file sets the parameters that RepoGPT needs to run. They are

* `REPO_PATH`: The path to the root directory of the git repo.
* `VS_PATH`: The path where the vector store will be created.
* `NUM_RESULTS`: The number of search results returned by the vector store for a given query.
* `EMBEDDING_TYPE`: The name of the embedding being used.
* `MODEL_NAME`: The name of the LLM to use.

Example `config.ini` files can be found in the `example_config_files` directory in this repo.


### 2. Initialize Repo
This step crawls and indexes the repo specified in `example_config.ini`.
```commandline
python cli.py --init example_config.ini
```

### 3. Ask Questions
Run the command
```commandline
python cli.py example_config.ini
```
you should then see

```commandline
Ask a question:
```
Then ask your question and wait for the response. To exit, type 'exit'.


26 changes: 12 additions & 14 deletions repogpt/crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import DeepLake
from repogpt.parsers.pygments_parser import PygmentsParser
from repogpt.parsers.python_parser import PythonParser
from multiprocessing import Pool
from repogpt.parsers.treesitter_parser import TreeSitterParser
from tqdm import tqdm
from typing import List, Optional
import os
Expand Down Expand Up @@ -73,7 +72,7 @@ def process_file(
if extension == '.py':
file_summary = PythonParser.get_file_summary(file_doc.page_content, file_name)
else:
file_summary = PygmentsParser.get_file_summary(file_doc.page_content, file_name)
file_summary = TreeSitterParser.get_file_summary(file_doc.page_content, file_name)

# split file contents based on file extension
splitter = RecursiveCharacterTextSplitter.from_language(
Expand All @@ -85,22 +84,21 @@ def process_file(
for doc in split_docs:
starting_line = file_doc.page_content[:doc.metadata['start_index']].count('\n') + 1
ending_line = starting_line + doc.page_content.count('\n')
doc.metadata['starting_line'] = starting_line
doc.metadata['ending_line'] = ending_line

# get methods and classes associated with chunk
if extension == '.py':
method_class_summary = PythonParser.get_closest_method_class_in_snippet(file_summary, starting_line,
ending_line)
else:
method_class_summary = PygmentsParser.get_closest_method_class_in_snippet(file_summary, starting_line,
ending_line)
method_class_summary = TreeSitterParser.get_closest_method_class_in_snippet(file_summary, starting_line,
ending_line)
doc.page_content = f"The following code snippet is from a file at location " \
f"{os.path.join(dir_path, file_name)} " \
f"starting at line {starting_line} and ending at line {ending_line}. " \
f"{method_class_summary} " \
f"The code snippet starting at line {starting_line} and ending at line " \
f"{ending_line} is \n ```\n{doc.page_content}\n``` "

return split_docs


Expand Down Expand Up @@ -138,13 +136,13 @@ def crawl_and_split(root_dir: str, chunk_size: int = 3000, chunk_overlap: int =

process_and_split_partial_function = partial(process_and_split, chunk_size=chunk_size, chunk_overlap=chunk_overlap)

with Pool(processes=os.cpu_count()) as pool:
split_docs = []
with tqdm(total=len(filtered_files), desc='Chunking documents...', ncols=80) as pbar:
for i, docs in enumerate(pool.imap_unordered(process_and_split_partial_function, filtered_files)):
if docs:
split_docs.extend(docs)
pbar.update()
split_docs = []
with tqdm(total=len(filtered_files), desc='Chunking documents...', ncols=80) as pbar:
for ff in filtered_files:
docs = process_and_split_partial_function(ff)
if docs:
split_docs.extend(docs)
pbar.update()

return split_docs

Expand Down
73 changes: 65 additions & 8 deletions repogpt/parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,73 @@ def add_method(self, method_name: str, method_start_line: int, method_end_line:

class CodeParser(ABC):
@staticmethod
@abstractmethod
def get_summary_from_position(summary_positions: List[SummaryPosition], start_line: int,
end_line: int) -> Tuple[List[SummaryPosition], List[SummaryPosition]]:
"""Helper function to get object positions within snippet"""
def get_summary_from_position(
summary_positions: List[SummaryPosition],
start_line: int,
end_line: int
) -> Tuple[List[SummaryPosition], List[SummaryPosition]]:

last_obj = []
current_obj = []

# TODO: binary search-ify this
for s_pos in summary_positions:
# get last defined method before the snippet
if s_pos.start_line < start_line and s_pos.end_line >= start_line:
last_obj.append(s_pos)

# get any method defined in this snippet
if start_line <= s_pos.start_line <= end_line:
current_obj.append(s_pos)

# ignore everything past this snippet
if s_pos.start_line > end_line:
break
return last_obj, current_obj

@staticmethod
@abstractmethod
def get_closest_method_class_in_snippet(file_summary: FileSummary, snippet_start_line: int,
snippet_end_line: int) -> str:
"""Get the relevent methods and classes in a snippet and convert to prompt"""
def get_closest_method_class_in_snippet(
file_summary: FileSummary,
snippet_start_line: int,
snippet_end_line: int
) -> str:
closest_method_class_summary = ""

last_class, current_class = CodeParser.get_summary_from_position(file_summary.classes, snippet_start_line,
snippet_end_line)

if last_class:
closest_method_class_summary += f" The last class defined before this snippet was called " \
f"`{last_class[-1].name}` starting at line {last_class[-1].start_line} " \
f"and ending at line {last_class[-1].end_line}."
if len(current_class) == 1:
closest_method_class_summary += f" The class defined in this snippet is called `{current_class[0].name}`" \
f"starting at line {current_class[0].start_line} and ending at line " \
f"{current_class[0].end_line}."
elif len(current_class) > 1:
multi_class_summary = " and ".join(
[f"`{c.name}` starting at line {c.start_line} and ending at line {c.end_line}" for c in current_class])
closest_method_class_summary += f" The classes defined in this snippet are {multi_class_summary}."

last_method, current_method = CodeParser.get_summary_from_position(file_summary.methods, snippet_start_line,
snippet_end_line)

if last_method:
closest_method_class_summary += f" The last method starting before this snippet is called " \
f"`{last_method[-1].name}` which starts on line " \
f"{last_method[-1].start_line} and ends at line {last_method[-1].end_line}."
if len(current_method) == 1:
closest_method_class_summary += f" The method defined in this snippet is called " \
f"`{current_method[0].name}` starting at line " \
f"{current_method[0].start_line} and ending at line " \
f"{current_method[0].end_line}."
elif len(current_method) > 1:
multi_method_summary = " and ".join(
[f"`{meth.name}` starting at line {meth.start_line} and ending at line {meth.end_line}" for meth in
current_method])
closest_method_class_summary += f" The methods defined in this snippet are {multi_method_summary}."

return closest_method_class_summary

@staticmethod
@abstractmethod
Expand Down
68 changes: 0 additions & 68 deletions repogpt/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,74 +4,6 @@


class PythonParser(CodeParser):
@staticmethod
def get_summary_from_position(
summary_positions: List[SummaryPosition],
start_line: int,
end_line: int
) -> Tuple[List[SummaryPosition], List[SummaryPosition]]:

last_obj = []
current_obj = []

# TODO: binary search-ify this
for s_pos in summary_positions:
# get last defined method before the snippet
if s_pos.start_line < start_line and s_pos.end_line >= start_line:
last_obj.append(s_pos)

# get any method defined in this snippet
if start_line <= s_pos.start_line <= end_line:
current_obj.append(s_pos)

# ignore everything past this snippet
if s_pos.start_line > end_line:
break
return last_obj, current_obj

@staticmethod
def get_closest_method_class_in_snippet(
file_summary: FileSummary,
snippet_start_line: int,
snippet_end_line: int
) -> str:
closest_method_class_summary = ""

last_class, current_class = PythonParser.get_summary_from_position(file_summary.classes, snippet_start_line,
snippet_end_line)

if last_class:
closest_method_class_summary += f" The last class defined before this snippet was called " \
f"`{last_class[-1].name}` starting at line {last_class[-1].start_line} " \
f"and ending at line {last_class[-1].end_line}."
if len(current_class) == 1:
closest_method_class_summary += f" The class defined in this snippet is called `{current_class[0].name}`" \
f"starting at line {current_class[0].start_line} and ending at line " \
f"{current_class[0].end_line}."
elif len(current_class) > 1:
multi_class_summary = " and ".join(
[f"`{c.name}` starting at line {c.start_line} and ending at line {c.end_line}" for c in current_class])
closest_method_class_summary += f" The classes defined in this snippet are {multi_class_summary}."

last_method, current_method = PythonParser.get_summary_from_position(file_summary.methods, snippet_start_line,
snippet_end_line)

if last_method:
closest_method_class_summary += f" The last method starting before this snippet is called " \
f"`{last_method[-1].name}` which starts on line " \
f"{last_method[-1].start_line} and ends at line {last_method[-1].end_line}."
if len(current_method) == 1:
closest_method_class_summary += f" The method defined in this snippet is called " \
f"`{current_method[0].name}` starting at line " \
f"{current_method[0].start_line} and ending at line " \
f"{current_method[0].end_line}."
elif len(current_method) > 1:
multi_method_summary = " and ".join(
[f"`{meth.name}` starting at line {meth.start_line} and ending at line {meth.end_line}" for meth in
current_method])
closest_method_class_summary += f" The methods defined in this snippet are {multi_method_summary}."

return closest_method_class_summary

@staticmethod
def get_file_summary(code: str, file_name: str) -> FileSummary:
Expand Down
78 changes: 78 additions & 0 deletions repogpt/parsers/treesitter_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from repogpt.parsers.base import CodeParser, FileSummary, SummaryPosition
import subprocess
from tree_sitter import Language, Parser
import os


class TreeSitterParser(CodeParser):

languages = {}
loaded = False

@staticmethod
def initialize_treesitter():
LANGUAGE_NAMES = ["python", "java", "cpp", "go", "rust", "ruby", "php"]
for language in LANGUAGE_NAMES:
subprocess.run(
f"git clone https://github.com/tree-sitter/tree-sitter-{language} cache/tree-sitter-{language}",
shell=True)
for language in LANGUAGE_NAMES:
Language.build_library(f'cache/build/{language}.so', [f"cache/tree-sitter-{language}"])
subprocess.run(f"cp cache/build/{language}.so /tmp/{language}.so", shell=True)
TreeSitterParser.languages = {language: Language(f"/tmp/{language}.so", language) for language in LANGUAGE_NAMES}
TreeSitterParser.loaded = True

@staticmethod
def get_file_summary(code: str, file_name: str) -> FileSummary:

# download and build treesitter shared objects if not already done.
if not TreeSitterParser.loaded:
TreeSitterParser.initialize_treesitter()

extension_to_language = {
".py": "python",
".rs": "rust",
".go": "go",
".java": "java",
".cpp": "cpp",
".cc": "cpp",
".cxx": "cpp",
".c": "cpp",
".h": "cpp",
".hpp": "cpp",
".rb": "ruby",
".php": "php",
}

file_summary = FileSummary()
_, extension = os.path.splitext(file_name)

if extension in extension_to_language:
language = TreeSitterParser.languages[extension_to_language[extension]]
parser = Parser()
parser.set_language(language)
tree = parser.parse(bytes(code, "utf-8"))

def traverse(node, current_line):
if node.type == 'function_definition':
function_name_node = node.children[1]
function_name = code[function_name_node.start_byte: function_name_node.end_byte]
file_summary.methods.append(SummaryPosition(function_name, node.start_point[0], node.end_point[0]))

if node.type == 'class_specifier':
class_name_node = node.children[1]
class_name = code[class_name_node.start_byte: class_name_node.end_byte]
file_summary.classes.append(SummaryPosition(class_name, node.start_point[0], node.end_point[0]))

for child in node.children:
traverse(child, current_line)

root_node = tree.root_node

traverse(root_node, 0)

# methods and classes are not in order so sort
file_summary.methods = sorted(file_summary.methods, key=lambda x: x.start_line)
file_summary.classes = sorted(file_summary.classes, key=lambda x: x.start_line)

return file_summary
1 change: 0 additions & 1 deletion repogpt/qa/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def get_resp(self, query_str: str) -> str:
for chunk in similar_chunks:
print(Fore.RED +
f"{chunk.metadata['source']} - lines {chunk.metadata['starting_line']} - {chunk.metadata['ending_line']}")

qa_prompt = self.create_prompt(query_str, similar_chunks)
print("Computing response...")
return self.llm(qa_prompt)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ tqdm
deeplake
gpt4all
llama-cpp-python
colorama
colorama
tree_sitter
Loading

0 comments on commit e3e874e

Please sign in to comment.