sklearn's plot_confusion_matrix 에러
2023. 7. 27. 00:45ㆍComputer Tips
728x90
1. 현상
plot_confusion_matrix 함수를 사용할 때 아래 에러메시지 발생
ImportError: cannot import name 'plot confusion matrix' from 'sklearn.metrics' (/home/ubuntu/mnt/brian/lib/python3.8/site-packages/sklearn/metrics/__init__.py
2. 해결방법
이전 버젼에는 사용할 수 있었던 sklearn.metrics.plot_confusion_matrix가 없어지고, 아래와 같이 ConfusionMatrixDisplay가 들어온 걸 확인할 수 있다. (scikit-learn 1.3.0 기준)
from sklearn.metrics import ConfusionMatrixDisplay
사용하는 방법은 아래와 같다.
728x90
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, svm
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names
# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
titles_options = [
("Confusion matrix, without normalization", None),
("Normalized confusion matrix", "true"),
]
for title, normalize in titles_options:
### ConfusionMatrixDisplay 인스턴스를 정의한다.
disp = ConfusionMatrixDisplay.from_estimator(
classifier, # 학습한 classifier
X_test, # 테스트 X 데이터
y_test, # X_test의 예측값과 비교를 위한 정답 y_test 데이터
display_labels=class_names, # 클래스 이름
cmap=plt.cm.Blues, # color_map
normalize=normalize, # 결과값 normalize 여부 (true or None)
)
disp.ax_.set_title(title) # 제목 세팅
print(title) # 별도로 제목 출력
print(disp.confusion_matrix) # 일반적인 confusion matrix 출력
plt.show() # ConfusionMatrixDisplay 인스턴스로 정의된 이미지 출력
3. 결과
기존의 함수와 같이 confusion matrix를 출력할 수 있다.
반응형
'Computer Tips' 카테고리의 다른 글
[linux] 파일들만 검색해서 압축(zip)하기 (0) | 2023.08.14 |
---|---|
[리눅스] 리눅스 터미널에서 구글 드라이브 대용량 파일 다운로드 (0) | 2023.02.18 |
[linux] add-apt-repository command not found 해결방법 (0) | 2023.02.10 |
[python] plotly 코드에 이상없는데 결과 시각화가 안 될 때 (0) | 2023.01.18 |
local에 연결된 git remote 백업, 변경하기 (0) | 2023.01.11 |