Skip to content

Commit

Permalink
chore: more API for query
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Nov 23, 2023
1 parent fcb65f7 commit 3a093cf
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
3 changes: 3 additions & 0 deletions examples/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@
tagger.config.tags = ["storage"]
tag_dict = tagger.tag(storage)
print(tag_dict)
tag_dict.export_csv()
topics = tag_dict.top_n_tags("srctag/storage.py", 5)
print(topics)
2 changes: 1 addition & 1 deletion examples/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

collector = Collector()
collector.config.repo_root = os.path.dirname(os.path.dirname(__file__))
collector.config.file_level = FileLevelEnum.DIR
collector.config.file_level = FileLevelEnum.FILE
collector.config.max_depth_limit = 16

ctx = collector.collect_metadata()
Expand Down
16 changes: 11 additions & 5 deletions srctag/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import pandas as pd
from chromadb import QueryResult, Metadata
from pandas import Index, DataFrame
from pandas import Index
from pydantic_settings import BaseSettings
from tqdm import tqdm
from loguru import logger

from srctag.storage import Storage

Expand All @@ -19,9 +21,11 @@ def export_csv(self, path: str = "srctag-output.csv") -> None:
def tags(self) -> Index:
return self.scores_df.columns

def top_n(self, path: str, n: int) -> DataFrame:
row = self.scores_df.loc[path]
return row.nlargest(n)
def top_n_tags(self, file_name, n) -> typing.List[str]:
return self.scores_df.loc[file_name].nlargest(n).index.tolist()

def top_n_files(self, tag_name, n) -> typing.List[str]:
return self.scores_df.nlargest(n, tag_name).index.tolist()


class TaggerConfig(BaseSettings):
Expand All @@ -44,8 +48,9 @@ def tag(self, storage: Storage) -> TagResult:
file_count = storage.chromadb_collection.count()
n_results = int(file_count * self.config.n_percent)

logger.info(f"start tagging source files ...")
ret = dict()
for each_tag in self.config.tags:
for each_tag in tqdm(self.config.tags):
query_result: QueryResult = storage.chromadb_collection.query(
query_texts=each_tag,
n_results=n_results,
Expand Down Expand Up @@ -73,5 +78,6 @@ def tag(self, storage: Storage) -> TagResult:
# END file loop
# END tag loop

logger.info(f"tag finished")
scores_df = pd.DataFrame.from_dict(ret, orient="index")
return TagResult(scores_df=scores_df)
12 changes: 10 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ def test_api():
tagger.config.tags = [
"embedding",
"search",
"tag",
"test",
"example"
]
tag_result = tagger.tag(storage)
logger.info(f"tags: {tag_result.tags().array}")
assert tag_result.tags().array

assert tag_result.top_n_tags("srctag/storage.py", 1)
assert tag_result.top_n_files("embedding", 1)

# result check
assert tag_result.top_n_tags("srctag/storage.py", 1)[0] == "embedding"
assert tag_result.top_n_tags("examples/read.py", 1)[0] == "example"

0 comments on commit 3a093cf

Please sign in to comment.