본문 바로가기

Data Science/Machine Learning

[핸즈온 머신러닝 2/E] 3장. 분류

이번에 포스팅할 3장에서는 "분류(Classification)"에 대해서 설명해보려 한다.

 

우선 설명을 위해 예시로 사용할 데이터는 MNIST 데이터 셋으로, 아래와 같은 숫자 이미지 데이터이다.

 

본격적으로 시작하기에 앞서, MNIST 데이터 셋을 "특성들로만 이루어진 데이터 셋""target 값들로만 이루어진 데이터 셋"으로 분리해주겠다.

 

▶ 이진 분류기 훈련

  • 교재에서는 확률적 경사 하강법(SGD; Stochastic Gradient Descent) 분류기숫자 5를 식별하는 것을 예시로 보여주고 있다.

◆ 확률적 경사 하강법 분류기(SGDClassifier)

  • 매우 큰 데이터 셋을 효율적으로 처리하는 장점이 있다.
    • 한 번에 하나씩 훈련 샘플을 독립적으로 처리하기 때문!
  • 사이킷런 SGDClassifier 클래스를 사용하면 된다.
    • SGDClassifier는 훈련하는 데 무작위성을 사용한다.
    • 따라서 random_state 매개변수를 지정해주는 것이 좋다.

 

 

▶ 성능 측정 지표

1. 정확도(Accuracy)

  • 정확하게 예측한 비율
  • 하지만 정확도를 분류기의 성능 측정 지표로 선호하지는 않는다.
    • 불균형 데이터 셋의 경우, 클래스 치우침 현상 때문에 정확도가 무조건 높게 나올 수가 있다.

2. 오차 행렬(Confusion Matrix)

  • 오차 행렬을 만들려면 실제 target과 비교할 수 있도록 먼저 예측값을 만들어주어야 한다.
    • cross_val_predict( ) 함수를 사용하면 k-fold 교차 검증을 수행하지만, 평가 점수가 아닌 각 테스트 fold에서 얻은 예측을 반환해준다.
  • 그 다음, confusion_matrix( ) 함수target 클래스예측 클래스를 넣어주면 된다.

  • 오차 행렬의 실제 클래스를 의미하고, 예측한 클래스를 의미한다.
    • 위 코드의 오차 행렬 결과를 해석해보겠다.
    • 첫 번째 행'5 아님' 이미지(negative class)에 대한 것으로, 53,057개를 '5 아님'으로 정확하게 분류(true negative)했고 나머지 1,522개는 '5'라고 잘못 분류(false positive)했다.
    • 두 번째 행'5' 이미지(positive class)에 대한 것으로,  1,325개를 '5 아님'으로 잘못 분류(false negative)했고 나머지 4,096개를 정확히 '5'라고 분류(true positive)했다.

3. 정밀도(Precision)와 재현율(Recall)

  • 오차 행렬이 많은 정보를 제공해주지만, 가끔 더 요약된 지표가 필요할 때도 있다.
  • 이럴 경우 사용하는 지표가 바로 정밀도재현율이다.
    • 정밀도(Precision) : 모델이 True 라고 분류한 것 중에서 실제 True인 것의 비율
    • 재현율(Recall) : 실제 True인 것 중에서 모델이 True라고 예측한 것의 비율
      • 민감도(Sensitivity)라고도 한다.
  • 정밀도와 재현율은 성능 측정 지표로 같이 사용하는 것이 일반적이다.

정밀도와 재현율 식
정밀도와 재현율에 대한 이해를 돕기 위한 예시 그림

  • True Positive(TP) : 실제 True인 정답을 True라고 예측 (정답)
  • False Positive(FP) : 실제 False인 정답을 True라고 예측 (오답)
  • False Negative(FN) : 실제 True인 정답을 False라고 예측 (오답)
  • True Negative(TN) : 실제 False인 정답을 False라고 예측 (정답)

 

4. F1 점수(F1 score)

  • F1 score정밀도와 재현율의 조화 평균이다.

F1 score

 

  • 정밀도와 재현율이 비슷한 분류기에서는 F1 score 값이 높다.
  • 하지만 F1 score가 항상 best 성능 측정 지표는 아니다.
  • 상황에 따라 정밀도가 중요할 수도 있고, 재현율이 중요할 수도 있다.
    • ex) 재현율(혹은 민감도)가 중요한 경우: 악성 종양 분류
      • 종양을 검사하는데 악성 종양을 양성이라고 판단하면, 그 환자는 치료할 기회를 놓치게 되고 목숨까지 위험해진다. 반대로 양성 종양을 악성이라고 판단하면 환자가 불필요한 비용을 지불해야하지만, 치료 시기를 놓쳐서 목숨이 위험한 상황은 발생하지 않게 된다.
    • ex) 정밀도가 중요한 경우: 스팸 메일 분류
      • 수신 받은 메일이 정상 메일임에도 불구하고 스팸 메일로 판단하면, 정말로 중요한 메일을 받지 못하는 상황이 발생한다. 반면에 수신 받은 메일이 스팸 메일임에도 불구하고 정상 메일로 판단하면, 비록 스팸 메일을 받게 되는 불편함(?)은 있지만 중요한 메일을 받지 못할 상황은 면할 수 있다.

5. 정밀도/재현율 Trade-Off

  • 정밀도와 재현율 사이에는 다음과 같은 관계가 성립한다.
    • 임곗값을 높이면 정밀도는 높아지지만 재현율은 낮아지고, 임곗값을 낮추면 재현율은 높아지지만 정밀도가 낮아진다.
    • 이해를 돕기 위해 MNIST 데이터 셋을 예시로 첨부하였다.

  • 따라서 안타깝게도 정밀도와 재현율을 동시에 높은 값으로 얻을 수는 없으며, 이러한 현상을 정밀도/재현율 트레이드 오프라고 한다.

◆ 그렇다면 적절한 임곗값(threshold)을 어떻게 정하면 될까?

  • 먼저 cross_val_predict( ) 함수를 사용해서 훈련 데이터 셋에 있는 모든 샘플의 점수를 구해야 한다.
    • 단, 주의할 점은 예측 결과를 반환받는 것이 아니라 결정 점수를 반환받도록 지정해야 한다.
  • 그리고 받환 받은 결정 점수로 precision_recall_curve( ) 함수를 사용하여, 가능한 모든 임곗값에 대해 정밀도와 재현율을 계산하고, 정밀도와 재현율matplotlib을 사용해서 그려본다.

  • 위 방법 외에도, 아래와 같이 재현율에 대한 정밀도 곡선을 그려보면 좋은 정밀도/재현율 Trade-off를 선택할 수 있다.

 

  • 위의 곡선을 살펴보면, 재현율 80% 근처에서 정밀도가 급격하게 줄어들기 시작한다.
  • 이러한 하강점 직전을 정밀도/재현율 Trade-Off로 선택하는 것이 좋다.
    • 위 곡선의 경우, 재현율이 60% 정도인 지점을 선택한다.

6. ROC 곡선(ROC Curve) 이진 분류기에서 많이 사용하는 성능 측정 지표!!

  • 이진 분류에서 일반적으로 많이 사용하는 성능 측정 지표이다.
  • ROC 곡선은 False Positive Rate(FPR)에 대한 True Positive Rate(TPR)의 곡선이다.
    • 여기서 FPR양성으로 잘못 분류된 음성 샘플의 비율을 의미한다.
    • 즉, 1에서 음성으로 정확하게 분류한 음성 샘플의 비율인 True Negative Rate(TNR)을 뺀 값이다.
    • True Negative Rate(TNR)특이도(Specificity)라고 한다.
  • 따라서 ROC 곡선민감도(또는 재현율)에 대한 (1 - 특이도) 그래프라고 이해하면 되겠다.

모든 가능한 임곗값에서 진짜 양성 비율(재현율 = TPR)에 대한 거짓 양성 비율(FPR)을 나타낸 ROC 곡선

  • ROC 곡선에서도 trade-off가 존재한다.
    • ROC 곡선을 살펴보면, 재현율(TPR)이 높을수록 거짓 양성 비율(FPR)이 늘어나는 것을 알 수 있다.
  • 점선완전한 랜덤 분류기의 ROC 곡선을 의미하며, 점선에서 최대한 멀리 떨어질수록 좋은 분류기이다.

◆ AUC(Area Under the Curve)란 무엇인가?

  • AUC란 ROC 곡선 아래의 면적을 의미하며, 이 값을 사용하여 분류기의 성능을 측정할 수 있다.
  • 완벽한 분류기ROC 곡선의 AUC가 1이고, 완전한 랜덤 분류기ROC 곡선의 AUC가 0.5이다.
    • 여기서 랜덤 분류기란 훈련 데이터셋의 class 비율에 따라 무작위로 예측하는 것을 말한다.
  • 사이킷런 roc_auc_score( ) 함수를 사용하면 쉽게 구할 수 있다.

※ 일반적으로 양성 class가 드물거나 False Negative(FN)보다 False Positive(FP)가 더 중요할 때(ex. 악성 종양 분류)는 PR 곡선(정밀도-재현율 곡선)을 사용하고, 그렇지 않으면 ROC 곡선을 사용한다.

 

 

▶ 다중 분류

  • 다중 분류란 두 개이상의 class를 구별하는 것을 말한다.
    • <다중 분류 모델>
      • SGD(Stochastic Gradient Descent)
      • 랜덤 포레스트
      • 나이브 베이즈
      • 등등...
    • <이진 분류 모델>
      • 로지스틱 회귀
      • 서포트 벡터 머신
      • 등등...
  • 이진 분류기를 여러 개 사용해서 다중 class를 분류하는 기법도 많다.
    • 1.  OvR(One-versus-the-Rest) 또는 OvA(One-versus-All) 전략
      • ex) 이미지를 분류할 때, 각 분류기의 결정 점수 중에서 가장 높은 것을 class로 선택한다.
      • 대부분의 이진 분류 알고리즘에서는 OvO보다 OvR을 선호한다.
    • 2. OvO(One-versus-One) 전략
      • ex) 0과 1 구별, 0과 2 구별, 1과 2 구별 등과 같이 각 숫자의 조합마다 이진 분류기를 훈련시킨다.
      • 각 분류기의 훈련에 전체 훈련 데이트 셋 중, 구별할 두 class에 해당하는 샘플만 필요하다는 장점이 있다.
      • 특히 서포트 벡터 머신과 같이 훈련 데이트 세트의 크기에 민감한 모델의 경우, 몇 개의 분류기를 훈련시키는 것보다 작은 훈련 데이터 셋에서 많은 분류기를 훈련시키는 쪽이 빠르므로 OvO 전략을 선호한다.
    • 다중 class 분류 작업에 이진 분류 알고리즘을 선택하면, 사이킷런이 알고리즘에 따라 자동으로 OvR 또는 OvO를 실행해준다.
      • 만일 사이킷런에서 OvR 또는 OvO를 지정해서 사용하려면, OneVsRestClassifier 또는 OneVsOneClassifier를 사용하면 된다.

 

 

▶ 에러 분석

  • 가능성이 높은 모델을 하나 찾았다고 가정하고, 해당 모델의 성능을 높이기 위한 방법으로 에러 분석이 있다.
  • 우선 오차 행렬(Confusion Matrix)를 확인해 볼 수 있는데, 다음과 같이 matplotlib의 matshow( ) 함수를 사용해서 이미지로 표현하면 보기가 편리하다.

배열에서 가장 큰 값은 흰색으로, 가장 작은 값은 검은색으로 정규화되어 그려진다.

  • 다음으로 오차 행렬의 각 값을 대응되는 class의 이미지 개수로 나누어 에러 비율을 비교한다.
    • 에러의 절대 개수가 아닌 에러 비율을 비교하는 이유는, 개수로 비교하면 이미지가 많은 class가 상대적으로 나쁘게 보이기 때문이다.
    • 다른 항목은 그대로 유지하고 주대각선만 0으로 채워서 다시 그려보면 다음과 같다.

class 8의 열이 상당히 밝은데, 이는 많은 이미지가 8로 잘못 분류되었음을 암시한다. 하지만 class 8의 행은 그리 나쁘지 않은데, 이는 실제 8이 적절히 8로 분류되었다는 것을 말해준다.

  • 위와 같이 오차 행렬을 분석하면 분류기의 성능 향상 방안에 대한 insight를 얻을 수 있다.

 

 

▶ 다중 레이블 분류

  • 분류기가 샘플마다 여러 개의 class를 출력하는 것이다.
    • ex) 얼굴 인식 분류기 → 같은 사진에 여러 사람이 등장한다면? 인식된 사람마다 하나씩 꼬리표(tag)를 붙여야 한다!!
  • 쉽게 말해, target label 값이 여러 개인 경우에 대해서 분류하는 것을 말한다.
  • 각각의 target label에 속한 샘플 수를 기반으로 가중치를 줄 수도 있다.

 

 

▶ 다중 출력 분류(또는 다중 출력 다중 클래스 분류)

  • 다중 레이블 분류에서 한 레이블이 다중 클래스가 될 수 있도록 일반화 한 것이다.
  • 즉, 분류기의 출력이 다중 레이블이고 각 레이블은 값을 여러 개 갖는 경우를 말한다.
    • ex) 이미지 잡음 제거 시스템

 

여기까지 해서 3장 분류에 대한 내용 정리를 마치도록 하겠다.

 

★ 참고 자료

- 핸즈온 머신러닝 2/E 교재

- 파이썬 머신러닝 완벽 가이드 교재