aboutsummaryrefslogtreecommitdiffstats
path: root/plot_exps.py
blob: 8fbef99e11c5bc31fef2fe45fecac6ae59b208e7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
from __future__ import print_function

import matplotlib.pyplot as plot
import os
import shutil as sh
import sys
from collections import namedtuple
from optparse import OptionParser
from parse.col_map import ColMap,ColMapBuilder
from parse.dir_map import DirMap
from plot.style import StyleMap
from multiprocessing import Pool, cpu_count

def parse_args():
    parser = OptionParser("usage: %prog [options] [csv_dir]...")

    parser.add_option('-o', '--out-dir', dest='out_dir',
                      help='directory for plot output', default='plot-data')
    parser.add_option('-f', '--force', action='store_true', default=False,
                      dest='force', help='overwrite existing data')

    return parser.parse_args()

ExpDetails = namedtuple('ExpDetails', ['variable', 'value', 'title',
                                       'out', 'node'])
OUT_FORMAT = 'pdf'

def get_details(node, path, out_dir):
    '''Decode a @path into details about a single experiment.'''
    out = "_".join(path) if path else "plot"
    out = "%s/%s.%s" % (out_dir, out, OUT_FORMAT)

    value = path.pop(0) if path else None
    variable = path.pop(0) if path else None

    title  = value.capitalize() if value else ""
    title += " by %s" % variable if variable else ""
    title += " (%s)" % (", ".join(path)) if path else ""

    return ExpDetails(variable, value, title, out, node)

def plot_by_variable(details):
    '''Plot each .csv files under @plot_node as a line on a shared plot.'''

    builder = ColMapBuilder()
    config_nodes = []

    # Generate mapping of (column)=>(line property to vary) for consistently
    # formatted plots
    for line_path, line_node in details.node.children.iteritems():
        encoded = line_path[:line_path.index(".csv")]
        line_config = ColMap.decode(encoded)

        for k, v in line_config.iteritems():
            builder.try_add(k, v)
        config_nodes += [(line_config, line_node)]

    col_map   = builder.build()
    style_map = StyleMap(col_map.columns(), col_map.get_values())

    figure = plot.figure()
    axes = figure.add_subplot(111)

    # Create a line for each file node and its configuration
    for line_config, line_node in config_nodes:
        # Create line style to match this configuration
        style  = style_map.get_style(line_config)
        values = sorted(line_node.values, key=lambda tup: tup[0])
        xvalues, yvalues = zip(*values)

        plot.plot(xvalues, yvalues, style.fmt())

    axes.set_title(details.title)

    lines, labels = zip(*style_map.get_key())
    axes.legend(tuple(lines), tuple(labels), prop={'size':10})

    axes.set_ylabel(details.value)
    axes.set_xlabel(details.variable)
    axes.set_xlim(0, axes.get_xlim()[1] + 1)
    axes.set_ylim(0, axes.get_ylim()[1] + 1)

    plot.savefig(details.out, format=OUT_FORMAT)

def plot_dir(data_dir, out_dir, force):
    sys.stderr.write("Reading data...\n")
    dir_map = DirMap.read(data_dir)

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    sys.stderr.write("Plotting...\n")

    # Count total plots for % counter
    num_plots = len([x for x in dir_map.leafs(1)])

    plot_details = []
    for plot_path, plot_node in dir_map.leafs(1):
        details = get_details(plot_node, plot_path, out_dir)

        if force or not os.path.exists(details.out):
            plot_details += [details]

    procs = min(len(plot_details), cpu_count()/2)
    pool  = Pool(processes=procs)
    enum  = pool.imap_unordered(plot_by_variable, plot_details)
    for i, _ in enumerate(enum):
        sys.stderr.write('\r {0:.2%}'.format(float(i)/num_plots))
    sys.stderr.write('\n')

def main():
    opts, args = parse_args()
    args = args or [os.getcwd()]

    if opts.force and os.path.exists(opts.out_dir):
        sh.rmtree(opts.out_dir)
    if not os.path.exists(opts.out_dir):
        os.mkdir(opts.out_dir)

    for dir in args:
        plot_dir(dir, opts.out_dir, opts.force)

if __name__ == '__main__':
    main()