#include <assert.h>
#include <iostream>
#include <algorithm>
#include <functional>
#include "print.h"
#include "ast.h"
#include "package.h"

void AST::print( uint32_t level ) const
{
   indent( level );
   std::cout << "empty node" << std::endl;
}

ASTBaseType::ASTBaseType( const Token& token ) : AST( token )
{
   append( token );
}

void ASTBaseType::print_name() const
{
   std::vector<std::string>::const_iterator i = _list.begin();
   std::vector<std::string>::const_iterator end = _list.end();

   for( ; i != end; ++i )
   {
      std::cout << *i;

      if( i != (end - 1) )
      {
         std::cout << ":";
      }
   }
}

void ASTBaseType::print( uint32_t level ) const
{
   indent( level );
   print_name();
   std::cout << std::endl;
}

ASTType::ASTType( ASTBaseType::Ptr base ) :
   AST( *base ),
   _base( base ),
   _ref( REF_STRONG )
{
}

void ASTType::bind( ASTType::Ptr parameter )
{
   _parameters.push_back( parameter );
}

void ASTType::set_ref( ref_t ref )
{
   _ref = ref;
}

void ASTType::print_name() const
{
   _base->print_name();

   ASTType::Vector::const_iterator i = _parameters.begin();
   ASTType::Vector::const_iterator end = _parameters.end();

   if( i != end )
   {
      std::cout << "<";

      for( ; i != end; ++i )
      {
         (*i)->print_name();

         if( i != (end - 1) )
         {
            std::cout << ", ";
         }
      }

      std::cout << ">";
   }

   switch( _ref )
   {
   case REF_STRONG:
      break;

   case REF_NIL:
      std::cout << "*";
      break;

   case REF_WEAK:
      std::cout << "**";
      break;
   }
}

void ASTType::print( uint32_t level ) const
{
   indent( level );
   print_name();
   std::cout << std::endl;
}

ASTTypeDefinition::ASTTypeDefinition( const Token& name ) :
   AST( name ),
   _name( name.string() )
{
}

void ASTTypeDefinition::set_base( ASTType::Ptr base )
{
   _base = base;
   _base_def.reset();
}

void ASTTypeDefinition::set_base_def( ASTTypeDefinition::Ptr def )
{
   _base_def = def;
}

ASTImport::ASTImport( ASTBaseType::Ptr type ) :
   AST( *type ),
   _type( type )
{
}

void ASTImport::set_type( ASTTypeDefinition::Ptr type )
{
   _type_def = type;
}

bool ASTImport::match( const std::string& name ) const
{
   return name == _type->get( _type->count() - 1 );
}

void ASTImport::print( uint32_t level ) const
{
   indent( level );
   std::cout << "import ";
   _type->print_name();
   std::cout << std::endl;
}

ASTBoolean::ASTBoolean( const Token& token ) : ASTExpr( token )
{
   switch( token.type() )
   {
   case Token::T_TRUE:
      _value = true;
      break;

   case Token::T_FALSE:
      _value = false;
      break;

   default:
      assert( 0 );
   }
}

void ASTBoolean::print( uint32_t level ) const
{
   indent( level );
   std::cout << (_value ? "true" : "false") << std::endl;
}

ASTInteger::ASTInteger( const Token& token ) : ASTExpr( token )
{
   switch( token.type() )
   {
   case Token::T_BINARY:
      _value = strtoll( token.string().c_str(), NULL, 2 );
      break;

   case Token::T_INTEGER:
      _value = strtoll( token.string().c_str(), NULL, 10 );
      break;

   case Token::T_HEXADECIMAL:
      _value = strtoll( token.string().c_str(), NULL, 16 );
      break;

   default:
      assert( 0 );
   }
}

void ASTInteger::print( uint32_t level ) const
{
   indent( level );
   std::cout << _value << std::endl;
}

ASTReal::ASTReal( const Token& token ) : ASTExpr( token )
{
   assert( token.type() == Token::T_REAL );
   _value = strtod( token.string().c_str(), NULL );
}

void ASTReal::print( uint32_t level ) const
{
   indent( level );
   std::cout << _value << std::endl;
}

ASTImaginary::ASTImaginary( const Token& token ) : ASTExpr( token )
{
   assert( token.type() == Token::T_IMAGINARY );
   _value = strtod( token.string().c_str(), NULL );
}

void ASTImaginary::print( uint32_t level ) const
{
   indent( level );
   std::cout << _value << "i" << std::endl;
}

ASTString::ASTString( const Token& token ) : ASTExpr( token )
{
   assert( token.type() == Token::T_STRING );
   _value = token.string();
}

void ASTString::print( uint32_t level ) const
{
   indent( level );
   std::cout << "\"" << _value << "\"" << std::endl;
}

void ASTArray::set_list( ASTExpr::VectorPtr list )
{
   _list = list;
}

void ASTArray::print( uint32_t level ) const
{
   indent( level );
   std::cout << "array" << std::endl;
   print_vectorptr( _list, level );
}

void ASTThis::print( uint32_t level ) const
{
   indent( level );
   std::cout << "this" << std::endl;
}

ASTIdentifier::ASTIdentifier( const Token& token ) :
   ASTExpr( token )
{
   _name = token.string();
}

void ASTIdentifier::print( uint32_t level ) const
{
   indent( level );
   std::cout << "id " << _name << std::endl;
}

ASTUnary::ASTUnary( const Token& token ) : ASTExpr( token )
{
   _type = token.type();
}

void ASTUnary::set_rhs( ASTExpr::Ptr rhs )
{
   _rhs = rhs;
}

void ASTUnary::print( uint32_t level ) const
{
   indent( level );
   std::cout << "unary " << Token::text( _type ) << std::endl;
   _rhs->print( level + 1 );
}

ASTBinary::ASTBinary( const Token& token, ASTExpr::Ptr lhs ) : ASTUnary( token )
{
   _lhs = lhs;
}

void ASTBinary::print( uint32_t level ) const
{
   indent( level );
   std::cout << "binary " << Token::text( _type ) << std::endl;
   _lhs->print( level + 1 );
   _rhs->print( level + 1 );
}

ASTArrayIndex::ASTArrayIndex( ASTExpr::Ptr array, ASTExpr::VectorPtr dimensions ) :
   ASTExpr( *array ),
   _array( array ),
   _dimensions( dimensions )
{
}

void ASTArrayIndex::print( uint32_t level ) const
{
   indent( level );
   std::cout << "array index" << std::endl;
   _array->print( level + 1 );
   print_vectorptr( _dimensions, level + 1 );
}

ASTCall::ASTCall( ASTExpr::Ptr method ) :
   ASTExpr( *method ),
   _method( method )
{
}

void ASTCall::set_args( ASTExpr::VectorPtr args )
{
   _args = args;
}

void ASTCall::print( uint32_t level ) const
{
   indent( level );
   std::cout << "call" << std::endl;
   _method->print( level + 1 );
   print_vectorptr( _args, level + 1 );
}

void ASTSuperCall::print( uint32_t level ) const
{
   indent( level );
   std::cout << "super call" << std::endl;
   _method->print( level + 1 );
   print_vectorptr( _args, level + 1 );
}

ASTObjectCall::ASTObjectCall( ASTExpr::Ptr object, ASTExpr::Ptr method ) :
   ASTCall( method ),
   _object( object )
{
}

void ASTObjectCall::print( uint32_t level ) const
{
   indent( level );
   std::cout << "object call" << std::endl;
   _object->print( level + 1 );
   _method->print( level + 1 );
   print_vectorptr( _args, level + 1 );
}

ASTTypeIdentifier::ASTTypeIdentifier( ASTType::Ptr type ) :
   ASTExpr( *type ),
   _type( type )
{
}

void ASTTypeIdentifier::set_identifier( ASTIdentifier::Ptr identifier )
{
   _identifier = identifier;
}

void ASTTypeIdentifier::print( uint32_t level ) const
{
   indent( level );
   std::cout << "type identifier" << std::endl;
   _type->print( level + 1 );
   _identifier->print( level + 1 );
}

ASTVariable::ASTVariable( ASTType::Ptr type ) :
   ASTCommand( *type ),
   _type( type )
{
}

void ASTVariable::add_identifier( ASTIdentifier::Ptr id )
{
   _identifiers.push_back( id );
}

void ASTVariable::set_initialiser( ASTExpr::VectorPtr initialiser )
{
   _initialiser = initialiser;
}

void ASTVariable::print( uint32_t level ) const
{
   indent( level );
   std::cout << "var ";
   _type->print_name();
   std::cout << std::endl;
   print_vector( _identifiers, level + 1 );
   print_vectorptr( _initialiser, level + 1 );
}

ASTExprCommand::ASTExprCommand( ASTExpr::VectorPtr list ) :
   ASTCommand( *(list->at( 0 )) )
{
   _list = list;
}

void ASTExprCommand::print( uint32_t level ) const
{
   indent( level );
   std::cout << "expression" << std::endl;
   print_vectorptr( _list, level + 1 );
}

ASTAssign::ASTAssign( ASTExpr::VectorPtr lhs ) :
   ASTCommand( *(lhs->at( 0 )) ),
   _lhs( lhs )
{
}

void ASTAssign::set_rhs( ASTExpr::VectorPtr rhs )
{
   _rhs = rhs;
}

void ASTAssign::print( uint32_t level ) const
{
   indent( level );
   std::cout << "assignment" << std::endl;
   print_vectorptr( _lhs, level + 1 );
   print_vectorptr( _rhs, level + 1 );
}

void ASTBlock::add_command( ASTCommand::Ptr command )
{
   _commands.push_back( command );
}

void ASTBlock::print( uint32_t level ) const
{
   print_vector( _commands, level );
}

ASTIf::ASTIf( ASTExpr::Ptr condition ) :
   ASTCommand( *condition ),
   _condition( condition )
{
}

void ASTIf::set_true( ASTBlock::Ptr on_true )
{
   _on_true = on_true;
}

void ASTIf::set_false( ASTBlock::Ptr on_false )
{
   _on_false = on_false;
}

void ASTIf::print( uint32_t level ) const
{
   indent( level );
   std::cout << "if" << std::endl;
   _on_true->print( level + 1 );
   _on_false->print( level + 1 );
}

void ASTReturn::print( uint32_t level ) const
{
   indent( level );
   std::cout << "return" << std::endl;
}

ASTLiteral::ASTLiteral( ASTIdentifier::Ptr name ) :
   AST( *name ),
   _name( name )
{
}

void ASTLiteral::set_value( ASTExpr::Ptr value )
{
   _value = value;
}

void ASTLiteral::print( uint32_t level ) const
{
   indent( level );
   std::cout << "literal" << std::endl;
   _name->print( level + 1 );
   _value->print( level + 1 );
}

ASTMember::ASTMember( ASTType::Ptr type ) :
   AST( *type ),
   _type( type )
{
}

void ASTMember::set_identifier( ASTIdentifier::Ptr id )
{
   _name = id;
}

void ASTMember::print( uint32_t level ) const
{
   indent( level );
   std::cout << "member" << std::endl;
   _type->print( level + 1 );
   _name->print( level + 1 );
}

ASTMethod::ASTMethod( const Token& visibility ) :
   AST( visibility ),
   _visibility( visibility.type() ),
   _constructor( false )
{
}

void ASTMethod::set_constructor( bool constructor )
{
   _constructor = constructor;
}

void ASTMethod::set_identifier( ASTIdentifier::Ptr id )
{
   _name = id;
}

void ASTMethod::set_params( ASTTypeIdentifier::VectorPtr params )
{
   _params = params;
}

void ASTMethod::set_results( ASTTypeIdentifier::VectorPtr results )
{
   _results = results;
}

void ASTMethod::add_throws( ASTType::Ptr throws )
{
   _throws.push_back( throws );
}

void ASTMethod::set_extends( ASTCall::Ptr extends )
{
   _extends = extends;
}

void ASTMethod::set_body( ASTBlock::Ptr body )
{
   _body = body;
}

void ASTMethod::print( uint32_t level ) const
{
   indent( level );

   std::cout
      << Token::text( _visibility )
      << (_constructor ? "constructor" : "method" )
      << std::endl;

   _name->print( level + 1 );
   print_vectorptr( _params, level + 1 );

   indent( level + 1 );
   std::cout << "->" << std::endl;

   print_vectorptr( _results, level + 1 );

   if( _extends )
   {
      indent( level + 1 );
      std::cout << "extends" << std::endl;
      _extends->print( level + 2 );
   }

   if( _throws.size() > 0 )
   {
      indent( level + 1 );
      std::cout << "throws" << std::endl;
      print_vector( _throws, level + 2 );
   }

   _body->print( level + 1 );
}

ASTGeneric::ASTGeneric( const Token& name ) :
   AST( name ),
   _name( name.string() )
{
}

void ASTGeneric::print( uint32_t level ) const
{
   indent( level );
   std::cout << "generic " << _name << std::endl;
   print_vector( _constraints, level + 1 );
}

void ASTEnum::shadow_check( Errors& errors )
{
   // check for multiply defined enum values
   ASTIdentifier::Vector::iterator i = _list.begin();
   ASTIdentifier::Vector::iterator end = _list.end();

   for( ; i != end; ++i )
   {
      ASTIdentifier::Vector::iterator j = i + 1;
      ASTIdentifier::Ptr id = *i;

      for( ; j != end; ++j )
      {
         if( id->name() == (*j)->name() )
         {
            errors.report( *id ) << "enum value '" << id->name() << "' is multiply defined";
         }
      }
   }

   // check for shadowed enum values
   if( _base_def )
   {
      ASTEnum* base = dynamic_cast<ASTEnum*>( _base_def.get() );

      if( base != NULL )
      {
         base->shadow_chain( *this, errors );
      } else {
         errors.report( *this ) << "base type is not an enum";
      }
   }
}

void ASTEnum::print( uint32_t level ) const
{
   indent( level );
   std::cout << "enum " << _name;

   if( _base )
   {
      printf( " includes " );
      _base->print_name();
   }

   std::cout << std::endl;
   print_vector( _list, level + 1 );
}

void ASTEnum::shadow_chain( ASTEnum& descendent, Errors& errors )
{
   ASTIdentifier::Vector::iterator i = _list.begin();
   ASTIdentifier::Vector::iterator end = _list.end();

   for( ; i != end; ++i )
   {
      ASTIdentifier::Vector::iterator j = descendent._list.begin();
      ASTIdentifier::Vector::iterator jend = descendent._list.end();

      for( ; j != jend; ++j )
      {
         ASTIdentifier::Ptr id = *j;

         if( (*i)->name() == id->name() )
         {
            errors.report( *id ) << "enum value '" << id->name() << "' is defined in a base type";
            break;
         }
      }
   }

   if( _base_def )
   {
      ASTEnum* base = dynamic_cast<ASTEnum*>( _base_def.get() );

      if( base != NULL )
      {
         base->shadow_chain( descendent, errors );
      }
   }
}

ASTClass::ASTClass( phylum_t phylum, const Token& name ) :
   ASTTypeDefinition( name ),
   _phylum( phylum )
{
}

ASTGeneric::Ptr ASTClass::generic( const std::string& name )
{
   ASTGeneric::Vector::iterator i = _generics.begin();
   ASTGeneric::Vector::iterator end = _generics.end();

   for( ; i != end; ++i )
   {
      if( (*i)->name() == name )
      {
         return *i;
      }
   }

   return ASTGeneric::Ptr();
}

void ASTClass::add_generic( ASTGeneric::Ptr generic )
{
   _generics.push_back( generic );
}

void ASTClass::add_implements( ASTType::Ptr interface )
{
   _implements.push_back( interface );
}

void ASTClass::add_literal( ASTLiteral::Ptr literal )
{
   _literals.push_back( literal );
}

void ASTClass::add_member( ASTMember::Ptr member )
{
   _members.push_back( member );
}

void ASTClass::add_method( ASTMethod::Ptr method )
{
   _methods.push_back( method );
}

void ASTClass::shadow_check( Errors& errors )
{
   // check for multiply defined literals
   {
      ASTLiteral::Vector::iterator i = _literals.begin();
      ASTLiteral::Vector::iterator end = _literals.end();

      for( ; i != end; ++i )
      {
         ASTLiteral::Vector::iterator j = i + 1;
         ASTLiteral::Ptr lit = *i;

         for( ; j != end; ++j )
         {
            if( lit->identifier()->name() == (*j)->identifier()->name() )
            {
               errors.report( *lit ) << "literal '" << lit->identifier()->name() << "' is multiply defined";
            }
         }
      }
   }

   // check for multiply defined members
   {
      ASTMember::Vector::iterator i = _members.begin();
      ASTMember::Vector::iterator end = _members.begin();

      for( ; i != end; ++i )
      {
         ASTMember::Vector::iterator j = i + 1;
         ASTMember::Ptr mem = *i;

         for( ; j != end; ++j )
         {
            if( mem->identifier()->name() == (*j)->identifier()->name() )
            {
               errors.report( *mem ) << "member '" << mem->identifier()->name() << "' is multiply defined";
            }
         }
      }
   }

   // check for shadowed members
   if( _base_def )
   {
      ASTClass* base = dynamic_cast<ASTClass*>( _base_def.get() );

      if( base != NULL )
      {
         base->shadow_chain( *this, errors );
      } else {
         errors.report( *this ) << "base type is not a class";
      }
   }
}

void ASTClass::type_check( Module& module, Errors& errors )
{
   /* FIX: class type check
    * when resolving types, check module and generic formal parameters
    * resolve member types
    * resolve method signature types
    * make sure overrides have the correct signature
    * make sure all interface methods are implemented
    * make sure no interface is implemented more than once (why?)
    * type check method bodies
    */

   // interfaces can't have a base class, members or method implementations
   if( _phylum == PHYLUM_INTERFACE )
   {
      if( _base )
      {
         errors.report( *this ) << "Interface '" << _name << "' has a base class";
      }

      if( !_members.empty() )
      {
         errors.report( *this ) << "Interface '" << _name << "' has members";
      }

      ASTMethod::Vector::iterator i = _methods.begin();
      ASTMethod::Vector::iterator end = _methods.end();

      for( ; i != end; ++i )
      {
         if( !(*i)->body()->empty() )
         {
            errors.report( *this )
               << "Interface '"
               << _name
               << "' has an implementation for method "
               << (*i)->identifier()->name();
         }
      }
   }

   // resolve member types
   ASTMember::Vector::iterator i = _members.begin();
   ASTMember::Vector::iterator end = _members.end();

   for( ; i != end; ++i )
   {
      ASTTypeDefinition::Vector list;
      ASTBaseType::Ptr base = (*i)->type()->base();

      if( base->count() == 1 )
      {
         // FIX: could be a generic formal parameter
      }

      // FIX: this isn't finished?
      module.resolve_type( *base, list );
   }
}

void ASTClass::shadow_chain( ASTClass& descendent, Errors& errors )
{
   ASTMember::Vector::iterator i = _members.begin();
   ASTMember::Vector::iterator end = _members.end();

   for( ; i != end; ++i )
   {
      ASTMember::Vector::iterator j = descendent._members.begin();
      ASTMember::Vector::iterator jend = descendent._members.end();

      for( ; j != jend; ++j )
      {
         ASTMember::Ptr mem = *j;
         const std::string& name = mem->identifier()->name();

         if( (*i)->identifier()->name() == name )
         {
            errors.report( *mem ) << "member '" << name << "' is defined in a base type";
            break;
         }
      }
   }

   if( _base_def )
   {
      ASTClass* base = dynamic_cast<ASTClass*>( _base_def.get() );

      if( base != NULL )
      {
         base->shadow_chain( descendent, errors );
      }
   }
}

void ASTClass::print( uint32_t level ) const
{
   indent( level );
   std::cout << "class " << _name;

   if( _base )
   {
      std::cout << " extends ";
      _base->print_name();
   }

   ASTType::Vector::const_iterator i = _implements.begin();
   ASTType::Vector::const_iterator end = _implements.end();

   if( i != end )
   {
      std::cout << " implements ";

      for( ; i != end; ++i )
      {
         (*i)->print_name();

         if( i != (end - 1) )
         {
            std::cout << ", ";
         }
      }
   }

   std::cout << std::endl;

   print_vector( _generics, level + 1 );
   print_vector( _literals, level + 1 );
   print_vector( _members, level + 1 );
   print_vector( _methods, level + 1 );
}
