Skip to content

Commit

Permalink
new classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Jul 3, 2024
1 parent 6480ec4 commit cb9e267
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions bengrn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,39 +286,50 @@ def train_classifier(
adj, da, random_state=0, train_size=train_size, shuffle=shuffle
)
print("doing classification....")
clf = LogisticRegression(
penalty="l1",
C=C,
solver="saga",
from sklearn.linear_model import RidgeClassifier

clf = RidgeClassifier(
alpha=C,
fit_intercept=False,
class_weight=class_weight,
# solver="saga",
max_iter=max_iter,
n_jobs=8,
fit_intercept=False,
**kwargs,
positive=True,
)
# clf = LogisticRegression(
# penalty="l1",
# C=C,
# solver="saga",
# class_weight=class_weight,
# max_iter=max_iter,
# n_jobs=8,
# fit_intercept=False,
# verbose=10,
# **kwargs,
# )
clf.fit(X_train, y_train)
pred = clf.predict(X_test)
epr = compute_epr(clf, X_test, y_test)
# epr = compute_epr(clf, X_test, y_test)
metrics = {
"used_heads": (clf.coef_ != 0).sum(),
"precision": (pred[y_test == 1] == 1).sum() / (pred == 1).sum(),
"random_precision": y_test.sum() / len(y_test),
"recall": (pred[y_test == 1] == 1).sum() / y_test.sum(),
"predicted_true": pred.sum(),
"number_of_true": y_test.sum(),
"epr": epr,
# "epr": epr,
}
if doplot:
print("metrics", metrics)
PrecisionRecallDisplay.from_estimator(
clf, X_test, y_test, plot_chance_level=True
)
plt.show()
adj = grn.varp["GRN"]
if return_full:
grn.varp["classified"] = clf.predict_proba(
adj.reshape(-1, adj.shape[-1])
).reshape(len(grn.var), len(grn.var), 2)[:, :, 1]
adj = grn.varp["GRN"]
grn.varp["classified"] = clf.predict(adj.reshape(-1, adj.shape[-1])).reshape(
len(grn.var), len(grn.var), 2
)[:, :, 1]
return grn, metrics, clf


Expand Down

0 comments on commit cb9e267

Please sign in to comment.