from cmath import exp
import torch
import torch.nn as nn
import torch.nn.functional as F

from basedformer.models import base_image

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        downsample = True if in_channels != out_channels else False
        self.residual = nn.Sequential()
        if downsample:
            self.residual = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
                                nn.BatchNorm2d(out_channels)
                            )
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2 if downsample else 1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out)) + self.residual(x)
        return F.relu(out)

class ResBlockBottleNeck(nn.Module):
    def __init__(self, in_channels, out_channels, expansion, needs_downsample=False) -> None:
        super().__init__()
        self.residual = nn.Sequential()
        self.expansion = expansion
        if needs_downsample or in_channels != out_channels * self.expansion:
            self.residual = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=2 if needs_downsample else 1),
                                nn.BatchNorm2d(out_channels * self.expansion)
                            )
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2 if needs_downsample else 1, padding=1)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

    def forward(self, x):
        residual = self.residual(x)
        out = F.relu((self.bn1(self.conv1(x))))
        out = F.relu((self.bn2(self.conv2(out))))
        out = F.relu((self.bn3(self.conv3(out))))
        return F.relu(out + residual)

        
class ResNet(base_image.BaseVisionModel):
    def __init__(self, user_config, **kwargs) -> None:
        self.default_config = {
            'in_channels': 3,
            'network_size': 18, #ResNet18/34/50/101/152
            'n_class': 100
        }
        super().__init__(user_config, **kwargs)
        network_config_dict = {
            18: (False, (2, 2, 2, 2)), 
            34: (False, (3, 4, 6, 3)),
            50: (True, (3, 4, 6, 3)),
            101: (True, (3, 4, 23, 3)),
            152: (True, (3, 4, 36, 3))
        }
        self.layerin = nn.Sequential(
                nn.Conv2d(self.config.in_channels, 64, kernel_size=7, stride=2, padding=3),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU()
            )
        self.resblocks = nn.ModuleList()
        self.network_config = network_config_dict[self.config.network_size]
        is_bottleneck = self.network_config[0] 
        curr_chan = 64
        prev_chan = curr_chan
        #dirty hack for downscaling at bottleneck layers
        firstlayer = True
        for i in self.network_config[1]:
            for _ in range(i):
                needs_downsample = True
                if is_bottleneck:
                    if firstlayer:
                        resblock = ResBlockBottleNeck(prev_chan, curr_chan, 4)
                        firstlayer = False
                    else:
                        resblock = ResBlockBottleNeck(prev_chan * 4, curr_chan, 4, needs_downsample)
                    needs_downsample = False
                else: 
                    resblock = ResBlock(prev_chan, curr_chan)
                self.resblocks.append(resblock)
                prev_chan = curr_chan
            curr_chan *= 2
            
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(prev_chan * 4 if is_bottleneck else prev_chan, self.config.n_class)
    
    def forward(self, x):
        out = self.layerin(x)
        for layer in self.resblocks:
            out = layer(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        return self.fc(out)