#!/usr/bin/python
"""Ask the antitheft server which software version I should be running."""
# Copyright (C) 2007-2012 One Laptop Per Child Association, Inc.
# Licensed under the terms of the GNU GPL v2 or later; see COPYING for details.
# Written by C. Scott Ananian <cscott@laptop.org>
from __future__ import with_statement
from __future__ import division
from urllib import urlencode
from urllib2 import urlopen, URLError
from binascii import hexlify
import bitfrost.util.json as json
from bitfrost.contents.utils import open_envel, UnifiedContents
from bitfrost.update import perform_update, inhibit_suspend, check_signature
from bitfrost.leases.keys import OATS_KEYS, LEASE_KEYS
from bitfrost.leases.crypto import date_cmp, verify_act
import bitfrost.leases.errors
from random import SystemRandom
from subprocess import call, check_call
import os, os.path, sys, shutil
import hashlib
import re
import time, calendar
from datetime import datetime, timedelta
import logging

DEFAULT_ANTITHEFT_SERVER = 'antitheft.laptop.org'
ANTITHEFT_SERVER_FILE='/etc/oats-server'
ANTITHEFT_IGNORE_XS_FLAG='/etc/oats-ignore-xs'
ANTITHEFT_IGNORE_SIGNATURE_FLAG='/etc/oats-ignore-signature'

# This is the last time we attempt to pass our randomly_do_nothing logic
LAST_ATTEMPT_FILE = '/security/update-attempt'

# This is the last time that we passed the randomly_do_nothing logic and went
# on to make a query. It doesn't suggest that the query actually succeeded.
LAST_QUERY_FILE = '/security/update-query'

INTERVAL_FILE = '/security/update-interval'
STREAM_FILE = '/security/update-stream'
DEFAULT_STREAM_FILE = '/etc/olpc-update/update-stream'
LEASE_FILE='/security/lease.sig'

def try_unlink(path):
    try:
        os.unlink(path)
    except:
        pass

# Get rid of any leases and poweroff.
def shutdown_on_stolen():
    try_unlink(LEASE_FILE)
    call(['/usr/sbin/poweroff'])
    sys.exit()

# FIXME: #11666
def find_devicetree():
    if os.path.exists('/proc/device-tree'):
        return '/proc/device-tree'
    else:
        return '/ofw'

def read_ofw(dev_path):
    """Read a node in the OFW device tree.  Return None if node not present."""
    real_path = os.path.join(find_devicetree(), 'mfg-data', dev_path)
    if not os.path.exists(real_path): return None
    with open(real_path) as f:
        return f.read().rstrip("\n\0")

def current_version():
    try:
        shortname = os.path.basename(os.readlink('/versions/running'))
        return UnifiedContents('/versions/contents/%s' % shortname) \
               .contents_hash()
    except:
        return None

def check_credential(data, credential, valid_keys, serialnum):
    if os.path.exists(ANTITHEFT_IGNORE_SIGNATURE_FLAG):
        logging.debug("Ignoring signature in response")
        return True

    # sig01: sha256 keyid data\n
    #  3 2 2    6  1  64 1 N   1
    seen_good_sig = False
    for sig in open_envel('sig', 1, credential):
        # check signature should never raise an error, and at least one
        # should return true.
        if check_signature(json.write(data), sig, valid_keys, serialnum):
            seen_good_sig = True
    if not seen_good_sig:
        raise RuntimeError("No signatures match our keys")
    return True

def timestr_to_secs(timestr):
    t = list(time.strptime(timestr, '%Y%m%dT%H%M%SZ'))
    t[8] = 0 # no dst
    secs = calendar.timegm(tuple(t))
    return secs

def clamp(n):
    """clamp number to the interval [0,1]."""
    return max(0, min(n, 1))

def randomly_do_nothing(increase_probability):
    """look at times of LAST_ATTEMPT_FILE, LAST_QUERY_FILE and compare to
    INTERVAL_FILE to see if we should actually do a check."""

    if not os.path.exists(LAST_QUERY_FILE):
        logging.debug('First run detected; forcing query attempt')
        return False

    def filetime(f, now):
        """Return the mtime of the given file, as a datetime object."""
        try:
            t = os.stat(f).st_mtime # in floating-point seconds
            return t if t <= now else None # file modified in future!
        except OSError:
            return None # couldn't find file.
    now = time.time() # current time (in fp seconds)
    last_attempt = filetime(LAST_ATTEMPT_FILE, now)
    last_query = filetime(LAST_QUERY_FILE, now)
    # if we get an error, assume the last attempt was 15 minutes ago.
    if last_attempt is None: last_attempt = now - (15*60)
    # if we get an error, assume that the last query was successful.
    if last_query is None: last_query = last_attempt
    # interval is how many attempts per month should be attempted.
    try:
        interval = int(open(INTERVAL_FILE).read().strip())
    except:
        interval = 30 # check once a day as a fallback.
    logging.debug('Last attempt: %d sec ago; last query: %d sec ago; interval: %d', now-last_attempt, now-last_query, interval)
    # normalize to attempts per second.
    interval = interval / (30*24*60*60.)
    basic_prob = clamp(     interval  * (now - last_attempt))

    if increase_probability:
        # increase probability 4x
        # (for use e.g. when lease has expired)
        basic_prob = clamp(basic_prob * 4.)

    # if it's been too long (10x expected time) since the last successful
    # query, increase the query probability by up to 10x.
    upper_prob = clamp((10.*interval) * (now - last_attempt))
    slider     = clamp(     interval  * (now - last_query) / 10.)
    adj_prob = basic_prob + slider*max(0, upper_prob - basic_prob)

    # roll the dice!
    logging.debug('Chances of a query: %3.1f%%', adj_prob * 100)
    r = SystemRandom().random()
    return r > adj_prob   # if r is higher, do nothing.

def get_oats_urls():
    oats_urls = []
    if not os.path.exists(ANTITHEFT_IGNORE_XS_FLAG):
        # The local XS is queried first
        # unless we are told not to.
        # The XS is local, fast and beautiful...
        # listen to its voice.
        oats_urls.append('http://schoolserver/antitheft/1/')

    if os.path.exists(ANTITHEFT_SERVER_FILE):
        oats_servers = open(ANTITHEFT_SERVER_FILE).read().strip().split("\n")
        oats_urls.extend(map(lambda x: 'http://' + x.strip() + '/antitheft/1/',
                            oats_servers)) 
    else:
        oats_urls.append('http://' + DEFAULT_ANTITHEFT_SERVER + '/antitheft/1/')

    return oats_urls

def touch(f, ignore_errors=False):
    try:
        if not os.path.exists(f):
            open(f,'w').close()
        os.utime(f, None) # set mtime to current time
    except OSError:
        if not ignore_errors: raise

def update_lease(new_lease):
    """
    Atomically put a new lease in place at /security/lease.sig
    """

    logging.info('Switching to new lease')
    try_unlink(LEASE_FILE + ".new")
    new_fd = open(LEASE_FILE + ".new", "w")
    new_fd.write(new_lease)
    new_fd.close()

    try:
        shutil.copy(LEASE_FILE, LEASE_FILE + ".bak")
    except:
        # don't abort lease update if backup failed
        pass

    os.rename(LEASE_FILE + ".new", LEASE_FILE)

class UpdateQuery:
    def __init__(self, serialnum, uuid, stream=None):
        self.serialnum = serialnum
        self.uuid = uuid
        self.nonce = hexlify(os.urandom(16))

        self.params = { 'serialnum': self.serialnum, 'nonce': self.nonce }

        vhash = current_version()
        if vhash is not None:
            self.params['version'] = vhash

        try:
            self.params['stream'] = open(DEFAULT_STREAM_FILE).read().strip()
        except:
            pass # no default stream for this build.

        try:
            self.params['stream'] = open(STREAM_FILE).read().strip()
        except:
            pass # no stream set.
        if stream is not None:
            self.params['stream'] = stream # manually force an update stream

        # add free disk space in /versions/pristine
        try:
            st = os.statvfs('/versions/pristine/')
            self.params['freespace'] = str(st.f_bfree * st.f_bsize // 1024)
        except:
            pass # not a critical parameter; be safe.

    def run(self, url, valid_keys=None):
        if valid_keys is None:
            valid_keys = OATS_KEYS

        logging.info('Querying %s', url)
        self.response = []

        # XXX: should use a handler to handle hashcash stuff here.
        try:
            resp = urlopen(url, urlencode(self.params))
        except URLError, e:
            logging.info("URL error %s", e.reason)
            return

        if resp.code == 200:
            data, credential = open_envel('oatc-signed-resp',1,json.read(resp.read()))
            check_credential(data, credential, valid_keys, self.serialnum)
            self.response = open_envel('oatc-resp', 1, data)
        else:
            logging.info('Bad HTTP status code: %d', resp.code)

    def is_stolen(self):
        if 'stolen' not in self.response:
            return False

        digest = hashlib.sha256(self.uuid + ":" + self.nonce + ":STOLEN").hexdigest()
        return self.response['stolen'] == digest

    def nonce_is_bad(self):
        return 'nonce' not in self.response or self.nonce != self.response['nonce']

class UpdateQueryApp:
    def __init__(self, options):
        self.serialnum = read_ofw('SN')
        self.uuid = read_ofw('U#')
        self.options = options
        self.delay = 1

        # We stop iterating the URL list when we find a server that gives
        # us a lease and/or update info.
        self.finished = False

        # Reverse URL list to treat like a stack, and use default keys
        self.urls = []
        urls = get_oats_urls()
        urls.reverse()
        for url in urls:
            self.urls.append((url, None))

        # Set class member variables:
        # lease_expiry:
        # The expiry date/time of a valid lease we found on the system, only
        # when it has not expired. If this value is not None, the system is
        # considered to have a good, valid lease.
        # Note that lease_expiry will also be None in the case when the security
        # system is disabled.
        #
        # lease_expired:
        # True if we found an expired lease on the system.
        self.examine_current_lease()

    def process_response(self):
        if not self.query.response:
            logging.info('No response from server')
            return

        if self.query.nonce_is_bad():
            logging.error('Bad nonce in reply')
            return

        if self.query.is_stolen():
            shutdown_on_stolen()

        clock_updated = False
        response = self.query.response
        if 'time' in response:
            logging.info('Server time: %s', response['time'])
            serverepoch = timestr_to_secs(response['time'])

            now = time.time()
            # The main goal is to reset the clock in case it is really off
            # due to RTC battery or tinkering.
            # We only set the time if it's off by more than 2 hours.
            # This prevents messing with NTP.
            if abs(now - serverepoch) > 3600 * 2:
                lease_was_good = self.lease_expiry is not None
                newtimestr = time.strftime('%Y-%m-%d %H:%M:%S +0000', time.gmtime(serverepoch))
                logging.info("Setting hwclock to %s", newtimestr)
                clock_updated = True
                check_call(['/usr/sbin/hwclock',
                            '--set', '--date', newtimestr])
                check_call(['/usr/sbin/hwclock',
                            '--hctosys'])

        if 'lease' in response:
            try:
                if self.maybe_update_lease():
                    self.finished = True
            except:
                # don't abort a sw update if the leasing doesn't work out
                logging.exception("Failed to update lease")

        if clock_updated:
            # If we had a good lease, but after updating the clock (and
            # installing any new lease from the server) it is no longer good,
            # shut down immediately.
            # We also consider the system stolen, on the assumption that we
            # have detected a clock rollback attack.
            self.examine_current_lease()
            if lease_was_good and self.lease_expiry is None:
                shutdown_on_stolen()

        if 'update' in response:
            self.handle_update()
            self.finished = True
            return

        if 'delegate' in response:
            new_url, new_key = response['delegate']
            time.sleep(self.delay) # limit worst case loop traffic.
            self.delay = self.delay * 2
            self.urls.append((new_url, [new_key]))

    def run(self):
        if self.options.sleep > 0:
            # sleep some number of minutes to prevent all queries from being
            # synchronized
            logging.info('Sleeping for a bit.')
            time.sleep(SystemRandom().uniform(0, self.options.sleep*60))

        try:
            if self.options.auto:
                if randomly_do_nothing(self.lease_expired):
                    logging.info('Not time for next query yet.')
                    return 0 # don't check.
        finally:
            touch(LAST_ATTEMPT_FILE, ignore_errors=True)

        self.query = UpdateQuery(self.serialnum, self.uuid, self.options.stream)
        return self.run_queries()

    @inhibit_suspend
    def run_queries(self):
        touch(LAST_QUERY_FILE, ignore_errors=True)

        # We query all URLs until we find a server that gives us a lease
        # and/or update information. This means that we even try the next
        # server in the list in the case when a server gave a blank-ish
        # response, in addition to the case when the server simply didn't
        # respond.
        while self.urls:
            url, keylist = self.urls.pop()
            try:
                self.query.run(url, keylist)
            except:
                logging.exception('Failed to send query')
                continue

            try:
                self.process_response()
            except:
                logging.exception('Failed to process response')

            if self.finished:
                break

        if not self.finished:
            logging.info('Could not contact any OAT server')
            return 1
        return 0


    def handle_update(self):
        vhash, check_frequency, priority, hints = self.query.response['update']
        logging.info('Requested update %s with priority %s', vhash, priority)
        try: # try to record the requested check_frequency.
            open(INTERVAL_FILE, 'w').write(str(check_frequency))
        except:
            pass
        if vhash == current_version():
            logging.info('Already up to date.')
            return # don't need to upgrade
        if priority == 'low':
            if self.options.force:
                logging.info('Forcing low priority update.')
            else:
                logging.info('Skipping low priority update; use --force to override.')
                return # don't need to upgrade

        # invoke updater.
        # XXX: once we've found that we want to update, we should try more often?
        logging.info('Performing update: %s %s %s', vhash, priority, hints)
        perform_update(vhash, priority, hints, self.options.verbosity)

    def maybe_update_lease(self):
        """
        Consider new lease data as a lease candidate to replace the existing one.
        Does not use the new lease if it fails security checks, or if the existing
        lease has an expiry date further in the future.
        Returns True if the lease was valid.
        """

        logging.info('Considering new lease from server')

        # cryptographically verify new lease
        new_lease = self.query.response['lease']
        try:
            new_expiry = verify_act(self.serialnum, self.uuid, new_lease,
                                    LEASE_KEYS)
        except (bitfrost.leases.errors.InvalidLeaseData, bitfrost.leases.errors.NoLeaseFound, bitfrost.leases.errors.VerificationFailure), e:
            logging.info("New lease is bad: %r", e)
            return False
        except bitfrost.leases.errors.LeaseExpired:
            logging.info("New lease has already expired, ignoring.")
            return False
        except:
            logging.exception("Unexpected exception from bitfrost")
            return False

        if self.lease_expiry is None:
            logging.info('Existing lease expired, broken or missing; installing new one')
            update_lease(new_lease)
            return True

        logging.debug('Current lease expires %s, new lease expires %s', self.lease_expiry, new_expiry)

        if date_cmp(new_expiry, self.lease_expiry) <= 0:
            logging.info('Staying with existing lease')
            return True

        update_lease(new_lease)
        return True

    def examine_current_lease(self):
        self.lease_expiry = None
        self.lease_expired = False

        try:
            fd = open(LEASE_FILE, "r")
            current_lease = fd.read()
            fd.close()
        except:
            return

        try:
            self.lease_expiry = verify_act(self.serialnum, self.uuid,
                                           current_lease, LEASE_KEYS)
        except (bitfrost.leases.errors.InvalidLeaseData, bitfrost.leases.errors.NoLeaseFound, bitfrost.leases.errors.VerificationFailure):
            # bad lease
            return
        except bitfrost.leases.errors.LeaseExpired:
            self.lease_expired = True
        except:
            logging.exception("Unexpected exception from bitfrost")


def main():
    from optparse import OptionParser
    from bitfrost.update import VERSION_INFO

    parser = OptionParser(usage="""
 %prog [options]
 %prog --help""")
    parser.add_option('-a','--auto', action='store_true', dest='auto',
                      default=False, help="use randomness to ensure that we don't check too often.")
    parser.add_option('-s','--sleep', action='store', type='float',
                      dest='sleep', default=0, metavar='MINUTES',
                      help='sleep for a random period up to the specified limit before starting; this helps avoid synchronized queries.')
    parser.add_option('-v',action='count',dest='verbosity',default=0,
                      help='display verbose progress information.')
    parser.add_option('-f','--force', action='store_true',dest='force',
                      default=False,help='Force update even if low priority.')
    parser.add_option('--stream', action='store',dest='stream',metavar='STREAM',
                      default=None,help='Force update to the given stream.')
    parser.add_option('--version',action='store_true',dest='version',
                      default=False,
                      help="display version and license information.")
    (options, args) = parser.parse_args()
    if options.version:
        print VERSION_INFO
        sys.exit(0)
    if os.getuid() != 0:
        parser.error('Must be run as root.')

    logger = logging.getLogger()
    if options.verbosity == 0:
        logger.setLevel(logging.CRITICAL)
    elif options.verbosity == 1:
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.DEBUG)

    app = UpdateQueryApp(options)
    sys.exit(app.run())

if __name__ == '__main__': main ()
