Skip to content

Commit 33988a1

Browse files
authored
Minor bugfixes for LSTM example. (#246)
1 parent f379bfd commit 33988a1

2 files changed

Lines changed: 12 additions & 3 deletions

File tree

cpp/lstm/dga_detection/lstm_dga_detection_predict.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
* To keep the model small and the code fast, we use `float` as a datatype
1616
* instead of the default `double`.
1717
*/
18+
19+
// This must be defined to avoid RNN::serialize() throwing an error---we know
20+
// what we are doing and have manually registered the layer types we care about.
21+
#define MLPACK_ANN_IGNORE_SERIALIZATION_WARNING
1822
#include <mlpack.hpp>
1923

2024
// To keep compilation time and program size down, we only register
@@ -142,10 +146,11 @@ int main(int argc, char** argv)
142146
const float benignLikelihood = ComputeLikelihood(benignOutput, response);
143147
const float maliciousLikelihood = ComputeLikelihood(maliciousOutput,
144148
response);
149+
const float score = benignLikelihood - maliciousLikelihood;
145150

146151
if (benignLikelihood > maliciousLikelihood)
147-
cout << "benign" << endl;
152+
cout << "benign (score " << score << ")" << std::endl;
148153
else
149-
cout << "malicious" << endl;
154+
cout << "malicious (score " << -score << ")" << std::endl;
150155
}
151156
}

cpp/lstm/dga_detection/lstm_dga_detection_train.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
* Train the model on the `dga_domains.csv` file in the data/ directory in the
2727
* repository (once you have run `scripts/download_data_set.py`).
2828
*/
29+
30+
// This must be defined to avoid RNN::serialize() throwing an error---we know
31+
// what we are doing and have manually registered the layer types we care about.
32+
#define MLPACK_ANN_IGNORE_SERIALIZATION_WARNING
2933
#include <mlpack.hpp>
3034

3135
// To keep compilation time and program size down, we only register
@@ -115,7 +119,7 @@ void PrepareData(const vector<string>& domains,
115119
size_t ComputeCorrect(const arma::fcube& benignPredictions,
116120
const arma::fcube& maliciousPredictions,
117121
const arma::fcube& labels,
118-
const arma::uvec& sequenceLengths,
122+
const arma::urowvec& sequenceLengths,
119123
const bool malicious)
120124
{
121125
size_t correct = 0;

0 commit comments

Comments
 (0)