Source code for clinamen.clustering.misc

# -*- coding: utf-8 -*-
""" Copyright 2020 Marco Arrigoni

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import defaultdict
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull
from scipy.spatial.qhull import QhullError

from sklearn.neighbors import NearestNeighbors


[docs]def group_structures_in_clusters(ordered_structures, cluster_labels): """ Group a list of structure names according to the cluster they belong to. Parameters ---------- ordered_structures : list Ordered list of structure names. The ordering is done by matching the dataset: the i-th element in the dataset is the structure corresponding to ``ordered_structures[i]`` cluster_labels : list ``cluster_labels[i]`` is the label of the cluster where ``ordered_structures[i]`` belongs to. Returns ------- groups : defaultdict keys are cluster labels and the values are lists with the structures belonging to that cluster. """ groups = defaultdict(list) for i, label in enumerate(cluster_labels): groups[label].append(ordered_structures[i]) return groups
[docs]def find_centroids(data, labels): """ Finds the centroids locations of the clusters. Parameters ---------- data : 2D array The dataset. Each row represents one structure labels : 1D array ``labels[i]`` is the index of the cluster where ``data[i]`` belongs. A label with value -1 is considered to represent noise. Its centroid will not be returned. Returns ------- centroids : dict the key are the clusters indices, the values are the coordinates of the centroid """ centroids = dict() unique_labels = np.unique(labels[labels != -1]) for lab in unique_labels: lmask = labels == lab points = data[lmask] centroids[lab] = points.mean(axis=0) return centroids
[docs]def get_structure_group_index(structure_name, groups): """ Given the name of a structure, returns the group index it belongs to. Parameters ---------- structure_name : string the name of the structure groups : dict key:value pairs of cluster indices and a list of the name of the structures belonging to that cluster Returns ------- key : int the index of the cluster """ for key in groups.keys(): if structure_name in groups[key]: return key
[docs]def write_clustered_structures(groups, key): """ Write on a text file all the structures belonging to a given cluster Parameters ---------- groups : dict key:value pairs of cluster indices and a list of the name of the structures belonging to that cluster key : int the cluster index Returns ------- fname : string the name of the just-written text file """ if key in groups.keys(): fname = 'clusters_' + str(key) + '.txt' with open(fname, 'w') as f: for item in groups[key]: f.write(item) f.write('\n') return fname
[docs]def plot_cluster_plot(data, labels, title, ordered_structures, plot_kwargs=dict(), cmap=cm.jet, show_names=True, plot_chull=False, plot_centroids=True, ax=None): """ Make a scatter plot of the clusters. Parameters ---------- data : 2D array The dataset. Each row represents one structure labels : 1D array ``labels[i]`` is the index of the cluster where ``data[i]`` belongs. A label with value -1 is considered to represent noise. Its points are represented by crosses. title : string the plot title ordered_structures : list the i-th element is the structure name for ``data[i]`` plot_kwargs : dict key:value pairs to fine-tune the plot cmap : matplotlib cmap instance. Default cm.jet the colormap to be used in the plot show_names : bool. Default True if True, the structure names will be shown in the plot plot_chull : bool. Default False if True, plots also the convex hull of points in the cluster plot_centroids: bool. Default True if True, the centroids of each cluster are also plotted ax : matplotlib Axes instance or None. Default is None the axes for the plot. If None, the current axes is taken. (TODO) """ figsize = plot_kwargs.get('figsize', (18, 8)) fig, ax = plt.subplots(figsize=figsize) ax.set_title(title) labels = np.array(labels) noise_mask = labels == -1 unique_labels = np.unique(labels[labels != -1]) clusters = np.array([data[labels == lab] for lab in unique_labels]) lengths = [len(clust) for clust in clusters] sindices = np.argsort(-np.array(lengths)) clusters = clusters[sindices] unique_labels = unique_labels[sindices] cmap_list = np.linspace(0, 1, len(unique_labels)) colors = cmap(cmap_list) if data.shape[1] > 1: ax.scatter(data[noise_mask, 0], data[noise_mask, 1], color='black', marker='+', alpha=0.3, zorder=0) else: ax.scatter(data[noise_mask, 0], np.zeros(data[noise_mask].shape[0]), color='black', marker='+', alpha=0.3, zorder=0) for i, col in enumerate(colors): lab = unique_labels[i] label = 'Cluster ' + str(lab) col = np.atleast_1d(col) points = clusters[i] if data.shape[1] > 1: ax.scatter(points[:, 0], points[:, 1], color=col, alpha=0.6, label=label, zorder=i + 1) if plot_chull: try: chull = ConvexHull(points[:, :2]) except QhullError: chull = None if chull is not None: for simplex in chull.simplices: ax.plot(points[simplex, 0], points[simplex, 1], color=col, linewidth=3, alpha=1) if plot_centroids: centroid = points.mean(axis=0) ax.scatter(centroid[0], centroid[1], color=col, marker='X', s=200, edgecolors='black', zorder = i + 10) else: ax.scatter(points[:, 0], np.zeros(points.shape[0]), color=col, alpha=0.6, label=label, zorder=i + 1) if show_names: for i, struct in enumerate(ordered_structures): name = struct.split('_')[-2:] if data.shape[1] > 1: ax.annotate(name, data[i, :2]) else: ax.annotate(name, (data[i, 0], 0.005 + 0.0005*i*(-1)**i)) leg = plt.legend() for lh in leg.legendHandles: lh.set_alpha(1) plt.show()
[docs]def make_reachability_plot(optics_instance, x_lims=None): """ Make the reachability plot from a trained scikit learn OPTICS instance Parameters ---------- optics_instance : scikit learn OPTICS instance a trained instance x_lims : tuple x limits to be plotted """ x_plot = range(len(optics_instance.labels_)) fig = plt.figure(figsize=(14, 6)) plt.bar(x_plot, optics_instance.reachability_[optics_instance.ordering_], align='edge', width=1, color='black') if x_lims is None: plt.xlim(0, len(x_plot)) else: plt.xlim(*x_lims) plt.xlabel('Ordered samples', fontsize=20) plt.locator_params(axis='x', nbins=20) plt.ylabel('Reachability', fontsize=20)
[docs]def calculate_k_distances(dataset, k, algorithm='auto', leaf_size=30, metric='minkowski', p=2, metric_params=None): """ Calculate and return the k-NN distances for each point in the data set. It uses sklearn.neighbors.NearestNeighbors, so look at the documentation for the parameters meaning. Parameters ---------- dataset : 2D array the dataset k : int the nearest neighbor number to consider Returns ------- k_distances: 1D array the k-NN distances for each point in the dataset sorted in descending order. """ neigh = NearestNeighbors(n_neighbors=k, algorithm=algorithm, leaf_size=leaf_size, metric=metric, p=p, metric_params=metric_params) neigh = neigh.fit(dataset) distances, indices = neigh.kneighbors(dataset) k_distances = distances[:, k - 1] return -np.sort(-k_distances)