Skip to content

Commit

Permalink
Move weighted tsvector to indexing step
Browse files Browse the repository at this point in the history
  • Loading branch information
Tschuppi81 committed Dec 23, 2024
1 parent cad0655 commit 1f142d9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 42 deletions.
30 changes: 0 additions & 30 deletions src/onegov/org/models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from elasticsearch_dsl.query import MatchPhrase
from elasticsearch_dsl.query import MultiMatch
from functools import cached_property
from itertools import chain, repeat
from sedate import utcnow
from sqlalchemy import func
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -272,35 +271,6 @@ def load_batch_results(self) -> list[Any]:

return sorted_events + other

Check warning on line 272 in src/onegov/org/models/search.py

View check run for this annotation

Codecov / codecov/patch

src/onegov/org/models/search.py#L272

Added line #L272 was not covered by tests

def _create_weighted_vector(
self,
model: Any,
language: str = 'simple'
) -> Any:
# for now weight the first field with 'A', the rest with 'B'
weighted_vectors = [
func.setweight(
func.to_tsvector(
language,
getattr(model.fts_idx_data, field, '')),
weight
)
for field, weight in zip(
model.es_properties.keys(),
chain('A', repeat('B')))
if not field.startswith('es_') # TODO: rename to fts_
]

# combine all weighted vectors
if weighted_vectors:
combined_vector = weighted_vectors[0]
for vector in weighted_vectors[1:]:
combined_vector = combined_vector.op('||')(vector)
else:
combined_vector = func.to_tsvector(language, '')

return combined_vector

def filter_user_level(self, model: Any, query: Any) -> Any:
""" Filters search content according user level """

Expand Down
45 changes: 33 additions & 12 deletions src/onegov/search/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from elasticsearch.exceptions import NotFoundError
from elasticsearch.helpers import streaming_bulk
from langdetect.lang_detect_exception import LangDetectException
from itertools import groupby
from itertools import groupby, chain, repeat
from operator import itemgetter
from queue import Queue, Empty, Full

from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
from unidecode import unidecode

Expand Down Expand Up @@ -401,7 +403,8 @@ def index(
:param session: Supply an active session
:return: True if the indexing was successful, False otherwise
"""
content = []
combined_vector = None
params = []

if not isinstance(tasks, list):
tasks = [tasks]
Expand All @@ -419,8 +422,31 @@ def index(
data[k] = task['properties'][k]
_id = task['id']

content.append(
{'language': language, 'data': data, '_id': _id})
# prefix all fields with '_' to avoid conflicts with
# reserved names in `bindparam` function below
params.append(
{'language': language, 'data': data, '_id': _id, **{
f'_{k}': v for k, v in data.items()
}})

weighted_vector = [
func.setweight(
func.to_tsvector(
sqlalchemy.bindparam('language',
type_=sqlalchemy.String),
sqlalchemy.bindparam(f'_{field}',
type_=sqlalchemy.String)
),
weight
)
for field, weight in zip(
task['properties'].keys(),
chain('A', repeat('B')))
if not field.startswith('es_') # TODO: rename to fts_
]
combined_vector = weighted_vector[0]
for vector in weighted_vector[1:]:
combined_vector = combined_vector.op('||')(vector)

schema = tasks[0]['schema']
tablename = tasks[0]['tablename']
Expand All @@ -433,17 +459,12 @@ def index(
sqlalchemy.column(self.TEXT_SEARCH_DATA_COLUMN_NAME),
schema=schema # type: ignore
)
tsvector_expr = sqlalchemy.text(
'to_tsvector(:language, :data)').bindparams(
sqlalchemy.bindparam('language', type_=sqlalchemy.String),
sqlalchemy.bindparam('data', type_=JSONB)
)

stmt = (
sqlalchemy.update(table)
.where(id_col == sqlalchemy.bindparam('_id'))
.values({
self.TEXT_SEARCH_COLUMN_NAME: tsvector_expr,
self.TEXT_SEARCH_COLUMN_NAME: combined_vector,
self.TEXT_SEARCH_DATA_COLUMN_NAME:
sqlalchemy.bindparam('data', type_=JSONB)
})
Expand All @@ -452,11 +473,11 @@ def index(
if session is None:
connection = self.engine.connect()
with connection.begin():
connection.execute(stmt, content)
connection.execute(stmt, params)
else:
# use a savepoint instead
with session.begin_nested():
session.execute(stmt, content)
session.execute(stmt, params)
except Exception as ex:
index_log.error(f"Error '{ex}' indexing schema "
f'{tasks[0]["schema"]} table '
Expand Down
1 change: 1 addition & 0 deletions src/onegov/search/integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import certifi
import morepath
import ssl
Expand Down

0 comments on commit 1f142d9

Please sign in to comment.