In this project we will build a spam detection filter. The dataset consists of volunteered text messages from a study in Singapore and some spam texts from a UK reporting site. The data comes from UCI Repository SMS Spam Detection: https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('nlp').getOrCreate()
data = spark.read.csv("dbfs:/FileStore/shared_uploads/dizhen@hsph.harvard.edu/SMSSpamCollection",inferSchema=True,sep='\t')
data = data.withColumnRenamed('_c0','class').withColumnRenamed('_c1','text')
data.show()
+-----+--------------------+ |class| text| +-----+--------------------+ | ham|Go until jurong p...| | ham|Ok lar... Joking ...| | spam|Free entry in 2 a...| | ham|U dun say so earl...| | ham|Nah I don't think...| | spam|FreeMsg Hey there...| | ham|Even my brother i...| | ham|As per your reque...| | spam|WINNER!! As a val...| | spam|Had your mobile 1...| | ham|I'm gonna be home...| | spam|SIX chances to wi...| | spam|URGENT! You have ...| | ham|I've been searchi...| | ham|I HAVE A DATE ON ...| | spam|XXXMobileMovieClu...| | ham|Oh k...i'm watchi...| | ham|Eh u remember how...| | ham|Fine if thats th...| | spam|England v Macedon...| +-----+--------------------+ only showing top 20 rows
Clean the data
from pyspark.sql.functions import length
data = data.withColumn('length',length(data['text']))
data.show()
+-----+--------------------+------+ |class| text|length| +-----+--------------------+------+ | ham|Go until jurong p...| 111| | ham|Ok lar... Joking ...| 29| | spam|Free entry in 2 a...| 155| | ham|U dun say so earl...| 49| | ham|Nah I don't think...| 61| | spam|FreeMsg Hey there...| 147| | ham|Even my brother i...| 77| | ham|As per your reque...| 160| | spam|WINNER!! As a val...| 157| | spam|Had your mobile 1...| 154| | ham|I'm gonna be home...| 109| | spam|SIX chances to wi...| 136| | spam|URGENT! You have ...| 155| | ham|I've been searchi...| 196| | ham|I HAVE A DATE ON ...| 35| | spam|XXXMobileMovieClu...| 149| | ham|Oh k...i'm watchi...| 26| | ham|Eh u remember how...| 81| | ham|Fine if thats th...| 56| | spam|England v Macedon...| 155| +-----+--------------------+------+ only showing top 20 rows
data.groupby('class').mean().show()
+-----+-----------------+ |class| avg(length)| +-----+-----------------+ | ham| 71.4545266210897| | spam|138.6706827309237| +-----+-----------------+
Feature Transformations
from pyspark.ml.feature import Tokenizer,StopWordsRemover, CountVectorizer,IDF,StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.linalg import Vector
tokenizer = Tokenizer(inputCol="text", outputCol="token_text")
stopremove = StopWordsRemover(inputCol='token_text',outputCol='stop_tokens')
count_vec = CountVectorizer(inputCol='stop_tokens',outputCol='c_vec')
idf = IDF(inputCol="c_vec", outputCol="tf_idf")
ham_spam_to_num = StringIndexer(inputCol='class',outputCol='label')
clean_up = VectorAssembler(inputCols=['tf_idf','length'],outputCol='features')
Modeling
from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()
Pipeline
from pyspark.ml import Pipeline
data_prep_pipe = Pipeline(stages=[ham_spam_to_num,tokenizer,stopremove,count_vec,idf,clean_up])
cleaner = data_prep_pipe.fit(data)
clean_data = cleaner.transform(data)
Training and Evaluation
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
clean_data = clean_data.select(['label','features'])
clean_data.show()
+-----+--------------------+ |label| features| +-----+--------------------+ | 0.0|(13424,[7,11,31,6...| | 0.0|(13424,[0,24,297,...| | 1.0|(13424,[2,13,19,3...| | 0.0|(13424,[0,70,80,1...| | 0.0|(13424,[36,134,31...| | 1.0|(13424,[10,60,139...| | 0.0|(13424,[10,53,103...| | 0.0|(13424,[125,184,4...| | 1.0|(13424,[1,47,118,...| | 1.0|(13424,[0,1,13,27...| | 0.0|(13424,[18,43,120...| | 1.0|(13424,[8,17,37,8...| | 1.0|(13424,[13,30,47,...| | 0.0|(13424,[39,96,217...| | 0.0|(13424,[552,1697,...| | 1.0|(13424,[30,109,11...| | 0.0|(13424,[82,214,47...| | 0.0|(13424,[0,2,49,13...| | 0.0|(13424,[0,74,105,...| | 1.0|(13424,[4,30,33,5...| +-----+--------------------+ only showing top 20 rows
(training,testing) = clean_data.randomSplit([0.7,0.3])
spam_predictor = nb.fit(training)
data.printSchema()
root |-- class: string (nullable = true) |-- text: string (nullable = true) |-- length: integer (nullable = true)
test_results = spam_predictor.transform(testing)
test_results.show()
+-----+--------------------+--------------------+--------------------+----------+ |label| features| rawPrediction| probability|prediction| +-----+--------------------+--------------------+--------------------+----------+ | 0.0|(13424,[0,1,2,41,...|[-1058.4236651267...|[1.0,2.0218993518...| 0.0| | 0.0|(13424,[0,1,14,31...|[-216.37757434172...|[1.0,4.6921539302...| 0.0| | 0.0|(13424,[0,1,15,20...|[-686.30933948850...|[1.0,4.7480435449...| 0.0| | 0.0|(13424,[0,1,21,27...|[-1009.2354637663...|[1.0,3.4420307844...| 0.0| | 0.0|(13424,[0,1,24,31...|[-341.12632144915...|[1.0,2.9801386551...| 0.0| | 0.0|(13424,[0,1,27,88...|[-1545.8052318883...|[0.99674540942030...| 0.0| | 0.0|(13424,[0,1,30,12...|[-612.31523960048...|[1.0,1.7729020150...| 0.0| | 0.0|(13424,[0,1,146,1...|[-250.28159301013...|[0.94419859225993...| 0.0| | 0.0|(13424,[0,1,874,1...|[-96.083169031673...|[0.99999996864916...| 0.0| | 0.0|(13424,[0,1,874,1...|[-97.770226263077...|[0.99999997561199...| 0.0| | 0.0|(13424,[0,2,3,6,9...|[-3296.5916245644...|[1.0,3.8692072210...| 0.0| | 0.0|(13424,[0,2,4,5,1...|[-2484.9794043685...|[1.0,3.2695873089...| 0.0| | 0.0|(13424,[0,2,4,5,1...|[-1610.4821878915...|[1.0,3.9135371465...| 0.0| | 0.0|(13424,[0,2,4,8,1...|[-1313.1520641704...|[1.0,5.4727792597...| 0.0| | 0.0|(13424,[0,2,4,8,2...|[-563.23303848389...|[1.0,9.1756624825...| 0.0| | 0.0|(13424,[0,2,4,25,...|[-425.86125539239...|[1.0,2.9319383384...| 0.0| | 0.0|(13424,[0,2,4,44,...|[-1892.9037879435...|[1.0,5.2089698504...| 0.0| | 0.0|(13424,[0,2,7,8,1...|[-469.57902691008...|[0.99999999433492...| 0.0| | 0.0|(13424,[0,2,7,11,...|[-727.83890019230...|[1.0,2.9949190670...| 0.0| | 0.0|(13424,[0,2,7,11,...|[-1411.8150184110...|[1.0,6.3381546067...| 0.0| +-----+--------------------+--------------------+--------------------+----------+ only showing top 20 rows
acc_eval = MulticlassClassificationEvaluator()
acc = acc_eval.evaluate(test_results)
print("Accuracy of model at predicting spam was: {}".format(acc))
Accuracy of model at predicting spam was: 0.9183086542776753