summaryrefslogtreecommitdiffstats
path: root/shuffle_truncate.py
blob: a7219a55636cd75468b4ef085ab67342c0a2ee6c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python

import numpy
import os
import sys
import optparse

from util import load_binary_file

o = optparse.make_option

opts = [
    o('-c', '--cut-off', action='store', dest='cutoff', type='int',
      help='max number of samples to use'),

    o(None, '--count', action='store_true', dest='count',
      help='just report the number of samples in each file'),

    ]

defaults = {
    'cutoff'  : None,
    'count'   : False,
    }

options = None

def load_files(fnames):
    return [load_binary_file(f) for f in fnames]

def shuffle_truncate(arrays, fnames, target_length=None):
    # Determine how many samples we can use.
    if target_length:
        shortest = target_length
    else:
        shortest = min([len(a) for a in arrays])
        print "Selecting %d samples from each data file." % shortest

    # Make sure we'll select samples from all
    # parts of the data file.
    for a, n in zip(arrays, fnames):
        if len(a) > shortest:
            # Gotta be uniformly shuffled.
            print "Shuffling %s ..." % n
            numpy.random.shuffle(a)
        else:
            # not enough samples
            print "Not shuffling %s (too few samples)." % n

    # Now select the same number of samples from each file.
    truncated = [a[:shortest] for a in arrays]

    return truncated

def store_files(arrays, fnames):
    for a, fn in zip(arrays, fnames):
        print 'Storing %s.' % fn
        fd = open(fn, 'wb')
        a.tofile(fd)
        fd.close()

def target_file(fname, want_ext):
    f = os.path.basename(fname)
    if not want_ext is None:
        name, ext = os.path.splitext(f)
        return "%s.%s" % (name, want_ext)
    else:
        return f

def shuffle_truncate_store(files, cutoff=None, ext='sbn'):
    data  = load_files(files)
    trunc = shuffle_truncate(data, files, target_length=cutoff)
    names = [target_file(f, ext) for f in files]
    store_files(trunc, names)

def shuffle_truncate_store_individually(files, cutoff):
    fmt = "%%0%dd" % len(str(len(files)))
    for i, f in enumerate(files):
        print ("["  + fmt + "/%d] %s") % (i+1, len(files),
                                          os.path.basename(f))
        sys.stdout.flush()
        name = target_file(f, 'sbn')
        fs = os.stat(f)
        if os.path.exists(name):
            print "Skipping since %s exists." % name
        elif fs.st_size == 0:
            print "Skipping since trace is empty."
        else:
            shuffle_truncate_store([f], cutoff=cutoff)

def report_sample_counts(files):
    fmt = "%%0%dd" % len(str(len(files)))
    for i, f in enumerate(files):
        d = load_binary_file(f)
        print ("["  + fmt + "/%d] %8d %s") % (i+1, len(files), len(d), f)
        sys.stdout.flush()
        del d

if __name__ == '__main__':
    parser = optparse.OptionParser(option_list=opts)
    parser.set_defaults(**defaults)
    (options, files) = parser.parse_args()

    if not files:
        print "Usage: shuffle_truncate_py data1.bin data2.bin data3.bin ..."
    else:
        if options.count:
            report_sample_counts(files)
        elif options.cutoff:
            shuffle_truncate_store_individually(files, options.cutoff)
        else:
            shuffle_truncate_store(files)