We create below a dummy trainer and evaluator for quick experimentations. We added print
command at different events to show how things are working. There is no meaningful training logic in the code to run things fast. This setup can be easily modified to test other user ideas.
Optionally, to install the latest version:
!pip install --upgrade pytorch-ignite
import torch
import ignite
print(torch.__version__, ignite.__version__)
from ignite.engine import Engine, Events
from ignite.utils import setup_logger, logging
train_data = range(10)
eval_data = range(4)
max_epochs = 5
def train_step(engine, batch):
print(f"{engine.state.epoch} / {engine.state.max_epochs} | {engine.state.iteration} - batch: {batch}", flush=True)
trainer = Engine(train_step)
# Enable trainer logger for a debug mode
# trainer.logger = setup_logger("trainer", level=logging.DEBUG)
evaluator = Engine(lambda e, b: None)
@trainer.on(Events.EPOCH_COMPLETED(every=2))
def run_validation():
print(f"{trainer.state.epoch} / {trainer.state.max_epochs} | {trainer.state.iteration} - run validation", flush=True)
evaluator.run(eval_data)
@trainer.on(Events.ITERATION_COMPLETED(every=7))
def log_events_filtering__every():
print(f"{trainer.state.epoch} / {trainer.state.max_epochs} | {trainer.state.iteration} - calling log_events_filtering__every", flush=True)
@trainer.on(Events.EPOCH_COMPLETED(once=3))
def log_events_filtering__once():
print(f"{trainer.state.epoch} / {trainer.state.max_epochs} | {trainer.state.iteration} - calling log_events_filtering__once", flush=True)
def custom_event_filter(engine, event):
if trainer.state.epoch == 2 and event in (1, 3):
return True
return False
@evaluator.on(Events.ITERATION_COMPLETED(event_filter=custom_event_filter))
def log_events_filtering__event_filter():
print(f"{trainer.state.epoch} / {trainer.state.max_epochs} | {evaluator.state.iteration} - calling log_events_filtering__event_filter", flush=True)
trainer.run(train_data, max_epochs=max_epochs)
import time
import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.utils import setup_logger, logging
def pprint(*args, **kwargs):
rank = idist.get_rank()
time.sleep(rank * 0.1)
print(f"Rank {rank}:", end=" ")
print(*args, **kwargs)
def run(local_rank):
rank = idist.get_rank()
torch.manual_seed(12 + rank)
train_data = range(10)
eval_data = range(4)
max_epochs = 5
def train_step(engine, batch):
pprint(f"{engine.state.epoch} / {engine.state.max_epochs} | {engine.state.iteration} - batch: {batch}", flush=True)
trainer = Engine(train_step)
@trainer.on(Events.EPOCH_COMPLETED)
def sync():
idist.barrier()
trainer.run(train_data, max_epochs=max_epochs)
idist.spawn("gloo", run, (), nproc_per_node=2)