# Copyright (c) Meta Platforms, Inc. and affiliates.
# SPDX-License-Identifier: LGPL-2.1-or-later

from drgn.helpers.linux.cpumask import for_each_possible_cpu
from drgn.helpers.linux.percpu import (
    per_cpu,
    per_cpu_ptr,
    percpu_counter_sum,
    percpu_counter_sum_positive,
)
from tests.linux_kernel import LinuxKernelTestCase, prng32, skip_unless_have_test_kmod


class TestPerCpu(LinuxKernelTestCase):
    def test_per_cpu(self):
        for cpu in for_each_possible_cpu(self.prog):
            rq = per_cpu(self.prog["runqueues"], cpu)
            try:
                rq_cpu = rq.cpu
            except AttributeError:
                # Before Linux kernel commit cac5cefbade9 ("sched/smp: Make SMP
                # unconditional") (in v6.17), struct rq::cpu only exists if
                # CONFIG_SMP=y, so check this instead.
                self.assertEqual(rq.idle, self.prog["init_task"].address_of_())
            else:
                self.assertEqual(rq_cpu, cpu)

    @skip_unless_have_test_kmod
    def test_per_cpu_module_static(self):
        for cpu, expected in zip(for_each_possible_cpu(self.prog), prng32("PCPU")):
            self.assertEqual(
                per_cpu(self.prog["drgn_test_percpu_static"], cpu), expected
            )

    @skip_unless_have_test_kmod
    def test_per_cpu_module_dynamic(self):
        for cpu, expected in zip(for_each_possible_cpu(self.prog), prng32("pcpu")):
            self.assertEqual(
                per_cpu_ptr(self.prog["drgn_test_percpu_dynamic"], cpu)[0], expected
            )


@skip_unless_have_test_kmod
class TestPercpuCounter(LinuxKernelTestCase):
    def test_percpu_counter_sum(self):
        self.assertEqual(percpu_counter_sum(self.prog["drgn_test_percpu_counter"]), 13)
        self.assertEqual(
            percpu_counter_sum(self.prog["drgn_test_percpu_counter_negative"]), -66
        )

    def test_percpu_counter_sum_positive(self):
        self.assertEqual(
            percpu_counter_sum_positive(self.prog["drgn_test_percpu_counter"]), 13
        )
        self.assertEqual(
            percpu_counter_sum_positive(self.prog["drgn_test_percpu_counter_negative"]),
            0,
        )
