본문 바로가기
AI/파이토치(Pytorch)

[딥러닝][파이토치] 이그나이트_엔진(Ignite_Engine)

by Hyen4110 2021. 9. 3.

이 글은 아래 튜토리얼을 따라서 공부한 내용을 정리한 글입니다 :) 

 > Pytorch Ignite Concepts

https://pytorch.org/ignite/concepts.html

 

Ignite Your Networks! — PyTorch-Ignite v0.4.6 Documentation

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.

pytorch.org

 > Convolutional Neural Networks for Classifying Fashion-MNIST Dataset using Ignite

https://colab.research.google.com/github/pytorch/ignite/blob/master/examples/notebooks/FashionMNIST.ipynb

 

1. 파이토치 이그나이트(Ignite)

1.1 파이토치 이그나이트(Ignite)란?

: 파이토치 이그나이트(Ignite)는 파이토치의 신경망을 유연하고 투명하게 훈련하고 평가하는 데 도움이 되는 고급 라이브러리입니다.

 

2. 이그나이트의 주요 클래스

2.1 엔진(Engine)

2.1.1. 엔진(Engine)이란?

: 엔진은 데이터를 주어진 숫자 만큼의 처리(processing function)를 반복하고 결과값을 반환하는 클래스입니다.

  (딥러닝에서 흔히 이야기할 수있는 예로, 모델을 학습시키는 학습기를 하나의 엔진이라고 볼 수있겠습니다.)

 

: 개념에 대한 설명을 코드로 아래 처럼 정리할 수있습니다.

   - max_epoch 만큼 while문을 반복하면서  process_fuction을 실행하고,

    실행한 iteration수(="iter_coutner")가 max_epochs가 되면 while문을 종료합니다. 

    (추가적으로 실행한 iteration수가 데이터 크기와 같아지면, 역시 while문을 종료합니다)

while epoch < max_epochs:
	# max_epochs : 최대 에포크 수(반복하여 실행할 숫자)
    
    data_iter = iter(data)
    while True:
        try:
            batch = next(data_iter)
            output = process_function(batch)
            iter_counter += 1
        except StopIteration:
            data_iter = iter(data)

        if iter_counter == epoch_length:
        # epoch_length : 데이터크기 ==len(data)
            break

 

* max_epochs와 epoch_length
 
- max_epochs : 최대 에포크 수(반복 횟수)
- epoch_length : 데이터의 크기  = len(data)

-> [Q] epoch_length는 len(data)로 대체할 수있는데 왜 굳이 변수로 존재하는가?
: [A] 만약 데이터의 크기를 알 수 없는 경우 data iterator가 과부하가 걸려서 중단될 수 있으므로, 이런 경우 epoch_length 임의로 지정하여 사용한다.

 : 위에서 말씀드린 것처럼, 모델 훈련기(model trainer)는 훈련 데이터셋을 여러번 반복하여 돌면서 모델의 파라미터를 업데이트하기 때문에 엔진(Engine)에 해당된다고 할 수 있습니다.

 : 지도학습에 해당하는 모델 훈련기(model_trainer)는 아래 코드와 같이 간단하게 써볼 수 있습니다. 

def train_step(trainer, batch):
    model.train()
    optimizer.zero_grad()
    x, y = prepare_batch(batch)
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

 

: 위의 간단한 예시 외에도 학습 로직에 필요한 다양성은 아래 코드와 같이 필요에 맞게 사용자가 정의하여 추가할 수있습니다. 

model_1 = ...
model_2 = ...
# ...
optimizer_1 = ...
optimizer_2 = ...
# ...
criterion_1 = ...
criterion_2 = ...
# ...

def train_step(trainer, batch):

    data_1 = batch["data_1"]
    data_2 = batch["data_2"]
    # ...

    model_1.train()
    optimizer_1.zero_grad()
    loss_1 = forward_pass(data_1, model_1, criterion_1)
    loss_1.backward()
    optimizer_1.step()
    # ...

    model_2.train()
    optimizer_2.zero_grad()
    loss_2 = forward_pass(data_2, model_2, criterion_2)
    loss_2.backward()
    optimizer_2.step()
    # ...

    # User can return any type of structure.
    return {
        "loss_1": loss_1,
        "loss_2": loss_2,
        # ...
    }

trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

 

3. 엔진 클래스(Engine Class)

3.1 엔진 클래스

ignite.engine.engine.Engine(process_function)

: 엔진(Engine) 클래스는 데이터의 미니배치마다 주어진 함수(process_function)를 반복합니다.

 

항목   설명
Parameters process_function : 엔진에서 iteration의 현재 batch마다 수행되는 함수로, 엔진의 state에 데이터를 저장하여 반환합니다.

- state : 내부적으로 사용자가 정의한 state를 event handlers 간에 전달하는데 사용되는 객체로, 엔진과 함께 생성되며 그 속성은(예: state.iteration, state.epoch) 매 run() 마다 초기화 됩니다.

- last_event_name : 엔진에 의하여 유도된 마지막 이벤트 이름

 

<Basic Trainer>

def update_model(engine, batch):
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(update_model)

@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training(engine):
    batch_loss = engine.state.output
    lr = optimizer.param_groups[0]['lr']
    e = engine.state.epoch
    n = engine.state.max_epochs
    i = engine.state.iteration
    print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss}, lr: {lr}")
trainer.run(data_loader, max_epochs=5)

> Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01
> ...
> Epoch 2/5 : 1700 - batch loss: 0.4217900575859437, lr: 0.01

 

<Basic Evalutator>

from ignite.metrics import Accuracy

def predict_on_batch(engine, batch)
    model.eval()
    with torch.no_grad():
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)

    return y_pred, y

evaluator = Engine(predict_on_batch)
Accuracy().attach(evaluator, "val_acc")
evaluator.run(val_dataloader)

 

3.1 엔진 클래스의 메소드

add_event_handler

add_event_handler
(event_namehandler*args**kwargs)

:지정된 
Event가 발생할 때 실행할 Event Handler를 추가합니다.
 

 

Methods 설명
 

 
fire_event 주어진 Event와 관련된 모든 Handler를 실행합니다.
has_event_handler 지정된 Event에 지정된 Handler가 있는지 확인합니다.
load_state_dict state_dict로 엔진을 세팅합니다. 
on add_event_handler에 대한 데코레이터 바로 가기
register_events 실행할 수 있는 Event를 추가합니다.
run 전달된 데이터에 대해 process_function을 실행합니다.
set_data 데이터 set
state_dict 엔진 상태를 포함하는 사전을 반환합니다
: "seed", "epoch_length", "max_epochs" 및 "iteration" 및 engine.state_dict_user_keys에 의해 정의된 기타 상태 값
terminate 엔진에 종료 신호를 보내 현재 반복 후에 실행을 완전히 종료합니다.
terminate_epoch 현재 반복 후에 현재 에포크를 종료하도록 엔진에 종료 신호를 보냅니다.

 

댓글