Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding feature for search result diversity re-ranking and evaluation - submission for the IRDM group project #1

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
157 changes: 157 additions & 0 deletions src/uk/ac/ucl/panda/BatchGetDocTermStats.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package uk.ac.ucl.panda;

/**
* @author Yiwei Chen
* @author Yifei Rong
*/

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;

import uk.ac.ucl.panda.indexing.io.IndexReader;
import uk.ac.ucl.panda.retrieval.Searcher;
import uk.ac.ucl.panda.utility.io.DocNameExtractor;
import uk.ac.ucl.panda.utility.structure.TermFreqVector;

public class BatchGetDocTermStats extends GetDocTermStats {

private ArrayList<Integer> totalDocWords;

public BatchGetDocTermStats(String index) throws IOException,
ClassNotFoundException {
super(index);
}
public BatchGetDocTermStats() throws ClassNotFoundException, IOException {
super();
}

public ArrayList<HashMap<String, Integer>> GetDocLevelStats(HashSet<String> docIDs)
throws IOException, ClassNotFoundException {
IndexReader rdr = IndexReader.open(cindex);
Searcher search = new Searcher(cindex);
DocNameExtractor xt = new DocNameExtractor("docname");
ArrayList<HashMap<String, Integer>> results = new ArrayList<HashMap<String, Integer>>();
for (int j = 0; j < rdr.maxDoc(); j++) {
String docName = xt.docName(search, j);
if (docIDs.contains(docName)) {
HashMap<String, Integer> termstats = new HashMap<String, Integer>();
int docid = j;
if (rdr.isDeleted(docid)) {
return null;
}
TermFreqVector tTerms = null;
TermFreqVector bTerms = null;
tTerms = rdr.getTermFreqVector(docid, docDataField1);
bTerms = rdr.getTermFreqVector(docid, docDataField2);
if (tTerms != null) {
if (type == true) {
String Atterms[] = tTerms.getTerms();
int AtFreq[] = tTerms.getTermFrequencies();
for (int i = 0; i < Atterms.length; i++) {
String id = Atterms[i];
termstats.put(id, AtFreq[i]);
}
}
}
if (bTerms != null) {
if (type == true) {
String Abterms[] = bTerms.getTerms();
int AbFreq[] = bTerms.getTermFrequencies();
for (int i = 0; i < Abterms.length; i++) {
String id = Abterms[i];
if (termstats.containsKey(id)) {
// int updateScore = (Integer) ( (Integer)
// termstats.get(id) + AbFreq[i]);
termstats
.put(id, (Integer) ((Integer) termstats
.get(id) + AbFreq[i]));
} else {
// eprop.put(Abterms[i], AbFreq[i]);
termstats.put(id, AbFreq[i]);
}
}
}
}
results.add(termstats);
}
}
return results;
}

public ArrayList<HashMap<String, Integer>> GetDocLevelStats(HashMap<String, Integer> docIDs)
throws IOException, ClassNotFoundException {
IndexReader rdr = IndexReader.open(cindex);
Searcher search = new Searcher(cindex);
DocNameExtractor xt = new DocNameExtractor("docname");
totalDocWords = new ArrayList<Integer>();
ArrayList<HashMap<String, Integer>> results = new ArrayList<HashMap<String, Integer>>();
ArrayList<Integer> index = new ArrayList<Integer>();
for (int j = 0; j < rdr.maxDoc(); j++) {
String docName = xt.docName(search, j);
int totalWords = 0;
if (docIDs.containsKey(docName)) {
HashMap<String, Integer> termstats = new HashMap<String, Integer>();
int docid = j;
if (rdr.isDeleted(docid)) {
return null;
}
TermFreqVector tTerms = null;
TermFreqVector bTerms = null;
tTerms = rdr.getTermFreqVector(docid, docDataField1);
bTerms = rdr.getTermFreqVector(docid, docDataField2);
if (tTerms != null) {
if (type == true) {
String Atterms[] = tTerms.getTerms();
int AtFreq[] = tTerms.getTermFrequencies();
for (int i = 0; i < Atterms.length; i++) {
String id = Atterms[i];
termstats.put(id, AtFreq[i]);
totalWords += AtFreq[i];
}
}
}
if (bTerms != null) {
if (type == true) {
String Abterms[] = bTerms.getTerms();
int AbFreq[] = bTerms.getTermFrequencies();
for (int i = 0; i < Abterms.length; i++) {
String id = Abterms[i];
if (termstats.containsKey(id)) {
termstats.put(id, (Integer) ((Integer) termstats.get(id) + AbFreq[i]));
totalWords += AbFreq[i];
} else {
termstats.put(id, AbFreq[i]);
totalWords += AbFreq[i];
}
}
}
}
results.add(termstats);
index.add(docIDs.get(docName));
totalDocWords.add(totalWords);
}
}
ArrayList<HashMap<String, Integer>> tempVector = new ArrayList<HashMap<String, Integer>>();
ArrayList<Integer> tempWords = new ArrayList<Integer>();
tempVector = results;
tempWords = totalDocWords;

for(int i=0; i<results.size(); i++){
results.add(index.get(i), tempVector.get(i));
results.remove(index.get(i)+1);
totalDocWords.add(index.get(i), tempWords.get(i));
totalDocWords.remove(index.get(i)+1);
}
return results;
}

public ArrayList<Integer> getTotalDocWords() {
return totalDocWords;
}
public void setTotalDocWords(ArrayList<Integer> totalDocWords) {
this.totalDocWords = totalDocWords;
}

}
97 changes: 90 additions & 7 deletions src/uk/ac/ucl/panda/Panda.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
* -V --version print version information
* -i --index index a collection
* -b --batch retrieve for batch, must be followed by argument of the form a:i:b, where a is the starting value b is the end value and i is the increment
* -r --reranking perform a reranking evaluation, followed by an argument specifying the reranking method {mmr|portfolio},
* and the underlying scroing model is specified by -m or 0 by default.
* -br --batchreranking batch reranking, format: method:a:i:b, a:i:b is like that in -b.
* -e --evaluate evaluates the results
* -v --var get var and mean for each query
* -m --model specify model number (must be followed by an integer and used in conjunction with other arguments)
Expand Down Expand Up @@ -51,7 +54,22 @@ public class Panda {
* Specifies whether to perform trec_eval like evaluation.
*/
protected boolean evaluation;

/**
* Specifies whether to perform a reranking task and evaluation.
*/
protected boolean reranking;

/**
* Specifies batch reranking
*/
protected boolean batch_reranking;

/**
* The reranking method, "mmr" or "portfolio"
*/
protected String reranking_method;

/**
* Specifies batch.
*/
Expand Down Expand Up @@ -89,6 +107,10 @@ protected void usage() {
System.out.println(" -b --batch retrieve for batch, must be followed by argument of the form a:i:b, where a is the starting value,");
System.out.println(" b is the end value and i is the increment");
System.out.println(" -e --evaluate evaluates the results");

System.out.println(" -r --reranking perform a reranking evaluation, followed by an argument specifying the reranking method {mmr|portfolio},");
System.out.println(" and the underlying scroing model is specified by -m or 0 by default.");
System.out.println(" -br --batchreranking batch reranking, format: method:a:i:b, a:i:b is like that in -b.");

System.out.println(" -v --var get var and mean for each query");
System.out.println(" var/results with the specified qrels file");
Expand Down Expand Up @@ -142,6 +164,8 @@ protected int processOptions(String[] args) {
return ERROR_NO_ARGUMENTS;
boolean hasModelFlag = false;
boolean isBatch = false;
boolean isReranking = false;
boolean isBatchReranking = false;
int pos = 0;
while (pos < args.length) {
if(hasModelFlag){
Expand Down Expand Up @@ -169,6 +193,39 @@ protected int processOptions(String[] args) {
batchIncrement = 0;
return ERROR_WRONG_BATCH_ARGUMENT;
}
} else if (isBatchReranking) {
isBatchReranking = false;
String[] vars = args[pos].split(":");
if(vars.length != 4)
return ERROR_WRONG_BATCH_ARGUMENT;
try{
reranking_method = vars[0].toLowerCase();
if(!reranking_method.equals("mmr") && ! reranking_method.equals("portfolio")) {
return ERROR_WRONG_RERANKING_ARGUMENT;
}
batchA = Double.parseDouble(vars[1]);
batchIncrement = Double.parseDouble(vars[2]);
batchB = Double.parseDouble(vars[3]);
}catch(NumberFormatException e){
reranking_method = null;
modelNumber = -1;
batchA = 0;
batchB = 0;
batchIncrement = 0;
return ERROR_WRONG_BATCH_ARGUMENT;
}
} else if (isReranking){
isReranking = false;
try{
reranking_method = args[pos].toLowerCase();
if(!reranking_method.equals("mmr") && ! reranking_method.equals("portfolio")) {
return ERROR_WRONG_RERANKING_ARGUMENT;
}
}catch(NumberFormatException e){
reranking_method = null;
modelNumber = -1;
return ERROR_WRONG_RERANKING_ARGUMENT;
}
}else if (args[pos].equals("-h") || args[pos].equals("--help"))
printHelp = true;

Expand All @@ -180,6 +237,12 @@ else if (args[pos].equals("-v") || args[pos].equals("--var"))

else if (args[pos].equals("-e") || args[pos].equals("--evaluate")) {
evaluation = true;
} else if (args[pos].equals("-r") || args[pos].equals("--reranking")) {
reranking = true;
isReranking = true;
} else if (args[pos].equals("-br") || args[pos].equals("--batchreranking")) {
batch_reranking = true;
isBatchReranking = true;
} else if (args[pos].equals("-b") || args[pos].equals("--batch")) {
batch = true;
isBatch = true;
Expand Down Expand Up @@ -250,7 +313,7 @@ else if (evaluation) {
TrecRetrieval trecsearch = new TrecRetrieval();
trecsearch.process(index, topics, qrels, appProp
.getProperty("panda.var"), modelNumber);
} else if (variance) {
} else if (reranking) {
buf = FileReader.openFileReader(appProp.getProperty("panda.etc")
+ fileseparator + "IndexDir.config");
String index = buf.readLine();
Expand All @@ -263,15 +326,30 @@ else if (evaluation) {
+ fileseparator + "Qrels.config");
String qrels = buf.readLine();

// System.out.println(index);
// System.out.println(topics);:q!
// System.out.println(qrels);

// reranking
TrecRetrieval trecsearch = new TrecRetrieval();
trecsearch.process_var(index, topics, qrels, appProp
.getProperty("panda.var"));
if (modelNumber < 0) modelNumber = 0;
trecsearch.process_reranking(index, topics, qrels, appProp
.getProperty("panda.var"), reranking_method, modelNumber);
} else if (batch_reranking) {
buf = FileReader.openFileReader(appProp.getProperty("panda.etc")
+ fileseparator + "IndexDir.config");
String index = buf.readLine();

buf = FileReader.openFileReader(appProp.getProperty("panda.etc")
+ fileseparator + "Topics.config");
String topics = buf.readLine();

buf = FileReader.openFileReader(appProp.getProperty("panda.etc")
+ fileseparator + "Qrels.config");
String qrels = buf.readLine();

// batch reranking
TrecRetrieval trecsearch = new TrecRetrieval();
if (modelNumber < 0) modelNumber = 0;
trecsearch.batch_reranking(index, topics, qrels, appProp
.getProperty("panda.var"), reranking_method, modelNumber,
batchA, batchB, batchIncrement);
} else if (variance) {
buf = FileReader.openFileReader(appProp.getProperty("panda.etc")
+ fileseparator + "IndexDir.config");
Expand Down Expand Up @@ -412,6 +490,10 @@ public void applyOptions(int status, Properties appProp)
System.err
.println("You entered an incorrect format for the batch argument, -b must be followed by argument of the form a:i:b, where a is the starting value b is the end value and i is the increment");
break;
case ERROR_WRONG_RERANKING_ARGUMENT:
System.err
.println("You entered an incorrect format for the reranking argument, -r must be followed by argument of the form method:model, where method is {mmr|portfolio} and model is the number of underlying scoring model.");
break;
case ARGUMENTS_OK:
default:
run(appProp);
Expand All @@ -438,5 +520,6 @@ public void applyOptions(int status, Properties appProp)
protected static final int ERROR_LANGUAGEMODEL_NOT_RETRIEVE = 17;
protected static final int ERROR_WRONG_MODEL_NUMBER = 18;
protected static final int ERROR_WRONG_BATCH_ARGUMENT = 19;
protected static final int ERROR_WRONG_RERANKING_ARGUMENT = 20;

}
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ public void log(String title, int paddLines, PrintWriter logger, String prefix)
* @param logger Logger.
* @param prefix prefix before each log line.
*/
public void batch_log(String title, int paddLines, PrintWriter logger, String prefix) {
public void batch_log(String title, int paddLines, PrintWriter logger, String prefix, boolean showMAP) {
logger.println();
if (title!=null && title.trim().length()>0) {
logger.println("a= "+title);
Expand All @@ -448,7 +448,7 @@ public void batch_log(String title, int paddLines, PrintWriter logger, String pr
nf.setGroupingUsed(true);
logger.print(fracFormat(nf.format(getMRR()))+'\t'); //MRR
logger.print(fracFormat(nf.format(getRecall()))+'\t');//Recall
//logger.print(fracFormat(nf.format(getMAP()))+'\t');//MAP
if (showMAP) logger.print(fracFormat(nf.format(getMAP()))+'\t');//MAP
// 1-call
logger.print(fracFormat(nf.format(getOneCall()))+'\t');
// 2-call
Expand Down Expand Up @@ -509,6 +509,10 @@ public void batch_log(String title, int paddLines, PrintWriter logger, String pr
logger.println();
}
}

public void batch_log(String title, int paddLines, PrintWriter logger, String prefix) {
batch_log(title, paddLines, logger, prefix, false); // no MAP by default
}

private static String padd = " ";
private String format(String s, int minLen) {
Expand Down
Loading