#|hide #|default_exp maker #|export from __future__ import annotations #|export from nbdev.config import * from nbdev.imports import * from fastcore.script import * from fastcore.basics import * from fastcore.imports import * from execnb.nbio import * import ast,contextlib from collections import defaultdict from pprint import pformat from textwrap import TextWrapper #|hide from fastcore.test import * from pdb import set_trace from importlib import reload #|export def find_var(lines, varname): "Find the line numbers where `varname` is defined in `lines`" start = first(i for i,o in enumerate(lines) if o.startswith(varname)) if start is None: return None,None empty = ' ','\t' if start==len(lines)-1 or lines[start+1][:1] not in empty: return start,start+1 end = first(i for i,o in enumerate(lines[start+1:]) if o[:1] not in empty) return start,len(lines) if end is None else (end+start+1) t = '''a_=(1, 2, 3) b_=3''' test_eq(find_var(t.splitlines(), 'a_'), (0,3)) test_eq(find_var(t.splitlines(), 'b_'), (4,5)) #|export def read_var(code, varname): "Eval and return the value of `varname` defined in `code`" lines = code.splitlines() start,end = find_var(lines, varname) if start is None: return None res = [lines[start].split('=')[-1].strip()] res += lines[start+1:end] try: return eval('\n'.join(res)) except SyntaxError: raise Exception('\n'.join(res)) from None test_eq(read_var(t, 'a_'), (1,2,3)) test_eq(read_var(t, 'b_'), 3) #|export def update_var(varname, func, fn=None, code=None): "Update the definition of `varname` in file `fn`, by calling `func` with the current definition" if fn: fn = Path(fn) code = fn.read_text() lines = code.splitlines() v = read_var(code, varname) res = func(v) start,end = find_var(lines, varname) del(lines[start:end]) lines.insert(start, f"{varname} = {res}") code = '\n'.join(lines) if fn: fn.write_text(code) else: return code g = exec_new(t) test_eq((g['a_'],g['b_']), ((1,2,3),3)) t2 = update_var('a_', lambda o:0, code=t) exec(t2, g) test_eq((g['a_'],g['b_']), (0,3)) t3 = update_var('b_', lambda o:0, code=t) exec(t3, g) test_eq((g['a_'],g['b_']), ((1,2,3),0)) #|export class ModuleMaker: "Helper class to create exported library from notebook source cells" def __init__(self, dest, name, nb_path, is_new=True, parse=True): dest,nb_path = Path(dest),Path(nb_path) store_attr() self.fname = dest/(name.replace('.','/') + ".py") if is_new: dest.mkdir(parents=True, exist_ok=True) else: assert self.fname.exists(), f"{self.fname} does not exist" self.dest2nb = nb_path.relpath(dest) self.hdr = f"# %% {self.dest2nb}" mm = ModuleMaker(dest='tmp', name='test.testing', nb_path=Path.cwd()/'01_export.ipynb', is_new=True) mm.fname #|export def decor_id(d): "`id` attr of decorator, regardless of whether called as function or bare" return d.id if hasattr(d, 'id') else nested_attr(d, 'func.id', '') #|export _def_types = ast.FunctionDef,ast.AsyncFunctionDef,ast.ClassDef _assign_types = ast.AnnAssign, ast.Assign, ast.AugAssign def _val_or_id(it): if sys.version_info < (3,8): return [getattr(o, 's', getattr(o, 'id', None)) for o in it.value.elts] else:return [getattr(o, 'value', getattr(o, 'id', None)) for o in it.value.elts] def _all_targets(a): return L(getattr(a,'elts',a)) def _filt_dec(x): return decor_id(x).startswith('patch') def _wants(o): return isinstance(o,_def_types) and not any(L(o.decorator_list).filter(_filt_dec)) #|export def retr_exports(trees): # include anything mentioned in "_all_", even if otherwise private # NB: "_all_" can include strings (names), or symbols, so we look for "id" or "value" assigns = trees.filter(risinstance(_assign_types)) all_assigns = assigns.filter(lambda o: getattr(o.targets[0],'id',None)=='_all_') all_vals = all_assigns.map(_val_or_id).concat() syms = trees.filter(_wants).attrgot('name') # assignment targets (NB: can be multiple, e.g. "a=b=c", and/or destructuring e.g "a,b=(1,2)") assign_targs = L(L(assn.targets).map(_all_targets).concat() for assn in assigns).concat() exports = (assign_targs.attrgot('id')+syms).filter(lambda o: o and o[0]!='_') return (exports+all_vals).unique() #|export @patch def make_all(self:ModuleMaker, cells): "Create `__all__` with all exports in `cells`" if cells is None: return '' return retr_exports(L(cells).map(NbCell.parsed_).concat()) #|export def make_code_cells(*ss): return dict2nb({'cells':L(ss).map(mk_cell)}).cells nb = make_code_cells("from __future__ import print_function", "def a():...", "def b():...", "c=d=1", "_f=1", "_g=1", "_h=1", "_all_=['_g', _h]", "@patch\ndef h(self:ca):...") test_eq(set(mm.make_all(nb)), set(['a','b','c','d', '_g', '_h'])) #|export def relative_import(name, fname, level=0): "Convert a module `name` to a name relative to `fname`" assert not level sname = name.replace('.','/') if not(os.path.commonpath([sname,fname])): return name rel = os.path.relpath(sname, fname) if rel==".": return "." res = rel.replace(f"..{os.path.sep}", ".") if not all(o=='.' for o in res): res='.'+res return res.replace(os.path.sep, ".") test_eq(relative_import('nbdev.core', "xyz"), 'nbdev.core') test_eq(relative_import('nbdev.core', 'nbdev'), '.core') _p = Path('fastai') test_eq(relative_import('fastai.core', _p/'vision'), '..core') test_eq(relative_import('fastai.core', _p/'vision/transform'), '...core') test_eq(relative_import('fastai.vision.transform', _p/'vision'), '.transform') test_eq(relative_import('fastai.notebook.core', _p/'data'), '..notebook.core') test_eq(relative_import('fastai.vision', _p/'vision'), '.') test_eq(relative_import('fastai', _p), '.') test_eq(relative_import('fastai', _p/'vision'), '..') test_eq(relative_import('fastai', _p/'vision/transform'), '...') #|export # Based on https://github.com/thonny/thonny/blob/master/thonny/ast_utils.py def _mark_text_ranges( source: str|bytes, # Source code to add ranges to ): "Adds `end_lineno` and `end_col_offset` to each `node` recursively. Used for Python 3.7 compatibility" from asttokens.asttokens import ASTTokens # We need to reparse the source to get a full tree to walk root = ast.parse(source) ASTTokens(source, tree=root) for child in ast.walk(root): if hasattr(child,"last_token"): child.end_lineno,child.end_col_offset = child.last_token.end # Some tokens stay without end info if hasattr(child,"lineno") and (not hasattrs(child, ["end_lineno","end_col_offset"])): child.end_lineno, child.end_col_offset = child.lineno, child.col_offset+2 return root.body #|export def update_import(source, tree, libname, f=relative_import): if not tree: return if sys.version_info < (3,8): tree = _mark_text_ranges(source) imps = L(tree).filter(risinstance(ast.ImportFrom)) if not imps: return src = source.splitlines(True) for imp in imps: nmod = f(imp.module, libname, imp.level) lin = imp.lineno-1 sec = src[lin][imp.col_offset:imp.end_col_offset] newsec = re.sub(f"(from +){'.'*imp.level}{imp.module}", fr"\1{nmod}", sec) src[lin] = src[lin].replace(sec,newsec) return src @patch def import2relative(cell:NbCell, libname): src = update_import(cell.source, cell.parsed_(), libname) if src: cell.set_source(src) ss = "from nbdev.export import *\nfrom nbdev.a.b import *" cell = make_code_cells([ss])[0] cell.import2relative('nbdev') test_eq(cell.source, 'from .export import *\nfrom .a.b import *') cell = make_code_cells([ss])[0] cell.import2relative('nbdev/a') test_eq(cell.source, 'from ..export import *\nfrom .b import *') #|export @patch def _last_future(self:ModuleMaker, cells): "Returns the location of a `__future__` in `cells`" trees = cells.map(NbCell.parsed_) try: return max(i for i,tree in enumerate(trees) if tree and any( isinstance(t,ast.ImportFrom) and t.module=='__future__' for t in tree))+1 except ValueError: return 0 #|export def _import2relative(cells, lib_name=None): "Converts `cells` to use `import2relative` based on `lib_name`" if lib_name is None: lib_name = get_config().lib_name for cell in cells: cell.import2relative(lib_name) #|export @patch def make(self:ModuleMaker, cells, all_cells=None, lib_path=None): "Write module containing `cells` with `__all__` generated from `all_cells`" if all_cells is None: all_cells = cells cells,all_cells = L(cells),L(all_cells) if self.parse: if not lib_path: lib_path = get_config().path('lib_path') mod_dir = os.path.relpath(self.fname.parent, Path(lib_path).parent) _import2relative(all_cells, mod_dir) if not self.is_new: return self._make_exists(cells, all_cells) self.fname.parent.mkdir(exist_ok=True, parents=True) last_future = 0 if self.parse: _all = self.make_all(all_cells) last_future = self._last_future(cells) if len(all_cells)>0 else 0 tw = TextWrapper(width=120, initial_indent='', subsequent_indent=' '*11, break_long_words=False) all_str = '\n'.join(tw.wrap(str(_all))) with self.fname.open('w') as f: f.write(f"# AUTOGENERATED! DO NOT EDIT! File to edit: {self.dest2nb}.") if last_future > 0: write_cells(cells[:last_future], self.hdr, f) if self.parse: f.write(f"\n\n# %% auto 0\n__all__ = {all_str}") write_cells(cells[last_future:], self.hdr, f, 1 if last_future>0 else 0) f.write('\n') cells = make_code_cells("from __future__ import print_function", "#|export\ndef a(): ...", "def b(): ...") mm.make(cells, L([cells[1]])) show_src(Path('tmp/test/testing.py').read_text()) am = ModuleMaker(dest='tmp', name='test.testing_noall', nb_path=Path.cwd()/'01_export.ipynb', is_new=True, parse=False) am.fname cells = make_code_cells("from __future__ import print_function", "#|export\ndef a(): ...", "#|export\nclass A:") am.make(cells) show_src(Path('tmp/test/testing_noall.py').read_text()) #|export @patch def _update_all(self:ModuleMaker, all_cells, alls): return pformat(alls + self.make_all(all_cells), width=160) @patch def _make_exists(self:ModuleMaker, cells, all_cells=None): "`make` for `is_new=False`" if all_cells and self.parse: update_var('__all__', partial(self._update_all, all_cells), fn=self.fname) with self.fname.open('a') as f: write_cells(cells, self.hdr, f) c2 = make_code_cells("def c(): ...", "def d(): ...") mm = ModuleMaker(dest='tmp', name='test.testing', nb_path=Path.cwd()/'01_export.ipynb', is_new=False) mm.make(c2, c2) show_src(Path('tmp/test/testing.py').read_text()) g = exec_import('.tmp.test.testing', '*') for s in "a c d".split(): assert s in g, s assert 'b' not in g assert g['a']() is None #|export def basic_export_nb2(fname, name, dest=None): "A basic exporter to bootstrap nbdev using `ModuleMaker`" if dest is None: dest = get_config().path('lib_path') cells = L(c for c in read_nb(fname).cells if re.match(r'#\|\s*export', c.source)) ModuleMaker(dest=dest, name=name, nb_path=fname).make(cells) #|eval: false path = Path('../nbdev') (path/'read.py').unlink(missing_ok=True) (path/'maker.py').unlink(missing_ok=True) add_init(path) cfg = get_config() basic_export_nb2('01_read.ipynb', 'read') basic_export_nb2('02_maker.ipynb', 'maker') g = exec_import('nbdev', 'maker') assert g['maker'].ModuleMaker assert 'ModuleMaker' in g['maker'].__all__