Skip to content

Conversation

@xiaolil1
Copy link
Contributor

@xiaolil1 xiaolil1 commented Jun 15, 2025

This is the pull request for the SYCL Kernels targeting the XPU backend.

  • It features the implementation of the "dequantize_blockwise," "dequantize_4bit," and "dequant & gemv_4bit fusion" kernels.
  • The target low-precision quantization datatypes encompass NF4, FP4 and General8bits.
  • This PR aims to eliminate the dependency on IPEX and improve the performance.

@matthewdouglas matthewdouglas added Low Priority (will be worked on after all priority issues) Intel labels Jun 17, 2025
@matthewdouglas matthewdouglas self-assigned this Jun 17, 2025
@matthewdouglas matthewdouglas self-requested a review June 17, 2025 16:19
@matthewdouglas matthewdouglas added this to the v0.48.0 milestone Jun 17, 2025
@fengyuan14
Copy link

Can we use a more accurate title for the commit? or reviewers would get confused if all SYCL kernels are included in the PR.

* fix sycl nd

Signed-off-by: jiqing-feng <[email protected]>

* fix tests

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>
@Egor-Krivov
Copy link
Contributor

Hi @Egor-Krivov . Could you share your device name? I can pass all tests on Intel(R) Data Center GPU Max 1550. = 2362 passed, 1540 skipped, 184 deselected, 24 xfailed, 31 warnings in 1081.07s (0:18:01) =

(triton) (base) jovyan@jupyter-ekrivov:~/triton/unsloth$ sycl-ls
[level_zero:gpu][level_zero:0] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) Data Center GPU Max 1100 12.60.7 [1.3.27642]
[opencl:cpu][opencl:0] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6438Y+ OpenCL 3.0 (Build 0) [2024.18.12.0.05_160000]
[opencl:gpu][opencl:1] Intel(R) OpenCL Graphics, Intel(R) Data Center GPU Max 1100 OpenCL 3.0 NEO  [23.43.27642.69]



# SYCL should be faster for xpu, so at first checking if it is available.
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
Copy link
Contributor

@Egor-Krivov Egor-Krivov Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently you either pick all methods from SYCL or all methods from triton. However, sycl implementation right now is missing these methods, available in triton:

quantize_blockwize
quantize_4bit

I suggest we keep using these triton methods even with SYCL, since that's the only option on XPU for new.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two kernels don't affect the performance of QLoRA, they are now default running with pytorch ops and we will implemented them with SYCL kernel later.

@Egor-Krivov
Copy link
Contributor

The implementation is missing following methods:

void cgemm_4bit_inference_naive_fp16(
    int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
    int ldc, int blocksize, cudaStream_t stream
) {
    gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

void cgemm_4bit_inference_naive_bf16(
    int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
    int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
) {
    gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

void cgemm_4bit_inference_naive_fp32(
    int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
    int ldc, int blocksize, cudaStream_t stream
) {
    gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

cgemm_4bit_inference_naive_bf16 is used for text generation in the basic unsloth tutorial. Here is a stack trace:

Traceback (most recent call last):
  File "/home/jovyan/triton/unsloth/bench_unsloth.py", line 150, in <module>
    outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/peft/peft_model.py", line 1968, in generate
    outputs = self.base_model.generate(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 1821, in unsloth_fast_generate
    output = self._old_generate(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/transformers/generation/utils.py", line 2625, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/transformers/generation/utils.py", line 3609, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 1253, in _CausalLM_fast_forward
    outputs = fast_forward_inference(
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 1186, in LlamaModel_fast_forward_inference_custom
    X, present_key_value = attention_fast_forward_inference(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 264, in LlamaAttention_fast_forward_inference
    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/kernels/utils.py", line 637, in fast_linear_forward
    out = fast_gemv(X, W, W_quant, out = out)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/kernels/utils.py", line 483, in fast_gemv
    fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
  File "/home/jovyan/triton/bitsandbytes/bitsandbytes/cextension.py", line 60, in throw_on_call
    raise RuntimeError(
RuntimeError: Method 'cgemm_4bit_inference_naive_bf16' not available in CPU-only version of bitsandbytes.
Reinstall with GPU support or use CUDA-enabled hardware.

@jiqing-feng
Copy link
Contributor

Hi @Egor-Krivov . Could you share your script to get this error?

@xiaolil1
Copy link
Contributor Author

xiaolil1 commented Jul 8, 2025

The implementation is missing following methods:

void cgemm_4bit_inference_naive_fp16(
    int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
    int ldc, int blocksize, cudaStream_t stream
) {
    gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

void cgemm_4bit_inference_naive_bf16(
    int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
    int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
) {
    gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

void cgemm_4bit_inference_naive_fp32(
    int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
    int ldc, int blocksize, cudaStream_t stream
) {
    gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

cgemm_4bit_inference_naive_bf16 is used for text generation in the basic unsloth tutorial. Here is a stack trace:

Traceback (most recent call last):
  File "/home/jovyan/triton/unsloth/bench_unsloth.py", line 150, in <module>
    outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/peft/peft_model.py", line 1968, in generate
    outputs = self.base_model.generate(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 1821, in unsloth_fast_generate
    output = self._old_generate(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/transformers/generation/utils.py", line 2625, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/transformers/generation/utils.py", line 3609, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 1253, in _CausalLM_fast_forward
    outputs = fast_forward_inference(
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 1186, in LlamaModel_fast_forward_inference_custom
    X, present_key_value = attention_fast_forward_inference(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 264, in LlamaAttention_fast_forward_inference
    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/kernels/utils.py", line 637, in fast_linear_forward
    out = fast_gemv(X, W, W_quant, out = out)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/kernels/utils.py", line 483, in fast_gemv
    fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
  File "/home/jovyan/triton/bitsandbytes/bitsandbytes/cextension.py", line 60, in throw_on_call
    raise RuntimeError(
RuntimeError: Method 'cgemm_4bit_inference_naive_bf16' not available in CPU-only version of bitsandbytes.
Reinstall with GPU support or use CUDA-enabled hardware.

@Egor-Krivov , these kernels have been implemented already.
They are "

The implementation is missing following methods:

void cgemm_4bit_inference_naive_fp16(
    int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
    int ldc, int blocksize, cudaStream_t stream
) {
    gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

void cgemm_4bit_inference_naive_bf16(
    int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out,
    int lda, int ldb, int ldc, int blocksize, cudaStream_t stream
) {
    gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

void cgemm_4bit_inference_naive_fp32(
    int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
    int ldc, int blocksize, cudaStream_t stream
) {
    gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

cgemm_4bit_inference_naive_bf16 is used for text generation in the basic unsloth tutorial. Here is a stack trace:

Traceback (most recent call last):
  File "/home/jovyan/triton/unsloth/bench_unsloth.py", line 150, in <module>
    outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/peft/peft_model.py", line 1968, in generate
    outputs = self.base_model.generate(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 1821, in unsloth_fast_generate
    output = self._old_generate(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/transformers/generation/utils.py", line 2625, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/transformers/generation/utils.py", line 3609, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.conda/envs/unsloth/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 1253, in _CausalLM_fast_forward
    outputs = fast_forward_inference(
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 1186, in LlamaModel_fast_forward_inference_custom
    X, present_key_value = attention_fast_forward_inference(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/models/llama.py", line 264, in LlamaAttention_fast_forward_inference
    Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/kernels/utils.py", line 637, in fast_linear_forward
    out = fast_gemv(X, W, W_quant, out = out)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/triton/unsloth/unsloth/kernels/utils.py", line 483, in fast_gemv
    fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
  File "/home/jovyan/triton/bitsandbytes/bitsandbytes/cextension.py", line 60, in throw_on_call
    raise RuntimeError(
RuntimeError: Method 'cgemm_4bit_inference_naive_bf16' not available in CPU-only version of bitsandbytes.
Reinstall with GPU support or use CUDA-enabled hardware.

@Egor-Krivov, these kernels already implemented with SYCL kernel.
For the "gemm_4bit_inference", you need to call "cgemv_4bit_inference_**"
You can refer to the kernel dispatch in "csrc/pythonInterface.cpp"
void cgemv_4bit_inference_fp16(
int m, int n, int k, sycl::half * A, unsigned char* B, float *absmax, float datatype, sycl::half * out,
int lda, int ldb, int ldc, int blocksize, sycl::queue
stream
) {
gemv_4bit_inference_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

@jiqing-feng
Copy link
Contributor

Hi @matthewdouglas . Could you please trigger the CI for this PR? Thanks!

@xiaolil1
Copy link
Contributor Author

xiaolil1 commented Jul 8, 2025

This PR is ready for review now, please reach us if there is any other question, thanks!

@Egor-Krivov
Copy link
Contributor

Hi @Egor-Krivov . Could you share your script to get this error?

I'm working on performance testing of unsloth right now.

These methods are used for CUDA implementation here:
https://github.com/unslothai/unsloth/blob/6ac4e2e36f2f8bd0bc63a6eb85afa7097948ff3d/unsloth/kernels/utils.py#L173
For XPU we will need to provide implementation as well, I think.

I am working with POC branch (not merged to upstream) from https://github.com/leizhenyuan/unsloth/blob/7bed913255f611e220c2d219ee988c179ed98033/unsloth/kernels/utils.py#L154
In the POC branch this method can be called.

For me the call happens in the last 2 lines of my script, which is essentially a copy of unsloth tutorial:

from unsloth import FastLanguageModel
import torch
import time

device = 'xpu:0'

max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",      # Llama-3.1 2x faster
    "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    "unsloth/Meta-Llama-3.1-70B-bnb-4bit",
    "unsloth/Meta-Llama-3.1-405B-bnb-4bit",    # 4bit for 405b!
    "unsloth/Mistral-Small-Instruct-2409",     # Mistral 22b 2x faster!
    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    "unsloth/Phi-3.5-mini-instruct",           # Phi-3.5 2x faster!
    "unsloth/Phi-3-medium-4k-instruct",
    "unsloth/gemma-2-9b-bnb-4bit",
    "unsloth/gemma-2-27b-bnb-4bit",            # Gemma 2x faster!

    "unsloth/Llama-3.2-1B-bnb-4bit",           # NEW! Llama 3.2 models
    "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
    "unsloth/Llama-3.2-3B-bnb-4bit",
    "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",

    "unsloth/Llama-3.3-70B-Instruct-bnb-4bit" # NEW! Llama 3.3 70B!
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "unsloth/Llama-3.2-3B-Instruct", # or choose "unsloth/Llama-3.2-1B-Instruct"
    model_name = "unsloth/Llama-3.2-3B-Instruct", # or choose "unsloth/Llama-3.2-1B-Instruct"
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # EGOR
    device_map={"": device}, # Use this to set the device for the model
    # attn_implementation="eager",
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

model = model.to(device)

# import pdb
# pdb.set_trace()
# model_devices = set(model.hf_device_map.values())
# model.hf_device_map = {0: 'xpu'}


from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }

from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-100k", split = "train")



from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched = True,)


from trl import SFTConfig, SFTTrainer
from transformers import DataCollatorForSeq2Seq, TrainerCallback

# from bench_tools import tm
import pprint
from collections import defaultdict

class LatencyCallback(TrainerCallback):
    def __init__(self):
        self.step_start_time = None
    
    def on_step_begin(self, args, state, control, **kwargs):
        self.step_start_time = time.time()
    
    def on_step_end(self, args, state, control, **kwargs):
        if self.step_start_time is not None:
            step_latency = time.time() - self.step_start_time
            print(f"Step {state.global_step}: Latency = {step_latency:.4f} seconds")
            print()
        # self.times.append(get_time())
        # if len(self.times) > 1:
        #     print("Token latency: {:.1f} ms".format(1000 * (self.times[-1] - self.times[-2])))

        # if len(self.times) % 10 == 3 and self.print_median:
        #     ts = np.array(self.times)
        #     diff = ts[1:] - ts[:-1]
        #     # print("Token latency:", 1000 * diff, "ms")
        #     print("Token latency median:", np.median(1000 * diff), "ms")
        # print("Total accumulators:", {k: 1000* sum(v) for k, v in self.acc.items()}, "ms")
        # import pdb
        # pdb.set_trace()
        # print("Total accumulators:")# , {k: 1000* v for k, v in tm.get_results().items()}, "ms")
        # results = tm.get_results()
        # results = {k: f'{1000 * v:.2f}ms' for k, v in results.items()}
        # pprint.pprint(results)
        # tm.reset()

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    callbacks = [LatencyCallback()],
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 60,
        learning_rate = 2e-4,
        logging_steps = 1,
        # optim = "adamw_8bit",
        optim = "adamw_torch",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use this for WandB etc
    ),
)

from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)

print(tokenizer.decode(trainer.train_dataset[5]["input_ids"]))

space = tokenizer(" ", add_special_tokens = False).input_ids[0]
print(tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[5]["labels"]]))

trainer_stats = trainer.train()


from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

messages = [
    {"role": "user", "content": "Continue the fibonnaci sequence: 1, 1, 2, 3, 5, 8,"},
]
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize = True,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
).to(device)

outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True,
                         temperature = 1.5, min_p = 0.1)
tokenizer.batch_decode(outputs)

@github-actions
Copy link

github-actions bot commented Jul 8, 2025

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jiqing-feng
Copy link
Contributor

Hi @matthewdouglas . The lint test failed with error fix. See this comment. Do you know how to skip xpu kernels on typo test?

* skip test for xpu ops

Signed-off-by: jiqing-feng <[email protected]>

* fix lint

Signed-off-by: jiqing-feng <[email protected]>

* skip typo for xpu

Signed-off-by: jiqing-feng <[email protected]>

* skip

Signed-off-by: jiqing-feng <[email protected]>

* skip

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>
@jiqing-feng
Copy link
Contributor

Hi @matthewdouglas . Please trigger the tests and review this PR. Thanks!

@matthewdouglas matthewdouglas removed the Low Priority (will be worked on after all priority issues) label Sep 2, 2025
@matthewdouglas matthewdouglas modified the milestones: v0.48.0, v0.49.0 Sep 2, 2025
jiqing-feng and others added 3 commits September 4, 2025 10:42
Signed-off-by: jiqing-feng <[email protected]>
# Description

The version comparison expression miss reference the .release property from the version object. This lead to compare between the tuple and the string

# Error message
```
The 8-bit optimizer is not available on your device, only available on CUDA for now.
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Traceback (most recent call last):
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/unsloth_validation/run.py", line 1, in <module>
    import unsloth
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/__init__.py", line 235, in <module>
    from .models import *
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/__init__.py", line 15, in <module>
    from .llama     import FastLlamaModel
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/llama.py", line 23, in <module>
    from ._utils import *
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/_utils.py", line 89, in <module>
    from unsloth_zoo.patching_utils import (
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth_zoo/patching_utils.py", line 629, in <module>
    import transformers.integrations.bitsandbytes
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py", line 20, in <module>
    import bitsandbytes as bnb
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/bitsandbytes/bitsandbytes/__init__.py", line 39, in <module>
    from .backends.xpu import ops as xpu_ops
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/bitsandbytes/bitsandbytes/backends/xpu/ops.py", line 17, in <module>
    if version.parse(torch.__version__).release >= version.parse("2.9"):
TypeError: '>=' not supported between instances of 'tuple' and 'Version'
```
@matthewdouglas matthewdouglas modified the milestones: v0.49.0, v0.48.0 Sep 15, 2025
@matthewdouglas
Copy link
Member

matthewdouglas commented Sep 15, 2025

Hi all,

There's a few small lint issues to fix (I'll take care of it!)

Apart from that, it would be great if we could add the XPU backend build to our existing workflow in .github/workflows/python-package.yml so that it can be included in our release wheels.

Thanks!

@matthewdouglas
Copy link
Member

As discussed on Slack, we can follow up with separate PRs for things like packaging.

@matthewdouglas matthewdouglas merged commit 1813b05 into bitsandbytes-foundation:main Sep 15, 2025
46 of 47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants