diff --git a/bin/run.sh b/bin/run.sh index 968c9dd..27ffbff 100644 --- a/bin/run.sh +++ b/bin/run.sh @@ -1,6 +1,4 @@ #!/bin/bash -# -#conda activate tf_gpu input_h5ad=../data/example.h5ad outdir=./ @@ -8,7 +6,7 @@ outdir=./ ## train the model for learning_rate in 0.001 0.0001 do - python ./run_pred.py --input_h5ad $input_h5ad --outdir $outdir --train train --predict embedding --learning_rate $learning_rate - python ./run_pred.py --input_h5ad $input_h5ad --outdir $outdir --train train --predict embedding --learning_rate $learning_rate --dis dis + python ./run_pred.py --input_h5ad $input_h5ad --outdir $outdir --train train --predict embedding --learning_rate $learning_rate --group celltype + python ./run_pred.py --input_h5ad $input_h5ad --outdir $outdir --train train --predict embedding --learning_rate $learning_rate --group celltype --dis dis done diff --git a/bin/run_pred.py b/bin/run_pred.py index 9ca5ae3..a1f838a 100644 --- a/bin/run_pred.py +++ b/bin/run_pred.py @@ -6,7 +6,7 @@ import numpy as np; import pandas as pd from sklearn import preprocessing, metrics -from sklearn.linear_model import RidgeCV +from sklearn.neighbors import NearestNeighbors import scipy from scipy.sparse import csc_matrix, coo_matrix import anndata as ad @@ -25,14 +25,15 @@ Loss = reconstr_loss of each domain + prediction_error of co-assay data from one domain to the other. """ -def pred_normexp(autoencoder, rna_data, outdir, target_species, target_batch, output_prefix, batch_size=128): +def pred_normexp(autoencoder, rna_data_query, rna_data, target_species, dis, output_prefix, batch_size=128): """ - predict denoised scRNA-seq profiles that is projected to a reference (target species and target batch) so that they are directly comparable + predict denoised scRNA-seq profiles that is projected to the target species Parameters ---------- - rna_data: scRNA expression anndata format - target_species: the target species to translate all data to - target_batch: the target batch to translate all data to + rna_data: original scRNA data + rna_data_query: original scRNA query data used to predict the corresponding group and batch in the target species + target_species: the target species to translate the query data to + dis: indicates whether the discriminator is used in the model output_prefix: output prefix Output @@ -40,13 +41,18 @@ def pred_normexp(autoencoder, rna_data, outdir, target_species, target_batch, ou imputed_rna: ncell x ngenes (the file might be big since it denoises the original data) """ imputed_rna = {} - for batch_id in range(0, rna_data.shape[0]//batch_size +1): - index_range = list(range(batch_size*batch_id, min(batch_size*(batch_id+1), rna_data.shape[0]))) - target_encoding = np.tile(rna_data[(rna_data.obs.species==target_species) & (rna_data.obs.batch==target_batch),].obsm['encoding'].to_numpy()[1,:], (len(index_range),1)) - imputed_rna[batch_id] = autoencoder.predict_normexp(rna_data[index_range,:].X.todense(), rna_data[index_range,:].obsm['encoding'].to_numpy(), target_encoding); + # swap species factor to the target encoding + target_species_encoding = rna_data[rna_data.obs.species==target_species,][:rna_data_query.shape[0],].obsm['encoding'].to_numpy() + target_encoding = rna_data_query.obsm['encoding'].to_numpy() + target_encoding[:, :len(rna_data.obs.species.unique())] = target_species_encoding[:, :len(rna_data.obs.species.unique())] + if dis == 'dis': + target_encoding[:, -len(rna_data.obs.species.unique()): ] = target_species_encoding[:, -len(rna_data.obs.species.unique()): ] + for batch_id in range(0, rna_data_query.shape[0]//batch_size +1): + index_range = list(range(batch_size*batch_id, min(batch_size*(batch_id+1), rna_data_query.shape[0]))) + imputed_rna[batch_id] = autoencoder.predict_normexp(rna_data_query[index_range,:].X.todense(), rna_data_query[index_range,:].obsm['encoding'].to_numpy(), target_encoding[index_range,:]); imputed_rna = np.concatenate([v for k,v in sorted(imputed_rna.items())], axis=0) - np.savetxt(output_prefix +'imputation_'+str(target_species)+str(target_batch)+'.txt', imputed_rna, delimiter='\t', fmt='%1.10f') + np.savetxt(output_prefix +'imputation_'+str(target_species)+ '.txt', imputed_rna, delimiter='\t', fmt='%1.10f') @@ -73,16 +79,17 @@ def pred_embedding(autoencoder, rna_data, output_prefix, batch_size=128): np.savetxt(output_prefix+'embedding.txt', encoded_rna, delimiter='\t', fmt='%1.5f') batch_label = list(rna_data.obs.batch) - celltype_label = list(rna_data.obs.celltype) + group_label = list(rna_data.obs.group) species_label = list(rna_data.obs.species) sc_rna_combined_embedding = StandardScaler().fit_transform(encoded_rna) reducer = umap.UMAP(n_neighbors=10, min_dist=0.001) embedding = reducer.fit_transform(sc_rna_combined_embedding) - #print(embedding.shape) - umap_embedding_mat = np.concatenate((embedding, np.array(batch_label)[:, None], np.array(celltype_label)[:, None], np.array(species_label)[:, None]), axis=1) + umap_embedding_mat = np.concatenate((embedding, np.array(batch_label)[:, None], np.array(group_label)[:, None], np.array(species_label)[:, None]), axis=1) - np.savetxt(output_prefix+'_umap.txt', umap_embedding_mat, delimiter='\t', fmt='%s') + np.savetxt(output_prefix+'umap.txt', umap_embedding_mat, delimiter='\t', fmt='%s') + + ## plot UMAP #os.system('Rscript ./plot_umap.R '+output_prefix) @@ -160,98 +167,144 @@ def convert_batch_to_onehot(input_dataset_list, dataset_list, output_name=''): -def calc_cor_mmd(x, y, nsubsample = 0, logscale=True, norm_x = 'norm', norm_y = '', return_mmd = False, gamma=0): - """ - return correlation and MMD between original (x) and predicted (y) scRNA px_scale - gamma: 0: +def compute_lisi(X, label, perplexity = 30): """ - x = np.asarray(x) - y = np.asarray(y) - if norm_x == 'norm': - ## compare with normalized true profile - lib = x.sum(axis=1, keepdims=True) - x = x / lib - if norm_y == 'norm': - ## compare with normalized true profile - lib = y.sum(axis=1, keepdims=True) - y = y / lib - if logscale: - x = np.log1p(x) - y = np.log1p(y) - - if return_mmd is False: - #if nsubsample >0: - # mmd = mmd_rbf(x[np.random.choice(x.shape[0], size=nsubsample, replace=False),], y[np.random.choice(y.shape[0], size=nsubsample, replace=False),]) - #else: - # mmd = mmd_rbf(x, y) - #pearson_r_bulk, pearson_p_bulk = scipy.stats.pearsonr(np.mean(x, axis=0), np.mean(y, axis=0)) - pearson_r_bulk_list = [] - mmd_list = [] - for nrand in range(10): - mmd = mmd_rbf(x[np.random.choice(x.shape[0], size=nsubsample, replace=False),], y[np.random.choice(y.shape[0], size=nsubsample, replace=False),], gamma=gamma) - pearson_r_bulk, pearson_p_bulk = scipy.stats.pearsonr(np.mean(x[np.random.choice(x.shape[0], size=nsubsample, replace=False),], axis=0), np.mean(y[np.random.choice(y.shape[0], size=nsubsample, replace=False),], axis=0)) - pearson_r_bulk_list.append(pearson_r_bulk) - mmd_list.append(mmd) - - #return(pearson_r_bulk, mmd) - return(pearson_r_bulk_list, mmd_list) - else: - if nsubsample >0: - mmd = mmd_rbf(x[np.random.choice(x.shape[0], size=nsubsample, replace=False),], y[np.random.choice(y.shape[0], size=nsubsample, replace=False),], gamma=gamma) - else: - mmd = mmd_rbf(x, y, gamma=gamma) - return(mmd) + adapted from https://github.com/slowkow/harmonypy/blob/master/harmonypy/lisi.py + Compute the Local Inverse Simpson Index (LISI) for each column in metadata. + LISI is a statistic computed for each item (row) in the data matrix X. + The following example may help to interpret the LISI values. -def mmd_rbf(X, Y, gamma=0): - """MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2)) - Arguments: - X {[n_sample1, dim]} -- [X matrix] - Y {[n_sample2, dim]} -- [Y matrix] - Keyword Arguments: - gamma {float} -- [kernel parameter] (default: {1.0}) - Returns: - [scalar] -- [MMD value] - """ - ## calculate mean pairwise euclidean distance - if gamma == 0: - sigma = np.median(euclidean_distances(np.concatenate((X, Y)))) - gamma = 0.5/sigma**2 - XX = metrics.pairwise.rbf_kernel(X, X, gamma) - YY = metrics.pairwise.rbf_kernel(Y, Y, gamma) - XY = metrics.pairwise.rbf_kernel(X, Y, gamma) - return XX.mean() + YY.mean() - 2 * XY.mean() + Suppose one of the columns in metadata is a categorical variable with 3 categories. + - If LISI is approximately equal to 3 for an item in the data matrix, + that means that the item is surrounded by neighbors from all 3 + categories. + - If LISI is approximately equal to 1, then the item is surrounded by + neighbors from 1 category. + + The LISI statistic is useful to evaluate whether multiple datasets are + well-integrated by algorithms such as Harmony [1]. -def pred(outdir, input_h5ad, sim_url, target_species, target_batch, dispersion, embed_dim, nlayer, dropout_rate, learning_rate, hidden_frac, kl_weight, discriminator_weight, epsilon, patience, my_epochs, nepoch_warmup, nepoch_klstart, batch_size, train, evaluate, predict, test_species='', test_celltype='', dis=''): + [1]: Korsunsky et al. 2019 doi: 10.1038/s41592-019-0619-0 + """ + n_cells = len(label) + n_labels = len(label) + # We need at least 3 * n_neigbhors to compute the perplexity + knn = NearestNeighbors(n_neighbors = perplexity * 3, algorithm = 'kd_tree').fit(X) + distances, indices = knn.kneighbors(X) + # Don't count yourself + indices = indices[:,1:] + distances = distances[:,1:] + # Save the result + labels = pd.Categorical(label) + n_categories = len(labels.categories) + simpson = compute_simpson(distances.T, indices.T, labels, n_categories, perplexity) + lisi_df = np.mean(1 / simpson) + return lisi_df + + + +def compute_simpson( + distances: np.ndarray, + indices: np.ndarray, + labels: pd.Categorical, + n_categories: int, + perplexity: float, + tol: float=1e-5 +): """ - train/load the Polarbear model + adapted from https://github.com/slowkow/harmonypy/blob/master/harmonypy/lisi.py + """ + n = distances.shape[1] + P = np.zeros(distances.shape[0]) + simpson = np.zeros(n) + logU = np.log(perplexity) + # Loop through each cell. + for i in range(n): + beta = 1 + betamin = -np.inf + betamax = np.inf + # Compute Hdiff + P = np.exp(-distances[:,i] * beta) + P_sum = np.sum(P) + if P_sum == 0: + H = 0 + P = np.zeros(distances.shape[0]) + else: + H = np.log(P_sum) + beta * np.sum(distances[:,i] * P) / P_sum + P = P / P_sum + Hdiff = H - logU + n_tries = 50 + for t in range(n_tries): + # Stop when we reach the tolerance + if abs(Hdiff) < tol: + break + # Update beta + if Hdiff > 0: + betamin = beta + if not np.isfinite(betamax): + beta *= 2 + else: + beta = (beta + betamax) / 2 + else: + betamax = beta + if not np.isfinite(betamin): + beta /= 2 + else: + beta = (beta + betamin) / 2 + # Compute Hdiff + P = np.exp(-distances[:,i] * beta) + P_sum = np.sum(P) + if P_sum == 0: + H = 0 + P = np.zeros(distances.shape[0]) + else: + H = np.log(P_sum) + beta * np.sum(distances[:,i] * P) / P_sum + P = P / P_sum + Hdiff = H - logU + # distancesefault value + if H == 0: + simpson[i] = -1 + # Simpson's index + for label_category in labels.categories: + ix = indices[:,i] + q = labels[ix] == label_category + if np.any(q): + P_sum = np.sum(P[q]) + simpson[i] += P_sum * P_sum + return simpson + + + +def pred(outdir, input_h5ad, sim_url, target_species, target_batch, target_group, group, dispersion, embed_dim, nlayer, dropout_rate, learning_rate, hidden_frac, kl_weight, discriminator_weight, epsilon, patience, my_epochs, nepoch_warmup, nepoch_klstart, batch_size, train, predict, dis=''): + """ + train/load the Icebear model + Usage: cross-species alignment, and cross-species prediction on missing cell types/tissues + Parameters ---------- outdir: output directory - input_h5ad: input data, in h5ad format. obs variable should include species, batch - test_species: held out species for evaluation purpose - test_celltype: held out cell type for evaluation purpose - train: "train" or "", train model + input_h5ad: input data, in h5ad format. obs variable should include species and batch + + Output + ---------- + For cross-species alignment and denoising, no data needs to be held out as test set. + For cross-species imputation, we hold out a pre-specified cell type to evaluate the performance. """ os.system('mkdir -p '+ outdir) ## ====================================== - ## load data and convert batch, species to condition vectors obsm['encoding'] - """ - The script serves as two purposes - cross-species alignment/data denoising, or cross-species prediction on missing cell types/tissues - In cross-species alignment and denoising, no data needs to be held out as test set. - In cross-species imputation, we hold out a pre-specified cell type to evaluate the performance. - """ + ## load data and convert factors (e.g., batch, species) to condition vectors in obsm['encoding'] logging.info('Loading data...') nsubsample = 2000 rna_data = ad.read_h5ad(input_h5ad) - + ## make species and batch encoding using one-hot encoding for obs_i in ['species', 'batch']: batch_encoding = pd.DataFrame(convert_batch_to_onehot(list(rna_data.obs[obs_i]), dataset_list=list(rna_data.obs[obs_i].unique())).todense()) batch_encoding.index = rna_data.obs.index @@ -259,8 +312,19 @@ def pred(outdir, input_h5ad, sim_url, target_species, target_batch, dispersion, rna_data.obsm['encoding'] = batch_encoding else: rna_data.obsm['encoding'] = pd.concat([rna_data.obsm['encoding'], batch_encoding], axis=1) - rna_data.obs["tissue"] = rna_data.obs.celltype + + if group != '' and group in rna_data.obs.columns: + rna_data.obs['group'] = rna_data.obs[group] + else: + rna_data.obs['group'] = '' + + ## when needed, include organ as another factor similar to the batch factor, though this function can be achieved by assigning the organ/tissue column as batch + #if tissuefactor in rna_data.obs.columns: + # batch_encoding = pd.DataFrame(convert_batch_to_onehot(list(rna_data.obs[tissuefactor]), dataset_list=list(rna_data.obs[tissuefactor].unique())).todense()) + # batch_encoding.index = rna_data.obs.index + # rna_data.obsm['encoding'] = pd.concat([rna_data.obsm['encoding'], batch_encoding], axis=1) + ## append the one-hot-encoded species encoding for subsequent use if dis == 'dis': ## add one hot encoded species append_encoding = pd.DataFrame(convert_batch_to_onehot(list(rna_data.obs.species), dataset_list=list(rna_data.obs.species.unique())).todense()) @@ -276,17 +340,13 @@ def pred(outdir, input_h5ad, sim_url, target_species, target_batch, dispersion, ## ====================================== ## define train, validation and test - rna_data_test = rna_data[(rna_data.obs.species==target_species) & (rna_data.obs.batch==target_batch),] - rna_data_train = rna_data[~ ((rna_data.obs.species==target_species) & (rna_data.obs.batch==target_batch)),] + rna_data_test = rna_data[(rna_data.obs.species==target_species) & (rna_data.obs.batch==target_batch) & (rna_data.obs.group==target_group),] + rna_data_train = rna_data[~ ((rna_data.obs.species==target_species) & (rna_data.obs.batch==target_batch) & (rna_data.obs.group==target_group)),] rna_data_train.obs['rank'] = range(0, rna_data_train.shape[0]) random.seed(101) rna_data_val_index = random.sample(range(rna_data_train.shape[0]), min(int(rna_data_train.shape[0] * 0.1), nsubsample*10)) - ## define validation set with consideration of dataset info, if available - #rna_data_val_index = [] - #for dataset_i in rna_data_train.obs.dataset.unique(): - #rna_data_val_index.extend(random.sample(list(rna_data_train.obs.loc[rna_data_train.obs['dataset']==dataset_i]['rank']), min(int(sum(rna_data_train.obs.dataset==dataset_i) * 0.1), nsubsample))) rna_data_val = rna_data_train[rna_data_val_index,:] @@ -306,15 +366,14 @@ def pred(outdir, input_h5ad, sim_url, target_species, target_batch, dispersion, data_val = rna_data_val[val_index,:].X.tocsr() batch_val = rna_data_val[val_index,:].obsm['encoding'].to_numpy() - data_test = rna_data_test.X.tocsr() - batch_test = rna_data_test.obsm['encoding'].to_numpy() - print('== data imported ==') sys.stdout.flush() sim_metric_all = [] save_model=True + ## ====================================== + ## train the model if train == 'train': logging.info('Training model...') tf.reset_default_graph() @@ -335,60 +394,51 @@ def pred(outdir, input_h5ad, sim_url, target_species, target_batch, dispersion, plot_loss_epoch(sim_url, dis) - if evaluate=='evaluate' or predict=='predict': - logging.info('Loading model with dropout_rate=0...') - tf.reset_default_graph() - if dis=='': - autoencoder = RNAAE(input_dim_x=data_train.shape[1], batch_dim_x=batch_train.shape[1], embed_dim_x=embed_dim, dispersion=dispersion, nlayer=nlayer, dropout_rate=0, output_model=sim_url, learning_rate_x=learning_rate, nlabel=nlabel, discriminator_weight=discriminator_weight, epsilon=epsilon, hidden_frac=hidden_frac, kl_weight=kl_weight); - else: - autoencoder = RNAAEdis(input_dim_x=data_train.shape[1], batch_dim_x=batch_train.shape[1], embed_dim_x=embed_dim, dispersion=dispersion, nlayer=nlayer, dropout_rate=0, output_model=sim_url, learning_rate_x=learning_rate, nlabel=nlabel, discriminator_weight=discriminator_weight, epsilon=epsilon, hidden_frac=hidden_frac, kl_weight=kl_weight); - - iter_list, reconstr_loss_list, kl_loss_list, discriminator_loss_list, val_reconstr_loss_list, val_kl_loss_list, val_discriminator_loss_list = autoencoder.train(data_train, batch_train, data_val, batch_val, nepoch_warmup, patience, nepoch_klstart, output_model=sim_url, my_epochs=my_epochs, batch_size=batch_size, nlayer=nlayer, kl_weight=kl_weight, dropout_rate=dropout_rate, save_model=save_model); - - del data_train - del batch_train - del data_val - del batch_val - - if evaluate == 'evaluate': - logging.info('Evaluating predictions...') - ## get MMD based on validation set, this is used to select best model + ## calculate cross-species alignment performance on the validation set, this is used to select the best model + logging.info('Evaluating predictions and get the best performing model...') output_prefix = sim_url + '_eval_' - val_mmd_embedding = [] sc_val_embedding = autoencoder.predict_embedding(rna_data_val.X.todense(), rna_data_val.obsm['encoding'].to_numpy()) # standardize each dimension of embedding sc_val_embedding = (sc_val_embedding - sc_val_embedding.mean(axis=0)) / (sc_val_embedding.std(axis=0)) - # calculate pairwise MMD and get the mean + # calculate lisi score based on species labels species = list(rna_data_val.obs.species.unique()) - for i in range(len(species)-1): - for j in range(i, len(species)): - if i