import logging
import os
import platform
import socket
import sys
import time
import sentry_sdk
import torch
from dotmap import DotMap
from icecream import ic
from sentry_sdk import capture_exception
from sentry_sdk.integrations.threading import ThreadingIntegration

from lm_node.base import GPTModel
from lm_node.tensorizer.tensorizer import load_model, get_tokenizer


def init_config() -> (GPTModel, DotMap):
    # init sentry
    config = DotMap()
    config.model_type = "GPT"
    is_dev = ""
    environment = "production"
    if os.environ['DEV'] == "True":
        is_dev = "_dev"
        environment = "staging"
    config.is_dev = is_dev

    # So we know all about the errors.
    sentry_url = os.getenv(
        "SENTRY_DSN",
        os.getenv("SENTRY_URL",
                  "https://cfa309d38ebf4cb48280f449ed5bae34@o846434.ingest.sentry.io/5987497"))
    sentry_sdk.init(
        sentry_url,
        server_name=socket.gethostname(),
        traces_sample_rate=0.002,
        environment=environment,
        integrations=[ThreadingIntegration(propagate_hub=True)],
    )

    # Setup logger
    logger = logging.getLogger(__name__)
    logger.setLevel(level=logging.INFO)
    fh = logging.StreamHandler()
    fh_formatter = logging.Formatter(
        "%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
    )
    fh.setFormatter(fh_formatter)
    logger.addHandler(fh)
    config.logger = logger

    # Gather node information
    config.cuda_dev = torch.cuda.current_device()
    cpu_id = platform.processor()
    if os.path.exists('/proc/cpuinfo'):
        cpu_id = [line for line in open("/proc/cpuinfo", 'r').readlines() if
         'model name' in line][0].rstrip().split(': ')[-1]

    config.cpu_id = cpu_id
    config.gpu_id = torch.cuda.get_device_name(config.cuda_dev)
    config.node_id = platform.node()

    # Report on our CUDA memory and model.
    gb_gpu = int(torch.cuda.get_device_properties(
        config.cuda_dev).total_memory / (1000 * 1000 * 1000))
    logger.info(f"CPU: {config.cpu_id}")
    logger.info(f"GPU: {config.gpu_id}")
    logger.info(f"GPU RAM: {gb_gpu}gb")

    config.model_name = os.environ['MODEL']
    logger.info(f"MODEL: {config.model_name}")

    # Unfortunately, lots of people push the limits!
    config.model_max_tokens = os.getenv('MODEL_MAX_TOKENS', None)
    if config.model_max_tokens is not None or config.model_max_tokens != "":
        try:
            config.model_max_tokens = int(config.model_max_tokens)
        except:
            config.model_max_tokens = 2048
    else:
        config.model_max_tokens = 2048
    logger.info(f"MODEL CONTEXT LIMIT: {config.model_max_tokens}")

    # Resolve where we get our model and data from.
    config.model_path = os.getenv('MODEL_PATH', None)
    config.model_uri = os.getenv('MODEL_URI', None)
    config.model_tensorized = os.getenv('MODEL_TENSORIZED', False)
    if config.model_uri:
        config.model_uri = config.model_uri.strip()
        logger.info(f"Loading assets from {config.model_uri}")
        config.model = load_model(config.model_uri)
        config.tokenizer, config.unitrim, config.word_tokens = \
            get_tokenizer(config.model_uri)
    elif not config.model_path:
        logger.fatal("You must have either `MODEL_URI` or `MODEL_PATH` set.")
        sys.exit(1)
    elif config.model_tensorized:
        config.model = load_model(config.model_path)
        config.tokenizer, config.unitrim, config.word_tokens = \
            get_tokenizer(config.model_path)

    # Get kube secrets
    config.rmq_username = os.getenv("RMQ_USERNAME", None)
    config.rmq_password = os.getenv("RMQ_PASSWORD", None)
    config.rmq_host = os.getenv("RMQ_HOST", None)
    config.queue_name = "generation_jobs_" + config.model_name + config.is_dev

    # Misc settings
    config.model_alias = os.getenv('MODEL_ALIAS')
    config.prefix_path = os.getenv('PREFIX_PATH')
    config.hyper_path = os.getenv('HYPER_PATH')
    config.user_module_path = os.getenv('MODULES')
    config.deepspeed_enabled = os.getenv('DEEPSPEED')
    config.version = "0.0.0.2smhs"
    config.prom_addr = os.getenv('PROM_ADDR')

    # Instantiate our actual model.
    load_time = time.time()

    try:
        model = GPTModel(config)
    except Exception as e:
        ic(e)
        capture_exception(e)
        logger.error("Failed to load model, restarting.")
        sys.exit(1)

    config.model = model

    # Mark that our model is loaded.
    f = open("/tmp/health_startup", "w")
    f.close()
    time_load = time.time() - load_time
    logger.info(f"Models loaded in {time_load:.2f}s")

    return model, config
