@@ -60,6 +60,8 @@ class GPUSampler : public SamplerObj {
6060 uniform_samples_host_ = NDArray::Empty ({max_num_sample}, dtype_f32_, device_cpu);
6161 sample_indices_host_ = NDArray::Empty ({max_num_sample}, dtype_i32_, device_cpu);
6262 top_p_host_ = NDArray::Empty ({max_num_sample}, dtype_f32_, device_cpu);
63+ top_p_init_pivots_host_ =
64+ NDArray::Empty ({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device_cpu);
6365 top_prob_offsets_host_ = NDArray::Empty ({max_num_sample * 5 }, dtype_i32_, device_cpu);
6466 draft_tokens_host_ = NDArray::Empty ({max_num_sample}, dtype_i32_, device_cpu);
6567 token_tree_first_child_host_ = NDArray::Empty ({max_num_sample}, dtype_i32_, device_cpu);
@@ -73,6 +75,8 @@ class GPUSampler : public SamplerObj {
7375 uniform_samples_device_ = NDArray::Empty ({max_num_sample}, dtype_f32_, device);
7476 sample_indices_device_ = NDArray::Empty ({max_num_sample}, dtype_i32_, device);
7577 top_p_device_ = NDArray::Empty ({max_num_sample}, dtype_f32_, device);
78+ top_p_init_pivots_device_ =
79+ NDArray::Empty ({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device);
7680 top_prob_offsets_device_ = NDArray::Empty ({max_num_sample * 5 }, dtype_i32_, device);
7781 draft_tokens_device_ = NDArray::Empty ({max_num_sample}, dtype_i32_, device);
7882 token_tree_first_child_device_ = NDArray::Empty ({max_num_sample}, dtype_i32_, device);
@@ -118,21 +122,35 @@ class GPUSampler : public SamplerObj {
118122 return probs_on_device;
119123 }
120124
121- // - Argsort the probability.
122- Array<NDArray> argsort_results = gpu_argsort_probs_func_ (probs_on_device);
123- ICHECK_EQ (argsort_results.size (), 2 );
124- NDArray sorted_probs_on_device = argsort_results[0 ];
125- NDArray sorted_indices_on_device = argsort_results[1 ];
126-
127- // - Copy auxiliary array for top-p.
125+ // - Copy auxiliary array for top-p and initial pivots.
128126 NDArray top_p_host = top_p_host_.CreateView ({num_probs}, dtype_f32_);
129127 NDArray top_p_device = top_p_device_.CreateView ({num_probs}, dtype_f32_);
130128 CopyArray (/* src=*/ top_p_host, /* dst=*/ top_p_device, copy_stream_);
129+
130+ NDArray top_p_init_pivots_host =
131+ top_p_init_pivots_host_.CreateView ({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_);
132+ NDArray top_p_init_pivots_device =
133+ top_p_init_pivots_device_.CreateView ({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_);
134+ const float * p_top_p = static_cast <const float *>(top_p_host->data );
135+ float * p_top_p_init_pivots = static_cast <float *>(top_p_init_pivots_host->data );
136+ for (int i = 0 ; i < num_probs; ++i) {
137+ if (1 - p_top_p[i] >= 0.02 ) {
138+ p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] =
139+ std::min (1 - p_top_p[i], static_cast <float >(0.5 ));
140+ p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1 ] = 0.02 ;
141+ p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2 ] = 0.01 ;
142+ } else {
143+ p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] = 1 - p_top_p[i];
144+ p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1 ] = (1 - p_top_p[i]) / 2 ;
145+ p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2 ] = (1 - p_top_p[i]) / 4 ;
146+ }
147+ }
148+ CopyArray (/* src=*/ top_p_init_pivots_host, /* dst=*/ top_p_init_pivots_device, copy_stream_);
131149 SyncCopyStream (device_, compute_stream_, copy_stream_);
132150
133151 // - Renormalize the prob with top p.
134152 NDArray renormed_probs_on_device =
135- gpu_renormalize_by_top_p_func_ (probs_on_device, sorted_probs_on_device, top_p_device );
153+ gpu_renormalize_by_top_p_func_ (probs_on_device, top_p_device, top_p_init_pivots_device );
136154
137155 RECORD_EVENT (trace_recorder_, request_ids, " finish renormalization by top p" );
138156 return renormed_probs_on_device;
@@ -500,6 +518,9 @@ class GPUSampler : public SamplerObj {
500518 << " GPU sampler requires the top_p values for each prob distribution are the same." ;
501519 }
502520 }
521+ for (int i = 0 ; i < num_probs; ++i) {
522+ p_top_p[i] = std::max (p_top_p[i], eps_);
523+ }
503524 return need_top_p;
504525 }
505526
@@ -665,6 +686,7 @@ class GPUSampler : public SamplerObj {
665686 NDArray uniform_samples_host_;
666687 NDArray sample_indices_host_;
667688 NDArray top_p_host_;
689+ NDArray top_p_init_pivots_host_;
668690 NDArray top_prob_offsets_host_;
669691 NDArray draft_tokens_host_;
670692 NDArray token_tree_first_child_host_;
@@ -678,6 +700,7 @@ class GPUSampler : public SamplerObj {
678700 NDArray uniform_samples_device_;
679701 NDArray sample_indices_device_;
680702 NDArray top_p_device_;
703+ NDArray top_p_init_pivots_device_;
681704 NDArray top_prob_offsets_device_;
682705 NDArray draft_tokens_device_;
683706 NDArray token_tree_first_child_device_;
@@ -691,6 +714,7 @@ class GPUSampler : public SamplerObj {
691714 // The device stream for copying auxiliary data structure to GPU.
692715 TVMStreamHandle copy_stream_ = nullptr ;
693716 const float eps_ = 1e-5 ;
717+ const int num_top_p_cutoff_pivots_ = 3 ;
694718};
695719
696720Sampler Sampler::CreateGPUSampler (int max_num_sample, int vocab_size, FunctionTable* ft,
0 commit comments