scribble

ottocho's blog

Home About GitHub

01 Sep 2016
Convert HTML table with rowspan-td to Python DICT

解析带有 rowspan 的 HTML table 为 Python 字典数据

需求来源是 Apple 的设备上报的设备类型标记不是直接的 iPhone 6 形制的内容,而是 iPhone 5;1,对应表格如下:

https://www.theiphonewiki.com/wiki/Models

于是准备解析获取到字典内容。遇到了一点小问题,故记录一下。

在此表中出现了带有 rowspan 属性的 td,因此无法直接简单的遍历 trtd 进行简单直接的数据获取。因此在此做了微末的工作。

思路是这样的:

  1. 先用 th 获得准确的数据列数和数据标题
  2. 使用一个字典,标注了具备 rowspantd 的内容和使用情况(rowspan 的数字表示此 td 应该被使用多少次)
  3. 在遍历 tr 过程中,优先获取使用还没被耗尽 rowspantd,再使用表格中的无 rowspantd

代码全文如下:

#coding:utf8

"""
Author:         ottocho
Filename:       getmodels.py
Last modified:  2016-08-22 02:18
Description:
    iPhone wiki about models and generations
    get model->generation map
"""

import sys
import urllib2

from pprint import pprint
from bs4 import BeautifulSoup

URL = 'https://www.theiphonewiki.com/wiki/Models'

def get_page_content(url=URL):
    response = urllib2.urlopen(url)
    html = response.read()
    return html

def parse_html(html):
    '''
    return format:
    { brand: list_of_generations }

    example:
    {
        u'iPod touch': [
            {
                u'"A" Number': u'A1213',
                u'Bootrom': u'Bootrom Rev.2',
                u'Color': u'Black',
                u'FCC ID': u'BCGA1213',
                u'Generation': u'iPod touch',
                u'Identifier': u'iPod1,1',
                u'Internal Name': u'N45AP',
                u'Model': u'A1213',
                u'Storage': u'iPod touch'
            },
            ...
        ],
        u'iPhone': [... ]
        ...
    }
    '''
    soup = BeautifulSoup(html, 'html.parser')
    all_table_data = {}

    for table in soup.find_all('table', class_='wikitable'):
        _table_data = []

        # get the `brand` from the head(H2 tag)
        _h2 = table.find_previous_sibling('h2')
        brand = _h2.find('a').contents[0]

        # fill the tds to solve the `rowspan` problem
        tmp_tds = {}
        n_column = 0
        names_column = []

        # collect data from each row(tr)
        for n, tr in enumerate(table.find_all('tr')):
            if n == 0:
                # initialize the tmp data by the first row
                ths = tr.find_all('th')
                names_column = [ th.contents[0].strip() for th in ths ]
                n_column = len(names_column)
                tmp_tds = dict((i, None) for i in range(n_column))
                continue

            tds = tr.find_all('td')
            if n == 1:
                # initialize the `tmp_tds` by the second row
                for idx, td in enumerate(tds):
                    if td.has_attr("rowspan"):
                        rowspan = int(td.attrs['rowspan'])
                        tmp_tds[idx] = [td, rowspan]
            _iter = iter(tds)
            _colected_tds = []
            for i in range(n_column):
                if tmp_tds.get(i):
                    tmp_tds[i][1] -= 1
                    td, cols = tmp_tds[i]
                    _colected_tds.append(td)
                    if cols == 0:
                        tmp_tds[i] = None
                else:
                    td = _iter.next()
                    # append the `rowspan` td
                    if td.has_attr("rowspan"):
                        _colected_tds.append(td)
                        rowspan = int(td.attrs['rowspan'])-1
                        if rowspan > 0:
                            tmp_tds[i] = [td, rowspan]
                    else:
                        _colected_tds.append(td)

            # append each row data to table data
            _row_data = {}
            for i in xrange(n_column):
                value = None
                td = _colected_tds[i]
                if td.find('a'):
                    value = td.find('a').contents[0].strip()
                else:
                    value = td.contents[0].strip()
                _row_data[names_column[i]] = value
            _table_data.append(_row_data)
        all_table_data[brand] = _table_data
    return all_table_data

def print_table_csv(table_data):
    spliter = '|'
    print spliter.join(('Identifier', 'Generation', 'Device'))
    for brand, dlist in table_data.iteritems():
        _bd = dict((_d['Identifier'], _d['Generation']) for _d in dlist)
        for _i, _g in _bd.iteritems():
            print spliter.join((_i, _g, brand))

def main():
    html = get_page_content()
    table_data = parse_html(html)
    print_table_csv(table_data)

if __name__ == '__main__':
    main()

Til next time,
at 10:23

scribble

Home About GitHub