Source code for flowvision.data.random_erasing

"""OneFlow implementation of Random Erasing(Cutout)
Modified from https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
"""

import random
import math
import oneflow as flow


def _get_pixels(per_pixel, rand_color, patch_size, dtype=flow.float32, device="cuda"):
    if per_pixel:
        return flow.empty(patch_size, dtype=dtype, device=device).normal_()
    elif rand_color:
        return flow.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
    else:
        return flow.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)


[docs]class RandomErasing: """ Randomly selects a rectangle region in an image and erases its pixels. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf This variant of RandomErasing is intended to be applied to either a batch or single image tensor after it has been normalized by dataset mean and std. Args: probability: Probability that the Random Erasing operation will be performed min_area: Minimum percentage of erased area wrt input image area max_area: Maximum percentage of erased area wrt input image area min_aspect: Minimum aspect ratio of erased area mode: Pixel color mode, one of 'const', 'rand', or 'pixel' * 'const' - erase block is constant color of 0 for all channels * 'rand' - erase block is same per-channel random (normal) color * 'pixel' - erase block is per-pixel random (normal) color max_count: Maximum number of erasing blocks per image, area per box is scaled by count. per-image count is randomly chosen between 1 and this value """ def __init__( self, probability=0.5, min_area=0.02, max_area=1 / 3, min_aspect=0.3, max_aspect=None, mode="const", min_count=1, max_count=None, num_splits=0, device="cuda", ): self.probability = probability self.min_area = min_area self.max_area = max_area max_aspect = max_aspect or 1 / min_aspect self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) self.min_count = min_count self.max_count = max_count or min_count self.num_splits = num_splits self.mode = mode.lower() self.rand_color = False self.per_pixel = False if self.mode == "rand": self.rand_color = True # per block random normal elif self.mode == "pixel": self.per_pixel = True # per pixel random normal else: assert not self.mode or self.mode == "const" self.device = device def _erase(self, img, chan, img_h, img_w, dtype): if random.random() > self.probability: return area = img_h * img_w count = ( self.min_count if self.min_count == self.max_count else random.randint(self.min_count, self.max_count) ) for _ in range(count): for attempt in range(10): target_area = ( random.uniform(self.min_area, self.max_area) * area / count ) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < img_w and h < img_h: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) img[:, top : top + h, left : left + w] = _get_pixels( self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=self.device, ) break def __call__(self, input): if len(input.size()) == 3: self._erase(input, *input.size(), input.dtype) else: batch_size, chan, img_h, img_w = input.size() # skip first slice of batch if num_splits is set (for clean portion of samples) batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 for i in range(batch_start, batch_size): self._erase(input[i], chan, img_h, img_w, input.dtype) return input def __repr__(self): # NOTE simplified state for repr fs = self.__class__.__name__ + f"(p={self.probability}, mode={self.mode}" fs += f", count=({self.min_count}, {self.max_count}))" return fs