Skip to content
6 changes: 3 additions & 3 deletions pyprophet/cli/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,13 @@ def score(
ctx.obj["LOG_HEADER"],
)

# Validate file type and subsample ratio, subsample_ratio is currently only applicateble for "parquet_split", "parquet_split_multi". If this combination is not met, throw warning and set subsample_ratio to 1.0
# Validate file type and subsample ratio. OSW, parquet, parquet_split, and parquet_split_multi all support subsampling
if (
config.file_type not in ["parquet", "parquet_split", "parquet_split_multi"]
config.file_type not in ["osw", "parquet", "parquet_split", "parquet_split_multi"]
and subsample_ratio < 1.0
):
logger.warning(
"Semi-supervised learning on a subset of the data, and then applying the weights to the full data is currently only supported for `parquet_split` and `parquet_split_multi` files.\nFor `osw`, you need to manually subsample the osw using the `subsample` module.\nSetting subsample_ratio to 1.0.",
"Semi-supervised learning on a subset of the data, and then applying the weights to the full data is currently only supported for OSW, `parquet`, `parquet_split`, and `parquet_split_multi` files.\nFor TSV and other formats, you need to manually prepare a subsampled input file.\nSetting subsample_ratio to 1.0.",
)
config.subsample_ratio = 1.0

Expand Down
37 changes: 37 additions & 0 deletions pyprophet/io/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"""

import glob
import math
import os
import pickle
import sys
Expand Down Expand Up @@ -993,6 +994,7 @@ class BaseOSWReader(BaseReader):

Methods:
read(): Read data from the input file based on the alogorithm.
_init_duckdb_views(): Initialize DuckDB views with optional subsampling.
"""

def __init__(self, config: BaseIOConfig):
Expand All @@ -1006,6 +1008,41 @@ def read(self) -> pd.DataFrame:
"The read method must be implemented in subclasses of BaseOSWReader."
)

def _init_duckdb_views(self, con):
"""
Initialize DuckDB views for the OSW file with optional subsampling support.

Creates a TEMP table of sampled precursor IDs if subsample_ratio < 1.0,
which can be used by subclasses to filter feature queries.

Subclasses should call this method before creating feature views and then
filter views with: WHERE PRECURSOR_ID IN (SELECT PRECURSOR_ID FROM sampled_precursor_ids)
when self.subsample_ratio < 1.0.

Args:
con: DuckDB connection with OSW database attached as 'osw'
"""
# Create TEMP table of sampled precursor IDs (if needed)
if self.subsample_ratio < 1.0:
logger.info(
f"Subsampling data for semi-supervised learning. Ratio: {self.subsample_ratio:.2f}"
)
precursor_count = con.execute(
"SELECT COUNT(DISTINCT ID) FROM osw.PRECURSOR"
).fetchone()[0]
sample_size = max(1, math.ceil(precursor_count * self.subsample_ratio))
con.execute(
f"""
CREATE TEMP TABLE sampled_precursor_ids AS
SELECT DISTINCT ID AS PRECURSOR_ID
FROM osw.PRECURSOR
ORDER BY hash(ID)
LIMIT {sample_size}
"""
)
n = con.execute("SELECT COUNT(*) FROM sampled_precursor_ids").fetchone()[0]
logger.info(f"Sampled {n} precursor IDs")


@dataclass
class BaseOSWWriter(BaseWriter):
Expand Down
38 changes: 30 additions & 8 deletions pyprophet/io/scoring/osw.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def read(self) -> pd.DataFrame:
con.execute("INSTALL sqlite_scanner;")
con.execute("LOAD sqlite_scanner;")
con.execute(f"ATTACH DATABASE '{self.infile}' AS osw (TYPE sqlite);")
self._init_duckdb_views(con)
return self._read_using_duckdb(con)
except ModuleNotFoundError as e:
logger.warning(
Expand Down Expand Up @@ -148,15 +149,27 @@ def _fetch_tables_duckdb(self, con):
).fetchdf()
return tables

def _get_precursor_filter_clause(self):
"""
Return a WHERE/AND clause fragment for filtering by sampled precursor IDs when subsampling is enabled.
Returns empty string if no subsampling, otherwise returns a clause like:
" AND f.PRECURSOR_ID IN (SELECT PRECURSOR_ID FROM sampled_precursor_ids)"
"""
if self.subsample_ratio < 1.0:
return " AND f.PRECURSOR_ID IN (SELECT PRECURSOR_ID FROM sampled_precursor_ids)"
return ""

def _fetch_ms2_features_duckdb(self, con):
if not check_duckdb_table(con, "main", "FEATURE_MS2"):
raise click.ClickException(
f"MS2-level feature table not present in file.\nTable Info:\n{self._fetch_tables_duckdb(con)}"
)

filter_clause = self._get_precursor_filter_clause()

if self.glyco:
con.execute(
"""
f"""
CREATE OR REPLACE VIEW ms2_table AS
SELECT
fm.*,
Expand Down Expand Up @@ -189,11 +202,12 @@ def _fetch_ms2_features_duckdb(self, con):
WHERE t.DETECTING = 1
GROUP BY tpm.PRECURSOR_ID
) ts ON f.PRECURSOR_ID = ts.PRECURSOR_ID
WHERE 1=1{filter_clause}
"""
)
else:
con.execute(
"""
f"""
CREATE OR REPLACE VIEW ms2_table AS
SELECT
fm.*,
Expand All @@ -216,6 +230,7 @@ def _fetch_ms2_features_duckdb(self, con):
WHERE t.DETECTING = 1
GROUP BY tpm.PRECURSOR_ID
) ts ON f.PRECURSOR_ID = ts.PRECURSOR_ID
WHERE 1=1{filter_clause}
"""
)

Expand Down Expand Up @@ -243,17 +258,19 @@ def _fetch_ms1_features_duckdb(self, con):
rc = self.config.runner
glyco = rc.glyco
ipf_max_rank = rc.ipf_max_peakgroup_rank
filter_clause = self._get_precursor_filter_clause()

if not glyco:
con.execute(
"""
f"""
CREATE OR REPLACE VIEW ms1_table AS
SELECT fm.*, f.RUN_ID, f.PRECURSOR_ID, f.EXP_RT,
p.CHARGE AS PRECURSOR_CHARGE, p.DECOY,
f.RUN_ID || '_' || f.PRECURSOR_ID AS GROUP_ID
FROM osw.FEATURE_MS1 fm
INNER JOIN osw.FEATURE f ON fm.FEATURE_ID = f.ID
INNER JOIN osw.PRECURSOR p ON f.PRECURSOR_ID = p.ID
WHERE 1=1{filter_clause}
ORDER BY f.RUN_ID, p.ID, f.EXP_RT
"""
)
Expand Down Expand Up @@ -281,7 +298,7 @@ def _fetch_ms1_features_duckdb(self, con):
FROM osw.PRECURSOR_GLYCOPEPTIDE_MAPPING pgm
INNER JOIN osw.GLYCOPEPTIDE gp ON pgm.GLYCOPEPTIDE_ID = gp.ID
) g ON f.PRECURSOR_ID = g.PRECURSOR_ID
WHERE s.RANK <= {ipf_max_rank}
WHERE s.RANK <= {ipf_max_rank}{filter_clause}
ORDER BY f.RUN_ID, p.ID, f.EXP_RT
"""
)
Expand All @@ -303,6 +320,7 @@ def _fetch_transition_features_duckdb(self, con):
)

rc = self.config.runner
filter_clause = self._get_precursor_filter_clause()
con.execute(
f"""
CREATE OR REPLACE VIEW transition_table AS
Expand All @@ -324,7 +342,7 @@ def _fetch_transition_features_duckdb(self, con):
AND s.PEP <= {rc.ipf_max_peakgroup_pep}
AND ft.VAR_ISOTOPE_OVERLAP_SCORE <= {rc.ipf_max_transition_isotope_overlap}
AND ft.VAR_LOG_SN_SCORE > {rc.ipf_min_transition_sn}
AND p.DECOY = 0
AND p.DECOY = 0{filter_clause}
"""
)
df = con.execute(
Expand All @@ -342,8 +360,10 @@ def _fetch_alignment_features_duckdb(self, con):
raise click.ClickException(
f"MS2-level feature alignment table not present in file.\nTable Info:\n{self._fetch_tables_duckdb(con)}"
)

filter_clause = self._get_precursor_filter_clause()
con.execute(
"""
f"""
CREATE OR REPLACE VIEW alignment_table AS
SELECT
fa.ALIGNMENT_ID AS ALIGNMENT_ID, fa.RUN_ID,
Expand All @@ -359,6 +379,7 @@ def _fetch_alignment_features_duckdb(self, con):
fa.PEAK_INTENSITY_RATIO AS VAR_PEAK_INTENSITY_RATIO,
fa.ALIGNED_FEATURE_ID || '_' || fa.PRECURSOR_ID AS GROUP_ID
FROM osw.FEATURE_MS2_ALIGNMENT fa
WHERE 1=1{filter_clause}
ORDER BY fa.RUN_ID, fa.PRECURSOR_ID, fa.REFERENCE_RT
"""
)
Expand Down Expand Up @@ -692,10 +713,10 @@ def save_results(self, result, pi0):
def save_weights(self, weights):
"""
Save the weights to a SQLite database based on the classifier type.
If classifier is "LDA", weights are saved to PYPROPHET_WEIGHTS table.
If classifier is "LDA" or "SVM", weights are saved to PYPROPHET_WEIGHTS table.
If classifier is "XGBoost", weights are saved to PYPROPHET_XGB or GLYCOPEPTIDEPROPHET_XGB table based on glyco and level.
"""
if self.classifier == "LDA":
if self.classifier in ("LDA", "SVM"):
weights["level"] = self.level
con = sqlite3.connect(self.outfile)

Expand All @@ -722,6 +743,7 @@ def save_weights(self, weights):
# print(weights)

weights.to_sql("PYPROPHET_WEIGHTS", con, index=False, if_exists="append")
con.commit()

Comment thread
singjc marked this conversation as resolved.
elif self.classifier == "XGBoost":
con = sqlite3.connect(self.outfile)
Expand Down
Loading
Loading