Skip to content

Commit

Permalink
Merge pull request #39 from egokick/negative-query-text
Browse files Browse the repository at this point in the history
negative text query
  • Loading branch information
deepfates authored Oct 3, 2023
2 parents 1cd98e9 + 803433d commit 761ba17
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
11 changes: 8 additions & 3 deletions memery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self, root: str = '.'):
self.db = None
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using {self.device} for computation.")

def index_flow(self, root: str, num_workers=0) -> tuple[str, str]:
'''Indexes images in path, returns the location of save files'''

Expand Down Expand Up @@ -75,13 +76,14 @@ def index_flow(self, root: str, num_workers=0) -> tuple[str, str]:

return(save_paths)

def query_flow(self, root: str, query: str=None, image_query: str=None, reindex: bool=False) -> list[str]:
def query_flow(self, root: str, query: str=None, negative_query: str=None, image_query: str=None, reindex: bool=False) -> list[str]:
'''
Indexes a folder and returns file paths ranked by query.
Parameters:
path (str): Folder to search
query (str): Search query text
query (str): Positive search query text
negative_query (str): Negative search query text
image_query (Tensor): Search query image(s)
reindex (bool): Reindex the folder if True
Returns:
Expand Down Expand Up @@ -123,6 +125,9 @@ def query_flow(self, root: str, query: str=None, image_query: str=None, reindex:
query_vec = text_vec + image_vec
elif query:
query_vec = encoder.text_encoder(query, device, model)
if negative_query:
negative_query_vec = encoder.text_encoder(negative_query, self.device, model)
query_vec = query_vec - negative_query_vec # Subtract negative query vector from positive query vector
elif image_query:
query_vec = encoder.image_query_encoder(img, device, model)
else:
Expand Down
7 changes: 4 additions & 3 deletions memery/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def parse_args(args: list[str]):
search_l, search_r = st.sidebar.columns([3,1])
with search_l:
text_query = st.text_input(label='Text query', value='')
negative_text_query = st.text_input(label='Negative Text query', value='')
with search_r:
st.title("")
search_button = st.button(label="Search", key="search_button")
Expand Down Expand Up @@ -93,15 +94,15 @@ def clear_cache(root, logbox):
print("Cleaned database and index files")

# Runs a search
def search(root, text_query, image_query, image_display_zone, skipped_files_box, num_images, captions_on, sizes, size_choice):
def search(root, text_query, negative_text_query, image_query, image_display_zone, skipped_files_box, num_images, captions_on, sizes, size_choice):
if not Path(path).exists():
with logbox:
with st_stdout('warning'):
print(f'{path} does not exist!')
return
with logbox:
with st_stdout('info'):
ranked = memery.query_flow(root, text_query, image_query)
ranked = memery.query_flow(root, text_query, negative_text_query, image_query) # Modified line
ims_to_display = {}
size = sizes[size_choice]
for o in ranked[:num_images]:
Expand Down Expand Up @@ -158,5 +159,5 @@ def st_stderr(dst):
elif do_index:
index(logbox, path, num_workers)
elif search_button or text_query or image_query:
search(path, text_query, image_query, image_display_zone, skipped_files_box, num_images, captions_on, sizes, size_choice)
search(path, text_query, negative_text_query, image_query, image_display_zone, skipped_files_box, num_images, captions_on, sizes, size_choice) # Modified line

0 comments on commit 761ba17

Please sign in to comment.