Friday, 8 January 2016

Lenskit code example

Recently, I started using Lenskit framework. The framework is designed for recommendations. It contains a few useful recommendation algorithms, such as item-item collaborative filtering and matrix factorization. However, there is a lack of documentation and examples on the framework.
I needed to use SimpleEvaluator class and I could not find a relevant documentation on the the class or a good example how I can use it.


I found one example in project's repository, but the code is written for lenskit 3. Here is my example for lenskit 2, hope it helps:
package bionic;
import org.grouplens.lenskit.ItemScorer;
import org.grouplens.lenskit.baseline.BaselineScorer;
import org.grouplens.lenskit.baseline.ItemMeanRatingItemScorer;
import org.grouplens.lenskit.baseline.UserMeanBaseline;
import org.grouplens.lenskit.baseline.UserMeanItemScorer;
import org.grouplens.lenskit.core.LenskitConfiguration;
import org.grouplens.lenskit.data.source.DataSource;
import org.grouplens.lenskit.data.source.GenericDataSource;
import org.grouplens.lenskit.data.text.DelimitedColumnEventFormat;
import org.grouplens.lenskit.data.text.EventFormat;
import org.grouplens.lenskit.data.text.RatingEventType;
import org.grouplens.lenskit.data.text.TextEventDAO;
import org.grouplens.lenskit.eval.data.crossfold.CrossfoldTask;
import org.grouplens.lenskit.eval.metrics.topn.ItemSelectors;
import org.grouplens.lenskit.eval.metrics.topn.NDCGTopNMetric;
import org.grouplens.lenskit.eval.metrics.topn.PrecisionRecallTopNMetric;
import org.grouplens.lenskit.eval.traintest.SimpleEvaluator;
import org.grouplens.lenskit.iterative.IterationCount;
import org.grouplens.lenskit.knn.item.ItemItemScorer;
import org.grouplens.lenskit.knn.user.UserUserItemScorer;
import org.grouplens.lenskit.mf.funksvd.FeatureCount;
import org.grouplens.lenskit.mf.funksvd.FunkSVDItemScorer;
import org.grouplens.lenskit.mf.funksvd.FunkSVDUpdateRule;
import org.grouplens.lenskit.mf.funksvd.RuntimeUpdate;
import org.grouplens.lenskit.transform.normalize.BaselineSubtractingUserVectorNormalizer;
import org.grouplens.lenskit.transform.normalize.MeanCenteringVectorNormalizer;
import org.grouplens.lenskit.transform.normalize.UserVectorNormalizer;
import org.grouplens.lenskit.transform.normalize.VectorNormalizer;
import org.hamcrest.Matchers;
import java.io.File;
public class SimpleEvaluatorExample {
//number of folds in k-fold cross-validation
private static final int CROSSFOLD_NUMBER = 1;
//number of ratings to hide for each user
private static final int HOLDOUT_NUMBER = 5;
//ndcg@n, precision@n, recall@n
private static final int AT_N = 2;
//rating threshold. Ratings > threshold - relevant, otherwise - irrelevant
private static final double THRESHOLD = 3.0;
private static final String DATASET_PATH = "D:\\bigdata\\movielens\\fake\\all_ratings_extended";
private static final String TRAIN_TEST_FOLDER_NAME = "task";
//paths for output files
private static final String OUTPUT_PATH = "./results/out.csv";
private static final String OUTPUT_USER_PATH = "./results/user.csv";
private static final String OUTPUT_ITEM_PATH = "./results/item.csv";
public static void main(String args[]) {
//create evaluator
SimpleEvaluator evaluator = new SimpleEvaluator();
//setting up parameters
EventFormat eventFormat = new DelimitedColumnEventFormat(new RatingEventType());
DataSource dataSource = new GenericDataSource("split", new TextEventDAO(new File(DATASET_PATH), eventFormat));
CrossfoldTask task = new CrossfoldTask(TRAIN_TEST_FOLDER_NAME);
task.setHoldout(HOLDOUT_NUMBER);
task.setPartitions(CROSSFOLD_NUMBER);
task.setSource(dataSource);
evaluator.addDataset(task);
//user-based collaborative filtering
LenskitConfiguration userUser = new LenskitConfiguration();
userUser.bind(ItemScorer.class).to(UserUserItemScorer.class);
userUser.bind(BaselineScorer.class, ItemScorer.class).to(UserMeanItemScorer.class);
userUser.bind(UserMeanBaseline.class, ItemScorer.class).to(ItemMeanRatingItemScorer.class);
userUser.within(UserVectorNormalizer.class).bind(VectorNormalizer.class).to(MeanCenteringVectorNormalizer.class);
evaluator.addAlgorithm("useruser", userUser);
//item-based collaborative filtering
LenskitConfiguration itemItem = new LenskitConfiguration();
itemItem.bind(ItemScorer.class).to(ItemItemScorer.class);
itemItem.bind(BaselineScorer.class, ItemScorer.class).to(UserMeanItemScorer.class);
itemItem.bind(UserMeanBaseline.class, ItemScorer.class).to(ItemMeanRatingItemScorer.class);
itemItem.bind(UserVectorNormalizer.class).to(BaselineSubtractingUserVectorNormalizer.class);
evaluator.addAlgorithm("itemitem", itemItem);
//matrix factorization
LenskitConfiguration SVD = new LenskitConfiguration();
SVD.bind(ItemScorer.class).to(FunkSVDItemScorer.class);
SVD.bind(UserVectorNormalizer.class).to(BaselineSubtractingUserVectorNormalizer.class);
SVD.bind(BaselineScorer.class, ItemScorer.class).to(UserMeanItemScorer.class);
SVD.bind(UserMeanBaseline.class, ItemScorer.class).to(ItemMeanRatingItemScorer.class);
SVD.bind(RuntimeUpdate.class, FunkSVDUpdateRule.class).to(FunkSVDUpdateRule.class);
SVD.set(FeatureCount.class).to(4);
SVD.set(IterationCount.class).to(10000);
evaluator.addAlgorithm("SVD", SVD);
//output
evaluator.setOutputPath(OUTPUT_PATH);
evaluator.setUserOutputPath(OUTPUT_USER_PATH);
evaluator.setPredictOutputPath(OUTPUT_ITEM_PATH);
//evaluation metrics
evaluator.addMetric(new NDCGTopNMetric(AT_N + "", "", AT_N, ItemSelectors.allItems(), ItemSelectors.trainingItems()));
evaluator.addMetric(new PrecisionRecallTopNMetric(AT_N + "", "", AT_N, ItemSelectors.allItems(), ItemSelectors.trainingItems(), ItemSelectors.testRatingMatches(Matchers.greaterThanOrEqualTo(THRESHOLD))));
try {
evaluator.call();
} catch (Exception e) {
e.printStackTrace();
}
}
}
view raw gistfile1.txt hosted with ❤ by GitHub

No comments:

Post a Comment