Learning
레슨 5 / 8·20분

사전학습 모델 활용

TensorFlow.js는 이미 학습된 강력한 모델들을 바로 사용할 수 있도록 제공합니다. 사전학습 모델을 활용하면 대규모 데이터셋 없이도 높은 성능의 AI 기능을 웹 앱에 추가할 수 있습니다.

사전학습 모델 패키지

  • @tensorflow-models/mobilenet : 이미지 분류 (1000개 카테고리)
  • @tensorflow-models/coco-ssd : 객체 탐지 (80개 카테고리)
  • @tensorflow-models/posenet / @tensorflow-models/pose-detection : 신체 자세 추정
  • @tensorflow-models/blazeface / @tensorflow-models/face-landmarks-detection : 얼굴 인식
  • @tensorflow-models/toxicity : 텍스트 독성 분류
  • @tensorflow-models/universal-sentence-encoder : 텍스트 임베딩

MobileNet 이미지 분류

javascript
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';

// 모델 로드 (최초 1회, 캐시됨)
const model = await mobilenet.load({
  version: 2,
  alpha: 1.0,  // 모델 크기 (0.25, 0.50, 0.75, 1.0)
});

// HTML 이미지 요소로 분류
const img = document.getElementById('myImage');
const predictions = await model.classify(img);

// 결과 출력
predictions.forEach(p => {
  console.log(p.className + ": " + (p.probability * 100).toFixed(1) + "%");
});
// 예: "golden retriever: 92.3%"
//     "Labrador retriever: 4.1%"
//     "tennis ball: 1.2%"

COCO-SSD 객체 탐지

javascript
import * as cocoSsd from '@tensorflow-models/coco-ssd';

const model = await cocoSsd.load();

// 이미지에서 객체 탐지
const img = document.getElementById('myImage');
const predictions = await model.detect(img);

predictions.forEach(pred => {
  console.log(
    pred.class + " (" + (pred.score * 100).toFixed(1) + "%) - " +
    "위치: [" + pred.bbox.join(', ') + "]"
  );
  // bbox: [x, y, width, height]
});

// 캔버스에 바운딩 박스 그리기
const canvas = document.getElementById('canvas');
const ctx = canvas.getContext('2d');
ctx.drawImage(img, 0, 0);

predictions.forEach(pred => {
  ctx.strokeStyle = '#00ff00';
  ctx.lineWidth = 2;
  ctx.strokeRect(...pred.bbox);
  ctx.fillStyle = '#00ff00';
  ctx.font = '16px Arial';
  ctx.fillText(
    pred.class + " " + (pred.score * 100).toFixed(0) + "%",
    pred.bbox[0], pred.bbox[1] - 5
  );
});

전이 학습 (Transfer Learning)

javascript
import * as mobilenet from '@tensorflow-models/mobilenet';

// MobileNet 특성 추출기 로드
const mobilenetModel = await mobilenet.load();

// 중간 레이어 출력을 특성으로 사용
const activation = mobilenetModel.infer(img, true);
// shape: [1, 1024] - 1024차원 특성 벡터

// 커스텀 분류기 구축
const customModel = tf.sequential();
customModel.add(tf.layers.dense({
  units: 64,
  activation: 'relu',
  inputShape: [1024],
}));
customModel.add(tf.layers.dense({
  units: 3,  // 커스텀 클래스 수
  activation: 'softmax',
}));

customModel.compile({
  optimizer: tf.train.adam(0.0001),
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy'],
});

// 소량의 데이터로 커스텀 분류기 학습
await customModel.fit(features, labels, {
  epochs: 20,
  batchSize: 16,
});
💡

사전학습 모델은 최초 로드 시 네트워크에서 가중치를 다운로드하므로 시간이 걸립니다. 로딩 상태를 사용자에게 표시하세요. 전이 학습을 사용하면 소량의 데이터만으로도 특정 도메인에 맞는 높은 성능의 모델을 만들 수 있습니다.