#!/usr/bin/python3

"""
This script is part of the apt-notifier package.

Usage: update-Origins-Pattern [<file>|options]

Update the Unattended-Upgrade::Origins-Pattern for most currently enabled
repos with the exceptions of the Debian Backports and Debian Experimental
repos.

File:

    File to be updated. Typically this would be the /etc/apt/apt.conf.d/
    51unattended-upgrades-mx file but can be any file user has read/write
    permissions to.

    If no file argument provided, output will be written to stdout.

Options:

    -d, --default
        Default to updating the Origins-Pattern in the
        /etc/apt/apt.conf.d/51unattended-upgrades-origins file.

    -h, --help
        Show this help.

"""

import json
import os, re, sys
import subprocess
from subprocess import run
import argparse
import hashlib

# defaults    
default_output = '/etc/apt/apt.conf.d/51unattended-upgrades-origins'
prog="update-Origins-Pattern"

# usage 
usage=f" {prog} [<file>|options]"
description = """

Update the Unattended-Upgrade::Origins-Pattern for most currently enabled
repos with the exceptions of the Debian Backports and Debian Experimental
repos - and also the MX testrepo if enabled.
"""
help_file = f"""
File to be updated. Typically this would be the {default_output} 
file but can be any file user has read/write permissions to.

If no file argument provided, output will be written to stdout.
"""
help_default = f"""
update the Origins-Pattern in the file {default_output}.
"""

# initiate argparser
parser = argparse.ArgumentParser(
            prog=prog,
            usage=usage,
            description=description 
            )

# add long and short argument
parser.add_argument("file", help=help_file, type=str, nargs='?',default='stdout')
parser.add_argument("-d", "--default", help=help_default, action="store_true" )

# parse command line arguments 
args = parser.parse_args()
    
# check arguments
if args.file == 'stdout' and len(sys.argv) == 1:
    output = args.file
if args.file and len(sys.argv) == 2:
    output = args.file
    if not output.startswith('/'):
        output = './' + output
if args.default:
    output = default_output
    if not output.startswith('/'):
        output = './' + output

def d_print(text=''):
    """
    simple debug printer helper
    """
    
    try: 
        debug_me = os.getenv('DEBUG_ME')
    except:
        debug_me = False
    if debug_me:
        print(text, file = sys.stderr)

# 'debug' prints
d_print(f"Debug: Will write Origins-Pattern into: '{output}'")


# check debian codename
# based on get_distro_codename() used within unattended-upgrades script
def debian_codename():
    try:
        with open("/etc/debian_version","r") as f:
            dv =  f.read().splitlines()[0]
            dv = dv.split(".")[0]
            dv = dv.split("/")[0]
            if dv == "9":
                debian_codename = "stretch"
            elif dv == "10":
                debian_codename = "buster"
            elif dv == "11":
                debian_codename = "bullseye"
            elif dv == "12":
                debian_codename = "bookworm"
            else:
                debian_codename = dv
    except FileNotFoundError as e:
        debian_codename = "n/a"
    return debian_codename

# read current apt policy
pol = run(['apt-cache', 'policy'], capture_output=True,text=True).stdout

# strip leading whitespaces
lines = [ x.strip() for x in pol.splitlines() ]

# keep pin-priority line and release line
lines = [ x for x in lines if x[0].isdigit() or x.startswith('release') ]

# remove leading 'release' with release lines
lines = [ re.sub('^release ','',x).strip() for x in lines ]

# check we have a even number of lines - one pin- and one release line
if len(lines) % 2 != 0:
    # fixme: fallback to ignore pin-lines with apt-cache output
    print("Warning: mismatch in number of pin- and relase-lines from apt-cache policy !")

# put pin- and release-lines into a separate list
pins = lines[0::2]
rels = lines[1::2]

# remove from pin-lines last two fields: architecture and 'Packages'
pins = [ ' '.join(x.split()[0:-2]) for x in pins ]

# put release -> lines into (rels -> pins) dictonary
relspins = dict(zip( rels, pins))
d_print(json.dumps(relspins, indent=4))

# create a dict with slimmed-down release lines and a dict of pin-lines
origins = {}
for k, v in relspins.items():
    # remove from  binary architecture, componenents and version from release-line
    # usign regexp: re.sub(',(b|v)=[^,]*','',x)
    pin = v
    rel = re.sub('(b|v)=[^,]*','',k)
    # strip double- and leading commas
    rel = re.sub(',,',',',rel)
    rel = rel.strip(',')

    # components handling for MX Repositories: keep components
    # needed to disable MX testrepo 
    if  'o=MX repository' in rel:
        remove_components = False
    else:
        remove_components = True

    if remove_components:
        # remove componenents within release lines
        rel = re.sub('c=[^,]*','',rel)
        rel = re.sub(',,',',',rel)
        rel = rel.strip(',')

    # strip double- and leading commas
    rel = re.sub(',,',',',rel)
    rel = rel.strip(',')
    # add 'dict' and 'pins' dictonary to rel-key in origins dict
    if not rel in origins:
        origins[rel] = { 'dict': {}, 'pins': {} }

    origins[rel]['dict'] = dict( x.split('=') for x in rel.split(',') )
    origins[rel]['enabled'] = True

    # add pin-line to release dict entry as 'pins'-key
    if not pin in origins[rel]['pins']:
        origins[rel]['pins'][pin] = 'x'

    """
    Note: 
    MX repository for testrepo and experimental temp repo do get identified
    by code'n'ame and and 'c'omponents, resp.:
    MX test-repo  "o=MX repository,a=mx,n=buster,l=MX repository,c=test"
    MX temp-repo  "o=MX repository,a=mx,n=temp,l=MX repository,c=main"
    """
    # use 'check' dictionary to find origin lines to be disabled 
    check = origins[rel]['dict']

    # Disable Unattended Upgrades from MX temp- and test-repos
    if 'o' in check and check['o'] == 'MX repository':
        if 'n' in check and check['n'] == 'temp':
            origins[rel]['enabled'] = False
        if 'c' in check and check['c'] =='test':
            origins[rel]['enabled'] = False

    # Disable Unattended Upgrades from Debian backports repos
    if 'o' in check and check['o'] == 'Debian Backports':
        origins[rel]['enabled'] = False

    # Disable Unattended Upgrades from Debian experimental repos
    if 'o' in check and check['o'] == 'Debian Experimental':
        origins[rel]['enabled'] = False

    # Disable Unattended Upgrades from Debian Multimedia repos
    if 'Unofficial Multimedia Packages' in rel:
        origins[rel]['enabled'] = False

    """
    Disable Unattended Upgrades from Debian repos with Debian releases
    other than the running one is based upon.
    """
    if 'o' in check and check['o'] == 'Debian':
        d_print(f"Debug: rel: {rel}") 
        d_print(f"Debug: debian_codename(): {debian_codename()}") 
        d_print(f"Debug: check['n']: {check['n']}") 
        if 'n' in check:
            if debian_codename():
                origins[rel]['enabled'] = False
                d_print(f"Debug: list(debian_codename()): {debian_codename()}")
                for n in debian_codename():
                    d_print(f"Debug: for n in list(debian_codename()): {n}")
                    d_print(f"Debug: if n in check('n'): if n in {check['n']}")
                    if n in check['n']:
                        origins[rel]['enabled'] = True

d_print(json.dumps(origins, indent=4))

# 
d_print("Unattended-Upgrade::Origins-Pattern {")

import datetime
now = datetime.datetime.now().strftime("%Y-%m-%d  %H:%M:%S")
uuop = """Unattended-Upgrade::Origins-Pattern {
"""
for k in sorted(origins.keys()):
    ok = origins[k]
    if not 'o' in  ok['dict']:
        continue

    for pin_line in ok['pins']:
        pin, deb = pin_line.split(' ',1)
        #d_print(f"//{pin.rjust(5, ' ')} {deb}")
        #d_print("//")
        d_print(f"//  {pin_line}")
        uuop += f"""//  {pin_line}
"""
    if ok['enabled']:
        prefix = '  '
    else:
        prefix = '//'
    d_print(f'{prefix}  "{k}";')
    uuop += f'''{prefix}  "{k}";
'''
d_print("};")
uuop +='''};
'''
d_print(uuop)

headline = f"// generated by {prog}"
header = f"""{headline} at {now}
//
"""
# display generated output on sdtout
if output == 'stdout':
    print(header + uuop)
    sys.exit(0)

# calculate md5sum fomr exiting file
try:
    with open(output,'r') as f:
        old_lines = f.read().splitlines()
        if len(old_lines) > 1:
            line = old_lines[0]
            if headline in line:
                old_lines = old_lines[1:]
            line = old_lines[0]
            if line.startswith('//'):
                old_lines = old_lines[1:]
        old = '\n'.join(old_lines) + '\n'
        old = old.encode(encoding="ascii")
        md5_old = hashlib.md5(old).hexdigest()
except FileNotFoundError as e:
    md5_old = ''

new = uuop.encode(encoding="ascii")
md5_new = hashlib.md5(new).hexdigest()

# write out if content has changed changed
if md5_old != md5_new:
    try:
        with open(output,'w') as f:
            f.write(header + uuop)
    except IOError as e:
        print(e, file = sys.stderr)
        sys.exit(e.errno)

sys.exit(0)
