diff --git a/bengrn/base.py b/bengrn/base.py index 03b39e0..fc5ed3f 100644 --- a/bengrn/base.py +++ b/bengrn/base.py @@ -286,19 +286,30 @@ def train_classifier( adj, da, random_state=0, train_size=train_size, shuffle=shuffle ) print("doing classification....") - clf = LogisticRegression( - penalty="l1", - C=C, - solver="saga", + from sklearn.linear_model import RidgeClassifier + + clf = RidgeClassifier( + alpha=C, + fit_intercept=False, class_weight=class_weight, + # solver="saga", max_iter=max_iter, - n_jobs=8, - fit_intercept=False, - **kwargs, + positive=True, ) + # clf = LogisticRegression( + # penalty="l1", + # C=C, + # solver="saga", + # class_weight=class_weight, + # max_iter=max_iter, + # n_jobs=8, + # fit_intercept=False, + # verbose=10, + # **kwargs, + # ) clf.fit(X_train, y_train) pred = clf.predict(X_test) - epr = compute_epr(clf, X_test, y_test) + # epr = compute_epr(clf, X_test, y_test) metrics = { "used_heads": (clf.coef_ != 0).sum(), "precision": (pred[y_test == 1] == 1).sum() / (pred == 1).sum(), @@ -306,7 +317,7 @@ def train_classifier( "recall": (pred[y_test == 1] == 1).sum() / y_test.sum(), "predicted_true": pred.sum(), "number_of_true": y_test.sum(), - "epr": epr, + # "epr": epr, } if doplot: print("metrics", metrics) @@ -314,11 +325,11 @@ def train_classifier( clf, X_test, y_test, plot_chance_level=True ) plt.show() - adj = grn.varp["GRN"] if return_full: - grn.varp["classified"] = clf.predict_proba( - adj.reshape(-1, adj.shape[-1]) - ).reshape(len(grn.var), len(grn.var), 2)[:, :, 1] + adj = grn.varp["GRN"] + grn.varp["classified"] = clf.predict(adj.reshape(-1, adj.shape[-1])).reshape( + len(grn.var), len(grn.var), 2 + )[:, :, 1] return grn, metrics, clf