Skip to content

Compatibility with timm augmentation? #195

@ardasahiner

Description

@ardasahiner

Hi,

I was attempting to use FFCV with timm, using the fact thattorch.nn.Modules should be compatible with the pipelines argument of FFCV's Loader. However, I am getting some strange errors and would like some clarification on what is going wrong here.

Please see my simple reproducible implementation below. I use CIFAR100 images and use timm's create_transform function. While each transform is not an instance of nn.Module, I attempted to wrap it in a simple module with the CustomClass. However, I get the following issue as documented below.

Would you have any suggestions what is causing this issue, or any ideas for a simpler integration with timm? Any help is appreciated.

Error:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'self': Cannot determine Numba type of <class 'ffcv.transforms.module.ModuleWrapper'>

File "../anaconda3/envs/ffcv/lib/python3.9/site-packages/ffcv/transforms/module.py", line 25:
        def apply_module(inp, _):
            res = self.module(inp)
            ^

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7fef91639670>))
During: typing of call at  (2)

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7fef91639670>))
During: typing of call at  (2)

Implementation:

import torch
import numpy as np
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
    RandomResizedCropRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, NormalizeImage, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze

class CustomClass(torch.nn.Module):
    def __init__(self, transform):
        super().__init__()
        self.transform = transform

    @staticmethod
    def get_params(img, scale, ratio):
        return self.transform.get_params(img, scale, ratio)

    def forward(self, img):
        return self.transform(img)

    def __repr__(self):
        return self.transform.__repr__()

is_train = True
imnet_mean, imnet_std = np.array(IMAGENET_DEFAULT_MEAN)*256, np.array(IMAGENET_DEFAULT_STD)*256

paths = {
    'train': 'cifar100_train.beton',
    'test': 'cifar100_test.beton'
}

to_module = create_transform(
                input_size=224,
                is_training=True,
                color_jitter=0.4,
                auto_augment='rand-m9-mstd0.5-inc1',
                interpolation='bicubic',
                re_prob=0.25,
                re_mode='pixel',
                re_count=False,
                mean = imnet_mean,
                std = imnet_std,
            )

module_list = []
for t in to_module.transforms:
    if isinstance(t, torch.nn.Module):
        module_list.append(t)
    else:
        t_new = CustomClass(t)
        module_list.append(t_new)

transform = torch.nn.Sequential(*module_list)

label_pipeline = [IntDecoder(), ToTensor(), Squeeze()]
image_pipeline = [SimpleRGBImageDecoder(), transform]

ordering =(OrderOption.QUASI_RANDOM) if is_train else OrderOption.SEQUENTIAL
dataset = Loader(paths['train'] if is_train else paths['test'], batch_size=10, num_workers=2,
                       order=ordering, drop_last=(is_train), os_cache=True, distributed=False,
                       pipelines={'image': image_pipeline, 'label': label_pipeline})

for i, (image, label) in enumerate(dataset):
    if i == 1:
        break
    print('loaded one image')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions