#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('matplotlib', 'inline') # # Scatterplot with Categories # A function to quickly produce a scatter plot colored by categories from a pandas `DataFrame` or NumPy `ndarray` object. # > from mlxtend.general_plotting import category_scatter # ## Overview # ### References # # - - # ## Example 1 - Category Scatter from Pandas DataFrames # In[2]: import pandas as pd from io import StringIO csvfile = """label,x,y class1,10.0,8.04 class1,10.5,7.30 class2,8.3,5.5 class2,8.1,5.9 class3,3.5,3.5 class3,3.8,5.1""" df = pd.read_csv(StringIO(csvfile)) df # Plotting the data where the categories are determined by the unique values in the label column `label_col`. The `x` and `y` values are simply the column names of the DataFrame that we want to plot. # In[3]: import matplotlib.pyplot as plt from mlxtend.plotting import category_scatter fig = category_scatter(x='x', y='y', label_col='label', data=df, legend_loc='upper left') # ## Example 2 - Category Scatter from NumPy Arrays # In[4]: import numpy as np from io import BytesIO csvfile = """1,10.0,8.04 1,10.5,7.30 2,8.3,5.5 2,8.1,5.9 3,3.5,3.5 3,3.8,5.1""" ary = np.genfromtxt(BytesIO(csvfile.encode()), delimiter=',') ary # Now, pretending that the first column represents the labels, and the second and third column represent the `x` and `y` values, respectively. # In[5]: import matplotlib.pyplot as plt from mlxtend.plotting import category_scatter fix = category_scatter(x=1, y=2, label_col=0, data=ary, legend_loc='upper left') # ## API # In[6]: with open('../../api_modules/mlxtend.plotting/category_scatter.md', 'r') as f: print(f.read())