diff --git a/GPSlog_Basic/app/src/main/assets/trainedmodel.model b/GPSlog_Basic/app/src/main/assets/trainedmodel.model new file mode 100644 index 0000000..4b76831 --- /dev/null +++ b/GPSlog_Basic/app/src/main/assets/trainedmodel.model @@ -0,0 +1,13 @@ +svm_type c_svc +kernel_type polynomial +degree 3 +gamma 0.8 +coef0 10.0 +nr_class 2 +total_sv 2 +rho -2.0186747393962623 +label 1 0 +nr_sv 1 1 +SV +0.02565161951916031 1:0.230769230769231 2:0.0804797306508295 3:0.103448275862069 4:0.307692307692308 5:0.692307692307692 +-0.02565161951916031 1:0.0102040816326531 2:0.00702047705546885 3:0.0967741935483871 4:0.76530612244898 5:0.785714285714286 diff --git a/GPSlog_Basic/app/src/main/java/com/example/gpslog/MainActivity.java b/GPSlog_Basic/app/src/main/java/com/example/gpslog/MainActivity.java index fdde815..fc9acc6 100644 --- a/GPSlog_Basic/app/src/main/java/com/example/gpslog/MainActivity.java +++ b/GPSlog_Basic/app/src/main/java/com/example/gpslog/MainActivity.java @@ -72,7 +72,7 @@ protected void onCreate(Bundle savedInstanceState) { db = new DatabaseHandler(this); hmmClassifier = new HMMClassifier(); - svmClassifier = new SVMClassifier(db); + svmClassifier = new SVMClassifier(db, getApplicationContext()); startButton = (Button) findViewById(R.id.startButton); stopButton = (Button) findViewById(R.id.stopButton); diff --git a/GPSlog_Basic/app/src/main/java/com/example/gpslog/SVMClassifier.java b/GPSlog_Basic/app/src/main/java/com/example/gpslog/SVMClassifier.java index d005dbb..5fa8012 100644 --- a/GPSlog_Basic/app/src/main/java/com/example/gpslog/SVMClassifier.java +++ b/GPSlog_Basic/app/src/main/java/com/example/gpslog/SVMClassifier.java @@ -4,51 +4,72 @@ * Created by timad on 6/28/2017. */ +import android.content.Context; +import android.content.res.AssetManager; import android.util.Log; -import android.widget.Toast; + +import java.io.*; import java.util.ArrayList; +import java.util.Arrays; + import libsvm.*; + // This class is used to classify stop and go segments using the svm algorithm public class SVMClassifier { - svm_parameter param; - svm_model model; DatabaseHandler dbh; + Context myContext; + svm_model model; + BufferedReader br; - // The constructor sets the model, and the svm parameters, as well as the databasehandler passed in - public SVMClassifier(DatabaseHandler dbh){ - + // Constructor sets dbh and context for asset use from what is passed into the SVMClassifier + public SVMClassifier(DatabaseHandler dbh, Context context){ this.dbh = dbh; - param = new svm_parameter(); - model = new svm_model(); - - param.kernel_type = svm_parameter.POLY; - //param.svm_type = null; - param.degree = 10; - //param.gamma = null; - //param.coef0 = null; - - model.param = param; - //model.nr_class = null; - //model.l = null; - model.SV = null; - model.sv_coef = null; - model.rho = null; - model.probA = null; - model.probB = null; - model.sv_indices = null; - + this.myContext = context; } - // Marks the Tracks that are stop and go as being "stop and go" in the data table - public void classifyStopAndGo(){ + // Runs the classifier and assigns a value to each track in the database as 0=notstopngo 1=stopngo + public void classifyStopAndGo() { + + // Load model file as a buffered reader then pass it to svm_load: Runtime error will occur on svm_predict if either exception occurs + try { + AssetManager assetManager = myContext.getAssets(); + br = new BufferedReader(new InputStreamReader(assetManager.open("trainedmodel.model"), "UTF-8")); + this.model = svm.svm_load_model(br); + } catch(FileNotFoundException ex) { + ex.printStackTrace(); + Log.e("Model Asset Error", "FileNotFound"); + } catch(IOException ex) { + ex.printStackTrace(); + Log.e("Model Asset Error", "IOException"); + } + + // Collect segments from database handler ArrayList segments = getSegments(); + + /** Create an array of node Arrays called allSegmentNodes, to be passed to svm_predict. + Each segmentNodes array represents the characteristic features of a segment. + This process is necessary to set up the data such that it can be read by the svm library **/ + svm_node[][] allSegmentNodes = new svm_node[segments.size()][5]; + for (int i = 0; i < segments.size(); i++) { + svm_node[] segmentNodes = createSegmentNodeArray(segments.get(i)); + allSegmentNodes[i] = segmentNodes; + } + + // Predict each node in allSegmentNodes and store the prediction in predictedValues array + double[] predictedValues = new double[segments.size()]; //:TODO Put this in a separate thread to improve performance? + for (int i = 0; i < allSegmentNodes.length; i++) { + predictedValues[i] = svm.svm_predict(model, allSegmentNodes[i]); + } + + // Alter the database to reflect which tracks are stop and go + updateDatabase(predictedValues, this.dbh); } - // Returns an array of Segments, which are distinct parts of the trip - // where the Track's hidden state is either ACCELERATION or STOPPED - // Each Segment will have characteristic features that can be used by SVM + /** Returns an array of Segments from the database, which are distinct parts of the trip + The Track's hidden state is either ACCELERATION or STOPPED + Each Segment will have characteristic features that can be used by SVM **/ private ArrayList getSegments(){ ArrayList tracks= (ArrayList)dbh.getAllTracks(); @@ -74,4 +95,27 @@ else if (!isSegment && tracks.get(i).hiddenState != HMMClassifier.FREEFLOW) { return segments; } + // Returns an svm node from index/value passed + private svm_node createSvmNode(int index, double value) { + svm_node node = new svm_node(); + node.index = index; + node.value = value; + return node; + }; + + // Creates an array of nodes that represent the soft-normalized values of the segment passed in + private svm_node[] createSegmentNodeArray(Segment segment) { + svm_node[] nodeArray = new svm_node[5]; + nodeArray[0] = createSvmNode(1, segment.getStopsWRTime()); + nodeArray[1] = createSvmNode(2, segment.getStopsWRSpace()); + nodeArray[2] = createSvmNode(3, segment.getPeakSpeedWRMaxSpeed()); + nodeArray[3] = createSvmNode(4, segment.getMaxSingleStopWRTime()); + nodeArray[4] = createSvmNode(5, segment.getTotStopTimeWRTime()); + return nodeArray; + } + + private void updateDatabase(double[] predictions, DatabaseHandler dbh) { + System.out.println(Arrays.toString(predictions)); + //TODO: Complete this method + } } diff --git a/GPSlog_Basic/app/src/main/java/com/example/gpslog/Segment.java b/GPSlog_Basic/app/src/main/java/com/example/gpslog/Segment.java index 48518d8..f7bdcf7 100644 --- a/GPSlog_Basic/app/src/main/java/com/example/gpslog/Segment.java +++ b/GPSlog_Basic/app/src/main/java/com/example/gpslog/Segment.java @@ -17,6 +17,9 @@ // of either ACCELERATION or STOPPED public class Segment { + // This will determine whether the tracks in the Segment are sent + public boolean toSend; + // These are the 5 features that will be passed to SVM private double stopsWRTime; private double stopsWRSpace; @@ -58,7 +61,7 @@ public Segment(){ public Segment(ArrayList tracks){ this(); try { - setFeatures(tracks); + this.setFeatures(tracks); } catch (ParseException e){ System.err.println("ParseException: " + e.getMessage()); @@ -184,8 +187,9 @@ public double getStopsWRTime() { // Sets the stops WR time based on the values passed in public void setStopsWRTime(int stops, long time) { - if (time != 0) { - this.stopsWRTime = stops / time; + double stopsD = stops; + if (time != 0 && stopsD != 0) { + this.stopsWRTime = stopsD / time; } else { this.stopsWRTime = 0.0; @@ -199,7 +203,7 @@ public double getStopsWRSpace() { // Sets the stops WR distance based on the values passed in public void setStopsWRSpace(int stops, double distance) { - if (distance != 0) { + if (distance != 0 && stops != 0) { this.stopsWRSpace = stops / distance; } else { @@ -229,8 +233,10 @@ public double getMaxSingleStopWRTime() { // Sets the Max Single Stop WR to Time based on the values passed in public void setMaxSingleStopWRTime(long maxSingleStop, long segTimeLength) { - if (maxSingleStop != 0 && segTimeLength != 0) { - this.maxSingleStopWRTime = maxSingleStop / segTimeLength; + double maxSingleStopD = maxSingleStop; + + if (maxSingleStopD != 0 && segTimeLength != 0) { + this.maxSingleStopWRTime = maxSingleStopD / segTimeLength; } else { this.maxSingleStopWRTime = 0.0; @@ -244,8 +250,9 @@ public double getTotStopTimeWRTime() { //Sets the total stopped time WR to time based on the values passed in public void setTotStopTimeWRTime(long totStopTime, long segTimeLength) { - if (totStopTime != 0 && segTimeLength != 0) { - this.totStopTimeWRTime = totStopTime / segTimeLength; + double totStopTimeD = totStopTime; + if (totStopTimeD != 0 && segTimeLength != 0) { + this.totStopTimeWRTime = totStopTimeD / segTimeLength; } else { this.totStopTimeWRTime = 0.0;