Learning
레슨 4 / 8·20분

모델 학습하기

모델 학습은 데이터를 준비하고, 학습 과정을 모니터링하며, 최적의 성능을 달성하기 위해 하이퍼파라미터를 조정하는 과정입니다. 이 레슨에서는 데이터 전처리부터 학습 시각화, 모델 저장까지 다룹니다.

데이터 전처리와 정규화

javascript
import * as tf from '@tensorflow/tfjs';

// 데이터 정규화: 값을 0~1 범위로 변환
function normalizeData(data) {
  const min = data.min();
  const max = data.max();
  return data.sub(min).div(max.sub(min));
}

// 학습/검증 데이터 분리
function splitData(xs, ys, ratio) {
  const numExamples = xs.shape[0];
  const splitIdx = Math.floor(numExamples * ratio);

  const trainXs = xs.slice([0], [splitIdx]);
  const trainYs = ys.slice([0], [splitIdx]);
  const valXs = xs.slice([splitIdx]);
  const valYs = ys.slice([splitIdx]);

  return { trainXs, trainYs, valXs, valYs };
}

// 사용 예시
const rawData = tf.tensor2d([
  [150, 60], [170, 75], [160, 65], [180, 85]
]);
const normalized = normalizeData(rawData);
normalized.print();

학습 콜백과 모니터링

javascript
const model = tf.sequential({
  layers: [
    tf.layers.dense({ units: 16, activation: 'relu', inputShape: [2] }),
    tf.layers.dense({ units: 1, activation: 'sigmoid' }),
  ]
});

model.compile({
  optimizer: tf.train.adam(0.001),
  loss: 'binaryCrossentropy',
  metrics: ['accuracy'],
});

// 학습 콜백 설정
const history = await model.fit(trainXs, trainYs, {
  epochs: 50,
  validationData: [valXs, valYs],
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      console.log(
        "에포크 " + (epoch + 1) +
        " - loss: " + logs.loss.toFixed(4) +
        " - val_loss: " + logs.val_loss.toFixed(4) +
        " - accuracy: " + logs.acc.toFixed(4)
      );
    },
    onTrainEnd: () => {
      console.log("학습 완료!");
    },
  },
});

// 학습 이력 확인
console.log("최종 loss:", history.history.loss.slice(-1)[0]);
console.log("최종 val_loss:", history.history.val_loss.slice(-1)[0]);

EarlyStopping 콜백

javascript
// 과적합 방지를 위한 조기 종료
const earlyStop = tf.callbacks.earlyStopping({
  monitor: 'val_loss',
  patience: 5,  // 5 에포크 동안 개선 없으면 중단
});

await model.fit(trainXs, trainYs, {
  epochs: 200,
  validationData: [valXs, valYs],
  callbacks: [earlyStop],
});

모델 저장과 불러오기

javascript
// 브라우저 로컬 스토리지에 저장
await model.save('localstorage://my-model');

// 불러오기
const loadedModel = await tf.loadLayersModel('localstorage://my-model');

// IndexedDB에 저장
await model.save('indexeddb://my-model');

// 파일 다운로드로 저장
await model.save('downloads://my-model');
// model.json과 weights.bin 파일이 다운로드됩니다

// HTTP 서버에 업로드
await model.save('http://localhost:3000/upload-model');
💡

학습 시 반드시 validationData를 설정하여 과적합 여부를 모니터링하세요. 학습 loss는 줄어드는데 검증 loss가 증가한다면 과적합이 발생하고 있는 것입니다. EarlyStopping으로 적절한 시점에 학습을 중단할 수 있습니다.