Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Han 2025 - GRPO Example

Notebook Tuning GPT-OSS

This notebook shows how to finetune gpt-oss-20b using unsloth. The parameters used are:

max_seq_length = 768
lora_rank = 4
lora_alpha = lora_rank * 2
load_in_4bit = True
offload_embedding = True

The toy problem is to induce gpt-oss to learn how to output an optimized all python matmul procedure that only uses the standard library. Naturally, using optimized libraries like numpy will be a lot faster, but there is some room to improve from a naive implementation in pure python.

First, we generate some random matrices. Note that we need A_list and B_list which are list[list[float]] to pass to our pure python function.

import numpy as np
def generate_random_matrices(seed = 3407, n = 256):
    random_state = np.random.RandomState(seed)
    n, k, m = random_state.randint(1, n+1, size = 3)
    A = np.random.uniform(-10, 10, size = (n, k))
    B = np.random.uniform(-10, 10, size = (k, m))
    return A, A.tolist(), B, B.tolist()

For example, a kernel generated by GPT-5 is:

# Kernel generated by GPT-5
def matmul(A, B):
    B_transpose = list(zip(*B))
    return [
        [
            sum(a*b for a, b in zip(row, col)) 
            for col in B_transpose
        ] for row in A
    ]

Preventing Cheating

GRPO works using RL with a reward function. To get the desired behaviour, we have to find a reward function that disincentivises cheating. For example, the LLM might just come up with a solution that imports torch or numpy to get an optimized matmul function.

Firstly, we write a function called check_only_stdlib_imports that, given a string representing a python function (let's call it fn_string), checks if it imports anything outside of the python standard library. We omit the definition here as it is rather involved.

Secondly, given fn_string, we want to compile it into a python function and disallow the function from importing anything from the global name space. The following code does so:

def create_locked_down_function(function: str):
    output_function = {}
    exec(function, {}, output_function)
    new_matmul = output_function["matmul"]
    new_matmul = types.FunctionType(new_matmul.__code__, {})
    return new_matmul

Some explanations:

  • exec(string: str, globals: dict, locals: dict)
    • This executes code in string and stores whatever definitions within the code into locals
    • By specifying an empty dict for globals, we are disallowing the functions defined within to access the global namespace
  • types.FunctionType(code, globals)
    • This creates a new function from compiled python bytecode
    • It also redefines the global variables that the function can access as an empty dictionary
  • So this function turns fn_string into a python function, but disallows accessing any global variables and libraries

Benchmarking

Now we create the actual benchmarking function. There are some details here that we shall gloss over, but note that at a high level, we call Benchmarker.benchmark(mat_mul_fn, [A, B]) to get some statistics about the matrix multiplcation function we are testing.

import os, gc, time, statistics
import signal
from contextlib import contextmanager

class TimeoutError(Exception): pass

@contextmanager
def time_limit(seconds):
    def _handler(signum, frame):
        raise TimeoutError(f"Timed out after {seconds}s")
    old = signal.signal(signal.SIGALRM, _handler)
    signal.setitimer(signal.ITIMER_REAL, seconds)
    try:
        yield
    finally:
        signal.setitimer(signal.ITIMER_REAL, 0.0)
        signal.signal(signal.SIGALRM, old)

class Benchmarker:
    def __init__(self, trials = 3, loops = 1, timeout = 30):
        self.buffer = np.zeros(2 * 1024 * 1024 * 1024, dtype = np.uint8)
        self.trials = trials
        self.loops = loops
        assert timeout > 0 # Cannot be 0 since it won't work!
        self.timeout = timeout

    def thrash(self):
        # Edit the buffer to wipe cache lines
        self.buffer ^= 1
        return int(self.buffer[::4096].sum())

    def benchmark(self, function, arguments):
        assert len(arguments) == self.loops
        samples = []
        exceptions = []
        timed_out = 0
        for _ in range(self.trials):
            gc.collect(); gc.disable(); self.thrash()
            t_start = time.perf_counter_ns()
            for i in range(self.loops):
                try:
                    with time_limit(self.timeout):
                        function(*arguments[i])
                except TimeoutError as e:
                    timed_out += 1
                except Exception as e:
                    exceptions.append(str(e))
            t_end = time.perf_counter_ns()
            gc.enable()
            samples.append((t_end - t_start) // max(1, self.loops))
        return {
            "median_ns": int(statistics.median(samples)),
            "mean_ns": int(statistics.fmean(samples)),
            "stdev_ns": int(statistics.pstdev(samples) if len(samples) > 1 else 0),
            "exceptions" : exceptions,
            "timeouts" : timed_out,
        }

Some other details:

  • A timeout is added to raise an error if the function runs for too long
  • gc is handled to avoid garbage collection from messing up the trial timing
  • Some thrashing of a large 2GB buffer is done to wipe out the L1, L2, L3 CPU caches to avoid the CPU from tapping on the cache to fast track the computation

Generate matmul function

Now we get gpt-oss to generate some matmul function, like so.

from transformers import TextStreamer

prompt = """
Create a new fast matrix multiplication function using only native Python code.
You are given a list of list of numbers.
Output your new function in backticks using the format below:
```python
def matmul(A, B):
    return ...
```
""".strip()

text = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    tokenize = False,
    add_generation_prompt = True,
    reasoning_effort = "low",
)
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 1.0,
    max_new_tokens = 512,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)

Which will give something like:

...

```python
def matmul(A, B):
    """
    Multiply two matrices A and B where both are given as
    lists of lists (rows) and are compatible for multiplication.
    Returns the resulting matrix as a list of lists.
    """
    # Transpose B so we can access its columns as rows.
    B_T = [list(col) for col in zip(*B)]          # O(n*m) time
    result = []

    for row in A:                                 # for each row in A
        res_row = []
        for col in B_T:                           # for each column of B
            # Compute dot product of row and column
            dot = sum(a * b for a, b in zip(row, col))
            res_row.append(dot)
        result.append(res_row)

    return result
```
...

Reward functions

Now we want to build up our reward function which assigns a reward score to each new generation from gpt-oss. The flow is:

  • Extract the function from a generation text
  • Use create_locked_down_function to turn it into python code
  • Use Benchmarker to test how fast it is
  • Using benchmark statistics, assign some scores

First we define extract_function(generation_text: str) which simply extracts the function from a generation text by looking for the triple-backticks. Omitted for brevity.

Next we create multiple reward functions that test different aspects. For example, we have function_works which just checks that the function runs without errors. Note that the function runs on a list of completions.

def function_works(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        function = extract_function(response)
        print(function)
        if function is not None:
            ok, info = check_only_stdlib_imports(function)
        if function is None or "error" in info:
            score = -2.0
        else:
            try:
                new_matmul = create_locked_down_function(function)
                score = 1.0
            except:
                score = -0.5
        scores.append(score)
    return scores

Another reward function checks that the function is correct. As we can see, the score is higher for more accurate functions and lower for inaccurate functions.

def correctness_check(completions, **kwargs):
    scores = []
    # Generate some random matrices of size less than 128
    A, A_list, B, B_list = generate_random_matrices(seed = np.random.randint(10000), n = 128)
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        function = extract_function(response)
        if function is not None:
            ok, info = check_only_stdlib_imports(function)
        if function is None or "error" in info:
            scores.append(0)
            continue
        try:
            new_matmul = create_locked_down_function(function)
        except:
            scores.append(0)
            continue
        try:
            pred = new_matmul(A_list.copy(), B_list.copy())
        except:
            # Failed!
            scores.append(-2.0)
            continue
        true = np.matmul(A, B)
        amax_error, mse_error = calculate_difference(pred, true)

        # Check correctness and score!
        machine_epsilon = 100*np.finfo(np.float64).eps
        if   amax_error >= 3:   score = -3.0
        elif amax_error >= 2:   score = -2.5
        elif amax_error >= 1:   score = -2.0
        elif amax_error >= 0.5: score = -1.0
        elif amax_error >= 100*machine_epsilon: score = 0.0
        elif amax_error >= machine_epsilon: score = 1.0
        else: score = 3.0

        if   mse_error >= 3:   score += -3.0
        elif mse_error >= 2:   score += -2.5
        elif mse_error >= 1:   score += -2.0
        elif mse_error >= 0.5: score += -1.0
        elif mse_error >= 100*machine_epsilon: score += 0.0
        elif mse_error >= machine_epsilon: score += 1.0
        else: score += 3.0
        scores.append(score)
    return scores

Finally we have the speed check function which rewards fast implementations.

  • If the LLM implementation is faster than numpy implementation, then positive is used (e.g. 2x faster gives +0.02)
  • If the LLM implementation is slower than numply implementation, then negative is used (e.g. 2x slower gives -0.02)
import gc
def speed_check(completions, **kwargs):
    scores = []
    # Generate some random matrices of size less than 256
    A, A_list, B, B_list = generate_random_matrices(seed = np.random.randint(10000), n = 256)
    numpy_results = benchmarker.benchmark(np.matmul, [(A, B)])
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        function = extract_function(response)
        if function is not None:
            ok, info = check_only_stdlib_imports(function)
        if function is None or "error" in info:
            scores.append(0)
            continue
        try:
            new_matmul = create_locked_down_function(function)
        except:
            scores.append(0)
            continue
        new_results = benchmarker.benchmark(new_matmul, [(A_list.copy(), B_list.copy())])

        # Get score and clip to -10, 10
        negative = -(new_results["median_ns"] / numpy_results["median_ns"]) / 100
        positive = +(numpy_results["median_ns"] / new_results["median_ns"]) / 100
        score = negative if new_results["median_ns"] >= numpy_results["median_ns"] else positive
        if score >= 10:  score = 10
        if score <= -10: score = -10
        scores.append(score)
    # Free memory to counteract OOMs
    gc.collect()
    torch.cuda.empty_cache()
    return scores

GRPOTrainer

Finally, we can prepare our dataset and set up the GRPOTrainer.

First, we create a dataset with just the prompt. We just need the prompt because we're running this in dynamic mode, where GRPOTrainer will sample a few generations for each prompt and evaluate it using the reward functions. This represents true on-policy RL.

Note also that we use reasoning_effort="low".

from datasets import Dataset
dataset = Dataset.from_list([{"prompt" : [{"role": "user", "content": prompt.strip()}], "answer" : 0, "reasoning_effort": "low"}]*1000)
maximum_length = len(tokenizer(prompt.strip())["input_ids"])
print(maximum_length)
dataset[0]

Now we setup GRPOTrainer with configuration.

max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    temperature = 1.0,
    learning_rate = 5e-5,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 2, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 100,
    save_steps = 100,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",

    # For optional training + evaluation
    # fp16_full_eval = True,
    # per_device_eval_batch_size = 4,
    # eval_accumulation_steps = 1,
    # eval_strategy = "steps",
    # eval_steps = 1,
)

Finally we can train the model. Note that T4 GPU might take 5 minutes for one generation which might take hours to start seeing some progress on this task.

# For optional training + evaluation
# new_dataset = dataset.train_test_split(test_size = 0.01)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        function_works,
        no_cheating,
        correctness_check,
        speed_check,
    ],
    args = training_args,
    train_dataset = dataset,

    # For optional training + evaluation
    # train_dataset = new_dataset["train"],
    # eval_dataset = new_dataset["test"],
)
trainer.train()

After training, we can use our original prompt to see the new function that the LLM generates.

text = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    tokenize = False,
    add_generation_prompt = True,
    reasoning_effort = "low",
)

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 1.0,
    max_new_tokens = 1024,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)