%matplotlib inline %reload_ext autoreload %autoreload 2 from fastai.conv_learner import * from fastai.dataset import * from cgan.options.train_options import * opt = TrainOptions().parse(['--dataroot', '/data0/datasets/cyclegan/horse2zebra', '--nThreads', '8', '--no_dropout', '--niter', '100', '--niter_decay', '100', '--name', 'nodrop', '--gpu_ids', '2']) from cgan.options.train_options import TrainOptions from cgan.data.data_loader import CreateDataLoader from cgan.models.models import create_model data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() dataset_size = len(data_loader) dataset_size model = create_model(opt) # opt.niter=9 # opt.niter_decay=1 total_steps = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 for i, data in tqdm(enumerate(dataset)): iter_start_time = time.time() if total_steps % opt.print_freq == 0: t_data = iter_start_time - iter_data_time total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters() if total_steps % opt.display_freq == 0: save_result = total_steps % opt.update_html_freq == 0 if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batchSize if total_steps % opt.save_latest_freq == 0: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.save('latest') iter_data_time = time.time() if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save('latest') model.save(epoch) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) model.update_learning_rate() def show_img(im, ax=None, figsize=None): if not ax: fig,ax = plt.subplots(figsize=figsize) ax.imshow(im) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) return ax def get_one(data): model.set_input(data) model.test() return list(model.get_current_visuals().values()) model.save(201) test_ims = [] for i,o in enumerate(dataset): if i>10: break test_ims.append(get_one(o)) def show_grid(ims): fig,axes = plt.subplots(2,3,figsize=(9,6)) for i,ax in enumerate(axes.flat): show_img(ims[i], ax); fig.tight_layout() for i in range(8): show_grid(test_ims[i]) #! wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip