Skip to content

Commit 9c024d5

Browse files
committed
fixed bug where graph attributes were stored to the trace at time=0
1 parent bf6cf2b commit 9c024d5

File tree

2 files changed

+30
-26
lines changed

2 files changed

+30
-26
lines changed

pyreason/scripts/interpretation/interpretation.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ def reason(interpretations_node, interpretations_edge, tmax, rules, nodes, edges
210210
else:
211211
# Check for inconsistencies (multiple facts)
212212
if check_consistent_node(interpretations_node, comp, (l, bnd)):
213-
u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, mode='fact')
213+
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
214+
u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, mode=mode)
214215

215216
update = u or update
216217
# Update convergence params
@@ -249,7 +250,8 @@ def reason(interpretations_node, interpretations_edge, tmax, rules, nodes, edges
249250
else:
250251
# Check for inconsistencies
251252
if check_consistent_edge(interpretations_edge, comp, (l, bnd)):
252-
u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, mode='fact')
253+
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
254+
u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, mode=mode)
253255

254256
update = u or update
255257
# Update convergence params
@@ -575,7 +577,7 @@ def _satisfies_threshold(num_neigh, num_qualified_component, threshold):
575577

576578

577579
@numba.njit(cache=True)
578-
def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, mode):
580+
def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, mode):
579581
updated = False
580582
# This is to prevent a key error in case the label is a specific label
581583
try:
@@ -592,17 +594,18 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
592594
updated_bnds.append(world.world[l])
593595

594596
# Add to rule trace if update happened and add to atom trace if necessary
595-
rule_trace.append((numba.types.int8(t_cnt), numba.types.int8(fp_cnt), comp, l, world.world[l].copy()))
596-
if atom_trace:
597-
# Mode can be fact or rule, updation of trace will happen accordingly
598-
if mode=='fact':
599-
qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
600-
qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
601-
name = facts_to_be_applied_trace[idx]
602-
_update_rule_trace_node(rule_trace_atoms, qn, qe, prev_bnd, name)
603-
elif mode=='rule':
604-
qn, qe, name = rules_to_be_applied_trace[idx]
605-
_update_rule_trace_node(rule_trace_atoms, qn, qe, prev_bnd, name)
597+
if save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact':
598+
rule_trace.append((numba.types.int8(t_cnt), numba.types.int8(fp_cnt), comp, l, world.world[l].copy()))
599+
if atom_trace:
600+
# Mode can be fact or rule, updation of trace will happen accordingly
601+
if mode=='fact' or mode=='graph-attribute-fact':
602+
qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
603+
qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
604+
name = facts_to_be_applied_trace[idx]
605+
_update_rule_trace_node(rule_trace_atoms, qn, qe, prev_bnd, name)
606+
elif mode=='rule':
607+
qn, qe, name = rules_to_be_applied_trace[idx]
608+
_update_rule_trace_node(rule_trace_atoms, qn, qe, prev_bnd, name)
606609

607610

608611
# Update complement of predicate (if exists) based on new knowledge of predicate
@@ -653,7 +656,7 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
653656

654657

655658
@numba.njit(cache=True)
656-
def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, mode):
659+
def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, mode):
657660
updated = False
658661
# This is to prevent a key error in case the label is a specific label
659662
try:
@@ -670,16 +673,17 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
670673
updated_bnds.append(world.world[l])
671674

672675
# Add to rule trace if update happened and add to atom trace if necessary
673-
rule_trace.append((numba.types.int8(t_cnt), numba.types.int8(fp_cnt), comp, l, world.world[l].copy()))
674-
if atom_trace:
675-
# Mode can be fact or rule, updation of trace will happen accordingly
676-
if mode=='fact':
677-
qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
678-
name = facts_to_be_applied_trace[idx]
679-
_update_rule_trace_edge(rule_trace_atoms, qn, prev_bnd, name)
680-
elif mode=='rule':
681-
qn, name = rules_to_be_applied_trace[idx]
682-
_update_rule_trace_edge(rule_trace_atoms, qn, prev_bnd, name)
676+
if save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact':
677+
rule_trace.append((numba.types.int8(t_cnt), numba.types.int8(fp_cnt), comp, l, world.world[l].copy()))
678+
if atom_trace:
679+
# Mode can be fact or rule, updation of trace will happen accordingly
680+
if mode=='fact' or mode=='graph-attribute-fact':
681+
qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
682+
name = facts_to_be_applied_trace[idx]
683+
_update_rule_trace_edge(rule_trace_atoms, qn, prev_bnd, name)
684+
elif mode=='rule':
685+
qn, name = rules_to_be_applied_trace[idx]
686+
_update_rule_trace_edge(rule_trace_atoms, qn, prev_bnd, name)
683687

684688

685689
# Update complement of predicate (if exists) based on new knowledge of predicate

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setup(
1010
name = 'pyreason',
11-
version = '1.2.0',
11+
version = '1.2.1',
1212
author = 'Dyuman Aditya',
1313
author_email = '[email protected]',
1414
description = 'An explainable inference software supporting annotated, real valued, graph based and temporal logic',

0 commit comments

Comments
 (0)