From 1ca6d67564c508c13db69573ec7075259293d484 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Tue, 9 Aug 2022 19:49:02 +0000 Subject: [PATCH 01/16] updated dataset.py to download unlabelled dataset --- ssl_study/data/dataset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ssl_study/data/dataset.py b/ssl_study/data/dataset.py index f47f036..6e44b8a 100644 --- a/ssl_study/data/dataset.py +++ b/ssl_study/data/dataset.py @@ -34,6 +34,10 @@ def download_dataset(dataset_name: str, data_df = pd.read_csv(save_at+'train.csv') elif dataset_name == 'val' and os.path.exists(save_at+'valid.csv'): data_df = pd.read_csv(save_at+'valid.csv') + elif dataset_name == 'in-class' and os.path.exists(save_at+'in-class.csv'): + data_df = pd.read_csv(save_at+'in-class.csv') + elif dataset_name == 'out-class' and os.path.exists(save_at+'out-class.csv'): + data_df = pd.read_csv(save_at+'out-class.csv') else: data_df = None print('Downloading dataset...') @@ -82,6 +86,12 @@ def download_dataset(dataset_name: str, if dataset_name == 'val' and not os.path.exists(save_at+'valid.csv'): data_df.to_csv(save_at+'valid.csv', index=False) + if dataset_name == 'in-class' and not os.path.exists(save_at+'in-class.csv'): + data_df.to_csv(save_at+'in-class.csv', index=False) + + if dataset_name == 'out-class' and not os.path.exists(save_at+'out-class.csv'): + data_df.to_csv(save_at+'out-class.csv', index=False) + return data_df From 7b796a033e676a1d56df6cc48cfee922978fe8ac Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Tue, 9 Aug 2022 20:22:33 +0000 Subject: [PATCH 02/16] simclrv1 structure set --- ssl_study/simclrv1/data_aug/__init__.py | 0 ssl_study/simclrv1/model/__init__.py | 0 ssl_study/simclrv1/pipeline/__init__.py | 0 ssl_study/simclrv1/simclrv1_train.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 ssl_study/simclrv1/data_aug/__init__.py create mode 100644 ssl_study/simclrv1/model/__init__.py create mode 100644 ssl_study/simclrv1/pipeline/__init__.py create mode 100644 ssl_study/simclrv1/simclrv1_train.py diff --git a/ssl_study/simclrv1/data_aug/__init__.py b/ssl_study/simclrv1/data_aug/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ssl_study/simclrv1/model/__init__.py b/ssl_study/simclrv1/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ssl_study/simclrv1/pipeline/__init__.py b/ssl_study/simclrv1/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ssl_study/simclrv1/simclrv1_train.py b/ssl_study/simclrv1/simclrv1_train.py new file mode 100644 index 0000000..e69de29 From ba6cbde7821d1183899be3aa9e0ed5a0729b56d9 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Tue, 9 Aug 2022 20:48:12 +0000 Subject: [PATCH 03/16] updated directory structure --- ssl_study/simclrv1/{ => pretext_task}/data_aug/__init__.py | 0 ssl_study/simclrv1/{ => pretext_task}/model/__init__.py | 0 ssl_study/simclrv1/{ => pretext_task}/pipeline/__init__.py | 0 ssl_study/simclrv1/{ => pretext_task}/simclrv1_train.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename ssl_study/simclrv1/{ => pretext_task}/data_aug/__init__.py (100%) rename ssl_study/simclrv1/{ => pretext_task}/model/__init__.py (100%) rename ssl_study/simclrv1/{ => pretext_task}/pipeline/__init__.py (100%) rename ssl_study/simclrv1/{ => pretext_task}/simclrv1_train.py (100%) diff --git a/ssl_study/simclrv1/data_aug/__init__.py b/ssl_study/simclrv1/pretext_task/data_aug/__init__.py similarity index 100% rename from ssl_study/simclrv1/data_aug/__init__.py rename to ssl_study/simclrv1/pretext_task/data_aug/__init__.py diff --git a/ssl_study/simclrv1/model/__init__.py b/ssl_study/simclrv1/pretext_task/model/__init__.py similarity index 100% rename from ssl_study/simclrv1/model/__init__.py rename to ssl_study/simclrv1/pretext_task/model/__init__.py diff --git a/ssl_study/simclrv1/pipeline/__init__.py b/ssl_study/simclrv1/pretext_task/pipeline/__init__.py similarity index 100% rename from ssl_study/simclrv1/pipeline/__init__.py rename to ssl_study/simclrv1/pretext_task/pipeline/__init__.py diff --git a/ssl_study/simclrv1/simclrv1_train.py b/ssl_study/simclrv1/pretext_task/simclrv1_train.py similarity index 100% rename from ssl_study/simclrv1/simclrv1_train.py rename to ssl_study/simclrv1/pretext_task/simclrv1_train.py From e047d0b756de2e5ee3e17dc67ff5fdadbf9228a5 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Tue, 9 Aug 2022 20:49:33 +0000 Subject: [PATCH 04/16] updated directory structure --- ssl_study/simclrv1/downstream_task/simclrv1_train.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 ssl_study/simclrv1/downstream_task/simclrv1_train.py diff --git a/ssl_study/simclrv1/downstream_task/simclrv1_train.py b/ssl_study/simclrv1/downstream_task/simclrv1_train.py new file mode 100644 index 0000000..e69de29 From 04dc1eebf1c5a347bcb8a32ccb688aa40abd6be7 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Tue, 16 Aug 2022 22:42:55 +0000 Subject: [PATCH 05/16] added data augmentation pipeline --- .../simclrv1/pretext_task/configs/__init__.py | 5 ++ .../simclrv1/pretext_task/configs/config.py | 35 +++++++++++++ .../{data_aug => data}/__init__.py | 0 .../simclrv1/pretext_task/data/data_aug.py | 52 +++++++++++++++++++ 4 files changed, 92 insertions(+) create mode 100644 ssl_study/simclrv1/pretext_task/configs/__init__.py create mode 100644 ssl_study/simclrv1/pretext_task/configs/config.py rename ssl_study/simclrv1/pretext_task/{data_aug => data}/__init__.py (100%) create mode 100644 ssl_study/simclrv1/pretext_task/data/data_aug.py diff --git a/ssl_study/simclrv1/pretext_task/configs/__init__.py b/ssl_study/simclrv1/pretext_task/configs/__init__.py new file mode 100644 index 0000000..27a3232 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/configs/__init__.py @@ -0,0 +1,5 @@ +from .config import get_config + +__all__ = [ + 'get_config' +] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/configs/config.py b/ssl_study/simclrv1/pretext_task/configs/config.py new file mode 100644 index 0000000..ff2bff3 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/configs/config.py @@ -0,0 +1,35 @@ +import os +import ml_collections + + +def get_wandb_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.project = "ssl-study" + configs.entity = "wandb_fc" + + return configs + +def get_augment_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.image_height = 224 #default - 224 + configs.image_width = 224 #default - 224 + configs.cropscale = (0.08, 1.0) + configs.cropratio = (0.75, 1.3333333333333333) + configs.jitterbrightness = 0.2 + configs.jittercontrast = 0.2 + configs.jittersaturation = 0.2 + configs.jitterhue = 0.2 + configs.gaussianblurlimit = (3, 7) + configs.gaussiansigmalimit = 0 + configs.alwaysapply = False + configs.probability = 0.5 + + return configs + +def get_config() -> ml_collections.ConfigDict: + config = ml_collections.ConfigDict() + config.seed = 0 + config.wandb_config = get_wandb_configs() + config.augmentation_config = get_augment_configs() + + return config \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data_aug/__init__.py b/ssl_study/simclrv1/pretext_task/data/__init__.py similarity index 100% rename from ssl_study/simclrv1/pretext_task/data_aug/__init__.py rename to ssl_study/simclrv1/pretext_task/data/__init__.py diff --git a/ssl_study/simclrv1/pretext_task/data/data_aug.py b/ssl_study/simclrv1/pretext_task/data/data_aug.py new file mode 100644 index 0000000..dec4be0 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/data/data_aug.py @@ -0,0 +1,52 @@ +import numpy as np +import tensorflow as tf +from functools import partial +import albumentations as A + +AUTOTUNE = tf.data.AUTOTUNE + +class Augment(): + def __init__(self, args): + self.args = args + + def build_augmentation(self): + transform = A.Compose([ + A.RandomResizedCrop(self.args.augmentation_config["img_height"], + self.args.augmentation_config["img_width"], + self.args.augmentation_config["cropscale"], + self.args.augmentation_config["cropratio"], + self.args.augmentation_config["probability"]), + A.ColorJitter (self.args.augmentation_config["jitterbrightness"], + self.args.augmentation_config["jittercontrast"], + self.args.augmentation_config["jittersaturation"], + self.args.augmentation_config["hue"], + self.args.augmentation_config["alwaysapply"], + self.args.augmentation_config["probablility"]), + A.GaussianBlur (self.args.augmentation_config["gaussianblurlimit"], + self.args.augmentation_config["gaussiansigmalimit"], + self.args.augmentation_config["alwaysapply"], + self.args.augmentation_config["probability"]) + ]) + return transform + + def augmentation(self, image, label): + aug_img = tf.numpy_function(func=self.aug_fn, inp=[image], Tout=tf.float32) + aug_img.set_shape((self.args.train_config["img_height"], + self.args.train_config["img_width"], 3)) + + aug_img = tf.image.random_flip_left_right(aug_img) + aug_img = tf.image.resize(aug_img, + [self.args.train_config["img_height"], + self.args.train_config["img_width"]], + method='bicubic', + preserve_aspect_ratio=False) + aug_img = tf.clip_by_value(aug_img, 0.0, 1.0) + + return aug_img, label + + def aug_fn(self, image): + data = {"image":image} + aug_data = self.transform(**data) + aug_img = aug_data["image"] + + return aug_img.astype(np.float32) \ No newline at end of file From b0c5ae8663348dd3f8376de410d67e58f9e49b1d Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Wed, 17 Aug 2022 00:03:41 +0000 Subject: [PATCH 06/16] added data augmentation pipeline --- .../simclrv1_pretext_config.py | 1 - simclrv1_pretext.py | 39 +++++++++++++++++++ ssl_study/data/__init__.py | 4 +- ssl_study/data/dataset.py | 14 +++++-- .../simclrv1/pretext_task/configs/__init__.py | 5 --- .../simclrv1/pretext_task/data/__init__.py | 5 +++ .../simclrv1/pretext_task/data/preprocess.py | 26 +++++++++++++ .../simclrv1/pretext_task/simclrv1_train.py | 0 8 files changed, 83 insertions(+), 11 deletions(-) rename ssl_study/simclrv1/pretext_task/configs/config.py => configs/simclrv1_pretext_config.py (99%) create mode 100644 simclrv1_pretext.py delete mode 100644 ssl_study/simclrv1/pretext_task/configs/__init__.py create mode 100644 ssl_study/simclrv1/pretext_task/data/preprocess.py delete mode 100644 ssl_study/simclrv1/pretext_task/simclrv1_train.py diff --git a/ssl_study/simclrv1/pretext_task/configs/config.py b/configs/simclrv1_pretext_config.py similarity index 99% rename from ssl_study/simclrv1/pretext_task/configs/config.py rename to configs/simclrv1_pretext_config.py index ff2bff3..9d3352f 100644 --- a/ssl_study/simclrv1/pretext_task/configs/config.py +++ b/configs/simclrv1_pretext_config.py @@ -1,7 +1,6 @@ import os import ml_collections - def get_wandb_configs() -> ml_collections.ConfigDict: configs = ml_collections.ConfigDict() configs.project = "ssl-study" diff --git a/simclrv1_pretext.py b/simclrv1_pretext.py new file mode 100644 index 0000000..fb00775 --- /dev/null +++ b/simclrv1_pretext.py @@ -0,0 +1,39 @@ +# General imports +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +import glob +import wandb +from absl import app +from absl import flags +import numpy as np +import tensorflow as tf +from ml_collections.config_flags import config_flags + +# Import modules +from ssl_study.data import download_dataset, preprocess_dataframe_unlabelled + +FLAGS = flags.FLAGS +CONFIG = config_flags.DEFINE_config_file("config") + +def main(_): + with wandb.init( + entity=CONFIG.value.wandb_config.entity, + project=CONFIG.value.wandb_config.project, + job_type='simclrv1_pretext', + config=CONFIG.value.to_dict(), + ): + # Access all hyperparameter values through wandb.config + config = wandb.config + # Seed Everything + tf.random.set_seed(config.seed) + + # Load the dataframes + inclass_df = download_dataset('in-class', 'unlabelled-dataset') + + # Preprocess the DataFrames + inclass_paths = preprocess_dataframe_unlabelled(inclass_df) + + print(inclass_paths) + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/ssl_study/data/__init__.py b/ssl_study/data/__init__.py index 64df8f1..2f18d70 100644 --- a/ssl_study/data/__init__.py +++ b/ssl_study/data/__init__.py @@ -1,6 +1,6 @@ -from .dataset import download_dataset, preprocess_dataset +from .dataset import download_dataset, preprocess_dataframe_labelled, preprocess_dataframe_unlabelled from .dataloader import GetDataloader __all__ = [ - 'download_dataset', 'preprocess_dataset', 'GetDataloader' + 'download_dataset', 'preprocess_dataframe_labelled', 'preprocess_dataframe_unlabelled', 'GetDataloader' ] \ No newline at end of file diff --git a/ssl_study/data/dataset.py b/ssl_study/data/dataset.py index 6e44b8a..238a66e 100644 --- a/ssl_study/data/dataset.py +++ b/ssl_study/data/dataset.py @@ -94,8 +94,7 @@ def download_dataset(dataset_name: str, return data_df - -def preprocess_dataset(df): +def preprocess_dataframe_labelled(df): # TODO: take care of df without labels. # Remove unnecessary columns df = df.drop(['image_id', 'width', 'height'], axis=1) @@ -107,4 +106,13 @@ def preprocess_dataset(df): image_paths = df.image_path.values labels = df.label.values - return image_paths, labels \ No newline at end of file + return image_paths, labels + +def preprocess_dataframe_unlabelled(df): + # Remove unnecessary columns + df = df.drop(['image_id', 'width', 'height'], axis=1) + + # Fix types + image_paths = df.image_path.values + + return image_paths \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/configs/__init__.py b/ssl_study/simclrv1/pretext_task/configs/__init__.py deleted file mode 100644 index 27a3232..0000000 --- a/ssl_study/simclrv1/pretext_task/configs/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .config import get_config - -__all__ = [ - 'get_config' -] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/__init__.py b/ssl_study/simclrv1/pretext_task/data/__init__.py index e69de29..dcfbc2e 100644 --- a/ssl_study/simclrv1/pretext_task/data/__init__.py +++ b/ssl_study/simclrv1/pretext_task/data/__init__.py @@ -0,0 +1,5 @@ +from .data_aug import Augment + +__all__ = [ + 'Augment' +] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/preprocess.py b/ssl_study/simclrv1/pretext_task/data/preprocess.py new file mode 100644 index 0000000..4c57808 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/data/preprocess.py @@ -0,0 +1,26 @@ +import numpy as np +import tensorflow as tf + +AUTOTUNE = tf.data.AUTOTUNE + +class PreprocessDataset(): + def __init__(self, args): + self.args = args + + def preprocess_for_inclass(self): + + + + def parse_data(self, path, label, dataloader_type='train'): + # Parse Image + image_string = tf.io.read_file(path) + image = tf.image.decode_jpeg(image_string, channels=3) + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + image = tf.image.resize(image, + [self.args.augmentation_config["img_height"], + self.args.augmentation_config["img_width"]], + method='bicubic', + preserve_aspect_ratio=False) + image = tf.clip_by_value(image, 0.0, 1.0) + + return image \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/simclrv1_train.py b/ssl_study/simclrv1/pretext_task/simclrv1_train.py deleted file mode 100644 index e69de29..0000000 From b028474e1464cd091e5819b7abf07fe5d793fca4 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Wed, 17 Aug 2022 01:51:42 +0000 Subject: [PATCH 07/16] added dataloader for in-class dataset --- configs/simclrv1_pretext_config.py | 13 +++++ simclrv1_pretext.py | 5 +- .../simclrv1/pretext_task/data/__init__.py | 3 +- .../simclrv1/pretext_task/data/data_aug.py | 14 ++--- .../simclrv1/pretext_task/data/dataloader.py | 53 +++++++++++++++++++ .../simclrv1/pretext_task/data/preprocess.py | 26 --------- 6 files changed, 79 insertions(+), 35 deletions(-) create mode 100644 ssl_study/simclrv1/pretext_task/data/dataloader.py delete mode 100644 ssl_study/simclrv1/pretext_task/data/preprocess.py diff --git a/configs/simclrv1_pretext_config.py b/configs/simclrv1_pretext_config.py index 9d3352f..b65ab47 100644 --- a/configs/simclrv1_pretext_config.py +++ b/configs/simclrv1_pretext_config.py @@ -8,6 +8,18 @@ def get_wandb_configs() -> ml_collections.ConfigDict: return configs +def get_dataset_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.image_height = 224 #default - 224 + configs.image_width = 224 #default - 224 + configs.channels = 3 + configs.apply_resize = True + configs.batch_size = 64 + configs.num_classes = 200 + configs.do_cache = False + + return configs + def get_augment_configs() -> ml_collections.ConfigDict: configs = ml_collections.ConfigDict() configs.image_height = 224 #default - 224 @@ -29,6 +41,7 @@ def get_config() -> ml_collections.ConfigDict: config = ml_collections.ConfigDict() config.seed = 0 config.wandb_config = get_wandb_configs() + config.dataset_config = get_dataset_configs() config.augmentation_config = get_augment_configs() return config \ No newline at end of file diff --git a/simclrv1_pretext.py b/simclrv1_pretext.py index fb00775..e94e00d 100644 --- a/simclrv1_pretext.py +++ b/simclrv1_pretext.py @@ -11,6 +11,7 @@ # Import modules from ssl_study.data import download_dataset, preprocess_dataframe_unlabelled +from ssl_study.simclrv1.pretext_task.data import GetDataloader FLAGS = flags.FLAGS CONFIG = config_flags.DEFINE_config_file("config") @@ -33,7 +34,9 @@ def main(_): # Preprocess the DataFrames inclass_paths = preprocess_dataframe_unlabelled(inclass_df) - print(inclass_paths) + # Build dataloaders + dataset = GetDataloader(config) + inclassloader = dataset.dataloader(inclass_paths) if __name__ == "__main__": app.run(main) \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/__init__.py b/ssl_study/simclrv1/pretext_task/data/__init__.py index dcfbc2e..1ae1309 100644 --- a/ssl_study/simclrv1/pretext_task/data/__init__.py +++ b/ssl_study/simclrv1/pretext_task/data/__init__.py @@ -1,5 +1,6 @@ from .data_aug import Augment +from .dataloader import GetDataloader __all__ = [ - 'Augment' + 'Augment', 'GetDataloader' ] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/data_aug.py b/ssl_study/simclrv1/pretext_task/data/data_aug.py index dec4be0..73bd4ee 100644 --- a/ssl_study/simclrv1/pretext_task/data/data_aug.py +++ b/ssl_study/simclrv1/pretext_task/data/data_aug.py @@ -29,24 +29,24 @@ def build_augmentation(self): ]) return transform - def augmentation(self, image, label): + def augmentation(self, image): aug_img = tf.numpy_function(func=self.aug_fn, inp=[image], Tout=tf.float32) - aug_img.set_shape((self.args.train_config["img_height"], - self.args.train_config["img_width"], 3)) + aug_img.set_shape((self.args.augmentation_config["img_height"], + self.args.augmentation_config["img_width"], 3)) aug_img = tf.image.random_flip_left_right(aug_img) aug_img = tf.image.resize(aug_img, - [self.args.train_config["img_height"], - self.args.train_config["img_width"]], + [self.args.augmentation_config["img_height"], + self.args.augmentation_config["img_width"]], method='bicubic', preserve_aspect_ratio=False) aug_img = tf.clip_by_value(aug_img, 0.0, 1.0) - return aug_img, label + return aug_img def aug_fn(self, image): data = {"image":image} - aug_data = self.transform(**data) + aug_data = self.build_augmentation(**data) aug_img = aug_data["image"] return aug_img.astype(np.float32) \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/dataloader.py b/ssl_study/simclrv1/pretext_task/data/dataloader.py new file mode 100644 index 0000000..acfd712 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/data/dataloader.py @@ -0,0 +1,53 @@ +import numpy as np +import tensorflow as tf +import albumentations as A + +AUTOTUNE = tf.data.AUTOTUNE + +class GetDataloader(): + def __init__(self, args): + self.args = args + + def dataloader(self, paths): + ''' + Args: + paths: List of strings, where each string is path to the image. + + Return: + dataloader: in-class dataloader + ''' + # Consume dataframe + dataloader = tf.data.Dataset.from_tensor_slices(paths) + + # Load the image + dataloader = ( + dataloader + .map(self.parse_data, num_parallel_calls=AUTOTUNE) + ) + + if self.args.dataset_config["do_cache"]: + dataloader = dataloader.cache() + + # Add general stuff + dataloader = ( + dataloader + .shuffle(self.args.dataset_config["batch_size"]) + .batch(self.args.dataset_config["batch_size"]) + .prefetch(AUTOTUNE) + ) + + return dataloader + + def parse_data(self, path): + # Parse Image + image_string = tf.io.read_file(path) + image = tf.image.decode_jpeg(image_string, channels=3) + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + image = tf.image.resize(image, + [self.args.dataset_config["image_height"], + self.args.dataset_config["image_width"]], + method='bicubic', + preserve_aspect_ratio=False) + image = tf.clip_by_value(image, 0.0, 1.0) + + return image \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/preprocess.py b/ssl_study/simclrv1/pretext_task/data/preprocess.py deleted file mode 100644 index 4c57808..0000000 --- a/ssl_study/simclrv1/pretext_task/data/preprocess.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np -import tensorflow as tf - -AUTOTUNE = tf.data.AUTOTUNE - -class PreprocessDataset(): - def __init__(self, args): - self.args = args - - def preprocess_for_inclass(self): - - - - def parse_data(self, path, label, dataloader_type='train'): - # Parse Image - image_string = tf.io.read_file(path) - image = tf.image.decode_jpeg(image_string, channels=3) - image = tf.image.convert_image_dtype(image, dtype=tf.float32) - image = tf.image.resize(image, - [self.args.augmentation_config["img_height"], - self.args.augmentation_config["img_width"]], - method='bicubic', - preserve_aspect_ratio=False) - image = tf.clip_by_value(image, 0.0, 1.0) - - return image \ No newline at end of file From 31a0766efca682288f26e32b1343b8c8d92335d2 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Mon, 22 Aug 2022 02:48:02 +0000 Subject: [PATCH 08/16] updated data augmentation for simclrv1 --- .../simclrv1/pretext_task/data/data_aug.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/ssl_study/simclrv1/pretext_task/data/data_aug.py b/ssl_study/simclrv1/pretext_task/data/data_aug.py index 73bd4ee..2a97437 100644 --- a/ssl_study/simclrv1/pretext_task/data/data_aug.py +++ b/ssl_study/simclrv1/pretext_task/data/data_aug.py @@ -1,3 +1,4 @@ +from albumentations.augmentations.transforms import ToGray import numpy as np import tensorflow as tf from functools import partial @@ -16,16 +17,18 @@ def build_augmentation(self): self.args.augmentation_config["cropscale"], self.args.augmentation_config["cropratio"], self.args.augmentation_config["probability"]), - A.ColorJitter (self.args.augmentation_config["jitterbrightness"], - self.args.augmentation_config["jittercontrast"], - self.args.augmentation_config["jittersaturation"], - self.args.augmentation_config["hue"], + A.HorizontalFlip(self.args.augmentation_config["probability"]), + A.ColorJitter(self.args.augmentation_config["jitterbrightness"], + self.args.augmentation_config["jittercontrast"], + self.args.augmentation_config["jittersaturation"], + self.args.augmentation_config["hue"], + self.args.augmentation_config["alwaysapply"], + self.args.augmentation_config["probablility"]), + A.ToGray(self.args.augmentation_config["probability"]), + A.GaussianBlur(self.args.augmentation_config["gaussianblurlimit"], + self.args.augmentation_config["gaussiansigmalimit"], self.args.augmentation_config["alwaysapply"], - self.args.augmentation_config["probablility"]), - A.GaussianBlur (self.args.augmentation_config["gaussianblurlimit"], - self.args.augmentation_config["gaussiansigmalimit"], - self.args.augmentation_config["alwaysapply"], - self.args.augmentation_config["probability"]) + self.args.augmentation_config["probability"]) ]) return transform From 24635800532db7d0428b7dc4c988c244d3880a11 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Mon, 22 Aug 2022 03:17:56 +0000 Subject: [PATCH 09/16] updated model and config for simclrv1 --- configs/simclrv1_pretext_config.py | 9 ++++- .../simclrv1/pretext_task/data/dataloader.py | 15 ++++---- .../simclrv1/pretext_task/model/__init__.py | 0 .../simclrv1/pretext_task/models/__init__.py | 5 +++ .../simclrv1/pretext_task/models/model.py | 38 +++++++++++++++++++ 5 files changed, 58 insertions(+), 9 deletions(-) delete mode 100644 ssl_study/simclrv1/pretext_task/model/__init__.py create mode 100644 ssl_study/simclrv1/pretext_task/models/__init__.py create mode 100644 ssl_study/simclrv1/pretext_task/models/model.py diff --git a/configs/simclrv1_pretext_config.py b/configs/simclrv1_pretext_config.py index b65ab47..16ee633 100644 --- a/configs/simclrv1_pretext_config.py +++ b/configs/simclrv1_pretext_config.py @@ -13,10 +13,8 @@ def get_dataset_configs() -> ml_collections.ConfigDict: configs.image_height = 224 #default - 224 configs.image_width = 224 #default - 224 configs.channels = 3 - configs.apply_resize = True configs.batch_size = 64 configs.num_classes = 200 - configs.do_cache = False return configs @@ -37,11 +35,18 @@ def get_augment_configs() -> ml_collections.ConfigDict: return configs +def get_bool_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.backbone = "resnet50" + configs.apply_resize = True + configs.do_cache = False + def get_config() -> ml_collections.ConfigDict: config = ml_collections.ConfigDict() config.seed = 0 config.wandb_config = get_wandb_configs() config.dataset_config = get_dataset_configs() config.augmentation_config = get_augment_configs() + config.bool_config = get_bool_configs() return config \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/dataloader.py b/ssl_study/simclrv1/pretext_task/data/dataloader.py index acfd712..3178c4b 100644 --- a/ssl_study/simclrv1/pretext_task/data/dataloader.py +++ b/ssl_study/simclrv1/pretext_task/data/dataloader.py @@ -25,7 +25,7 @@ def dataloader(self, paths): .map(self.parse_data, num_parallel_calls=AUTOTUNE) ) - if self.args.dataset_config["do_cache"]: + if self.args.bool_config["do_cache"]: dataloader = dataloader.cache() # Add general stuff @@ -43,11 +43,12 @@ def parse_data(self, path): image_string = tf.io.read_file(path) image = tf.image.decode_jpeg(image_string, channels=3) image = tf.image.convert_image_dtype(image, dtype=tf.float32) - image = tf.image.resize(image, - [self.args.dataset_config["image_height"], - self.args.dataset_config["image_width"]], - method='bicubic', - preserve_aspect_ratio=False) - image = tf.clip_by_value(image, 0.0, 1.0) + if self.args.bool_config["apply_resize"]: + image = tf.image.resize(image, + [self.args.dataset_config["image_height"], + self.args.dataset_config["image_width"]], + method='bicubic', + preserve_aspect_ratio=False) + image = tf.clip_by_value(image, 0.0, 1.0) return image \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/model/__init__.py b/ssl_study/simclrv1/pretext_task/model/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ssl_study/simclrv1/pretext_task/models/__init__.py b/ssl_study/simclrv1/pretext_task/models/__init__.py new file mode 100644 index 0000000..c89fb18 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/models/__init__.py @@ -0,0 +1,5 @@ +from .model import SimCLRv1Model + +__all__ = [ + 'SimCLRv1Model' +] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/models/model.py b/ssl_study/simclrv1/pretext_task/models/model.py new file mode 100644 index 0000000..5d64264 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/models/model.py @@ -0,0 +1,38 @@ +import tensorflow as tf + +class SimCLRv1Model(): + def __init__(self, args): + self.args = args + + def get_backbone(self): + """Get backbone for the model.""" + weights = None + + if self.args.train_config["backbone"] == 'resnet50': + base_encoder = tf.keras.applications.ResNet50(include_top=False, weights=weights) + base_encoder.trainabe = True + else: + raise NotImplementedError("Not implemented for this backbone.") + + return base_encoder + + def get_model(self, hidden_1, hidden_2, hidden_3): + """Get model.""" + # Backbone + base_encoder = self.get_backbone() + + # Stack layers + inputs = tf.keras.layers.Input( + (self.args.train_config["model_img_height"], + self.args.train_config["model_img_width"], + self.args.train_config["model_img_channels"])) + + x = base_encoder(inputs, training=True) + + projection_1 = tf.keras.layers.Dense(hidden_1)(x) + projection_1 = tf.keras.layers.Activation("relu")(projection_1) + projection_2 = tf.keras.layers.Dense(hidden_2)(projection_1) + projection_2 = tf.keras.layers.Activation("relu")(projection_2) + projection_3 = tf.keras.layers.Dense(hidden_3)(projection_2) + + return tf.keras.models.Model(inputs, projection_3) \ No newline at end of file From 368711d1b0a1141bdd7b64034f88815ce630ea52 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Mon, 22 Aug 2022 04:08:48 +0000 Subject: [PATCH 10/16] updated utils+configs+preprocess_dataframe --- configs/simclrv1_pretext_config.py | 7 ++++ simclrv1_pretext.py | 4 +-- ssl_study/data/__init__.py | 4 +-- ssl_study/data/dataset.py | 28 +++++----------- .../simclrv1/pretext_task/utils/__init__.py | 6 ++++ .../simclrv1/pretext_task/utils/helpers.py | 12 +++++++ .../simclrv1/pretext_task/utils/losses.py | 33 +++++++++++++++++++ 7 files changed, 71 insertions(+), 23 deletions(-) create mode 100644 ssl_study/simclrv1/pretext_task/utils/__init__.py create mode 100644 ssl_study/simclrv1/pretext_task/utils/helpers.py create mode 100644 ssl_study/simclrv1/pretext_task/utils/losses.py diff --git a/configs/simclrv1_pretext_config.py b/configs/simclrv1_pretext_config.py index 16ee633..b4df82b 100644 --- a/configs/simclrv1_pretext_config.py +++ b/configs/simclrv1_pretext_config.py @@ -40,6 +40,13 @@ def get_bool_configs() -> ml_collections.ConfigDict: configs.backbone = "resnet50" configs.apply_resize = True configs.do_cache = False + configs.use_cosine_similarity = True + +def get_train_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.epochs = 30 + configs.temperature = 0.5 + configs.s = 1 def get_config() -> ml_collections.ConfigDict: config = ml_collections.ConfigDict() diff --git a/simclrv1_pretext.py b/simclrv1_pretext.py index e94e00d..d46facd 100644 --- a/simclrv1_pretext.py +++ b/simclrv1_pretext.py @@ -10,7 +10,7 @@ from ml_collections.config_flags import config_flags # Import modules -from ssl_study.data import download_dataset, preprocess_dataframe_unlabelled +from ssl_study.data import download_dataset, preprocess_dataframe from ssl_study.simclrv1.pretext_task.data import GetDataloader FLAGS = flags.FLAGS @@ -32,7 +32,7 @@ def main(_): inclass_df = download_dataset('in-class', 'unlabelled-dataset') # Preprocess the DataFrames - inclass_paths = preprocess_dataframe_unlabelled(inclass_df) + inclass_paths = preprocess_dataframe(inclass_df) # Build dataloaders dataset = GetDataloader(config) diff --git a/ssl_study/data/__init__.py b/ssl_study/data/__init__.py index 2f18d70..60f4e6d 100644 --- a/ssl_study/data/__init__.py +++ b/ssl_study/data/__init__.py @@ -1,6 +1,6 @@ -from .dataset import download_dataset, preprocess_dataframe_labelled, preprocess_dataframe_unlabelled +from .dataset import download_dataset, preprocess_dataframe from .dataloader import GetDataloader __all__ = [ - 'download_dataset', 'preprocess_dataframe_labelled', 'preprocess_dataframe_unlabelled', 'GetDataloader' + 'download_dataset', 'preprocess_dataframe', 'GetDataloader' ] \ No newline at end of file diff --git a/ssl_study/data/dataset.py b/ssl_study/data/dataset.py index 238a66e..4102abc 100644 --- a/ssl_study/data/dataset.py +++ b/ssl_study/data/dataset.py @@ -94,25 +94,15 @@ def download_dataset(dataset_name: str, return data_df -def preprocess_dataframe_labelled(df): - # TODO: take care of df without labels. - # Remove unnecessary columns +def preprocess_dataframe(df, is_labelled=True): df = df.drop(['image_id', 'width', 'height'], axis=1) - assert len(df.columns) == 2 - - # Fix types - df[['label']] = df[['label']].apply(pd.to_numeric) - - image_paths = df.image_path.values - labels = df.label.values - - return image_paths, labels - -def preprocess_dataframe_unlabelled(df): - # Remove unnecessary columns - df = df.drop(['image_id', 'width', 'height'], axis=1) - - # Fix types image_paths = df.image_path.values - return image_paths \ No newline at end of file + if is_labelled: + assert len(df.columns) == 2 + # Fix types + df[['label']] = df[['label']].apply(pd.to_numeric) + labels = df.label.values + return image_paths, labels + else: + return image_paths \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/utils/__init__.py b/ssl_study/simclrv1/pretext_task/utils/__init__.py new file mode 100644 index 0000000..15d527f --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/utils/__init__.py @@ -0,0 +1,6 @@ +from .helpers import get_negative_mask +from .losses import _cosine_simililarity_dim1, _cosine_simililarity_dim2, _dot_simililarity_dim1, _dot_simililarity_dim2 + +__all__ = [ + 'get_negative_mask', '_cosine_simililarity_dim1', '_cosine_simililarity_dim2', '_dot_simililarity_dim1', '_dot_simililarity_dim2' +] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/utils/helpers.py b/ssl_study/simclrv1/pretext_task/utils/helpers.py new file mode 100644 index 0000000..41fa14c --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/utils/helpers.py @@ -0,0 +1,12 @@ +import tensorflow as tf +import numpy as np + +def get_negative_mask(batch_size): + # return a mask that removes the similarity score of equal/similar images. + # this function ensures that only distinct pair of images get their similarity scores + # passed as negative examples + negative_mask = np.ones((batch_size, 2 * batch_size), dtype=bool) + for i in range(batch_size): + negative_mask[i, i] = 0 + negative_mask[i, i + batch_size] = 0 + return tf.constant(negative_mask) \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/utils/losses.py b/ssl_study/simclrv1/pretext_task/utils/losses.py new file mode 100644 index 0000000..764b280 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/utils/losses.py @@ -0,0 +1,33 @@ +import tensorflow as tf + +cosine_sim_1d = tf.keras.losses.CosineSimilarity(axis=1, reduction=tf.keras.losses.Reduction.NONE) +cosine_sim_2d = tf.keras.losses.CosineSimilarity(axis=2, reduction=tf.keras.losses.Reduction.NONE) + + +def _cosine_simililarity_dim1(x, y): + v = cosine_sim_1d(x, y) + return v + + +def _cosine_simililarity_dim2(x, y): + # x shape: (N, 1, C) + # y shape: (1, 2N, C) + # v shape: (N, 2N) + v = cosine_sim_2d(tf.expand_dims(x, 1), tf.expand_dims(y, 0)) + return v + + +def _dot_simililarity_dim1(x, y): + # x shape: (N, 1, C) + # y shape: (N, C, 1) + # v shape: (N, 1, 1) + v = tf.matmul(tf.expand_dims(x, 1), tf.expand_dims(y, 2)) + return v + + +def _dot_simililarity_dim2(x, y): + v = tf.tensordot(tf.expand_dims(x, 1), tf.expand_dims(tf.transpose(y), 0), axes=2) + # x shape: (N, 1, C) + # y shape: (1, C, 2N) + # v shape: (N, 2N) + return v \ No newline at end of file From 6231c6c679d4471a0d6ffd7ab6239bc7f2c1b087 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Mon, 22 Aug 2022 04:43:37 +0000 Subject: [PATCH 11/16] updated dataloader --- configs/simclrv1_pretext_config.py | 1 + ssl_study/simclrv1/pretext_task/data/data_aug.py | 6 ++++++ ssl_study/simclrv1/pretext_task/data/dataloader.py | 7 ++++++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/configs/simclrv1_pretext_config.py b/configs/simclrv1_pretext_config.py index b4df82b..b0e71fe 100644 --- a/configs/simclrv1_pretext_config.py +++ b/configs/simclrv1_pretext_config.py @@ -55,5 +55,6 @@ def get_config() -> ml_collections.ConfigDict: config.dataset_config = get_dataset_configs() config.augmentation_config = get_augment_configs() config.bool_config = get_bool_configs() + config.train_config = get_train_configs() return config \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/data_aug.py b/ssl_study/simclrv1/pretext_task/data/data_aug.py index 2a97437..c3ae8b9 100644 --- a/ssl_study/simclrv1/pretext_task/data/data_aug.py +++ b/ssl_study/simclrv1/pretext_task/data/data_aug.py @@ -1,3 +1,4 @@ +from albumentations import augmentations from albumentations.augmentations.transforms import ToGray import numpy as np import tensorflow as tf @@ -47,6 +48,11 @@ def augmentation(self, image): return aug_img + def simclrv1_augmentation(self, image): + a1 = self.augmentation(image) + a2 = self.augmentation(image) + return a1, a2 + def aug_fn(self, image): data = {"image":image} aug_data = self.build_augmentation(**data) diff --git a/ssl_study/simclrv1/pretext_task/data/dataloader.py b/ssl_study/simclrv1/pretext_task/data/dataloader.py index 3178c4b..eab095e 100644 --- a/ssl_study/simclrv1/pretext_task/data/dataloader.py +++ b/ssl_study/simclrv1/pretext_task/data/dataloader.py @@ -1,6 +1,7 @@ import numpy as np import tensorflow as tf -import albumentations as A + +from .data_aug import Augment AUTOTUNE = tf.data.AUTOTUNE @@ -28,9 +29,13 @@ def dataloader(self, paths): if self.args.bool_config["do_cache"]: dataloader = dataloader.cache() + # Add simclrv1 augmentaion + dataloader = dataloader.map(Augment.simclrv1_augmentation(), num_parallel_calls=AUTOTUNE) + # Add general stuff dataloader = ( dataloader + .repeat(self.args.train_config['epochs']) .shuffle(self.args.dataset_config["batch_size"]) .batch(self.args.dataset_config["batch_size"]) .prefetch(AUTOTUNE) From 656d501c5dcfdc1f29ded5fab064e02b4b427104 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Mon, 22 Aug 2022 05:36:22 +0000 Subject: [PATCH 12/16] minor corrections --- configs/simclrv1_pretext_config.py | 4 ++++ simclrv1_pretext.py | 2 +- ssl_study/simclrv1/pretext_task/data/data_aug.py | 6 ------ ssl_study/simclrv1/pretext_task/data/dataloader.py | 4 ---- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/configs/simclrv1_pretext_config.py b/configs/simclrv1_pretext_config.py index b0e71fe..e6cb1e2 100644 --- a/configs/simclrv1_pretext_config.py +++ b/configs/simclrv1_pretext_config.py @@ -42,11 +42,15 @@ def get_bool_configs() -> ml_collections.ConfigDict: configs.do_cache = False configs.use_cosine_similarity = True + return configs + def get_train_configs() -> ml_collections.ConfigDict: configs = ml_collections.ConfigDict() configs.epochs = 30 configs.temperature = 0.5 configs.s = 1 + + return configs def get_config() -> ml_collections.ConfigDict: config = ml_collections.ConfigDict() diff --git a/simclrv1_pretext.py b/simclrv1_pretext.py index d46facd..046a56d 100644 --- a/simclrv1_pretext.py +++ b/simclrv1_pretext.py @@ -32,7 +32,7 @@ def main(_): inclass_df = download_dataset('in-class', 'unlabelled-dataset') # Preprocess the DataFrames - inclass_paths = preprocess_dataframe(inclass_df) + inclass_paths = preprocess_dataframe(inclass_df, is_labelled=False) # Build dataloaders dataset = GetDataloader(config) diff --git a/ssl_study/simclrv1/pretext_task/data/data_aug.py b/ssl_study/simclrv1/pretext_task/data/data_aug.py index c3ae8b9..e79b063 100644 --- a/ssl_study/simclrv1/pretext_task/data/data_aug.py +++ b/ssl_study/simclrv1/pretext_task/data/data_aug.py @@ -38,7 +38,6 @@ def augmentation(self, image): aug_img.set_shape((self.args.augmentation_config["img_height"], self.args.augmentation_config["img_width"], 3)) - aug_img = tf.image.random_flip_left_right(aug_img) aug_img = tf.image.resize(aug_img, [self.args.augmentation_config["img_height"], self.args.augmentation_config["img_width"]], @@ -48,11 +47,6 @@ def augmentation(self, image): return aug_img - def simclrv1_augmentation(self, image): - a1 = self.augmentation(image) - a2 = self.augmentation(image) - return a1, a2 - def aug_fn(self, image): data = {"image":image} aug_data = self.build_augmentation(**data) diff --git a/ssl_study/simclrv1/pretext_task/data/dataloader.py b/ssl_study/simclrv1/pretext_task/data/dataloader.py index eab095e..d7c72a8 100644 --- a/ssl_study/simclrv1/pretext_task/data/dataloader.py +++ b/ssl_study/simclrv1/pretext_task/data/dataloader.py @@ -29,13 +29,9 @@ def dataloader(self, paths): if self.args.bool_config["do_cache"]: dataloader = dataloader.cache() - # Add simclrv1 augmentaion - dataloader = dataloader.map(Augment.simclrv1_augmentation(), num_parallel_calls=AUTOTUNE) - # Add general stuff dataloader = ( dataloader - .repeat(self.args.train_config['epochs']) .shuffle(self.args.dataset_config["batch_size"]) .batch(self.args.dataset_config["batch_size"]) .prefetch(AUTOTUNE) From e883c163c7bec16f66655a487a2c51c2efebf677 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Mon, 22 Aug 2022 05:59:07 +0000 Subject: [PATCH 13/16] updated pipeline.py --- .../pretext_task/pipeline/pipeline.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 ssl_study/simclrv1/pretext_task/pipeline/pipeline.py diff --git a/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py b/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py new file mode 100644 index 0000000..ee21285 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py @@ -0,0 +1,77 @@ +import os +import json +import tempfile +import numpy as np +from sklearn.metrics import accuracy_score +import wandb +from tqdm import tqdm +import tensorflow as tf + +from ssl_study.simclrv1.pretext_task.utils import _dot_simililarity_dim1 as sim_func_dim1, _dot_simililarity_dim2 as sim_func_dim2, get_negative_mask +from ssl_study.simclrv1.pretext_task.data import Augment + +class SupervisedPipeline(): + def __init__(self, model, args, class_weights=None, callbacks=[]): + self.args = args + self.model = model + self.class_weights = class_weights + self.callbacks = callbacks + + @tf.function + def train_step(self, xis, xjs, model, optimizer, criterion, temperature): + with tf.GradientTape() as tape: + zis = model(xis) + zjs = model(xjs) + + # normalize projection feature vectors + zis = tf.math.l2_normalize(zis, axis=1) + zjs = tf.math.l2_normalize(zjs, axis=1) + + l_pos = sim_func_dim1(zis, zjs) + l_pos = tf.reshape(l_pos, (self.args.dataset_config["batch_size"], 1)) + l_pos /= temperature + + negatives = tf.concat([zjs, zis], axis=0) + + loss = 0 + + for positives in [zis, zjs]: + l_neg = sim_func_dim2(positives, negatives) + + labels = tf.zeros(self.args.dataset_config["batch_size"], dtype=tf.int32) + + l_neg = tf.boolean_mask(l_neg, get_negative_mask(self.args.dataset_config["batch_size"])) + l_neg = tf.reshape(l_neg, (self.args.dataset_config["batch_size"], -1)) + l_neg /= temperature + + logits = tf.concat([l_pos, l_neg], axis=1) + loss += criterion(y_pred=logits, y_true=labels) + + loss = loss / (2 * self.args.dataset_config["batch_size"]) + + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + + return loss + + def train_simclr(self, model, dataset, optimizer, criterion, temperature=0.1, epochs=100): + step_wise_loss = [] + epoch_wise_loss = [] + + augment = Augment.augmentation() + + for epoch in tqdm(range(epochs)): + for image_batch in dataset: + a = augment(image_batch) + b = augment(image_batch) + + loss = self.train_step(a, b, model, optimizer, criterion, temperature) + step_wise_loss.append(loss) + + epoch_wise_loss.append(np.mean(step_wise_loss)) + wandb.log({"nt_xentloss": np.mean(step_wise_loss)}) + + if epoch % 10 == 0: + print("epoch: {} loss: {:.3f}".format(epoch + 1, np.mean(step_wise_loss))) + + return epoch_wise_loss, model \ No newline at end of file From cc8eeb67686c61879b0f74c7278df03803cb8bbe Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Wed, 24 Aug 2022 07:46:49 +0000 Subject: [PATCH 14/16] updated pipeline+model+helpers+config+script --- configs/simclrv1_pretext_config.py | 17 ++++++++++++++++- simclrv1_pretext.py | 7 +++++++ ssl_study/simclrv1/pretext_task/models/model.py | 8 ++++---- .../simclrv1/pretext_task/pipeline/__init__.py | 5 +++++ .../simclrv1/pretext_task/pipeline/pipeline.py | 6 ++---- .../simclrv1/pretext_task/utils/helpers.py | 15 ++++++++++++++- 6 files changed, 48 insertions(+), 10 deletions(-) diff --git a/configs/simclrv1_pretext_config.py b/configs/simclrv1_pretext_config.py index e6cb1e2..98569c2 100644 --- a/configs/simclrv1_pretext_config.py +++ b/configs/simclrv1_pretext_config.py @@ -37,13 +37,21 @@ def get_augment_configs() -> ml_collections.ConfigDict: def get_bool_configs() -> ml_collections.ConfigDict: configs = ml_collections.ConfigDict() - configs.backbone = "resnet50" configs.apply_resize = True configs.do_cache = False configs.use_cosine_similarity = True return configs +def get_model_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.backbone = "resnet50" + configs.hidden1 = 256 + configs.hidden2 = 128 + configs.hidden3 = 50 + + return configs + def get_train_configs() -> ml_collections.ConfigDict: configs = ml_collections.ConfigDict() configs.epochs = 30 @@ -51,6 +59,11 @@ def get_train_configs() -> ml_collections.ConfigDict: configs.s = 1 return configs + +def get_learning_rate_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.decay_steps = 1000 + configs.initial_learning_rate = 0.1 def get_config() -> ml_collections.ConfigDict: config = ml_collections.ConfigDict() @@ -60,5 +73,7 @@ def get_config() -> ml_collections.ConfigDict: config.augmentation_config = get_augment_configs() config.bool_config = get_bool_configs() config.train_config = get_train_configs() + config.learning_rate_config = get_learning_rate_configs() + config.model_config = get_model_configs() return config \ No newline at end of file diff --git a/simclrv1_pretext.py b/simclrv1_pretext.py index 046a56d..25295c3 100644 --- a/simclrv1_pretext.py +++ b/simclrv1_pretext.py @@ -12,6 +12,8 @@ # Import modules from ssl_study.data import download_dataset, preprocess_dataframe from ssl_study.simclrv1.pretext_task.data import GetDataloader +from ssl_study.simclrv1.pretext_task.models import SimCLRv1Model +from ssl_study.simclrv1.pretext_task.pipeline import SimCLRv1Pipeline FLAGS = flags.FLAGS CONFIG = config_flags.DEFINE_config_file("config") @@ -38,5 +40,10 @@ def main(_): dataset = GetDataloader(config) inclassloader = dataset.dataloader(inclass_paths) + # Model + tf.keras.backend.clear_session() + model = SimCLRv1Model(config).get_model(config.model_config["hidden1"], config.model_config["hidden2"], config.model_config["hidden3"]) + model.summary() + if __name__ == "__main__": app.run(main) \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/models/model.py b/ssl_study/simclrv1/pretext_task/models/model.py index 5d64264..fe20640 100644 --- a/ssl_study/simclrv1/pretext_task/models/model.py +++ b/ssl_study/simclrv1/pretext_task/models/model.py @@ -8,7 +8,7 @@ def get_backbone(self): """Get backbone for the model.""" weights = None - if self.args.train_config["backbone"] == 'resnet50': + if self.args.model_config["backbone"] == 'resnet50': base_encoder = tf.keras.applications.ResNet50(include_top=False, weights=weights) base_encoder.trainabe = True else: @@ -23,9 +23,9 @@ def get_model(self, hidden_1, hidden_2, hidden_3): # Stack layers inputs = tf.keras.layers.Input( - (self.args.train_config["model_img_height"], - self.args.train_config["model_img_width"], - self.args.train_config["model_img_channels"])) + (self.args.dataset_config["image_height"], + self.args.dataset_config["image_width"], + self.args.dataset_config["channels"])) x = base_encoder(inputs, training=True) diff --git a/ssl_study/simclrv1/pretext_task/pipeline/__init__.py b/ssl_study/simclrv1/pretext_task/pipeline/__init__.py index e69de29..6e24ee2 100644 --- a/ssl_study/simclrv1/pretext_task/pipeline/__init__.py +++ b/ssl_study/simclrv1/pretext_task/pipeline/__init__.py @@ -0,0 +1,5 @@ +from .pipeline import SimCLRv1Pipeline + +__all__ = [ + 'SimCLRv1Pipeline' +] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py b/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py index ee21285..42a13a3 100644 --- a/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py +++ b/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py @@ -10,12 +10,10 @@ from ssl_study.simclrv1.pretext_task.utils import _dot_simililarity_dim1 as sim_func_dim1, _dot_simililarity_dim2 as sim_func_dim2, get_negative_mask from ssl_study.simclrv1.pretext_task.data import Augment -class SupervisedPipeline(): - def __init__(self, model, args, class_weights=None, callbacks=[]): +class SimCLRv1Pipeline(): + def __init__(self, model, args): self.args = args self.model = model - self.class_weights = class_weights - self.callbacks = callbacks @tf.function def train_step(self, xis, xjs, model, optimizer, criterion, temperature): diff --git a/ssl_study/simclrv1/pretext_task/utils/helpers.py b/ssl_study/simclrv1/pretext_task/utils/helpers.py index 41fa14c..a662508 100644 --- a/ssl_study/simclrv1/pretext_task/utils/helpers.py +++ b/ssl_study/simclrv1/pretext_task/utils/helpers.py @@ -9,4 +9,17 @@ def get_negative_mask(batch_size): for i in range(batch_size): negative_mask[i, i] = 0 negative_mask[i, i + batch_size] = 0 - return tf.constant(negative_mask) \ No newline at end of file + return tf.constant(negative_mask) + +def get_criterion(): + return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM) + +class Optimizer(): + def __init__(self, args): + self.args = args + + def get_optimizer(self): + learning_rate = tf.keras.experimental.CosineDecay(initial_learning_rate= self.args.learning_rate_config['initial_learning_rate'], decay_steps=self.args.learning_rate_config['decay_steps']) + optimizer = tf.keras.optimizers.SGD(learning_rate) + + return optimizer \ No newline at end of file From 809277037f5779916102e1df97d4f793f9a80bae Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Wed, 24 Aug 2022 08:12:19 +0000 Subject: [PATCH 15/16] updated script and pipeline --- simclrv1_pretext.py | 8 ++++++++ .../simclrv1/pretext_task/pipeline/pipeline.py | 17 +++++++++++++---- .../simclrv1/pretext_task/utils/helpers.py | 15 +-------------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/simclrv1_pretext.py b/simclrv1_pretext.py index 25295c3..181e7a0 100644 --- a/simclrv1_pretext.py +++ b/simclrv1_pretext.py @@ -45,5 +45,13 @@ def main(_): model = SimCLRv1Model(config).get_model(config.model_config["hidden1"], config.model_config["hidden2"], config.model_config["hidden3"]) model.summary() + # Build the pipeline + pipeline = SimCLRv1Pipeline(model, config) + optimizer = pipeline.get_optimizer + criterion = pipeline.get_criterion + + epoch_wise_loss, resnet_simclr = pipeline.train_simclr(model, inclassloader, optimizer, criterion, temperature=config.train_config["temperature"], epochs=config.train_config["epochs"]) + + if __name__ == "__main__": app.run(main) \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py b/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py index 42a13a3..0fa9e54 100644 --- a/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py +++ b/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py @@ -56,12 +56,12 @@ def train_simclr(self, model, dataset, optimizer, criterion, temperature=0.1, ep step_wise_loss = [] epoch_wise_loss = [] - augment = Augment.augmentation() + augment = Augment(self.args) for epoch in tqdm(range(epochs)): for image_batch in dataset: - a = augment(image_batch) - b = augment(image_batch) + a = augment.augmentation(image_batch) + b = augment.augmentation(image_batch) loss = self.train_step(a, b, model, optimizer, criterion, temperature) step_wise_loss.append(loss) @@ -72,4 +72,13 @@ def train_simclr(self, model, dataset, optimizer, criterion, temperature=0.1, ep if epoch % 10 == 0: print("epoch: {} loss: {:.3f}".format(epoch + 1, np.mean(step_wise_loss))) - return epoch_wise_loss, model \ No newline at end of file + return epoch_wise_loss, model + + def get_criterion(self): + return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM) + + def get_optimizer(self): + learning_rate = tf.keras.experimental.CosineDecay(initial_learning_rate= self.args.learning_rate_config['initial_learning_rate'], decay_steps=self.args.learning_rate_config['decay_steps']) + optimizer = tf.keras.optimizers.SGD(learning_rate) + + return optimizer \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/utils/helpers.py b/ssl_study/simclrv1/pretext_task/utils/helpers.py index a662508..41fa14c 100644 --- a/ssl_study/simclrv1/pretext_task/utils/helpers.py +++ b/ssl_study/simclrv1/pretext_task/utils/helpers.py @@ -9,17 +9,4 @@ def get_negative_mask(batch_size): for i in range(batch_size): negative_mask[i, i] = 0 negative_mask[i, i + batch_size] = 0 - return tf.constant(negative_mask) - -def get_criterion(): - return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM) - -class Optimizer(): - def __init__(self, args): - self.args = args - - def get_optimizer(self): - learning_rate = tf.keras.experimental.CosineDecay(initial_learning_rate= self.args.learning_rate_config['initial_learning_rate'], decay_steps=self.args.learning_rate_config['decay_steps']) - optimizer = tf.keras.optimizers.SGD(learning_rate) - - return optimizer \ No newline at end of file + return tf.constant(negative_mask) \ No newline at end of file From 879d2d1810ca1fcadc00880478245998c849a6cb Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Wed, 24 Aug 2022 08:32:30 +0000 Subject: [PATCH 16/16] minor typos and correction in data_aug --- .../simclrv1/pretext_task/data/data_aug.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ssl_study/simclrv1/pretext_task/data/data_aug.py b/ssl_study/simclrv1/pretext_task/data/data_aug.py index e79b063..73554b7 100644 --- a/ssl_study/simclrv1/pretext_task/data/data_aug.py +++ b/ssl_study/simclrv1/pretext_task/data/data_aug.py @@ -11,10 +11,10 @@ class Augment(): def __init__(self, args): self.args = args - def build_augmentation(self): + def build_augmentation(self, image): transform = A.Compose([ - A.RandomResizedCrop(self.args.augmentation_config["img_height"], - self.args.augmentation_config["img_width"], + A.RandomResizedCrop(self.args.augmentation_config["image_height"], + self.args.augmentation_config["image_width"], self.args.augmentation_config["cropscale"], self.args.augmentation_config["cropratio"], self.args.augmentation_config["probability"]), @@ -22,9 +22,9 @@ def build_augmentation(self): A.ColorJitter(self.args.augmentation_config["jitterbrightness"], self.args.augmentation_config["jittercontrast"], self.args.augmentation_config["jittersaturation"], - self.args.augmentation_config["hue"], + self.args.augmentation_config["jitterhue"], self.args.augmentation_config["alwaysapply"], - self.args.augmentation_config["probablility"]), + self.args.augmentation_config["probability"]), A.ToGray(self.args.augmentation_config["probability"]), A.GaussianBlur(self.args.augmentation_config["gaussianblurlimit"], self.args.augmentation_config["gaussiansigmalimit"], @@ -35,12 +35,12 @@ def build_augmentation(self): def augmentation(self, image): aug_img = tf.numpy_function(func=self.aug_fn, inp=[image], Tout=tf.float32) - aug_img.set_shape((self.args.augmentation_config["img_height"], - self.args.augmentation_config["img_width"], 3)) + aug_img.set_shape((self.args.augmentation_config["image_height"], + self.args.augmentation_config["image_width"], 3)) aug_img = tf.image.resize(aug_img, - [self.args.augmentation_config["img_height"], - self.args.augmentation_config["img_width"]], + [self.args.augmentation_config["image_height"], + self.args.augmentation_config["image_width"]], method='bicubic', preserve_aspect_ratio=False) aug_img = tf.clip_by_value(aug_img, 0.0, 1.0)