#hide #default_exp merge from nbdev.showdoc import show_doc #export from nbdev.imports import * from fastcore.script import * #hide tst_nb="""{ "cells": [ { "cell_type": "code", <<<<<<< HEAD "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z=3\n", "z" ] }, { "cell_type": "code", "execution_count": 7, ======= "execution_count": 5, >>>>>>> a7ec1b0bfb8e23b05fd0a2e6cafcb41cd0fb1c35 "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, <<<<<<< HEAD "execution_count": 7, ======= "execution_count": 5, >>>>>>> a7ec1b0bfb8e23b05fd0a2e6cafcb41cd0fb1c35 "metadata": {}, "output_type": "execute_result" } ], "source": [ "x=3\n", "y=3\n", "x+y" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }""" print(tst_nb) #export def extract_cells(raw_txt): "Manually extract cells in potential broken json `raw_txt`" lines = raw_txt.split('\n') cells = [] i = 0 while not lines[i].startswith(' "cells"'): i+=1 i += 1 start = '\n'.join(lines[:i]) while lines[i] != ' ],': while lines[i] != ' {': i+=1 j = i while not lines[j].startswith(' }'): j+=1 c = '\n'.join(lines[i:j+1]) if not c.endswith(','): c = c + ',' cells.append(c) i = j+1 end = '\n'.join(lines[i:]) return start,cells,end start,cells,end = extract_cells(tst_nb) test_eq(len(cells), 3) test_eq(cells[0], """ { "cell_type": "code", <<<<<<< HEAD "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z=3\n", "z" ] },""") #hide #Test the whole text is there #We add a , to the last cell (because we might add some after for merge conflicts at the end, so we need to remove it) test_eq(tst_nb, '\n'.join([start] + cells[:-1] + [cells[-1][:-1]] + [end])) #export def get_md_cell(txt): "A markdown cell with `txt`" return ''' { "cell_type": "markdown", "metadata": {}, "source": [ "''' + txt + '''" ] },''' tst = ''' { "cell_type": "markdown", "metadata": {}, "source": [ "A bit of markdown" ] },''' assert get_md_cell("A bit of markdown") == tst #export conflicts = '<<<<<<< ======= >>>>>>>'.split() #export def _split_cell(cell, cf, names): "Split `cell` between `conflicts` given state in `cf`, save `names` of branches if seen" res1,res2 = [],[] for line in cell.split('\n'): if line.startswith(conflicts[cf]): if names[cf//2] is None: names[cf//2] = line[8:] cf = (cf+1)%3 continue if cf<2: res1.append(line) if cf%2==0: res2.append(line) return '\n'.join(res1),'\n'.join(res2),cf,names #hide tst = '\n'.join(['a', f'{conflicts[0]} HEAD', 'b', conflicts[1], 'c', f'{conflicts[2]} lala', 'd']) v1,v2,cf,names = _split_cell(tst, 0, [None,None]) assert v1 == 'a\nb\nd' assert v2 == 'a\nc\nd' assert cf == 0 assert names == ['HEAD', 'lala'] #hide tst = '\n'.join(['a', f'{conflicts[0]} HEAD', 'b', conflicts[1], 'c', f'{conflicts[2]} lala', 'd', f'{conflicts[0]} HEAD', 'e']) v1,v2,cf,names = _split_cell(tst, 0, [None,None]) assert v1 == 'a\nb\nd\ne' assert v2 == 'a\nc\nd' assert cf == 1 assert names == ['HEAD', 'lala'] #hide tst = '\n'.join(['a', f'{conflicts[0]} HEAD', 'b', conflicts[1], 'c', f'{conflicts[2]} lala', 'd', f'{conflicts[0]} HEAD', 'e', conflicts[1]]) v1,v2,cf,names = _split_cell(tst, 0, [None,None]) assert v1 == 'a\nb\nd\ne' assert v2 == 'a\nc\nd' assert cf == 2 assert names == ['HEAD', 'lala'] #hide tst = '\n'.join(['b', conflicts[1], 'c', f'{conflicts[2]} lala', 'd']) v1,v2,cf,names = _split_cell(tst, 1, ['HEAD',None]) assert v1 == 'b\nd' assert v2 == 'c\nd' assert cf == 0 assert names == ['HEAD', 'lala'] #hide tst = '\n'.join(['c', f'{conflicts[2]} lala', 'd']) v1,v2,cf,names = _split_cell(tst, 2, ['HEAD',None]) assert v1 == 'd' assert v2 == 'c\nd' assert cf == 0 assert names == ['HEAD', 'lala'] #export _re_conflict = re.compile(r'^<<<<<<<', re.MULTILINE) #hide assert _re_conflict.search('a\nb\nc') is None assert _re_conflict.search('a\n<<<<<<<\nc') is not None #export def same_inputs(t1, t2): "Test if the cells described in `t1` and `t2` have the same inputs" if len(t1)==0 or len(t2)==0: return False try: c1,c2 = json.loads(t1[:-1]),json.loads(t2[:-1]) return c1['source']==c2['source'] except Exception as e: return False ts = [''' { "cell_type": "code", "source": [ "'''+code+'''" ] },''' for code in ["a=1", "b=1", "a=1"]] assert same_inputs(ts[0],ts[2]) assert not same_inputs(ts[0], ts[1]) #export def analyze_cell(cell, cf, names, prev=None, added=False, fast=True, trust_us=True): "Analyze and solve conflicts in `cell`" if cf==0 and _re_conflict.search(cell) is None: return cell,cf,names,prev,added old_cf = cf v1,v2,cf,names = _split_cell(cell, cf, names) if fast and same_inputs(v1,v2): if old_cf==0 and cf==0: return (v2 if trust_us else v1),cf,names,prev,added v1,v2 = (v2,v2) if trust_us else (v1,v1) res = [] if old_cf == 0: added=True res.append(get_md_cell(f'`{conflicts[0]} {names[0]}`')) res.append(v1) if cf ==0: res.append(get_md_cell(f'`{conflicts[1]}`')) if prev is not None: res += prev res.append(v2) res.append(get_md_cell(f'`{conflicts[2]} {names[1]}`')) prev = None else: prev = [v2] if prev is None else prev + [v2] return '\n'.join([r for r in res if len(r) > 0]),cf,names,prev,added tst = '\n'.join(['a', f'{conflicts[0]} HEAD', 'b', conflicts[1], 'c']) c,cf,names,prev,added = analyze_cell(tst, 0, [None,None], None, False,fast=False) test_eq(c, get_md_cell('`<<<<<<< HEAD`')+'\na\nb') test_eq(cf, 2) test_eq(names, ['HEAD', None]) test_eq(prev, ['a\nc']) test_eq(added, True) #export @call_parse def nbdev_fix_merge( fname:str, # A notebook filename to fix fast:bool=True, # Fast fix: automatically fix the merge conflicts in outputs or metadata trust_us:bool=True # Use local outputs/metadata when fast merging ): "Fix merge conflicts in notebook `fname`" fname=Path(fname) shutil.copy(fname, fname.with_suffix('.ipynb.bak')) with open(fname, 'r') as f: raw_text = f.read() start,cells,end = extract_cells(raw_text) res = [start] cf,names,prev,added = 0,[None,None],None,False for cell in cells: c,cf,names,prev,added = analyze_cell(cell, cf, names, prev, added, fast=fast, trust_us=trust_us) res.append(c) if res[-1].endswith(','): res[-1] = res[-1][:-1] with open(f'{fname}', 'w') as f: f.write('\n'.join([r for r in res+[end] if len(r) > 0])) if fast and not added: print("Successfully merged conflicts!") else: print("One or more conflict remains in the notebook, please inspect manually.") #hide from nbdev.export import notebook2script notebook2script()