Skip to content

Commit

Permalink
best IM model search (unoptimized) and print out ion mobility as feature
Browse files Browse the repository at this point in the history
  • Loading branch information
yangkl96 committed Jul 16, 2024
1 parent 272e9d8 commit 036dab0
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 7 deletions.
11 changes: 10 additions & 1 deletion src/main/java/External/KoinaMethods.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public class KoinaMethods {
public HashSet<String> peptideSet = new HashSet<>();
public HashMap<String, LinkedList<Integer>> scanNums = new HashMap<>();
public HashMap<String, LinkedList<String>> peptides = new HashMap<>();
public HashSet<String> peptideSetIM = new HashSet<>();
public HashMap<String, LinkedList<Integer>> scanNumsIM = new HashMap<>();
public HashMap<String, LinkedList<String>> peptidesIM = new HashMap<>();

//decoys
public HashSet<String> peptideSetDecoys = new HashSet<>();
Expand All @@ -63,10 +66,16 @@ public void getTopPeptides() throws IOException {
for (int j = 0; j < pmMatcher.pinFiles.length; j++) {
File pinFile = pmMatcher.pinFiles[j];
PinReader pinReader = new PinReader(pinFile.getAbsolutePath());
LinkedList[] topPSMs = pinReader.getTopPSMs(numTopPSMs);
LinkedList[] topPSMs = pinReader.getTopPSMs(numTopPSMs, false);
peptideSet.addAll(topPSMs[0]);
scanNums.put(pmMatcher.mzmlFiles[j].getName(), topPSMs[1]);
peptides.put(pmMatcher.mzmlFiles[j].getName(), topPSMs[0]);
if (Constants.useIM) {
LinkedList[] topPSMsIM = pinReader.getTopPSMs(numTopPSMs, true);
peptideSetIM.addAll(topPSMsIM[0]);
scanNumsIM.put(pmMatcher.mzmlFiles[j].getName(), topPSMsIM[1]);
peptidesIM.put(pmMatcher.mzmlFiles[j].getName(), topPSMsIM[0]);
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/main/java/Features/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ private static HashMap<String, String> makeCamelToUnderscore() {
map.put("predictedRT", "predicted_RT");
map.put("deltaIMLOESS", "delta_IM_loess");
map.put("deltaIMLOESSnormalized", "delta_IM_loess_normalized");
map.put("ionmobility", "ion_mobility");
map.put("IMprobabilityUnifPrior", "IM_probability_unif_prior");
map.put("predictedIM", "predicted_IM");
map.put("detectProtSpearmanDiff", "detect_prot_spearman_diff");
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/Features/FeatureCalculator.java
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ public void run() {
break;
// case "predictedIM":
// break;
// case "ionmobility":
// break;
case "ionmobility":
break;
// case "y_matched_intensity":
// break;
// case "b_matched_intensity":
Expand Down
140 changes: 138 additions & 2 deletions src/main/java/Features/MainClass.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import External.KoinaModelCaller;
import External.NCEcalibrator;
import kotlin.jvm.functions.Function1;
import org.checkerframework.checker.units.qual.A;

import java.io.*;
import java.lang.reflect.Field;
Expand Down Expand Up @@ -507,7 +506,7 @@ public static void main(String[] args) throws Exception {
float expRT = msn.RT;
predRTs.add(predRT);
expRTs.add(expRT);
} catch (Exception e) {}
} catch (Exception ignored) {}
}
}
double[][] rts = new double[2][expRTs.size()];
Expand Down Expand Up @@ -686,7 +685,136 @@ public static void main(String[] args) throws Exception {
Constants.rtModel = "";
}

HashMap<String, float[]> datapointsIM = new HashMap<>();
if (Constants.useIM) {
if (Constants.findBestImModel) {
printInfo("Searching for best IM model for your data");
ArrayList<String> consideredModels = new ArrayList<>();
String[] imSearchModels = Constants.imSearchModelsString.split(",");
for (String model : imSearchModels) {
if (modelMapper.containsKey(model.toLowerCase())) {
consideredModels.add(modelMapper.get(model.toLowerCase()));
} else {
printError("No IM model called " + model + ". Exiting.");
System.exit(0);
}
}
printInfo("Searching the following models: ");
printInfo(String.valueOf(consideredModels));

String jsonOutFolder = Constants.outputDirectory + File.separator + "best_model";
MyFileUtils.createWholeDirectory(jsonOutFolder);

for (String model : consideredModels) {
PredictionEntryHashMap imPreds = null;
if (model.equals("DIA-NN")) { //mode for DIA-NN
MyFileUtils.createWholeDirectory(jsonOutFolder + File.separator + model);

km.writeFullPeptideFile(jsonOutFolder + File.separator + model + File.separator +
"spectraRT_full.tsv", model, km.peptideSetIM);
String inputFile = jsonOutFolder + File.separator + model + File.separator + "spectraRT.tsv";
FileWriter myWriter = new FileWriter(inputFile);
myWriter.write("peptide" + "\t" + "charge\n");
HashSet<String> peptides = new HashSet<>();
for (String pep : km.peptideSetIM) {
String[] pepSplit = pep.split(",");
String line = pepSplit[1] + "\t" + pepSplit[0].split("\\|")[1] + "\n";
if (!peptides.contains(line)) {
myWriter.write(line);
peptides.add(line);
}
}
myWriter.close();

String predFileString = DiannModelCaller.callModel(inputFile, false);
imPreds = LibraryPredictionMapper.createLibraryPredictionMapper(
predFileString, "DIA-NN", executorService).getPreds();
} else { //mode for koina
HashSet<String> allHits = km.writeFullPeptideFile(
jsonOutFolder + File.separator + model + "_full.tsv", model,
km.peptideSetIM);
imPreds = km.getKoinaPredictions(allHits, model, 30,
jsonOutFolder + File.separator + model,
jsonOutFolder + File.separator + model + "_full.tsv");
}

//collect the data for testing
ArrayList<Float> expIMs = new ArrayList<>();
ArrayList<Float> predIMs = new ArrayList<>();
for (int j = 0; j < pmMatcher.mzmlReaders.length; j++) {
MzmlReader mzmlReader = pmMatcher.mzmlReaders[j];
LinkedList<Integer> thisScanNums = km.scanNumsIM.get(pmMatcher.mzmlFiles[j].getName());
LinkedList<String> thisPeptides = km.peptidesIM.get(pmMatcher.mzmlFiles[j].getName());

for (int k = 0; k < thisScanNums.size(); k++) {
try {
int scanNum = thisScanNums.get(k);
String peptide = thisPeptides.get(k).split(",")[0];
MzmlScanNumber msn = mzmlReader.getScanNumObject(scanNum);

float predIM = imPreds.get(peptide).IM;
float expIM = msn.IM;
predIMs.add(predIM);
expIMs.add(expIM);
} catch (Exception ignored) {}
}
}
double[][] ims = new double[2][expIMs.size()];
for (int i = 0; i < expIMs.size(); i++) {
//put pred first so it can be normalized to exp scale
ims[0][i] = Math.round(predIMs.get(i) * 100d) / 100d;
ims[1][i] = Math.round(expIMs.get(i) * 100d) / 100d;
}

//testing
String[] bandwidths = Constants.loessBandwidth.split(",");
float[] floatBandwidths = new float[bandwidths.length];
for (int i = 0; i < bandwidths.length; i++) {
floatBandwidths[i] = Float.parseFloat(bandwidths[i]);
}
Object[] bandwidth_loess_imdiffs = StatMethods.gridSearchCV(ims, floatBandwidths);
Function1<Double, Double> loess = (Function1<Double, Double>) bandwidth_loess_imdiffs[1];
float[] imdiffs = (float[]) bandwidth_loess_imdiffs[2];
float mse = StatMethods.meanSquaredError(imdiffs);
printInfo(model + " has root mean squared error of " +
String.format("%.4f", Math.sqrt(mse)));
Arrays.sort(imdiffs);
datapointsIM.put(model, imdiffs);
}

//get best model //TODO: test which method is more accurate
HashMap<String, Integer> votes = new HashMap<>();
for (String model : datapointsIM.keySet()) {
votes.put(model, 0);
}
for (int i = 0; i < 10; i++) {
String bestModel = "";
float bestImDiff = Float.MAX_VALUE;

for (String model : datapointsIM.keySet()) {
float[] imDiffs = datapointsIM.get(model);
float deltaIM = imDiffs[imDiffs.length - 1 - i];
if (deltaIM < bestImDiff) {
bestImDiff = deltaIM;
bestModel = model;
}
}

votes.put(bestModel, votes.get(bestModel) + 1);
}

int mostVotes = 0;
for (String model : votes.keySet()) {
int vote = votes.get(model);
if (vote > mostVotes) {
mostVotes = vote;
Constants.imModel = model;
}
}

printInfo("IM model chosen is " + Constants.imModel);
}

if (Constants.imModel.isEmpty()) {
Constants.imModel = "DIA-NN";
}
Expand Down Expand Up @@ -1033,6 +1161,7 @@ public static void main(String[] args) throws Exception {
intersection.retainAll(Constants.imFeatures);
if (intersection.isEmpty()) {
featureLL.add("deltaIMLOESS");
featureLL.add("ionmobility");
}
} else {
for (String feature : Constants.imFeatures) {
Expand Down Expand Up @@ -1403,6 +1532,7 @@ public static void main(String[] args) throws Exception {
List<String> lines = Files.readAllLines(Paths.get(Constants.paramsList));
AtomicBoolean spectrafound = new AtomicBoolean(false);
AtomicBoolean rtfound = new AtomicBoolean(false);
AtomicBoolean imfound = new AtomicBoolean(false);

// Stream the list, and if a line starts with the specified prefix,
// replace everything after the prefix with newSuffix
Expand All @@ -1415,6 +1545,9 @@ public static void main(String[] args) throws Exception {
} else if (line.trim().startsWith("rtModel")) {
rtfound.set(true);
return "rtModel=" + Constants.rtModel;
} else if (line.trim().startsWith("imModel")) {
imfound.set(true);
return "imModel=" + Constants.imModel;
}
return line;
})
Expand All @@ -1427,6 +1560,9 @@ public static void main(String[] args) throws Exception {
if (!rtfound.get()) {
modifiedLines.add("rtModel=" + Constants.rtModel);
}
if (!imfound.get()) {
modifiedLines.add("imModel=" + Constants.imModel);
}

// Write the modified lines back to the file
Files.write(Paths.get(Constants.paramsList), modifiedLines);
Expand Down
12 changes: 11 additions & 1 deletion src/main/java/Features/PinReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,18 @@ public String[] createJSON(File mzmlFile, PinMzmlMatcher pmm, String modelFormat
return peps.toArray(new String[0]);
}

public LinkedList[] getTopPSMs(int num) throws IOException {
public LinkedList[] getTopPSMs(int num, boolean charge2forIM) throws IOException {
LinkedList<String> PSMs = new LinkedList<>();
LinkedList<Integer> scanNums = new LinkedList<>();

//quicker way of getting top NCE
PriorityQueue<Float> topEscores = new PriorityQueue<>(num, (a, b) -> Float.compare(b, a));
while (next(true)) {
if (charge2forIM) {
if (getColumn("charge_2").equals("0")) {
continue;
}
}
float escore = Float.parseFloat(getEScore());
if (topEscores.size() < num) {
topEscores.offer(escore);
Expand All @@ -488,6 +493,11 @@ public LinkedList[] getTopPSMs(int num) throws IOException {
float eScoreCutoff = topEscores.peek();

while (next(true) && PSMs.size() < num) {
if (charge2forIM) {
if (getColumn("charge_2").equals("0")) {
continue;
}
}
float escore = Float.parseFloat(getEScore());
if (escore <= eScoreCutoff) {
PeptideFormatter pf = getPep();
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/Features/PinWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ public void write() throws IOException {
formattedWrite("predicted_IM", pepObj.IM);
break;
case "ionmobility":
formattedWrite("ionmobility", pepObj.scanNumObj.IM);
formattedWrite("ion_mobility", pepObj.scanNumObj.IM);
break;
case "y_matched_intensity":
formattedWrite("y_matched_intensity", pepObj.matchedIntensities.get("y"));
Expand Down

0 comments on commit 036dab0

Please sign in to comment.