GitHub - DmitryUlyanovAGE Код статті; Змагальні мережі генераторів-кодерів

Використовуйте скрипт age.py для навчання моделі. Ось найважливіші параметри:

статті

  • --набір даних: один із [celeba, cifar10, imagenet, svhn, mnist]
  • --dataroot: для наборів даних, включених до torchvision, це каталог, куди буде завантажено все; для наборів даних imagenet, celeba це шлях до каталогу з папками train і val всередині.
  • --image_size:
  • --save_dir: шлях до папки, де будуть зберігатися контрольні точки
  • --nz: розмірність прихованого простору
  • -- batch_size: розмір партії. За замовчуванням 64.
  • --netG: файл .py із визначенням генератора. Шукали в каталозі моделей
  • --netE: .py файл із визначенням генератора. Шукали в каталозі моделей
  • --netG_chp: шлях до контрольної точки генератора для завантаження
  • --netE_chp: шлях до контрольної точки кодера, з якої завантажується
  • --nepoch: кількість епохи, яку потрібно запустити
  • --start_epoch: номер епохи, з якого починається. Корисно для дооснащення.
  • --e_updates: План оновлення кодера.; KL_fake:, KL_real:, match_z:, match_x: .
  • --g_updates: План оновлення генератора.; KL_fake:, match_z:, match_x: .

І різні аргументи:

  • --worker: кількість працівників завантажувача даних.
  • --ngf: контролює кількість каналів у генераторі
  • --ndf: контролює кількість каналів у кодері
  • --beta1: параметр оптимізатора ADAM
  • --процесор: не використовувати графічний процесор
  • --критерій: Параметричний параметр або непараметричний непараметричний спосіб обчислення KL. Параметричний вписує Гауса в дані, непараметричний базується на найближчих сусідах. За замовчуванням: парам .
  • --KL: Який KL обчислити: qp або pq. За замовчуванням qp .
  • --шум: сфера для рівномірного на кулі або гауссова. Сфера за замовчуванням .
  • --match_z: втрата для використання як втрата при реконструкції у прихованому просторі. L1 | L2 | cos. За замовчуванням cos .
  • --match_x: втрата для використання як відновлення втрат у просторі даних. L1 | L2 | cos. За замовчуванням L1 .
  • --drop_lr: з кожним періодом drop_lr знижується рівень навчання.
  • --save_every: контролює, як часто зберігаються проміжні результати. За замовчуванням 50 .
  • --manual_seed: випадкове насіння. За замовчуванням 123 .