Skip to content

Commit 0aeca05

Browse files
authored
ENH harden Method and Operator node audits (#482)
* ENH harden Method and Operator node audits * ... * ... * trigger CI * add tests * CI timeout * changelog * codecov count change * trigger CI
1 parent 53cabd2 commit 0aeca05

File tree

8 files changed

+84
-16
lines changed

8 files changed

+84
-16
lines changed

.codecov.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ codecov:
33
branch: main
44
require_ci_to_pass: true
55
notify:
6-
after_n_builds: 12
6+
after_n_builds: 21
77
wait_for_ci: true
88
ignore:
99
- "skops/_min_dependencies.py" # This file is not tested, and won't be.

.github/workflows/build-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
]
2828

2929
# Timeout: https://stackoverflow.com/a/59076067/4521646
30-
timeout-minutes: 15
30+
timeout-minutes: 30
3131

3232
steps:
3333
# The following two steps are workarounds to retrieve the "real" commit

docs/changes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ skops Changelog
1111

1212
v0.12
1313
-----
14+
- `huggingface_hub` dependency is now optional. :pr:`462` by `Adrin Jalali`_.
15+
- Objects' `__reduce__` is used when the output of it is of the form
16+
`(type, (constructor_args,)` where type is the same as the `type(obj)`.
17+
:pr:`467` by `Adrin Jalali`_.
18+
- `MethodNode` and `OperatorNode` have a hardened audit now, removing certain security
19+
vulnerabilities. :pr:`482` by `Adrin Jalali`_.
1420

1521
v0.11
1622
-----

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ filterwarnings = [
9090
"ignore:The ExtraTreesQuantileRegressor or classes from which it inherits use `_get_tags` and `_more_tags`:FutureWarning",
9191
# BaseEstimator._validate_data deprecation warning in sklearn 1.6 #TODO can be removed when a new release of quantile-forest is out
9292
"ignore:`BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7:FutureWarning",
93+
# This comes from matplotlib somehow
94+
"ignore:'mode' parameter is deprecated and will be removed in Pillow 13:DeprecationWarning",
9395
]
9496
addopts = "--cov=skops --cov-report=term-missing --doctest-modules"
9597

skops/io/_audit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import io
44
from contextlib import contextmanager
5-
from typing import Any, Dict, Generator, List, Optional, Sequence, Type, Union
5+
from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Type, Union
66

77
from ._protocol import PROTOCOL
88
from ._utils import LoadContext, get_module, get_type_paths
@@ -39,7 +39,7 @@ def check_type(module_name: str, type_name: str, trusted: Sequence[str]) -> bool
3939
return module_name + "." + type_name in trusted
4040

4141

42-
def audit_tree(tree: Node) -> None:
42+
def audit_tree(tree: Node, trusted: Iterable[str] | None) -> None:
4343
"""Audit a tree of nodes.
4444
4545
A tree is safe if it only contains trusted types.
@@ -54,7 +54,8 @@ def audit_tree(tree: Node) -> None:
5454
UntrustedTypesFoundException
5555
If the tree contains an untrusted type.
5656
"""
57-
unsafe = tree.get_unsafe_set()
57+
trusted = trusted or set()
58+
unsafe = tree.get_unsafe_set() - set(trusted)
5859
if unsafe:
5960
raise UntrustedTypesFoundException(unsafe)
6061

skops/io/_general.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,13 +509,15 @@ def method_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
509509
# dependent on a specific instance of an object.
510510
# It stores the state of the object the method is bound to,
511511
# and prepares both to be persisted.
512+
owner = obj.__self__
513+
func_name = obj.__func__.__name__
512514
res = {
513-
"__class__": obj.__class__.__name__,
515+
"__class__": owner.__class__.__name__,
514516
"__module__": get_module(obj),
515517
"__loader__": "MethodNode",
516518
"content": {
517-
"func": obj.__func__.__name__,
518-
"obj": get_state(obj.__self__, save_context),
519+
"func": func_name,
520+
"obj": get_state(owner, save_context),
519521
},
520522
}
521523
return res
@@ -529,13 +531,32 @@ def __init__(
529531
trusted: Optional[Sequence[str]] = None,
530532
) -> None:
531533
super().__init__(state, load_context, trusted)
534+
obj = get_tree(state["content"]["obj"], load_context, trusted=trusted)
535+
if self.module_name != obj.module_name or self.class_name != obj.class_name:
536+
raise ValueError(
537+
f"Expected object of type {self.module_name}.{self.class_name}, got"
538+
f" {obj.module_name}.{obj.class_name}. This is probably due to a"
539+
" corrupted or a malicious file."
540+
)
532541
self.children = {
533-
"obj": get_tree(state["content"]["obj"], load_context, trusted=trusted),
542+
"obj": obj,
534543
"func": state["content"]["func"],
535544
}
536545
# TODO: what do we trust?
537546
self.trusted = self._get_trusted(trusted, [])
538547

548+
def get_unsafe_set(self) -> set[str]:
549+
res = super().get_unsafe_set()
550+
obj_node = self.children["obj"]
551+
res.add(
552+
obj_node.module_name # type: ignore
553+
+ "."
554+
+ obj_node.class_name # type: ignore
555+
+ "."
556+
+ self.children["func"]
557+
)
558+
return res
559+
539560
def _construct(self):
540561
loaded_obj = self.children["obj"].construct()
541562
method = getattr(loaded_obj, self.children["func"])
@@ -658,6 +679,11 @@ def __init__(
658679
trusted: Optional[Sequence[str]] = None,
659680
) -> None:
660681
super().__init__(state, load_context, trusted)
682+
if self.module_name != "operator":
683+
raise ValueError(
684+
f"Expected module 'operator', got {self.module_name}. This is probably"
685+
" due to a corrupted or a malicious file."
686+
)
661687
self.trusted = self._get_trusted(trusted, [])
662688
self.children["attrs"] = get_tree(state["attrs"], load_context, trusted=trusted)
663689

skops/io/_persist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def load(file: str | Path, trusted: Optional[Sequence[str]] = None) -> Any:
148148
schema = json.loads(input_zip.read("schema.json"))
149149
load_context = LoadContext(src=input_zip, protocol=schema["protocol"])
150150
tree = get_tree(schema, load_context, trusted=trusted)
151-
audit_tree(tree)
151+
audit_tree(tree, trusted=trusted)
152152
instance = tree.construct()
153153

154154
return instance
@@ -188,7 +188,7 @@ def loads(data: bytes, trusted: Optional[Sequence[str]] = None) -> Any:
188188
schema = json.loads(zip_file.read("schema.json"))
189189
load_context = LoadContext(src=zip_file, protocol=schema["protocol"])
190190
tree = get_tree(schema, load_context, trusted=trusted)
191-
audit_tree(tree)
191+
audit_tree(tree, trusted=trusted)
192192
instance = tree.construct()
193193

194194
return instance

skops/io/tests/test_audit.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
import io
22
import json
3+
import operator
34
import re
45
from contextlib import suppress
56
from zipfile import ZipFile
67

78
import pytest
89
from sklearn.linear_model import LogisticRegression
10+
from sklearn.preprocessing import FunctionTransformer
911

1012
from skops.io import dumps, get_untrusted_types
1113
from skops.io._audit import Node, audit_tree, check_type, get_tree, temp_setattr
12-
from skops.io._general import DictNode, JsonNode, ObjectNode, dict_get_state
14+
from skops.io._general import (
15+
DictNode,
16+
JsonNode,
17+
MethodNode,
18+
ObjectNode,
19+
OperatorFuncNode,
20+
dict_get_state,
21+
method_get_state,
22+
operator_func_get_state,
23+
)
1324
from skops.io._utils import LoadContext, SaveContext, get_state, gettype
1425

1526

@@ -46,26 +57,26 @@ def test_audit_tree_untrusted():
4657
"Untrusted types found in the file: ['test_audit.CustomType']."
4758
),
4859
):
49-
audit_tree(node)
60+
audit_tree(node, None)
5061

5162
# there shouldn't be an error with trusted=everything
5263
node = DictNode(state, LoadContext(None, -1), trusted=["test_audit.CustomType"])
53-
audit_tree(node)
64+
audit_tree(node, None)
5465

5566
untrusted_list = get_untrusted_types(data=dumps(var))
5667
assert untrusted_list == ["test_audit.CustomType"]
5768

5869
# passing the type would fix it.
5970
node = DictNode(state, LoadContext(None, -1), trusted=untrusted_list)
60-
audit_tree(node)
71+
audit_tree(node, None)
6172

6273

6374
def test_audit_tree_defaults():
6475
# test that the default types are trusted
6576
var = {"a": 1, 2: "b"}
6677
state = dict_get_state(var, SaveContext(None, 0, {}))
6778
node = DictNode(state, LoadContext(None, -1), trusted=None)
68-
audit_tree(node)
79+
audit_tree(node, None)
6980

7081

7182
@pytest.mark.parametrize(
@@ -170,3 +181,25 @@ def test_format_json_node(inp, expected):
170181
state = get_state(inp, SaveContext(None))
171182
node = JsonNode(state, LoadContext(None, -1))
172183
assert node.format() == expected
184+
185+
186+
def test_method_node_invalid_state():
187+
# Test that MethodNode raises a ValueError if the state is invalid.
188+
# The __class__ and __module__ should match what's inside the content.
189+
var = FunctionTransformer().fit
190+
state = method_get_state(var, SaveContext(None, 0, {}))
191+
state["content"]["obj"]["__class__"] = "foo"
192+
load_context = LoadContext(None, -1)
193+
194+
with pytest.raises(ValueError, match="Expected object of type"):
195+
MethodNode(state, load_context, trusted=None)
196+
197+
198+
def test_operator_func_node_invalid_state():
199+
var = operator.methodcaller("fit")
200+
state = operator_func_get_state(var, SaveContext(None, 0, {}))
201+
state["__module__"] = "foo"
202+
load_context = LoadContext(None, -1)
203+
204+
with pytest.raises(ValueError, match="Expected module 'operator'"):
205+
OperatorFuncNode(state, load_context, trusted=None)

0 commit comments

Comments
 (0)