Skip to content
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
ENVIRONMENT=development
DEBUG=TRUE
# Disables the lookup of language maps on startup (speeds up dev boot)
NO_LM=TRUE
SERVICE_NAME=oclapi2

# -----------------------------------------------------------------------------
Expand Down
138 changes: 90 additions & 48 deletions core/common/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import time
import urllib

import requests
from cid.locals import get_cid
from django.conf import settings
from django.db.models import Case, When, IntegerField
from elasticsearch_dsl import FacetedSearch, Q
from pydash import compact, get, has, set_
from sentence_transformers import CrossEncoder

from core.common import ERRBIT_LOGGER
from core.common.constants import ES_REQUEST_TIMEOUT
from core.common.utils import is_url_encoded_string

Expand Down Expand Up @@ -336,42 +338,54 @@ def __get_response(self, exact_count=True, load_fields=False):
return self._dsl_search, None, total


class VectorEmbed:
_LOCAL_MODELS = {}
SERVICE_TIMEOUT = 3

def __init__(self, model_name=None):
self.model_name = model_name or settings.LM_MODEL_NAME

def embed(self, txt):
if settings.ENV == 'ci':
return None
if settings.EMBEDDING_SERVICE_URL:
return self._get_embedding_from_service(txt)
return self._get_embedding_locally(txt)

def _get_embedding_from_service(self, txt):
try:
response = requests.post(
f'{settings.EMBEDDING_SERVICE_URL}/embeddings',
headers={'Authorization': f'Bearer {settings.INFINITY_API_KEY}'},
json={'model': self.model_name, 'input': str(txt)},
timeout=self.SERVICE_TIMEOUT
)
response.raise_for_status()
return response.json()['data'][0]['embedding']
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return self._get_embedding_locally(txt)

def _get_embedding_locally(self, txt):
try:
model = self._LOCAL_MODELS.get(self.model_name)
if model is None:
from sentence_transformers import SentenceTransformer
model = self._LOCAL_MODELS[self.model_name] = SentenceTransformer(self.model_name)
return model.encode(str(txt)).tolist()
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return None


class Reranker:
ENCODERS = [
# Best and Fastest overall lightweight medical reranker
# Size: ~110M
# Speed: similar to MiniLM CrossEncoder
# Training: includes clinical, medical, question-answering datasets
# Output: positive similarity scores (not raw logits!)
# 0.6B params
# https://huggingface.co/BAAI/bge-reranker-v2-m3
"BAAI/bge-reranker-v2-m3",

# Model: jinhybr/OA-MedBERT-cross-encoder or similar
# Size: ~110M
# Domain: PubMed abstracts, biomedical QA
# Type: binary classifier (logits)
# Not huggin face model -- ???
# "jinhybr/OA-MedBERT-cross-encoder",

# Model: microsoft/BioLinkBERT-base
# Type: CrossEncoder
# Size: ~120M
# Domain: UMLS, PubMed, MeSH, SNOMED (closest to OCL)
# Not huggin face model -- doesn't work with sentence_transformers
# "microsoft/BioLinkBERT-base",

# 22.7M params
# https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2
# doesn't work with logits, so not between 0-1
"cross-encoder/ms-marco-MiniLM-L-6-v2",
]
SCORE_KEY = 'search_rerank_score'
MISSING_SCORE = -1000000.0
SERVICE_TIMEOUT = 60
_LOCAL_MODELS = {}

def __init__(self, model_name=None):
self.model_name = model_name
self.encoder = self._get_encoder(self.model_name)
self.model_name = model_name or self.default_model

def rerank( # pylint: disable=too-many-arguments
self, hits, txt, name_key='name', source_attr=None, should_convert_source_to_dict=True,
Expand All @@ -393,18 +407,56 @@ def _predict_scores(self, hits, txt, name_key, source_attr, should_convert_sourc
return scores_full

docs = [get(self._get_source(hit, source_attr, should_convert_source_to_dict), name_key) for hit in hits]
valid = []
valid_docs = []
for i, d in enumerate(docs):
if isinstance(d, str) and d.strip():
valid.append((i, d.strip()))
if not valid:
valid_docs.append((i, d.strip()))
if not valid_docs:
return scores_full
scores = self.encoder.predict([(txt, d) for _, d in valid])
for (i, _), s in zip(valid, scores):

scores = self._get_rerank_scores(txt, valid_docs)
for (i, _), s in zip(valid_docs, scores):
scores_full[i] = float(s)

return scores_full

def _get_rerank_scores(self, txt, docs):
if settings.ENV == 'ci' or not self.model_name:
return [self.MISSING_SCORE] * len(docs)
if settings.EMBEDDING_SERVICE_URL:
return self._get_rerank_scores_from_service(txt, docs)
return self._get_rerank_scores_locally(txt, docs)

def _get_rerank_scores_from_service(self, txt, docs):
try:
response = requests.post(
f'{settings.EMBEDDING_SERVICE_URL}/rerank',
headers={'Authorization': f'Bearer {settings.INFINITY_API_KEY}'},
json={
'model': self.model_name or self.default_model,
'query': txt,
'documents': [d for _, d in docs],
},
timeout=self.SERVICE_TIMEOUT
)
response.raise_for_status()
results = response.json()['results']
# results is a list of {index, relevance_score} sorted by index
return [r['relevance_score'] for r in sorted(results, key=lambda r: r['index'])]
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return self._get_rerank_scores_locally(txt, docs)

def _get_rerank_scores_locally(self, txt, docs):
try:
encoder = self._LOCAL_MODELS.get(self.model_name)
if encoder is None:
encoder = self._LOCAL_MODELS[self.model_name] = self._get_encoder()
return encoder.predict([(txt, d) for _, d in docs])
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return [self.MISSING_SCORE] * len(docs)

def _assign_score(self, hits, scores, score_key, order_results):
score_key = score_key or self.SCORE_KEY
key_to_set = score_key
Expand All @@ -420,18 +472,8 @@ def _assign_score(self, hits, scores, score_key, order_results):
def _order(hits, key_to_order):
return sorted(hits, key=lambda hit: get(hit, key_to_order), reverse=True)

def _get_encoder(self, model_name):
if model_name and model_name != self.default_model:
return self._load_encoder(model_name)
return self._load_default_encoder()

@staticmethod
def _load_encoder(model_name):
return CrossEncoder(model_name, device="cpu", max_length=128)

@staticmethod
def _load_default_encoder():
return settings.ENCODER
def _get_encoder(self):
return CrossEncoder(self.model_name, device="cpu", max_length=128)

@staticmethod
def _get_source(data, source_attr, should_convert_source_to_dict):
Expand Down
Loading