User:The Anome/code/nhpn to kml.py

# Copyright (c) The Anome 2012

# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

### This is a quick-hack decoder for the U.S. National Highway Planning Network
### shapefiles at http://www.fhwa.dot.gov/planning/nhpn/
###
### Please note that this is early-stage test code only.

import string, urllib, re, math

from osgeo import ogr
import nhpn_values

SEGMENTS_BEFORE_STOP = 1000

# Scan the entire KML file, building feature dicts and point lists
def scan_file():
    global featuredicts, line_entries
    driver = ogr.GetDriverByName("ESRI Shapefile")

    # Run this from within the same directory as the shapefiles...
    source = driver.Open("NHPNLine.shp", 0)
    layer = source.GetLayer(0)
    
    #print "Hello world!", layer.GetFeatureCount(), layer.GetExtent()

    n = 0
    feature = layer.GetNextFeature()
    while feature:
        n += 1
        if (n%5000) == 0:
            print "done", n
            # for debugging: disable in real run
            #if n >= SEGMENTS_BEFORE_STOP:
            #    return

        featuredict = {}
        count = feature.GetFieldCount()
        for offset in range(count):
            fieldname, fieldval = feature.GetFieldDefnRef(offset).GetName(), feature.GetField(offset)
            key = "%s %s" % (fieldname, fieldval)
            fieldval = nhpn_values.valuecode.get(key, fieldval)
            if fieldval != None:
                featuredict[fieldname] =  fieldval
        
        if 1:
            geom = feature.GetGeometryRef()

            if geom.GetGeometryType() == 2:
                line_entry = []
                for p in range(geom.GetPointCount()):
                    line_entry.append((geom.GetX(p), geom.GetY(p)))
                featuredicts[n] = featuredict
                line_entries[n] = line_entry
                # print featuredict
                
#            geom.Destroy() -- can't do this: causes segmentation fault!

        feature.Destroy()
        feature = layer.GetNextFeature()
    source.Destroy()

# convert float to ASCII
KML_POINT_FORMAT = "%.6f,%.6f"
def format_xy_as_text(pairs):
    return string.join([KML_POINT_FORMAT % p for p in pairs])

def gen_road_names(featuredict, n):
    names = []
    for idx in [1, 2, 3]:
        roadname = featuredict.get("SIGNN%d" % idx)
        roadtype = featuredict.get("SIGNT%d" % idx)

        # Ignore every road type but the following
        if roadtype not in ["US Route", "Interstate", "State Route"]:
            roadtype = None

        # Question: is a "State Route" the same as a "State Road??"

        # Now reformat names to Wikipedia conventions
        if roadtype == "US Route":
            roadtype = "U.S. Route"

        # "State Routes" seem to have different names state-to-state: these
        # need fixing, and then cross-checking with Wikipedia conventions
        if roadtype == "State Route":
            roadtype = featuredict["STFIPS"] + " " + roadtype

        if roadtype != None:
            names.append(string.join([str(roadtype), str(roadname)]))
    return map(check_safe_xmltext, names)

def gen_files_index():
    global featuredicts, line_entries, file_entries, road_names
    for n in featuredicts:
        road_names[n] = gen_road_names(featuredicts[n], n)
        for name in road_names[n]:
            file_entries[name] = file_entries.get(name, []) + [n]

gensym_val = 1000000

def gensym():
    global gensym_val
    gensym_val += 1
    return gensym_val

# Segment merging logic goes here...

# Pythagorean distance between two points
def dist(point1, point2):
    x1, y1 = point1
    x2, y2 = point2
    return math.sqrt((x2-x1)**2 + (y2-y1)**2)

# Debugging routine
# This performs an exhaustive scan of all pairs of endpoints
def scan_segments_endpoint_distances(segments):
    # Make an exhaustive list of all endpoints
    seg_endpoints = []
    for seg_id in segments:
        (placemark_name, placemark_description, seg_entries) = segments[seg_id]
        seg_endpoints.append((seg_id, "start", seg_entries[0]))
        seg_endpoints.append((seg_id, "end", seg_entries[-1]))

    # Having made the index, scan through it for all endpoints
    can_reduce_by = 0
    for seg1 in seg_endpoints:
        for seg2 in seg_endpoints:
            if seg1 > seg2:
                seg_id1, which1, seg_point1 = seg1
                seg_id2, which1, seg_point2 = seg2
                if dist(seg_point1, seg_point2) < 0.0001:
                    print "distance between", seg1, seg2, "=", dist(seg_point1, seg_point2)
                    desc1, desc2 = (segments[seg_id1][1], segments[seg_id2][1])
                    print "descriptions:", (desc1, desc2)
                    if desc1 == desc2:
                        can_reduce_by += 1
    return can_reduce_by

KEY_FORMAT = KML_POINT_FORMAT # operate at this precision for endpoint matching

BIN_FORMAT = "%s|%s|%s"

# Attempt to merge a number of pointlist segments for a single KML
# file Does not need to get them all first time, and attempting to do
# so complicates the logic substantially: much easier to do multiple
# passes until no change.
def merge_segments(segments):
    # Empty segment list, or only one segment? Then there's nothing to do
    if len(segments) <= 1:
        return segments

    # Make an exhaustive list of all endpoints
    seg_endpoint_bins = {}
    for seg_id in segments:
        placemark_name, placemark_description, seg_entries = segments[seg_id]

        start_point = (seg_id, "start", seg_entries[0])
        end_point = (seg_id, "end", seg_entries[-1])
        
        start_bin = BIN_FORMAT % (KEY_FORMAT % start_point[2], placemark_name, placemark_description)
        end_bin = BIN_FORMAT % (KEY_FORMAT % end_point[2], placemark_name, placemark_description)

        seg_endpoint_bins[start_bin] = seg_endpoint_bins.get(start_bin, []) + [start_point]
        seg_endpoint_bins[end_bin] = seg_endpoint_bins.get(end_bin, []) + [end_point]

    # Make sure we only merge a segment once in a pass
    segments_merged = {}

    # Now we should have a limited number of collisions
    for bin in seg_endpoint_bins.keys():
        if len(seg_endpoint_bins[bin]) < 2:
            continue
        
        # Only look at the first two points in the bin
        seg_id1, which1, point1 = seg_endpoint_bins[bin][0]
        seg_id2, which2, point2 = seg_endpoint_bins[bin][1]

#        print "found a possible matchable pair:"
#        print "seg_ids:", seg_id1, seg_id2

        # Get the actual segment data
        seg_name1, seg_desc1, seg_entries1 = segments[seg_id1]
        seg_name2, seg_desc2, seg_entries2 = segments[seg_id2]

        # Double-check for compatibility
        if (seg_name1, seg_desc1) != (seg_name1, seg_desc2):
            raise Exception("attempt to merge dissimilar segments: should never happen")

        # Don't try to merge segments we have changed: their entries in the
        # endpoint table will be bogus -- we will get them in the next iteration...
        if (seg_id1 in segments_merged) or (seg_id2 in segments_merged):
            continue

        # Don't try to merge empty segments
        if (len(seg_entries1) == 0) or (len(seg_entries2) == 0):
            continue
        
#        print "preparing to merge"
        
        # Make a note that these segments are going to change, invalidating
        # their endpoints in the endpoint table
        segments_merged[seg_id1] = 1
        segments_merged[seg_id2] = 1

        # Now we want to get them in end/start order, with the matching point at the middle, so we can merge
        if which1 == "start":
#            print "reversing seg1"
            seg_entries1 = [x for x in reversed(seg_entries1)]
            which1 = "end"
        if which2 == "end":
#            print "reversing seg2"
            seg_entries2 = [x for x in reversed(seg_entries2)]
            which2 = "start"

#        print "performing the merge of", seg_id1, "and", seg_id2

        # Concatenate, and reduce second sequence to a stub
        # Stubbing means segments do not disappear, thus no need to manage key deletion on the fly
        segments[seg_id1] = (seg_name1, seg_desc1, seg_entries1 + seg_entries2[1:])
        segments[seg_id2] = (seg_name2, seg_desc2, [])

    # Scan is finished, so purge all the empty segments we just created
    empties_ids = [x for x in segments.keys() if segments[x][2] == []]
    for id in empties_ids:
        del segments[id]

    return segments

def wikitag(x):
    pagename = check_safe_xmltext(x)
    link = check_safe_filename(urllib.quote_plus(string.replace(x, " ", "_"), ""))
    return '<a href="http://en.wikipedia.org/wiki/%s">%s</a>' % (link, pagename)

def html_escape(text):
    text = string.replace(text, "&", "&amp;")
    text = string.replace(text, "<", "&lt;")
    text = string.replace(text, ">", "&gt;")
    return text

def gen_file(roadname):
    # Generate a set of road segments, with their labels
    segments = {}
    for n in file_entries[roadname]:
        all_roads = road_names[n]
        other_roads = [r for r in all_roads if r != roadname]
        placemark_name = check_safe_xmltext(roadname)
        if other_roads:
            placemark_description = html_escape(wikitag(roadname) + " with " + string.join([wikitag(x) for x in other_roads], ", "))
        else:
            placemark_description = html_escape(wikitag(roadname))

        # Let's treat this as a bag of logical segments, with no implied order
        segments[gensym()] = (placemark_name, placemark_description, line_entries[n])
                        
    # Now repeat joining passes until no further good is done
    n_segments_at_start = len(segments.keys())
    while 1:
        n_segments_before = len(segments.keys())
        segments = merge_segments(segments)
        n_segments_after = len(segments.keys())
        if n_segments_before == n_segments_after:
            break
    print repr(roadname), n_segments_at_start, n_segments_after, "reduction %f" % (float(n_segments_at_start - n_segments_after)/n_segments_at_start)
    
    # Then do an exhaustive scan to see what we missed
#    can_reduce_by = scan_segments_endpoint_distances(segments)
#    print "SEGMENTS LEFT", n_segments_after, "CAN REDUCE", can_reduce_by, "LEAVING", n_segments_after-can_reduce_by
#    print

    # Format this list of segments into a KML string
    data = ['<?xml version="1.0" encoding="utf-8" ?>\n<kml xmlns="http://www.opengis.net/kml/2.2">\n<Document>']
    for segk in segments.keys():
        (placemark_name, placemark_description, seg_entries) = segments[segk]
        data.append("<Placemark>\n<name>%s</name>\n<description>%s</description>" % (placemark_name, placemark_description))
        data.append("<LineString>\n<coordinates>")
        data.append(format_xy_as_text(seg_entries))
        data.append("</coordinates>\n</LineString>\n</Placemark>")
    data.append("</Document>\n</kml>\n")
    return string.join(data, "\n")

# Belt-and-braces that filenames are OK, since we are getting these strings out of a database
def check_safe_filename(x):
    matches = re.findall(r"[-A-Za-z0-9._]+", x)
    if string.join(string.split(string.strip(string.replace(x, "_", " "))), "_") != x:
        raise Exception("filename was unwikisafe")
    if matches != [x]:
        raise Exception("filename was unsafe")
    return x

# We have a strict whitelist about the characters we will allow in our KML text
def check_safe_xmltext(x):
    matches = re.findall(r"[-A-Za-z0-9., ]+", x)
    if matches != [x]:
        raise Exception("text was unsafe: " + repr(x))
    return x

##
## Main program
## 

featuredicts = {}
line_entries = {}
file_entries = {}
road_names = {}
 
def main():
    scan_file()
    gen_files_index()
    for name in file_entries:
        filename = "output/" + check_safe_filename(urllib.quote_plus(string.replace(name, " ", "_"), "") + ".kml")
        open(filename, "w").write(gen_file(name))

main()