Learning
레슨 3 / 8·20분

학습과 예측

모델을 구축한 후에는 compile로 학습 설정을 하고, fit으로 데이터를 학습시킵니다. 학습이 완료되면 predict로 새로운 데이터에 대한 예측을 수행할 수 있습니다.

모델 컴파일

javascript
// 모델 학습 전에 반드시 컴파일
model.compile({
  optimizer: 'adam',              // 최적화 알고리즘
  loss: 'categoricalCrossentropy', // 손실 함수 (다중 분류)
  metrics: ['accuracy'],          // 평가 지표
});

// 회귀 모델 컴파일
regressionModel.compile({
  optimizer: tf.train.adam(0.01),  // 학습률 직접 지정
  loss: 'meanSquaredError',        // 평균 제곱 오차 (회귀)
});

모델 학습 (fit)

javascript
// 학습 데이터 준비
const xs = tf.tensor2d([
  [0, 0], [0, 1], [1, 0], [1, 1]
]);
const ys = tf.tensor2d([
  [1, 0], [0, 1], [0, 1], [1, 0]
]);

// 모델 학습
const history = await model.fit(xs, ys, {
  epochs: 100,          // 전체 데이터 반복 횟수
  batchSize: 2,         // 한 번에 처리할 샘플 수
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      // 매 에포크 종료 시 호출
      console.log(
        "에포크 " + (epoch + 1) + ": loss = " + logs.loss.toFixed(4) +
        ", accuracy = " + logs.acc.toFixed(4)
      );
    },
  },
});

// 학습 이력 확인
console.log('최종 손실:', history.history.loss.slice(-1)[0]);

예측 수행

javascript
// 새로운 데이터로 예측
const input = tf.tensor2d([[0, 1]]);
const prediction = model.predict(input);

prediction.print();
// 예: Tensor [[0.12, 0.88]]  → 두 번째 클래스일 확률 88%

// 예측 결과를 JavaScript로 가져오기
const probabilities = await prediction.data();
console.log('확률:', Array.from(probabilities));

// 가장 높은 확률의 클래스 찾기
const classIndex = prediction.argMax(-1).dataSync()[0];
console.log('예측 클래스:', classIndex);

데이터 제너레이터 활용

javascript
// 대용량 데이터는 제너레이터로 처리
function* dataGenerator() {
  for (let i = 0; i < 1000; i++) {
    const x = Math.random();
    const y = 2 * x + 1 + (Math.random() - 0.5) * 0.1;
    yield { xs: [x], ys: [y] };
  }
}

const dataset = tf.data.generator(dataGenerator)
  .batch(32)
  .shuffle(100);

// 데이터셋으로 학습
await model.fitDataset(dataset, {
  epochs: 50,
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      if ((epoch + 1) % 10 === 0) {
        console.log("에포크 " + (epoch + 1) + ": loss = " + logs.loss.toFixed(4));
      }
    },
  },
});
💡

epochs는 전체 데이터를 몇 번 반복할지, batchSize는 한 번의 업데이트에 몇 개의 샘플을 사용할지를 결정합니다. 너무 큰 epochs는 과적합(overfitting)을, 너무 작은 값은 과소적합(underfitting)을 유발할 수 있습니다.