#This code uses BART. This code also needs to change since it is directly pulled from Mosaic's API


import json
import torch
import argparse
from tqdm import tqdm
from pathlib import Path
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


def use_task_specific_params(model, task):
    """Update config with summarization specific params."""
    task_specific_params = model.config.task_specific_params

    if task_specific_params is not None:
        pars = task_specific_params.get(task, {})
        model.config.update(pars)

def trim_batch(
    input_ids, pad_token_id, attention_mask=None,
):
    """Remove columns that are populated exclusively by pad_token_id"""
    keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
    if attention_mask is None:
        return input_ids[:, keep_column_mask]
    else:
        return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])


def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]


class Comet:
    def __init__(self, model_path):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        task = "summarization"
        use_task_specific_params(self.model, task)
        self.batch_size = 1
        self.decoder_start_token_id = None

    def generate(
            self, 
            queries,
            decode_method="beam", 
            num_generate=5, 
            ):

        with torch.no_grad():
            examples = queries

            decs = []
            for batch in list(chunks(examples, self.batch_size)):

                batch = self.tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(self.device)
                input_ids, attention_mask = trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id)

                summaries = self.model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_start_token_id=self.decoder_start_token_id,
                    num_beams=num_generate,
                    num_return_sequences=num_generate,
                    )

                dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
                decs.append(dec)

            return decs


all_relations = [
    "AtLocation",
    "CapableOf",
    "Causes",
    "CausesDesire",
    "CreatedBy",
    "DefinedAs",
    "DesireOf",
    "Desires",
    "HasA",
    "HasFirstSubevent",
    "HasLastSubevent",
    "HasPainCharacter",
    "HasPainIntensity",
    "HasPrerequisite",
    "HasProperty",
    "HasSubEvent",
    "HasSubevent",
    "HinderedBy",
    "InheritsFrom",
    "InstanceOf",
    "IsA",
    "LocatedNear",
    "LocationOfAction",
    "MadeOf",
    "MadeUpOf",
    "MotivatedByGoal",
    "NotCapableOf",
    "NotDesires",
    "NotHasA",
    "NotHasProperty",
    "NotIsA",
    "NotMadeOf",
    "ObjectUse",
    "PartOf",
    "ReceivesAction",
    "RelatedTo",
    "SymbolOf",
    "UsedFor",
    "isAfter",
    "isBefore",
    "isFilledBy",
    "oEffect",
    "oReact",
    "oWant",
    "xAttr",
    "xEffect",
    "xIntent",
    "xNeed",
    "xReact",
    "xReason",
    "xWant",
    ]
question_relation_tuple = {
    "AtLocation": "What is at location ",
    "CapableOf" : "What is capable of ",
    "Causes" : "What causes ",
    "CausesDesire" : "What causes desire ",
    "CreatedBy" : "What is created by ",
    "DefinedAs" : "What is defined as ",
    "DesireOf" : "What is a desire of ",
    "Desires" : "What desires ",
    "HasA" : "What has a ",
    "HasFirstSubevent" : "What has a first subevent of ",
    "HasLastSubevent" : "What has a last subevent of ",
    "HasPainCharacter" : "What has a pain characteristic of ",
    "HasPainIntensity" : "What has a pain intensity of ",
    "HasPrerequisite" : "What has a prerequisite of ",
    "HasProperty" : "What has property ",
    "HasSubEvent" : "What has subevent ",
    "HasSubevent" : "What has subevent ",
    "HinderedBy" : "What is hindered by ",
    "InheritsFrom" : "What inherits from ",
    "InstanceOf" : "What is an instance of ",
    "IsA" : "What is a ",
    "LocatedNear" : "What is located near ",
    "LocationOfAction" : "What is located at action ",
    "MadeOf" : "What is made of ",
    "MadeUpOf" : "What is made up of ",
    "MotivatedByGoal" : "What is motivated by ",
    "NotCapableOf" : "What is not capable of ",
    "NotDesires" : "What does not desire ",
    "NotHasA" : "What does not have a ",
    "NotHasProperty" : "What does not have property ",
    "NotIsA" : "What is not a ",
    "NotMadeOf" : "What is not made of ",
    "ObjectUse" : "What object uses ",
    "PartOf" : "What is part of ",
    "ReceivesAction" : "What recieves action ",
    "RelatedTo" : "What is related to ",
    "SymbolOf" : "What is a symbol of ",
    "UsedFor" : "What is used for ",
    "isAfter" : "What is after ",
    "isBefore" : "What is before ",
    "isFilledBy" : "What is filled by ",
    "oEffect" : "Who has the effect ",
    "oReact" : "Who reacts ",
    "oWant" : "Who wants ",
    "xAttr" : "Who has the attribute ",
    "xEffect" : "Who has the effect ",
    "xIntent" : "Who has the intent ",
    "xNeed" : "Who has the need ",
    "xReact" : "Who has the reaction ",
    "xReason" : "Who has the reason ",
    "xWant" : "Who has the want ",
}

most_relations = [
    "AtLocation",
    "CapableOf",
    "Causes",
    "CausesDesire",
    "CreatedBy",
    "DefinedAs",
    "DesireOf",
    "Desires",
    "HasA",
    "HasFirstSubevent",
    "HasLastSubevent",
    "HasPainCharacter",
    "HasPainIntensity",
    "HasPrerequisite",
    "HasProperty",
    "HasSubEvent",
    "HasSubevent",
    "HinderedBy",
    "InheritsFrom",
    "InstanceOf",
    "IsA",
    "LocatedNear",
    "LocationOfAction",
    "MadeOf",
    "MadeUpOf",
    "MotivatedByGoal",
    "ObjectUse",
    "PartOf",
    "ReceivesAction",
    "RelatedTo",
    "SymbolOf",
    "UsedFor",
    "isAfter",
    "isBefore",
    "isFilledBy",
    ]