Skip to content

Commit 796ae1c

Browse files
committed
local change to export llama to qnn
1 parent d3326a2 commit 796ae1c

File tree

4 files changed

+111
-24
lines changed

4 files changed

+111
-24
lines changed

backends/qualcomm/quantizer/quantizer.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange
1616
from executorch.backends.qualcomm.passes.remove_clone import RemoveClone
1717
from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer
18-
18+
from executorch.backends.qualcomm.passes.convert_constants_to_attrs import ConvertConstantsToAttrs
19+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
1920
from torch import Tensor
2021
from torch._ops import OpOverload
2122
from torch.ao.quantization.observer import (
@@ -378,8 +379,58 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
378379
model = DecomposeScaledDotProductAttention()(model).graph_module
379380
model = DecomposeSilu()(model).graph_module
380381
model = ReplaceInfBuffer()(model).graph_module
382+
# ConvertConstantsToAttrs(model)
383+
self._lift_constant_scalar_operands(model)
384+
# model = ConvertBinaryOpsWithScalar()(model).graph_module
381385

382386
return model
383387

388+
def _lift_constant_scalar_operands(self, gm: torch.fx.GraphModule) -> None:
389+
# print("running _lift_constant_scalar_operands...")
390+
for n in gm.graph.nodes:
391+
# if n.name == "mul_78":
392+
# print(" n.name: ", n.name)
393+
394+
if n.op != "call_function" or n.target not in (
395+
torch.ops.aten.add.Tensor,
396+
torch.ops.aten.sub.Tensor,
397+
torch.ops.aten.mul.Tensor,
398+
torch.ops.aten.mul.Scalar,
399+
torch.ops.aten.rsub.Scalar,
400+
):
401+
continue
402+
403+
# print(" handling n: ", n, " n.target: ", n.target, " n.args: ", n.args)
404+
const_arg = None
405+
non_const_arg = None
406+
for arg in n.args:
407+
if isinstance(arg, torch.fx.Node):
408+
non_const_arg = arg
409+
else:
410+
const_arg = arg
411+
412+
if non_const_arg is None or const_arg is None:
413+
continue
414+
415+
# print(" n'args are all constant: ", n)
416+
tensor_constant = torch.tensor([const_arg], dtype=torch.float32)
417+
tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")(
418+
gm
419+
)
420+
gm.register_buffer(tensor_constant_name, tensor_constant)
421+
422+
fake_mode = n.meta["val"].fake_mode
423+
with gm.graph.inserting_before(n):
424+
get_attr_node = gm.graph.get_attr(tensor_constant_name)
425+
get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant)
426+
427+
if n.target == torch.ops.aten.rsub.Scalar:
428+
n.args = (get_attr_node, non_const_arg) + n.args[2:]
429+
n.target = torch.ops.aten.sub.Tensor
430+
else:
431+
n.args = (non_const_arg, get_attr_node) + n.args[2:]
432+
433+
gm.recompile()
434+
384435
def validate(self, model: GraphModule) -> None:
385436
pass

backends/qualcomm/quantizer/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Callable, Dict, List, Optional, Sequence
88

99
import torch
10+
from torch._subclasses import FakeTensor
1011

1112
from torch._ops import OpOverload
1213

@@ -41,6 +42,13 @@ def decorator(annotator: Callable):
4142

4243
return decorator
4344

45+
def _is_input_non_float_tensor(node: Node):
46+
"""Check if the input is not a float tensor, so that we can skip quantization for the node
47+
since observers only works with float Tensors
48+
"""
49+
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
50+
return True
51+
return node.meta["val"].dtype != torch.float32
4452

4553
def _is_annotated(nodes: List[Node]):
4654
"""
@@ -115,6 +123,7 @@ def annotate_single_in_single_out(
115123

116124

117125
def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None:
126+
print(f"annotate_binary running for node {node}...")
118127
if _is_annotated([node]):
119128
return
120129

@@ -123,12 +132,14 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None
123132

124133
input_qspec_map = {}
125134
input_act0 = node.args[0]
126-
if isinstance(input_act0, Node):
135+
if isinstance(input_act0, Node) and not _is_input_non_float_tensor(input_act0):
127136
input_qspec_map[input_act0] = input_act_qspec
137+
print(" input_act0: ", input_act0, " _is_input_non_float_tensor: ", _is_input_non_float_tensor(input_act0))
128138

129139
input_act1 = node.args[1]
130-
if isinstance(input_act1, Node):
140+
if isinstance(input_act1, Node) and not _is_input_non_float_tensor(input_act1):
131141
input_qspec_map[input_act1] = input_act_qspec
142+
print(" input_act1: ", input_act1, " _is_input_non_float_tensor: ", _is_input_non_float_tensor(input_act1))
132143

133144
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
134145
input_qspec_map=input_qspec_map,
@@ -147,7 +158,8 @@ def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None:
147158
annotate_binary(node, quantization_config)
148159

149160

150-
@register_annotator([torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar])
161+
# @register_annotator([torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar])
162+
@register_annotator([torch.ops.aten.mul.Tensor])
151163
def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None:
152164
annotate_binary(node, quantization_config)
153165

examples/models/llama2/export_llama_lib.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -647,17 +647,17 @@ def _export_llama(modelname, args) -> str: # noqa: C901
647647
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
648648
generate_qnn_executorch_compiler_spec(
649649
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
650-
soc_model=QcomChipset.SM8650, # default to SM8650
650+
soc_model=QcomChipset.SM8450, # default to SM8650
651651
backend_options=backend_options,
652652
debug=False,
653653
saver=False,
654654
),
655655
skip_node_id_set={},
656-
skip_node_op_set={},
656+
skip_node_op_set={"aten.unsqueeze_copy.default", "aten.permute_copy.default"},
657657
)
658658
)
659659
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
660-
_transform(builder_exported_to_edge.export_program())
660+
_transform(builder_exported_to_edge.edge_manager.exported_program())
661661

662662
if args.generate_etrecord:
663663
if not builder_exported_to_edge.edge_manager:
@@ -678,7 +678,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901
678678
logging.info("Generated etrecord.bin")
679679
else:
680680
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
681-
681+
print("graph after to_backend")
682+
builder.edge_manager.exported_program().graph.print_tabular()
682683
if args.profile_memory:
683684
generate_memory_trace(builder.export_program, "memory_profile.json")
684685

examples/models/llama2/llama_transformer.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.nn.functional as F
1515

1616
from torch import nn
17-
17+
import math
1818

1919
class RMSNorm(torch.nn.Module):
2020
def __init__(self, dim: int, eps: float = 1e-6):
@@ -216,15 +216,23 @@ def __init__(self, args: ModelArgs, layer_id: int):
216216
self.use_sdpa_with_kv_cache_op = args.use_sdpa_with_kv_cache_op
217217
self.layer_id = layer_id
218218

219-
causal_mask = torch.tril(
220-
torch.ones(
221-
self.max_seq_len,
222-
self.max_seq_len,
223-
dtype=torch.bool,
224-
device="cpu",
225-
)
219+
# causal_mask = torch.tril(
220+
# torch.ones(
221+
# self.max_seq_len,
222+
# self.max_seq_len,
223+
# dtype=torch.bool,
224+
# device="cpu",
225+
# )
226+
# )
227+
# self.register_buffer("mask", causal_mask, persistent=False)
228+
mask = torch.full(
229+
(1, 1, args.max_seq_len, args.max_seq_len),
230+
float("-inf"),
231+
device="cpu",
226232
)
227-
self.register_buffer("mask", causal_mask, persistent=False)
233+
234+
mask = torch.triu(mask, diagonal=1)
235+
self.register_buffer("mask", mask)
228236

229237
if self.use_kv_cache:
230238
self.kv_cache = KVCache(
@@ -264,18 +272,33 @@ def forward(
264272
v = v.transpose(1, 2)
265273

266274
k, v = self.kv_cache.update(input_pos, k, v)
267-
mask = self.mask[None, None, input_pos]
275+
mask = torch.squeeze(self.mask, [0, 1])
276+
mask = mask[None, None, input_pos]
277+
# mask = self.mask[None, None, input_pos]
268278

269279
k = k.repeat_interleave(self.n_rep, dim=1)
270280
v = v.repeat_interleave(self.n_rep, dim=1)
271-
y = F.scaled_dot_product_attention(
272-
q, k, v, attn_mask=mask, dropout_p=0.0
273-
)
274281

275-
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
276282

277-
y = self.wo(y)
278-
return y
283+
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
284+
scores = F.softmax(scores.float(), dim=-1).type_as(q)
285+
scores = scores + mask
286+
output = torch.matmul(
287+
scores, v
288+
) # (bs, n_local_heads, seqlen, head_dim)
289+
290+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
291+
292+
output = self.wo(output)
293+
return output
294+
# y = F.scaled_dot_product_attention(
295+
# q, k, v, attn_mask=mask, dropout_p=0.0
296+
# )
297+
298+
# y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
299+
300+
# y = self.wo(y)
301+
# return y
279302
else:
280303
from .custom_ops.sdpa_with_kv_cache import sdpa_with_kv_cache # noqa
281304

0 commit comments

Comments
 (0)