#|hide
#|default_exp maker
#|export
from __future__ import annotations
Create one or more modules from selected notebook cells
#|export
from nbdev.config import *
from nbdev.imports import *
from fastcore.script import *
from fastcore.basics import *
from fastcore.imports import *
from execnb.nbio import *
import ast,contextlib
from collections import defaultdict
from pprint import pformat
from textwrap import TextWrapper
#|hide
from fastcore.test import *
from pdb import set_trace
from importlib import reload
from nbdev.showdoc import *
These functions let us find and modify the definitions of variables in Python modules.
#|export
def find_var(lines, varname):
"Find the line numbers where `varname` is defined in `lines`"
start = first(i for i,o in enumerate(lines) if o.startswith(varname))
if start is None: return None,None
empty = ' ','\t'
if start==len(lines)-1 or lines[start+1][:1] not in empty: return start,start+1
end = first(i for i,o in enumerate(lines[start+1:]) if o[:1] not in empty)
return start,len(lines) if end is None else (end+start+1)
t = '''a_=(1,
2,
3)
b_=3'''
test_eq(find_var(t.splitlines(), 'a_'), (0,3))
test_eq(find_var(t.splitlines(), 'b_'), (4,5))
#|export
def read_var(code, varname):
"Eval and return the value of `varname` defined in `code`"
lines = code.splitlines()
start,end = find_var(lines, varname)
if start is None: return None
res = [lines[start].split('=')[-1].strip()]
res += lines[start+1:end]
try: return eval('\n'.join(res))
except SyntaxError: raise Exception('\n'.join(res)) from None
test_eq(read_var(t, 'a_'), (1,2,3))
test_eq(read_var(t, 'b_'), 3)
#|export
def update_var(varname, func, fn=None, code=None):
"Update the definition of `varname` in file `fn`, by calling `func` with the current definition"
if fn:
fn = Path(fn)
code = fn.read_text()
lines = code.splitlines()
v = read_var(code, varname)
res = func(v)
start,end = find_var(lines, varname)
del(lines[start:end])
lines.insert(start, f"{varname} = {res}")
code = '\n'.join(lines)
if fn: fn.write_text(code)
else: return code
g = exec_new(t)
test_eq((g['a_'],g['b_']), ((1,2,3),3))
t2 = update_var('a_', lambda o:0, code=t)
exec(t2, g)
test_eq((g['a_'],g['b_']), (0,3))
t3 = update_var('b_', lambda o:0, code=t)
exec(t3, g)
test_eq((g['a_'],g['b_']), ((1,2,3),0))
#|export
class ModuleMaker:
"Helper class to create exported library from notebook source cells"
def __init__(self, dest, name, nb_path, is_new=True, parse=True):
dest,nb_path = Path(dest),Path(nb_path)
store_attr()
self.fname = dest/(name.replace('.','/') + ".py")
if is_new: dest.mkdir(parents=True, exist_ok=True)
else: assert self.fname.exists(), f"{self.fname} does not exist"
self.dest2nb = nb_path.relpath(self.fname.parent)
self.hdr = f"# %% {self.dest2nb}"
In order to export a notebook, we need an way to create a Python file. ModuleMaker
fills that role. Pass in the directory where you want to module created, the name of the module, the path of the notebook source, and set is_new
to True
if this is a new file being created (rather than an existing file being added to). The location of the saved module will be in fname
. Finally, if the source in the notebooks should not be parsed by Python (such as partial class declarations in cells), parse
should be set to False
.
Note: If doing so, then the
__all__
generation will be turned off as well.
mm = ModuleMaker(dest='tmp', name='test.testing', nb_path=Path.cwd()/'01_export.ipynb', is_new=True)
mm.fname
Path('tmp/test/testing.py')
#|export
def decor_id(d):
"`id` attr of decorator, regardless of whether called as function or bare"
return d.id if hasattr(d, 'id') else nested_attr(d, 'func.id', '')
#|export
_def_types = ast.FunctionDef,ast.AsyncFunctionDef,ast.ClassDef
_assign_types = ast.AnnAssign, ast.Assign, ast.AugAssign
def _val_or_id(it):
if sys.version_info < (3,8): return [getattr(o, 's', getattr(o, 'id', None)) for o in it.value.elts]
else:return [getattr(o, 'value', getattr(o, 'id', None)) for o in it.value.elts]
def _all_targets(a): return L(getattr(a,'elts',a))
def _filt_dec(x): return decor_id(x).startswith('patch')
def _wants(o): return isinstance(o,_def_types) and not any(L(o.decorator_list).filter(_filt_dec))
#|export
def retr_exports(trees):
# include anything mentioned in "_all_", even if otherwise private
# NB: "_all_" can include strings (names), or symbols, so we look for "id" or "value"
assigns = trees.filter(risinstance(_assign_types))
all_assigns = assigns.filter(lambda o: getattr(o.targets[0],'id',None)=='_all_')
all_vals = all_assigns.map(_val_or_id).concat()
syms = trees.filter(_wants).attrgot('name')
# assignment targets (NB: can be multiple, e.g. "a=b=c", and/or destructuring e.g "a,b=(1,2)")
assign_targs = L(L(assn.targets).map(_all_targets).concat() for assn in assigns).concat()
exports = (assign_targs.attrgot('id')+syms).filter(lambda o: o and o[0]!='_')
return (exports+all_vals).unique()
#|export
@patch
def make_all(self:ModuleMaker, cells):
"Create `__all__` with all exports in `cells`"
if cells is None: return ''
return retr_exports(L(cells).map(NbCell.parsed_).concat())
#|export
def make_code_cells(*ss): return dict2nb({'cells':L(ss).map(mk_cell)}).cells
We want to add an __all__
to the top of the exported module. This methods autogenerates it from all code in cells
.
nb = make_code_cells("from __future__ import print_function", "def a():...", "def b():...",
"c=d=1", "_f=1", "_g=1", "_h=1", "_all_=['_g', _h]", "@patch\ndef h(self:ca):...")
test_eq(set(mm.make_all(nb)), set(['a','b','c','d', '_g', '_h']))
#|export
def relative_import(name, fname, level=0):
"Convert a module `name` to a name relative to `fname`"
assert not level
sname = name.replace('.','/')
if not(os.path.commonpath([sname,fname])): return name
rel = os.path.relpath(sname, fname)
if rel==".": return "."
res = rel.replace(f"..{os.path.sep}", ".")
if not all(o=='.' for o in res): res='.'+res
return res.replace(os.path.sep, ".")
test_eq(relative_import('nbdev.core', "xyz"), 'nbdev.core')
test_eq(relative_import('nbdev.core', 'nbdev'), '.core')
_p = Path('fastai')
test_eq(relative_import('fastai.core', _p/'vision'), '..core')
test_eq(relative_import('fastai.core', _p/'vision/transform'), '...core')
test_eq(relative_import('fastai.vision.transform', _p/'vision'), '.transform')
test_eq(relative_import('fastai.notebook.core', _p/'data'), '..notebook.core')
test_eq(relative_import('fastai.vision', _p/'vision'), '.')
test_eq(relative_import('fastai', _p), '.')
test_eq(relative_import('fastai', _p/'vision'), '..')
test_eq(relative_import('fastai', _p/'vision/transform'), '...')
#|export
# Based on https://github.com/thonny/thonny/blob/master/thonny/ast_utils.py
def _mark_text_ranges(
source: str|bytes, # Source code to add ranges to
):
"Adds `end_lineno` and `end_col_offset` to each `node` recursively. Used for Python 3.7 compatibility"
from asttokens.asttokens import ASTTokens
# We need to reparse the source to get a full tree to walk
root = ast.parse(source)
ASTTokens(source, tree=root)
for child in ast.walk(root):
if hasattr(child,"last_token"):
child.end_lineno,child.end_col_offset = child.last_token.end
# Some tokens stay without end info
if hasattr(child,"lineno") and (not hasattrs(child, ["end_lineno","end_col_offset"])):
child.end_lineno, child.end_col_offset = child.lineno, child.col_offset+2
return root.body
#|export
def update_import(source, tree, libname, f=relative_import):
if not tree: return
if sys.version_info < (3,8): tree = _mark_text_ranges(source)
imps = L(tree).filter(risinstance(ast.ImportFrom))
if not imps: return
src = source.splitlines(True)
for imp in imps:
nmod = f(imp.module, libname, imp.level)
lin = imp.lineno-1
sec = src[lin][imp.col_offset:imp.end_col_offset]
newsec = re.sub(f"(from +){'.'*imp.level}{imp.module}", fr"\1{nmod}", sec)
src[lin] = src[lin].replace(sec,newsec)
return src
@patch
def import2relative(cell:NbCell, libname):
src = update_import(cell.source, cell.parsed_(), libname)
if src: cell.set_source(src)
ss = "from nbdev.export import *\nfrom nbdev.a.b import *"
cell = make_code_cells([ss])[0]
cell.import2relative('nbdev')
test_eq(cell.source, 'from .export import *\nfrom .a.b import *')
cell = make_code_cells([ss])[0]
cell.import2relative('nbdev/a')
test_eq(cell.source, 'from ..export import *\nfrom .b import *')
#|export
@patch
def _last_future(self:ModuleMaker, cells):
"Returns the location of a `__future__` in `cells`"
trees = cells.map(NbCell.parsed_)
try: return max(i for i,tree in enumerate(trees) if tree and any(
isinstance(t,ast.ImportFrom) and t.module=='__future__' for t in tree))+1
except ValueError: return 0
#|export
def _import2relative(cells, lib_name=None):
"Converts `cells` to use `import2relative` based on `lib_name`"
if lib_name is None: lib_name = get_config().lib_name
for cell in cells: cell.import2relative(lib_name)
#|export
@patch
def make(self:ModuleMaker, cells, all_cells=None, lib_path=None):
"Write module containing `cells` with `__all__` generated from `all_cells`"
if all_cells is None: all_cells = cells
cells,all_cells = L(cells),L(all_cells)
if self.parse:
if not lib_path: lib_path = get_config().path('lib_path')
mod_dir = os.path.relpath(self.fname.parent, Path(lib_path).parent)
_import2relative(all_cells, mod_dir)
if not self.is_new: return self._make_exists(cells, all_cells)
self.fname.parent.mkdir(exist_ok=True, parents=True)
last_future = 0
if self.parse:
_all = self.make_all(all_cells)
last_future = self._last_future(cells) if len(all_cells)>0 else 0
tw = TextWrapper(width=120, initial_indent='', subsequent_indent=' '*11, break_long_words=False)
all_str = '\n'.join(tw.wrap(str(_all)))
with self.fname.open('w') as f:
f.write(f"# AUTOGENERATED! DO NOT EDIT! File to edit: {self.dest2nb}.")
if last_future > 0: write_cells(cells[:last_future], self.hdr, f)
if self.parse: f.write(f"\n\n# %% auto 0\n__all__ = {all_str}")
write_cells(cells[last_future:], self.hdr, f, 1 if last_future>0 else 0)
f.write('\n')
cells = make_code_cells("from __future__ import print_function", "#|export\ndef a(): ...", "def b(): ...")
mm.make(cells, L([cells[1]]))
show_src(Path('tmp/test/testing.py').read_text())
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../01_export.ipynb.
# %% ../../01_export.ipynb 0
from __future__ import print_function
# %% auto 0
__all__ = ['a']
# %% ../../01_export.ipynb 2
#|export
def a(): ...
# %% ../../01_export.ipynb 3
def b(): ...
Pass all_cells=[]
or parse=False
if you don't want any __all__
added.
Passing parse=False
is also handy for when writing broken up functions or classes that ast.parse
might not like but still want it to be exported, such as having a cell with:
#|export
class A:
Note that by doing so we cannot properly generate a __all__
, so we assume that it is unwanted.
am = ModuleMaker(dest='tmp', name='test.testing_noall', nb_path=Path.cwd()/'01_export.ipynb', is_new=True, parse=False)
am.fname
Path('tmp/test/testing_noall.py')
cells = make_code_cells("from __future__ import print_function", "#|export\ndef a(): ...", "#|export\nclass A:")
am.make(cells)
show_src(Path('tmp/test/testing_noall.py').read_text())
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../01_export.ipynb.
# %% ../../01_export.ipynb 0
from __future__ import print_function
# %% ../../01_export.ipynb 1
#|export
def a(): ...
# %% ../../01_export.ipynb 2
#|export
class A:
#|export
@patch
def _update_all(self:ModuleMaker, all_cells, alls):
return pformat(alls + self.make_all(all_cells), width=160)
@patch
def _make_exists(self:ModuleMaker, cells, all_cells=None):
"`make` for `is_new=False`"
if all_cells and self.parse: update_var('__all__', partial(self._update_all, all_cells), fn=self.fname)
with self.fname.open('a') as f: write_cells(cells, self.hdr, f)
If is_new=False
then the additional definitions are added to the bottom, and any existing __all__
is updated with the newly-added symbols.
c2 = make_code_cells("def c(): ...", "def d(): ...")
mm = ModuleMaker(dest='tmp', name='test.testing', nb_path=Path.cwd()/'01_export.ipynb', is_new=False)
mm.make(c2, c2)
show_src(Path('tmp/test/testing.py').read_text())
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../01_export.ipynb.
# %% ../../01_export.ipynb 0
from __future__ import print_function
# %% auto 0
__all__ = ['a', 'c', 'd']
# %% ../../01_export.ipynb 2
#|export
def a(): ...
# %% ../../01_export.ipynb 3
def b(): ...
# %% ../../01_export.ipynb 0
def c(): ...
# %% ../../01_export.ipynb 1
def d(): ...
g = exec_import('tmp.test.testing', '*')
for s in "a c d".split(): assert s in g, s
assert 'b' not in g
assert g['a']() is None
#|export
def _basic_export_nb2(fname, name, dest=None):
"A basic exporter to bootstrap nbdev using `ModuleMaker`"
if dest is None: dest = get_config().path('lib_path')
cells = L(c for c in read_nb(fname).cells if re.match(r'#\|\s*export', c.source))
ModuleMaker(dest=dest, name=name, nb_path=fname).make(cells)
#|hide
#|eval: false
path = Path('../nbdev')
(path/'config.py').unlink(missing_ok=True)
(path/'maker.py').unlink(missing_ok=True)
add_init(path)
cfg = get_config()
_basic_export_nb2('01_config.ipynb', 'config')
_basic_export_nb2('02_maker.ipynb', 'maker')
g = exec_import('nbdev', 'maker')
assert g['maker'].ModuleMaker
assert 'ModuleMaker' in g['maker'].__all__