Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
53e6ff0
initial commit. test passed.
proboscis Oct 7, 2019
154809a
Create README.md
proboscis Oct 7, 2019
547136c
add ./pytest_cache
proboscis Oct 7, 2019
6f2ee5c
fixed wrong import statements
proboscis Oct 7, 2019
cadf924
bugfix on zipped series
proboscis Oct 7, 2019
e1700a2
removed get_value from Trace constructor as it is redundant
proboscis Oct 7, 2019
9afceea
add many features such as : improved visualiztion, interact() method…
proboscis Nov 28, 2019
bb87280
trying to make multiprocessing work on slicing generator
proboscis Dec 18, 2019
2b2e091
commit before tree alternation
proboscis Jan 26, 2020
66788fb
added cmap and iwidget
proboscis Jan 30, 2020
1143af4
add loguru dependency
proboscis Feb 6, 2020
db1bea7
implemented auto_image
proboscis Mar 18, 2020
8901ca6
now AutoData is picklable
proboscis May 25, 2020
3ab698d
added auto function
proboscis May 26, 2020
2bec7b4
add tutorial
proboscis May 26, 2020
32c2a2e
add readme
proboscis Jul 1, 2020
cf89ff4
edit readme
proboscis Jul 1, 2020
d4a82d0
add tuple conversion
proboscis Oct 23, 2020
332dbbf
tuple conversion smart search is implemented
proboscis Oct 23, 2020
8a303a5
add conversions
proboscis Dec 4, 2020
b6409bd
fix tuple conversion
proboscis Dec 14, 2020
4e9e8a1
remove comment
proboscis Dec 14, 2020
412650b
implemented caching system for astar
proboscis Jan 18, 2021
c64b394
working commit
proboscis Jan 30, 2021
05dbb0d
Merge branch 'master' into origin_master
proboscis May 30, 2023
b0c3db9
Bump notebook from 6.4.7 to 6.4.12
dependabot[bot] May 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 346 additions & 0 deletions coconut/astar.coco
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
import heapq
from loguru import logger
from pprint import pformat
from data_tree.coconut.monad import try_monad,Try,Success,Failure
import dill # for pickling inner lambda...
from tqdm.autonotebook import tqdm
from loguru import logger
from itertools import chain
from os.path import expanduser
from data_tree.util import DefaultShelveCache
import shelve
import os
from itertools import chain
data Edge(src,dst,f,cost,name="unnamed")
class NoRouteException(Exception)
class ConversionError(Exception)
class Conversion:
def __init__(self,edges):
self.edges = edges

def __call__(self,x):
init_x = x
for e in self.edges:
try:
x = e.f(x)
except Exception as ex:
import inspect
import pickle
logger.error(f"caught an exception. inspecting..{ex}")
logger.warning(f"saving erroneous conversion for debug")
info= dict(
start = self.edges[0].src,
end = self.edges[-1].dst,
x = init_x
)
logger.debug(f"conversion info = start:{info['start']},x:{info['x']}")
with open("last_erroneous_conversion.pkl","wb") as f:
pickle.dump(info,f)
source = inspect.getsource(e.f)
logger.warning("saved last conversion error cause")
raise ConversionError(f"exception in edge:{e.name} \n paths:{[e.name for e in self.edges]} \n x:{x} \n edge source:{source}") from ex
return x

def __getitem__(self,item):
if isinstance(item,int):
edges = [(self.edges[item])]
else:
edges = self.edges[item]
return Conversion(edges)

def __repr__(self):
if len(self.edges):
start = self.edges[0].src
end = self.edges[-1].dst
else:
start = "None"
end = "None"
info = dict(
name="Conversion",
start = start,
end = end,
path = [e.name for e in self.edges],
cost = sum([e.cost for e in self.edges])
)
return pformat(info)
def trace(self,tgt):
x = tgt
for e in self.edges:
x = e.f(x)
yield dict(edge= e,x=x)

def new_conversion(edges):
return Conversion(edges)
class _HeapContainer:
def __init__(self,score,data):
self.score = score
self.data = data
def __lt__(self,other):
return self.score < other.score

def _astar(
start,
matcher,
neighbors,
max_depth = 100):
"""
neighbors: node->[(mapper,next_node,cost,name)]
"""
to_visit = []
scores = dict()
scores[start] = 0#heuristics(start,None)
heapq.heappush(to_visit,
_HeapContainer(scores[start],(start,[])))
visited = 0
bar = tqdm(desc="solving with astar")
while to_visit:
bar.update(1)
hc = heapq.heappop(to_visit)
score = hc.score
(pos,trace) = hc.data
visited += 1
#print(f"visit:{pos}")
#print(f"{((trace[-1].a,trace[-1].name) if trace else 'no trace')}")
#print(f"visit:{trace[-1] if trace else 'no trace'}")
if len(trace) >= max_depth: # terminate search on max_depth
continue
if matcher(pos): # reached a goal
logger.debug(f"found after {visited} visits.")
logger.debug(f"search result:\n{new_conversion(trace)}")
bar.close()
return trace
for mapper,next_node,cost,name in neighbors(pos):
assert isinstance(cost,int),f"cost is not a number. cost:{cost},name:{name},pos:{pos}"
new_trace = trace + [Edge(pos,next_node,mapper,cost,name)]
try:
new_score = scores[pos] + cost #+ heuristics(next_node,end)
except Exception as e:
logger.error(f"pos:{pos},cost:{cost},next_node:{next_node}")
raise e
if next_node in scores and scores[next_node] <= new_score:
continue

else:
scores[next_node] = new_score
heapq.heappush(to_visit,_HeapContainer(new_score,(next_node,new_trace)))

raise NoRouteException(f"no route found from {start} matching {matcher}")



def _astar_direct(
start,
end,
neighbors,
smart_neighbors,
heuristics,
edge_cutter,
max_depth = 100,
silent=False
):
"""
neighbors: node->[(mapper,next_node,cost,name)]
"""
to_visit = []
scores = dict()
scores[start] = heuristics(start,end)
heapq.heappush(to_visit,
_HeapContainer(scores[start],(start,[])))
visited = 0
bar = tqdm(desc="solving with astar_direct")
last_bar_update = visited
while to_visit:
if visited - last_bar_update > 1000:
bar.update(visited-last_bar_update)
last_bar_update = visited
hc = heapq.heappop(to_visit)
score = hc.score
(pos,trace) = hc.data
if not silent:
#logger.info(f"{score}:{pos}")
#bar.write(f"score:{score}")
pass

#logger.debug(f"visiting:{pos}")
visited += 1
#print(f"visit:{pos}")
#print(f"{((trace[-1].a,trace[-1].name) if trace else 'no trace')}")
#print(f"visit:{trace[-1] if trace else 'no trace'}")
if len(trace) >= max_depth: # terminate search on max_depth
continue
if pos == end: # reached a goal
if not silent:
logger.debug(f"found after {visited} visits.")
logger.debug(f"search result:\n{new_conversion(trace)}")
bar.close()
return trace
normal_nodes = list(neighbors(pos))
smart_nodes = list(smart_neighbors(pos,end))
#logger.debug(f"{normal_nodes},{smart_nodes}")
if visited % 10000 == 0:
msg = str(pos)[:50]
bar.set_description(f"""pos:{msg:<50}""")
for i,(mapper,next_node,cost,name) in enumerate(chain(normal_nodes,smart_nodes)):
assert isinstance(cost,int),f"cost is not a number:{pos}"
new_trace = trace + [Edge(pos,next_node,mapper,cost,name)]
try:
new_score = scores[pos] + cost + heuristics(next_node,end)
except Exception as e:
logger.error(f"pos:{pos},cost:{cost},next_node:{next_node}")
raise e
if next_node in scores and scores[next_node] <= new_score:
continue
elif(edge_cutter(pos,next_node,end)):
continue
else:
scores[next_node] = new_score
heapq.heappush(to_visit,_HeapContainer(new_score,(next_node,new_trace)))

raise NoRouteException(f"no route found from {start} matching {end}. searched {visited} nodes.")

astar = (_astar ..> new_conversion) |> try_monad
astar_direct = (_astar_direct ..> new_conversion) |> try_monad

def _zero_heuristics(x,y):
return 0

def _no_cutter(x,y,end):
return False

MAX_MEMO = 2**(20)

class AStarSolver:
"""
to make this picklable, you have to have these caches on global variable.
use getstate and setstate
actually, having to pickle this solver every time you pass auto_data is not feasible.
so, lets have auto_data to hold solver in global variable and never pickle AStarSolver!.
so forget about lru_cache pickling issues.
"""
def __init__(self,
rules=None,
smart_rules=None,
heuristics = _zero_heuristics,
edge_cutter=_no_cutter,
cache_path=os.path.join(expanduser("~"),".cache/autodata.shelve")
):
"""
rules: List[Rule]
Rule: (state)->List[(converter:(data)->data,new_state,cost,conversion_name)]
"""
from lru import LRU
self.rules = rules if rules is not None else []
self.smart_rules = smart_rules if smart_rules is not None else []
self.heuristics = heuristics
self.edge_cutter = edge_cutter
self.neighbors_memo = LRU(MAX_MEMO) # you cannot pickle a lru.LRU object, thus you cannot pickle this class for multiprocessing.
self.search_memo = LRU(MAX_MEMO)
self.smart_neighbors_memo = LRU(MAX_MEMO)
self.direct_search_cache = DefaultShelveCache(self._search_direct,cache_path)
self.direct_search_memo = LRU(MAX_MEMO)

def neighbors(self,node):
if node in self.neighbors_memo:
return self.neighbors_memo[node]

res = []
for rule in self.rules:
edges = rule(node)
if edges is not None:
res += edges
self.neighbors_memo[node] = res
return res
def smart_neighbors(self,node,end):
if (node,end) in self.smart_neighbors_memo:
return self.smart_neighbors_memo[(node,end)]
res = []
for rule in self.smart_rules:
edges = rule(node,end)
if edges is not None:
res += edges
self.smart_neighbors_memo[(node,end)] = res
#logger.debug(res)
return res

def invalidate_cache(self):
self.neighbors_memo.clear()
self.search_memo.clear()
self.direct_search_memo.clear()

def add_rule(self,f):
self.rules.append(f)
self.invalidate_cache()

def search(self,start,matcher):
#problem is that you can't hash matcher
#let's use id of matcher for now.
q = (start,id(matcher))
if q in self.search_memo:
res = self.search_memo[q]
else:
logger.debug(f"searching from {start} for matching {matcher}")
res = astar(
start=start,
matcher=matcher,
neighbors=self.neighbors,
)
self.search_memo[q] = res

case res:
match Success(res):
return res
match Failure(e,trc):
raise e
def _research_from_edges(self,edges):
searched_edges = list(chain(*[self._search_direct((src,dst,True)).edges for src,dst in edges]))
#logger.info(f"searched_edges:{searched_edges}")
return Conversion(searched_edges)

def _search_direct(self,q):
# you cannot directly save the function.
# so you need to save the paths and re-search it
start,end,silent = q
if not silent:
logger.debug(f"searching {start} to {end}")
res = astar_direct(
start=start,
end=end,
neighbors=self.neighbors,
smart_neighbors=self.smart_neighbors,
heuristics=self.heuristics,
edge_cutter=self.edge_cutter,
silent=silent
)
case res:
match Success(res):
return res
match Failure(e,trc):
raise e
def search_direct(self,start,end,silent=False):
key = (start,end,silent)
if key in self.direct_search_memo:
return self.direct_search_memo[key]
elif key in self.direct_search_cache:
edges = self.direct_search_cache[key]
conversion = self._research_from_edges(edges)
self.direct_search_memo[key] = conversion
logger.debug(f"researched_conversion:\n{conversion}")
return conversion
else:
conversion = self._search_direct(key)
self.direct_search_cache[key] = [(e.src,e.dst) for e in conversion.edges]
self.direct_search_memo[key] = conversion
# I can memo every path in conversion actually.
# however since the states are in a different space than a query, no speedups can be done.
# if this astar knows about casting, it can first search for a cache though..

return conversion

def search_direct_any(self,start,ends):
for cand in ends:
try:
res = self.search_direct(start,cand)
return res
except Exception as e:
pass
raise NoRouteException(f"no route found from {start} to any of {ends}")
Loading