diff --git a/rdflib/graph.py b/rdflib/graph.py index d74dd85cf..857491a2e 100644 --- a/rdflib/graph.py +++ b/rdflib/graph.py @@ -847,26 +847,32 @@ def set( def subjects( self, predicate: Union[None, Path, _PredicateType] = None, - object: Optional[_ObjectType] = None, + object: Optional[Union[_ObjectType, List[_ObjectType]]] = None, unique: bool = False, ) -> Generator[_SubjectType, None, None]: """A generator of (optionally unique) subjects with the given - predicate and object""" - if not unique: - for s, p, o in self.triples((None, predicate, object)): - yield s + predicate and object(s)""" + # if the object is a list of Nodes, yield results from subject() call for each + if isinstance(object, list): + for obj in object: + for s in self.subjects(predicate, obj, unique): + yield s else: - subs = set() - for s, p, o in self.triples((None, predicate, object)): - if s not in subs: + if not unique: + for s, p, o in self.triples((None, predicate, object)): yield s - try: - subs.add(s) - except MemoryError as e: - logger.error( - f"{e}. Consider not setting parameter 'unique' to True" - ) - raise + else: + subs = set() + for s, p, o in self.triples((None, predicate, object)): + if s not in subs: + yield s + try: + subs.add(s) + except MemoryError as e: + logger.error( + f"{e}. Consider not setting parameter 'unique' to True" + ) + raise def predicates( self, @@ -894,27 +900,32 @@ def predicates( def objects( self, - subject: Optional[_SubjectType] = None, + subject: Optional[Union[_SubjectType, List[_SubjectType]]] = None, predicate: Union[None, Path, _PredicateType] = None, unique: bool = False, ) -> Generator[_ObjectType, None, None]: """A generator of (optionally unique) objects with the given - subject and predicate""" - if not unique: - for s, p, o in self.triples((subject, predicate, None)): - yield o + subject(s) and predicate""" + if isinstance(subject, list): + for subj in subject: + for o in self.objects(subj, predicate, unique): + yield o else: - objs = set() - for s, p, o in self.triples((subject, predicate, None)): - if o not in objs: + if not unique: + for s, p, o in self.triples((subject, predicate, None)): yield o - try: - objs.add(o) - except MemoryError as e: - logger.error( - f"{e}. Consider not setting parameter 'unique' to True" - ) - raise + else: + objs = set() + for s, p, o in self.triples((subject, predicate, None)): + if o not in objs: + yield o + try: + objs.add(o) + except MemoryError as e: + logger.error( + f"{e}. Consider not setting parameter 'unique' to True" + ) + raise def subject_predicates( self, object: Optional[_ObjectType] = None, unique: bool = False diff --git a/test/test_graph/test_graph_generators.py b/test/test_graph/test_graph_generators.py index 0d89c9b7f..bec7ccb4c 100644 --- a/test/test_graph/test_graph_generators.py +++ b/test/test_graph/test_graph_generators.py @@ -75,3 +75,19 @@ def test_parse_berners_lee_card_into_graph(): assert len(list(graph.subjects(unique=True))) == no_of_unique_subjects assert len(list(graph.predicates(unique=True))) == no_of_unique_predicates assert len(list(graph.objects(unique=True))) == no_of_unique_objects + + +def test_subjects_multi(): + graph = Graph() + add_stuff(graph) + assert len([subj for subj in graph.subjects(LIKES, [CHEESE, PIZZA])]) == 5 + assert len([subj for subj in graph.subjects(LIKES, [])]) == 0 + assert len([subj for subj in graph.subjects(LIKES | HATES, [CHEESE, PIZZA])]) == 6 + + +def test_objects_multi(): + graph = Graph() + add_stuff(graph) + assert len([obj for obj in graph.objects([TAREK, BOB], LIKES)]) == 6 + assert len([obj for obj in graph.objects([], LIKES)]) == 0 + assert len([obj for obj in graph.objects([TAREK, BOB], LIKES | HATES)]) == 8