diff --git a/core/pom.xml b/core/pom.xml
index 93d5b34..baa9bb3 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -151,16 +151,10 @@
${basedir}/src/main/resources
true
-
- ${svmSgdModel.value}/**
-
${project.build.directory}
-
- ${svmSgdModel.value}.zip
-
@@ -206,26 +200,6 @@
maven-assembly-plugin
2.6
-
- zipSVMWithSGDModel
- generate-resources
-
- single
-
-
- false
- posix
- ${svmSgdModel.value}
- ${project.build.directory}
-
-
-
- ${basedir}/src/main/assembly/zipSVMWithSGDModel.xml
-
-
-
-
-
generateDistribution
package
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/Learner.java b/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/Learner.java
deleted file mode 100644
index 8a752a3..0000000
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/Learner.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed under the Apache License, Version 2.0 (the "License"); you
- * may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.sdap.mudrod.ssearch.ranking;
-
-import org.apache.sdap.mudrod.driver.SparkDriver;
-import org.apache.spark.SparkContext;
-import org.apache.spark.mllib.classification.SVMModel;
-import org.apache.spark.mllib.regression.LabeledPoint;
-
-import java.io.Serializable;
-
-/**
- * Supports the ability to importing classifier into memory
- */
-public class Learner implements Serializable {
- /**
- *
- */
- private static final long serialVersionUID = 1L;
- SVMModel model = null;
- transient SparkContext sc = null;
-
- /**
- * Constructor to load in spark SVM classifier
- *
- * @param classifierName classifier type
- * @param skd an instance of spark driver
- * @param svmSgdModel path to a trained model
- */
- public Learner(SparkDriver skd, String svmSgdModel) {
- sc = skd.sc.sc();
- sc.addFile(svmSgdModel, true);
- model = SVMModel.load(sc, svmSgdModel);
- }
-
- /**
- * Method of classifying instance
- *
- * @param p the instance that needs to be classified
- * @return the class id
- */
- public double classify(LabeledPoint p) {
- return model.predict(p.features());
- }
-
-}
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/SparkFormatter.java b/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/SparkFormatter.java
deleted file mode 100644
index 8c7fa7f..0000000
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/SparkFormatter.java
+++ /dev/null
@@ -1,47 +0,0 @@
-package org.apache.sdap.mudrod.ssearch.ranking;
-
-import java.io.*;
-import java.text.DecimalFormat;
-
-public class SparkFormatter {
- DecimalFormat NDForm = new DecimalFormat("#.###");
-
- public SparkFormatter() {
- }
-
- public void toSparkSVMformat(String inputCSVFileName, String outputTXTFileName) {
- File file = new File(outputTXTFileName);
- if (file.exists()) {
- file.delete();
- }
- try {
- file.createNewFile();
-
- try (FileWriter fw = new FileWriter(outputTXTFileName);
- BufferedWriter bw = new BufferedWriter(fw);
- BufferedReader br = new BufferedReader(new FileReader(inputCSVFileName));) {
-
- String line = null;
- line = br.readLine(); //header
- while ((line = br.readLine())!= null) {
- String[] list = line.split(",");
- String output = "";
- Double label = Double.parseDouble(list[list.length - 1].replace("\"", ""));
- if (label == -1.0) {
- output = "0 ";
- } else if (label == 1.0) {
- output = "1 ";
- }
-
- for (int i = 0; i < list.length - 1; i++) {
- int index = i + 1;
- output += index + ":" + NDForm.format(Double.parseDouble(list[i].replace("\"", ""))) + " ";
- }
- bw.write(output + "\n");
- }
- }
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
-}
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/SparkSVM.java b/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/SparkSVM.java
deleted file mode 100644
index 0d0eb8d..0000000
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/SparkSVM.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed under the Apache License, Version 2.0 (the "License"); you
- * may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.sdap.mudrod.ssearch.ranking;
-
-import org.apache.sdap.mudrod.main.MudrodEngine;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.classification.SVMModel;
-import org.apache.spark.mllib.classification.SVMWithSGD;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.util.MLUtils;
-
-public class SparkSVM {
-
- private SparkSVM() {
- //public constructor
- }
-
- public static void main(String[] args) {
- MudrodEngine me = new MudrodEngine();
-
- JavaSparkContext jsc = me.startSparkDriver().sc;
-
- String path = SparkSVM.class.getClassLoader().getResource("inputDataForSVM_spark.txt").toString();
- JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
-
- // Run training algorithm to build the model.
- int numIterations = 100;
- final SVMModel model = SVMWithSGD.train(data.rdd(), numIterations);
-
- // Save and load model
- model.save(jsc.sc(), SparkSVM.class.getClassLoader().getResource("javaSVMWithSGDModel").toString());
-
- jsc.sc().stop();
-
- }
-
-}
diff --git a/core/src/main/java/org/apache/sdap/mudrod/weblog/pre/RankingTrainDataGenerator.java b/core/src/main/java/org/apache/sdap/mudrod/weblog/pre/RankingTrainDataGenerator.java
deleted file mode 100644
index de41d56..0000000
--- a/core/src/main/java/org/apache/sdap/mudrod/weblog/pre/RankingTrainDataGenerator.java
+++ /dev/null
@@ -1,54 +0,0 @@
-package org.apache.sdap.mudrod.weblog.pre;
-
-import org.apache.sdap.mudrod.discoveryengine.DiscoveryStepAbstract;
-import org.apache.sdap.mudrod.driver.ESDriver;
-import org.apache.sdap.mudrod.driver.SparkDriver;
-import org.apache.sdap.mudrod.weblog.structure.session.RankingTrainData;
-import org.apache.sdap.mudrod.weblog.structure.session.SessionExtractor;
-import org.apache.spark.api.java.JavaRDD;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.Properties;
-
-public class RankingTrainDataGenerator extends DiscoveryStepAbstract {
-
- private static final long serialVersionUID = 1L;
- private static final Logger LOG = LoggerFactory.getLogger(RankingTrainDataGenerator.class);
-
- public RankingTrainDataGenerator(Properties props, ESDriver es, SparkDriver spark) {
- super(props, es, spark);
- // TODO Auto-generated constructor stub
- }
-
- @Override
- public Object execute() {
- // TODO Auto-generated method stub
- LOG.info("Starting generate ranking train data.");
- startTime = System.currentTimeMillis();
-
- String rankingTrainFile = "E:\\Mudrod_input_data\\Testing_Data_4_1monthLog+Meta+Onto\\traing.txt";
- try {
- SessionExtractor extractor = new SessionExtractor();
- JavaRDD rankingTrainDataRDD = extractor.extractRankingTrainData(this.props, this.es, this.spark);
-
- JavaRDD rankingTrainData_JsonRDD = rankingTrainDataRDD.map(f -> f.toJson());
-
- rankingTrainData_JsonRDD.coalesce(1, true).saveAsTextFile(rankingTrainFile);
-
- } catch (Exception e) {
- e.printStackTrace();
- }
-
- endTime = System.currentTimeMillis();
- LOG.info("Ranking train data generation complete. Time elapsed {} seconds.", (endTime - startTime) / 1000);
- return null;
- }
-
- @Override
- public Object execute(Object o) {
- // TODO Auto-generated method stub
- return null;
- }
-
-}
diff --git a/core/src/main/java/org/apache/sdap/mudrod/weblog/process/ClickStreamAnalyzer.java b/core/src/main/java/org/apache/sdap/mudrod/weblog/process/ClickStreamAnalyzer.java
index 193115e..91913ad 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/weblog/process/ClickStreamAnalyzer.java
+++ b/core/src/main/java/org/apache/sdap/mudrod/weblog/process/ClickStreamAnalyzer.java
@@ -18,7 +18,6 @@
import org.apache.sdap.mudrod.driver.SparkDriver;
import org.apache.sdap.mudrod.main.MudrodConstants;
import org.apache.sdap.mudrod.semantics.SVDAnalyzer;
-import org.apache.sdap.mudrod.ssearch.ClickStreamImporter;
import org.apache.sdap.mudrod.utils.LinkageTriple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -60,10 +59,7 @@ public Object execute() {
props.getProperty(MudrodConstants.CLICKSTREAM_SVD_PATH));
List tripleList = svd.calTermSimfromMatrix(props.getProperty(MudrodConstants.CLICKSTREAM_SVD_PATH));
svd.saveToES(tripleList, props.getProperty(MudrodConstants.ES_INDEX_NAME), MudrodConstants.CLICK_STREAM_LINKAGE_TYPE);
-
- // Store click stream in ES for the ranking use
- ClickStreamImporter cs = new ClickStreamImporter(props, es, spark);
- cs.importfromCSVtoES();
+
}
} catch (Exception e) {
LOG.error("Encountered an error during execution of ClickStreamAnalyzer.", e);
diff --git a/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/Session.java b/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/Session.java
index 2c917a6..7a2ef8b 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/Session.java
+++ b/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/Session.java
@@ -203,7 +203,7 @@ public List getClickStreamList(String indexName, String type, Strin
* @return an instance of session tree structure
* @throws UnsupportedEncodingException UnsupportedEncodingException
*/
- private SessionTree getSessionTree(String indexName, String type, String sessionID) throws UnsupportedEncodingException {
+ public SessionTree getSessionTree(String indexName, String type, String sessionID) throws UnsupportedEncodingException {
SearchResponse response = es.getClient().prepareSearch(indexName).setTypes(type).setQuery(QueryBuilders.termQuery("SessionID", sessionID)).setSize(100).addSort("Time", SortOrder.ASC)
.execute().actionGet();
@@ -261,31 +261,4 @@ private JsonElement getRequests(String cleanuptype, String sessionID) throws Uns
}
return gson.toJsonTree(requestList);
}
-
- /**
- * getClickStreamList: Extracted ranking training data from current session.
- *
- * @param indexName an index from which to obtain ranked training data.
- * @param cleanuptype: Session type name in Elasticsearch
- * @param sessionID: Session ID
- * @return Click stram data list
- * {@link ClickStream}
- */
- public List getRankingTrainData(String indexName, String cleanuptype, String sessionID) {
- SessionTree tree = null;
- try {
- tree = this.getSessionTree(indexName, cleanuptype, sessionID);
- } catch (UnsupportedEncodingException e) {
- LOG.error("Error whilst retreiving Session Tree: {}", e);
- }
-
- List trainData = new ArrayList<>();
- try {
- trainData = tree.getRankingTrainData(indexName);
- } catch (UnsupportedEncodingException e) {
- LOG.error("Error whilst retreiving ranking training data: {}", e);
- }
-
- return trainData;
- }
}
diff --git a/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/SessionExtractor.java b/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/SessionExtractor.java
index f7eb602..8b6dc9d 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/SessionExtractor.java
+++ b/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/SessionExtractor.java
@@ -205,7 +205,7 @@ public List call(List v1, List v2) throws Exception {
* a log index name
* @return list of session names
*/
- protected List getSessions(Properties props, ESDriver es, String logIndex) {
+ public List getSessions(Properties props, ESDriver es, String logIndex) {
String cleanupType = MudrodConstants.CLEANUP_TYPE;
String sessionStatType = MudrodConstants.SESSION_STATS_TYPE;
@@ -383,102 +383,4 @@ public Tuple2> call(String sessionitem) throws Exception {
}
});
}
-
- /**
- * extractClickStreamFromES:Extract click streams from logs stored in
- * Elasticsearch
- *
- * @param props
- * the Mudrod configuration
- * @param es
- * the Elasticsearch drive
- * @param spark
- * the spark driver
- * @return clickstream list in JavaRDD format {@link ClickStream}
- */
- public JavaRDD extractRankingTrainData(Properties props, ESDriver es, SparkDriver spark) {
-
- List queryList = this.extractRankingTrainData(props, es);
- return spark.sc.parallelize(queryList);
-
- }
-
- /**
- * getClickStreamList:Extract click streams from logs stored in Elasticsearch.
- *
- * @param props
- * the Mudrod configuration
- * @param es
- * the Elasticsearch driver
- * @return clickstream list {@link ClickStream}
- */
- protected List extractRankingTrainData(Properties props, ESDriver es) {
- List logIndexList = es.getIndexListWithPrefix(props.getProperty(MudrodConstants.LOG_INDEX));
-
- LOG.info(logIndexList.toString());
-
- List result = new ArrayList<>();
- for (String logIndex : logIndexList) {
- List sessionIdList;
- try {
- sessionIdList = this.getSessions(props, es, logIndex);
- Session session = new Session(props, es);
- for (String aSessionIdList : sessionIdList) {
- String[] sArr = aSessionIdList.split(",");
- List datas = session.getRankingTrainData(sArr[1], sArr[2], sArr[0]);
- result.addAll(datas);
- }
- } catch (Exception e) {
- LOG.error("Error which extracting ranking train data: {}", e);
- }
- }
-
- return result;
- }
-
- protected JavaRDD extractRankingTrainDataInParallel(Properties props, SparkDriver spark, ESDriver es) {
-
- List logIndexList = es.getIndexListWithPrefix(props.getProperty(MudrodConstants.LOG_INDEX));
-
- LOG.info(logIndexList.toString());
-
- List sessionIdList = new ArrayList<>();
- for (String logIndex : logIndexList) {
- List tmpsessionList = this.getSessions(props, es, logIndex);
- sessionIdList.addAll(tmpsessionList);
- }
-
- JavaRDD sessionRDD = spark.sc.parallelize(sessionIdList, 16);
-
- JavaRDD clickStreamRDD = sessionRDD.mapPartitions(
- new FlatMapFunction, RankingTrainData>() {
- /**
- *
- */
- private static final long serialVersionUID = 1L;
-
- @Override
- public Iterator call(Iterator arg0) throws Exception {
- ESDriver tmpES = new ESDriver(props);
- tmpES.createBulkProcessor();
-
- Session session = new Session(props, tmpES);
- List clickstreams = new ArrayList<>();
- while (arg0.hasNext()) {
- String s = arg0.next();
- String[] sArr = s.split(",");
- List clicks = session.getRankingTrainData(sArr[1], sArr[2], sArr[0]);
- clickstreams.addAll(clicks);
- }
- tmpES.destroyBulkProcessor();
- tmpES.close();
- return clickstreams.iterator();
- }
- });
-
- LOG.info("Clickstream number: {}", clickStreamRDD.count());
-
- return clickStreamRDD;
- }
-
}
diff --git a/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/SessionTree.java b/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/SessionTree.java
index 5531f83..abe1ee9 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/SessionTree.java
+++ b/core/src/main/java/org/apache/sdap/mudrod/weblog/structure/session/SessionTree.java
@@ -90,6 +90,14 @@ public SessionTree(Properties props, ESDriver es, String sessionID, String clean
this.sessionID = sessionID;
this.cleanupType = cleanupType;
}
+
+ public SessionNode getRoot(){
+ return this.root;
+ }
+
+ public String getSessionId(){
+ return this.sessionID;
+ }
/**
* insert: insert a node into the session tree.
@@ -427,7 +435,7 @@ private List getViewNodes(SessionNode node) {
return viewnodes;
}
- private List getQueryNodes(SessionNode node) {
+ public List getQueryNodes(SessionNode node) {
return this.getNodes(node, MudrodConstants.SEARCH_MARKER);
}
@@ -447,76 +455,4 @@ private List getNodes(SessionNode node, String nodeKey) {
return nodes;
}
-
- /**
- * Obtain the ranking training data.
- *
- * @param indexName the index from whcih to obtain the data
- * @return {@link ClickStream}
- * @throws UnsupportedEncodingException if there is an error whilst
- * processing the ranking training data.
- */
- public List getRankingTrainData(String indexName) throws UnsupportedEncodingException {
-
- List trainDatas = new ArrayList<>();
-
- List queryNodes = this.getQueryNodes(this.root);
- for (SessionNode querynode : queryNodes) {
- List children = querynode.getChildren();
-
- LinkedHashMap datasetOpt = new LinkedHashMap<>();
- int ndownload = 0;
- for (SessionNode node : children) {
- if ("dataset".equals(node.getKey())) {
- Boolean bDownload = false;
- List nodeChildren = node.getChildren();
- for (SessionNode aNodeChildren : nodeChildren) {
- if ("ftp".equals(aNodeChildren.getKey())) {
- bDownload = true;
- ndownload += 1;
- break;
- }
- }
- datasetOpt.put(node.datasetId, bDownload);
- }
- }
-
- // method 1: The priority of download data are higher
- if (datasetOpt.size() > 1 && ndownload > 0) {
- // query
- RequestUrl requestURL = new RequestUrl();
- String queryUrl = querynode.getRequest();
- String infoStr = requestURL.getSearchInfo(queryUrl);
- String query = null;
- try {
- query = es.customAnalyzing(props.getProperty(MudrodConstants.ES_INDEX_NAME), infoStr);
- } catch (InterruptedException | ExecutionException e) {
- throw new RuntimeException("Error performing custom analyzing", e);
- }
- Map filter = RequestUrl.getFilterInfo(queryUrl);
-
- for (String datasetA : datasetOpt.keySet()) {
- Boolean bDownloadA = datasetOpt.get(datasetA);
- if (bDownloadA) {
- for (String datasetB : datasetOpt.keySet()) {
- Boolean bDownloadB = datasetOpt.get(datasetB);
- if (!bDownloadB) {
-
- String[] queries = query.split(",");
- for (String query1 : queries) {
- RankingTrainData trainData = new RankingTrainData(query1, datasetA, datasetB);
- trainData.setSessionId(this.sessionID);
- trainData.setIndex(indexName);
- trainData.setFilter(filter);
- trainDatas.add(trainData);
- }
- }
- }
- }
- }
- }
- }
-
- return trainDatas;
- }
}
diff --git a/pom.xml b/pom.xml
index 5e31926..98ccf51 100644
--- a/pom.xml
+++ b/pom.xml
@@ -136,6 +136,7 @@
core
+ ranking
service
web
diff --git a/ranking/.gitignore b/ranking/.gitignore
new file mode 100644
index 0000000..ca94514
--- /dev/null
+++ b/ranking/.gitignore
@@ -0,0 +1,3 @@
+/target/
+/bin/
+/lib/
diff --git a/ranking/pom.xml b/ranking/pom.xml
new file mode 100644
index 0000000..428b03b
--- /dev/null
+++ b/ranking/pom.xml
@@ -0,0 +1,215 @@
+
+
+
+ 4.0.0
+
+
+ org.apache.sdap
+ mudrod-parent
+ 0.0.1-SNAPSHOT
+ ../
+
+
+ mudrod-ranking
+
+ Mudrod :: Ranking
+ Mudrod ranking algorithm implementation.
+
+
+
+ javaSVMWithSGDModel
+ 1.0.0-beta
+ nd4j-native-platform
+
+
+
+
+
+ org.apache.sdap
+ mudrod-core
+ ${project.version}
+
+
+ org.nd4j
+ nd4j-native-platform
+ ${nd4j.version}
+
+
+ org.nd4j
+ ${nd4j.backend}
+ ${nd4j.version}
+
+
+ org.deeplearning4j
+ deeplearning4j-core
+ ${nd4j.version}
+
+
+
+
+
+
+ ${basedir}/src/main/resources
+ true
+
+ ${svmSgdModel.value}/**
+
+
+
+
+ ${project.build.directory}
+
+ ${svmSgdModel.value}.zip
+
+
+
+
+ ${basedir}/../
+ META-INF
+
+ LICENSE.txt
+ NOTICE.txt
+
+
+
+
+
+
+
+ org.codehaus.mojo
+ appassembler-maven-plugin
+ 1.10
+
+
+ package
+
+ assemble
+
+
+
+
+ flat
+ lib
+
+
+ org.apache.sdap.mudrod.main.MudrodEngine
+
+ mudrod-engine
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+ 2.6
+
+
+ zipSVMWithSGDModel
+ generate-resources
+
+ single
+
+
+ false
+ posix
+ ${svmSgdModel.value}
+ ${project.build.directory}
+
+
+
+ ${basedir}/src/main/assembly/zipSVMWithSGDModel.xml
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+ 3.0.2
+
+
+
+ test-jar
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+ 3.0.0
+
+
+ package
+
+ shade
+
+
+
+
+
+
+ org.apache.sdap.mudrod.main.MudrodEngine
+
+ ${implementation.build}
+
+
+
+
+
+
+ *:*
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+ ${project.artifactId}-uber-${project.version}
+
+
+
+
+
+
+
+
+
+
+
+ release
+
+
+
+ ${basedir}/../
+
+ ${project.build.directory}/apidocs/META-INF
+
+
+ LICENSE.txt
+ NOTICE.txt
+
+
+
+
+
+
+
+
diff --git a/ranking/src/main/assembly/zipSVMWithSGDModel.xml b/ranking/src/main/assembly/zipSVMWithSGDModel.xml
new file mode 100644
index 0000000..6f277f7
--- /dev/null
+++ b/ranking/src/main/assembly/zipSVMWithSGDModel.xml
@@ -0,0 +1,24 @@
+
+
+
+ zipSVMWithSGDModel
+ ${svmSgdModel.value}
+
+ zip
+
+
+
+ ${basedir}/src/main/resources/${svmSgdModel.value}
+ .
+
+
+
\ No newline at end of file
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/Dispatcher.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/Dispatcher.java
similarity index 99%
rename from core/src/main/java/org/apache/sdap/mudrod/ssearch/Dispatcher.java
rename to ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/Dispatcher.java
index 611c76b..511b8c4 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/Dispatcher.java
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/Dispatcher.java
@@ -11,7 +11,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.sdap.mudrod.ssearch;
+package org.apache.sdap.mudrod.ranking.common;
import org.apache.sdap.mudrod.discoveryengine.MudrodAbstract;
import org.apache.sdap.mudrod.driver.ESDriver;
diff --git a/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/LearnerFactory.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/LearnerFactory.java
new file mode 100644
index 0000000..96f2f29
--- /dev/null
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/LearnerFactory.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License"); you
+ * may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sdap.mudrod.ranking.common;
+
+import org.apache.sdap.mudrod.discoveryengine.MudrodAbstract;
+import org.apache.sdap.mudrod.driver.ESDriver;
+import org.apache.sdap.mudrod.driver.SparkDriver;
+import org.apache.sdap.mudrod.main.MudrodConstants;
+import org.apache.sdap.mudrod.ranking.dlrank.DLRankLearner;
+import org.apache.sdap.mudrod.ranking.ranksvm.RankSVMLearner;
+import org.apache.spark.SparkContext;
+import org.apache.spark.mllib.classification.SVMModel;
+import org.apache.spark.mllib.regression.LabeledPoint;
+
+import java.io.Serializable;
+import java.util.Properties;
+
+/**
+ * Create a learner due to configuration
+ */
+public class LearnerFactory extends MudrodAbstract {
+
+ public LearnerFactory(Properties props, ESDriver es, SparkDriver spark) {
+ super(props, es, spark);
+ }
+
+ public RankLearner createLearner() {
+ /*if ("1".equals(props.getProperty(MudrodConstants.RANKING_ML)))
+ return new SVMLearner(props, es, spark, props.getProperty(MudrodConstants.RANKING_MODEL));
+
+ return null;*/
+ return new RankSVMLearner(props, es, spark, props.getProperty(MudrodConstants.RANKING_MODEL));
+ //return new DLRankLearner(props, es, spark, "");
+ }
+}
diff --git a/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/RankLearner.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/RankLearner.java
new file mode 100644
index 0000000..61663cf
--- /dev/null
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/RankLearner.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License"); you
+ * may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sdap.mudrod.ranking.common;
+
+import org.apache.sdap.mudrod.discoveryengine.MudrodAbstract;
+import org.apache.sdap.mudrod.driver.ESDriver;
+import org.apache.sdap.mudrod.driver.SparkDriver;
+import org.apache.sdap.mudrod.ranking.traindata.ExpertRankTrainData;
+import org.apache.spark.SparkContext;
+import org.apache.spark.mllib.classification.SVMModel;
+import org.apache.spark.mllib.regression.LabeledPoint;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Properties;
+
+/**
+ * learn rank weights from train data and predict search results ranking
+ */
+public abstract class RankLearner extends MudrodAbstract {
+
+ public RankLearner(Properties props, ESDriver es, SparkDriver spark) {
+ super(props, es, spark);
+ }
+
+ public abstract String customizeData(String sourceDir, String outFileName);
+
+ public abstract void train(String trainFile);
+
+ public abstract void evaluate(String testFile);
+
+ public abstract double predict(double[] value);
+
+ public abstract void save();
+
+ public abstract void load(String model);
+}
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/Ranker.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/Ranker.java
similarity index 91%
rename from core/src/main/java/org/apache/sdap/mudrod/ssearch/Ranker.java
rename to ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/Ranker.java
index af7e6a9..5df0509 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/Ranker.java
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/Ranker.java
@@ -11,14 +11,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.sdap.mudrod.ssearch;
+package org.apache.sdap.mudrod.ranking.common;
import org.apache.sdap.mudrod.discoveryengine.MudrodAbstract;
import org.apache.sdap.mudrod.driver.ESDriver;
import org.apache.sdap.mudrod.driver.SparkDriver;
import org.apache.sdap.mudrod.main.MudrodConstants;
-import org.apache.sdap.mudrod.ssearch.ranking.Learner;
-import org.apache.sdap.mudrod.ssearch.structure.SResult;
+import org.apache.sdap.mudrod.ranking.structure.SResult;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
@@ -32,12 +31,13 @@
public class Ranker extends MudrodAbstract implements Serializable {
private static final long serialVersionUID = 1L;
transient List resultList = new ArrayList<>();
- Learner le = null;
+ RankLearner le = null;
public Ranker(Properties props, ESDriver es, SparkDriver spark) {
super(props, es, spark);
- if("1".equals(props.getProperty(MudrodConstants.RANKING_ML)))
- le = new Learner(spark, props.getProperty(MudrodConstants.RANKING_MODEL));
+
+ LearnerFactory factory = new LearnerFactory(props, es, spark);
+ le = factory.createLearner();
}
/**
@@ -160,11 +160,9 @@ public int compare(SResult o1, SResult o2) {
}
double[] ins = instList.stream().mapToDouble(i -> i).toArray();
- LabeledPoint insPoint = new LabeledPoint(99.0, Vectors.dense(ins));
- int prediction = (int)le.classify(insPoint);
+ int prediction = (int)le.predict(ins);
return prediction;
}
}
-
-}
+}
\ No newline at end of file
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/Searcher.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/Searcher.java
similarity index 98%
rename from core/src/main/java/org/apache/sdap/mudrod/ssearch/Searcher.java
rename to ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/Searcher.java
index ce0183a..7b5551c 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/Searcher.java
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/Searcher.java
@@ -11,7 +11,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.sdap.mudrod.ssearch;
+package org.apache.sdap.mudrod.ranking.common;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
@@ -20,7 +20,7 @@
import org.apache.sdap.mudrod.discoveryengine.MudrodAbstract;
import org.apache.sdap.mudrod.driver.ESDriver;
import org.apache.sdap.mudrod.driver.SparkDriver;
-import org.apache.sdap.mudrod.ssearch.structure.SResult;
+import org.apache.sdap.mudrod.ranking.structure.SResult;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.index.query.BoolQueryBuilder;
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/package-info.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/package-info.java
similarity index 94%
rename from core/src/main/java/org/apache/sdap/mudrod/ssearch/package-info.java
rename to ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/package-info.java
index b635b64..2c47e90 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/package-info.java
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/common/package-info.java
@@ -15,4 +15,4 @@
* This package includes classes for semantic search, such as click stream importer,
* query dispatcher, semantic searcher, and ranker (ranksvm, ordinal/linear regression)
*/
-package org.apache.sdap.mudrod.ssearch;
+package org.apache.sdap.mudrod.ranking.common;
diff --git a/ranking/src/main/java/org/apache/sdap/mudrod/ranking/dlrank/DLRankLearner.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/dlrank/DLRankLearner.java
new file mode 100644
index 0000000..2325b6c
--- /dev/null
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/dlrank/DLRankLearner.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License"); you
+ * may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sdap.mudrod.ranking.dlrank;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Properties;
+
+import org.apache.sdap.mudrod.driver.ESDriver;
+import org.apache.sdap.mudrod.driver.SparkDriver;
+import org.apache.sdap.mudrod.main.MudrodEngine;
+import org.apache.sdap.mudrod.ranking.common.RankLearner;
+import org.apache.sdap.mudrod.ranking.common.LearnerFactory;
+import org.apache.sdap.mudrod.ranking.ranksvm.RankSVMLearner;
+import org.apache.sdap.mudrod.ranking.ranksvm.SparkFormatter;
+import org.apache.sdap.mudrod.ranking.traindata.RankTrainDataFactory;
+import org.datavec.api.records.reader.RecordReader;
+import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
+import org.datavec.api.split.FileSplit;
+import org.datavec.api.util.ClassPathResource;
+import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
+import org.deeplearning4j.eval.Evaluation;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.Nesterovs;
+import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+
+public class DLRankLearner extends RankLearner {
+
+ MultiLayerNetwork model = null;
+
+ /**
+ * Constructor to train rank model with deep learning method
+ *
+ * @param classifierName
+ * classifier type
+ * @param skd
+ * an instance of spark driver
+ * @param svmSgdModel
+ * path to a trained model
+ */
+ public DLRankLearner(Properties props, ESDriver es, SparkDriver spark, String dlModel) {
+ super(props, es, spark);
+ load(dlModel);
+ }
+
+ @Override
+ public String customizeData(String sourceDir, String outFileName) {
+ RankTrainDataFactory factory = new RankTrainDataFactory(props, es, spark);
+ String resultFile = factory.createRankTrainData("experts", sourceDir);
+
+ String path = new File(resultFile).getParent();
+ String separator = System.getProperty("file.separator");
+ String nd4jFile = path + separator + outFileName + ".csv";
+ ND4JFormatter sf = new ND4JFormatter();
+ sf.toND4Jformat(resultFile, nd4jFile);
+ return nd4jFile;
+ }
+
+ @Override
+ public void train(String trainFile) {
+ //init model
+ if(model == null){
+
+ int seed = 123;
+ double learningRate = 0.01;
+ int numInputs = 7;
+ int numOutputs = 2;
+ int numHiddenNodes = 20;
+ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ .seed(seed)
+ .updater(new Nesterovs(learningRate, 0.9))
+ .list()
+ .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
+ .weightInit(WeightInit.XAVIER)
+ .activation(Activation.RELU)
+ .build())
+ .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
+ .weightInit(WeightInit.XAVIER)
+ .activation(Activation.SOFTMAX)
+ .nIn(numHiddenNodes).nOut(numOutputs).build())
+ .pretrain(false).backprop(true).build();
+
+ model = new MultiLayerNetwork(conf);
+ model.init();
+ }
+
+ //Load the training data
+ int numLinesToSkip = 1;
+ char delimiter = ',';
+ RecordReader rr = new CSVRecordReader(numLinesToSkip,delimiter);
+ try {
+ rr.initialize(new FileSplit(new File(trainFile)));
+ } catch (IOException | InterruptedException e) {
+ e.printStackTrace();
+ }
+
+ //train model
+ int batchSize = 50;
+ int nEpochs = 30;
+ DataSetIterator trainIter = new RecordReaderDataSetIterator(rr,batchSize,0,2);
+ for ( int n = 0; n < nEpochs; n++) {
+ model.fit( trainIter );
+ }
+ }
+
+ @Override
+ public void evaluate(String testFile) {
+ //Load the test/evaluation data:
+ int batchSize = 50;
+ int numOutputs = 2;
+ int numLinesToSkip = 1;
+ char delimiter = ',';
+ RecordReader rrTest = new CSVRecordReader(numLinesToSkip,delimiter);
+ try {
+ rrTest.initialize(new FileSplit(new File(testFile)));
+ } catch (IOException | InterruptedException e) {
+ e.printStackTrace();
+ }
+ DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,batchSize,0,2);
+
+ System.out.println("Evaluate model....");
+ Evaluation eval = new Evaluation(numOutputs);
+ while(testIter.hasNext()){
+ DataSet t = testIter.next();
+ INDArray features = t.getFeatureMatrix();
+ INDArray lables = t.getLabels();
+ INDArray predicted = model.output(features,false);
+ eval.eval(lables, predicted);
+ }
+ System.out.println(eval.stats());
+ }
+
+ @Override
+ public double predict(double[] value) {
+ int nRows = 1;
+ int nColumns = value.length;
+ INDArray features = Nd4j.zeros(nRows, nColumns);
+ for(int i=0; i 0; i--) {
+ list[i] = list[i - 1].replace("\"", "");
+ }
+ list[0] = output;
+
+ csvOutput.writeNext(list); // Write this array to the file
+ line = br.readLine();
+ }
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ClickStreamImporter.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/evaluate/ClickstreamImporter.java
similarity index 74%
rename from core/src/main/java/org/apache/sdap/mudrod/ssearch/ClickStreamImporter.java
rename to ranking/src/main/java/org/apache/sdap/mudrod/ranking/evaluate/ClickstreamImporter.java
index 546fae5..0620d98 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ClickStreamImporter.java
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/evaluate/ClickstreamImporter.java
@@ -11,7 +11,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.sdap.mudrod.ssearch;
+package org.apache.sdap.mudrod.ranking.evaluate;
import org.apache.sdap.mudrod.discoveryengine.MudrodAbstract;
import org.apache.sdap.mudrod.driver.ESDriver;
@@ -48,22 +48,19 @@ public ClickStreamImporter(Properties props, ESDriver es, SparkDriver spark) {
*/
public void addClickStreamMapping() {
XContentBuilder mapping;
- String clickStreamMatrixType = props.getProperty(MudrodConstants.CLICK_STREAM_MATRIX_TYPE);
try {
- mapping = jsonBuilder()
- .startObject()
- .startObject(clickStreamMatrixType)
- .startObject("properties")
- .startObject("query").field("type", "string").field("index", "not_analyzed").endObject()
- .startObject("dataID").field("type", "string").field("index", "not_analyzed").endObject()
- .endObject()
- .endObject()
- .endObject();
+ mapping = jsonBuilder().startObject().startObject(
+ props.getProperty(MudrodConstants.CLICK_STREAM_MATRIX_TYPE)).startObject(
+ "properties").startObject("query").field("type", "string").field(
+ "index", "not_analyzed").endObject().startObject("dataID").field(
+ "type", "string").field("index", "not_analyzed").endObject()
+
+ .endObject().endObject().endObject();
es.getClient().admin().indices().preparePutMapping(
props.getProperty(MudrodConstants.ES_INDEX_NAME)).setType(
- clickStreamMatrixType).setSource(
- mapping).execute().actionGet();
+ props.getProperty(MudrodConstants.CLICK_STREAM_MATRIX_TYPE)).setSource(
+ mapping).execute().actionGet();
} catch (IOException e) {
e.printStackTrace();
}
@@ -73,9 +70,8 @@ public void addClickStreamMapping() {
* Method to import click stream CSV into Elasticsearch
*/
public void importfromCSVtoES() {
- String clickStreamMatrixType = props.getProperty(MudrodConstants.CLICK_STREAM_MATRIX_TYPE);
- String esIndexName = props.getProperty(MudrodConstants.ES_INDEX_NAME);
- es.deleteType(esIndexName, clickStreamMatrixType);
+ es.deleteType(props.getProperty(MudrodConstants.ES_INDEX_NAME),
+ props.getProperty(MudrodConstants.CLICK_STREAM_MATRIX_TYPE));
es.createBulkProcessor();
BufferedReader br = null;
@@ -90,7 +86,8 @@ public void importfromCSVtoES() {
String[] clicks = line.split(cvsSplitBy);
for (int i = 1; i < clicks.length; i++) {
if (!"0.0".equals(clicks[i])) {
- IndexRequest ir = new IndexRequest(esIndexName, clickStreamMatrixType)
+ IndexRequest ir = new IndexRequest(props.getProperty(MudrodConstants.ES_INDEX_NAME),
+ props.getProperty(MudrodConstants.CLICK_STREAM_MATRIX_TYPE))
.source(jsonBuilder().startObject().field("query", clicks[0]).field(
"dataID", dataList[i]).field("clicks", clicks[i]).endObject());
es.getBulkProcessor().add(ir);
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/Evaluator.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/evaluate/Evaluator.java
similarity index 98%
rename from core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/Evaluator.java
rename to ranking/src/main/java/org/apache/sdap/mudrod/ranking/evaluate/Evaluator.java
index 0efb82f..0c4e3fd 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/Evaluator.java
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/evaluate/Evaluator.java
@@ -11,7 +11,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.sdap.mudrod.ssearch.ranking;
+package org.apache.sdap.mudrod.ranking.evaluate;
import java.util.Collections;
import java.util.Comparator;
diff --git a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/TrainingImporter.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/evaluate/TrainingImporter.java
similarity index 98%
rename from core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/TrainingImporter.java
rename to ranking/src/main/java/org/apache/sdap/mudrod/ranking/evaluate/TrainingImporter.java
index ff55c85..5cae4a1 100644
--- a/core/src/main/java/org/apache/sdap/mudrod/ssearch/ranking/TrainingImporter.java
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/evaluate/TrainingImporter.java
@@ -11,7 +11,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.sdap.mudrod.ssearch.ranking;
+package org.apache.sdap.mudrod.ranking.evaluate;
import org.apache.sdap.mudrod.discoveryengine.MudrodAbstract;
import org.apache.sdap.mudrod.driver.ESDriver;
diff --git a/ranking/src/main/java/org/apache/sdap/mudrod/ranking/ranksvm/RankSVMLearner.java b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/ranksvm/RankSVMLearner.java
new file mode 100644
index 0000000..a0f2c47
--- /dev/null
+++ b/ranking/src/main/java/org/apache/sdap/mudrod/ranking/ranksvm/RankSVMLearner.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License"); you
+ * may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sdap.mudrod.ranking.ranksvm;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Comparator;
+import java.util.Properties;
+
+import org.apache.sdap.mudrod.driver.ESDriver;
+import org.apache.sdap.mudrod.driver.SparkDriver;
+import org.apache.sdap.mudrod.main.MudrodEngine;
+import org.apache.sdap.mudrod.ranking.common.RankLearner;
+import org.apache.sdap.mudrod.ranking.common.LearnerFactory;
+import org.apache.sdap.mudrod.ranking.traindata.RankTrainDataFactory;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.classification.SVMModel;
+import org.apache.spark.mllib.classification.SVMWithSGD;
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+
+import scala.Tuple2;
+
+/**
+ * Learn ranking weights with SVM model
+ */
+public class RankSVMLearner extends RankLearner {
+ /**
+ *
+ */
+ private static final long serialVersionUID = 1L;
+ SVMModel model = null;
+ transient SparkContext sc = null;
+
+ /**
+ * Constructor to load in spark SVM classifier
+ *
+ * @param classifierName
+ * classifier type
+ * @param skd
+ * an instance of spark driver
+ * @param svmSgdModel
+ * path to a trained model
+ */
+ public RankSVMLearner(Properties props, ESDriver es, SparkDriver spark, String svmSgdModel) {
+ super(props, es, spark);
+ sc = spark.sc.sc();
+ load(svmSgdModel);
+ }
+
+ @Override
+ public String customizeData(String sourceDir, String outFileName) {
+ RankTrainDataFactory factory = new RankTrainDataFactory(props, es, spark);
+ String resultFile = factory.createRankTrainData("experts", sourceDir);
+
+ String path = new File(resultFile).getParent();
+
+ String separator = System.getProperty("file.separator");
+ String svmSparkFile = path + separator + outFileName + ".txt";
+ SparkFormatter sf = new SparkFormatter();
+ sf.toSparkSVMformat(resultFile, svmSparkFile);
+
+ return svmSparkFile;
+ }
+
+ @Override
+ public void train(String trainFile) {
+ JavaRDD data = MLUtils.loadLibSVMFile(sc, trainFile).toJavaRDD();
+ // Run training algorithm to build the model.
+ int numIterations = 100;
+ model = SVMWithSGD.train(data.rdd(), numIterations);
+ }
+
+ @Override
+ public void evaluate(String testFile) {
+ JavaRDD data = MLUtils.loadLibSVMFile(sc, testFile).toJavaRDD();
+ // Run training algorithm to build the model.
+ JavaRDD> scoreAndLabels = data.map(p->{
+ double score = model.predict(p.features());
+ return new Tuple2<>(score, p.label());
+ });
+ BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(scoreAndLabels.rdd());
+ System.out.println("Area under ROC = " + metrics.areaUnderROC());
+ long correctNum = scoreAndLabels.filter(new Function, Boolean>(){
+ @Override
+ public Boolean call(Tuple2