Recently, I started using Lenskit framework. The framework is designed for recommendations. It contains a few useful recommendation algorithms, such as item-item collaborative filtering and matrix factorization. However, there is a lack of documentation and examples on the framework.
I needed to use SimpleEvaluator class and I could not find a relevant documentation on the the class or a good example how I can use it.
I found one example in project's repository, but the code is written for lenskit 3. Here is my example for lenskit 2, hope it helps:
I needed to use SimpleEvaluator class and I could not find a relevant documentation on the the class or a good example how I can use it.
I found one example in project's repository, but the code is written for lenskit 3. Here is my example for lenskit 2, hope it helps:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package bionic; | |
import org.grouplens.lenskit.ItemScorer; | |
import org.grouplens.lenskit.baseline.BaselineScorer; | |
import org.grouplens.lenskit.baseline.ItemMeanRatingItemScorer; | |
import org.grouplens.lenskit.baseline.UserMeanBaseline; | |
import org.grouplens.lenskit.baseline.UserMeanItemScorer; | |
import org.grouplens.lenskit.core.LenskitConfiguration; | |
import org.grouplens.lenskit.data.source.DataSource; | |
import org.grouplens.lenskit.data.source.GenericDataSource; | |
import org.grouplens.lenskit.data.text.DelimitedColumnEventFormat; | |
import org.grouplens.lenskit.data.text.EventFormat; | |
import org.grouplens.lenskit.data.text.RatingEventType; | |
import org.grouplens.lenskit.data.text.TextEventDAO; | |
import org.grouplens.lenskit.eval.data.crossfold.CrossfoldTask; | |
import org.grouplens.lenskit.eval.metrics.topn.ItemSelectors; | |
import org.grouplens.lenskit.eval.metrics.topn.NDCGTopNMetric; | |
import org.grouplens.lenskit.eval.metrics.topn.PrecisionRecallTopNMetric; | |
import org.grouplens.lenskit.eval.traintest.SimpleEvaluator; | |
import org.grouplens.lenskit.iterative.IterationCount; | |
import org.grouplens.lenskit.knn.item.ItemItemScorer; | |
import org.grouplens.lenskit.knn.user.UserUserItemScorer; | |
import org.grouplens.lenskit.mf.funksvd.FeatureCount; | |
import org.grouplens.lenskit.mf.funksvd.FunkSVDItemScorer; | |
import org.grouplens.lenskit.mf.funksvd.FunkSVDUpdateRule; | |
import org.grouplens.lenskit.mf.funksvd.RuntimeUpdate; | |
import org.grouplens.lenskit.transform.normalize.BaselineSubtractingUserVectorNormalizer; | |
import org.grouplens.lenskit.transform.normalize.MeanCenteringVectorNormalizer; | |
import org.grouplens.lenskit.transform.normalize.UserVectorNormalizer; | |
import org.grouplens.lenskit.transform.normalize.VectorNormalizer; | |
import org.hamcrest.Matchers; | |
import java.io.File; | |
public class SimpleEvaluatorExample { | |
//number of folds in k-fold cross-validation | |
private static final int CROSSFOLD_NUMBER = 1; | |
//number of ratings to hide for each user | |
private static final int HOLDOUT_NUMBER = 5; | |
//ndcg@n, precision@n, recall@n | |
private static final int AT_N = 2; | |
//rating threshold. Ratings > threshold - relevant, otherwise - irrelevant | |
private static final double THRESHOLD = 3.0; | |
private static final String DATASET_PATH = "D:\\bigdata\\movielens\\fake\\all_ratings_extended"; | |
private static final String TRAIN_TEST_FOLDER_NAME = "task"; | |
//paths for output files | |
private static final String OUTPUT_PATH = "./results/out.csv"; | |
private static final String OUTPUT_USER_PATH = "./results/user.csv"; | |
private static final String OUTPUT_ITEM_PATH = "./results/item.csv"; | |
public static void main(String args[]) { | |
//create evaluator | |
SimpleEvaluator evaluator = new SimpleEvaluator(); | |
//setting up parameters | |
EventFormat eventFormat = new DelimitedColumnEventFormat(new RatingEventType()); | |
DataSource dataSource = new GenericDataSource("split", new TextEventDAO(new File(DATASET_PATH), eventFormat)); | |
CrossfoldTask task = new CrossfoldTask(TRAIN_TEST_FOLDER_NAME); | |
task.setHoldout(HOLDOUT_NUMBER); | |
task.setPartitions(CROSSFOLD_NUMBER); | |
task.setSource(dataSource); | |
evaluator.addDataset(task); | |
//user-based collaborative filtering | |
LenskitConfiguration userUser = new LenskitConfiguration(); | |
userUser.bind(ItemScorer.class).to(UserUserItemScorer.class); | |
userUser.bind(BaselineScorer.class, ItemScorer.class).to(UserMeanItemScorer.class); | |
userUser.bind(UserMeanBaseline.class, ItemScorer.class).to(ItemMeanRatingItemScorer.class); | |
userUser.within(UserVectorNormalizer.class).bind(VectorNormalizer.class).to(MeanCenteringVectorNormalizer.class); | |
evaluator.addAlgorithm("useruser", userUser); | |
//item-based collaborative filtering | |
LenskitConfiguration itemItem = new LenskitConfiguration(); | |
itemItem.bind(ItemScorer.class).to(ItemItemScorer.class); | |
itemItem.bind(BaselineScorer.class, ItemScorer.class).to(UserMeanItemScorer.class); | |
itemItem.bind(UserMeanBaseline.class, ItemScorer.class).to(ItemMeanRatingItemScorer.class); | |
itemItem.bind(UserVectorNormalizer.class).to(BaselineSubtractingUserVectorNormalizer.class); | |
evaluator.addAlgorithm("itemitem", itemItem); | |
//matrix factorization | |
LenskitConfiguration SVD = new LenskitConfiguration(); | |
SVD.bind(ItemScorer.class).to(FunkSVDItemScorer.class); | |
SVD.bind(UserVectorNormalizer.class).to(BaselineSubtractingUserVectorNormalizer.class); | |
SVD.bind(BaselineScorer.class, ItemScorer.class).to(UserMeanItemScorer.class); | |
SVD.bind(UserMeanBaseline.class, ItemScorer.class).to(ItemMeanRatingItemScorer.class); | |
SVD.bind(RuntimeUpdate.class, FunkSVDUpdateRule.class).to(FunkSVDUpdateRule.class); | |
SVD.set(FeatureCount.class).to(4); | |
SVD.set(IterationCount.class).to(10000); | |
evaluator.addAlgorithm("SVD", SVD); | |
//output | |
evaluator.setOutputPath(OUTPUT_PATH); | |
evaluator.setUserOutputPath(OUTPUT_USER_PATH); | |
evaluator.setPredictOutputPath(OUTPUT_ITEM_PATH); | |
//evaluation metrics | |
evaluator.addMetric(new NDCGTopNMetric(AT_N + "", "", AT_N, ItemSelectors.allItems(), ItemSelectors.trainingItems())); | |
evaluator.addMetric(new PrecisionRecallTopNMetric(AT_N + "", "", AT_N, ItemSelectors.allItems(), ItemSelectors.trainingItems(), ItemSelectors.testRatingMatches(Matchers.greaterThanOrEqualTo(THRESHOLD)))); | |
try { | |
evaluator.call(); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
} |
No comments:
Post a Comment