Created 7/9/2014 by KO </br> Implements propensity-score matching and eventually will implement balance diagnostics

In [1]:
%matplotlib inline
import math
import numpy as np
import scipy
from scipy.stats import binom, hypergeom
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression

Goal: find the average treatment effect in the treatment group (ATT) on RE78.

Import the data: controls and treated from Lalonde/Dehejia papers. Here's what the site says about the data:

The variables from left to right are: treatment indicator (1 if treated, 0 if not treated), age, education, Black (1 if black, 0 otherwise), Hispanic (1 if Hispanic, 0 otherwise), married (1 if married, 0 otherwise), nodegree (1 if no degree, 0 otherwise), RE74 (earnings in 1974), RE75 (earnings in 1975), and RE78 (earnings in 1978).

http://users.nber.org/%7Erdehejia/nswdata2.html

In [2]:
names = ['Treated', 'Age', 'Education', 'Black', 'Hispanic', 'Married',
         'Nodegree', 'RE74', 'RE75', 'RE78']
treated = pd.read_table('nswre74_treated.txt', sep = '\s+',
                        header = None, names = names)
control = pd.read_table('nswre74_control.txt', sep='\s+', 
                        header = None, names = names)
data = pd.concat([treated, control])
data.head()
Out[2]:
Treated Age Education Black Hispanic Married Nodegree RE74 RE75 RE78
0 1 37 11 1 0 1 1 0 0 9930.0460
1 1 22 9 0 1 0 1 0 0 3595.8940
2 1 30 12 1 0 0 0 0 0 24909.4500
3 1 27 11 1 0 0 1 0 0 7506.1460
4 1 33 8 1 0 0 1 0 0 289.7899

Compute propensity scores to start. Then we need to separate the treated and controls again (preserve original indexing) in order to match them.

Note, this section might need some fine-tuning to make it match Dehejia and Wahba (see their appendix for how they computed propensity scores)

In [3]:
propensity = LogisticRegression()
propensity = propensity.fit(data[names[1:-1]], data.Treated)
pscore = propensity.predict_proba(data[names[1:-1]])[:,1] # The predicted propensities by the model
print pscore[:5]

data['Propensity'] = pscore
#pscore = pd.Series(data = pscore, index = data.index)
[ 0.42716293  0.25617646  0.54874013  0.37386481  0.40217168]

Implement one-to-one matching, caliper without replacement. Variants of the method are examined in the following paper. This is something to explore further. </br> Austin, P. C. (2014), A comparison of 12 algorithms for matching on the propensity score. Statist. Med., 33: 1057–1069. doi: 10.1002/sim.6004

In [4]:
def Match(groups, propensity, caliper = 0.05):
    ''' 
    Inputs:
    groups = Treatment assignments.  Must be 2 groups
    propensity = Propensity scores for each observation. Propensity and groups should be in the same order (matching indices)
    caliper = Maximum difference in matched propensity scores. For now, this is a caliper on the raw
            propensity; Austin reccommends using a caliper on the logit propensity.
    
    Output:
    A series containing the individuals in the control group matched to the treatment group.
    Note that with caliper matching, not every treated individual may have a match.
    '''

    # Check inputs
    if any(propensity <=0) or any(propensity >=1):
        raise ValueError('Propensity scores must be between 0 and 1')
    elif not(0<caliper<1):
        raise ValueError('Caliper must be between 0 and 1')
    elif len(groups)!= len(propensity):
        raise ValueError('groups and propensity scores must be same dimension')
    elif len(groups.unique()) != 2:
        raise ValueError('wrong number of groups')
        
        
    # Code groups as 0 and 1
    groups = groups == groups.unique()[0]
    N = len(groups)
    N1 = groups.sum(); N2 = N-N1
    g1, g2 = propensity[groups == 1], (propensity[groups == 0])
    # Check if treatment groups got flipped - treatment (coded 1) should be the smaller
    if N1 > N2:
       N1, N2, g1, g2 = N2, N1, g2, g1 
        
        
    # Randomly permute the smaller group to get order for matching
    morder = np.random.permutation(N1)
    matches = pd.Series(np.empty(N1))
    matches[:] = np.NAN
    
    for m in morder:
        dist = abs(g1[m] - g2)
        if dist.min() <= caliper:
            matches[m] = dist.argmin()
            g2 = g2.drop(matches[m])
    return (matches)

    
    
    
    
In [5]:
stuff = Match(data.Treated, data.Propensity)
g1, g2 = data.Propensity[data.Treated==1], data.Propensity[data.Treated==0]
# test ValueError
#badtreat = data.Treated + data.Hispanic
#Match(badtreat, pscore)
stuff[:5]
Out[5]:
0     25
1    213
2     75
3     14
4    210
dtype: float64

Here's the result: if we put the propensity scores of the treatment and matched controls side-by-side, we see that they're matched pretty well.

In [6]:
zip(g1, g2[stuff])
Out[6]:
[(0.42716292681386048, 0.42581880946572531),
 (0.25617646165856756, 0.25484822320097977),
 (0.54874012613984235, 0.54874012613984235),
 (0.37386481070884442, 0.37424224379754362),
 (0.40217167818759586, 0.4025488132226952),
 (0.37341040912703743, 0.37341040912703743),
 (0.53301175942990109, 0.53492670949658028),
 (0.38451503327005687, 0.38489660886065941),
 (0.50914521709464833, 0.5038713907925424),
 (0.61794972732004938, 0.58527783340947481),
 (0.36708091086829236, 0.36693596150764657),
 (0.52310609372279693, 0.52399253077902697),
 (0.37001375859203811, 0.37001375859203811),
 (0.41040938689881862, 0.40263858902000565),
 (0.37295623086366536, 0.37330571977466881),
 (0.36206521830641381, 0.36206521830641381),
 (0.53663014121627661, 0.53887945917796798),
 (0.37046647062835597, 0.37046647062835597),
 (0.57103382954029125, 0.5709591033853374),
 (0.53976237257756132, 0.53976237257756132),
 (0.36543122767419073, 0.36543122767419073),
 (0.59343787140217641, 0.58088215756113837),
 (0.4386969080721812, 0.44247414429559007),
 (0.36753212799132406, 0.36753212799132406),
 (0.35997778168476563, 0.35997778168476563),
 (0.40954977660006303, 0.40954977660006303),
 (0.36963807095849838, 0.36918577756921234),
 (0.26069791214312099, 0.26100872287425286),
 (0.35789562808324343, 0.35789562808324343),
 (0.36753212799132406, 0.36708091086829236),
 (0.35789562808324343, 0.35789562808324343),
 (0.45658813935562875, nan),
 (0.40082653085458858, 0.40082653085458858),
 (0.52624903693537084, nan),
 (0.5375136651601029, 0.5375136651601029),
 (0.56484900646250547, 0.56420807943645801),
 (0.40038494333600216, 0.39999799811496028),
 (0.56561608250077011, 0.56652020912997003),
 (0.46330566204377877, nan),
 (0.37257932252559373, 0.37257932252559373),
 (0.52850447140293089, 0.53075874267386913),
 (0.3969042231346866, 0.39612686819916004),
 (0.36790692445253587, 0.36753212799132406),
 (0.27913948507677838, 0.27881522993582841),
 (0.35915945500472685, 0.35868111759952492),
 (0.40082653085458858, 0.40168028755982904),
 (0.36790692445253587, 0.36790692445253587),
 (0.36288641657881326, 0.36288641657881326),
 (0.40038494333600216, 0.40079387221848178),
 (0.53301175942990109, 0.57005720877929322),
 (0.3913433578168215, 0.39168310001168555),
 (0.41393321094396696, 0.4127054899751052),
 (0.35500470600275558, 0.35500470600275558),
 (0.5375136651601029, 0.57053301873387885),
 (0.4117397433545737, 0.41137419523919089),
 (0.35789562808324343, 0.35789562808324343),
 (0.40567343869825451, 0.40603741546027916),
 (0.47360442597144675, 0.44944907059288414),
 (0.59691150969843831, 0.58464545042501748),
 (0.39304342081844873, 0.39301309024977354),
 (0.36790692445253587, 0.36835856598771649),
 (0.39350085125328266, 0.39350085125328266),
 (0.39612686819916004, 0.39775517139652095),
 (0.37386481070884442, 0.37386481070884442),
 (0.41434922364455296, 0.41302230654596012),
 (0.38619843594707792, 0.38619843594707792),
 (0.37386481070884442, 0.37424224379754362),
 (0.41608090759002542, 0.41309626451423542),
 (0.47723165395085809, 0.47774167074970658),
 (0.54112719501669537, 0.55768626228483698),
 (0.38405562703559115, 0.38451503327005687),
 (0.57901039501721874, 0.57317432336471086),
 (0.38016333753745757, 0.38041199464841013),
 (0.5375136651601029, 0.55545300341207382),
 (0.52810275907963666, 0.55097984688635149),
 (0.39953799040539978, 0.39938108342140327),
 (0.53075874267386913, 0.56180397045598252),
 (0.37129563329740667, 0.37046647062835597),
 (0.40178416325708044, 0.40255931536012274),
 (0.36333543761891568, 0.36333543761891568),
 (0.5375136651601029, 0.5375136651601029),
 (0.5217350446165081, nan),
 (0.59732989436222506, 0.58907212092192596),
 (0.53663014121627661, nan),
 (0.38919010277019828, 0.38919010277019828),
 (0.39744213629889608, 0.39736320517754298),
 (0.26207005517968529, 0.26138340190382492),
 (0.37174898822506325, 0.37174898822506325),
 (0.55984202089567947, nan),
 (0.38489660886065941, 0.38514373934059354),
 (0.41393321094396696, 0.41126954472315269),
 (0.35789562808324343, 0.35789562808324343),
 (0.36963807095849838, 0.36963807095849838),
 (0.52850447140293089, 0.52822407944383731),
 (0.41051397115899052, 0.41051397115899052),
 (0.5687410368098067, 0.56887204494312937),
 (0.36498110138737444, 0.36498110138737444),
 (0.57546114010719607, 0.57669759720271807),
 (0.37469705231202188, 0.37469705231202188),
 (0.27200031650600354, 0.27659164918684331),
 (0.37257932252559373, 0.37257932252559373),
 (0.37681958989925324, 0.37681958989925324),
 (0.36963807095849838, 0.36963807095849838),
 (0.35500470600275558, 0.35500470600275558),
 (0.35707939391258053, 0.35707939391258053),
 (0.41882782270999969, 0.4199216409379678),
 (0.53663014121627661, 0.53663014121627661),
 (0.36288641657881326, 0.36288641657881326),
 (0.38024100782850595, 0.38062096495388931),
 (0.39397088933484964, 0.3917332458962281),
 (0.32878224824302671, 0.32857587970056046),
 (0.24604313298890371, 0.24531002119584097),
 (0.52145634998553481, 0.52341300607717067),
 (0.35215316763290705, 0.35293545509098728),
 (0.55291558624596349, 0.55453051526410713),
 (0.53812498686079091, 0.53760415714737053),
 (0.53165840610880255, nan),
 (0.53178741742330415, 0.53526343055701198),
 (0.36732044941345232, 0.36726529273432967),
 (0.50016320180970986, 0.54233385721831306),
 (0.54568333247369738, 0.54649842890599332),
 (0.37088397590008554, 0.37046647062835597),
 (0.37807518272603041, 0.37811089138712795),
 (0.54505617352098923, 0.54778717953003553),
 (0.37187026316150906, 0.37123230389724349),
 (0.42907939387160898, 0.42621298956406267),
 (0.503576849031942, nan),
 (0.41499892319370452, 0.40530744104465677),
 (0.34867332496059367, 0.34855987148037076),
 (0.35124881591719082, 0.34951011622716577),
 (0.368354966402969, 0.36835856598771649),
 (0.39388148189718369, 0.3939643292973713),
 (0.42287750361949428, 0.42323937900753988),
 (0.27508239938161005, 0.28098466748341883),
 (0.41651421693564455, 0.42037816699807562),
 (0.4167397609523012, 0.41919527728474437),
 (0.46806147084466776, 0.45878647462448441),
 (0.36978946463533807, 0.36963807095849838),
 (0.49379330696466373, nan),
 (0.43357631858871326, 0.43257515170041527),
 (0.34692254178035581, 0.34623057969596582),
 (0.46757669724338247, 0.46341413004896059),
 (0.37685357018163268, 0.37681958989925324),
 (0.4455565269264718, 0.44617327640136217),
 (0.61886657105467635, 0.60079264899057727),
 (0.37718188114200613, 0.37705709921672609),
 (0.40934354192118261, 0.40954977660006303),
 (0.39573595296507968, 0.39432477073157096),
 (0.44255500877946502, 0.42053379991459011),
 (0.36811932082281856, 0.36790692445253587),
 (0.42116861754671925, 0.42103202468255674),
 (0.47650112195239025, 0.47887757800201047),
 (0.39476901876612058, 0.39217707695803344),
 (0.53190717776280616, 0.54227884676845073),
 (0.49379443849344062, 0.45360505623822056),
 (0.4872454286920998, 0.48602144945577064),
 (0.41089388876758753, 0.41048788604880759),
 (0.53019919491977685, 0.52624903693537084),
 (0.42206423034004498, 0.42342902381310754),
 (0.42156230780201276, 0.42182411349268001),
 (0.41195948127964338, 0.41102028893247478),
 (0.41575142347870969, 0.41004419660549557),
 (0.37949264038566599, 0.37894686448575243),
 (0.44597592218322879, 0.44343522976592487),
 (0.54700647957502357, nan),
 (0.46839575084837853, 0.42584389373973836),
 (0.30118504354309478, 0.30266228063350109),
 (0.39535751543292041, 0.39479876377674628),
 (0.46967422150989402, 0.47016706817874493),
 (0.40711739797786478, 0.40603741546027916),
 (0.43016221895202345, 0.40603741546027916),
 (0.59539063185170427, 0.57227359694234925),
 (0.25033235153545463, 0.2506659230345058),
 (0.41813865251776078, 0.40263858902000565),
 (0.2969316019282166, 0.29322046060925033),
 (0.46046038221301638, 0.45526853386815697),
 (0.52687531615177041, 0.54771545246673836),
 (0.50045740817131545, 0.51770231215857909),
 (0.63404475620095613, 0.58689845316771205),
 (0.376835082633315, 0.37604736111069387),
 (0.59996973878777604, 0.57989993766179626),
 (0.47324609795733941, 0.47791911265297649),
 (0.53695349937194692, 0.53707042977012709),
 (0.60108829050619672, nan),
 (0.6729751518671101, 0.63068079328671744)]
In [ ]: