import attr
import funcy as fn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from bidict import bidict
from IPython.display import Image, display, SVG
import networkx as nx
import pydot
import pandas as pd
from collections import Counter
import dfa
from dfa.utils import find_subset_counterexample, find_equiv_counterexample
from dfa_identify import find_dfa, find_dfas
from diss.planners.product_mc import ProductMC
from diss.concept_classes.dfa_concept import DFAConcept
from diss.domains.gridworld_naive import GridWorldNaive as World
from diss.domains.gridworld_naive import GridWorldState as State
from diss import search, LabeledExamples, GradientGuidedSampler, ConceptIdException
from pprint import pprint
from itertools import combinations
from tqdm import tqdm_notebook
from tqdm.notebook import trange
from IPython.display import clear_output
from IPython.display import HTML as html_print
from functools import reduce
sns.set_context('paper')
sns.set_style('darkgrid')
sns.set_palette('Set2')
np.set_printoptions(precision=3)
from diss.experiment import PartialDFAIdentifier, ignore_white, PARTIAL_DFA, BASE_EXAMPLES
from diss.experiment import view_dfa
from diss import diss
from diss import DemoPrefixTree as PrefixTree
def analyze(search, n_iters):
concept2energy = {} # Explored concepts + associated energy
partial_masses = []
median_energies = []
min_energies = []
total_energies = []
# Run Search and collect concepts, energy, and POI.
for i, (data, concept, metadata) in zip(trange(n_iters, desc='DISS'), search):
print(f'==========={i}================')
print('size', concept.size)
score = metadata['energy']
print('energy', score)
if 'grad' in metadata:
print('surprisal', metadata.get('surprisal'))
grad = metadata['grad']
sns.set(rc={"figure.figsize":(10, 2)})
sns.barplot(x=np.arange(len(grad)), y=np.array(grad) / np.abs(grad).max())
plt.xticks(rotation=45)
plt.show()
weights = metadata['weights']
sns.set(rc={"figure.figsize":(10, 2)})
sns.barplot(x=np.arange(len(weights)), y=np.array(weights))
plt.xticks(rotation=45)
plt.show()
print('pivot', metadata['pivot'])
print("conjecture:")
print(f"{metadata['conjecture']}")
print(f'data')
data = metadata['data']
data @= identifer.base_examples # Force labels of prior examples.
buff = ''
for lbl, split in [(True, data.positive), (False, data.negative)]:
buff += f'------------- {lbl} --------------<br>'
for word in sorted(split, key=len):
obs = '\n'.join(map(tile, word))
buff += f'{obs}<br>'
display(html_print(buff))
concept2energy[concept] = metadata['energy']
view_dfa(concept)
energies = list(concept2energy.values())
partial_masses.append(sum(np.exp(-x) for x in energies)) # Record unormalized mass
median_energies.append(np.median(energies))
min_energies.append(np.min(energies))
total_energies.append(sum(energies))
sorted_concepts = sorted(list(concept2energy), key=concept2energy.get)
p = 0
for c in sorted(concept2energy, key=concept2energy.get):
p += np.exp(-concept2energy[c])
print('energy', concept2energy[c])
view_dfa(c)
if p > 0.99:
break
df = pd.DataFrame(data={
'probability mass explored': partial_masses,
'median energies': median_energies,
'min energies': min_energies,
'cumulative energy': total_energies,
'iteration': list(range(1, len(total_energies) + 1)),
})
return df, sorted_concepts
from diss.experiment.planner import GridWorldPlanner
planner = GridWorldPlanner.from_string(
buff="""y....g..
........
.b.b...r
.b.b...r
.b.b....
.b.b....
rrrrrr.r
g.y.....""",
start=(3, 5),
slip_prob=1/32,
horizon=15,
policy_cache='diss_experiment.shelve',
)
SENSOR = planner.gw.sensor
DYN = planner.gw.dyn
This can all seem pretty abstract, so let's visualize the way the sensor sees the board.
from IPython.display import HTML as html_print
COLOR_ALIAS = {
'white': 'white',
'yellow': '#ffff00',
'red': '#ff8b8b',
'blue': '#afafff',
'green' : '#8ff45d'
}
def tile(color='black'):
color = COLOR_ALIAS.get(color, color)
s = ' '*4
return f"<text style='border: solid 1px;background-color:{color}'>{s}</text>"
def print_map():
"""Scan the board row by row and print colored tiles."""
order = range(1, 9)
buffer = ''
for y in order:
chars = (tile(planner.gw.ap_at_state(x, y)) for x in order)
buffer += ' '.join(chars) + '<br>'
display(html_print(buffer))
DYN_SENSE = DYN >> SENSOR
def print_trc(trc, idx=0):
obs = planner.lift_path(trc, flattened=False, compress=False)
actions = [x['a'] for x in trc[1:]]
obs = map(tile, obs)
display(
html_print(f'trc {idx}: ' + ''.join(''.join(x) for x in zip(actions, obs)) + '\n')
)
print_map()
TRC4 = [
(3, 5),
{'a': '↑', 'c': 0},
{'a': '↑', 'c': 1},
{'a': '↑', 'c': 1},
{'a': '→', 'c': 1},
{'a': '↑', 'c': 1},
{'a': '↑', 'c': 1},
{'a': '→', 'c': 1},
{'a': '→', 'c': 1},
{'a': '→', 'c': 1},
{'a': '←', 'c': 1},
{'a': '←', 'c': 1},
{'a': '←', 'c': 1},
{'a': '←', 'c': 1},
{'a': '←', 'c': 1, 'EOE_ego': 1},
]
TRC5 = [
(3, 5),
{'a': '↑', 'c': 1},
{'a': '↑', 'c': 1},
#{'a': '↑', 'c': 1},
#{'a': '↑', 'c': 1},
{'a': '↑', 'c': 1},
{'a': '←', 'c': 1},
{'a': '←', 'c': 1, 'EOE_ego': 1},
]
print(len(TRC4))
print_trc(TRC4, 4)
print_trc(TRC5, 5)
BASE_EXAMPLES
import random
env_yellow = dfa.DFA(
start=False,
inputs={'blue', 'green', 'red', 'yellow'},
outputs={True, False},
label=lambda s: s,
transition=lambda s, c: s | (c == 'yellow'),
)
universal = dfa.DFA(
start=True,
inputs={'blue', 'green', 'red', 'yellow'},
outputs={True, False},
label=lambda s: s,
transition=lambda s, c: True,
)
MONOLITHIC = True
identifer = PartialDFAIdentifier(
partial = universal if MONOLITHIC else PARTIAL_DFA,
base_examples = LabeledExamples(negative=[], positive=[]) if MONOLITHIC else BASE_EXAMPLES,
try_reach_avoid=True,
)
def to_chain(c, t, psat):
return planner.plan(c, t, psat, monolithic=MONOLITHIC, use_rationality=True)
n_iters = 100
to_demo = planner.to_demo
dfs = []
for i in fn.chain(range(-10, 11, 1), [float('inf')]):
print(f'beta = 2^{i}')
for _ in range(5):
dfa_search = diss(
demos=[to_demo(TRC4), to_demo(TRC5)] if MONOLITHIC else [to_demo(TRC4[:-1])],
to_concept=identifer,
to_chain=to_chain,
competency=lambda *_: 10,
lift_path=planner.lift_path,
n_iters=n_iters,
reset_period=30,
surprise_weight=1,
size_weight=1/50,
sgs_temp=2**i,
example_drop_prob=1/20, #1e-2,
synth_timeout=20,
)
df, found_concepts = analyze(dfa_search, n_iters)
df['treatment'] = r'$\beta = 2^{' + f'{i}' + '}$'
df['logtemp'] = i
df['iteration'] = df.index
dfs.append(df)
df = pd.concat(dfs, ignore_index=True)
df['experiment'] = 'Monolithic' if monolithic else 'Incremental'
df.to_json( f'experiment_{"mono" if MONOLITHIC else "inc"}_beta.json')
from diss.experiment import concept_class
enumeration_dfs = []
for n_iters, monolithic in [(100, True), (40, False)]:
def to_chain(c, t, psat):
return planner.plan(c, t, psat, monolithic=monolithic, use_rationality=True)
pos_examples_mono = LabeledExamples(negative=[], positive=[('blue', 'green', 'yellow'), ('yellow',)])
pos_examples_inc = BASE_EXAMPLES @ LabeledExamples(negative=[], positive=[('blue', 'green', 'yellow'), ('yellow',)])
identifer = PartialDFAIdentifier(
partial = universal if monolithic else PARTIAL_DFA,
base_examples = pos_examples_mono if monolithic else pos_examples_inc
)
dfa_search = concept_class.enumerative_search(
demos=[to_demo(TRC4), to_demo(TRC5)] if monolithic else [to_demo(TRC4[:-1])],
identifer=identifer,
to_chain=to_chain,
competency=lambda *_: 0.8,
n_iters=n_iters,
surprise_weight=1, # Rescale surprise to make comparable to size.
size_weight=1/50,
)
df3, _ = analyze(dfa_search, n_iters)
df3['experiment'] = 'Monolithic' if monolithic else 'Incremental'
df3['treatment'] = 'enumeration'
df3['iteration'] = df3.index
enumeration_dfs.append(df3)
df_mono = pd.read_json( f'experiment_mono_beta.json')
df_inc = pd.read_json( f'experiment_inc_beta.json')
# Normalize all energies between 0 and 1.
for tmp1, tmp2 in zip([df_mono, df_inc], enumeration_dfs):
U1, U2 = tmp1['min energies'], tmp2['min energies']
U_min = min(U1.min(), U2.min())
U_max = max(U1.max(), U2.max())
for tmp in [tmp1, tmp2]:
U = tmp['min energies']
tmp['U'] = (U - U_min) / (U_max- U_min)
df_mono['experiment'] = 'Monolithic'
df_inc['experiment'] = 'Incremental'
diss_dfs = [df_mono, df_inc]
enumeration_dfs[0]['experiment'] = 'Monolithic'
enumeration_dfs[1]['experiment'] = 'Incremental'
#df_enum = pd.concat(enumeration_dfs)
sns.set(rc={"figure.figsize":(10, 6)})
for df_diss, df_enum, iters, experiment in zip(diss_dfs, enumeration_dfs, [80, 40], ['Monolithic', 'Incremental']):
if experiment == 'Incremental':
plot = plt.scatter(list(range(21)), list(range(-10, 11)), c=list(range(-10, 11)), cmap='coolwarm')
plt.clf()
cbar = plt.colorbar(plot, extend='max')
cbar.ax.set_ylabel(r'$\ln \beta$', rotation=270)
hdl = plt.plot(df_enum['iteration'], df_enum['U'], '--', c='black', label='enumerate')
grid = sns.lineplot(
data=df_diss, x='iteration', y='U',
palette='coolwarm', hue='treatment', legend=False,
estimator=np.median, ci=None,
)
plt.title(f'{experiment=}')
plt.xlim(0, iters)
plt.xlabel('Iteration')
plt.ylabel('(normalized) minumum energy DFA found')
plt.legend()
plt.savefig(f'mass_{experiment}.pgf')
plt.show()
#
#plt.savefig('mass_mono2.pgf')