Another Approach to Wrapping Integer Arithmetic in Python



When you write a lot of Python that is intended to mimic arithmetic that was taken from C or assembly code, one giant advantage of Python sometimes becomes a burden: The integers are arbitrary precision, and the "built-in" wrapping of fixed-width integer types is something that has to be added manually, sometimes a truly arduous process. I recently had a thought about how this can be done in a slightly less agonizing way. This is a prototype and might require some more extension in the future, but I am pretty happy with it already. The idea is to implement a decorator that will use Python's ast module to add a bitmask to each arithmetic operation that occurs in the code of a decorated function. Here's the code:
import ast
import functools
import inspect


def masked(mask: int):
    '''
    Convert arithmetic operations that occur within the decorated function body in such a way that
    the result is reduced using the given bitmask. All additions, subtractions, multiplications,
    left shifts, and taking powers is augmented by introducing a bitwise and with the given mask.
    '''
    def decorator(function):
        code = inspect.getsource(function)
        tree = ast.parse(code, mode='exec')

        class Postprocessor(ast.NodeTransformer):
            name = None 
            def visit_BinOp(self, node: ast.BinOp):
                node = self.generic_visit(node)
                if not isinstance(node.op, (ast.Add, ast.Mult, ast.Sub, ast.LShift, ast.Pow)):
                    return node
                return ast.BinOp(node, ast.BitAnd(), ast.Constant(mask))
            def visit_FunctionDef(self, node: ast.FunctionDef):
                node = self.generic_visit(node)
                if self.name is None:
                    node.name = self.name = F'__wrapped_{node.name}'
                    for k in range(len(node.decorator_list)):
                        if node.decorator_list[k].func.id == masked.__name__:
                            del node.decorator_list[:k + 1]
                            break
                return node

        pp = Postprocessor()
        fixed = ast.fix_missing_locations(pp.visit(tree))
        eval(compile(fixed, function.__code__.co_filename, 'exec'))
        return functools.wraps(function)(eval(pp.name))

    return decorator
With this decorator, you can now write:
@masked(0xFFFF)
def test(x: int) -> int:
    return x * 0xBAAD
This will leave all code exactly as it is except for the binary operations involving addition, multiplication, subtraction, left shift, and computing powers - all of these are converted to having one additional bitwise and operation on top.

Tags: -

Leave a Reply

Your email address will not be published. Required fields are marked *