scikit-learnでk-means法によるクラスタリングしてみる

ロウ
2022-01-18
ロウ
2022-01-18

ライブラリをインポートする

```
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from sklearn.metrics import silhouette_score
```

データセットをロードする

```
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target_names[iris.target]
print(df.head())
# 結果:
# sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
# 0     5.1         3.5              1.4              0.2           setosa
# 1     4.9         3.0              1.4              0.2           setosa
# 2     4.7         3.2              1.3              0.2           setosa
# 3     4.6         3.1              1.5              0.2           setosa
# 4     5.0         3.6              1.4              0.2           setosa
```

最適なクラスター数を選択する

```
cols = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]

for n_cluster in [3,4,5,6]:
    kmeans = KMeans(n_clusters=n_cluster).fit(df[cols])
    silhouette_avg = silhouette_score(df[cols], kmeans.labels_)
    print('Silhouette Score for %i Clusters: %0.4f' % (n_cluster, silhouette_avg))

# 結果:
# Silhouette Score for 3 Clusters: 0.5528
# Silhouette Score for 4 Clusters: 0.4981
# Silhouette Score for 5 Clusters: 0.4912
# Silhouette Score for 6 Clusters: 0.3648
```

k-means法によるクラスタリング

```
kmeans = KMeans(n_clusters=3).fit(df[cols])
labels = kmeans.labels_
print(labels)
# 結果:
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
# 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
# 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 2 2 2 2 1 2 2 2 2
# 2 2 1 1 2 2 2 2 1 2 1 2 1 2 2 1 1 2 2 2 2 2 1 2 2 2 2 1 2 2 2 1 2 2 2 1 2
# 2 1]

df['cluster'] = labels
print(df.head())
# 結果:
# sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target cluster
# 0          5.1          3.5          1.4          0.2          setosa          0
# 1          4.9          3.0          1.4          0.2          setosa          0
# 2          4.7          3.2          1.3          0.2          setosa          0
# 3          4.6          3.1          1.5          0.2          setosa          0
# 4          5.0          3.6          1.4          0.2          setosa          0
```