aboutsummaryrefslogtreecommitdiffstats
path: root/ft-shuffle-truncate
blob: 44ace395d9362ebc111d66c4c69ef9d6c6d75fd8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python

import numpy
import os
import sys
import optparse

def load_binary_file(fname, dtype='float32', modify=False):
    size = os.stat(fname).st_size

    if size:
        data = numpy.memmap(fname, dtype=dtype,
                            mode='r+' if modify else 'c')
        return data
    else:
        return []

o = optparse.make_option

opts = [
    o('-c', '--cut-off', action='store', dest='cutoff', type='int',
      help='max number of samples to use'),

    o(None, '--count', action='store_true', dest='count',
      help='just report the number of samples in each file'),

    o(None, '--only-min', action='store_true', dest='only_min',
      help='When counting, report only the minimum number of samples.'),

    o(None, '--output-dir', action='store', dest='output_dir',
      help='directory where output files should be stored.')
    ]

defaults = {
    'cutoff'  : None,
    'count'   : False,
    'only_min' : False,
    'output_dir' : None,
    }

options = None

def load_files(fnames):
    return [load_binary_file(f) for f in fnames]

def shuffle_truncate(arrays, fnames, target_length=None):
    # Determine how many samples we can use.
    if target_length:
        shortest = target_length
    else:
        shortest = min([len(a) for a in arrays])
        print "Selecting %d samples from each data file." % shortest

    # Make sure we'll select samples from all
    # parts of the data file.
    for a, n in zip(arrays, fnames):
        if len(a) > shortest:
            # Gotta be uniformly shuffled.
            print "Shuffling %s ..." % n
            numpy.random.shuffle(a)
        else:
            # not enough samples
            print "Not shuffling %s." % n

    # Now select the same number of samples from each file.
    truncated = [a[:shortest] for a in arrays]

    return truncated

def store_files(arrays, fnames):
    for a, fn in zip(arrays, fnames):
        print 'Storing %s.' % fn
        fd = open(fn, 'wb')
        a.tofile(fd)
        fd.close()

def target_file(fname, want_ext):
    f = os.path.basename(fname)
    if options.output_dir:
        d = options.output_dir
    else:
        d = os.path.dirname(fname)
    if not want_ext is None:
        name, ext = os.path.splitext(f)
        f = "%s.%s" % (name, want_ext)
    return os.path.join(d, f)

def shuffle_truncate_store(files, cutoff=None, ext='sf32'):
    data  = load_files(files)
    trunc = shuffle_truncate(data, files, target_length=cutoff)
    names = [target_file(f, ext) for f in files]
    store_files(trunc, names)

def shuffle_truncate_store_individually(files, cutoff):
    fmt = "%%0%dd" % len(str(len(files)))
    for i, f in enumerate(files):
        print ("["  + fmt + "/%d] %s") % (i+1, len(files),
                                          os.path.basename(f))
        sys.stdout.flush()
        name = target_file(f, 'sbn')
        fs = os.stat(f)
        if os.path.exists(name):
            print "Skipping since %s exists." % name
        elif fs.st_size == 0:
            print "Skipping since trace is empty."
        else:
            shuffle_truncate_store([f], cutoff=cutoff)

def report_sample_counts(files):
    counts = []
    fmt = "%%0%dd" % len(str(len(files)))
    for i, f in enumerate(files):
        d = load_binary_file(f)
        counts.append(len(d))
        if not options.only_min:
            print ("["  + fmt + "/%d] %8d %s") % (i+1, len(files), len(d), f)
            sys.stdout.flush()
        del d
    if options.only_min:
        print min(counts)

if __name__ == '__main__':
    # FIXME: would be nicer with argparse
    parser = optparse.OptionParser(option_list=opts)
    parser.set_defaults(**defaults)
    (options, files) = parser.parse_args()

    if not files:
        print "Usage: ft-shuffle-truncate data1.float32 data2.float32 data3.float32 ..."
    else:
        if options.count:
            report_sample_counts(files)
        elif options.cutoff:
            shuffle_truncate_store_individually(files, options.cutoff)
        else:
            shuffle_truncate_store(files)