Pull to refresh

Обзор архитектур image-to-image translation

Level of difficultyMedium
Reading time13 min
Views6.1K

Привет, Хабр! Я работаю инженером компьютерного зрения в направлении искусственного интеллекта компании Норникель. Мы разрабатываем и внедряем модели с применением машинного обучения на наши производственные площадки.

В скоуп наших проектов попадают как системы, управляющие (или частично управляющие) технологическим процессом (например, флотация или плавка), так и системы промышленного машинного зрения, которые по сути представляют из себя одну из разновидностей датчиков.

В этой статье я расскажу про основные архитектуры генеративных сетей для задачи перевода изображения из одного домена в другой (image-to-image translation). В конце расскажу, для чего именно мы применяем синтетические данные и приведу примеры изображений, которых нам удалось достичь. Но перед погружением в данную тему рекомендую ознакомиться с тем, что такое свёрточная сеть, U-Net и генеративная сеть. Если же Вы готовы, поехали.

Каждое производство в той или иной степени индивидуально. Картинки на разных площадках могут сильно отличаться. Поэтому хочется сократить время, которое уходит на разметку данных. Например, при помощи генерации изображений по бинарной маске.

Слева - конвейер с дроблёным файнштейном, справа - семантическая маска
Слева - конвейер с дроблёным файнштейном, справа - семантическая маска

Существующие архитектуры можно разделить на 2 типа: supervised и unsupervised. Для supervised подхода входная маска и целевое изображение должны соответствовать друг другу, как на картинке выше. Для unsupervised такого условия не требуется. В обоих случаях мы хотим сгенерировать изображение таким образом, чтобы сохранилось содержимое маски (расположение объектов), при этом стиль сгенерированного изображения совпадал с целевым.

Supervised image-to-image translation

pix2pix

[Статья, реализация]

Обучение pix2pix
Обучение pix2pix

Про архитектуру

Генератор представляет собой U-Net подобную архитектуру, на вход которой подаётся бинарная маска. Дискриминатор - это классификатор типа PatchGAN, на вход которому подаются бинарная маска и целевое изображение.

Как и в задаче сегментации, в архитектуре генератора нужны skip connections из энкодер блоков в декодер блоки. Но есть и некоторые отличия от архитектур сегментации.

Например, при обучении ГАНов для уменьшения размерности рекомендуется заменить операции max pooling на свёртки со страйдом. А в качестве финальной активации используется гиперболический тангенс, т.к. сгенерированные изображение мы затем подаём на вход дискриминатору, т.е. мат. ожидание выходного изображения должно быть нулевым, а дисперсия единичной.

PatchGAN - это свёрточный классификатор, который классифицирует каждый патч (реальный или фейковый). Реализовано это следующим образом.

NLayerDiscriminator
import functools

import torch.nn as nn


class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)
      

В дискриминаторе ReLU активации заменены на LeakyRELU, удалены все fully connected слои, а на выходе нет никакой активации, т.е. выходные значения ничем не ограничены. Теоретически это может привести к gradient explosion, однако в данной конфигурации это работает стабильно.

Про обучение и борьбу с затуханием градиентов

Модели обучаются в генеративно-состязательном режиме. Т.е. цель дискриминатора - научиться отличать оригинальные изображения от сгенерированных, цель генератора - научиться генерировать такие изображения, что дискриминатор не сможет их отличить от реальных.

\min\limits_{G} \max\limits_{D} V(D, G)=\mathbb{E}_{x\sim p_{data}(x)}[logD(x)] + \mathbb{E}_{z\sim p_{z}(z)}[log(1-D(G(z)))].

Такой способ обучения называется minimax game. В этой игре дискриминатор минимизирует кросс-энтропию между целевым классом и предсказанным распределением а генератор наоборот максимизирует кросс-энтропию. Это приводит к тому, что, когда дискриминатор слишком уверенно отклоняет сгенерированные изображения, градиенты генератора затухают.

Один из способов для борьбы с затуханием градиентов генератора - non-saturating game, который был найден эвристическим путём:

Вместо минимизации

\mathbb{E}_{z\sim p_{z}(z)}[log(1-D(G(z)))]

для генератора максимизируем

\mathbb{E}_{z}logD(G(z))

Авторы pix2pix и используют non-saturating game.

Также они добавляют к лосс-функции (L2 loss даёт более размытые результаты):

\mathcal{L}_{L1}(G)=   \mathbb{E}_{x, y,z}[\|y - G(x, z)\|_{1}].

И L1 и L2 loss между реальным и сгенерированным изображения хорошо улавливают низкочастотные детали, но не высокочастотные детали.

cGAN - non-saturating game; L1 - minimax game + L1 loss between real and fake images
cGAN - non-saturating game; L1 - minimax game + L1 loss between real and fake images

Для увеличения фокуса на высокочастотных деталях, картинку делят на патчи, как это и сделано в PatchGAN.

Архитектура pix2pix является базовой для supervised image-to-image translation. Pix2pixHD и SPADE, которые мы рассмотрим дальше, улучшают её идеи.

pix2pixHD

[Статья, реализация]

Данная архитектура, как следует из названия, предназначена для генерации изображений более высокого разрешения, чем её предшественница.

Про архитектуру

Можно выделить следующие модификации:

Multi-scale discriminators

Исходная картинка уменьшается в 2 и 4 раза - получаем 3 изображения разного масштаба. На каждый масштаб создаётся свой дискриминатор PatchGAN, которые имеют одинаковую структуру. Реализация.

MultiscaleDiscriminator
class MultiscaleDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
                 use_sigmoid=False, num_D=3, getIntermFeat=False):
        super(MultiscaleDiscriminator, self).__init__()
        self.num_D = num_D
        self.n_layers = n_layers
        self.getIntermFeat = getIntermFeat
     
        for i in range(num_D):
            netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
            if getIntermFeat:                                
                for j in range(n_layers+2):
                    setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))                                   
            else:
                setattr(self, 'layer'+str(i), netD.model)

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)

    def singleD_forward(self, model, input):
        if self.getIntermFeat:
            result = [input]
            for i in range(len(model)):
                result.append(model[i](result[-1]))
            return result[1:]
        else:
            return [model(input)]

    def forward(self, input):        
        num_D = self.num_D
        result = []
        input_downsampled = input
        for i in range(num_D):
            if self.getIntermFeat:
                model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
            else:
                model = getattr(self, 'layer'+str(num_D-1-i))
            result.append(self.singleD_forward(model, input_downsampled))
            if i != (num_D-1):
                input_downsampled = self.downsample(input_downsampled)
        return result

Progressive generator

Архитектура генератора pix2pixHD
Архитектура генератора pix2pixHD

Сперва сеть G1 обучается на низком разрешении, потом добавляются слои G2 и модель дообучается на высоком разрешении.

Про обучение

Feature matching + perceptual loss

Во время обучения из дискриминатора извлекаются промежуточные признаки с разных слоёв и считается L1 loss между признаками реального изображения и сгенерированного. Реализация.

Perceptual loss похож на Feature matching, но вместо дискриминатора используется предобученная на Imagenet модель, например, VGG16 (во время обучения GANов она не обучается). Реализация.

Обе лосс-функции стабилизируют обучение, т.к. генератор должен производить правдоподобные статистики на разных масштабах изображения.

Lsgan

Вместо cross-entropy в minimax game считаем L2 loss [lsgan].

Про instance-wise генерацию. Boundary map

Пример (рис. (а) внизу) сгенерированного изображения по семантической маске и пример (рис. (b) внизу) сгенерированного изображения по семантической маске и boundary map
Пример (рис. (а) внизу) сгенерированного изображения по семантической маске и пример (рис. (b) внизу) сгенерированного изображения по семантической маске и boundary map

Ко входной семантической маски добавляется boundary map, которая содержит instance-wise информацию. Это делается для того, чтобы разные объекты одного класса не сливались в один объект, как на рисунке выше.

Про instance-wise генерацию. Feature encoder network    

Instance-wise генерация при помощи Feature Encoder Network
Instance-wise генерация при помощи Feature Encoder Network

Feature Encoder Network представляет собой encoder-decoder архитектуру (без skip-connections), выход которой подаётся на слой instance-wise average pooling.  Новый слой усредняет признаки для каждого инстанса по boundary map, полученная маска добавляется к семантической маске.

Таким образом, во время инференса для одной и той же семантической маски мы можем сгенерировать разные изображения, подавая на вход сети E, случайные изображения из нашего набора (instance-wise average pooling усредняет признаки, в соответствии с семантической маской).

 Энкодер E обучается совместно с генератором и дискриминатором.

SPADE

[Статья, реализация]

Про архитектуру

Multi-sclae discriminators

Как и в pix2pixHD используются Multi-scale discriminators.

Feature matching + perceptual loss

Также как и в pix2pixHD.

SPADE generator

Генератор сильно отличается от двух предыдущих архитектур. На вход генератору подаётся либо случайный вектор (если используем конфигурацию с VAE) либо маска уменьшенной размерности.

В конфигурации с автоэнкодером вводится дополнительная сеть - вариационный энкодер, на выходе которого и получаем случайный вектор (входной вектор для генератора). Энкодер и генератор, таким образом, формируют VAE.

Архитектура SPADE
Архитектура SPADE

Архитектура состоит из несколько SPADE Res блоков, внутри которых применяется новый слой нормализации - SPADE.

Spatially-adaptive denormalization (SPADE)
Spatially-adaptive denormalization (SPADE)

Как видно на рисунке маска прогоняется через свёртки для получения параметров гамма и бетта (В BatchNorm2d  либо InstanceNorm2d параметр affine=False) [реализация].

Когда мы применяем InstanceNorm2d на однотонную маску, практически вся информация теряется (см. рисунок ниже). Блок SPADE это исключает.

Генерация изображения по однотонной маске. Результаты у pix2pixHD и SPADE
Генерация изображения по однотонной маске. Результаты у pix2pixHD и SPADE

Spectral Norm

Авторы также добавляют спектральную нормализацию во все архитектуры - генератор, дискриминатор и энкодер. Например, для свёрток в SPADEResnetBlock и слоёв нормализации Batch/InstanceNorm.

\textbf W_{SN}=\frac{\textbf W}{\sigma(\textbf W)}, \sigma(\textbf W) = \max\limits_{\textbf h: \textbf h\ne0} \frac {\|\textbf W \textbf h\|_{2}}{\|\textbf h\|_{2}}.

Спектральная норма (сигма) по сути измеряет, как сильно матрица W может растянуть любой ненулевой вектор h, и равна максимальному сингулярному значению матрицы W [proof].

Спектральная нормализация стабилизирует процесс обучения ГАНов за счёт того, что делает константу Липшица L = 1 для всех слоев.

Про обучение

Обучается SPADE таким же образом, как и pix2pixHD, только вместо LSGAN используется Hinge loss.

Unsupervised image-to-image translation

Cycle-GAN

[Cтатья, реализация]

(a) Перевод изображения из домена X в Y и наоборот, (b) cycle-consistency loss для домена X, (c) cycle-consistency loss для домена Y
(a) Перевод изображения из домена X в Y и наоборот, (b) cycle-consistency loss для домена X, (c) cycle-consistency loss для домена Y

Про архитектуру

Архитектуры генератора и дискриминатора в Cycle-GAN такие же, как и в pix2pix. Меняется только постановка задачи и процедура обучения.

Про обучение

Мы хотим перевести изображения из домена X в домен Y, сохранив при этом содержимое изображения. И наоборот, из домена Y в X. Т.е. решаем задачу style transfer.

Создаём два генератора G : X -> Y, F: Y -> X и два дискриминатора Dx, Dy. Каждую пару генератора и дискриминатора обучаем, например, через minimax game. И добавляем ещё две лосс-функции cycle consistency losses, необходимые для сохранения содержимого изображения.

\mathcal{L}_{cyc}(G, F)=   \mathbb{E}_{x\sim p_{data}(x)}[\|F(G(x)) - x\|_{1}] + \mathbb{E}_{y\sim p_{data}(y)}[\|G(F(y)) - y\|_{1}].
Примеры сгенерированных и восстановленных изображений
Примеры сгенерированных и восстановленных изображений

UGATIT

[Статья, реализация]

Архитектура UGATIT
Архитектура UGATIT

Про архитектуру

UGATIT архитектура похожа на Cycle GAN. Она состоит из двух генераторов (по одному на каждый домен) и 4 дискриминаторов (по 2 на каждый домен).

Дискриминаторы одного домена - это два PatchGAN разной глубины (таким образом, реализуется multi-scale дискриминатор).

Генератор представляет собой Resnet-based архитектуру (без skip connections).

Внутри Resnet блоки с AdaLIN (adaptive layer instance normalization)

AdaLIN(a,\gamma, \beta) = \gamma \cdot  (\rho\cdot \hat a_{I} + (1 - \rho)\cdot \hat a_{L}) + \beta,\hat a_{I}=\frac{a - \mu_{I}}{\sqrt {\sigma^2_{I} + \epsilon}}, \hat a_{L}=\frac{a - \mu_{L}}{\sqrt {\sigma^2_{L} + \epsilon}},\rho \leftarrow clip_{[0,1]} (\rho-\tau\ \Delta\rho),

где μI , μL and σI , σL - это channel-wise и layer-wise мат. ожидание и стандартное отклонений активаций соответственно, γ и β - параметры, генерируемые fully connected слоем, τ - это learning rate и ∆ρ - градиент, определённый оптимайзером. Значения ρ ограничены отрезком [0, 1].

Про обучение

Внутри дискриминатора есть дополнительная классификационная голова. Ей на вход подаются 2D карты признаков, по ним с помощью avg_pooling и fully connected слоёв считаются веса и умножаются на соответствующие карты признаков.

Во время инференса по этим картам строятся class activation map (CAM; attention feature map; heatmap)

Полученный 2d тензор и есть выход дополнительной классификационной головы. Получаем + 4 дополнительных лосса для обучения генератора:

G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device))

G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device))
G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device))

И дискриминатора:

D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))

D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))
D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))

Основные лоссы (minimax game):

G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device))

G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device))
G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device))

D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))

D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))
D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))

Также как и в cycle gan добавляются cycle consistency losses

L^{s\rightarrow t}_{cycle}=\mathbb{E}_{x \sim X_{s}}[|x-G_{t \rightarrow s}(G_{s \rightarrow t}(x))|_{1}].

И identity loss (для сохранения содержимого):

L^{s\rightarrow t}_{identity}=\mathbb{E}_{x \sim X_{t}}[|x-(G_{s \rightarrow t}(x))|_{1}].
fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

В генераторы, как и в дискриминаторы, добавляются дополнительные классификационные головы, которая предсказывают принадлежит ли изображение домену А или домену B, в зависимости от генератора.

G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))
Пример сгенерированных CAM: (a) - исходное изображение; (b) - CAM генератора (c-d) - CAM локального и глобального дискриминатора (e) - сгенерированное изображение с CAM головой; (f) - сгенерированное изображение без CAM головы
Пример сгенерированных CAM: (a) - исходное изображение; (b) - CAM генератора (c-d) - CAM локального и глобального дискриминатора (e) - сгенерированное изображение с CAM головой; (f) - сгенерированное изображение без CAM головы

Финальный лосс:

\min\limits_{G} \max\limits_{D} \lambda_{1}L_{lsgan}+\lambda_{2}L_{cycle} + \lambda_{3}L_{identity} + \lambda_{4}L_{cam}

MUNIT

[Статья, реализация]

Архитектура генератора MUNIT
Архитектура генератора MUNIT

Про архитектуру

Данная архитектура значительно отличается от двух предыдущих. В отличие от них, MUNIT кодирует исходное изображение на содержимое и стиль при помощи Content and Style Encoder.

Style Encoder состоит из strided свёрток, global average pooling и fully connected слоя. Instance Normalization в Style Encoder не используется, т.к. он удаляет оригинальное мат. ожидание и дисперсию признаков, которые содержат информацию о стиле.

Content encoder состоит из тех же свёрток и residual blocks.

Декодер по содержимому и стилю восстанавливает исходное изображение. Он состоит из свёрток и интерполяции (nearest-neighbor interpolation) для повышения размерности. Внутри residual blocks используется Adaptive Instance Normalization.

AdaIN(z,\gamma,\beta) = \gamma (\frac{z-\mu(z)}{\sigma(z)}) + \beta,

где z - это активация предыдущего свёрточного слоя, μ и σ - это channel-wise мат. ожидание и стандартное отклонение активаций, γ и β - это параметры, генерируемые fully connected слоем.

Дискриминатор - multi-scale дискриминатор с LSGAN лоссом (в minimax game)

Про обучение

Также как и в pix2pixHD используется perceptual loss. Перед вычислением дистанции к эмбэддингам применяется Instance Normalization, чтобы удалить информацию о стиле.

MUNIT состоит из двух автоэнкодеров (красные и синие стрелочки соответственно)
MUNIT состоит из двух автоэнкодеров (красные и синие стрелочки соответственно)

MUNIT обучается одновременно и как автоэнкодер и как GAN. Для преобразования изображения из домена А в B и наоборот они пропускаются через style и content энкодер. У каждого домена свои энкодеры. Затем, чтобы получить изображение из домена B(с2 - содержимое; s2 - стиль), но с содержимым изображения из домена А (с1 - содержимое; s1 - стиль), мы заменяем s1 на s2 и пару (с1, s2) пропускаем через декодер для домена B. Аналогично для преобразования из B в А.

Для обучения как обычно используется adversarial loss (minimax game), reconstruction loss (cycle consistency losses) для изображений, как и в двух предыдущих моделях, и reconstruction loss для содержимого и стиля.

\min\limits_{E_{1}, E_{2}, G_{1}, G_{2}} \max\limits_{D_{1}, D_{2}} \mathcal{L}(E_{1}, E_{2}, G_{1}, G_{2}, D_{1}, D_{2}) =  \mathcal{L}^{x_{1}}_{GAN} + \mathcal{L}^{x_{2}}_{GAN} + \lambda_{x}( \mathcal{L}^{x_{1}}_{recon} +  \mathcal{L}^{x_{2}}_{recon}) +  \\ \lambda_{c}( \mathcal{L}^{c_{1}}_{recon} +  \mathcal{L}^{c_{2}}_{recon}) + \lambda_{x}( \mathcal{L}^{x_{1}}_{recon} + \mathcal{L}^{x_{2}}_{recon}) \mathcal{L}^{x_{1}}_{recon} =\mathbb{E}_{x_{1} \sim p(x_{1})}[\| G_{1}(E^{c}_{1}(x_{1}), E^{s}_{1}(x_{1})) - x_{1} \|_{1}] \mathcal{L}^{c_{1}}_{recon} =\mathbb{E}_{c_{1} \sim p(c_{1}), s_{2} \sim q(s_{2}))}[\| E^c_{2}(G_{2}(c1, s2)) - c_{1} \|_{1}] \mathcal{L}^{s_{2}}_{recon} =\mathbb{E}_{c_{1} \sim p(c_{1}), s_{2} \sim q(s_{2}))}[\| E^s_{2}(G_{2}(c1, s2)) - s_{2} \|_{1}] \mathcal{L}^{x_{2}}_{GAN} =\mathbb{E}_{c_{1} \sim p(c_{1}), s_{2} \sim q(s_{2})}[log(1-D_{2}(G_{2}(c_{1}, s_{2})))] + \mathbb{E}_{x_{2} \sim p(x_{2})}[logD_{2}(x_{2})].

Что получилось у нас

И так мы разобрали основные, на мой взгляд, архитектуры в image-to-image translation, а теперь мне бы хотелось рассказать, зачем мы это применяли у себя и как это нам помогло.

Для чего применяем

Нередко у нас появляется задача - обучить Unet/LinkNet модель, чтобы точно находить объекты и определять их размеры. Разметка таких объектов крайне кропотливое занятие, а новые домены могут сильно отличаться от тех, с которыми мы уже умеем работать.

Руда на конвейере
Руда на конвейере
Флотация
Флотация

Как вариант, можно применять Domain Adaptation методы (DA), но иногда домены отличаются довольно сильно, например, руда и пузыри. В таком случае DA не поможет. Поэтому мы захотели научиться генерировать синтетические изображения по бинарной маске, имея всего лишь небольшое количество размеченных изображений (~ 100).

Генерация масок

До сих пор мы ничего не говорили про генерацию самих масок, хотя это тоже довольно важный этап для получения разнообразных и качественных изображений.

Мы попробовали обучить ProgressiveGAN, но получилось плохо. В масках было мало разнообразия.

Пример сгенерированных масок ProgressiveGAN
Пример сгенерированных масок ProgressiveGAN

Наилучшие результаты получились с масками, которые генерировали сегментационные модели, и с рандомно расставленными по пустому изображению масками объектов (из небольшого набора).

Генерация изображений по маскам

Мы попробовали все приведённые в статье архитектуры и наилучшие результаты показали SPADE (среди supervised архитектур) и UGATIT (среди unsupervised).

Результат генерации SPADE на плотной маске
Результат генерации SPADE на плотной маске
Результат генерации SPADE на разряжённой маске (на чёрном фоне появились артефакты)
Результат генерации SPADE на разряжённой маске (на чёрном фоне появились артефакты)
Результат генерации SPADE; слева - семантическая маска, сгенерированная UNet; в центре - реальное изображение; справа - синтетическое; Здесь можно заметить, что синтетическое изображение выглядит реалистично и, плюс к этому, оно лучше соответствует маске, чем реальное
Результат генерации SPADE; слева - семантическая маска, сгенерированная UNet; в центре - реальное изображение; справа - синтетическое; Здесь можно заметить, что синтетическое изображение выглядит реалистично и, плюс к этому, оно лучше соответствует маске, чем реальное
Результат генерации UGATIT на плотной маске
Результат генерации UGATIT на плотной маске
Результат генерации UGATIT на разряжённой маске
Результат генерации UGATIT на разряжённой маске

Таким образом, мы сгенерировали набор синтетических данных (порядка 10 000 изображений), преодобучили UNet модель на них, а затем дообучили на реальных данных. Это дало ~1-2% прироста в точности по метрикам precision, recall (для объектов, не для пикселей) по сравнению с моделью, обученной только на реальных данных.

Заключение

Генеративные сети можно применять для предобучения моделей либо в качестве аугментации. Сгенерированные, таким образом, пары  - маска и картинка - лучше соответствуют друг другу, чем пара - реальное изображения и маска, сгенерированная сегментационной моделью. Но совсем без разметки не обойтись.

Качество unsupervised методов сильно ниже supervised. Последние же хоть и генерируют реалистичные изображения, но недостаточно разнообразные.

Также стоит отметить, что и у supervised и у unsupervised архитектур появляются артефакты на разряжённых масках.

To be discussed

А используете ли Вы генеративные сети у себя на работе? И если да, то какие архитектуры?

Tags:
Hubs:
Total votes 12: ↑12 and ↓0+12
Comments2

Articles

Information

Website
www.nornickel.ru
Registered
Employees
over 10,000 employees
Location
Россия