#|hide
#|default_exp maker
#|export
from __future__ import annotations
Create one or more modules from selected notebook cells
#|export
from nbdev.read import *
from nbdev.imports import *
from fastcore.script import *
from fastcore.basics import *
from fastcore.imports import *
from execnb.nbio import *
import ast,contextlib
from collections import defaultdict
from pprint import pformat
from textwrap import TextWrapper
#|hide
from fastcore.test import *
from pdb import set_trace
from importlib import reload
These functions let us find and modify the definitions of variables in Python modules.
#|export
def find_var(lines, varname):
"Find the line numbers where `varname` is defined in `lines`"
start = first(i for i,o in enumerate(lines) if o.startswith(varname))
if start is None: return None,None
empty = ' ','\t'
if start==len(lines)-1 or lines[start+1][:1] not in empty: return start,start+1
end = first(i for i,o in enumerate(lines[start+1:]) if o[:1] not in empty)
return start,len(lines) if end is None else (end+start+1)
t = '''a_=(1,
2,
3)
b_=3'''
test_eq(find_var(t.splitlines(), 'a_'), (0,3))
test_eq(find_var(t.splitlines(), 'b_'), (4,5))
#|export
def read_var(code, varname):
"Eval and return the value of `varname` defined in `code`"
lines = code.splitlines()
start,end = find_var(lines, varname)
if start is None: return None
res = [lines[start].split('=')[-1].strip()]
res += lines[start+1:end]
try: return eval('\n'.join(res))
except SyntaxError: raise Exception('\n'.join(res)) from None
test_eq(read_var(t, 'a_'), (1,2,3))
test_eq(read_var(t, 'b_'), 3)
#|export
def update_var(varname, func, fn=None, code=None):
"Update the definition of `varname` in file `fn`, by calling `func` with the current definition"
if fn:
fn = Path(fn)
code = fn.read_text()
lines = code.splitlines()
v = read_var(code, varname)
res = func(v)
start,end = find_var(lines, varname)
del(lines[start:end])
lines.insert(start, f"{varname} = {res}")
code = '\n'.join(lines)
if fn: fn.write_text(code)
else: return code
g = exec_new(t)
test_eq((g['a_'],g['b_']), ((1,2,3),3))
t2 = update_var('a_', lambda o:0, code=t)
exec(t2, g)
test_eq((g['a_'],g['b_']), (0,3))
t3 = update_var('b_', lambda o:0, code=t)
exec(t3, g)
test_eq((g['a_'],g['b_']), ((1,2,3),0))
#|export
class ModuleMaker:
"Helper class to create exported library from notebook source cells"
def __init__(self, dest, name, nb_path, is_new=True, parse=True):
dest,nb_path = Path(dest),Path(nb_path)
store_attr()
self.fname = dest/(name.replace('.','/') + ".py")
if is_new: dest.mkdir(parents=True, exist_ok=True)
else: assert self.fname.exists(), f"{self.fname} does not exist"
self.dest2nb = nb_path.relpath(dest)
self.hdr = f"# %% {self.dest2nb}"
In order to export a notebook, we need an way to create a Python file. ModuleMaker
fills that role. Pass in the directory where you want to module created, the name of the module, the path of the notebook source, and set is_new
to True
if this is a new file being created (rather than an existing file being added to). The location of the saved module will be in fname
. Finally, if the source in the notebooks should not be parsed by Python (such as partial class declarations in cells), parse
should be set to False
.
Note: If doing so, then the
__all__
generation will be turned off as well.
mm = ModuleMaker(dest='tmp', name='test.testing', nb_path=Path.cwd()/'01_export.ipynb', is_new=True)
mm.fname
Path('tmp/test/testing.py')
#|export
def decor_id(d):
"`id` attr of decorator, regardless of whether called as function or bare"
return d.id if hasattr(d, 'id') else nested_attr(d, 'func.id', '')
#|export
_def_types = ast.FunctionDef,ast.AsyncFunctionDef,ast.ClassDef
_assign_types = ast.AnnAssign, ast.Assign, ast.AugAssign
def _val_or_id(it):
if sys.version_info < (3,8): return [getattr(o, 's', None) for o in it.value.elts]
else:return [getattr(o, 'value', getattr(o, 'id', None)) for o in it.value.elts]
def _all_targets(a): return L(getattr(a,'elts',a))
def _filt_dec(x): return decor_id(x).startswith('patch')
def _wants(o): return isinstance(o,_def_types) and not any(L(o.decorator_list).filter(_filt_dec))
#|export
def retr_exports(trees):
# include anything mentioned in "_all_", even if otherwise private
# NB: "_all_" can include strings (names), or symbols, so we look for "id" or "value"
assigns = trees.filter(risinstance(_assign_types))
all_assigns = assigns.filter(lambda o: getattr(o.targets[0],'id',None)=='_all_')
all_vals = all_assigns.map(_val_or_id).concat()
syms = trees.filter(_wants).attrgot('name')
# assignment targets (NB: can be multiple, e.g. "a=b=c", and/or destructuring e.g "a,b=(1,2)")
assign_targs = L(L(assn.targets).map(_all_targets).concat() for assn in assigns).concat()
exports = (assign_targs.attrgot('id')+syms).filter(lambda o: o and o[0]!='_')
return (exports+all_vals).unique()
#|export
@patch
def make_all(self:ModuleMaker, cells):
"Create `__all__` with all exports in `cells`"
if cells is None: return ''
return retr_exports(cells.map(NbCell.parsed_).concat())
#|export
def make_code_cell(code): return AttrDict(source=code, cell_type="code", execution_count=None)
def make_code_cells(*ss): return dict2nb({'cells':L(ss).map(make_code_cell)}).cells
We want to add an __all__
to the top of the exported module. This methods autogenerates it from all code in cells
.
nb = make_code_cells("from __future__ import print_function", "def a():...", "def b():...",
"c=d=1", "_f=1", "_g=1", "_all_=['_g']", "@patch\ndef h(self:ca):...")
test_eq(set(mm.make_all(nb)), set(['a','b','c','d', '_g']))
#|export
def relative_import(name, fname, level=0):
"Convert a module `name` to a name relative to `fname`"
assert not level
sname = name.replace('.','/')
if not(os.path.commonpath([sname,fname])): return name
rel = os.path.relpath(sname, fname)
if rel==".": return "."
res = rel.replace(f"..{os.path.sep}", ".")
return "." + res.replace(os.path.sep, ".")
test_eq(relative_import('nbdev.core', "xyz"), 'nbdev.core')
test_eq(relative_import('nbdev.core', 'nbdev'), '.core')
_p = Path('fastai')
test_eq(relative_import('fastai.core', _p/'vision'), '..core')
test_eq(relative_import('fastai.core', _p/'vision/transform'), '...core')
test_eq(relative_import('fastai.vision.transform', _p/'vision'), '.transform')
test_eq(relative_import('fastai.notebook.core', _p/'data'), '..notebook.core')
test_eq(relative_import('fastai.vision', _p/'vision'), '.')
#|export
# Based on https://github.com/thonny/thonny/blob/master/thonny/ast_utils.py
def _mark_text_ranges(
source: str|bytes, # Source code to add ranges to
):
"Adds `end_lineno` and `end_col_offset` to each `node` recursively. Used for Python 3.7 compatibility"
from asttokens.asttokens import ASTTokens
# We need to reparse the source to get a full tree to walk
root = ast.parse(source)
ASTTokens(source, tree=root)
for child in ast.walk(root):
if hasattr(child,"last_token"):
child.end_lineno,child.end_col_offset = child.last_token.end
# Some tokens stay without end info
if hasattr(child,"lineno") and (not hasattrs(child, ["end_lineno","end_col_offset"])):
child.end_lineno, child.end_col_offset = child.lineno, child.col_offset+2
return root.body
#|export
def update_import(source, tree, libname, f=relative_import):
if not tree: return
if sys.version_info < (3,8): tree = _mark_text_ranges(source)
imps = L(tree).filter(risinstance(ast.ImportFrom))
if not imps: return
src = source.splitlines(True)
for imp in imps:
nmod = f(imp.module, libname, imp.level)
lin = imp.lineno-1
sec = src[lin][imp.col_offset:imp.end_col_offset]
newsec = re.sub(f"(from +){'.'*imp.level}{imp.module}", fr"\1{nmod}", sec)
src[lin] = src[lin].replace(sec,newsec)
return src
@patch
def import2relative(cell:NbCell, libname):
src = update_import(cell.source, cell.parsed_(), libname)
if src: cell.set_source(src)
ss = "from nbdev.export import *\nfrom nbdev.a.b import *"
cell = make_code_cells([ss])[0]
cell.import2relative('nbdev')
test_eq(cell.source, 'from .export import *\nfrom .a.b import *')
cell = make_code_cells([ss])[0]
cell.import2relative('nbdev/a')
test_eq(cell.source, 'from ..export import *\nfrom .b import *')
#|export
@patch
def _last_future(self:ModuleMaker, cells):
"Returns the location of a `__future__` in `cells`"
trees = cells.map(NbCell.parsed_)
try: return max(i for i,tree in enumerate(trees) if tree and any(
isinstance(t,ast.ImportFrom) and t.module=='__future__' for t in tree))+1
except ValueError: return 0
#|export
def _import2relative(cells, lib_name=None):
"Converts `cells` to use `import2relative` based on `lib_name`"
if lib_name is None: lib_name = get_config().lib_name
for cell in cells: cell.import2relative(lib_name)
#|export
@patch
def make(self:ModuleMaker, cells, all_cells=None, lib_name=None):
"Write module containing `cells` with `__all__` generated from `all_cells`"
if all_cells is None: all_cells = cells
if self.parse:
libnm = get_config().path('lib_path')
mod_dir = os.path.relpath(self.fname.parent, libnm.parent)
_import2relative(all_cells, mod_dir)
if not self.is_new: return self._make_exists(cells, all_cells)
self.fname.parent.mkdir(exist_ok=True, parents=True)
last_future = 0
if self.parse:
_all = self.make_all(all_cells)
last_future = self._last_future(cells) if len(all_cells)>0 else 0
tw = TextWrapper(width=120, initial_indent='', subsequent_indent=' '*11, break_long_words=False)
all_str = '\n'.join(tw.wrap(str(_all)))
with self.fname.open('w') as f:
f.write(f"# AUTOGENERATED! DO NOT EDIT! File to edit: {self.dest2nb}.")
if last_future > 0: write_cells(cells[:last_future], self.hdr, f)
if self.parse: f.write(f"\n\n# %% auto 0\n__all__ = {all_str}")
write_cells(cells[last_future:], self.hdr, f, 1 if last_future>0 else 0)
f.write('\n')
def _print_file(fname, mx=None): print(Path(fname).read_text().strip()[:ifnone(mx,9999)])
cells = make_code_cells("from __future__ import print_function", "#|export\ndef a(): ...", "def b(): ...")
mm.make(cells, L([cells[1]]))
print(Path('tmp/test/testing.py').read_text())
# AUTOGENERATED! DO NOT EDIT! File to edit: ../01_export.ipynb. # %% ../01_export.ipynb 0 from __future__ import print_function # %% auto 0 __all__ = ['a'] # %% ../01_export.ipynb 2 #|export def a(): ... # %% ../01_export.ipynb 3 def b(): ...
Pass all_cells=[]
or parse=False
if you don't want any __all__
added.
Passing parse=False
is also handy for when writing broken up functions or classes that ast.parse
might not like but still want it to be exported, such as having once cell with the contents of:
#|export
class A:
Note that by doing so we cannot properly generate a __all__
, so we assume that it is unwanted.
am = ModuleMaker(dest='tmp', name='test.testing_noall', nb_path=Path.cwd()/'01_export.ipynb', is_new=True, parse=False)
am.fname
Path('tmp/test/testing_noall.py')
cells = make_code_cells("from __future__ import print_function", "#|export\ndef a(): ...", "#|export\nclass A:")
am.make(cells)
print(Path('tmp/test/testing_noall.py').read_text())
# AUTOGENERATED! DO NOT EDIT! File to edit: ../01_export.ipynb. # %% ../01_export.ipynb 0 from __future__ import print_function # %% ../01_export.ipynb 1 #|export def a(): ... # %% ../01_export.ipynb 2 #|export class A:
#|export
@patch
def _update_all(self:ModuleMaker, all_cells, alls):
return pformat(alls + self.make_all(all_cells), width=160)
@patch
def _make_exists(self:ModuleMaker, cells, all_cells=None):
"`make` for `is_new=False`"
if all_cells and self.parse: update_var('__all__', partial(self._update_all, all_cells), fn=self.fname)
with self.fname.open('a') as f: write_cells(cells, self.hdr, f)
If is_new=False
then the additional definitions are added to the bottom, and any existing __all__
is updated with the newly-added symbols.
c2 = make_code_cells("def c(): ...", "def d(): ...")
mm = ModuleMaker(dest='tmp', name='test.testing', nb_path=Path.cwd()/'01_export.ipynb', is_new=False)
mm.make(c2, c2)
print(Path('tmp/test/testing.py').read_text())
# AUTOGENERATED! DO NOT EDIT! File to edit: ../01_export.ipynb. # %% ../01_export.ipynb 0 from __future__ import print_function # %% auto 0 __all__ = ['a', 'c', 'd'] # %% ../01_export.ipynb 2 #|export def a(): ... # %% ../01_export.ipynb 3 def b(): ... # %% ../01_export.ipynb 0 def c(): ... # %% ../01_export.ipynb 1 def d(): ...
g = exec_import('.tmp.test.testing', '*')
for s in "a c d".split(): assert s in g, s
assert 'b' not in g
assert g['a']() is None
#|export
def basic_export_nb2(fname, name, dest=None):
"A basic exporter to bootstrap nbdev using `ModuleMaker`"
if dest is None: dest = get_config().path('lib_path')
cells = L(c for c in read_nb(fname).cells if re.match(r'#\|\s*export', c.source))
ModuleMaker(dest=dest, name=name, nb_path=fname).make(cells)
#|eval: false
path = Path('../nbdev')
(path/'read.py').unlink(missing_ok=True)
(path/'maker.py').unlink(missing_ok=True)
add_init(path)
cfg = get_config()
basic_export_nb2('01_read.ipynb', 'read')
basic_export_nb2('02_maker.ipynb', 'maker')
g = exec_import('nbdev', 'maker')
assert g['maker'].ModuleMaker
assert 'ModuleMaker' in g['maker'].__all__