모형을 훈련시키다 보면, fit을 쓸 때도 있고, train_on_batch()를 쓸 때도 있고, 뭔 차이가 있길래 이렇게 다르게 쓰는지 마음이 답답할 때가 있습니다. 사실 이 차이는 매우 간단한 이야기인데 말이죠,
fit()은 말 그대로 학습입니다. 여러 개의 epoch도 할 수 있습니다. 그런데, train_on_batch()는 batch라는 용어 때문에 조금 헷갈릴 순 있는데, 이걸 저라면 train_on_data()라고 이름 지었을 것입니다. 무슨 의미냐면 주어진 데이터만큼 학습한다는 의미입니다.
왜 이런 게 생겼느냐 하면, 조금 씩 학습시킬 필요가 있거나, 추가 학습을 할 때에는 train_on_batch()를 이용해서 학습하면 fit보다 더 세밀하게 학습시킬 수 있어서 이런 걸 이용하는 거라고 생각하면 좋겠습니다. 그러니까 fit()은 train_on_batch() 보다 더 큰 개념이라고 생각하면 좋고,
def fit(x, y, batch_size, epochs=1):
for epoch in range(epochs): # epochs
for loop in range(len(x)/batch_size) : # loops for 1 epoch
for batch_x, batch_y in batch(x, y, batch_size): # batch size
model.train_on_batch(batch_x, batch_y)
정확하지는 않지만, 개념적으로 이런 식의 pesudo 코드로 이해해도 좋을 것 같긴 합니다. 그냥 이 정도의 컨셉이구나 정도 이해하면 좋겠네요.
train_on_batch는 언제 만나게 되냐면, GAN할 때 처음 만나게 될 거라 생각합니다. 어쨌든 이걸 이해하면 왜 GAN 훈련을 그런 식으로 시키는지도 이해하기 쉽지 않을까 하는 기대가 있습니다.
댓글