sklearn's plot_confusion_matrix 에러

2023. 7. 27. 00:45Computer Tips

728x90

1. 현상

plot_confusion_matrix 함수를 사용할 때 아래 에러메시지 발생

 

ImportError: cannot import name 'plot confusion matrix' from 'sklearn.metrics' (/home/ubuntu/mnt/brian/lib/python3.8/site-packages/sklearn/metrics/__init__.py

 

2. 해결방법

이전 버젼에는 사용할 수 있었던 sklearn.metrics.plot_confusion_matrix가 없어지고, 아래와 같이 ConfusionMatrixDisplay가 들어온 걸 확인할 수 있다. (scikit-learn 1.3.0 기준)

from sklearn.metrics import ConfusionMatrixDisplay

 

사용하는 방법은 아래와 같다.

출처: https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

728x90
import matplotlib.pyplot as plt
import numpy as np

from sklearn import datasets, svm
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)

np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
titles_options = [
    ("Confusion matrix, without normalization", None),
    ("Normalized confusion matrix", "true"),
]
for title, normalize in titles_options:

	### ConfusionMatrixDisplay 인스턴스를 정의한다.
    disp = ConfusionMatrixDisplay.from_estimator(
        classifier, # 학습한 classifier
        X_test, # 테스트 X 데이터
        y_test, # X_test의 예측값과 비교를 위한 정답 y_test 데이터
        display_labels=class_names, # 클래스 이름
        cmap=plt.cm.Blues, # color_map
        normalize=normalize, # 결과값 normalize 여부 (true or None)
    )
    disp.ax_.set_title(title) # 제목 세팅

    print(title) # 별도로 제목 출력
    print(disp.confusion_matrix) # 일반적인 confusion matrix 출력

plt.show() # ConfusionMatrixDisplay 인스턴스로 정의된 이미지 출력

3. 결과

기존의 함수와 같이 confusion matrix를 출력할 수 있다.

반응형