aboutsummaryrefslogtreecommitdiffstats
path: root/schedcat/util/csv.py
blob: 4482f0a67d13f9ab7b893b4c1a127eb1157e6cf4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from __future__ import absolute_import

import csv

from .storage import storage

def load_columns(fname,
                 convert=lambda x: x,
                 expect_uniform=True):
    """Load a file of CSV data. The first row is assumed
    to contain column labels. These labels can then be used to
    reference individual columns.
    
    x = load_column_csv(...)
    x.by_name -> columns by name
    x.by_idx  -> columns by index in the file
    x.columns -> all columns
    """
    if isinstance(fname, str):
        f = open(fname)
    else:
        # assume we got a file object
        f = fname
    d = list(csv.reader(f))
    if fname != f:
        f.close()

    # infer column labels
    col_idx = {}
    for i, key in enumerate(d[0]):
        col_idx[key.strip()] = i

    max_idx = i

    data    = d[1:]

    if expect_uniform:
        for row in data:
            if len(row) != max_idx + 1:
                print len(row), max_idx
                msg = "expected uniform row length (%s:%d)" % \
                    (fname, data.index(row) + 1)
                raise IOError, msg # bad row length

    # column iterator
    def col(i):
        for row in data:
            if row:
                yield convert(row[i])

    by_col_name = {}
    by_col_idx  = [0] * (max_idx + 1)

    for key in col_idx:
        by_col_name[key] = list(col(col_idx[key]))
        by_col_idx[col_idx[key]] = by_col_name[key]

    return storage(name=fname, columns=col_idx,
                   by_name=by_col_name, by_idx=by_col_idx)