import torch

from allennlp.common.checks import ConfigurationError
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler


@LearningRateScheduler.register("polynomial_decay")
class PolynomialDecay(LearningRateScheduler):
    """
    Implements polynomial decay Learning rate scheduling. The learning rate is
    first linearly increased for the first `warmup_steps` training steps. Then
    it is decayed for `total_steps` - `warmup_steps` from the initial learning
    rate to `end_learning_rate` using a polynomial of degree `power`.

    Formally,

    `lr` = (`initial_lr` - `end_learning_rate`) *
           ((`total_steps` - `steps`)/(`total_steps` - `warmup_steps`)) ** `power`

    # Parameters

    optimizer : `torch.optim.Optimizer`
        This argument does not get an entry in a configuration file for the
        object.
    num_epochs: `int`
        The number of epochs in the experiment. this does *NOT* get an entry in
        the config.
    num_steps_per_epoch: `int`
        The number of steps per epoch. this does *NOT* get an entry in the
        config.
    warmup_steps : `int`, required
        The number of steps to linearly increase the learning rate.
    power : `float`, optional (default = `1.0`)
        The power of the polynomial used for decaying.
    end_learning_rate : `float`, optional (default = `0.0`)
        Final learning rate to decay towards.

    # Example

    Config for using the `PolynomialDecay` Learning Rate Scheduler with
    `warmup_steps` set `100`, `power` set to `2`, and `end_learning_rate` set
    to `1e-10`.

    ```json
    {
        ...
       "trainer":{
            ...
            "learning_rate_scheduler": {
                "type": "polynomial_decay",
                "power": 2,
                "warmup_steps": 100,
                "end_learning_rate": 1e-10
            },
            ...
       }
    }
    ```
    Note that you do NOT pass a `optimizer`, `num_epochs`, nor
    `num_steps_per_epoch` key to the Learning rate scheduler.
    """

    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        num_epochs: int,
        num_steps_per_epoch: int,
        power=1.0,
        warmup_steps=0,
        end_learning_rate=0.0,
        last_epoch: int = -1,
    ):
        super().__init__(optimizer, last_epoch)

        # Sanity check here.
        if num_steps_per_epoch is None:
            raise ConfigurationError(
                "'num_steps_per_epoch' is required for this LR scheduler.\n\n"
                "If you know how many batches per epoch for your training data, you can set this value "
                "directly in your config. Otherwise you'll need to use compatible settings with your data loader "
                "so that it can report an accurate number of batches per epoch. "
                "If you're using the MultiProcessDataLoader, "
                "this means you either need to set 'batches_per_epoch' "
                "or leave 'max_instances_in_memory' as None (if your entire dataset can fit into memory)."
            )

        self.power = power
        self.warmup_steps = warmup_steps
        self.total_steps = num_epochs * num_steps_per_epoch
        self.end_learning_rate = end_learning_rate

        self.steps = 0

        self.step_batch(0)

    def get_values(self):
        if self.warmup_steps > 0 and self.steps < self.warmup_steps:
            f = self.steps / self.warmup_steps
            return [f * lr for lr in self.base_values]

        if self.steps >= self.total_steps:
            return [self.end_learning_rate for _ in self.base_values]

        current_decay_steps = self.total_steps - self.steps
        total_decay_steps = self.total_steps - self.warmup_steps
        f = (current_decay_steps / total_decay_steps) ** self.power
        return [
            f * (lr - self.end_learning_rate) + self.end_learning_rate for lr in self.base_values
        ]

    def step(self, metric: float = None) -> None:
        pass

    def step_batch(self, batch_num_total: int = None) -> None:
        if batch_num_total is None:
            self.steps += 1
        else:
            self.steps = batch_num_total

        for param_group, lr in zip(self.optimizer.param_groups, self.get_values()):
            param_group[self.param_group_field] = lr
