
"""
Uninformed search algorithms.

By Connelly Barnes, 2005.  Version 1.0.2.  Public domain.

All search algorithms return a list for a path of states from the
initial state to a goal state.  If the graph argument is True, states
are stored in a visited set, so duplicate states are not revisited.
"""

import heapq
from collections import deque

__all__ = ['State', 'path_cost', 'breadth_first', 'uniform_cost',
           'depth_first', 'depth_limited', 'iterative_deepening',
           'bidirectional']

class State:
  """
  Interface for state of environment, and state transitions.
  """
  def successors(self):
    """
    Get iterable of successor states.
    """
    raise NotImplementedError

  def cost(self, other):
    """
    Cost of state self -> other (only needed for uniform_cost).
    """
    raise NotImplementedError

  def is_goal(self):
    """
    True iff self is a goal state.
    """
    raise NotImplementedError

  def predecessors(self):
    """
    Iterable of predecessors (only needed for bidirectional).
    """
    raise NotImplementedError


def path_cost(path):
  """
  Sum of edge costs for a given path (iterable) of states.
  """
  return sum([path[i].cost(path[i+1]) for i in xrange(len(path)-1)])


def flatten(L):
  """
  Given linked list of the form [[[[],1],2],3], return [1,2,3].
  """
  ans = []
  while len(L) > 0:
    ans.append(L[-1])
    L = L[0]
  return ans[::-1]


def successors_function(state, graph):
  """
  Get normal or loop-breaking successor function as appropriate.
  """
  if not graph:
    return state.__class__.successors
  visited = set()
  def f(self):
    ans = [x for x in self.successors() if x not in visited]
    visited.update(ans)
    return ans
  return f


def breadth_first(initial_state, graph=False):
  """
  Breadth-first search.
  """
  successors = successors_function(initial_state, graph)
  q = deque([((),initial_state)])   # FIFO of paths.
  while True:
    path = q.popleft()
    state = path[-1]
    if state.is_goal():
      return flatten(path)
    q.extend([(path, x) for x in successors(state)])


def uniform_cost(initial_state, graph=False):
  """
  Uniform cost search.
  """
  q = [(0, ((),initial_state))]   # Stack of (cost, path) tuples.
  if not graph:
    while True:
      (cost, path) = heapq.heappop(q)
      state = path[-1]
      if state.is_goal():
        return flatten(path)
      for x in state.successors():
        heapq.heappush(q, (cost + state.cost(x), (path, x)))
  else:
    visited = set()
    while True:
      (cost, path) = heapq.heappop(q)
      state = path[-1]
      if not state in visited:
        visited.add(state)
        if state.is_goal():
          return flatten(path)
        for x in state.successors():
          if x not in visited:
            heapq.heappush(q, (cost + state.cost(x), (path, x)))


def depth_first(initial_state, graph=False):
  """
  Depth-first search.
  """
  return depth_limited(initial_state, (), graph)


def depth_limited(initial_state, max_depth, graph=False):
  """
  Depth-limited search (depth-first).
  """
  successors = successors_function(initial_state, graph)
  q = [(0, ((),initial_state))]     # Stack of (depth, path) tuples.
  while True:
    (depth, path) = q.pop()
    state = path[-1]
    if state.is_goal():
      return flatten(path)
    if depth < max_depth:
      q.extend([(depth+1,(path,x)) for x in successors(state)][::-1])


def iterative_deepening(initial_state, graph=False):
  """
  Iterative deepening search (depth-first).
  """
  depth = 0
  while True:
    try:
      return depth_limited(initial_state, depth, graph)
    except IndexError:
      depth += 1


def bidirectional(initial_state, goal_state, graph=False):
  """
  Bidirectional search (two simultaneous breadth-first searches).
  """
  q1 = deque([((),initial_state)]) # FIFO of paths for first tree.
  q2 = deque([((),goal_state)])    # FIFO of paths for second tree.
  visited1 = {}                    # Maps state->path for visited.
  visited2 = {}
  while True:
    path1 = q1.popleft()
    path2 = q2.popleft()
    state1 = path1[-1]
    state2 = path2[-1]
    visited1[state1] = path1
    visited2[state2] = path2
    if state1 in visited2:
      return flatten(path1)[:1] + flatten(visited2[state1])[::-1]
    if state2 in visited1:
      return flatten(visited1[state2])[:-1] + flatten(path2)[::-1]
    if not graph:
      q1.extend([(path1, x) for x in state1.successors()])
      q2.extend([(path2, x) for x in state2.predecessors()])
    else:
      q1.extend([(path1, x) for x in state1.successors() if
                 x not in visited1])
      q2.extend([(path2, x) for x in state2.predecessors() if
                 x not in visited2])
