Skip to content

Commit

Permalink
feat: use pandas for result management
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Nov 23, 2023
1 parent e70b1b6 commit fcb65f7
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 33 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ jobs:
poetry install --all-extras
poetry run python3 examples/write.py
poetry run python3 examples/read.py
poetry run pytest
150 changes: 149 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ packages = [
{ include = "srctag" }
]

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Expand All @@ -21,6 +24,7 @@ pydantic-settings = "*"
pydantic = "*"
tqdm = "*"
loguru = "^0.7.2"
pandas = "^2.0.3"

# actually srctag still requires `sentence_transformers` here
# but pytorch is a large dep which I don't want to manage it here
Expand Down
4 changes: 4 additions & 0 deletions srctag/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def init_chroma(self):
)

def embed_file(self, file: FileContext):
if not file.commits:
logger.warning(f"no related commits found: {file.name}")
return

self.init_chroma()
sentences = [each.message.split(os.linesep)[0] for each in file.commits]

Expand Down
40 changes: 17 additions & 23 deletions srctag/tagger.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
import csv
import typing
from collections import OrderedDict

import pandas as pd
from chromadb import QueryResult, Metadata
from pydantic import BaseModel
from pandas import Index, DataFrame
from pydantic_settings import BaseSettings

from srctag.storage import Storage

# tag name -> distance score
SingleTagResult = typing.Dict[str, float]


class TagResult(BaseModel):
# todo: need some other ways for querying flexibly
files: typing.Dict[str, SingleTagResult] = dict()
class TagResult(object):
def __init__(self, scores_df: pd.DataFrame):
self.scores_df = scores_df

def export_csv(self, path: str = "srctag-output.csv") -> None:
file_list = self.files.keys()
col_list = set().union(*[d.keys() for d in self.files.values()])
with open(path, "w", newline="", encoding="utf-8-sig") as file:
writer = csv.writer(file)
self.scores_df.to_csv(path)

header = [""] + list(col_list)
writer.writerow(header)
def tags(self) -> Index:
return self.scores_df.columns

for each_file in file_list:
row = [each_file] + [
self.files[each_file].get(subkey, "-1") for subkey in col_list
]
writer.writerow(row)
def top_n(self, path: str, n: int) -> DataFrame:
row = self.scores_df.loc[path]
return row.nlargest(n)


class TaggerConfig(BaseSettings):
Expand Down Expand Up @@ -59,7 +51,8 @@ def tag(self, storage: Storage) -> TagResult:
n_results=n_results,
include=["metadatas", "distances"],
)
files: typing.List[Metadata] = query_result["metadatas"][0]

metadatas: typing.List[Metadata] = query_result["metadatas"][0]
distances: typing.List[float] = query_result["distances"][0]

minimum = min(distances)
Expand All @@ -72,12 +65,13 @@ def tag(self, storage: Storage) -> TagResult:
1 - ((x - minimum) / (maximum - minimum)) for x in distances
]

for each_file, each_score in zip(files, normalized_scores):
each_file_name = each_file["source"]
for each_metadata, each_score in zip(metadatas, normalized_scores):
each_file_name = each_metadata["source"]
if each_file_name not in ret:
ret[each_file_name] = OrderedDict()
ret[each_file_name][each_tag] = each_score
# END file loop
# END tag loop

return TagResult(files=ret)
scores_df = pd.DataFrame.from_dict(ret, orient="index")
return TagResult(scores_df=scores_df)
29 changes: 20 additions & 9 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
from srctag import CollectorLayer, Storage
import os

def test_abc():
collector = CollectorLayer()
collector.config.repo_root = "."
collector.config.max_depth_limit = 1
from loguru import logger

from srctag.collector import Collector
from srctag.storage import Storage
from srctag.tagger import Tagger


def test_api():
collector = Collector()
collector.config.repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ctx = collector.collect_metadata()
storage = Storage()
for each_file in ctx.files.values():
storage.embed_file(each_file)
storage.embed_ctx(ctx)

result = storage.chromadb_collection.query(query_texts=["docs"])
print(result)
tagger = Tagger()
tagger.config.tags = [
"embedding",
"search",
]
tag_result = tagger.tag(storage)
logger.info(f"tags: {tag_result.tags().array}")
assert tag_result.tags().array

0 comments on commit fcb65f7

Please sign in to comment.