#|default_exp test #|export from fastcore.imports import * from collections import Counter from contextlib import redirect_stdout #|hide from nbdev.showdoc import * from fastcore.nb_imports import * #|export def test_fail(f, msg='', contains='', args=None, kwargs=None): "Fails with `msg` unless `f()` raises an exception and (optionally) has `contains` in `e.args`" args, kwargs = args or [], kwargs or {} try: f(*args, **kwargs) except Exception as e: assert not contains or contains in str(e) return assert False,f"Expected exception but none raised. {msg}" def _fail(): raise Exception("foobar") test_fail(_fail, contains="foo") def _fail(): raise Exception() test_fail(_fail) def _fail_args(a): if a == 5: raise ValueError test_fail(_fail_args, args=(5,)) test_fail(_fail_args, kwargs=dict(a=5)) #|export def test(a, b, cmp, cname=None): "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails" if cname is None: cname=cmp.__name__ assert cmp(a,b),f"{cname}:\n{a}\n{b}" test([1,2],[1,2], operator.eq) test_fail(lambda: test([1,2],[1], operator.eq)) test([1,2],[1], operator.ne) test_fail(lambda: test([1,2],[1,2], operator.ne)) show_doc(all_equal) test(['abc'], ['abc'], all_equal) test_fail(lambda: test(['abc'],['cab'], all_equal)) show_doc(equals) test([['abc'],['a']], [['abc'],['a']], equals) test([['abc'],['a'],'b', [['x']]], [['abc'],['a'],'b', [['x']]], equals) # supports any depth and nested structure #|export def nequals(a,b): "Compares `a` and `b` for `not equals`" return not equals(a,b) test(['abc'], ['ab' ], nequals) #|export def test_eq(a,b): "`test` that `a==b`" test(a,b,equals, cname='==') test_eq([1,2],[1,2]) test_eq([1,2],map(int,[1,2])) test_eq(array([1,2]),array([1,2])) test_eq(array([1,2]),array([1,2])) test_eq([array([1,2]),3],[array([1,2]),3]) test_eq(dict(a=1,b=2), dict(b=2,a=1)) test_fail(lambda: test_eq([1,2], 1), contains="==") test_fail(lambda: test_eq(None, np.array([1,2])), contains="==") test_eq({'a', 'b', 'c'}, {'c', 'a', 'b'}) #|hide import pandas as pd import torch df1 = pd.DataFrame(dict(a=[1,2],b=['a','b'])) df2 = pd.DataFrame(dict(a=[1,2],b=['a','b'])) df3 = pd.DataFrame(dict(a=[1,2],b=['a','c'])) test_eq(df1,df2) test_eq(df1.a,df2.a) test_fail(lambda: test_eq(df1,df3), contains='==') class T(pd.Series): pass test_eq(df1.iloc[0], T(df2.iloc[0])) # works with subclasses test_eq(torch.zeros(10), torch.zeros(10, dtype=torch.float64)) test_eq(torch.zeros(10), torch.ones(10)-1) test_fail(lambda:test_eq(torch.zeros(10), torch.ones(1, 10)), contains='==') test_eq(torch.zeros(3), [0,0,0]) #|export def test_eq_type(a,b): "`test` that `a==b` and are same type" test_eq(a,b) test_eq(type(a),type(b)) if isinstance(a,(list,tuple)): test_eq(map(type,a),map(type,b)) test_eq_type(1,1) test_fail(lambda: test_eq_type(1,1.)) test_eq_type([1,1],[1,1]) test_fail(lambda: test_eq_type([1,1],(1,1))) test_fail(lambda: test_eq_type([1,1],[1,1.])) #|export def test_ne(a,b): "`test` that `a!=b`" test(a,b,nequals,'!=') test_ne([1,2],[1]) test_ne([1,2],[1,3]) test_ne(array([1,2]),array([1,1])) test_ne(array([1,2]),array([1,1])) test_ne([array([1,2]),3],[array([1,2])]) test_ne([3,4],array([3])) test_ne([3,4],array([3,5])) test_ne(dict(a=1,b=2), ['a', 'b']) test_ne(['a', 'b'], dict(a=1,b=2)) #|export def is_close(a,b,eps=1e-5): "Is `a` within `eps` of `b`" if hasattr(a, '__array__') or hasattr(b,'__array__'): return (abs(a-b) 0 else '') test_stdout(lambda: print('hi'), 'hi') test_fail(lambda: test_stdout(lambda: print('hi'), 'ho')) test_stdout(lambda: 1+1, '') test_stdout(lambda: print('hi there!'), r'^hi.*!$', regex=True) #|export def test_warns(f, show=False): with warnings.catch_warnings(record=True) as w: f() assert w, "No warnings raised" if show: for e in w: print(f"{e.category}: {e.message}") test_warns(lambda: warnings.warn("Oh no!")) test_fail(lambda: test_warns(lambda: 2+2), contains='No warnings raised') test_warns(lambda: warnings.warn("Oh no!"), show=True) #|export TEST_IMAGE = 'images/puppy.jpg' im = Image.open(TEST_IMAGE).resize((128,128)); im #|export TEST_IMAGE_BW = 'images/mnist3.png' im = Image.open(TEST_IMAGE_BW).resize((128,128)); im #|export def test_fig_exists(ax): "Test there is a figure displayed in `ax`" assert ax and len(ax.figure.canvas.tostring_argb()) fig,ax = plt.subplots() ax.imshow(array(im)); test_fig_exists(ax) #|export class ExceptionExpected: "Context manager that tests if an exception is raised" def __init__(self, ex=Exception, regex=''): self.ex,self.regex = ex,regex def __enter__(self): pass def __exit__(self, type, value, traceback): if not isinstance(value, self.ex) or (self.regex and not re.search(self.regex, f'{value.args}')): raise TypeError(f"Expected {self.ex.__name__}({self.regex}) not raised.") return True def _tst_1(): assert False, "This is a test" def _tst_2(): raise SyntaxError with ExceptionExpected(): _tst_1() with ExceptionExpected(ex=AssertionError, regex="This is a test"): _tst_1() with ExceptionExpected(ex=SyntaxError): _tst_2() #|export exception = ExceptionExpected() with exception: _tst_1() #|hide def _f(): with ExceptionExpected(): 1 test_fail(partial(_f)) def _f(): with ExceptionExpected(SyntaxError): assert False test_fail(partial(_f)) def _f(): with ExceptionExpected(AssertionError, "Yes"): assert False, "No" test_fail(partial(_f)) #|hide #|eval: false from nbdev import nbdev_export nbdev_export()