Spark MLlib library for Machine Learning provides a Collaborative Filtering implementation by using Alternating Least Squares. The implementation in MLlib has these parameters:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('rec').getOrCreate()
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
data = spark.read.csv("dbfs:/FileStore/shared_uploads/dizhen@hsph.harvard.edu/movielens_ratings.csv",inferSchema=True,header=True)
data.head()
Out[5]: Row(movieId=2, rating=3.0, userId=0)
data.describe().show()
+-------+------------------+------------------+------------------+ |summary| movieId| rating| userId| +-------+------------------+------------------+------------------+ | count| 1501| 1501| 1501| | mean| 49.40572951365756|1.7741505662891406|14.383744170552964| | stddev|28.937034065088994| 1.187276166124803| 8.591040424293272| | min| 0| 1.0| 0| | max| 99| 5.0| 29| +-------+------------------+------------------+------------------+
(training, test) = data.randomSplit([0.8, 0.2])
als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating")
model = als.fit(training)
predictions = model.transform(test)
predictions.show()
+-------+------+------+-----------+ |movieId|rating|userId| prediction| +-------+------+------+-----------+ | 2| 2.0| 1| 1.9871931| | 1| 1.0| 6| 0.28594762| | 4| 1.0| 7| 2.146959| | 0| 1.0| 8| 1.8356189| | 4| 2.0| 8| 0.87620103| | 2| 3.0| 9| 2.1391335| | 4| 1.0| 9| 2.4214845| | 0| 1.0| 11| -1.3260899| | 2| 1.0| 12| 3.1861515| | 3| 1.0| 13| 2.1452992| | 4| 2.0| 13| 0.7788123| | 2| 1.0| 15| 2.4739377| | 2| 1.0| 17| -2.657514| | 3| 1.0| 17| 0.2842377| | 2| 2.0| 20| -1.2231252| | 3| 2.0| 22| 0.5278863| | 0| 1.0| 23| 0.34326315| | 4| 1.0| 23| 1.9485532| | 0| 1.0| 27|-0.45441574| | 0| 3.0| 28| 0.7916339| +-------+------+------+-----------+ only showing top 20 rows
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))
Root-mean-square error = 1.783041436897024
single_user = test.filter(test['userId']==11).select(['movieId','userId'])
single_user.show()
+-------+------+ |movieId|userId| +-------+------+ | 0| 11| | 9| 11| | 12| 11| | 20| 11| | 25| 11| | 43| 11| | 47| 11| | 51| 11| | 66| 11| | 75| 11| | 81| 11| | 97| 11| | 99| 11| +-------+------+
reccomendations = model.transform(single_user)
reccomendations.orderBy('prediction',ascending=False).show()
+-------+------+----------+ |movieId|userId|prediction| +-------+------+----------+ | 47| 11| 2.4248514| | 20| 11| 1.6464235| | 43| 11| 1.3852543| | 9| 11| 1.372124| | 97| 11| 1.0738724| | 25| 11| 1.0302699| | 99| 11|0.95755816| | 51| 11|0.62003374| | 81| 11| 0.5865185| | 12| 11| 0.44638| | 75| 11|0.27664962| | 66| 11|0.22641006| | 0| 11|-1.3260899| +-------+------+----------+
Mapping
{ 2. : "Chicken Curry",
3. : "Spicy Chicken Nuggest",
5. : "Hamburger",
9. : "Taco Surprise",
11. : "Meatloaf",
12. : "Ceaser Salad",
15. : "BBQ Ribs",
17. : "Sushi Plate",
19. : "Cheesesteak Sandwhich",
21. : "Lasagna",
23. : "Orange Chicken",
26. : "Spicy Beef Plate",
27. : "Salmon with Mashed Potatoes",
28. : "Penne Tomatoe Pasta",
29. : "Pork Sliders",
30. : "Vietnamese Sandwich",
31. : "Chicken Wrap",
np.nan: "Cowboy Burger",
4. : "Pretzels and Cheese Plate",
6. : "Spicy Pork Sliders",
13. : "Mandarin Chicken PLate",
14. : "Kung Pao Chicken",
16. : "Fried Rice Plate",
8. : "Chicken Chow Mein",
10. : "Roasted Eggplant ",
18. : "Pepperoni Pizza",
22. : "Pulled Pork Plate",
0. : "Cheese Pizza",
1. : "Burrito",
7. : "Nachos",
24. : "Chili",
20. : "Southwest Salad",
25.: "Roast Beef Sandwich"}
data = spark.read.csv("dbfs:/FileStore/shared_uploads/dizhen@hsph.harvard.edu/Meal_Info.csv",inferSchema=True,header=True)
(training, test) = data.randomSplit([0.8, 0.2])
als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="mealskew", ratingCol="rating")
model = als.fit(training)
--------------------------------------------------------------------------- Py4JJavaError Traceback (most recent call last) <command-1680652028931040> in <module> 1 als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="mealskew", ratingCol="rating") ----> 2 model = als.fit(training) /databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/_pyspark.py in patched_method(self, *args, **kwargs) 28 call_succeeded = False 29 try: ---> 30 result = original_method(self, *args, **kwargs) 31 call_succeeded = True 32 return result /databricks/spark/python/pyspark/ml/base.py in fit(self, dataset, params) 159 return self.copy(params)._fit(dataset) 160 else: --> 161 return self._fit(dataset) 162 else: 163 raise TypeError("Params must be either a param map or a list/tuple of param maps, " /databricks/spark/python/pyspark/ml/wrapper.py in _fit(self, dataset) 333 334 def _fit(self, dataset): --> 335 java_model = self._fit_java(dataset) 336 model = self._create_model(java_model) 337 return self._copyValues(model) /databricks/spark/python/pyspark/ml/wrapper.py in _fit_java(self, dataset) 330 """ 331 self._transfer_params_to_java() --> 332 return self._java_obj.fit(dataset._jdf) 333 334 def _fit(self, dataset): /databricks/spark/python/lib/py4j-0.10.9.1-src.zip/py4j/java_gateway.py in __call__(self, *args) 1302 1303 answer = self.gateway_client.send_command(command) -> 1304 return_value = get_return_value( 1305 answer, self.gateway_client, self.target_id, self.name) 1306 /databricks/spark/python/pyspark/sql/utils.py in deco(*a, **kw) 115 def deco(*a, **kw): 116 try: --> 117 return f(*a, **kw) 118 except py4j.protocol.Py4JJavaError as e: 119 converted = convert_exception(e.java_exception) /databricks/spark/python/lib/py4j-0.10.9.1-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client) 325 if answer[1] == REFERENCE_TYPE: --> 326 raise Py4JJavaError( 327 "An error occurred while calling {0}{1}{2}.\n". 328 format(target_id, ".", name), value) Py4JJavaError: An error occurred while calling o894.fit. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 218.0 failed 1 times, most recent failure: Lost task 0.0 in stage 218.0 (TID 251) (ip-10-172-165-175.us-west-2.compute.internal executor driver): org.apache.spark.SparkException: Failed to execute user defined function (ALSModelParams$$Lambda$6638/1722650802: (double) => int) at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:166) at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491) at scala.collection.Iterator$ConcatIterator.hasNext(Iterator.scala:224) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.$anonfun$runTask$3(ShuffleMapTask.scala:81) at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110) at org.apache.spark.scheduler.ShuffleMapTask.$anonfun$runTask$1(ShuffleMapTask.scala:81) at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41) at org.apache.spark.scheduler.Task.doRunTask(Task.scala:156) at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:125) at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110) at org.apache.spark.scheduler.Task.run(Task.scala:95) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:826) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1670) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:829) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:684) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) Caused by: java.lang.IllegalArgumentException: ALS only supports values in Integer range for columns userId and mealskew. Value null was not numeric. at org.apache.spark.ml.recommendation.ALSModelParams.$anonfun$checkedCast$1(ALS.scala:105) at org.apache.spark.ml.recommendation.ALSModelParams.$anonfun$checkedCast$1$adapted(ALS.scala:90) ... 29 more Driver stacktrace: at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2984) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2931) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2925) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2925) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1345) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1345) at scala.Option.foreach(Option.scala:407) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1345) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3193) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3134) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3122) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1107) at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2628) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2611) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2649) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2668) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2693) at org.apache.spark.rdd.RDD.count(RDD.scala:1263) at org.apache.spark.ml.recommendation.ALS$.train(ALS.scala:974) at org.apache.spark.ml.recommendation.ALS.$anonfun$fit$1(ALS.scala:723) at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:284) at scala.util.Try$.apply(Try.scala:213) at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:284) at org.apache.spark.ml.recommendation.ALS.fit(ALS.scala:705) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380) at py4j.Gateway.invoke(Gateway.java:295) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:251) at java.lang.Thread.run(Thread.java:748) Caused by: org.apache.spark.SparkException: Failed to execute user defined function (ALSModelParams$$Lambda$6638/1722650802: (double) => int) at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:166) at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491) at scala.collection.Iterator$ConcatIterator.hasNext(Iterator.scala:224) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.$anonfun$runTask$3(ShuffleMapTask.scala:81) at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110) at org.apache.spark.scheduler.ShuffleMapTask.$anonfun$runTask$1(ShuffleMapTask.scala:81) at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41) at org.apache.spark.scheduler.Task.doRunTask(Task.scala:156) at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:125) at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110) at org.apache.spark.scheduler.Task.run(Task.scala:95) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:826) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1670) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:829) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:684) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more Caused by: java.lang.IllegalArgumentException: ALS only supports values in Integer range for columns userId and mealskew. Value null was not numeric. at org.apache.spark.ml.recommendation.ALSModelParams.$anonfun$checkedCast$1(ALS.scala:105) at org.apache.spark.ml.recommendation.ALSModelParams.$anonfun$checkedCast$1$adapted(ALS.scala:90) ... 29 more
predictions = model.transform(test)
predictions.show()
+-------+------+------+--------+--------------------+----------+ |movieId|rating|userId|mealskew| meal_name|prediction| +-------+------+------+--------+--------------------+----------+ | 2| 2.0| 1| 2.0| Chicken Curry| 1.9871931| | 3| 1.0| 1| 3.0|Spicy Chicken Nug...|0.90050435| | 4| 3.0| 2| 4.0|Pretzels and Chee...| 2.912253| | 0| 1.0| 5| 0.0| Cheese Pizza| 1.1937842| | 3| 1.0| 7| 3.0|Spicy Chicken Nug...| 1.367688| | 4| 1.0| 7| 4.0|Pretzels and Chee...| 2.146959| | 3| 2.0| 8| 3.0|Spicy Chicken Nug...| 1.7796037| | 2| 3.0| 9| 2.0| Chicken Curry| 2.1391335| | 3| 1.0| 9| 3.0|Spicy Chicken Nug...| 1.0622486| | 2| 1.0| 12| 2.0| Chicken Curry| 3.1861515| | 4| 1.0| 12| 4.0|Pretzels and Chee...|0.73995054| | 3| 1.0| 13| 3.0|Spicy Chicken Nug...| 2.1452992| | 1| 4.0| 15| 1.0| Burrito| 2.993206| | 0| 1.0| 19| 0.0| Cheese Pizza| 1.0666925| | 1| 1.0| 19| 1.0| Burrito| 0.7900078| | 0| 1.0| 20| 0.0| Cheese Pizza| 1.0853959| | 1| 1.0| 20| 1.0| Burrito| 1.0895385| | 0| 1.0| 23| 0.0| Cheese Pizza|0.34326315| | 2| 1.0| 23| 2.0| Chicken Curry| 0.9516133| | 1| 1.0| 26| 1.0| Burrito| 0.8212584| +-------+------+------+--------+--------------------+----------+ only showing top 20 rows
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))
Root-mean-square error = 0.8463110939439068