Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ jobs:
commands: |
echo "::add-matcher::.github/workflows/matchers/pylint.json"
tox -e lint
- name: "mypy"
commands: |
echo "::add-matcher::.github/workflows/matchers/mypy.json"
tox -e mypy

steps:
- name: "Harden Runner"
Expand Down
16 changes: 16 additions & 0 deletions .github/workflows/matchers/mypy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"problemMatcher": [
{
"owner": "mypy",
"pattern": [
{
"regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$",
"file": 1,
"line": 2,
"severity": 3,
"message": 4
}
]
}
]
}
4 changes: 2 additions & 2 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Implement FMS adapter for INT8xINT8 checkpoints"""

# Standard
from typing import Mapping
from typing import Mapping, MutableMapping

# Third Party
from fms.utils import serialization
Expand Down Expand Up @@ -47,7 +47,7 @@ def _int8_qparams_aiu(


def _add_defaults_and_concat(
new_sd: dict[str, torch.Tensor],
new_sd: MutableMapping[str, torch.Tensor],
modules_seen: set[str],
) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion fms_mo/calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@ def qmodel_calib(
return model

DPorDDPdevices = None
if "qmodel_prep" not in sys._getframe().f_back.f_code.co_name:
f_back = sys._getframe().f_back
if f_back and "qmodel_prep" not in f_back.f_code.co_name:
model.to(currDev)
qcfg["wasDPmodel"] = qcfg.get("wasDPmodel", isinstance(model, nn.DataParallel))
qcfg["wasDDPmodel"] = qcfg.get(
Expand Down
6 changes: 3 additions & 3 deletions fms_mo/custom_ext_kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@
# Third Party
import torch.library as lib

reg_op = partial(lib.custom_op, mutates_args=())
reg_op = partial(lib.custom_op, mutates_args=()) # type: ignore[attr-defined]
reg_op_func = lib.define # NOTE this is func, not decorator
kernel_impl = lib.register_kernel
reg_fake = lib.register_fake
kernel_impl = lib.register_kernel # type: ignore[attr-defined]

Check failure on line 79 in fms_mo/custom_ext_kernels/utils.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "Callable[[str | OpOverload | CustomOpDef, str | Sequence[str] | None, Callable[..., Any] | None, DefaultNamedArg(Library | None, 'lib')], Any]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'device_types'), DefaultNamedArg(Any, 'func')], Any]") [assignment]
reg_fake = lib.register_fake # type: ignore[attr-defined]

Check failure on line 80 in fms_mo/custom_ext_kernels/utils.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "Callable[[str | OpOverload | CustomOpDef, Callable[..., Any] | None, DefaultNamedArg(Library | None, 'lib'), DefaultNamedArg(int, '_stacklevel')], Any]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'func')], Any]") [assignment]

else:
raise RuntimeError("Custom Op registration only works for >PT2.1")
Expand Down Expand Up @@ -532,7 +532,7 @@

# Register op
@reg_op(f"{namespace}::exv1_i4f16")
def exv1_i4f16(x: torch.Tensor, q4: int, q4_width: int) -> torch.Tensor:

Check failure on line 535 in fms_mo/custom_ext_kernels/utils.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Missing return statement [empty-body]
"""q4 is the handle, i.e. INT "address", to packed weight, not a tensor"""

# Generic implementation
Expand All @@ -558,7 +558,7 @@

# Register op
@reg_op(f"{namespace}::exv2_i4f16")
def exv2_i4f16(

Check failure on line 561 in fms_mo/custom_ext_kernels/utils.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Missing return statement [empty-body]
x: torch.Tensor, q4: int, q4_width: int, force_cuda: bool
) -> torch.Tensor:
"""q4 is the handle, i.e. INT "address", to packed weight, not a tensor"""
Expand Down Expand Up @@ -587,7 +587,7 @@
# Wrappers for better graph representation, i.e. try to pass real tensors instead of just
# one handle, even though the extra terms will not be used in the external Kernel...
@reg_op(f"{namespace}::exv2_i4f16_fxinputs")
def exv2_i4f16_fxinputs(

Check failure on line 590 in fms_mo/custom_ext_kernels/utils.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Missing return statement [empty-body]
x: torch.Tensor,
qw: torch.Tensor,
qzeros: torch.Tensor,
Expand Down Expand Up @@ -848,7 +848,7 @@
use_inductor = "default"
if use_inductor:
with torch.no_grad():
mod = torch.compile(mod, mode=use_inductor)

Check failure on line 851 in fms_mo/custom_ext_kernels/utils.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "Callable[[VarArg(Any), KwArg(Any)], Any]", variable has type Module) [assignment]
mod(**exam_inp_eval)

logger.info(f"\nModel lowered {'and compiled' if use_inductor else ''}.\n{mod}")
Expand Down
8 changes: 5 additions & 3 deletions fms_mo/quant/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2631,8 +2631,10 @@ def reset_bn(module: nn.BatchNorm2d):
Function not currently used.
"""
if module.track_running_stats:
module.running_mean.zero_()
module.running_var.fill_(1 - module.eps)
if running_mean := module.running_mean:
running_mean.zero_()
if running_var := module.running_var:
running_var.fill_(1 - module.eps)
# we do not reset numer of tracked batches here
if module.affine:
nn.init.ones_(module.weight)
Expand All @@ -2651,7 +2653,7 @@ def reset_bn(module: nn.BatchNorm2d):
bn_affine = True # FrozenBN doesn't have .affine property
except:
BNofInteret = (nn.BatchNorm2d, nn.BatchNorm1d)
AbsorbLayers = (nn.Conv2d, nn.Linear)
AbsorbLayers = (nn.Conv2d, nn.Linear) # type: ignore[assignment]


def search_fold_and_remove_bn(model, mod_folded):
Expand Down
21 changes: 12 additions & 9 deletions fms_mo/quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""

# pylint: disable=too-many-return-statements
# mypy: disable-error-code="assignment"

# Standard
from collections.abc import Mapping
Expand Down Expand Up @@ -3895,7 +3896,8 @@ def forward(self, x: torch.Tensor):
self.delta = torch.nn.Parameter(delta)
else:
delta, zero_point = self.init_quantization_scale(x, self.channel_wise)
self.delta.fill_(delta)
if self_data := self.delta:
self_data.fill_(delta)
self.zero_point.fill_(zero_point)
self.inited = True

Expand All @@ -3906,7 +3908,8 @@ def forward(self, x: torch.Tensor):
return x_dequant

def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False):
delta, zero_point = None, None
# delta, zero_point = 1.0, 0
# init seems unnecessary, comment out to avoid None induced type chk err
if channel_wise:
x_clone = x.clone().detach()
n_channels = x_clone.shape[0]
Expand Down Expand Up @@ -3935,7 +3938,7 @@ def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False):
x_min = x_min * (self.n_bits + 2) / 8
x_max = x_max * (self.n_bits + 2) / 8

x_absmax = max(abs(x_min), x_max)
x_absmax = max(abs(x_min), x_max) # type: ignore [call-overload]
if self.sym:
x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax

Expand All @@ -3960,7 +3963,7 @@ def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False):
if score < best_score:
best_score = score
delta = (new_max - new_min) / (2**self.n_bits - 1)
zero_point = (-new_min / delta).round()
zero_point = (-new_min / delta).round() # type: ignore[union-attr]
else:
raise NotImplementedError

Expand Down Expand Up @@ -4035,8 +4038,8 @@ def __init__(
self.reset_ReSig_param(multimodal)

self.beta = 2 / 3
self.Wshape = None
self.reshape2 = None
self.Wshape: list[int] = list()
self.reshape2: list[Any] = list()

def forward(self, x):
if self.useSAWB:
Expand Down Expand Up @@ -4583,7 +4586,7 @@ def transformers_prepare_input(
if isinstance(data, Mapping):
return type(data)(
{k: transformers_prepare_input(v, dev=dev) for k, v in data.items()}
)
) # type: ignore[call-arg]
if isinstance(data, (tuple, list)):
return type(data)(transformers_prepare_input(v, dev=dev) for v in data)
if isinstance(data, torch.Tensor):
Expand Down Expand Up @@ -5389,7 +5392,7 @@ def __init__(
if "e4m3" in q_mode:
self.float8_dtype = torch.float8_e4m3fn
elif "e5m2" in q_mode:
self.float8_dtype = torch.float8_e5m2G
self.float8_dtype = torch.float8_e5m2
else:
raise ValueError("FP8 only supports e4m3 and e5m2")
self.emulate = emulate
Expand Down Expand Up @@ -5451,7 +5454,7 @@ def custom_fp8_quantizer(
mantissa_bits: int = 3,
use_subnormal: bool = False,
scale_to_max: bool = False,
) -> torch.Tensor:
):
"""Convert tensor tensor to FP8 format, remanining in decimal form (no binary conversion)
and using some clever manipulation to round each tensor values to the closest representable
FP8 value.
Expand Down
4 changes: 2 additions & 2 deletions fms_mo/utils/qconfig_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Standard
from pathlib import Path
from typing import Any
from typing import Any, Dict
import json
import logging
import os
Expand Down Expand Up @@ -149,7 +149,7 @@ def qconfig_init(recipe: str = None, args: Any = None):
otherwise use constantLR as default
"""

qcfg = {}
qcfg: Dict[str, Any] = {}
# 1. create a dict with default values
qcfg["mapping"] = {
nn.Conv2d: {"from": nn.Conv2d, "to": QConv2d, "otherwise": QConv2d},
Expand Down
51 changes: 30 additions & 21 deletions fms_mo/utils/torchscript_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def parse_operation(op_str: str):
operands = op_str[
last_open_parenthesis_index + 1 : last_close_parenthesis_index
].split(",")
operands = [operand.strip() for operand in operands] if operands != [""] else None
# pylint: disable=line-too-long
operands = [operand.strip() for operand in operands] if operands != [""] else None # type: ignore[assignment]
return operator, operands


Expand Down Expand Up @@ -178,9 +179,14 @@ def __init__(self, node_input, dictionary_of_nodes: dict):
)
operator, operands = parse_operation(op_str)
if "aten::_conv" in op_str:
self.ch_in = list(native_torchscript_node.inputs())[0].type().sizes()
# NOTE: Needed for finding shortcut convolutions later
self.ch_out = list(native_torchscript_node.outputs())[0].type().sizes()
if native_torchscript_node:
self.ch_in = (
list(native_torchscript_node.inputs())[0].type().sizes()
)
# NOTE: Needed for finding shortcut convolutions later
self.ch_out = (
list(native_torchscript_node.outputs())[0].type().sizes()
)
else:
node_def = node_input_repr
op_str, operator, operands = None, None, None
Expand All @@ -200,31 +206,34 @@ def __init__(self, node_input, dictionary_of_nodes: dict):
working_str = node_input_repr[start_index:end_index]
start_index = end_index + 2

node_instance.name, node_instance.obj = working_str.split(" : ")
node_instance.name = node_instance.name.strip()
# pylint: disable=line-too-long
node_instance.name, node_instance.obj = working_str.split(" : ") # type: ignore[attr-defined]
node_instance.name = node_instance.name.strip() # type: ignore[attr-defined]
if native_torchscript_outputs:
if node_instance.name not in native_torchscript_outputs:
# pylint: disable=line-too-long
if node_instance.name not in native_torchscript_outputs: # type: ignore[attr-defined]
# pylint: disable=line-too-long
logger.error(
f"Node def {node_instance.name} not in nativeTSoutputs "
f"Node def {node_instance.name} not in nativeTSoutputs " # type: ignore[attr-defined]
f"{native_torchscript_outputs}"
)
node_instance.Op = op_str
node_instance.Op = op_str # type: ignore[attr-defined]
if node_def_in_one_line > 1:
node_instance.unpackIdx = node_index
node_instance.unpackIdx = node_index # type: ignore[attr-defined]
if line_number:
node_instance.lineno = line_number
node_instance.operator = operator
node_instance.lineno = line_number # type: ignore[attr-defined]
node_instance.operator = operator # type: ignore[attr-defined]
# This is the name of parents, not the pointer to the parent nodes
node_instance.parents = operands
node_instance.parents_ptr = []
node_instance.scope = scope_repr
node_instance.modname = module_name
node_instance.children = []
node_instance.children_ptr = []
node_instance.TSparents = native_torchscript_parents
node_instance.TSoutputs = native_torchscript_outputs
node_instance.parents = operands # type: ignore[attr-defined]
node_instance.parents_ptr = [] # type: ignore[attr-defined]
node_instance.scope = scope_repr # type: ignore[attr-defined]
node_instance.modname = module_name # type: ignore[attr-defined]
node_instance.children = [] # type: ignore[attr-defined]
node_instance.children_ptr = [] # type: ignore[attr-defined]
node_instance.TSparents = native_torchscript_parents # type: ignore[attr-defined]
node_instance.TSoutputs = native_torchscript_outputs # type: ignore[attr-defined]
# graph.dictionary_of_nodes will keep a record of all the nodes
dictionary_of_nodes[node_instance.name] = node_instance
dictionary_of_nodes[node_instance.name] = node_instance # type: ignore[attr-defined]

def __repr__(self):
return f"{self.name} "
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ known-local-folder=["fms_mo","tests"]
[tool.mypy]
mypy_path = [""]
packages = ["fms_mo", "tests"]
disable_error_code = []
disable_error_code = ["import-not-found", "import-untyped", "no-any-return"]
# TODO: tighten MyPy checks by enabling these checks over time.
check_untyped_defs = false
disallow_incomplete_defs = false
Expand Down
Loading