mmpose.core.utils.regularizations 源代码

from abc import ABCMeta, abstractmethod, abstractproperty

import torch


class PytorchModuleHook(metaclass=ABCMeta):
    """Base class for PyTorch module hook registers.

    An instance of a subclass of PytorchModuleHook can be used to
    register hook to a pytorch module using the `register` method like:
        hook_register.register(module)

    Subclasses should add/overwrite the following methods:
        - __init__
        - hook
        - hook_type
    """

    @abstractmethod
    def hook(self, *args, **kwargs):
        """Hook function."""

    @abstractproperty
    def hook_type(self) -> str:
        """Hook type Subclasses should overwrite this function to return a
        string value in.

        {`forward`, `forward_pre`, `backward`}
        """

    def register(self, module):
        """Register the hook function to the module.

        Args:
            module (pytorch module): the module to register the hook.

        Returns:
            handle (torch.utils.hooks.RemovableHandle): a handle to remove
                the hook by calling handle.remove()
        """
        assert isinstance(module, torch.nn.Module)

        if self.hook_type == 'forward':
            h = module.register_forward_hook(self.hook)
        elif self.hook_type == 'forward_pre':
            h = module.register_forward_pre_hook(self.hook)
        elif self.hook_type == 'backward':
            h = module.register_backward_hook(self.hook)
        else:
            raise ValueError(f'Invalid hook type {self.hook}')

        return h


[文档]class WeightNormClipHook(PytorchModuleHook): """Apply weight norm clip regularization. The module's parameter will be clip to a given maximum norm before each forward pass. Args: max_norm (float): The maximum norm of the parameter. module_param_names (str|list): The parameter name (or name list) to apply weight norm clip. """ def __init__(self, max_norm=1.0, module_param_names='weight'): self.module_param_names = module_param_names if isinstance( module_param_names, list) else [module_param_names] self.max_norm = max_norm @property def hook_type(self): return 'forward_pre'
[文档] def hook(self, module, _input): for name in self.module_param_names: assert name in module._parameters, f'{name} is not a parameter' \ f' of the module {type(module)}' param = module._parameters[name] with torch.no_grad(): m = param.norm().item() if m > self.max_norm: param.mul_(self.max_norm / (m + 1e-6))