Skip to content

Commit d24fea6

Browse files
authored
INC Bench pruning support (#295)
Signed-off-by: bmyrcha <[email protected]>
1 parent 5b9be25 commit d24fea6

File tree

90 files changed

+5216
-764
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+5216
-764
lines changed

neural_compressor/ux/components/benchmark/execute_benchmark.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,6 @@ def execute_benchmark(data: Dict[str, Any]) -> None:
6464
project_id = benchmark_details["project_id"]
6565
project_details = ProjectAPIInterface.get_project_details({"id": project_id})
6666

67-
BenchmarkAPIInterface.update_benchmark_status(
68-
{
69-
"id": benchmark_id,
70-
"status": ExecutionStatus.WIP,
71-
},
72-
)
73-
7467
response_data = execute_real_benchmark(
7568
request_id=request_id,
7669
project_details=project_details,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright (c) 2022 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Pruning configuration generator class."""
16+
from typing import Any
17+
18+
from neural_compressor.ux.components.config_generator.config_generator import ConfigGenerator
19+
from neural_compressor.ux.utils.workload.config import Config
20+
from neural_compressor.ux.utils.workload.evaluation import Accuracy, Evaluation, Metric
21+
from neural_compressor.ux.utils.workload.pruning import Pruning
22+
23+
24+
class PruningConfigGenerator(ConfigGenerator):
25+
"""PruningConfigGenerator class."""
26+
27+
def __init__(self, *args: Any, **kwargs: Any) -> None:
28+
"""Initialize configuration generator."""
29+
super().__init__(*args, **kwargs)
30+
data = kwargs.get("data", {})
31+
self.pruning_configuration: dict = data["pruning_details"]
32+
33+
def generate(self) -> None:
34+
"""Generate yaml config file."""
35+
config = Config()
36+
config.load(self.predefined_config_path)
37+
config.quantization = None
38+
config.model = self.generate_model_config()
39+
config.evaluation = self.generate_evaluation_config()
40+
config.pruning = self.generate_pruning_config()
41+
config.dump(self.config_path)
42+
43+
def generate_evaluation_config(self) -> Evaluation:
44+
"""Generate evaluation configuration."""
45+
evaluation = Evaluation()
46+
evaluation.accuracy = Accuracy()
47+
48+
if self.metric:
49+
evaluation.accuracy.metric = Metric(self.metric)
50+
51+
evaluation.accuracy.dataloader = self.generate_dataloader_config(batch_size=1)
52+
evaluation.set_accuracy_postprocess_transforms(self.transforms)
53+
return evaluation
54+
55+
def generate_pruning_config(self) -> Pruning:
56+
"""Generate graph optimization configuration."""
57+
pruning = Pruning(self.pruning_configuration)
58+
if pruning.train is not None:
59+
pruning.train.dataloader = self.generate_dataloader_config(batch_size=1)
60+
pruning.train.set_postprocess_transforms(self.transforms)
61+
return pruning
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright (c) 2022 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Configuration type parser."""
16+
from typing import Any
17+
18+
19+
class PruningConfigParser:
20+
"""Pruning configuration parser class."""
21+
22+
def parse(self, input_data: list) -> dict:
23+
"""Parse configuration."""
24+
raise NotImplementedError
25+
26+
def generate_tree(self, input_data: dict) -> list:
27+
"""Generate tree from pruning configuration."""
28+
parsed_tree = self.parse_entry(input_data)
29+
return parsed_tree
30+
31+
def parse_entry(self, input_data: dict) -> Any:
32+
"""Parse configuration entry to tree element."""
33+
config_tree = []
34+
for key, value in input_data.items():
35+
if key in ["train", "approach"] and value is None:
36+
continue
37+
parsed_entry = {"name": key}
38+
if isinstance(value, dict):
39+
children = self.parse_entry(value)
40+
parsed_entry.update({"children": children})
41+
elif isinstance(value, list):
42+
for list_entry in value:
43+
parsed_list_entries = self.parse_entry(list_entry)
44+
parsed_entry.update({"children": parsed_list_entries})
45+
else:
46+
parsed_entry.update({"value": value})
47+
config_tree.append(parsed_entry)
48+
return config_tree
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright (c) 2022 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# flake8: noqa
16+
# mypy: ignore-errors
17+
"""pruning_support
18+
19+
Revision ID: 644ec953a7dc
20+
Revises: 6ece06672ed3
21+
Create Date: 2022-12-09 17:22:17.310141
22+
23+
"""
24+
import sqlalchemy as sa
25+
from alembic import op
26+
from sqlalchemy.orm import sessionmaker
27+
28+
from neural_compressor.ux.components.db_manager.db_manager import DBManager
29+
from neural_compressor.ux.components.db_manager.db_models.optimization_type import OptimizationType
30+
from neural_compressor.ux.components.db_manager.db_models.precision import (
31+
Precision,
32+
precision_optimization_type_association,
33+
)
34+
from neural_compressor.ux.utils.consts import OptimizationTypes, Precisions
35+
36+
db_manager = DBManager()
37+
Session = sessionmaker(bind=db_manager.engine)
38+
39+
# revision identifiers, used by Alembic.
40+
revision = "644ec953a7dc"
41+
down_revision = "6ece06672ed3"
42+
branch_labels = None
43+
depends_on = None
44+
45+
46+
def upgrade():
47+
# ### commands auto generated by Alembic - please adjust! ###
48+
with Session.begin() as db_session:
49+
pruning_optimization_id = OptimizationType.add(
50+
db_session=db_session,
51+
name=OptimizationTypes.PRUNING.value,
52+
)
53+
fp32_precision_id = Precision.get_precision_by_name(
54+
db_session=db_session,
55+
precision_name=Precisions.FP32.value,
56+
)[0]
57+
58+
query = precision_optimization_type_association.insert().values(
59+
precision_id=fp32_precision_id,
60+
optimization_type_id=pruning_optimization_id,
61+
)
62+
db_session.execute(query)
63+
64+
op.create_table(
65+
"pruning_details",
66+
sa.Column("id", sa.Integer(), nullable=False),
67+
sa.Column("train", sa.String(), nullable=True),
68+
sa.Column("approach", sa.String(), nullable=True),
69+
sa.Column(
70+
"created_at",
71+
sa.DateTime(),
72+
server_default=sa.text("(CURRENT_TIMESTAMP)"),
73+
nullable=False,
74+
),
75+
sa.Column("modified_at", sa.DateTime(), nullable=True),
76+
sa.PrimaryKeyConstraint("id", name=op.f("pk_pruning_details")),
77+
)
78+
with op.batch_alter_table("pruning_details", schema=None) as batch_op:
79+
batch_op.create_index(batch_op.f("ix_pruning_details_id"), ["id"], unique=True)
80+
81+
op.create_table(
82+
"example",
83+
sa.Column("id", sa.Integer(), nullable=False),
84+
sa.Column("name", sa.String(length=50), nullable=False),
85+
sa.Column("framework", sa.Integer(), nullable=False),
86+
sa.Column("domain", sa.Integer(), nullable=False),
87+
sa.Column("dataset_type", sa.String(length=50), nullable=False),
88+
sa.Column("model_url", sa.String(length=250), nullable=False),
89+
sa.Column("config_url", sa.String(length=250), nullable=False),
90+
sa.Column("created_at", sa.DateTime(), nullable=False),
91+
sa.ForeignKeyConstraint(["domain"], ["domain.id"], name=op.f("fk_example_domain_domain")),
92+
sa.ForeignKeyConstraint(
93+
["framework"], ["framework.id"], name=op.f("fk_example_framework_framework")
94+
),
95+
sa.PrimaryKeyConstraint("id", name=op.f("pk_example")),
96+
)
97+
with op.batch_alter_table("example", schema=None) as batch_op:
98+
batch_op.create_index(batch_op.f("ix_example_id"), ["id"], unique=False)
99+
100+
with op.batch_alter_table("model", schema=None) as batch_op:
101+
batch_op.add_column(
102+
sa.Column(
103+
"supports_pruning",
104+
sa.Boolean(),
105+
default=False,
106+
nullable=True,
107+
),
108+
)
109+
op.execute("UPDATE model SET supports_pruning = false")
110+
111+
with op.batch_alter_table("model", schema=None) as batch_op:
112+
batch_op.alter_column("supports_pruning", nullable=False)
113+
114+
with op.batch_alter_table("optimization", schema=None) as batch_op:
115+
batch_op.add_column(sa.Column("pruning_details_id", sa.Integer(), nullable=True))
116+
batch_op.create_foreign_key(
117+
batch_op.f("fk_optimization_pruning_details_id_pruning_details"),
118+
"pruning_details",
119+
["pruning_details_id"],
120+
["id"],
121+
)
122+
123+
# ### end Alembic commands ###
124+
125+
126+
def downgrade():
127+
# ### commands auto generated by Alembic - please adjust! ###
128+
with op.batch_alter_table("optimization", schema=None) as batch_op:
129+
batch_op.drop_constraint(
130+
batch_op.f("fk_optimization_pruning_details_id_pruning_details"), type_="foreignkey"
131+
)
132+
batch_op.drop_column("pruning_details_id")
133+
134+
with op.batch_alter_table("model", schema=None) as batch_op:
135+
batch_op.drop_column("supports_pruning")
136+
137+
with op.batch_alter_table("example", schema=None) as batch_op:
138+
batch_op.drop_index(batch_op.f("ix_example_id"))
139+
140+
op.drop_table("example")
141+
with op.batch_alter_table("pruning_details", schema=None) as batch_op:
142+
batch_op.drop_index(batch_op.f("ix_pruning_details_id"))
143+
144+
op.drop_table("pruning_details")
145+
# ### end Alembic commands ###

neural_compressor/ux/components/db_manager/db_models/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class Model(Base):
4646
output_nodes = Column(String(250), nullable=False, default="")
4747
supports_profiling = Column(Boolean, nullable=False, default=False)
4848
supports_graph = Column(Boolean, nullable=False, default=False)
49+
supports_pruning = Column(Boolean, nullable=False, default=False)
4950
created_at = Column(DateTime, nullable=False, default=func.now())
5051

5152
project: Any = relationship("Project", back_populates="models")
@@ -104,6 +105,7 @@ def add(
104105
domain_flavour_id: int,
105106
supports_profiling: bool,
106107
supports_graph: bool,
108+
supports_pruning: bool,
107109
) -> int:
108110
"""
109111
Add model to database.
@@ -123,6 +125,7 @@ def add(
123125
domain_flavour_id=domain_flavour_id,
124126
supports_profiling=supports_profiling,
125127
supports_graph=supports_graph,
128+
supports_pruning=supports_pruning,
126129
)
127130
db_session.add(new_model)
128131
db_session.flush()
@@ -206,5 +209,6 @@ def build_info(model: Any) -> dict:
206209
"output_nodes": json.loads(model.output_nodes),
207210
"supports_profiling": model.supports_profiling,
208211
"supports_graph": model.supports_graph,
212+
"supports_pruning": model.supports_pruning,
209213
"created_at": str(model.created_at),
210214
}

0 commit comments

Comments
 (0)