-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update code to pass entire py file as an input
- Loading branch information
Showing
5 changed files
with
165 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,58 @@ | ||
import argparse | ||
from DocuFlow.file_traverser import run_doc_generator | ||
import os | ||
import sys | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | ||
|
||
def summarize_code(code_snippet, tokenizer, model): | ||
inputs = tokenizer.encode(code_snippet, return_tensors='pt', truncation=True, max_length=512) | ||
outputs = model.generate( | ||
inputs, | ||
max_length=150, | ||
num_beams=4, | ||
early_stopping=True, | ||
no_repeat_ngram_size=2 | ||
) | ||
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) | ||
return summary.strip() | ||
|
||
def generate_description(file_path, model_name='Salesforce/codet5-small'): | ||
# Load the tokenizer and model | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | ||
|
||
# Read the entire code file | ||
with open(file_path, 'r', encoding='utf-8') as file: | ||
code = file.read() | ||
|
||
# Summarize the code | ||
print("Summarizing the code...") | ||
module_summary = summarize_code(code, tokenizer, model) | ||
|
||
# Prepare markdown content | ||
md_content = f"# Module Summary\n\n{module_summary}\n" | ||
|
||
return md_content | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Generate Markdown documentation from Python code.") | ||
parser.add_argument('project_path', help="Path to the Python project") | ||
parser.add_argument('--output', default='docs', help="Output directory for documentation") | ||
|
||
parser = argparse.ArgumentParser(description='Generate markdown description from Python file.') | ||
parser.add_argument('file_path', help='Path to the Python (.py) file.') | ||
parser.add_argument('--model-name', default='Salesforce/codet5-small', help='Name of the open-source LLM to use.') | ||
args = parser.parse_args() | ||
run_doc_generator(args.project_path, args.output) | ||
|
||
if __name__ == "__main__": | ||
file_path = args.file_path | ||
model_name = args.model_name | ||
|
||
if not os.path.isfile(file_path): | ||
print(f"File {file_path} does not exist.") | ||
sys.exit(1) | ||
|
||
description = generate_description(file_path, model_name) | ||
|
||
md_file_path = os.path.splitext(file_path)[0] + '.md' | ||
with open(md_file_path, 'w', encoding='utf-8') as md_file: | ||
md_file.write(description) | ||
|
||
print(f"Markdown description generated and saved to {md_file_path}") | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import requests | ||
from abc import ABC, abstractmethod | ||
|
||
class DataFetcher(ABC): | ||
"""Abstract base class for data fetchers.""" | ||
|
||
@abstractmethod | ||
def fetch(self, source): | ||
"""Fetch data from a source.""" | ||
pass | ||
|
||
class HTTPDataFetcher(DataFetcher): | ||
"""Fetch data over HTTP.""" | ||
|
||
def fetch(self, url): | ||
"""Fetch data from a URL.""" | ||
try: | ||
response = requests.get(url) | ||
response.raise_for_status() | ||
return response.text | ||
except requests.RequestException as e: | ||
print(f"Error fetching data from {url}: {e}") | ||
return None | ||
|
||
class FileDataFetcher(DataFetcher): | ||
"""Fetch data from a local file.""" | ||
|
||
def fetch(self, file_path): | ||
"""Fetch data from a file.""" | ||
try: | ||
with open(file_path, 'r') as file: | ||
return file.read() | ||
except IOError as e: | ||
print(f"Error reading file {file_path}: {e}") | ||
return None | ||
|
||
class DataProcessor: | ||
"""Process data fetched from a source.""" | ||
|
||
def __init__(self, fetcher): | ||
self.fetcher = fetcher | ||
|
||
def process(self, source): | ||
"""Fetch and process data from the source.""" | ||
data = self.fetcher.fetch(source) | ||
if data: | ||
print(f"Processing data from {source}") | ||
# Add processing logic here | ||
return data | ||
else: | ||
print(f"No data to process from {source}") | ||
return None | ||
|
||
def main(): | ||
url = 'https://api.github.com' | ||
file_path = 'data.txt' | ||
|
||
http_fetcher = HTTPDataFetcher() | ||
file_fetcher = FileDataFetcher() | ||
|
||
http_processor = DataProcessor(http_fetcher) | ||
file_processor = DataProcessor(file_fetcher) | ||
|
||
http_data = http_processor.process(url) | ||
file_data = file_processor.process(file_path) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import json | ||
|
||
class Person: | ||
"""A class representing a person.""" | ||
|
||
def __init__(self, name, age): | ||
"""Initialize the person's name and age.""" | ||
self.name = name | ||
self.age = age | ||
|
||
def to_json(self): | ||
"""Convert the person's data to JSON.""" | ||
try: | ||
return json.dumps({'name': self.name, 'age': self.age}) | ||
except TypeError as e: | ||
print(f"Error converting to JSON: {e}") | ||
return None | ||
|
||
def load_people(file_path): | ||
"""Load a list of people from a JSON file.""" | ||
try: | ||
with open(file_path, 'r') as file: | ||
people_data = json.load(file) | ||
return [Person(**data) for data in people_data] | ||
except (IOError, json.JSONDecodeError) as e: | ||
print(f"Error loading people from {file_path}: {e}") | ||
return [] | ||
|
||
if __name__ == '__main__': | ||
alice = Person('Alice', 30) | ||
print(alice.to_json()) | ||
|
||
people = load_people('people.json') | ||
for person in people: | ||
print(f"{person.name} is {person.age} years old.") |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import math | ||
|
||
def calculate_area(radius): | ||
"""Calculate the area of a circle given its radius.""" | ||
return math.pi * radius ** 2 | ||
|
||
if __name__ == '__main__': | ||
r = 5 | ||
area = calculate_area(r) | ||
print(f"The area of a circle with radius {r} is {area:.2f}") |