Skip to content

Commit

Permalink
chore: temp
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Nov 30, 2023
1 parent c7c606a commit a3b0f83
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
4 changes: 3 additions & 1 deletion srctag/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ def prepare():
@click.option("--output-path", default="", help="Output file path for CSV")
@click.option("--file-level", default=FileLevelEnum.FILE.value, help="Scan file level, FILE or DIR, default to FILE")
@click.option("--st-model", default="", help="Sentence Transformer Model")
def tag(repo_root, max_depth_limit, include_regex, tags_file, output_path, file_level, st_model):
@click.option("--commit-include-regex", default="", help="Commit message include regex pattern")
def tag(repo_root, max_depth_limit, include_regex, tags_file, output_path, file_level, st_model, commit_include_regex):
""" tag your repo """
collector = Collector()
collector.config.repo_root = repo_root
collector.config.max_depth_limit = max_depth_limit
collector.config.include_regex = include_regex
collector.config.file_level = file_level
collector.config.commit_include_regex = commit_include_regex

ctx = collector.collect_metadata()
storage = Storage()
Expand Down
2 changes: 1 addition & 1 deletion srctag/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def init_chroma(self):
model_name=self.config.st_model_name
),
# dis range: [0, 1]
metadata={"hnsw:space": "cosine"}
metadata={"hnsw:space": "l2"}
)

def process_file_ctx(self, file: FileContext, collection: Collection):
Expand Down
9 changes: 6 additions & 3 deletions srctag/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ def tag(self, storage: Storage) -> TagResult:

metadatas: typing.List[Metadata] = query_result["metadatas"][0]
# https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/vectorstores/chroma.py
# https://stats.stackexchange.com/questions/158279/how-i-can-convert-distance-euclidean-to-similarity-score
distances: typing.List[float] = query_result["distances"][0]
normalized_scores = [1 - each for each in distances]
normalized_scores = [
1.0 / (1.0 + x) for x in distances
]

for each_metadata, each_score in zip(metadatas, normalized_scores):
each_file_name = each_metadata[MetadataConstant.KEY_SOURCE]
Expand All @@ -100,8 +103,8 @@ def tag(self, storage: Storage) -> TagResult:
each_file_tag_result[each_tag] = each_score
else:
# has been touched by other commits
# merge these scores
each_file_tag_result[each_tag] += each_score
# keep the closest one
each_file_tag_result[each_tag] = max(each_score, each_file_tag_result[each_tag])
# END tag_results

scores_df = pd.DataFrame.from_dict(ret, orient="index")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def test_query(setup_tagger):

tags_series = tag_result.tags_by_file("examples/write.py")
assert len(tags_series) == len(all_tags)
tags_series = tags_series[tags_series > 0.5][:5]
assert len(tags_series) == 1
for k, v in tags_series.items():
logger.info(f"tag: {k}, score: {v}")

Expand Down

0 comments on commit a3b0f83

Please sign in to comment.