# -*- coding: utf-8 -*-
""" Copyright 2020 Marco Arrigoni
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import defaultdict
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull
from scipy.spatial.qhull import QhullError
from sklearn.neighbors import NearestNeighbors
[docs]def group_structures_in_clusters(ordered_structures, cluster_labels):
"""
Group a list of structure names according to the cluster they belong to.
Parameters
----------
ordered_structures : list
Ordered list of structure names.
The ordering is done by matching the dataset: the i-th element in the
dataset is the structure corresponding to ``ordered_structures[i]``
cluster_labels : list
``cluster_labels[i]`` is the label of the cluster where
``ordered_structures[i]`` belongs to.
Returns
-------
groups : defaultdict
keys are cluster labels and the values are lists with the structures
belonging to that cluster.
"""
groups = defaultdict(list)
for i, label in enumerate(cluster_labels):
groups[label].append(ordered_structures[i])
return groups
[docs]def find_centroids(data, labels):
""" Finds the centroids locations of the clusters.
Parameters
----------
data : 2D array
The dataset. Each row represents one structure
labels : 1D array
``labels[i]`` is the index of the cluster where ``data[i]`` belongs.
A label with value -1 is considered to represent noise. Its centroid
will not be returned.
Returns
-------
centroids : dict
the key are the clusters indices, the values are the coordinates of
the centroid
"""
centroids = dict()
unique_labels = np.unique(labels[labels != -1])
for lab in unique_labels:
lmask = labels == lab
points = data[lmask]
centroids[lab] = points.mean(axis=0)
return centroids
[docs]def get_structure_group_index(structure_name, groups):
""" Given the name of a structure, returns the group index it belongs to.
Parameters
----------
structure_name : string
the name of the structure
groups : dict
key:value pairs of cluster indices and a list of the name of the
structures belonging to that cluster
Returns
-------
key : int
the index of the cluster
"""
for key in groups.keys():
if structure_name in groups[key]:
return key
[docs]def write_clustered_structures(groups, key):
""" Write on a text file all the structures belonging to a given cluster
Parameters
----------
groups : dict
key:value pairs of cluster indices and a list of the name of the
structures belonging to that cluster
key : int
the cluster index
Returns
-------
fname : string
the name of the just-written text file
"""
if key in groups.keys():
fname = 'clusters_' + str(key) + '.txt'
with open(fname, 'w') as f:
for item in groups[key]:
f.write(item)
f.write('\n')
return fname
[docs]def plot_cluster_plot(data, labels, title, ordered_structures,
plot_kwargs=dict(), cmap=cm.jet, show_names=True,
plot_chull=False, plot_centroids=True, ax=None):
""" Make a scatter plot of the clusters.
Parameters
----------
data : 2D array
The dataset. Each row represents one structure
labels : 1D array
``labels[i]`` is the index of the cluster where ``data[i]`` belongs.
A label with value -1 is considered to represent noise. Its points
are represented by crosses.
title : string
the plot title
ordered_structures : list
the i-th element is the structure name for ``data[i]``
plot_kwargs : dict
key:value pairs to fine-tune the plot
cmap : matplotlib cmap instance. Default cm.jet
the colormap to be used in the plot
show_names : bool. Default True
if True, the structure names will be shown in the plot
plot_chull : bool. Default False
if True, plots also the convex hull of points in the cluster
plot_centroids: bool. Default True
if True, the centroids of each cluster are also plotted
ax : matplotlib Axes instance or None. Default is None
the axes for the plot. If None, the current axes is taken.
(TODO)
"""
figsize = plot_kwargs.get('figsize', (18, 8))
fig, ax = plt.subplots(figsize=figsize)
ax.set_title(title)
labels = np.array(labels)
noise_mask = labels == -1
unique_labels = np.unique(labels[labels != -1])
clusters = np.array([data[labels == lab] for lab in unique_labels])
lengths = [len(clust) for clust in clusters]
sindices = np.argsort(-np.array(lengths))
clusters = clusters[sindices]
unique_labels = unique_labels[sindices]
cmap_list = np.linspace(0, 1, len(unique_labels))
colors = cmap(cmap_list)
if data.shape[1] > 1:
ax.scatter(data[noise_mask, 0], data[noise_mask, 1],
color='black', marker='+', alpha=0.3, zorder=0)
else:
ax.scatter(data[noise_mask, 0], np.zeros(data[noise_mask].shape[0]),
color='black', marker='+', alpha=0.3, zorder=0)
for i, col in enumerate(colors):
lab = unique_labels[i]
label = 'Cluster ' + str(lab)
col = np.atleast_1d(col)
points = clusters[i]
if data.shape[1] > 1:
ax.scatter(points[:, 0], points[:, 1], color=col, alpha=0.6,
label=label, zorder=i + 1)
if plot_chull:
try:
chull = ConvexHull(points[:, :2])
except QhullError:
chull = None
if chull is not None:
for simplex in chull.simplices:
ax.plot(points[simplex, 0], points[simplex, 1],
color=col, linewidth=3, alpha=1)
if plot_centroids:
centroid = points.mean(axis=0)
ax.scatter(centroid[0], centroid[1], color=col,
marker='X', s=200, edgecolors='black',
zorder = i + 10)
else:
ax.scatter(points[:, 0], np.zeros(points.shape[0]),
color=col, alpha=0.6, label=label, zorder=i + 1)
if show_names:
for i, struct in enumerate(ordered_structures):
name = struct.split('_')[-2:]
if data.shape[1] > 1:
ax.annotate(name, data[i, :2])
else:
ax.annotate(name, (data[i, 0], 0.005 + 0.0005*i*(-1)**i))
leg = plt.legend()
for lh in leg.legendHandles:
lh.set_alpha(1)
plt.show()
[docs]def make_reachability_plot(optics_instance, x_lims=None):
""" Make the reachability plot from a trained scikit learn OPTICS instance
Parameters
----------
optics_instance : scikit learn OPTICS instance
a trained instance
x_lims : tuple
x limits to be plotted
"""
x_plot = range(len(optics_instance.labels_))
fig = plt.figure(figsize=(14, 6))
plt.bar(x_plot, optics_instance.reachability_[optics_instance.ordering_],
align='edge', width=1, color='black')
if x_lims is None:
plt.xlim(0, len(x_plot))
else:
plt.xlim(*x_lims)
plt.xlabel('Ordered samples', fontsize=20)
plt.locator_params(axis='x', nbins=20)
plt.ylabel('Reachability', fontsize=20)
[docs]def calculate_k_distances(dataset, k, algorithm='auto', leaf_size=30,
metric='minkowski', p=2, metric_params=None):
""" Calculate and return the k-NN distances for each point in the data set.
It uses sklearn.neighbors.NearestNeighbors, so look at the documentation
for the parameters meaning.
Parameters
----------
dataset : 2D array
the dataset
k : int
the nearest neighbor number to consider
Returns
-------
k_distances: 1D array
the k-NN distances for each point in the dataset
sorted in descending order.
"""
neigh = NearestNeighbors(n_neighbors=k, algorithm=algorithm,
leaf_size=leaf_size, metric=metric,
p=p, metric_params=metric_params)
neigh = neigh.fit(dataset)
distances, indices = neigh.kneighbors(dataset)
k_distances = distances[:, k - 1]
return -np.sort(-k_distances)