#!/usr/bin/env python3

# Jailhouse, a Linux-based partitioning hypervisor
#
# Copyright (c) Siemens AG, 2015-2016
#
# Authors:
#  Jan Kiszka <jan.kiszka@siemens.com>
#
# This work is licensed under the terms of the GNU GPL, version 2.  See
# the COPYING file in the top-level directory.

import argparse
import gzip
import os
import struct
import sys

# Imports from directory containing this must be done before the following
sys.path[0] = os.path.dirname(os.path.abspath(__file__)) + "/.."
from pyjailhouse.cell import JailhouseCell
import pyjailhouse.config_parser as config_parser

libexecdir = None


def page_align(value, page_size=0x1000):
    return (value + page_size - 1) & ~(page_size - 1)


def unpack_cstring(blob):
    string = ''
    pos = 0
    (byte,) = struct.unpack_from('B', blob[pos:pos+1])
    while byte != 0:
        string += chr(byte)
        pos += 1
        (byte,) = struct.unpack_from('B', blob[pos:pos+1])
    return string


class DTBHeader:
    def __init__(self, blob):
        (self.total_size, self.off_dt_struct, self.off_dt_strings,
         self.off_mem_rsvmap, self.version, self.last_comp_version,
         self.boot_cpuid_phys, self.size_dt_strings, self.size_dt_struct) = \
            struct.unpack_from('>4x9I', blob)
        if self.version < 17 or self.last_comp_version > 17:
            print('Unsupported DTB version %d' % self.version)
            exit(1)
        self.header_blob = blob

    def get(self):
        return struct.pack('>10I', 0xd00dfeed, self.total_size,
                           self.off_dt_struct, self.off_dt_strings,
                           self.off_mem_rsvmap, self.version,
                           self.last_comp_version, self.boot_cpuid_phys,
                           self.size_dt_strings, self.size_dt_struct)


class DTBReserveMap:
    def __init__(self, blob):
        self.blob = blob[0:16]
        pos = 0
        while struct.unpack_from('>8xQ', self.blob[pos:pos+16]) != (0,):
            self.blob += blob[pos+16:pos+32]
            pos += 16


class DTBStrings:
    def __init__(self, blob):
        self.blob = blob

    def add(self, string):
        offset = len(self.blob)
        self.blob += string.encode() + b'\0'
        return offset

    def get(self, off):
        strings = str(self.blob.decode())
        return strings[off:].split('\0', 1)[0]


OF_DT_BEGIN_NODE = 0x00000001
OF_DT_END_NODE = 0x00000002
OF_DT_PROP = 0x00000003
OF_DT_END = 0x00000009


class DTBProperty:
    def __init__(self, name_off, value, strings):
        self.name_off = name_off
        self.name = strings.get(self.name_off)
        self.data = value

    @staticmethod
    def parse(blob, strings):
        (datasize, name_off) = struct.unpack_from('>II', blob)
        data = blob[8:8+datasize]
        prop = DTBProperty(name_off, data, strings)
        length = 8 + (datasize + 3) & ~3
        return (prop, length)

    @staticmethod
    def create(name, value, strings):
        return DTBProperty(strings.add(name), value, strings)

    def set_value(self, value):
        self.data = value

    def get(self):
        datasize = len(self.data)
        blob = struct.pack('>III', OF_DT_PROP, datasize, self.name_off)
        blob += self.data
        if datasize & 3 != 0:
            blob += bytearray(4 - datasize & 3)
        return blob


class DTBNode:
    def __init__(self, name, children, properties, strings):
        self.name = name
        self.children = children
        self.properties = properties
        self.strings = strings

    @staticmethod
    def parse(blob, strings):
        name = unpack_cstring(blob)
        # next field is word-aligned
        length = ((len(name) + 1) + 3) & ~3

        children = []
        properties = []

        while True:
            (token,) = struct.unpack_from('>I', blob[length:])
            length += 4
            if token == OF_DT_BEGIN_NODE:
                (child, child_length) = DTBNode.parse(blob[length:], strings)
                length += child_length
                children.append(child)
            elif token == OF_DT_PROP:
                (prop, prop_length) = DTBProperty.parse(blob[length:], strings)
                length += prop_length
                properties.append(prop)
            elif token == OF_DT_END_NODE:
                break
            else:
                raise RuntimeError('Invalid DTB')

        return (DTBNode(name, children, properties, strings), length)

    def find(self, path):
        if path == '':
            return self
        if path.find('/') >= 0:
            (childname, subpath) = path.split('/', 1)
        else:
            childname = path
            subpath = ''
        matches = [c for c in self.children if c.name == childname]
        if len(matches) > 1:
            raise RuntimeError('Invalid DTB')
        if not matches:
            return None
        else:
            return matches[0].find(subpath)

    def add_node(self, name):
        child = DTBNode(name, [], [], self.strings)
        self.children.append(child)
        return child

    def get_prop(self, name):
        matches = [p for p in self.properties if p.name == name]
        if len(matches) > 1:
            raise RuntimeError('Invalid DTB')
        return matches[0] if matches else None

    def add_prop(self, name, value):
        prop = DTBProperty.create(name, value, self.strings)
        self.properties.append(prop)

    def get(self):
        namelen = ((len(self.name) + 1) + 3) & ~3
        blob = struct.pack('>I%ds' % namelen, OF_DT_BEGIN_NODE,
                           self.name.encode())
        for prop in self.properties:
            blob += prop.get()
        for child in self.children:
            blob += child.get()
        return blob + struct.pack('>I', OF_DT_END_NODE)


class DTB:
    def __init__(self, blob):
        self.header = DTBHeader(blob[0:40])

        self.rsvmap = DTBReserveMap(blob[self.header.off_mem_rsvmap:])

        end_idx = self.header.off_dt_strings + self.header.size_dt_strings
        self.strings = DTBStrings(blob[self.header.off_dt_strings:end_idx])

        start_idx = self.header.off_dt_struct
        end_idx = start_idx + self.header.size_dt_struct

        (begin_token,) = struct.unpack_from('>I', blob[start_idx:end_idx])
        if begin_token != OF_DT_BEGIN_NODE:
            raise RuntimeError('Invalid DTB')
        (self.root_node, length) = DTBNode.parse(blob[start_idx+4:end_idx],
                                                 self.strings)

    def get_prop(self, path, name):
        node = self.root_node.find(path[path.find('/')+1:])
        if not node:
            raise RuntimeError('DTB is missing node %s' % path)
        prop = node.get_prop(name)
        return prop.data if prop else None

    def set_prop(self, path, name, value):
        subpath = path[path.find('/')+1:]
        node = self.root_node.find(subpath)
        if not node:
            (parent_path, new_node) = path.rsplit('/')
            parent = self.root_node.find(parent_path)
            node = parent.add_node(new_node)
        prop = node.get_prop(name)
        if prop:
            prop.set_value(value)
        else:
            node.add_prop(name, value)

    def get(self):
        nodes = self.root_node.get() + struct.pack('>I', OF_DT_END)
        self.header.off_mem_rsvmap = 40
        self.header.off_dt_struct = 40 + len(self.rsvmap.blob)
        self.header.size_dt_struct = len(nodes)
        self.header.off_dt_strings = self.header.off_dt_struct + len(nodes)
        self.header.size_dt_strings = len(self.strings.blob)
        self.header.total_size = self.header.off_dt_strings + \
            len(self.strings.blob)

        return self.header.get() + self.rsvmap.blob + nodes + self.strings.blob


class X86:
    name = 'x86'

    def setup(self, args, config):
        if config.cpu_reset_address != 0:
            raise RuntimeError('x86 requires cpu_reset_address=0')

        self.kernel_image = args.kernel.read()

        self._zero_page = X86ZeroPage(self.kernel_image, args.initrd,
                                      args.kernel_decomp_factor, config)

        setup_data = x86_gen_setup_data(config)

        self._zero_page.setup_header.set_setup_data(
            arch.params_address() + 0x1000)
        self._zero_page.setup_header.set_cmd_line_ptr(
            self._zero_page.setup_header.setup_data + len(setup_data))

        self.params = self._zero_page.get_data() + setup_data + \
            (args.cmdline.encode() if args.cmdline else b'') + b'\0'

    def write_params(self, args, config):
        args.write_params.write(arch.params)
        args.write_params.close()

        print('\
Boot parameters written. Start Linux with the following commands (adjusting \
paths as needed):\n\
\n\
jailhouse cell create %s\n\
jailhouse cell load %s linux-loader.bin -a 0x%x %s -a 0x%x ' %
              (args.config.name, config.name, self.loader_address(),
               args.kernel.name, self._zero_page.kernel_load_addr),
              end='')
        if args.initrd:
            print('%s -a 0x%x ' % (args.initrd.name,
                                   self._zero_page.setup_header.ramdisk_image),
                  end='')
        print('%s -a 0x%x' % (args.write_params.name, arch.params_address()))
        print('jailhouse cell start %s' % config.name)

    def loader_address(self):
        return 0

    def params_address(self):
        return 0x1000

    def kernel_address(self):
        return self._zero_page.kernel_load_addr

    def dtb_address(self):
        return None

    def dtb(self):
        return None

    def ramdisk_address(self):
        return self._zero_page.setup_header.ramdisk_image


class ARMCommon:
    def setup(self, args, config):
        self._cpu_reset_address = config.cpu_reset_address

        (self.kernel_image,
         self._kernel_gz) = self.get_uncompressed_kernel(args.kernel)
        kernel_size = page_align(len(self.kernel_image))
        kernel_load_offset = self.get_kernel_offset(self.kernel_image)
        image_size = kernel_load_offset + kernel_size + self.kernel_alignment()

        ramdisk_size = 0
        if args.initrd:
            ramdisk_size = os.fstat(args.initrd.fileno()).st_size
            # leave sufficient space between the kernel and the initrd
            decompression_factor = self.default_decompression_factor()
            if args.kernel_decomp_factor:
                decompression_factor = args.kernel_decomp_factor
            decompression_space = decompression_factor * kernel_size
            kernel_size += decompression_space
            image_size += decompression_space
            image_size += ramdisk_size * (1 + decompression_factor)

        if not args.dtb:
            print('No device tree specified', file=sys.stderr)
            exit(1)
        # add one page to the DTB size to be safe when extending it
        dtb_size = page_align(os.fstat(args.dtb.fileno()).st_size) + 0x1000
        image_size += dtb_size

        # add some pages in case the region contains the loader
        image_size += 0x2000

        ram_regions = [region for region in config.memory_regions
                       if region.is_ram() and region.size >= image_size]
        if not ram_regions:
            print('No space found to load all images', file=sys.stderr)
            exit(1)

        self._dtb_addr = ram_regions[0].virt_start
        if self._dtb_addr == 0:
            # leave room for the loader
            self._dtb_addr += 0x2000
        self._kernel_addr = page_align(self._dtb_addr + dtb_size,
                                       self.kernel_alignment())
        self._kernel_addr += kernel_load_offset
        self._ramdisk_addr = self._kernel_addr + kernel_size

        self.params = 'kernel=0x%x dtb=0x%x' % (self._kernel_addr,
                                                self._dtb_addr)
        self.params = self.params.encode()

        self.dtb = DTB(args.dtb.read())

        prop = self.dtb.get_prop('/', '#address-cells')
        if not prop:
            raise RuntimeError('Invalid DTB')
        (cells,) = struct.unpack_from('>I', prop)
        addr_format = '>Q' if cells == 2 else '>I'

        prop = self.dtb.get_prop('/', '#size-cells')
        if not prop:
            raise RuntimeError('Invalid DTB')
        (cells,) = struct.unpack_from('>I', prop)
        reg_format = addr_format + ('Q' if cells == 2 else 'I')

        self.dtb.set_prop('/memory', 'device_type', 'memory'.encode() + b'\0')
        reg = bytearray(0)
        for region in config.memory_regions:
            # Filter out non-RAM regions and small ones at the start of the
            # cell address space that is used for the loader.
            if region.is_ram() and \
               (region.virt_start > 0 or region.size > 0x10000):
                reg += struct.pack(reg_format, region.virt_start, region.size)
        self.dtb.set_prop('/memory', 'reg', reg)

        if args.cmdline:
            self.dtb.set_prop('/chosen', 'bootargs',
                              args.cmdline.encode() + b'\0')
        if args.initrd:
            self.dtb.set_prop('/chosen', 'linux,initrd-start',
                              struct.pack(addr_format, arch.ramdisk_address()))
            end_addr = arch.ramdisk_address() + ramdisk_size
            self.dtb.set_prop('/chosen', 'linux,initrd-end',
                              struct.pack(addr_format, end_addr))

    def params_address(self):
        return self._cpu_reset_address + 0x1000

    def write_params(self, args, config):
        args.write_params.write(self.dtb.get())
        args.write_params.close()

        print('\
Modified device tree written. Start Linux with the following commands \
(adjusting paths as needed):\n\
\n\
jailhouse cell create %s\n\
jailhouse cell load %s linux-loader.bin -a 0x%x -s "%s" -a 0x%x %s -a 0x%x ' %
              (args.config.name, config.name, self.loader_address(),
               self.params.decode(), self.params_address(),
               args.kernel.name + ('-unzipped' if self._kernel_gz else ''),
               self._kernel_addr), end='')
        if args.initrd:
            print('%s -a 0x%x ' % (args.initrd.name, self._ramdisk_addr),
                  end='')
        print('%s -a 0x%x' % (args.write_params.name, self._dtb_addr))
        print('jailhouse cell start %s' % config.name)
        if self._kernel_gz:
            print('\nNote that %s is zipped and requires decompression first.'
                  % args.kernel.name)

    def loader_address(self):
        return self._cpu_reset_address

    def kernel_address(self):
        return self._kernel_addr

    def dtb_address(self):
        return self._dtb_addr

    def dtb(self):
        return self.dtb.get()

    def ramdisk_address(self):
        return self._ramdisk_addr


class ARM(ARMCommon):
    name = 'arm'

    @staticmethod
    def get_uncompressed_kernel(kernel):
        return (kernel.read(), False)

    @staticmethod
    def kernel_alignment():
        return 0x1000

    @staticmethod
    def get_kernel_offset(kernel):
        return 0

    @staticmethod
    def default_decompression_factor():
        return 4


class ARM64(ARMCommon):
    name = 'arm64'

    @staticmethod
    def get_uncompressed_kernel(kernel):
        try:
            unzipped = gzip.GzipFile(fileobj=kernel)
            return (unzipped.read(), True)
        except IOError:
            kernel.seek(0)
            return (kernel.read(), False)

    @staticmethod
    def kernel_alignment():
        return 0x200000

    @staticmethod
    def get_kernel_offset(kernel_image):
        (text_offset,) = struct.unpack_from('<8xQ', kernel_image)
        return text_offset

    @staticmethod
    def default_decompression_factor():
        return 1


def resolve_arch(defined_arch=None):
    arch_classes = {'x86': X86, 'arm': ARM, 'arm64': ARM64, None: None}
    try:
        arch_class = arch_classes[defined_arch]
    except KeyError:
        print('Unknown architecture', file=sys.stderr)
        exit(1)
    if not arch_class:
        arch_str = os.uname()[4]
        if arch_str in ('i686', 'x86_64'):
            arch_class = X86
        elif arch_str == 'armv7l':
            arch_class = ARM
        elif arch_str == 'aarch64':
            arch_class = ARM64
        else:
            print('Unsupported architecture', file=sys.stderr)
            exit(1)
    return arch_class()


# see linux/Documentation/x86/boot.txt
class X86SetupHeader:
    BASE_OFFSET = 0x1f0
    _HEADER_FORMAT = 'xB2xI8xH14xB7xII8xI4xI20xI4xQQ'

    def __init__(self, kernel_image):
        base = X86SetupHeader.BASE_OFFSET
        parse_size = struct.calcsize(X86SetupHeader._HEADER_FORMAT)
        (self.setup_sects,
         self.syssize,
         self.jump,
         self.type_of_loader,
         self.ramdisk_image,
         self.ramdisk_size,
         self.cmd_line_ptr,
         self.kernel_alignment,
         self.payload_offset,
         self.setup_data,
         self.pref_address) = \
            struct.unpack(X86SetupHeader._HEADER_FORMAT,
                          kernel_image[base:base+parse_size])

        self.size = 0x202 + (self.jump >> 8) - X86SetupHeader.BASE_OFFSET
        self.data = bytearray(kernel_image[base:base+self.size])

    def set_value_in_data(self, fmt, offset, value):
        struct.pack_into(fmt, self.data, offset - X86SetupHeader.BASE_OFFSET,
                         value)

    def set_type_of_loader(self, value):
        self.type_of_loader = value
        self.set_value_in_data('B', 0x210, value)

    def set_ramdisk_image(self, value):
        self.ramdisk_image = value
        self.set_value_in_data('I', 0x218, value)

    def set_ramdisk_size(self, value):
        self.ramdisk_size = value
        self.set_value_in_data('I', 0x21c, value)

    def set_cmd_line_ptr(self, value):
        self.cmd_line_ptr = value
        self.set_value_in_data('I', 0x228, value)

    def set_kernel_alignment(self, value):
        self.kernel_alignment = value
        self.set_value_in_data('Q', 0x230, value)

    def set_setup_data(self, value):
        self.setup_data = value
        self.set_value_in_data('Q', 0x250, value)

    def set_pref_address(self, value):
        self.pref_address = value
        self.set_value_in_data('Q', 0x258, value)

    def get_data(self):
        return self.data


def get_kernel_compression_method(payload_magic):
    if payload_magic[0] == 0x1f and payload_magic[1] in (0x8b, 0x9e):
        return 'gzip'
    elif payload_magic[0] == 0x42 and payload_magic[1] == 0x5a:
        return 'bzip2'
    elif payload_magic[0] == 0x5d and payload_magic[1] == 0x0:
        return 'lzma'
    elif payload_magic[0] == 0xfd and payload_magic[1] == 0x37:
        return 'xz'
    elif payload_magic[0] == 0x02 and payload_magic[1] == 0x21:
        return 'lz4'
    elif payload_magic[0] == 0x89 and payload_magic[1] == 0x4c:
        return 'lzo'
    elif payload_magic[0] == 0x7f and payload_magic[1] == 0x45 and \
            payload_magic[2] == 0x4c and payload_magic[3] == 0x46:
        return 'uncompressed'
    else:
        return 'unknown'


def mem_region_as_e820(mem):
    E820_RAM = 1
    E820_RESERVED = 2

    return struct.pack('QQI', mem.virt_start, mem.size,
                       E820_RAM if mem.is_ram() else E820_RESERVED)


# see linux/Documentation/x86/zero-page.txt
class X86ZeroPage:
    def __init__(self, kernel_image, initrd, kernel_decomp_factor, config):
        self.setup_header = X86SetupHeader(kernel_image)

        prot_image_offs = (self.setup_header.setup_sects + 1) * 512
        prot_image_size = self.setup_header.syssize * 16

        self.kernel_load_addr = self.setup_header.pref_address - \
            prot_image_offs

        self.setup_header.set_kernel_alignment(self.setup_header.pref_address)
        self.setup_header.set_type_of_loader(0xff)

        ramdisk_size = 0
        ramdisk_load_addr = 0
        if initrd:
            kernel_size = len(kernel_image)
            ramdisk_size = os.fstat(initrd.fileno()).st_size

            offs = prot_image_offs + self.setup_header.payload_offset
            payload_magic = bytearray(kernel_image[offs:offs+4])
            compression_method = get_kernel_compression_method(payload_magic)

            # Unless specified, derive decompression factor from method
            if kernel_decomp_factor:
                decompression_factor = kernel_decomp_factor
            elif compression_method in ('bzip2'):
                # "Bzip2 uses a large amount of memory."
                # And that causes a larger binary, thus a smaller factor.
                decompression_factor = 4
            elif compression_method in ('lz4', 'lzo'):
                decompression_factor = 6
            elif compression_method in ('gzip'):
                decompression_factor = 7
            elif compression_method in ('lzma'):
                decompression_factor = 8
            else:  # xz or unknown
                decompression_factor = 9
            decompression_space = decompression_factor * kernel_size

            ramdisk_load_addr = page_align(self.kernel_load_addr +
                                           decompression_space)

        self.setup_header.set_ramdisk_image(ramdisk_load_addr)
        self.setup_header.set_ramdisk_size(ramdisk_size)

        self.e820_entries = []
        for region in config.memory_regions:
            if region.is_ram() or region.is_comm_region():
                if len(self.e820_entries) >= 128:
                    print('Too many memory regions', file=sys.stderr)
                    exit(1)
                self.e820_entries.append(region)

    def get_data(self):
        data = bytearray(0x1e8) + \
            struct.pack('B', len(self.e820_entries)) + \
            bytearray(X86SetupHeader.BASE_OFFSET - 0x1e9) + \
            self.setup_header.get_data() + \
            bytearray(0x2d0 - X86SetupHeader.BASE_OFFSET -
                      self.setup_header.size)
        for region in self.e820_entries:
            data += mem_region_as_e820(region)
        return data + bytearray(0x1000 - len(data))


def x86_gen_setup_data(config):
    SETUP_TYPE_JAILHOUSE = 6
    MAX_CPUS = 255
    SETUP_DATA_VERSION = 2
    SETUP_DATA_COMPAT_VERSION = 1
    standard_ioapic = 0
    setup_data_hdr_struct = '8xII'
    setup_data_struct = ('HH12x8xB%uxI' % MAX_CPUS)

    for irqchip in config.irqchips:
        if irqchip.is_standard():
            standard_ioapic = 1

    flags = 0
    platform_uarts = [0x3f8, 0x2f8, 0x3e8, 0x2e8]
    for pio_region in config.pio_regions:
        if pio_region.base in platform_uarts and pio_region.length == 8:
            flags |= (1 << platform_uarts.index(pio_region.base))

    return struct.pack(setup_data_hdr_struct, SETUP_TYPE_JAILHOUSE,
                       struct.calcsize(setup_data_struct)) + \
           struct.pack(setup_data_struct, SETUP_DATA_VERSION,
                       SETUP_DATA_COMPAT_VERSION, standard_ioapic, flags)


# pretend to be part of the jailhouse tool
sys.argv[0] = sys.argv[0].replace('-', ' ')

parser = argparse.ArgumentParser(description='Boot Linux in a non-root cell.')
parser.add_argument('config', metavar='CELLCONFIG',
                    type=argparse.FileType('rb'),
                    help='cell configuration file')
parser.add_argument('kernel', metavar='KERNEL',
		    type=argparse.FileType('rb'),
                    help='image of the kernel to be booted')
parser.add_argument('-d', '--dtb', metavar='DTB',
		    type=argparse.FileType('rb'),
                    help='device tree for the kernel [arm/arm64 only]')
parser.add_argument('-i', '--initrd', metavar='INITRD',
                    type=argparse.FileType('rb'),
                    help='initrd/initramfs for the kernel')
parser.add_argument('-c', '--cmdline', metavar='"CMDLINE"',
                    help='kernel command line')
parser.add_argument('-w', '--write-params', metavar='PARAMS_FILE',
                    type=argparse.FileType('wb'),
                    help='only parse cell configuration, write out '
                         'parameters into the specified file and print '
                         'required jailhouse cell commands to boot Linux '
                         'to the console')
parser.add_argument('-a', '--arch', metavar='ARCH',
                    help='target architecture')
parser.add_argument('-k', '--kernel-decomp-factor', metavar='FACTOR',
                    type=int,
                    help='decompression factor of the kernel image, used to '
                         'reserve space between the kernel and the initramfs')

try:
    args = parser.parse_args()
except IOError as e:
    print(e.strerror, file=sys.stderr)
    exit(1)

arch = resolve_arch(args.arch)

try:
    config = config_parser.CellConfig(args.config.read())
except RuntimeError as e:
    print(str(e) + ": " + args.config.name, file=sys.stderr)
    exit(1)

arch.setup(args, config)

if args.write_params:
    arch.write_params(args, config)
else:
    if libexecdir:
        linux_loader = libexecdir + '/jailhouse/linux-loader.bin'
    else:
        linux_loader = os.path.abspath(os.path.dirname(sys.argv[0])) + \
            '/../inmates/tools/' + arch.name + '/linux-loader.bin'

    cell = JailhouseCell(config)
    cell.load(open(linux_loader, mode='rb').read(), arch.loader_address())
    cell.load(arch.kernel_image, arch.kernel_address())
    if arch.dtb_address():
        cell.load(arch.dtb.get(), arch.dtb_address())
    if args.initrd:
        cell.load(args.initrd.read(), arch.ramdisk_address())
    cell.load(arch.params, arch.params_address())
    cell.start()
