From d917e869c13bc21cd47ad6f516628865bc218ece Mon Sep 17 00:00:00 2001
From: Thomas Kramer <code@tkramer.ch>
Date: Mon, 29 Apr 2024 18:28:38 +0200
Subject: [PATCH] graph router: pack arguments into GraphRoutingProblem class

---
 src/lclayout/graphrouter/graphrouter.py | 29 ++++++++--
 src/lclayout/graphrouter/hv_router.py   | 71 ++++++++++---------------
 src/lclayout/graphrouter/inspect.py     | 24 +++------
 src/lclayout/graphrouter/pathfinder.py  | 41 ++++----------
 src/lclayout/router.py                  | 17 +++---
 5 files changed, 79 insertions(+), 103 deletions(-)

diff --git a/src/lclayout/graphrouter/graphrouter.py b/src/lclayout/graphrouter/graphrouter.py
index 0b6f43c..b4e32ec 100644
--- a/src/lclayout/graphrouter/graphrouter.py
+++ b/src/lclayout/graphrouter/graphrouter.py
@@ -7,17 +7,19 @@ import networkx as nx
 
 from typing import Any, Dict, List, AbstractSet, Optional, Iterable
 
+class GraphRoutingProblem:
+    """
+    Representation of a multi-signal routing problem in a graph.
+    """
 
-class GraphRouter:
-
-    def route(self,
+    def __init__(self, 
               graph: nx.Graph,
               signals: Dict[Any, List[Any]],
               reserved_nodes: Optional[Dict[Any, AbstractSet[Any]]] = None,
               node_conflict: Optional[Dict[Any, AbstractSet[Any]]] = None,
               equivalent_nodes: Optional[Dict[Any, AbstractSet[Any]]] = None,
               is_virtual_node_fn=None
-              ) -> Iterable[Dict[Any, nx.Graph]]:
+                 ):
         """
 
         :param graph: Routing graph.
@@ -28,4 +30,23 @@ class GraphRouter:
         :param is_virtual_node_fn: Function that returns True iff the argument is a virtual node.
         :return: Returns a dict mapping signal names to routing trees.
         """
+        
+        self.graph: nx.Graph = graph
+        self.signals: Dict[Any, List[Any]] = signals
+        self.reserved_nodes: Optional[Dict[Any, AbstractSet[Any]]] = reserved_nodes
+        self.node_conflict: Optional[Dict[Any, AbstractSet[Any]]] = node_conflict
+        self.equivalent_nodes: Optional[Dict[Any, AbstractSet[Any]]] = equivalent_nodes
+        self.is_virtual_node_fn=is_virtual_node_fn
+
+        if self.node_conflict is None:
+            self.node_conflict = dict()
+        if self.equivalent_nodes is None:
+            self.equivalent_nodes = dict()
+        if self.reserved_nodes is None:
+            self.reserved_nodes = dict()
+
+
+class GraphRouter:
+
+    def route(self, routing_problem: GraphRoutingProblem) -> Iterable[Dict[Any, nx.Graph]]:
         pass
diff --git a/src/lclayout/graphrouter/hv_router.py b/src/lclayout/graphrouter/hv_router.py
index 9f65fc8..fb4bd9a 100644
--- a/src/lclayout/graphrouter/hv_router.py
+++ b/src/lclayout/graphrouter/hv_router.py
@@ -8,7 +8,7 @@ from itertools import chain, combinations, product
 
 from typing import *
 import logging
-from .graphrouter import GraphRouter
+from .graphrouter import GraphRouter, GraphRoutingProblem
 
 logger = logging.getLogger(__name__)
 
@@ -24,19 +24,12 @@ class HVGraphRouter(GraphRouter):
         self.orientation_change_penalty = orientation_change_penalty
 
     def route(self,
-              graph: nx.Graph,
-              signals: Dict[Any, List[Any]],
-              reserved_nodes: Optional[Dict[Any, AbstractSet[Any]]] = None,
-              node_conflict: Optional[Dict[Any, AbstractSet[Any]]] = None,
-              equivalent_nodes: Optional[Dict[Any, AbstractSet[Any]]] = None,
-              is_virtual_node_fn=None
+              routing_problem: GraphRoutingProblem
               ) -> Iterable[Dict[Any, nx.Graph]]:
         return _route_hv(self.sub_graphrouter,
-                         graph,
-                         signals=signals,
-                         reserved_nodes=reserved_nodes,
-                         node_conflict=node_conflict,
-                         is_virtual_node_fn=is_virtual_node_fn)
+                         routing_problem,
+                         orientation_change_penalty=self.orientation_change_penalty
+                         )
 
 
 def _build_hv_routing_graph(graph: nx.Graph, orientation_change_penalty=1) -> Tuple[nx.Graph, Dict, Dict]:
@@ -139,12 +132,8 @@ def _flatten_hv_graph(hv_graph: nx.Graph, reverse_mapping: Dict) -> nx.Graph:
 
 
 def _route_hv(router: GraphRouter,
-              graph: nx.Graph,
-              signals: Dict[Any, List[Any]],
+              pr: GraphRoutingProblem,
               orientation_change_penalty: float = 1,
-              node_conflict: Dict[Any, Set[Any]] = None,
-              reserved_nodes: Optional[Dict[Any, AbstractSet[Any]]] = None,
-              is_virtual_node_fn=None,
               **kw) -> Iterable[Dict[Any, nx.Graph]]:
     """ Global routing with corner avoidance.
     Corners (changes between horizontal/vertical tracks) are avoided by transforming the routing graph `G`
@@ -154,28 +143,22 @@ def _route_hv(router: GraphRouter,
     ----------
     :param graph: Routing graph with edge orientation information.
             Edge orientation must be stored in the 'orientation' field of the networkx edge data.
-    :param signals: A dict mapping signal names to signal terminals.
-    :param orientation_change_penalty: Cost for changes between different orientations.
-    :param reserved_nodes: An optional dict which specifies nodes that are reserved for a specific net.
-    Dict[net_name, set of nodes].
+    :param routing_problem: 
     :param kw: Parameters to be passed to underlying routing function.
 
     """
 
-    assert isinstance(signals, dict)
+    assert isinstance(pr.signals, dict)
     logger.debug('Start global routing with corner avoidance.')
 
     H, node_mapping, node_mapping_reverse = _build_hv_routing_graph(
-        graph,
+        pr.graph,
         orientation_change_penalty=orientation_change_penalty
     )
     reserved_nodes_h = None
-    if reserved_nodes is not None:
+    if pr.reserved_nodes is not None:
         reserved_nodes_h = {net: list(chain(*(node_mapping[n].values() for n in nodes))) for net, nodes in
-                            reserved_nodes.items()}
-
-    if node_conflict is None:
-        node_conflict = dict()
+                            pr.reserved_nodes.items()}
 
     # For each node find other nodes that are equivalent when mapped back.
     equivalent_nodes = {
@@ -191,25 +174,28 @@ def _route_hv(router: GraphRouter,
         conflicts = set()
         conflicts.update(node_mapping[n_g].values())
 
-        if n_g in node_conflict:
-            conflicts_g = set(node_conflict[n_g])
+        if n_g in pr.node_conflict:
+            conflicts_g = set(pr.node_conflict[n_g])
             for n in conflicts_g:
                 conflicts.update(node_mapping[n].values())
 
         node_conflict_h[n_h] = conflicts
 
-    signals_h = {net: [node_mapping[t][None] for t in terminals] for net, terminals in signals.items()}
+    signals_h = {net: [node_mapping[t][None] for t in terminals] for net, terminals in pr.signals.items()}
 
     def _is_virtual_node_fn(n) -> bool:
-        return is_virtual_node_fn(node_mapping_reverse[n])
+        return pr.is_virtual_node_fn(node_mapping_reverse[n])
 
     assert nx.is_connected(H)
-    solutions = router.route(H, signals_h,
+
+    hv_problem = GraphRoutingProblem(graph=H, signals=signals_h,
                                    reserved_nodes=reserved_nodes_h,
                                    node_conflict=node_conflict_h,
                                    equivalent_nodes=equivalent_nodes,
-                                   is_virtual_node_fn=_is_virtual_node_fn,
-                                   **kw)
+                                   is_virtual_node_fn=_is_virtual_node_fn
+                               )
+    
+    solutions = router.route(hv_problem, **kw)
     for routing_trees_h in solutions:
     
         if routing_trees_h is None:
@@ -226,13 +212,12 @@ def _route_hv(router: GraphRouter,
 
         # Assert that reserved_nodes is respected.
         # A node that is reserved for a signal should not be used by another signal.
-        if reserved_nodes:
-            for net, nodes in reserved_nodes.items():
-                for rt_net, rt in routing_trees.items():
-                    if net != rt_net:
-                        for n in nodes:
-                            assert n not in rt.nodes, \
-                                "Node %s is reserved for net %s but has been used for net %s." % (
-                                    n, net, rt_net)
+        for net, nodes in pr.reserved_nodes.items():
+            for rt_net, rt in routing_trees.items():
+                if net != rt_net:
+                    for n in nodes:
+                        assert n not in rt.nodes, \
+                            "Node %s is reserved for net %s but has been used for net %s." % (
+                                n, net, rt_net)
 
         yield routing_trees
diff --git a/src/lclayout/graphrouter/inspect.py b/src/lclayout/graphrouter/inspect.py
index beabe3f..b72fade 100644
--- a/src/lclayout/graphrouter/inspect.py
+++ b/src/lclayout/graphrouter/inspect.py
@@ -11,7 +11,7 @@ import networkx as nx
 
 from typing import *
 import logging
-from .graphrouter import GraphRouter
+from .graphrouter import GraphRouter, GraphRoutingProblem
 
 import matplotlib.pyplot as plt
 
@@ -33,22 +33,12 @@ class InspectRouter(GraphRouter):
         self._graph = None
         self._is_virtual_node_fn = None
 
-    def route(self,
-              graph: nx.Graph,
-              signals: Dict[Any, List[Any]],
-              reserved_nodes: Optional[Dict[Any, AbstractSet[Any]]] = None,
-              node_conflict: Optional[Dict[Any, AbstractSet[Any]]] = None,
-              equivalent_nodes: Optional[Dict[Any, AbstractSet[Any]]] = None,
-              is_virtual_node_fn=None
-              ) -> Iterable[Dict[Any, nx.Graph]]:
-
-        self._graph = graph
-        self._is_virtual_node_fn = is_virtual_node_fn
-
-        solutions = self.sub_graphrouter.route(
-                            graph=graph, signals=signals, reserved_nodes=reserved_nodes, 
-                            node_conflict=node_conflict, equivalent_nodes=equivalent_nodes, 
-                            is_virtual_node_fn=is_virtual_node_fn)
+    def route(self, routing_problem: GraphRoutingProblem) -> Iterable[Dict[Any, nx.Graph]]:
+
+        self._graph = routing_problem.graph
+        self._is_virtual_node_fn = routing_problem.is_virtual_node_fn
+
+        solutions = self.sub_graphrouter.route(routing_problem)
 
         for routes in solutions:
             self._inspect(routes)
diff --git a/src/lclayout/graphrouter/pathfinder.py b/src/lclayout/graphrouter/pathfinder.py
index c710b59..090770e 100644
--- a/src/lclayout/graphrouter/pathfinder.py
+++ b/src/lclayout/graphrouter/pathfinder.py
@@ -1,11 +1,10 @@
-# Copyright 2019-2020 Thomas Kramer.
-# SPDX-FileCopyrightText: 2022 Thomas Kramer
+# SPDX-FileCopyrightText: 2019-2024 Thomas Kramer
 #
 # SPDX-License-Identifier: CERN-OHL-S-2.0
 
 import networkx as nx
 import numpy as np
-from .graphrouter import GraphRouter
+from .graphrouter import GraphRouter, GraphRoutingProblem
 from .signal_router import SignalRouter
 from .multi_via_router import MultiViaRouter
 
@@ -33,40 +32,18 @@ class PathFinderGraphRouter(GraphRouter):
         self.detail_router = detail_router
         self.max_iterations = max_iterations
 
-    def route(self,
-              graph: nx.Graph,
-              signals: Dict[Any, List[Any]],
-              reserved_nodes: Optional[Dict] = None,
-              node_conflict: Optional[Dict[Any, AbstractSet[Any]]] = None,
-              equivalent_nodes: Optional[Dict[Any, AbstractSet[Any]]] = None,
-              is_virtual_node_fn=None
-              ) -> Dict[Any, nx.Graph]:
+    def route(self, p: GraphRoutingProblem) -> Dict[Any, nx.Graph]:
         """ Route multiple signals in the graph.
         Based on PathFinder algorithm.
-
-        Parameters
-        ----------
-        :param is_virtual_node_fn: A function which tells wether a node is 'virtual'.
-        :param equivalent_nodes: An optional mapping from a node n to a set of nodes which are equivalent to the node n.
-            This is used for the HVGraphRouter which splits some nodes into multiple nodes which are mutually exclusive.
-        :param graph : networkx.Graph
-                        Graph representing the routing grid.
-        :param signals : Dict[node name, List[node]]
-                        Signals to be routed. Each signal is represented by its terminal nodes.
-        :param reserved_nodes: An optional dict which specifies nodes that are reserved for a specific net.
-        Dict[net_name, set of nodes].
-        :param node_conflict: Dict[node, Set[node]]
-        Tells which other nodes are blocked by a node. A node might block its direct neigbhours to ensure minimum spacing.
-
         :returns : A list of `networkx.Graph`s representing the routes of each signal.
         """
         return _route(self.detail_router,
-                      graph,
-                      signals=signals,
-                      reserved_nodes=reserved_nodes,
-                      node_conflict=node_conflict,
-                      equivalent_nodes=equivalent_nodes,
-                      is_virtual_node_fn=is_virtual_node_fn,
+                      p.graph,
+                      signals=p.signals,
+                      reserved_nodes=p.reserved_nodes,
+                      node_conflict=p.node_conflict,
+                      equivalent_nodes=p.equivalent_nodes,
+                      is_virtual_node_fn=p.is_virtual_node_fn,
                       max_iterations=self.max_iterations)
 
 
diff --git a/src/lclayout/router.py b/src/lclayout/router.py
index a02b824..a54e50e 100644
--- a/src/lclayout/router.py
+++ b/src/lclayout/router.py
@@ -2,7 +2,7 @@
 #
 # SPDX-License-Identifier: CERN-OHL-S-2.0
 
-from .graphrouter.graphrouter import GraphRouter
+from .graphrouter.graphrouter import GraphRouter, GraphRoutingProblem
 from .routing_graph import *
 from . import tech_util
 from .lvs import lvs
@@ -591,12 +591,15 @@ class DefaultRouter():
                 virtual_terminal_nodes = {net: virtual_terminal_nodes[net] for net in routing_nets}
 
             # Invoke router.
-            solutions = self.router.route(graph,
-                                              signals=virtual_terminal_nodes,
-                                              reserved_nodes=reserved_nodes,
-                                              node_conflict=conflicts,
-                                              is_virtual_node_fn=_is_virtual_node_fn
-                                              )
+            routing_problem = GraphRoutingProblem(
+                            graph,
+                            signals=virtual_terminal_nodes,
+                            reserved_nodes=reserved_nodes,
+                            node_conflict=conflicts,
+                            is_virtual_node_fn=_is_virtual_node_fn
+                        )
+
+            solutions = self.router.route(routing_problem)
 
             for routing_trees in solutions:
                 if routing_trees is None:
-- 
GitLab