이 글은 아래 튜토리얼을 따라서 공부한 내용을 정리한 글입니다 :)
> Pytorch Ignite Concepts
https://pytorch.org/ignite/concepts.html
> Convolutional Neural Networks for Classifying Fashion-MNIST Dataset using Ignite
1. 파이토치 이그나이트(Ignite)
1.1 파이토치 이그나이트(Ignite)란?
: 파이토치 이그나이트(Ignite)는 파이토치의 신경망을 유연하고 투명하게 훈련하고 평가하는 데 도움이 되는 고급 라이브러리입니다.
2. 이그나이트의 주요 클래스
2.1 엔진(Engine)
2.1.1. 엔진(Engine)이란?
: 엔진은 데이터를 주어진 숫자 만큼의 처리(processing function)를 반복하고 결과값을 반환하는 클래스입니다.
(딥러닝에서 흔히 이야기할 수있는 예로, 모델을 학습시키는 학습기를 하나의 엔진이라고 볼 수있겠습니다.)
: 개념에 대한 설명을 코드로 아래 처럼 정리할 수있습니다.
- max_epoch 만큼 while문을 반복하면서 process_fuction을 실행하고,
실행한 iteration수(="iter_coutner")가 max_epochs가 되면 while문을 종료합니다.
(추가적으로 실행한 iteration수가 데이터 크기와 같아지면, 역시 while문을 종료합니다)
while epoch < max_epochs:
# max_epochs : 최대 에포크 수(반복하여 실행할 숫자)
data_iter = iter(data)
while True:
try:
batch = next(data_iter)
output = process_function(batch)
iter_counter += 1
except StopIteration:
data_iter = iter(data)
if iter_counter == epoch_length:
# epoch_length : 데이터크기 ==len(data)
break
* max_epochs와 epoch_length - max_epochs : 최대 에포크 수(반복 횟수) - epoch_length : 데이터의 크기 = len(data) -> [Q] epoch_length는 len(data)로 대체할 수있는데 왜 굳이 변수로 존재하는가? : [A] 만약 데이터의 크기를 알 수 없는 경우 data iterator가 과부하가 걸려서 중단될 수 있으므로, 이런 경우 epoch_length 임의로 지정하여 사용한다. |
: 위에서 말씀드린 것처럼, 모델 훈련기(model trainer)는 훈련 데이터셋을 여러번 반복하여 돌면서 모델의 파라미터를 업데이트하기 때문에 엔진(Engine)에 해당된다고 할 수 있습니다.
: 지도학습에 해당하는 모델 훈련기(model_trainer)는 아래 코드와 같이 간단하게 써볼 수 있습니다.
def train_step(trainer, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(train_step)
trainer.run(data, max_epochs=100)
: 위의 간단한 예시 외에도 학습 로직에 필요한 다양성은 아래 코드와 같이 필요에 맞게 사용자가 정의하여 추가할 수있습니다.
model_1 = ...
model_2 = ...
# ...
optimizer_1 = ...
optimizer_2 = ...
# ...
criterion_1 = ...
criterion_2 = ...
# ...
def train_step(trainer, batch):
data_1 = batch["data_1"]
data_2 = batch["data_2"]
# ...
model_1.train()
optimizer_1.zero_grad()
loss_1 = forward_pass(data_1, model_1, criterion_1)
loss_1.backward()
optimizer_1.step()
# ...
model_2.train()
optimizer_2.zero_grad()
loss_2 = forward_pass(data_2, model_2, criterion_2)
loss_2.backward()
optimizer_2.step()
# ...
# User can return any type of structure.
return {
"loss_1": loss_1,
"loss_2": loss_2,
# ...
}
trainer = Engine(train_step)
trainer.run(data, max_epochs=100)
3. 엔진 클래스(Engine Class)
3.1 엔진 클래스
ignite.engine.engine.Engine(process_function)
: 엔진(Engine) 클래스는 데이터의 미니배치마다 주어진 함수(process_function)를 반복합니다.
항목 | 설명 | |
Parameters | process_function | : 엔진에서 iteration의 현재 batch마다 수행되는 함수로, 엔진의 state에 데이터를 저장하여 반환합니다. - state : 내부적으로 사용자가 정의한 state를 event handlers 간에 전달하는데 사용되는 객체로, 엔진과 함께 생성되며 그 속성은(예: state.iteration, state.epoch) 매 run() 마다 초기화 됩니다. - last_event_name : 엔진에 의하여 유도된 마지막 이벤트 이름 |
<Basic Trainer>
def update_model(engine, batch):
inputs, targets = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(update_model)
@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training(engine):
batch_loss = engine.state.output
lr = optimizer.param_groups[0]['lr']
e = engine.state.epoch
n = engine.state.max_epochs
i = engine.state.iteration
print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss}, lr: {lr}")
trainer.run(data_loader, max_epochs=5)
> Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01
> ...
> Epoch 2/5 : 1700 - batch loss: 0.4217900575859437, lr: 0.01
<Basic Evalutator>
from ignite.metrics import Accuracy
def predict_on_batch(engine, batch)
model.eval()
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
return y_pred, y
evaluator = Engine(predict_on_batch)
Accuracy().attach(evaluator, "val_acc")
evaluator.run(val_dataloader)
3.1 엔진 클래스의 메소드
add_event_handler add_event_handler(event_name, handler, *args, **kwargs) :지정된 Event가 발생할 때 실행할 Event Handler를 추가합니다. |
Methods | 설명 |
fire_event | 주어진 Event와 관련된 모든 Handler를 실행합니다. |
has_event_handler | 지정된 Event에 지정된 Handler가 있는지 확인합니다. |
load_state_dict | state_dict로 엔진을 세팅합니다. |
on | add_event_handler에 대한 데코레이터 바로 가기 |
register_events | 실행할 수 있는 Event를 추가합니다. |
run | 전달된 데이터에 대해 process_function을 실행합니다. |
set_data | 데이터 set |
state_dict | 엔진 상태를 포함하는 사전을 반환합니다 : "seed", "epoch_length", "max_epochs" 및 "iteration" 및 engine.state_dict_user_keys에 의해 정의된 기타 상태 값 |
terminate | 엔진에 종료 신호를 보내 현재 반복 후에 실행을 완전히 종료합니다. |
terminate_epoch | 현재 반복 후에 현재 에포크를 종료하도록 엔진에 종료 신호를 보냅니다. |
'AI > 파이토치(Pytorch)' 카테고리의 다른 글
[파이토치] 텐서 기초 (0) | 2021.09.08 |
---|---|
[딥러닝][파이토치] 이그나이트 이벤트(Ignite Events) (0) | 2021.09.03 |
[NLP][파이토치] seq2seq - Decoder(디코더) (0) | 2021.08.30 |
[NLP][파이토치] seq2seq - Attention(어텐션) (0) | 2021.08.30 |
[NLP][파이토치] seq2seq - Encoder(인코더) (0) | 2021.08.29 |
댓글