From 13d5f0fac1ff667f0e8af8d49bf66189d21774ad Mon Sep 17 00:00:00 2001 From: Han Lin Mai Date: Thu, 30 Nov 2023 16:21:29 +0100 Subject: [PATCH 1/2] Add utilities to compute structure clusters --- structuretoolkit/analyse/clustering.py | 104 +++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 structuretoolkit/analyse/clustering.py diff --git a/structuretoolkit/analyse/clustering.py b/structuretoolkit/analyse/clustering.py new file mode 100644 index 000000000..f6f23c0ff --- /dev/null +++ b/structuretoolkit/analyse/clustering.py @@ -0,0 +1,104 @@ +import numpy as np +from scipy.cluster.hierarchy import linkage, fcluster +from ase.io import read +from ase.atoms import Atoms +from ase.visualize import view +import matplotlib.pyplot as plt + +def compute_cluster_labels(structure, num_clusters): + """ + Compute hierarchical clustering labels for an ASE Atoms structure. + + Use case: Identification of inherently different parts of a single structure, i.e. separate slabs, specific phases, etc. + Atomic distances are the sole defining metric used for clustering. + + Parameters: + structure (Atoms): ASE Atoms object. + num_clusters (int): Number of clusters to form. + + Returns: + np.ndarray: Cluster labels for each atom. + """ + if isinstance(structure, Atoms): + # If structure is an ASE Atoms object, use it directly + struct = structure + else: + raise ValueError("Invalid input for structure. Please provide an ASE Atoms object.") + + # Calculate the distance matrix + distance_matrix = struct.get_all_distances(mic=True) + + # Perform hierarchical clustering + linkage_matrix = linkage(distance_matrix, method='ward') + + # Get cluster labels for the specified number of clusters + cluster_labels = fcluster(linkage_matrix, num_clusters, criterion='maxclust') + + return cluster_labels + +def ase_view_clusters(structure, n_clusters, target_cluster_label=1): + """ + Visualize a specific cluster in an ASE Atoms structure. + + Parameters: + structure (Atoms): ASE Atoms object. + target_cluster_label (int): Target cluster label to visualize (NOTE: STARTS AT 1, NOT 0 (scipy)) + """ + cluster_labels = compute_cluster_labels(structure, n_clusters) + # Print or visualize the indices of the specified cluster label + indices_of_cluster = np.where(cluster_labels == target_cluster_label)[0] + # Visualize the cluster using ASE's view function + view(structure[indices_of_cluster]) + +def plot_clusters(structure, n_clusters=1, projection=[1, 2], figsize=(30, 10)): + """ + Plot clusters in a 2D scatter plot based on hierarchical clustering. + + Parameters: + structure (Atoms): ASE Atoms object. + n_clusters (int): Number of clusters to form. + projection (list): List of two integers specifying axes for the scatter plot. + figsize (tuple): Figure size. + + Returns: + None + """ + fig, ax = plt.subplots(figsize=figsize) + + # Compute cluster labels + cluster_labels = compute_cluster_labels(structure, n_clusters) + + for cluster_label in np.unique(cluster_labels): + cluster_data = structure.positions[cluster_labels == cluster_label] + ax.scatter(cluster_data[:, projection[0]], cluster_data[:, projection[1]], label=f'Cluster {cluster_label}') + + ax.set_xlabel(f'Axis {projection[0]}') + ax.set_ylabel(f'Axis {projection[1]}') + ax.set_title(f'2D Projection (Axis {projection[0]}-{projection[1]}) \nHierarchical Clusters ') + ax.legend(loc=[1.05,0.9]) + + # Set aspect ratio to be equal + ax.set_aspect('equal') + + # Set axis limits to be tight + ax.autoscale() + +def get_structure_clusters(structure, n_clusters=2): + """ + Split an ASE Atoms structure into multiple structures based on hierarchical clustering. + + Parameters: + structure (Atoms): ASE Atoms object. + n_clusters (int): Number of clusters to form. + + Returns: + list: List of ASE Atoms structures, each corresponding to a cluster. + """ + # Returns a list of structures + cluster_labels = compute_cluster_labels(structure, n_clusters) + struct_list = [] + for target_cluster_label in np.unique(cluster_labels): + target_cluster_label = 1 + indices_of_cluster = np.where(cluster_labels == target_cluster_label)[0] + struct_list.append(structure.copy()[indices_of_cluster]) + return struct_list From e90d7306ccbf21834690bfb5ef8276161e493456 Mon Sep 17 00:00:00 2001 From: pyiron-runner Date: Wed, 14 Feb 2024 13:24:21 +0000 Subject: [PATCH 2/2] Format black --- structuretoolkit/analyse/clustering.py | 38 +++++++++++++++++--------- structuretoolkit/analyse/strain.py | 1 - structuretoolkit/analyse/symmetry.py | 1 - 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/structuretoolkit/analyse/clustering.py b/structuretoolkit/analyse/clustering.py index f6f23c0ff..f3a65e561 100644 --- a/structuretoolkit/analyse/clustering.py +++ b/structuretoolkit/analyse/clustering.py @@ -5,10 +5,11 @@ from ase.visualize import view import matplotlib.pyplot as plt + def compute_cluster_labels(structure, num_clusters): """ Compute hierarchical clustering labels for an ASE Atoms structure. - + Use case: Identification of inherently different parts of a single structure, i.e. separate slabs, specific phases, etc. Atomic distances are the sole defining metric used for clustering. @@ -23,19 +24,22 @@ def compute_cluster_labels(structure, num_clusters): # If structure is an ASE Atoms object, use it directly struct = structure else: - raise ValueError("Invalid input for structure. Please provide an ASE Atoms object.") + raise ValueError( + "Invalid input for structure. Please provide an ASE Atoms object." + ) # Calculate the distance matrix distance_matrix = struct.get_all_distances(mic=True) # Perform hierarchical clustering - linkage_matrix = linkage(distance_matrix, method='ward') + linkage_matrix = linkage(distance_matrix, method="ward") # Get cluster labels for the specified number of clusters - cluster_labels = fcluster(linkage_matrix, num_clusters, criterion='maxclust') + cluster_labels = fcluster(linkage_matrix, num_clusters, criterion="maxclust") return cluster_labels + def ase_view_clusters(structure, n_clusters, target_cluster_label=1): """ Visualize a specific cluster in an ASE Atoms structure. @@ -49,7 +53,8 @@ def ase_view_clusters(structure, n_clusters, target_cluster_label=1): indices_of_cluster = np.where(cluster_labels == target_cluster_label)[0] # Visualize the cluster using ASE's view function view(structure[indices_of_cluster]) - + + def plot_clusters(structure, n_clusters=1, projection=[1, 2], figsize=(30, 10)): """ Plot clusters in a 2D scatter plot based on hierarchical clustering. @@ -70,19 +75,26 @@ def plot_clusters(structure, n_clusters=1, projection=[1, 2], figsize=(30, 10)): for cluster_label in np.unique(cluster_labels): cluster_data = structure.positions[cluster_labels == cluster_label] - ax.scatter(cluster_data[:, projection[0]], cluster_data[:, projection[1]], label=f'Cluster {cluster_label}') - - ax.set_xlabel(f'Axis {projection[0]}') - ax.set_ylabel(f'Axis {projection[1]}') - ax.set_title(f'2D Projection (Axis {projection[0]}-{projection[1]}) \nHierarchical Clusters ') - ax.legend(loc=[1.05,0.9]) + ax.scatter( + cluster_data[:, projection[0]], + cluster_data[:, projection[1]], + label=f"Cluster {cluster_label}", + ) + + ax.set_xlabel(f"Axis {projection[0]}") + ax.set_ylabel(f"Axis {projection[1]}") + ax.set_title( + f"2D Projection (Axis {projection[0]}-{projection[1]}) \nHierarchical Clusters " + ) + ax.legend(loc=[1.05, 0.9]) # Set aspect ratio to be equal - ax.set_aspect('equal') + ax.set_aspect("equal") # Set axis limits to be tight ax.autoscale() + def get_structure_clusters(structure, n_clusters=2): """ Split an ASE Atoms structure into multiple structures based on hierarchical clustering. @@ -94,7 +106,7 @@ def get_structure_clusters(structure, n_clusters=2): Returns: list: List of ASE Atoms structures, each corresponding to a cluster. """ - # Returns a list of structures + # Returns a list of structures cluster_labels = compute_cluster_labels(structure, n_clusters) struct_list = [] for target_cluster_label in np.unique(cluster_labels): diff --git a/structuretoolkit/analyse/strain.py b/structuretoolkit/analyse/strain.py index 3632c9dda..f9ebb6aaa 100644 --- a/structuretoolkit/analyse/strain.py +++ b/structuretoolkit/analyse/strain.py @@ -6,7 +6,6 @@ class Strain: - """ Calculate local strain of each atom following the Lagrangian strain tensor: diff --git a/structuretoolkit/analyse/symmetry.py b/structuretoolkit/analyse/symmetry.py index bebeeae37..2da660fee 100644 --- a/structuretoolkit/analyse/symmetry.py +++ b/structuretoolkit/analyse/symmetry.py @@ -24,7 +24,6 @@ class Symmetry(dict): - """ Return a class for operations related to box symmetries. Main attributes: