#!/usr/bin/env python

import os
import re
import shutil
import sys
import tempfile

CAB_HEADER = "MSCE"

ARCHITECTURES = {
    "shx-sh3":        103,
    "shx-sh4":        104,
    "i386":           386,
    "i486":           486,
    "i586":           586,
    "powerpc-601":    601,
    "powerpc-603":    603,
    "powerpc-604":    604,
    "powerpc-620":    620,
    "powerpc-mpc821": 821,
    "arm720":         1824,
    "arm820":         2080,
    "arm920":         2336,
    "strongarm":      2577,
    "mips-r4000":     4000,
    "sh3":            10003,
    "sh3e":           10004,
    "sh4":            10005,
    "alpha-21064":    21064,
    "arm7tdmi":       70001,
}

DIR_VARIABLES = {
    "PROGRAMS": "%CE1%",                # \Program Files
    "WINDOWS": "%CE2%",                 # \Windows
    "DESKTOP": "%CE3%",                 # \Windows\Desktop
    "STARTUP": "%CE4%",                 # \Windows\StartUp
    "DOCUMENTS": "%CE5%",               # \My Documents
    "PROGRAMS_ACCESSORIES": "%CE6%",    # \Program Files\Accessories
    "PROGRAMS_COMMUNICATIONS": "%CE7%", # \Program Files\Communications
    "PROGRAMS_GAMES": "%CE8%",          # \Program Files\Games
    "PROGRAMS_OUTLOOK": "%CE9%",        # \Program Files\Pocket Outlook
    "PROGRAMS_OFFICE": "%CE10%",        # \Program Files\Office
    "WINDOWS_PROGRAMS": "%CE11%",       # \Windows\Programs
    "WINDOWS_ACCESSORIES": "%CE12%",    # \Windows\Programs\Accessories
    "WINDOWS_COMMUNICATIONS": "%CE13%", # \Windows\Programs\Communications
    "WINDOWS_GAMES": "%CE14%",          # \Windows\Programs\Games
    "FONTS": "%CE15%",                  # \Windows\Fonts
    "RECENT": "%CE16%",                 # \Windows\Recent
    "FAVORITES": "%CE17%",              # \Windows\Favorites

    "START_PROGRAMS": "%CE11%",         # \Windows\Start Menu\Programs
    "START_ACCESSORIES": "%CE12%",      # \Windows\Start Menu\Accessories
    "START_COMMUNICATIONS": "%CE13%",   # \Windows\Start Menu\Communications
    "START_GAMES": "%CE14%",            # \Windows\Start Menu\Games
    "START": "%CE17%",                  # \Windows\Start Menu
}

def write_int16(f, value):
    b1 = value & 0xff
    b2 = (value >> 8) & 0xff
    f.write("%c%c" % (b1, b2))

def write_int32(f, value):
    b1 = value & 0xff
    b2 = (value >> 8) & 0xff
    b3 = (value >> 16) & 0xff
    b4 = (value >> 24) & 0xff
    f.write("%c%c%c%c" % (b1, b2, b3, b4))

# Pad a string with NUL characters so that it has a length that is 
# a multiple of 4.  At least one NUL is always added.

def pad_string(s):
    pad_len = 4 - (len(s) % 4)
    return s + (pad_len * "\x00")

class HeaderSection:

    def __init__(self, cab_header):
        self.cab_header = cab_header
        self.arch = None
        self.app_name = None
        self.provider = None
        self.unsupported = None

    def __len__(self):
        return 100       # header has fixed size

    def set_meta(self, arch, app_name, provider, unsupported):

        if arch not in ARCHITECTURES:
            raise Exception("Unknown architecture '%s'" % arch)

        self.arch = ARCHITECTURES[arch]

        dictionary = self.cab_header.dictionary

        self.app_name = app_name
        dictionary.get(self.app_name)

        self.provider = provider
        dictionary.get(self.provider)

        self.unsupported = unsupported
        dictionary.get(self.unsupported)

    def write(self, stream):

        # Basic header

        stream.write(CAB_HEADER)
        write_int32(stream, 0)
        write_int32(stream, len(self.cab_header))
        write_int32(stream, 0)
        write_int32(stream, 1)
        write_int32(stream, self.arch)

        # minimum Windows CE version:
        write_int32(stream, 0)
        write_int32(stream, 0)
        write_int32(stream, 0)
        write_int32(stream, 0)
        write_int32(stream, 0)
        write_int32(stream, 0)

        dictionary = self.cab_header.dictionary

        # Write number of entries in other sections:

        for section in self.cab_header.sections:
            if section is not self:
                write_int16(stream, section.num_entries())

        # Write offsets of other sections:

        for section in self.cab_header.sections:
            if section is not self:
                offset = self.cab_header.get_section_offset(section)
                write_int32(stream, offset)

        # Special strings:

        special_strings = (
            self.app_name,
            self.provider,
            self.unsupported
        )

        dictionary_offset = self.cab_header.get_section_offset(dictionary)

        for s in special_strings:
            s_offset = dictionary.get_offset(s)
            write_int16(stream, dictionary_offset + s_offset)
            write_int16(stream, len(s) + 1)

        # Two left-over fields of unknown use:

        write_int16(stream, 0)
        write_int16(stream, 0)

class StringDictionary:
    def __init__(self, cab_header):
        self.cab_header = cab_header
        self.string_list = []
        self.strings = {}
        self.length = 0
        self.index = 1

    # Get the length of the dictionary, in bytes.

    def __len__(self):
        return self.length

    # Get the number of entries in the dictionary.

    def num_entries(self):
        return len(self.strings)

    # Get the ID for the given string, adding it if necessary.

    def get(self, s):
        # Is this a new string?  Add it to the dictionary.

        if s not in self.strings:
            offset = self.length
            padded = pad_string(s)

            self.strings[s] = (self.index, offset)
            self.string_list.append((self.index, padded))
            self.length += len(padded) + 4
            self.index += 1

        return self.strings[s][0]

    # Get the offset of a particular string within the dictionary.

    def get_offset(self, s):
        return self.strings[s][1] + 4

    # Write the dictionary to the output stream.

    def write(self, stream):

        # Write out all strings:

        for i, s in self.string_list:
            write_int16(stream, i)
            write_int16(stream, len(s))
            stream.write(s)

class DirectoryList:
    def __init__(self, cab_header):
        self.cab_header = cab_header
        self.directories_list = []
        self.directories = {}
        self.length = 0
        self.index = 1

    def __len__(self):
        return self.length

    def num_entries(self):
        return len(self.directories_list)

    # Find whether the specified directory exists in the list

    def find(self, dir):
        key = dir.lower()

        if key in self.directories:
            return self.directories[key]
        else:
            return None

    # Get the ID for the given directory, adding it if necessary.

    def get(self, dir):

        key = dir.lower()
        dictionary = self.cab_header.dictionary

        # Add new directory?

        if key not in self.directories:

            # Separate into individual directories, and map to strings:

            #dir_path = dir.split("\\")
            #if dir_path[0] == "":
            #    dir_path = dir_path[1:]
            dir_path = [ dir ]

            dir_path = map(lambda x: dictionary.get(x), dir_path)

            self.directories[key] = self.index
            self.directories_list.append((self.index, dir_path))
            self.length += 6 + 2 * len(dir_path)
            self.index += 1

        return self.directories[key]

    # Write the directory list to the specified stream.

    def write(self, stream):
        for i, dir in self.directories_list:
            write_int16(stream, i)
            write_int16(stream, 2 * len(dir) + 2)

            for subdir in dir:
                write_int16(stream, subdir)

            write_int16(stream, 0)

class FileList:
    def __init__(self, cab_header):
        self.cab_header = cab_header
        self.files = []
        self.length = 0
        self.index = 1

    # Get the length of this section, in bytes.

    def __len__(self):
        return self.length

    # Query whether the file list contains a particular file.

    def find(self, filename):
        dirname, sep, target_basename = filename.rpartition("\\")

        target_basename = pad_string(target_basename)

        target_dir_id = self.cab_header.directory_list.find(dirname)

        if target_dir_id is None:
            return None
        else:
            # Search the list of files:

            for i, dir_id, basename, file_no, flags in self.files:
                if dir_id == target_dir_id and basename == target_basename:
                    return file_no
            else:
                return None

    # Get the number of entries in the file list

    def num_entries(self):
        return len(self.files)

    # Add a file to the list.

    def add(self, filename, file_no, flags=0):

        dirname, sep, basename = filename.rpartition("\\")

        dir_id = self.cab_header.directory_list.get(dirname)

        padded = pad_string(basename)

        self.files.append((self.index, dir_id, padded, file_no, flags))
        self.length += 12 + len(padded)
        self.index += 1

    # Write this section to the output stream.

    def write(self, stream):

        for i, dir_id, filename, file_no, flags in self.files:
            write_int16(stream, i)
            write_int16(stream, dir_id)
            write_int16(stream, file_no)
            write_int32(stream, flags)
            write_int16(stream, len(filename))
            stream.write(filename)

# TODO?

class RegHiveList:
    def __len__(self):
        return 0

    def num_entries(self):
        return 0

    def write(self, stream):
        pass

class RegKeyList():
    def __len__(self):
        return 0

    def num_entries(self):
        return 0

    def write(self, stream):
        pass

class LinkList:
    def __init__(self, cab_header):
        self.cab_header = cab_header
        self.links = []
        self.length = 0
        self.index = 1

    def __len__(self):
        return self.length

    def num_entries(self):
        return len(self.links)
    
    # Determine the target type (dir/file) and ID:

    def __find_target(self, target):
        file_id = self.cab_header.file_list.find(target)

        if file_id is not None:
            return 1, file_id

        dir_list = self.cab_header.get_section(DirectoryList)
        dir_id = dir_list.find(target)

        if dir_id is not None:
            return 0, dir_id

        raise Exception("Link target '%s' not found" % target)

    def add(self, target, destination):

        target_type, target_id = self.__find_target(target)

        dest_path = destination.split("\\")

        # Leading \:

        if dest_path[0] == "":
            dest_path = dest_path[1:]

        # %CEn% to specify the install root is handled differently for
        # links than it is for files/dirs.

        match = re.match(r"\%CE(\d+)\%", dest_path[0])

        if match:
            base_dir = int(match.group(1))
            dest_path = dest_path[1:]
        else:
            base_dir = 0

        # Map dirs that make up the path to strings.

        dictionary = self.cab_header.dictionary
        dest_path = map(lambda x: dictionary.get(x), dest_path)

        self.links.append((self.index, target_type, target_id,
                           base_dir, dest_path))
        self.index += 1
        self.length += 14 + 2 * len(dest_path)

    def write(self, stream):

        for i, target_type, target_id, base_dir, dest_path in self.links:

            write_int16(stream, i)
            write_int16(stream, 0)
            write_int16(stream, base_dir)
            write_int16(stream, target_id)
            write_int16(stream, target_type)
            write_int16(stream, 2 * len(dest_path) + 2)

            for subdir in dest_path:
                write_int16(stream, subdir)

            write_int16(stream, 0)

class CabHeaderFile:
    def __init__(self):
        self.dictionary = StringDictionary(self)
        self.directory_list = DirectoryList(self)
        self.file_list = FileList(self)

        self.sections = [
            HeaderSection(self),
            self.dictionary,
            self.directory_list,
            self.file_list,
            RegHiveList(),
            RegKeyList(),
            LinkList(self)
        ]

    def set_meta(self, *args):
        header_section = self.get_section(HeaderSection)
        header_section.set_meta(*args)

    def add_file(self, filename, file_no, flags=0):
        files_section = self.get_section(FileList)
        files_section.add(filename, file_no, flags)

    def add_link(self, target, destination):
        links_section = self.get_section(LinkList)
        links_section.add(target, destination)

    def get_section(self, section_class):
        for section in self.sections:
            if isinstance(section, section_class):
                return section
        else:
            raise Exception("Can't find section of class %s" % section_class)

    def get_section_offset(self, section):
        offset = 0

        for s in self.sections:
            if section is s:
                return offset
            offset += len(s)
        else:
            raise Exception("Section %s not found in list")

    def __len__(self):
        result = 0
        for s in self.sections:
            result += len(s)
        return result

    def write(self, stream):
        old_pos = 0
        for section in self.sections:
            section.write(stream)
            pos = stream.tell()
            if pos != old_pos + len(section):
                raise Exception("Section is %i bytes long, but %i written" % \
                                (len(section), pos - old_pos))
            old_pos = pos

class CabFile:
    def __init__(self, config):
        self.cab_header = CabHeaderFile()

        self.__process_meta(config)
        self.__process_files(config["files"])

        if "links" in config:
            self.__process_links(config["links"])

    # Metadata:

    def __process_meta(self, config):
        arch = config.get("arch") or "strongarm"
        app_name = config.get("app_name")
        provider = config.get("provider")
        unsupported = config.get("unsupported") or ""

        if app_name is None or provider is None:
            raise Exception("Application name and provider must be specified")

        self.cab_header.set_meta(arch, app_name, provider, unsupported)
        self.app_name = app_name

    # Get the shortened 8.3 filename used for the specified file
    # within the CAB.

    def __shorten_name(self, filename, file_no):

        # Strip down to base filename without extension:

        basename = os.path.basename(filename)

        if "." in basename:
            basename = basename.rpartition(".")[0]

        # Remove non-alphanumeric characters:

        def only_alnum(x):
            if x.isalnum():
                return x
            else:
                return ""

        cleaned_name = "".join(map(only_alnum, basename))
        short_name = cleaned_name[0:8]

        if len(short_name) < 8:
            short_name = "0" * (8 - len(short_name)) + short_name

        return "%s.%03i" % (short_name, file_no)

    # Process the list of files to install:

    def __process_files(self, files):
        self.files = [ self.app_name ]

        for filename, source_file in files.items():
            file_no = len(self.files)
            filename = expand_path(filename)
            self.cab_header.add_file(filename, file_no)
            self.files.append(source_file)

    # Process the list of links:

    def __process_links(self, links):
        for destination, target in links.items():
            target = expand_path(target)
            destination = expand_path(destination)
            self.cab_header.add_link(target, destination)

    # Write the header file:

    def __write_header(self, dir):

        basename = self.__shorten_name(self.files[0], 0)
        filename = os.path.join(dir, basename)

        stream = file(filename, "w")
        self.cab_header.write(stream)
        stream.close()

        return [ filename ]

    # Write the files:

    def __write_files(self, dir):

        result = []

        for file_no in range(1, len(self.files)):
            source_file = self.files[file_no]
            basename = self.__shorten_name(source_file, file_no)
            filename = os.path.join(dir, basename)

            shutil.copy(source_file, filename)
            result.append(filename)

        return result

    # Output to a file:

    def write(self, filename):

        temp_dir = tempfile.mkdtemp()

        header = self.__write_header(temp_dir)
        files = self.__write_files(temp_dir)
        files.reverse()

        args = [ "lcab", "-n" ] + header + files + [ filename ]

        os.spawnlp(os.P_WAIT, "lcab", *args)

        # Clean up:

        for tmpfile in header + files:
            os.remove(tmpfile)
        os.rmdir(temp_dir)

def expand_path(filename):

    # Replace Unix-style / path separators with DOS-style \

    filename = filename.replace("/", "\\")

    # Expand $(xyz) path variables to their Windows equivalents:

    def replace_var(match):
	var_name = match.group(1)

	if not var_name in DIR_VARIABLES:
	    raise Exception("Unknown variable '%s'" % var_name)
	else:
	    return DIR_VARIABLES[var_name]

    return re.sub(r"\$\((.*?)\)", replace_var, filename)

def read_config_file(filename):
    f = file(filename)

    data = f.readlines()
    data = "".join(data)

    f.close()

    prog = compile(data, filename, "exec")
    result = {}
    eval(prog, result)

    return result

# List the files that the output CAB depends on.

def print_dependencies(filename):
    config = read_config_file(filename)

    files_list = config["files"]

    for dest, source_file in files_list.items():
        print source_file

if len(sys.argv) < 3:
    print "Usage: %s <config file> <output file>" % sys.argv[0]
    sys.exit(0)

if sys.argv[1] == "-d":
    print_dependencies(sys.argv[2])
else:
    config = read_config_file(sys.argv[1])

    cab_file = CabFile(config)
    cab_file.write(sys.argv[2])