Source code for flowvision.data.auto_augment

"""
OneFlow implementation of AutoAugment, RandAugment and AugMix
Modified from https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py

Papers:
    AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
    Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
    RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
    AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
"""

import random
import math
import re
from PIL import Image, ImageOps, ImageEnhance, ImageChops
import PIL
import numpy as np


_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])

_FILL = (128, 128, 128)

_LEVEL_DENOM = 10.0  # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments

_HPARAMS_DEFAULT = dict(translate_const=250, img_mean=_FILL,)

_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)


def _interpolation(kwargs):
    interpolation = kwargs.pop("resample", Image.BILINEAR)
    if isinstance(interpolation, (list, tuple)):
        return random.choice(interpolation)
    else:
        return interpolation


def _check_args_tf(kwargs):
    if "fillcolor" in kwargs and _PIL_VER < (5, 0):
        kwargs.pop("fillcolor")
    kwargs["resample"] = _interpolation(kwargs)
    if "fillcolor" in kwargs:
        kwargs["fillcolor"] = tuple(kwargs["fillcolor"])


def shear_x(img, factor, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)


def shear_y(img, factor, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)


def translate_x_rel(img, pct, **kwargs):
    pixels = pct * img.size[0]
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)


def translate_y_rel(img, pct, **kwargs):
    pixels = pct * img.size[1]
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)


def translate_x_abs(img, pixels, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)


def translate_y_abs(img, pixels, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)


def rotate(img, degrees, **kwargs):
    _check_args_tf(kwargs)
    if _PIL_VER >= (5, 2):
        return img.rotate(degrees, **kwargs)
    elif _PIL_VER >= (5, 0):
        w, h = img.size
        post_trans = (0, 0)
        rotn_center = (w / 2.0, h / 2.0)
        angle = -math.radians(degrees)
        matrix = [
            round(math.cos(angle), 15),
            round(math.sin(angle), 15),
            0.0,
            round(-math.sin(angle), 15),
            round(math.cos(angle), 15),
            0.0,
        ]

        def transform(x, y, matrix):
            (a, b, c, d, e, f) = matrix
            return a * x + b * y + c, d * x + e * y + f

        # align the center pos with the original one
        matrix[2], matrix[5] = transform(
            -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
        )
        matrix[2] += rotn_center[0]
        matrix[5] += rotn_center[1]
        return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
    else:
        return img.rotate(degrees, resample=kwargs["resample"])


def auto_contrast(img, **__):
    return ImageOps.autocontrast(img)


def invert(img, **__):
    return ImageOps.invert(img)


def equalize(img, **__):
    return ImageOps.equalize(img)


def solarize(img, thresh, **__):
    return ImageOps.solarize(img, thresh)


def solarize_add(img, add, thresh=128, **__):
    lut = []
    for i in range(256):
        if i < thresh:
            lut.append(min(255, i + add))
        else:
            lut.append(i)
    if img.mode in ("L", "RGB"):
        if img.mode == "RGB" and len(lut) == 256:
            lut = lut + lut + lut
        return img.point(lut)
    else:
        return img


def posterize(img, bits_to_keep, **__):
    if bits_to_keep >= 8:
        return img
    return ImageOps.posterize(img, bits_to_keep)


def contrast(img, factor, **__):
    return ImageEnhance.Contrast(img).enhance(factor)


def color(img, factor, **__):
    return ImageEnhance.Color(img).enhance(factor)


def brightness(img, factor, **__):
    return ImageEnhance.Brightness(img).enhance(factor)


def sharpness(img, factor, **__):
    return ImageEnhance.Sharpness(img).enhance(factor)


def _randomly_negate(v):
    """With 50% prob, negate the value"""
    return -v if random.random() > 0.5 else v


def _rotate_level_to_arg(level, _hparams):
    # range [-30, 30]
    level = (level / _LEVEL_DENOM) * 30.0
    level = _randomly_negate(level)
    return (level,)


def _enhance_level_to_arg(level, _hparams):
    # range [0.1, 1.9]
    return ((level / _LEVEL_DENOM) * 1.8 + 0.1,)


def _enhance_increasing_level_to_arg(level, _hparams):
    # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
    # range [0.1, 1.9] if level <= _LEVEL_DENOM
    level = (level / _LEVEL_DENOM) * 0.9
    level = max(0.1, 1.0 + _randomly_negate(level))  # keep it >= 0.1
    return (level,)


def _shear_level_to_arg(level, _hparams):
    # range [-0.3, 0.3]
    level = (level / _LEVEL_DENOM) * 0.3
    level = _randomly_negate(level)
    return (level,)


def _translate_abs_level_to_arg(level, hparams):
    translate_const = hparams["translate_const"]
    level = (level / _LEVEL_DENOM) * float(translate_const)
    level = _randomly_negate(level)
    return (level,)


def _translate_rel_level_to_arg(level, hparams):
    # default range [-0.45, 0.45]
    translate_pct = hparams.get("translate_pct", 0.45)
    level = (level / _LEVEL_DENOM) * translate_pct
    level = _randomly_negate(level)
    return (level,)


def _posterize_level_to_arg(level, _hparams):
    # As per Tensorflow TPU EfficientNet impl
    # range [0, 4], 'keep 0 up to 4 MSB of original image'
    # intensity/severity of augmentation decreases with level
    return (int((level / _LEVEL_DENOM) * 4),)


def _posterize_increasing_level_to_arg(level, hparams):
    # As per Tensorflow models research and UDA impl
    # range [4, 0], 'keep 4 down to 0 MSB of original image',
    # intensity/severity of augmentation increases with level
    return (4 - _posterize_level_to_arg(level, hparams)[0],)


def _posterize_original_level_to_arg(level, _hparams):
    # As per original AutoAugment paper description
    # range [4, 8], 'keep 4 up to 8 MSB of image'
    # intensity/severity of augmentation decreases with level
    return (int((level / _LEVEL_DENOM) * 4) + 4,)


def _solarize_level_to_arg(level, _hparams):
    # range [0, 256]
    # intensity/severity of augmentation decreases with level
    return (int((level / _LEVEL_DENOM) * 256),)


def _solarize_increasing_level_to_arg(level, _hparams):
    # range [0, 256]
    # intensity/severity of augmentation increases with level
    return (256 - _solarize_level_to_arg(level, _hparams)[0],)


def _solarize_add_level_to_arg(level, _hparams):
    # range [0, 110]
    return (int((level / _LEVEL_DENOM) * 110),)


LEVEL_TO_ARG = {
    "AutoContrast": None,
    "Equalize": None,
    "Invert": None,
    "Rotate": _rotate_level_to_arg,
    # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
    "Posterize": _posterize_level_to_arg,
    "PosterizeIncreasing": _posterize_increasing_level_to_arg,
    "PosterizeOriginal": _posterize_original_level_to_arg,
    "Solarize": _solarize_level_to_arg,
    "SolarizeIncreasing": _solarize_increasing_level_to_arg,
    "SolarizeAdd": _solarize_add_level_to_arg,
    "Color": _enhance_level_to_arg,
    "ColorIncreasing": _enhance_increasing_level_to_arg,
    "Contrast": _enhance_level_to_arg,
    "ContrastIncreasing": _enhance_increasing_level_to_arg,
    "Brightness": _enhance_level_to_arg,
    "BrightnessIncreasing": _enhance_increasing_level_to_arg,
    "Sharpness": _enhance_level_to_arg,
    "SharpnessIncreasing": _enhance_increasing_level_to_arg,
    "ShearX": _shear_level_to_arg,
    "ShearY": _shear_level_to_arg,
    "TranslateX": _translate_abs_level_to_arg,
    "TranslateY": _translate_abs_level_to_arg,
    "TranslateXRel": _translate_rel_level_to_arg,
    "TranslateYRel": _translate_rel_level_to_arg,
}


NAME_TO_OP = {
    "AutoContrast": auto_contrast,
    "Equalize": equalize,
    "Invert": invert,
    "Rotate": rotate,
    "Posterize": posterize,
    "PosterizeIncreasing": posterize,
    "PosterizeOriginal": posterize,
    "Solarize": solarize,
    "SolarizeIncreasing": solarize,
    "SolarizeAdd": solarize_add,
    "Color": color,
    "ColorIncreasing": color,
    "Contrast": contrast,
    "ContrastIncreasing": contrast,
    "Brightness": brightness,
    "BrightnessIncreasing": brightness,
    "Sharpness": sharpness,
    "SharpnessIncreasing": sharpness,
    "ShearX": shear_x,
    "ShearY": shear_y,
    "TranslateX": translate_x_abs,
    "TranslateY": translate_y_abs,
    "TranslateXRel": translate_x_rel,
    "TranslateYRel": translate_y_rel,
}


class AugmentOp:
    def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
        """
        set magnitude to get different level of transforms, e.g., magnitude = 10 means there are 10 levels of transforms for each op
        """
        hparams = hparams or _HPARAMS_DEFAULT
        self.name = name
        self.aug_fn = NAME_TO_OP[name]
        self.level_fn = LEVEL_TO_ARG[name]
        self.prob = prob
        self.magnitude = magnitude
        self.hparams = hparams.copy()
        self.kwargs = dict(
            fillcolor=hparams["img_mean"] if "img_mean" in hparams else _FILL,
            resample=hparams["interpolation"]
            if "interpolation" in hparams
            else _RANDOM_INTERPOLATION,
        )

        # If magnitude_std is > 0, we introduce some randomness
        # in the usually fixed policy and sample magnitude from a normal distribution
        # with mean `magnitude` and std-dev of `magnitude_std`.
        # NOTE This is my own hack, being tested, not in papers or reference impls.
        # If magnitude_std is inf, we sample magnitude from a uniform distribution
        self.magnitude_std = self.hparams.get("magnitude_std", 0)
        self.magnitude_max = self.hparams.get("magnitude_max", None)

    def __call__(self, img):
        if self.prob < 1.0 and random.random() > self.prob:
            return img
        magnitude = self.magnitude
        if self.magnitude_std > 0:
            # magnitude randomization enabled
            if self.magnitude_std == float("inf"):
                magnitude = random.uniform(0, magnitude)
            elif self.magnitude_std > 0:
                magnitude = random.gauss(magnitude, self.magnitude_std)
        # default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
        # setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
        upper_bound = self.magnitude_max or _LEVEL_DENOM
        magnitude = max(0.0, min(magnitude, upper_bound))
        level_args = (
            self.level_fn(magnitude, self.hparams)
            if self.level_fn is not None
            else tuple()
        )
        return self.aug_fn(img, *level_args, **self.kwargs)

    def __repr__(self):
        fs = self.__class__.__name__ + f"(name={self.name}, p={self.prob}"
        fs += f", m={self.magnitude}, mstd={self.magnitude_std}"
        if self.magnitude_max is not None:
            fs += f", mmax={self.magnitude_max}"
        fs += ")"
        return fs


def auto_augment_policy_v0(hparams):
    # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
    policy = [
        [("Equalize", 0.8, 1), ("ShearY", 0.8, 4)],
        [("Color", 0.4, 9), ("Equalize", 0.6, 3)],
        [("Color", 0.4, 1), ("Rotate", 0.6, 8)],
        [("Solarize", 0.8, 3), ("Equalize", 0.4, 7)],
        [("Solarize", 0.4, 2), ("Solarize", 0.6, 2)],
        [("Color", 0.2, 0), ("Equalize", 0.8, 8)],
        [("Equalize", 0.4, 8), ("SolarizeAdd", 0.8, 3)],
        [("ShearX", 0.2, 9), ("Rotate", 0.6, 8)],
        [("Color", 0.6, 1), ("Equalize", 1.0, 2)],
        [("Invert", 0.4, 9), ("Rotate", 0.6, 0)],
        [("Equalize", 1.0, 9), ("ShearY", 0.6, 3)],
        [("Color", 0.4, 7), ("Equalize", 0.6, 0)],
        [("Posterize", 0.4, 6), ("AutoContrast", 0.4, 7)],
        [("Solarize", 0.6, 8), ("Color", 0.6, 9)],
        [("Solarize", 0.2, 4), ("Rotate", 0.8, 9)],
        [("Rotate", 1.0, 7), ("TranslateYRel", 0.8, 9)],
        [("ShearX", 0.0, 0), ("Solarize", 0.8, 4)],
        [("ShearY", 0.8, 0), ("Color", 0.6, 4)],
        [("Color", 1.0, 0), ("Rotate", 0.6, 2)],
        [("Equalize", 0.8, 4), ("Equalize", 0.0, 8)],
        [("Equalize", 1.0, 4), ("AutoContrast", 0.6, 2)],
        [("ShearY", 0.4, 7), ("SolarizeAdd", 0.6, 7)],
        [
            ("Posterize", 0.8, 2),
            ("Solarize", 0.6, 10),
        ],  # This results in black image with Tpu posterize
        [("Solarize", 0.6, 8), ("Equalize", 0.6, 1)],
        [("Color", 0.8, 6), ("Rotate", 0.4, 5)],
    ]
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
    return pc


def auto_augment_policy_v0r(hparams):
    # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
    # in Google research implementation (number of bits discarded increases with magnitude)
    policy = [
        [("Equalize", 0.8, 1), ("ShearY", 0.8, 4)],
        [("Color", 0.4, 9), ("Equalize", 0.6, 3)],
        [("Color", 0.4, 1), ("Rotate", 0.6, 8)],
        [("Solarize", 0.8, 3), ("Equalize", 0.4, 7)],
        [("Solarize", 0.4, 2), ("Solarize", 0.6, 2)],
        [("Color", 0.2, 0), ("Equalize", 0.8, 8)],
        [("Equalize", 0.4, 8), ("SolarizeAdd", 0.8, 3)],
        [("ShearX", 0.2, 9), ("Rotate", 0.6, 8)],
        [("Color", 0.6, 1), ("Equalize", 1.0, 2)],
        [("Invert", 0.4, 9), ("Rotate", 0.6, 0)],
        [("Equalize", 1.0, 9), ("ShearY", 0.6, 3)],
        [("Color", 0.4, 7), ("Equalize", 0.6, 0)],
        [("PosterizeIncreasing", 0.4, 6), ("AutoContrast", 0.4, 7)],
        [("Solarize", 0.6, 8), ("Color", 0.6, 9)],
        [("Solarize", 0.2, 4), ("Rotate", 0.8, 9)],
        [("Rotate", 1.0, 7), ("TranslateYRel", 0.8, 9)],
        [("ShearX", 0.0, 0), ("Solarize", 0.8, 4)],
        [("ShearY", 0.8, 0), ("Color", 0.6, 4)],
        [("Color", 1.0, 0), ("Rotate", 0.6, 2)],
        [("Equalize", 0.8, 4), ("Equalize", 0.0, 8)],
        [("Equalize", 1.0, 4), ("AutoContrast", 0.6, 2)],
        [("ShearY", 0.4, 7), ("SolarizeAdd", 0.6, 7)],
        [("PosterizeIncreasing", 0.8, 2), ("Solarize", 0.6, 10)],
        [("Solarize", 0.6, 8), ("Equalize", 0.6, 1)],
        [("Color", 0.8, 6), ("Rotate", 0.4, 5)],
    ]
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
    return pc


def auto_augment_policy_original(hparams):
    # ImageNet policy from https://arxiv.org/abs/1805.09501
    policy = [
        [("PosterizeOriginal", 0.4, 8), ("Rotate", 0.6, 9)],
        [("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)],
        [("Equalize", 0.8, 8), ("Equalize", 0.6, 3)],
        [("PosterizeOriginal", 0.6, 7), ("PosterizeOriginal", 0.6, 6)],
        [("Equalize", 0.4, 7), ("Solarize", 0.2, 4)],
        [("Equalize", 0.4, 4), ("Rotate", 0.8, 8)],
        [("Solarize", 0.6, 3), ("Equalize", 0.6, 7)],
        [("PosterizeOriginal", 0.8, 5), ("Equalize", 1.0, 2)],
        [("Rotate", 0.2, 3), ("Solarize", 0.6, 8)],
        [("Equalize", 0.6, 8), ("PosterizeOriginal", 0.4, 6)],
        [("Rotate", 0.8, 8), ("Color", 0.4, 0)],
        [("Rotate", 0.4, 9), ("Equalize", 0.6, 2)],
        [("Equalize", 0.0, 7), ("Equalize", 0.8, 8)],
        [("Invert", 0.6, 4), ("Equalize", 1.0, 8)],
        [("Color", 0.6, 4), ("Contrast", 1.0, 8)],
        [("Rotate", 0.8, 8), ("Color", 1.0, 2)],
        [("Color", 0.8, 8), ("Solarize", 0.8, 7)],
        [("Sharpness", 0.4, 7), ("Invert", 0.6, 8)],
        [("ShearX", 0.6, 5), ("Equalize", 1.0, 9)],
        [("Color", 0.4, 0), ("Equalize", 0.6, 3)],
        [("Equalize", 0.4, 7), ("Solarize", 0.2, 4)],
        [("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)],
        [("Invert", 0.6, 4), ("Equalize", 1.0, 8)],
        [("Color", 0.6, 4), ("Contrast", 1.0, 8)],
        [("Equalize", 0.8, 8), ("Equalize", 0.6, 3)],
    ]
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
    return pc


def auto_augment_policy_originalr(hparams):
    # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
    policy = [
        [("PosterizeIncreasing", 0.4, 8), ("Rotate", 0.6, 9)],
        [("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)],
        [("Equalize", 0.8, 8), ("Equalize", 0.6, 3)],
        [("PosterizeIncreasing", 0.6, 7), ("PosterizeIncreasing", 0.6, 6)],
        [("Equalize", 0.4, 7), ("Solarize", 0.2, 4)],
        [("Equalize", 0.4, 4), ("Rotate", 0.8, 8)],
        [("Solarize", 0.6, 3), ("Equalize", 0.6, 7)],
        [("PosterizeIncreasing", 0.8, 5), ("Equalize", 1.0, 2)],
        [("Rotate", 0.2, 3), ("Solarize", 0.6, 8)],
        [("Equalize", 0.6, 8), ("PosterizeIncreasing", 0.4, 6)],
        [("Rotate", 0.8, 8), ("Color", 0.4, 0)],
        [("Rotate", 0.4, 9), ("Equalize", 0.6, 2)],
        [("Equalize", 0.0, 7), ("Equalize", 0.8, 8)],
        [("Invert", 0.6, 4), ("Equalize", 1.0, 8)],
        [("Color", 0.6, 4), ("Contrast", 1.0, 8)],
        [("Rotate", 0.8, 8), ("Color", 1.0, 2)],
        [("Color", 0.8, 8), ("Solarize", 0.8, 7)],
        [("Sharpness", 0.4, 7), ("Invert", 0.6, 8)],
        [("ShearX", 0.6, 5), ("Equalize", 1.0, 9)],
        [("Color", 0.4, 0), ("Equalize", 0.6, 3)],
        [("Equalize", 0.4, 7), ("Solarize", 0.2, 4)],
        [("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)],
        [("Invert", 0.6, 4), ("Equalize", 1.0, 8)],
        [("Color", 0.6, 4), ("Contrast", 1.0, 8)],
        [("Equalize", 0.8, 8), ("Equalize", 0.6, 3)],
    ]
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
    return pc


def auto_augment_policy(name="v0", hparams=None):
    hparams = hparams or _HPARAMS_DEFAULT
    if name == "original":
        return auto_augment_policy_original(hparams)
    elif name == "originalr":
        return auto_augment_policy_originalr(hparams)
    elif name == "v0":
        return auto_augment_policy_v0(hparams)
    elif name == "v0r":
        return auto_augment_policy_v0r(hparams)
    else:
        assert False, "Unknown AA policy (%s)" % name


[docs]class AutoAugment: """ Auto Augmentation From paper `AutoAugment: Learning Augmentation Policies from Data <https://arxiv.org/abs/1805.09501>`_ """ def __init__(self, policy): self.policy = policy def __call__(self, img): sub_policy = random.choice(self.policy) for op in sub_policy: img = op(img) return img def __repr__(self): fs = self.__class__.__name__ + f"(policy=" for p in self.policy: fs += "\n\t[" fs += ", ".join([str(op) for op in p]) fs += "]" fs += ")" return fs
[docs]def auto_augment_transform(config_str, hparams): """ Create a AutoAugment transform :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). The remaining sections, not order sepecific determine * 'mstd' - float std deviation of magnitude noise applied Example: 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme :return: An OneFlow compatible Transform """ config = config_str.split("-") policy_name = config[0] config = config[1:] for c in config: cs = re.split(r"(\d.*)", c) if len(cs) < 2: continue key, val = cs[:2] if key == "mstd": # noise param injected via hparams for now hparams.setdefault("magnitude_std", float(val)) else: assert False, "Unknown AutoAugment config section" aa_policy = auto_augment_policy(policy_name, hparams=hparams) return AutoAugment(aa_policy)
_RAND_TRANSFORMS = [ "AutoContrast", "Equalize", "Invert", "Rotate", "Posterize", "Solarize", "SolarizeAdd", "Color", "Contrast", "Brightness", "Sharpness", "ShearX", "ShearY", "TranslateXRel", "TranslateYRel", #'Cutout' # NOTE I've implement this as random erasing separately ] _RAND_INCREASING_TRANSFORMS = [ "AutoContrast", "Equalize", "Invert", "Rotate", "PosterizeIncreasing", "SolarizeIncreasing", "SolarizeAdd", "ColorIncreasing", "ContrastIncreasing", "BrightnessIncreasing", "SharpnessIncreasing", "ShearX", "ShearY", "TranslateXRel", "TranslateYRel", ] # These experimental weights are based loosely on the relative improvements mentioned in paper. # They may not result in increased performance, but could likely be tuned to so. _RAND_CHOICE_WEIGHTS_0 = { "Rotate": 0.3, "ShearX": 0.2, "ShearY": 0.2, "TranslateXRel": 0.1, "TranslateYRel": 0.1, "Color": 0.025, "Sharpness": 0.025, "AutoContrast": 0.025, "Solarize": 0.005, "SolarizeAdd": 0.005, "Contrast": 0.005, "Brightness": 0.005, "Equalize": 0.005, "Posterize": 0, "Invert": 0, } def _select_rand_weights(weight_idx=0, transforms=None): transforms = transforms or _RAND_TRANSFORMS assert weight_idx == 0 # only one set of weights currently rand_weights = _RAND_CHOICE_WEIGHTS_0 probs = [rand_weights[k] for k in transforms] probs /= np.sum(probs) return probs def rand_augment_ops(magnitude=10, hparams=None, transforms=None): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS return [ AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms ]
[docs]class RandAugment: """ Random Augmentation From paper `RandAugment: Practical automated data augmentation <https://arxiv.org/abs/1909.13719>`_ """ def __init__(self, ops, num_layers=2, choice_weights=None): self.ops = ops self.num_layers = num_layers self.choice_weights = choice_weights def __call__(self, img): # no replacement when using weighted choice ops = np.random.choice( self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights, ) for op in ops: img = op(img) return img def __repr__(self): fs = self.__class__.__name__ + f"(n={self.num_layers}, ops=" for op in self.ops: fs += f"\n\t{op}" fs += ")" return fs
[docs]def rand_augment_transform(config_str, hparams): """ Create a RandAugment transform :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining sections, not order sepecific determine * 'm' - integer magnitude of rand augment * 'n' - integer num layers (number of transform ops selected per image) * 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) * 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) * 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) * 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) Example: 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 :param hparams: Other hparams (kwargs) for the RandAugmentation scheme Returns: An OneFlow compatible Transform """ magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10) num_layers = 2 # default to 2 ops per image weight_idx = None # default to no probability weights for op choice transforms = _RAND_TRANSFORMS config = config_str.split("-") assert config[0] == "rand" config = config[1:] for c in config: cs = re.split(r"(\d.*)", c) if len(cs) < 2: continue key, val = cs[:2] if key == "mstd": # noise param / randomization of magnitude values mstd = float(val) if mstd > 100: # use uniform sampling in 0 to magnitude if mstd is > 100 mstd = float("inf") hparams.setdefault("magnitude_std", mstd) elif key == "mmax": # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM] hparams.setdefault("magnitude_max", int(val)) elif key == "inc": if bool(val): transforms = _RAND_INCREASING_TRANSFORMS elif key == "m": magnitude = int(val) elif key == "n": num_layers = int(val) elif key == "w": weight_idx = int(val) else: assert False, "Unknown RandAugment config section" ra_ops = rand_augment_ops( magnitude=magnitude, hparams=hparams, transforms=transforms ) choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
_AUGMIX_TRANSFORMS = [ "AutoContrast", "ColorIncreasing", # not in paper "ContrastIncreasing", # not in paper "BrightnessIncreasing", # not in paper "SharpnessIncreasing", # not in paper "Equalize", "Rotate", "PosterizeIncreasing", "SolarizeIncreasing", "ShearX", "ShearY", "TranslateXRel", "TranslateYRel", ] def augmix_ops(magnitude=10, hparams=None, transforms=None): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _AUGMIX_TRANSFORMS return [ AugmentOp(name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms ]
[docs]class AugMixAugment: """ AugMix Transform Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py From paper `AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty <https://arxiv.org/abs/1912.02781>`_ """ def __init__(self, ops, alpha=1.0, width=3, depth=-1, blended=False): self.ops = ops self.alpha = alpha self.width = width self.depth = depth self.blended = blended # blended mode is faster but not well tested def _calc_blended_weights(self, ws, m): ws = ws * m cump = 1.0 rws = [] for w in ws[::-1]: alpha = w / cump cump *= 1 - alpha rws.append(alpha) return np.array(rws[::-1], dtype=np.float32) def _apply_blended(self, img, mixing_weights, m): # This is my first crack and implementing a slightly faster mixed augmentation. Instead # of accumulating the mix for each chain in a Numpy array and then blending with original, # it recomputes the blending coefficients and applies one PIL image blend per chain. # TODO the results appear in the right ballpark but they differ by more than rounding. img_orig = img.copy() ws = self._calc_blended_weights(mixing_weights, m) for w in ws: depth = self.depth if self.depth > 0 else np.random.randint(1, 4) ops = np.random.choice(self.ops, depth, replace=True) img_aug = img_orig # no ops are in-place, deep copy not necessary for op in ops: img_aug = op(img_aug) img = Image.blend(img, img_aug, w) return img def _apply_basic(self, img, mixing_weights, m): # This is a literal adaptation of the paper/official implementation without normalizations and # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the # typical augmentation transforms, could use a GPU / Kornia implementation. img_shape = img.size[0], img.size[1], len(img.getbands()) mixed = np.zeros(img_shape, dtype=np.float32) for mw in mixing_weights: depth = self.depth if self.depth > 0 else np.random.randint(1, 4) ops = np.random.choice(self.ops, depth, replace=True) img_aug = img # no ops are in-place, deep copy not necessary for op in ops: img_aug = op(img_aug) mixed += mw * np.asarray(img_aug, dtype=np.float32) np.clip(mixed, 0, 255.0, out=mixed) mixed = Image.fromarray(mixed.astype(np.uint8)) return Image.blend(img, mixed, m) def __call__(self, img): mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) m = np.float32(np.random.beta(self.alpha, self.alpha)) if self.blended: mixed = self._apply_blended(img, mixing_weights, m) else: mixed = self._apply_basic(img, mixing_weights, m) return mixed def __repr__(self): fs = ( self.__class__.__name__ + f"(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops=" ) for op in self.ops: fs += f"\n\t{op}" fs += ")" return fs
[docs]def augment_and_mix_transform(config_str, hparams): """ Create AugMix OneFlow transform :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining sections, not order sepecific determine * 'm' - integer magnitude (severity) of augmentation mix (default: 3) * 'w' - integer width of augmentation chain (default: 3) * 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) * 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) * 'mstd' - float std deviation of magnitude noise applied (default: 0) Example: 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 :param hparams: Other hparams (kwargs) for the Augmentation transforms :return: An OneFlow compatible Transform """ magnitude = 3 width = 3 depth = -1 alpha = 1.0 blended = False config = config_str.split("-") assert config[0] == "augmix" config = config[1:] for c in config: cs = re.split(r"(\d.*)", c) if len(cs) < 2: continue key, val = cs[:2] if key == "mstd": # noise param injected via hparams for now hparams.setdefault("magnitude_std", float(val)) elif key == "m": magnitude = int(val) elif key == "w": width = int(val) elif key == "d": depth = int(val) elif key == "a": alpha = float(val) elif key == "b": blended = bool(val) else: assert False, "Unknown AugMix config section" hparams.setdefault( "magnitude_std", float("inf") ) # default to uniform sampling (if not set via mstd arg) ops = augmix_ops(magnitude=magnitude, hparams=hparams) return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)