Pull to refresh

Self-Supervised Learning. Contrastive learning

Reading time5 min
Views6.6K

В предыдущих статьях мы разобрали много аспектов, связанных с SSL. Теперь пришло время посмотреть на методы, которые используют достаточно очевидное, на первый взгляд, знание - одна и та же картинка похожа, а разные картинки - разные. Это основная идея методов с contrastive подходом. Ниже мы разберём более подробно как эту идею можно использовать при конфигурации фреймворка обучения.

Напомню, что это четвертая статья из цикла про SSL в Computer Vision.

InstDisc (Instance Discrimination)

📋Z. Wu, Y. Xiong, S. Yu et al. Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination(Май 2018)
Основная идея метода - попытаться учить модель располагать эмбеддинги на единичной сфере (норма эмбеддинга = 1). Для этого они используют непарметрический софтмакс:

P(i \mid \mathbf{v})=\frac{\exp \left(\mathbf{v}_i^T \mathbf{v} / \tau\right)}{\sum_{j=1}^n \exp \left(\mathbf{v}_j^T \mathbf{v} / \tau\right)}

т.е. они максимизируют вероятность эмбеддинга на себя относительно всех других изображений. Таким образом противоположные картинки отталкиваются. Похожие картинки тоже отталкиваются, но поскольку софтмакс считается сразу по всему датасету, то отталкиваются ещё и от противоположных, поэтому всё-таки немного притягиваются🙃. Для работы такой штуки авторы снизили размерность эмбеддинга до 128. \tau в формуле - это параметр температуры (по аналогии со статистической механикой), при температуре близкой к 0, вероятность максимального значения после софтмакса стремится к 1, при температуре близкой к бесконечности - распределение близко к равномерному. Подробнее про лосс почитайте в самой статье, там есть ещё оптимизация для расчёта софтмакса по всему датасету и регуляризация.

Также они привносят понятие memory bank - хранилище эмбеддингов для всех картинок. Оно необходимо, чтобы не пересчитывать эмбеддинги после каждого апдейта сети.
Это один из первых методов, где удалось достичь улучшение метрик на последнем свёрточном слое (conv5) относительно предпоследнего (conv4) (мы на это обращали внимание во второй статье). И отсюда можно проследить развитие по двум направлениям - кластерные методы и методы с притягиванием и отталкиванием различных изображений.

Contrastive learning

Contrastive learning - подход при котором обучение происходит не только по принципу близости, но и по принципы различия. Следовательно, для описания нам необходимы как позитивные примеры, так и негативные.

Существуют различные формулировки contrastive loss

  1. триплет лосс - выбирается 1 изображение как якорь, ему ставят в соответствие 1 позитивный и 1 негативный пример

  2. InfoNCE (Information Noise-Contrastive Estimation) - то же самое, но N негативных примеров (по сути мы уже рассмотрели его выше на примере InstDisc)

Основной сложностью использования такого лосса можно назвать правильный подбор негативных примеров (hard negative mining). Поскольку разметки по классам у нас нет, то различные изображения енотов будут учитываться как негативные пример, хотя семантически они близки.

Больше подробностей по contrastive лоссам можно поискать здесь

MoCo, MoCo v2 (Momentum Contrast)

📋K. He, H. Fan, Y. Wu et al. Momentum Contrast for Unsupervised Visual Representation Learning(Ноябрь 2019)
📋X. Chen, H. Fan, R. Girshick, et al. Improved Baselines with Momentum Contrastive Learning(Март 2020)

Статья наследует идею contrastive лосс и memory bank, добавляя momentum encoder. Мы разбирали эту идею ещё при рассмотрении BYOL, глобально она перекочевала сюда из reinforcement learning(RL) и заключается в том, что есть 2 сети одна из которых обновляется через механизм обратного распространения, а другая - скользящим средним от уже обновлённых весов и предыдущего состояния. Также авторы модифицировали концепцию memory bank - раньше он содержал информацию об эмбеддингах всего датасета и обновлялся 1 раз за эпоху (следовательно, содержал много устаревших представлений). Теперь же он представляет собой очередь из предыдущих N батчей, следовательно динамически обновляется.

MoCo v2 - небольшое обновление MoCo, в котором авторы добавили к себе парочку идей из SimCLR (см ниже) и показали, что у них всё ещё SOTA😜

PIRL (Pretext-Invariant Representation Learning)

📋I. Misra, L. Maaten. Self-Supervised Learning of Pretext-Invariant Representations(Декабрь 2019)
Авторы в этом методе совмещают концепцию contrastive лосс, memory bank и Jigsaw. Идея состоит в том, что представление всего изображения и комбинации его кусочков должны быть похожи и контрастировать с представлениями других изображений из memory bank. Ещё одно нововведение - что эмбеддинг в memory bank обновляется скользящим средним, а не полной заменой. Это играет большую роль в стабилизации обучения наряду с регуляризацией в InstDisc и с momentum encoder в MoCo.

SimCLR (Simple Framework for Contrastive Learning)

📋T. Chen, S. Kornblith, M. Norouzi et al. A Simple Framework for Contrastive Learning of Visual Representations [Google blogpost] (Июль 2020)

Статья во многом оказала сильное влияние на дальнейшее развитие направления. Рассмотрим ключевые тезисы отдельно:

  1. Авторы анализируют влияние выбора аугментаций на финальное качество модели. Они указывают на важность кропов изображений и дисторсии цвета (иначе как минимум цветовые гистораммы 2 аугментаций очень похожи, что позволяет модели выучивать эти признаки вместо семантики). Авторы приходят к выводу, что использование более суровых аугментаций приводит к лучшим результатам. Даже те, которые не улучшали качество при обычном supervised обучении.

  2. Авторы применяют промежуточный MLP слой, называемый projector, перед рассчётом лосса. Добавление такого слоя докидывает аж +10% к финальному качеству. Тут надо прояснить терминологию - в дальнейшее обучение downstream задачи всё так же идёт энкодер f(), который преобразует картинку в эмбеддинг y=f(x), просто для обучения используется дополнительный MLP слой projector g(), который порождает проекцию z=g(y)=g(f(x)) и именно она передаётся в лосс. Этот хак потом подтянули к себе многие последующие фреймворки, а MoCo даже выпустили ещё одну статью.

  3. Ещё один важный аспект - авторы избавились от memory bank. Негативный примеры теперь полностью берутся из батча. Авторы приводят ablation study, где показывают, что чем больше размер батча - тем лучше. Правда размер батча у них 8192, что могут себе позволить не только лишь все, но Гугл.

  4. Ещё один рецепт успеха - долгое обучение. 1000 эпох процентов на 10 пунктов лучше, чем 100 эпох.

Основная идея добавления блока projector в том, что он помогает адаптировать эмбединги из экстрактора под конкретно решаемую задачу (в данном случае - удовлетворить contrastive лосс). У этой задачи есть определённая специфика, она отличается от downstream задачи. При такой постановке задачи projector забирает часть специфики, позволяя обучать более обобщённое представление. Авторы в статье показывают, что проекция содержит гораздо меньше информации о применённых аугментация, нежели эмбеддинги. Впрочем это ничего не доказывает и, как всегда, не хватает теоретических доказательств. Пока есть только практические, этот трюк работает.

NNCLR (Nearest-Neighbor Contrastive Learning of visual Representations)

📋D. Dwibedi, Y. Aytar, J. Tompson et al. With a Little Help from My Friends: Nearest-Neighbor Contrastive Learning of Visual Representations(Апрель 2021)

Авторы развивают идею SimCLR, только в качестве похожих изображений предлагают брать не 2 эмбеддинга от 2-х разных аугментаций, а эмбеддинг и ближайшего к нему соседа из memory bank (у них это называется Support set, формируется как в MoCo - очередь из предыдущих батчей). Интересно, что такой подход позволяет ослабить тяжесть применяемых аугментаций, скорее всего за счёт того, что мы всё равно будем использовать соседа (который в семантическом пространстве находится дальше от картинки, чем её аугментация).

Вместо итога

В статье мы рассмотрели основные методы обучения SSL с использованием Contrastive loss. Это эффективный подход, однако требовательный к большому количеству негативных примеров при расчёте лосса (для этого используются memory bank либо большой размер батча).
В следующеё статье мы посмотрим как сравнивать не только с несколькими примерами, но сразу со всем доступным датасетом!

Список статей в цикле

  1. SSL. Проблематика и постановка задачи

  2. SSL. Метрики и первые pretext tasks

  3. SSL. Обучение на изображении и его аугментациях

  4. SSL. Contrastive learning

  5. SSL. Кластеризация как лосс

  6. SSL. Результаты и основные фреймворки

Tags:
Hubs:
Total votes 6: ↑6 and ↓0+6
Comments0

Articles