我想根据标签的不同来显示两种颜色。具体来说,当label =2时,点的颜色为黑色,其余的点遵循cmap ='Dark2‘。我只有一个简单的想法来重叠点(label = 2),代码就是这样,
import matplotlib.pyplot as plt
from sklearn import datasets
from collections import Counter
iris = datasets.load_iris()
X = iris.data
y = iris.target
df = pd.DataFrame(X, columns = iris.feature_names)
fig, ax = plt.subplots(figsize=(12,8))
points = ax.scatter(df.values[:,0],
df.values[:,1],
c = y,
cmap='Dark2') #others is follow this cmap
for i in range (len(y)):
if y[i] == 2:
ax.scatter(df.values[i,0],df.values[i,1], c = 'k') #when label = 2,points color is black
handles, _ = points.legend_elements()
labels =sorted([f'{item}: {count}' for item, count in Counter(y).items()])
ax.legend(handles, labels, loc = "lower right",title = 'clusters')
plt.show()现在,问题是黑色点的颜色仍然遵循以前的颜色图,比如(是灰色而不是黑色)。如何解决这个问题?

发布于 2021-08-11 22:55:00
可能有一种更优雅的解决方案,但您可以手动向图例中再添加一个标记:
import pandas as pd
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
from sklearn import datasets
from collections import Counter
iris = datasets.load_iris()
X = iris.data
y = iris.target
df = pd.DataFrame(X, columns = iris.feature_names)
fig, ax = plt.subplots(figsize=(12,8))
df1 = df[y!= 2]
points = ax.scatter(df1.values[:,0], df1.values[:,1], marker = 'o', c = y[y!=2], cmap='Dark2')
df2 = df[y== 2]
points2 = ax.scatter(df2.values[:,0], df2.values[:,1], marker = 'o', color = 'k')
handles, _ = points.legend_elements()
labels = sorted([f'{item}: {count}' for item, count in Counter(y).items()])
one_more = mlines.Line2D([], [], color='k', marker='o', linestyle='None', markersize = handles[0].get_ms())
ax.legend(handles + [one_more], labels, loc = "lower right",title = 'clusters')
plt.show()它提供了:

发布于 2021-08-11 22:26:59
您可以首先在y != 2所在的记录上绘制散点图,然后在y == 2上绘制散点图,但这一次,由于您希望所有记录都是Black,而不是使用c,因此设置了color='black',因为这些数据点没有不同的值。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets
from collections import Counter
iris = datasets.load_iris()
X = iris.data
y = iris.target
df = pd.DataFrame(X, columns=iris.feature_names)
fig, ax = plt.subplots(figsize=(12, 8))
points = ax.scatter(df.values[np.where(y != 2), 0],
df.values[np.where(y != 2), 1],
c=y[np.where(y != 2)], cmap='Dark2')
p2 = ax.scatter(df.values[np.where(y == 2), 0], df.values[np.where(y == 2), 1], color='black')
handles, _ = points.legend_elements()
labels =sorted([f'{item}: {count}' for item, count in Counter(y[np.where( y != 2)]).items()])
ax.legend([*handles, p2], [*labels, f'2: {np.sum(y == 2)}'], loc="lower right",title='clusters')
plt.show()输出:

https://stackoverflow.com/questions/68749432
复制相似问题