ColorJitter in PyTorch

Buy Me a Coffee

ColorJitter() can change the brightness, contrast, saturation and hue of zero or more images as shown below:

*Memos:

  • The 1st argument for initialization is brightness(Optional-Default:0-Type:float or tuple/list(float)): *Memos:
    • It’s the range of the brightness [min, max].
    • It must be 0 <= x.
    • A single value is converted to [max(0, 1-brightness), 1+brightness].
    • A tuple or list must be the 1D with 2 elements. *The 1st element must be less than or equal to the 2nd element.
  • The 2nd argument for initialization is contrast(Optional-Default:0-Type:float or tuple/list(float)): *Memos:
    • It’s the range of the contrast [min, max].
    • It must be 0 <= x.
    • A single value is converted to [max(0, 1-contrast), 1+contrast].
    • A tuple or list must be the 1D with 2 elements. *The 1st element must be less than or equal to the 2nd element.
  • The 3rd argument for initialization is saturation(Optional-Default:0-Type:float or tuple/list(float)): *Memos:
    • It’s the range of the saturation [min, max].
    • It must be 0 <= x.
    • A single value is converted to [max(0, 1-saturation), 1+saturation].
    • A tuple or list must be the 1D with 2 elements. *The 1st element must be less than or equal to the 2nd element.
  • The 4th argument for initialization is hue(Optional-Default:0-Type:float or tuple/list(float)): *Memos:
    • It’s the range of the hue [min, max].
    • It must be -0.5 <= x <= 0.5.
    • A single value is converted to [-hue, +hue].
    • A tuple or list must be the 1D with 2 elements. *The 1st element must be less than or equal to the 2nd element.
  • The 1st argument is img(Required-Type:PIL Image or tensor/tuple/list(int or float)): *Memos:
    • It must be 2D or 3D. For 3D, the deepest D must have one element.
    • Don’t use img=.
  • v2 is recommended to use according to V1 or V2? Which one should I use?.
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import ColorJitter

colorjitter = ColorJitter()
colorjitter = ColorJitter(brightness=0,
                          contrast=0,
                          saturation=0,
                          hue=0)
colorjitter = ColorJitter(brightness=(1.0, 2.0),
                          contrast=(1.0, 1.0),
                          saturation=(1.0, 1.0),
                          hue=(0.0, 0.0))
colorjitter
# ColorJitter() 
print(colorjitter.brightness)
# None 
print(colorjitter.contrast)
# None 
print(colorjitter.saturation)
# None 
print(colorjitter.hue)
# None 
origin_data = OxfordIIITPet(
    root="data",
    transform=None
    # transform=ColorJitter()     # colorjitter = ColorJitter(brightness=0,     # contrast=0,     # saturation=0,     # hue=0)     # transform=ColorJitter(brightness=(1.0, 1.0),     # contrast=(1.0, 1.0),     # saturation=(1.0, 1.0),     # hue=(0.0, 0.0)) )

p2bright_data = OxfordIIITPet( # `p` is plus.     root="data",
    transform=ColorJitter(brightness=2.0)
    # transform=ColorJitter(brightness=(0.0, 3.0)) )

p2p2bright_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(brightness=(2.0, 2.0))
)

p05p05bright_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(brightness=(0.5, 0.5))
)

p2contra_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=2.0)
    # transform=ColorJitter(contrast=(0.0, 3.0)) )

p2p2contra_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=(2.0, 2.0))
)

p05p05contra_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=(0.5, 0.5))
)

p2satura_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(saturation=2.0)
    # transform=ColorJitter(saturation=(0.0, 3.0)) )

p2p2satura_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(saturation=(2.0, 2.0))
)

p05p05satura_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(saturation=(0.5, 0.5))
)

p05hue_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(hue=0.5)
    # transform=ColorJitter(hue=(-0.5, 0.5)) )

p025p025hue_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(hue=(0.25, 0.25))
)

m025m025hue_data = OxfordIIITPet( # `m` is minus.     root="data",
    transform=ColorJitter(hue=(-0.25, -0.25))
)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    for i, (im, _) in zip(range(1, 6), data):
        plt.subplot(1, 5, i)
        plt.imshow(X=im)
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images(data=origin_data, main_title="origin_data")
show_images(data=p2bright_data, main_title="p2bright_data")
show_images(data=p2p2bright_data, main_title="p2p2bright_data")
show_images(data=p05p05bright_data, main_title="p05p05bright_data")

show_images(data=origin_data, main_title="origin_data")
show_images(data=p2contra_data, main_title="p2contra_data")
show_images(data=p2p2contra_data, main_title="p2p2contra_data")
show_images(data=p05p05contra_data, main_title="p05p05contra_data")

show_images(data=origin_data, main_title="origin_data")
show_images(data=p2satura_data, main_title="p2satura_data")
show_images(data=p2p2satura_data, main_title="p2p2satura_data")
show_images(data=p05p05satura_data, main_title="p05p05satura_data")

show_images(data=origin_data, main_title="origin_data")
show_images(data=p05hue_data, main_title="p05hue_data")
show_images(data=p025p025hue_data, main_title="p025p025hue_data")
show_images(data=m025m025hue_data, main_title="m025m025hue_data")

Enter fullscreen mode Exit fullscreen mode

from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import RandomRotation

my_data = OxfordIIITPet(
    root="data",
    transform=None
)

import matplotlib.pyplot as plt

def show_images(data, main_title=None, b=0.0, c=0.0, s=0.0, h=0.0):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    for i, (im, _) in zip(range(1, 6), data):
        plt.subplot(1, 5, i)
        cj = ColorJitter(brightness=b, contrast=c, saturation=s, hue=h)
        plt.imshow(X=cj(im)) # Here         plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images(data=my_data, main_title="origin_data")
show_images(data=my_data, main_title="p2bright_data", b=2.0)
show_images(data=my_data, main_title="p2p2bright_data", b=(2.0, 2.0))
show_images(data=my_data, main_title="p05p05bright_data", b=(0.5, 0.5))

show_images(data=my_data, main_title="origin_data")
show_images(data=my_data, main_title="p2contra_data", c=2.0)
show_images(data=my_data, main_title="p2p2contra_data", c=(2.0, 2.0))
show_images(data=my_data, main_title="p05p05contra_data", c=(0.5, 0.5))

show_images(data=my_data, main_title="origin_data")
show_images(data=my_data, main_title="p2satura_data", s=2.0)
show_images(data=my_data, main_title="p2p2satura_data", s=(2.0, 2.0))
show_images(data=my_data, main_title="p05p05satura_data", s=(0.5, 0.5))

show_images(data=my_data, main_title="origin_data")
show_images(data=my_data, main_title="p05hue_data", h=0.5)
show_images(data=my_data, main_title="p025p025hue_data", h=(0.25, 0.25))
show_images(data=my_data, main_title="m025m025hue_data", h=(-0.25, -0.25))

Enter fullscreen mode Exit fullscreen mode

原文链接:ColorJitter in PyTorch

© 版权声明
THE END
喜欢就支持一下吧
点赞9 分享
评论 抢沙发

请登录后发表评论

    暂无评论内容