import logging
import time
import warnings

import elastic_transport
import urllib3
from elasticsearch import Elasticsearch, helpers

from lib.flask import cache

_logger = logging.getLogger('SERVER').getChild('ELASTIC')

# Will be loaded on-demand.
# Not typed since that would trigger the import of the massive ML libraries.
_MODEL = None

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


class ElasticClient:
    _Es = Elasticsearch(['https://localhost:9200'])
    _elastic_host = ''
    _elastic_index = ''
    _api_key = ''

    @property
    def client(self):
        return self._Es

    @client.setter
    def client(self, client):
        raise Exception

    @classmethod
    def initialise(cls, elastic_host: str, elastic_index: str, api_key: str):
        cls._elastic_host = elastic_host
        cls._elastic_index = elastic_index
        cls._api_key = api_key

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            cls._Es = Elasticsearch(
                [cls._elastic_host],
                api_key=cls._api_key,
                verify_certs=False
            )
            _logger.debug(f'Connected to Elastic: {cls._Es.info()}')

    @classmethod
    def create_index(cls, template: dict = None):
        with warnings.catch_warnings():
            if not cls._Es.indices.exists(index=cls._elastic_index):
                cls._Es.indices.create(index=cls._elastic_index, body=template)

    @classmethod
    def insert_json(cls, json_data: dict, item_id: str = None):
        for i in range(3):
            try:
                with warnings.catch_warnings():
                    warnings.simplefilter('ignore')
                    return cls._Es.index(index=cls._elastic_index, body=json_data, id=item_id)
            except elastic_transport.ConnectionTimeout:
                _logger.warning(f'Elastic timeout. Retry: {i}')
                time.sleep(10)
                return None
            except:
                # print(json_data)
                raise
        return None

    @classmethod
    def aggs(cls, term: str):
        body = {
            "size": 0,
            "aggs": {
                "tags": {
                    "composite": {
                        "size": 10000,
                        "sources": [
                            {"tags": {"terms": {"field": f"{term}.keyword"}}}
                        ]
                    }
                }
            }
        }

        tags = set()
        while True:
            response = cls._Es.search(index=cls._elastic_index, body=body)
            buckets = response['aggregations']['tags']['buckets']
            if not buckets:
                break

            for bucket in buckets:
                tags.add((bucket['key']['tags'], bucket['doc_count']))

            # use the last bucket key as the `after` parameter for the next page
            body['aggs']['tags']['composite']['after'] = buckets[-1]['key']

        return tags

    @classmethod
    def delete_by_id(cls, doc_id: str):
        try:
            with warnings.catch_warnings():
                warnings.simplefilter('ignore')
                return cls._Es.delete(index=cls._elastic_index, id=doc_id)
        except elastic_transport.ConnectionTimeout:
            _logger.warning(f'Elastic timeout while deleting document with id: {doc_id}')
            raise
        except Exception as e:
            _logger.error(f'Error while deleting document with id: {doc_id}. Error: {str(e)}')
            raise

    @classmethod
    def fetch_all_data(cls, page_size: int = 1000):
        query = {
            "query": {
                "match_all": {}
            }
        }

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            data = []
            for hit in helpers.scan(cls._Es, query=query, index=cls._elastic_index, size=page_size):
                data.append(hit['_source'])

        return data

    @classmethod
    def document_exists(cls, doc_id: str) -> bool:
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            response = cls._Es.exists(index=cls._elastic_index, id=doc_id)
            return bool(response)  # it's actually a bool but the returned type is weird

    @classmethod
    @cache.cached(timeout=86400, query_string=True)
    def get_mapping(cls):
        mapping = cls._Es.indices.get_mapping(index=cls._elastic_index)
        return mapping[cls._elastic_index]['mappings']['properties']

    @classmethod
    def get_document_by_id(cls, doc_id: str):
        try:
            with warnings.catch_warnings():
                warnings.simplefilter('ignore')
                response = cls._Es.get(index=cls._elastic_index, id=doc_id)
                return response['_source']
        except elastic_transport.ConnectionTimeout:
            _logger.warning(f'Elastic timeout while retrieving document with id: {doc_id}')
            raise
        except elastic_transport.exceptions.NotFoundError:
            _logger.warning(f'Document with id {doc_id} not found')
            return None
        except Exception as e:
            _logger.error(f'Error while retrieving document with id: {doc_id}. Error: {str(e)}')
            raise

    # ====================================================================================================================================

    @classmethod
    def build_query_body(cls, additional_fields: dict = None, exclude_fields: list = None):
        body = {
            'query': {
                'bool': {
                    'must': [],
                    'must_not': []
                }
            }
        }

        # Add additional fields to the query
        if additional_fields:
            for field, value in additional_fields.items():
                if isinstance(value, (list, tuple)):
                    for val in value:
                        body['query']['bool']['must'].append({
                            'match_phrase': {field: val}
                        })
                else:
                    body['query']['bool']['must'].append({
                        'match_phrase': {field: value}
                    })

        # Add exclude fields to the query
        if exclude_fields:
            for field, value in exclude_fields:
                body['query']['bool']['must_not'].append({
                    'term': {field: value}
                })

        return body

    @classmethod
    def execute_search(cls, body: dict, page: int, page_size: int, timeout: int, sort_field: str = None, sort_order: str = None):
        index_mapping = cls.get_mapping()

        # Add sorting if sort_field is not None
        if sort_field:
            sort_field_mapping = index_mapping.get(sort_field)
            if sort_field_mapping:
                sort_field_type = sort_field_mapping['type']
                if sort_field_type == 'text':
                    # If the field is of type 'text', use the .keyword subfield for sorting
                    body['sort'] = [
                        {
                            f"{sort_field}.keyword": {
                                'order': sort_order if sort_order else 'asc'
                            }
                        }
                    ]
                else:
                    # For other field types, use the field directly for sorting
                    body['sort'] = [
                        {
                            str(sort_field): {
                                'order': sort_order if sort_order else 'asc'
                            }
                        }
                    ]
            else:
                raise CustomElasticException(f"no mapping found for key [{sort_field}]")

        # Calculate the starting document for the results
        from_ = (page - 1) * page_size
        body['from'] = from_
        body['size'] = page_size

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            response = cls._Es.search(index=cls._elastic_index, body=body, request_timeout=timeout)

        total_hits = response['hits']['total']['value']

        # Calculate total pages
        total_pages = total_hits // page_size
        if total_hits % page_size != 0:
            total_pages += 1

        return response['hits']['hits'], total_pages

    # ====================================================================================================================================

    @classmethod
    def natural_query(cls, query_str: str, page: int = 1, page_size: int = 20, timeout: int = 5, exclude_fields: list = None):
        # Load the model on-demand.
        global _MODEL
        if not _MODEL:
            from sentence_transformers import SentenceTransformer
            _MODEL = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')

        query_embedding = _MODEL.encode(query_str).tolist()

        # https://www.elastic.co/search-labs/blog/simplifying-knn-search
        knn_query = {
            "knn": {
                "field": "embedding",
                "query_vector": query_embedding,
                "k": page_size
            }
        }

        body = {
            "size": page_size,
            "from": (page - 1) * page_size,
            "query": {}
        }

        if exclude_fields:
            bool_query = {
                "bool": {
                    "must": knn_query,
                    "must_not": [
                        {"term": {field: value}} for field, value in exclude_fields
                    ]
                }
            }
            body["query"] = bool_query
        else:
            body["query"] = knn_query

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            response = cls._Es.search(
                index=cls._elastic_index,
                body=body,
                request_timeout=timeout
            )

        total_hits = response['hits']['total']['value']
        total_pages = (total_hits + page_size - 1) // page_size

        hits = response['hits']['hits']
        return hits, total_pages

    # ====================================================================================================================================

    @classmethod
    def simple_query_search(cls, simple_query_str, page: int = 1, page_size: int = 20, timeout: int = 5,
                            additional_fields: dict = None, exclude_fields: list = None, sort_field: str = None, sort_order: str = None):
        body = cls.build_query_body(additional_fields, exclude_fields)

        body['query']['bool']['must'].append({
            'simple_query_string': {
                'query': simple_query_str,
            }
        })

        return cls.execute_search(body, page, page_size, timeout, sort_field, sort_order)

    @classmethod
    def multi_match_search(cls, query: str, query_fields: list, query_fields_exclude: list = None, page: int = 1, page_size: int = 20, timeout: int = 5,
                           additional_fields: dict = None, exclude_fields: list = None, sort_field: str = None, sort_order: str = None):
        body = cls.build_query_body(additional_fields, exclude_fields)

        if query:  # we might have an empty query string and instead be searching via `additional_fields`
            selected_fields = query_fields
            if len(query_fields) == 1 and query_fields[0] == '*':
                selected_fields = cls.get_mapping()
            body['query']['bool']['must'].append({
                'multi_match': {
                    'query': query,
                    'fields': [field for field in selected_fields if field not in (query_fields_exclude if query_fields_exclude else [])],
                    'lenient': True
                }
            })

        return cls.execute_search(body, page, page_size, timeout, sort_field, sort_order)

    # ====================================================================================================================================


class CustomElasticException(Exception):
    def __init__(self, message):
        self.message = message

    def __str__(self):
        return self.message
