-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'fork-origin/devel' into devel
- Loading branch information
Showing
3 changed files
with
113 additions
and
23 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import numpy as np | ||
from sklearn.linear_model import LogisticRegression | ||
from os.path import join | ||
import quapy as qp | ||
from quapy.protocol import UPP | ||
from quapy.method.aggregative import KDEyML | ||
|
||
DEBUG = True | ||
|
||
qp.environ["SAMPLE_SIZE"] = 100 if DEBUG else 500 | ||
val_repeats = 100 if DEBUG else 500 | ||
test_repeats = 100 if DEBUG else 500 | ||
if DEBUG: | ||
qp.environ["DEFAULT_CLS"] = LogisticRegression() | ||
|
||
test_results = {} | ||
val_choice = {} | ||
|
||
bandwidth_range = np.linspace(0.01, 0.20, 20) | ||
if DEBUG: | ||
bandwidth_range = np.linspace(0.01, 0.20, 10) | ||
|
||
def datasets(): | ||
for dataset_name in qp.datasets.UCI_MULTICLASS_DATASETS[:4]: | ||
dataset = qp.datasets.fetch_UCIMulticlassDataset(dataset_name) | ||
if DEBUG: | ||
dataset = dataset.reduce(random_state=0) | ||
yield dataset | ||
|
||
|
||
def experiment_dataset(dataset): | ||
train, test = dataset.train_test | ||
test_gen = UPP(test, repeats=test_repeats) | ||
|
||
# bandwidth chosen during model selection in validation | ||
train_tr, train_va = train.split_stratified(random_state=0) | ||
kdey = KDEyML(random_state=0) | ||
modsel = qp.model_selection.GridSearchQ( | ||
model=kdey, | ||
param_grid={'bandwidth': bandwidth_range}, | ||
protocol=UPP(train_va, repeats=val_repeats), | ||
refit=False, | ||
n_jobs=-1 | ||
).fit(train_tr) | ||
chosen_bandwidth = modsel.best_params_['bandwidth'] | ||
modsel_choice = float(chosen_bandwidth) | ||
|
||
# results in test | ||
print(f"testing KDEy in {dataset.name}") | ||
dataset_results = [] | ||
for b in bandwidth_range: | ||
kdey = KDEyML(bandwidth=b, random_state=0) | ||
kdey.fit(train) | ||
|
||
mae = qp.evaluation.evaluate(kdey, protocol=test_gen, error_metric='mae', verbose=True) | ||
print(f'bandwidth={b}: {mae:.5f}') | ||
dataset_results.append((float(b), float(mae))) | ||
|
||
return modsel_choice, dataset_results | ||
|
||
def plot_bandwidth(val_choice, test_results): | ||
for dataset_name in val_choice.keys(): | ||
import matplotlib.pyplot as plt | ||
|
||
bandwidths, results = zip(*test_results[dataset_name]) | ||
|
||
# Crear la gráfica | ||
plt.figure(figsize=(8, 6)) | ||
|
||
# Graficar los puntos de datos | ||
plt.plot(bandwidths, results, marker='o') | ||
|
||
# Agregar la línea vertical en bandwidth_chosen | ||
plt.axvline(x=val_choice[dataset_name], color='r', linestyle='--', label=f'Bandwidth elegido: {val_choice[dataset_name]}') | ||
|
||
# Agregar etiquetas y título | ||
plt.xlabel('Bandwidth') | ||
plt.ylabel('Resultado') | ||
plt.title('Gráfica de Bandwidth vs Resultado') | ||
|
||
# Mostrar la leyenda | ||
plt.legend() | ||
|
||
# Mostrar la gráfica | ||
plt.grid(True) | ||
plt.show() | ||
|
||
|
||
for dataset in datasets(): | ||
if DEBUG: | ||
result_path = f'./results/debug/{dataset.name}.pkl' | ||
else: | ||
result_path = f'./results/{dataset.name}.pkl' | ||
|
||
modsel_choice, dataset_results = qp.util.pickled_resource(result_path, experiment_dataset, dataset) | ||
val_choice[dataset.name] = modsel_choice | ||
test_results[dataset.name] = dataset_results | ||
|
||
print(f'Dataset = {dataset.name}') | ||
print(modsel_choice) | ||
print(dataset_results) | ||
|
||
plot_bandwidth(val_choice, test_results) | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters