""" This file contains all code for allowing the debugging of threads. """

from bdb import Breakpoint
import linecache
import inspect
import os
import sys
import thread
import threading

from pydb.fns import get_confirmation
from pydb.gdb import Gdb
from mpdb import MPdb, Exit    # Needed for calls to do_info

class UnknownThreadID(Exception): pass

tt_dict = {}
current_thread = None

# This global variable keeps track of the 'main' debugger which, when one
# of the MTracer objects encounters a breakpoint in its thread, is used to
# place a call to that main debugger asking it to stop and examine the frame
# said thread stopped at.
_main_debugger = None

# These global variables keep track of 'global state' in the
# debugger.
g_args = None
g_breaks = {}
g_commands = {}    # commands to execute at breakpoints
g_dirname = None
g_main_dirname = None
g_program_sys_argv = None
g_prompt = "(MPdb) "
g_search_path = None
g_sys_argv = None

class MTracer(Gdb):
    """ A class to trace a thread. Breakpoints can be passed from
    a main debugger to this debugger through the constructor
    which is useful, for instance, if a breakpoint occurs inside
    a thread's run() method.
    """
    def __init__(self):
        Gdb.__init__(self, stdout=_main_debugger.stdout)
        self.thread = threading.currentThread()
        self.running = True
        self.reset()

        # Copy some state
        self.breaks = dict(g_breaks)
        self._program_sys_argv = g_program_sys_argv
        self._sys_argv = g_sys_argv
        self.prompt = g_prompt

        # XXX This needs fixing, hack so that when a new MTracer
        # is created for this thread we don't stop at the bottom
        # frame.
        self.botframe = 1

    def trace_dispatch(self, frame, event, arg):
        """ Override this method so that we can lock the main debugger
        whilst we're performing our debugging operations.
        """
        if not self.running:
            # The user has requested and confirmed a quit, nothing
            # left to do but exit the entire debugger. Every
            # thread calls sys.exit() in order to supress any
            # exceptions.
            sys.exit()
            
        _main_debugger.lock.acquire()
        try:
            # Normally exceptions just unwind up the execution
            # stack and mpdb enters post mortem. It's more conveniant
            # here to just see if an exception occurred and drop the
            # user into a command loop.
            if sys.exc_info()[2]:
                self.user_exception(frame, sys.exc_info())
                # Once we've finished accepting user commands, exit.
                sys.exit()
            Gdb.trace_dispatch(self, frame, event, arg)
        finally:
            _main_debugger.lock.release()

    def user_line(self, frame):
        """ Override this method from pydb.pydbbdb.Bdb to make
        it thread-safe.
        """
        _check_and_switch_current()
        Gdb.user_line(self, frame)
  
    def user_call(self, frame, args):
        """ Override pydb.pydbbdb.Bdb.user_call and make it
        thread-safe.
        """
        _check_and_switch_current()
        Gdb.user_call(self, frame, args)

    def user_exception(self, frame, (exc_type, value, traceback)):
        """ Override this method from pydb.pydbbdb.Bdb to make
        it thread-safe.
        """
        Gdb.msg(self, 'Program raised exception %s' % exc_type.__name__)
        _check_and_switch_current()
        Gdb.user_exception(self, frame, (exc_type, value,
                                             traceback))
            
    def do_thread(self, arg):
        do_thread(arg)

    def do_info(self, arg):
        do_info(arg)

    def do_quit(self, arg):
        """ Override this method to ask the user whether they really
        want to quit.
        """
        if self.running:
            ret = get_confirmation(self,
                            'The program is running. Exit anyway? (y or n) ')
            if ret:
                for t in tt_dict.values():
                    if t != self.thread:
                        Gdb.do_quit(t, None)
                sys.exit()
            Gdb.msg(_main_debugger, 'Not confirmed.')

def trace_dispatch_init(frame, event, arg):
    """ This method is called by a sys.settrace when a thread is running
    for the first time. Setup this thread with a tracer object and
    set this thread's tracing function to that object's trace_dispatch
    method.
    """
    global tt_dict
    tr = MTracer()
    th = threading.currentThread().getName()

    tt_dict[th] = tr

    sys.settrace(tr.trace_dispatch)

def init(debugger):
    """ This method intialises thread debugging. It sets up a tracing
    method for newly created threads so that they call trace_dispatch_init,
    which hooks them up with a MTracer object. The argument 'debugger' is
    the debugger that is debugging the MainThread, i.e. the Python
    interpreter.
    """
    global _main_debugger, g_breaks, g_commands, g_dirname, g_program_sys_argv
    global g_search_path, g_sys_argv, current_thread
    
    if _main_debugger == None:
        _main_debugger = debugger

        current_thread = thread.get_ident()
        
        # This lock must be acquired when a MTracer object
        # places a call to _main_debugger.user_*
        _main_debugger.lock = threading.Lock()

        # Copy some state from the main debugger so that 
        # newly created debuggers can have the same.
        g_breaks = _main_debugger.breaks
        g_commands = _main_debugger.commands
        g_dirname = _main_debugger.main_dirname
        g_prompt = _main_debugger.prompt
        g_program_sys_argv = _main_debugger._program_sys_argv
        g_search_path = _main_debugger.search_path
        g_sys_argv = _main_debugger._sys_argv

        # Replace some of the mpdb methods with thread-safe ones
        _main_debugger.user_line = user_line
        _main_debugger.user_call = user_call
        _main_debugger.user_exception = user_exception
        _main_debugger.user_return = user_return

        _main_debugger.do_break = do_break
        _main_debugger.do_thread = do_thread
        _main_debugger.do_info = do_info

    global tt_dict
    tt_dict[threading.currentThread().getName()] = _main_debugger
    threading.settrace(trace_dispatch_init)

def _check_and_switch_current():
        """ Check to see if we're the current thread. If we're not,
        we change the current_thread global variable to indicate that
        now, we are the current thread.
        """
        global current_thread
        if current_thread == thread.get_ident():
            return
        else:
            Gdb.msg(_main_debugger, '\n[Switching to Thread %s]' %
                    thread.get_ident())
            current_thread = thread.get_ident()

# All the methods below override the methods from MPdb so
# that they are thread-safe. Every thread must contend for
# the Lock object on the MPdb instance.

def user_line(frame):
    _main_debugger.lock.acquire()
    try:
        _check_and_switch_current()
        Gdb.user_line(_main_debugger, frame)
    finally:
        _main_debugger.lock.release()


def user_call(frame, arg):
    _main_debugger.lock.acquire()
    try:
        _check_and_switch_current()
        Gdb.user_call(_main_debugger, frame, arg)
    finally:
        _main_debugger.lock.release()
    
def user_return(frame, return_value):
    _main_debugger.lock.acquire()
    try:
        _check_and_switch_current()
        Gdb.user_return(_main_debugger, frame, return_value)
    finally:
       _main_debugger.lock.release()
    
def user_exception(frame, (exc_type, exc_value, exc_traceback)):
    _main_debugger.lock.acquire()
    try:
        _check_and_switch_current()
        Gdb.user_exception(_main_debugger, frame, (exc_type, exc_value,
                                               exc_traceback))
    finally:
        _main_debugger.lock.release()

def do_break(arg, temporary=0):
    """Override the break command so that, by default,
    breakpoints are set in all threads.
    """
    global tt_dict

    if 'thread' in arg:
        # We're setting a thread-specific breakpoint
        # syntax: break linespec thread threadID
        args, threadcmd, threadID = arg.split()
        try:
            t = tt_dict[threadID]
        except KeyError:
            Gdb.errmsg(_main_debugger, 'Invalid thread ID')
        return

    Gdb.do_break(_main_debugger, arg, temporary)
    
def do_thread(arg):
    """Use this command to switch between threads.
The new thread ID must be currently known.

List of thread subcommands:

thread apply -- Apply a command to a list of threads

Type "help thread" followed by thread subcommand name for full documentation.
Command name abbreviations are allowed if unambiguous.
"""
    global tt_dict
    if not arg:
        Gdb.msg(_main_debugger,
                '[Current thread is (Thread %s)]' %
                thread.get_ident())
        return

    if 'apply' in arg:
        threadID = arg[6:arg.find(' ', 6)]
        cmd = arg[7+len(threadID):]
        global tt_dict
        try:
            tr = tt_dict[threadID]
            # XXX Maybe this wants to be, tr.do_cmd(params)?
            tr.onecmd(cmd)
        except KeyError:
            return

def print_thread_details():
    # The output here should be as close to the expected output of
    # 'thread' in GDB. Each message consits of,
    #  - A thread number assigned by mpdb
    #  - Stack frame summary
    threads = sys._current_frames()

    i = 0
    for t in threads.keys():
        f = threads[t]
        s = ""
        global current_thread
        if t == current_thread:
            s += "* "
        else: s += "  "
        s += str(t) + " in " + str(f.f_code.co_name) + \
             "() at " + f.f_code.co_filename + ":" + str(f.f_lineno)
        Gdb.msg(_main_debugger, s)


# Info subcommand and helper functions
def do_info(arg):
    """ Extends mpdb.do_info to include information about
    thread commands.
    """
    if not arg:
        Gdb.do_info(_main_debugger, arg)
        return

    args = arg.split()
    if 'thread'.startswith(args[0]) and len(args[0]) > 2:
        print_thread_details()
    else:
        Gdb.do_info(_main_debugger, arg)


