Skip to content

Commit

Permalink
Expand ngram_range to 100
Browse files Browse the repository at this point in the history
  • Loading branch information
akikuno committed Oct 31, 2024
1 parent 0a6f661 commit b42e3c6
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions src/DAJIN2/core/preprocess/sequence_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,21 @@
###############################################################################


def parse_midsv_from_csv(csv_tags: list[list[str]]) -> str:
midsv_seq = []
for tag in csv_tags:
if tag.startswith("N") or tag.startswith("n"):
midsv_seq.append("N")
else:
midsv_seq.append("M")
return "".join(midsv_seq)
def convert_nm_tag(csv_tags: list[list[str]]) -> str:
"""Create a tag with N for bases that reads are truncated and M for bases that are mapped"""
return "".join(["N" if tag.startswith("N") or tag.startswith("n") else "M" for tag in csv_tags])


def detect_sequence_error_reads_in_control(ARGS) -> None:
# Convert CSV strings to MIDSV tags
midsv_control = io.read_jsonl(Path(ARGS.tempdir, ARGS.control_name, "midsv", "control", "control.jsonl"))
midsv_tags, qnames = zip(*[(parse_midsv_from_csv(m["CSSPLIT"].split(",")), m["QNAME"]) for m in midsv_control])
nm_tags, qnames = zip(*[(convert_nm_tag(m["CSSPLIT"].split(",")), m["QNAME"]) for m in midsv_control])

# Vectorize the MIDSV tags using TF-IDF with character-level 3-grams
vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(3, 3))
X = vectorizer.fit_transform(midsv_tags)
# Vectorize the MIDSV tags using TF-IDF with character-level N-grams
vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(100, 100))
X = vectorizer.fit_transform(nm_tags)
# Add a feature for the number of matches in the X
match_counts = np.array([tag.count("M") for tag in midsv_tags], dtype=int)
match_counts = np.array([tag.count("M") for tag in nm_tags], dtype=int)
X = hstack([X, match_counts.reshape(-1, 1)])

# Apply KMeans clustering for binary classification based on the similarity of MIDSV.
Expand All @@ -49,7 +44,7 @@ def detect_sequence_error_reads_in_control(ARGS) -> None:
match_counts = {0: [], 1: []}

# Count occurrences of "M" for each label
for midsv_tag, label in zip(midsv_tags, labels):
for midsv_tag, label in zip(nm_tags, labels):
match_counts[label].append(midsv_tag.count("M"))

# Determine which label corresponds to the sequences with fewer matches and treat them as sequence errors.
Expand Down Expand Up @@ -83,17 +78,17 @@ def detect_sequence_error_reads_in_sample(ARGS) -> None:

midsv_control = io.read_jsonl(Path(ARGS.tempdir, ARGS.control_name, "midsv", "control", "control.jsonl"))
midsv_errors = (m for m in midsv_control if m["QNAME"] in qnames_with_sequence_error_control)
midsv_tags_error = [parse_midsv_from_csv(m["CSSPLIT"].split(",")) for m in midsv_errors]
nm_tags_error = [convert_nm_tag(m["CSSPLIT"].split(",")) for m in midsv_errors]

# ランダムに100本のエラー配列を取得
random.seed(1)
midsv_tags_error = random.sample(midsv_tags_error, min(len(midsv_tags_error), 100))
nm_tags_error = random.sample(nm_tags_error, min(len(nm_tags_error), 100))

path_midsv_sample = Path(ARGS.tempdir, ARGS.sample_name, "midsv", "control", f"{ARGS.sample_name}.jsonl")
midsv_sample = io.read_jsonl(path_midsv_sample)
midsv_tags_sample = [parse_midsv_from_csv(m["CSSPLIT"].split(",")) for m in midsv_sample]
nm_tags_sample = [convert_nm_tag(m["CSSPLIT"].split(",")) for m in midsv_sample]

similarity_scores = cdist(midsv_tags_sample, midsv_tags_error, scorer=JaroWinkler.normalized_similarity)
similarity_scores = cdist(nm_tags_sample, nm_tags_error, scorer=JaroWinkler.normalized_similarity)
most_similar_scores = np.max(similarity_scores, axis=1)

midsv_sample = io.read_jsonl(path_midsv_sample)
Expand Down Expand Up @@ -124,7 +119,9 @@ def split_fastq_by_sequence_error(ARGS, is_control: bool = False) -> None:
else:
NAME = ARGS.sample_name

path_qnames_without_sequence_error = Path(ARGS.tempdir, NAME, "sequence_error", "qnames_without_sequence_error.txt")
path_qnames_without_sequence_error = Path(
ARGS.tempdir, NAME, "sequence_error", "qnames_without_sequence_error.txt"
)
qnames_without_error = set(path_qnames_without_sequence_error.read_text().splitlines())

path_fastq = Path(ARGS.tempdir, NAME, "fastq", f"{NAME}.fastq.gz")
Expand Down

0 comments on commit b42e3c6

Please sign in to comment.