1616# under the License.
1717# pylint: disable=invalid-name
1818"""Device config class to hold information about the target hardware"""
19- from typing import Tuple , List , Dict
19+ from typing import Tuple , List , Dict , Optional
2020from functools import reduce
2121
2222import math
@@ -332,6 +332,7 @@ def _get_input_block(
332332
333333 def get_kernel_steps (
334334 self ,
335+ op_type : str ,
335336 dilated_kernel_h : int ,
336337 dilated_kernel_w : int ,
337338 ifm_dtype : str ,
@@ -341,6 +342,9 @@ def get_kernel_steps(
341342
342343 Parameters
343344 ----------
345+ op_type : str
346+ The NPU primitive operator
347+ "ethosu_pooling"
344348 dilated_kernel_h: int
345349 Height of dilated kernel
346350 dilated_kernel_w: int
@@ -355,18 +359,23 @@ def get_kernel_steps(
355359 List[int]
356360 List where each entry contains the amount of elements in one of the subkernels
357361 """
362+ if op_type == "ethosu_binary_elementwise" :
363+ return [1 ]
364+
358365 subkernels = self ._get_subkernels (dilated_kernel_h , dilated_kernel_w )
359366
360367 # Determine the number of kernel steps per subkernel
361368 kernel_steps = []
362369 for y , x in subkernels :
363370 subkernel_elements = x * y
364- if is_partkernel :
365- # Part-kernel-first traversal
371+ if op_type == "ethosu_conv2d" and is_partkernel :
372+ # Part-kernel-first traversal conv2d
366373 divisor = 4 if ifm_dtype == "int8" else 2
367374 kernel_steps .append (int (_round_up_div (subkernel_elements , divisor )))
375+ elif op_type == "ethosu_depthwise_conv2d" :
376+ kernel_steps .append (int (_round_up_div (subkernel_elements , 4 )))
368377 else :
369- # Depth-first traversal
378+ # Depth-first traversal conv2d or pooling
370379 kernel_steps .append (int (subkernel_elements ))
371380
372381 return kernel_steps
@@ -430,11 +439,133 @@ def is_partkernel(
430439
431440 return part_kernel_first_utilization > depth_first_utilization or ifm_channels <= 8
432441
442+ def get_elementwise_block_config (
443+ self ,
444+ ifm_propagator : Propagator ,
445+ ifm2_propagator : Optional [Propagator ],
446+ op_attrs : Dict ,
447+ ofm_shape : List [int ],
448+ output_layout : str ,
449+ input_layout : str ,
450+ input2_layout : Optional [str ],
451+ ifm_dtype : str ,
452+ ofm_dtype : str ,
453+ ) -> List [BlockConfig ]:
454+ """Get a suitable block config for an elementwise operator
455+
456+ Parameters
457+ ----------
458+ ifm_propagator: Propagator,
459+ The propagator containing the data dependencies between input and output
460+ ifm2_propagator: Propagator,
461+ The propagator containing the data dependencies between input2 and output
462+ op_attrs: Dict,
463+ Dictionary containing operator attributes
464+ ofm_shape: List[int],
465+ Shape of the output tensor
466+ output_layout: str,
467+ The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".
468+ input_layout: str,
469+ The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
470+ input2_layout: str,
471+ The layout of the Input2 Feature Map tensor. Can be "NHWC" or "NHCWB16".
472+ ifm_dtype: str,
473+ Datatype of the Input Feature Map tensor (IFM)
474+ ofm_dtype: str,
475+ Datatype of the Output Feature Map tensor (OFM)
476+
477+ Returns
478+ ----------
479+ List[BlockConfig]
480+ List containing a single suitable block config
481+ """
482+ block_config = []
483+ output_shape = [int (a ) for a in ofm_shape ]
484+
485+ op_type = op_attrs .get ("op" )
486+ op_str = op_attrs .get ("op_str" )
487+ activation = op_attrs .get ("activation" , "NONE" )
488+
489+ input_bytewidth = 1 if ifm_dtype == "int8" else 2 if ifm_dtype == "int16" else 4
490+ banks_available = self ._total_banks - self ._reserved_banks
491+ if activation == "LUT" and not self ._lut_reserved :
492+ banks_available -= 2
493+
494+ # Split the block in half until it fits into SHRAM
495+ if output_layout == "NHCWB16" :
496+ split_order = (a for a in [1 , 3 , 2 ])
497+ output_block = [
498+ output_shape [0 ],
499+ min (output_shape [1 ], self ._max_block_shape .height ),
500+ min (output_shape [2 ] * output_shape [4 ], self ._max_block_shape .depth ),
501+ min (output_shape [3 ], self ._max_block_shape .width ),
502+ 16 ,
503+ ]
504+ else :
505+ split_order = (a for a in [1 , 2 , 3 ])
506+ output_block = [
507+ output_shape [0 ],
508+ min (output_shape [1 ], self ._max_block_shape .height ),
509+ min (output_shape [2 ], self ._max_block_shape .width ),
510+ min (output_shape [3 ], self ._max_block_shape .depth ),
511+ ]
512+ split_axis = next (split_order )
513+ while True :
514+ # Create stripe config for output block
515+ offset = [0 ] * len (output_block )
516+ stripes = [1 ] * len (output_block )
517+ order = [1 , 2 , 4 , 3 , 0 ] if output_layout == "NHCWB16" else [1 , 2 , 3 , 4 ]
518+ output_stripe_config = StripeConfig (
519+ output_block , output_block , output_block , order , stripes , offset
520+ )
521+
522+ # Propagate the output to obtain the two input blocks
523+ input_block = _Shape (ifm_propagator .propagate (output_stripe_config ).shape , input_layout )
524+ if ifm2_propagator :
525+ input2_block = _Shape (
526+ ifm2_propagator .propagate (output_stripe_config ).shape , input2_layout
527+ )
528+ else :
529+ # Unary elementwise
530+ input2_block = _Shape ([0 , 0 , 0 , 0 ])
531+
532+ input_block .round_up (self ._input_micro_block )
533+ input2_block .round_up (self ._input_micro_block )
534+
535+ # Banks required for input block
536+ input_bytes = input_block .area () * self ._align (input_block .depth * input_bytewidth , 8 )
537+ input_banks = _round_up_div (input_bytes , self ._bank_size_bytes ) * 2
538+ input_banks = _round_up (input_banks , self ._input_granularity )
539+
540+ # Banks required for input2 block
541+ input2_bytes = input2_block .area () * self ._align (
542+ input2_block .depth * input_bytewidth , 8
543+ )
544+ input2_banks = _round_up_div (input2_bytes , self ._bank_size_bytes ) * 2
545+ input2_banks = _round_up (input2_banks , self ._input_granularity )
546+
547+ # Check whether or not both IFMs fit into SHRAM
548+ if (input_banks + input2_banks ) <= banks_available :
549+ output_cycles = self ._get_output_cycles (
550+ op_type , op_str , ifm_dtype , ofm_dtype , activation
551+ )
552+ output_cycles *= reduce (lambda a , b : a * b , output_block , 1 )
553+ output_cycles = int (math .ceil (output_cycles ))
554+ block_config .append (BlockConfig (output_block , 0 , output_cycles ))
555+ break
556+
557+ if output_block [split_axis ] == 1 :
558+ split_axis = next (split_order )
559+
560+ output_block [split_axis ] = _round_up_div (output_block [split_axis ], 2 )
561+
562+ return block_config
563+
433564 def get_valid_block_configs (
434565 self ,
435566 ifm_propagator : Propagator ,
436567 op_attrs : Dict ,
437- output_shape : List [int ],
568+ ofm_shape : List [int ],
438569 ofm_channels : int ,
439570 ifm_channels : int ,
440571 output_layout : str ,
@@ -452,7 +583,7 @@ def get_valid_block_configs(
452583 The propagator containing the data dependencies between input and output
453584 op_attrs: Dict,
454585 Dictionary containing operator attributes
455- output_shape : List[int],
586+ ofm_shape : List[int],
456587 Shape of the output tensor
457588 ofm_channels: int,
458589 Number of output channels
@@ -487,9 +618,9 @@ def get_valid_block_configs(
487618
488619 subkernel_transform = ifm_propagator .transform
489620 if output_layout == "NHCWB16" :
490- output_shape = _Shape ([1 , output_shape [1 ], output_shape [3 ], ofm_channels ])
621+ output_shape = _Shape ([1 , ofm_shape [1 ], ofm_shape [3 ], ofm_channels ])
491622 else :
492- output_shape = _Shape (output_shape )
623+ output_shape = _Shape (ofm_shape )
493624
494625 if input_layout == "NHCWB16" :
495626 subkernel_transform [1 ][- 1 ] = min (
@@ -571,6 +702,7 @@ def get_valid_block_configs(
571702
572703 input_block_shape = _Shape (input_block .shape , input_layout )
573704 input_block_shape .round_up (self ._input_micro_block )
705+
574706 output_block_shape = _Shape (output_block , output_layout )
575707
576708 if op_type == "ethosu_conv2d" :
@@ -592,12 +724,11 @@ def get_valid_block_configs(
592724 acc_banks = _round_up (acc_banks , self ._accumulator_granularity [acc_bytewidth ])
593725
594726 if (input_banks + acc_banks ) <= banks_available :
595-
596727 output_cycles = self ._get_output_cycles (
597728 op_type , op_str , ifm_dtype , ofm_dtype , activation
598729 )
599730 output_cycles *= reduce (lambda a , b : a * b , output_block , 1 )
600- output_cycles = int (_round_up (output_cycles , 1 ))
731+ output_cycles = int (math . ceil (output_cycles ))
601732 compute_cycles = self ._estimate_compute_cycles_per_block (
602733 op_type ,
603734 output_block_shape ,
@@ -634,16 +765,17 @@ def _estimate_compute_cycles_per_block(
634765 num_quantum_z = _round_up_div (block_shape .depth , self ._micro_block .depth )
635766 num_quantum_xy = num_quantum_x * num_quantum_y
636767
637- kernel_steps = self .get_kernel_steps (kernel_h , kernel_w , ifm_dtype , is_partkernel )
768+ kernel_steps = self .get_kernel_steps (op_type , kernel_h , kernel_w , ifm_dtype , is_partkernel )
638769
639770 wd_cycles = self ._get_weight_decoder_cycles (op_type )
640771 delay_cycles = self ._get_delay_cycles (op_type , ifm_dtype )
641772 cycle_quantum = 4
642773
643774 compute_cycles = 0
644775 for subkernel_steps in kernel_steps :
776+ subkernel_cycles = 1 if op_type == "ethosu_pooling" else subkernel_steps
645777 compute_cycles += (
646- max (wd_cycles , cycle_quantum * num_quantum_xy ) * subkernel_steps * num_quantum_z
778+ max (wd_cycles , cycle_quantum * num_quantum_xy ) * subkernel_cycles * num_quantum_z
647779 )
648780
649781 if num_quantum_xy == 1 :
0 commit comments