ShuShu


import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sheshe import ShuShu

X, y = load_iris(return_X_y=True)
sh = ShuShu().fit(X, y)
sh.plot_classes(X, y)
plt.show()

Optimizador basado en gradiente que busca máximos locales de una puntuación escalar o probabilidades de clase. Ejecuta una optimización por clase cuando se proporcionan etiquetas y también puede operar sobre funciones de puntuación definidas por el usuario, devolviendo centroides de cada modo descubierto.

Formulación matemática

ShuShu realiza actualizaciones de ascenso por gradiente x_{t+1} = x_t + η∇f(x_t) hasta que ‖∇f(x_t)‖ < tol o se alcanza un número máximo de iteraciones.

Cuando no hay gradientes analíticos, los estima mediante SPSA: ∂f/∂x_i ≈ (f(x + cΔ) - f(x - cΔ))/(2cΔ_i) con perturbaciones de Rademacher Δ_i ∈ {−1,1}.

Ejemplo


from sheshe import ShuShu
ss = ShuShu()
ss.fit(X, y)
labels = ss.predict(X)

Ejemplos de uso


from sheshe import ShuShu

shu = ShuShu(random_state=0)
shu.fit(X, y)                      # fit

from sheshe import ShuShu

shu = ShuShu(random_state=0)
shu.fit_predict(X, y)              # fit_predict

from sheshe import ShuShu

shu = ShuShu(random_state=0)
shu.fit_transform(X, y)            # fit_transform

from sheshe import ShuShu

shu = ShuShu(random_state=0).fit(X, y)
shu.transform(X)                   # transform

from sheshe import ShuShu

shu = ShuShu(random_state=0).fit(X, y)
shu.predict(X)                     # predict

from sheshe import ShuShu

shu = ShuShu(random_state=0).fit(X, y)
shu.predict_proba(X)               # predict_proba

from sheshe import ShuShu

shu = ShuShu(random_state=0).fit(X, y)
shu.decision_function(X)           # decision_function

from sheshe import ShuShu

shu = ShuShu(random_state=0).fit(X, y)
shu.predict_regions(X)             # predict_regions

from sheshe import ShuShu

shu = ShuShu(random_state=0).fit(X, y)
shu.score(X, y)                    # score

from sheshe import ShuShu

shu = ShuShu(random_state=0).fit(X, y)
shu.save("shu.joblib")             # save

from sheshe import ShuShu

shu = ShuShu.load("shu.joblib")

Ejemplos adicionales


from sklearn.datasets import load_iris
from sheshe import ShuShu

X, y = load_iris(return_X_y=True)
sh = ShuShu(random_state=0).fit(X, y)
print(sh.summary_tables()[0][["class_label", "n_clusters"]])

import numpy as np
def paraboloid(Z):
    return -np.linalg.norm(Z - 1.0, axis=1)

sc = ShuShu(random_state=0).fit(np.random.rand(100, 2), score_fn=paraboloid)
print(sc.centroids_)

from sklearn.linear_model import LogisticRegression
model = LogisticRegression(max_iter=200).fit(X, y)
ShuShu(random_state=0).fit(X, y, score_model=model)

Parámetros

  • clusterer_factory (callable o None, por defecto None): fábrica que devuelve el optimizador interno. Cuando es None se usa una implementación por defecto.
  • random_state (int o None, por defecto None): semilla para reproducibilidad.
  • **clusterer_kwargs – argumentos adicionales enviados al optimizador interno.

Métodos

  • fit(X, y=None, score_fn=None, ...) – entrena el optimizador. Cuando se proporciona y, la función de puntuación se aprende por clase; de lo contrario debe proporcionarse score_fn.
  • fit_predict(X, y=None, **kwargs) – envoltorio de conveniencia que llama a fit y luego predict.
  • predict(X) – devuelve etiquetas de clase o ids de cluster.
  • predict_proba(X) – probabilidades de clase (tras entrenar con etiquetas).
  • decision_function(X) – puntuaciones de decisión en bruto.
  • transform(X) – matriz de pertenencia/afinidad.
  • fit_transform(X, y=None, **kwargs) – combina fit y transform.
  • plot_pairs(X, y=None, max_pairs=None, show_histograms=False) – gráficos de dispersión para pares de características con histogramas marginales opcionales.
  • plot_classes(X, y, grid_res=200, contour_levels=None, max_paths=20, show_paths=True) – visualiza superficies de puntuación por clase.
  • plot_pair_3d(X, y=None, features=(i, j), ax=None, fig=None, grid=64) – renderizado 3D de la función de puntuación.
  • get_cluster(cluster_id, with_geometry=False) – recupera información sobre un cluster descubierto.