summaryrefslogtreecommitdiffstats
path: root/crypto/pcbc.c
blob: fe704775f88ff0fa526a197201c1a751a8535611 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
/*
 * PCBC: Propagating Cipher Block Chaining mode
 *
 * Copyright (C) 2006 Red Hat, Inc. All Rights Reserved.
 * Written by David Howells (dhowells@redhat.com)
 *
 * Derived from cbc.c
 * - Copyright (c) 2006 Herbert Xu <herbert@gondor.apana.org.au>
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the Free
 * Software Foundation; either version 2 of the License, or (at your option)
 * any later version.
 *
 */

#include <crypto/algapi.h>
#include <linux/err.h>
#include <linux/init.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/scatterlist.h>
#include <linux/slab.h>

struct crypto_pcbc_ctx {
	struct crypto_cipher *child;
};

static int crypto_pcbc_setkey(struct crypto_tfm *parent, const u8 *key,
			      unsigned int keylen)
{
	struct crypto_pcbc_ctx *ctx = crypto_tfm_ctx(parent);
	struct crypto_cipher *child = ctx->child;
	int err;

	crypto_cipher_clear_flags(child, CRYPTO_TFM_REQ_MASK);
	crypto_cipher_set_flags(child, crypto_tfm_get_flags(parent) &
				CRYPTO_TFM_REQ_MASK);
	err = crypto_cipher_setkey(child, key, keylen);
	crypto_tfm_set_flags(parent, crypto_cipher_get_flags(child) &
			     CRYPTO_TFM_RES_MASK);
	return err;
}

static int crypto_pcbc_encrypt_segment(struct blkcipher_desc *desc,
				       struct blkcipher_walk *walk,
				       struct crypto_cipher *tfm)
{
	void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
		crypto_cipher_alg(tfm)->cia_encrypt;
	int bsize = crypto_cipher_blocksize(tfm);
	unsigned int nbytes = walk->nbytes;
	u8 *src = walk->src.virt.addr;
	u8 *dst = walk->dst.virt.addr;
	u8 *iv = walk->iv;

	do {
		crypto_xor(iv, src, bsize);
		fn(crypto_cipher_tfm(tfm), dst, iv);
		memcpy(iv, dst, bsize);
		crypto_xor(iv, src, bsize);

		src += bsize;
		dst += bsize;
	} while ((nbytes -= bsize) >= bsize);

	return nbytes;
}

static int crypto_pcbc_encrypt_inplace(struct blkcipher_desc *desc,
				       struct blkcipher_walk *walk,
				       struct crypto_cipher *tfm)
{
	void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
		crypto_cipher_alg(tfm)->cia_encrypt;
	int bsize = crypto_cipher_blocksize(tfm);
	unsigned int nbytes = walk->nbytes;
	u8 *src = walk->src.virt.addr;
	u8 *iv = walk->iv;
	u8 tmpbuf[bsize];

	do {
		memcpy(tmpbuf, src, bsize);
		crypto_xor(iv, src, bsize);
		fn(crypto_cipher_tfm(tfm), src, iv);
		memcpy(iv, tmpbuf, bsize);
		crypto_xor(iv, src, bsize);

		src += bsize;
	} while ((nbytes -= bsize) >= bsize);

	memcpy(walk->iv, iv, bsize);

	return nbytes;
}

static int crypto_pcbc_encrypt(struct blkcipher_desc *desc,
			       struct scatterlist *dst, struct scatterlist *src,
			       unsigned int nbytes)
{
	struct blkcipher_walk walk;
	struct crypto_blkcipher *tfm = desc->tfm;
	struct crypto_pcbc_ctx *ctx = crypto_blkcipher_ctx(tfm);
	struct crypto_cipher *child = ctx->child;
	int err;

	blkcipher_walk_init(&walk, dst, src, nbytes);
	err = blkcipher_walk_virt(desc, &walk);

	while ((nbytes = walk.nbytes)) {
		if (walk.src.virt.addr == walk.dst.virt.addr)
			nbytes = crypto_pcbc_encrypt_inplace(desc, &walk,
							     child);
		else
			nbytes = crypto_pcbc_encrypt_segment(desc, &walk,
							     child);
		err = blkcipher_walk_done(desc, &walk, nbytes);
	}

	return err;
}

static int crypto_pcbc_decrypt_segment(struct blkcipher_desc *desc,
				       struct blkcipher_walk *walk,
				       struct crypto_cipher *tfm)
{
	void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
		crypto_cipher_alg(tfm)->cia_decrypt;
	int bsize = crypto_cipher_blocksize(tfm);
	unsigned int nbytes = walk->nbytes;
	u8 *src = walk->src.virt.addr;
	u8 *dst = walk->dst.virt.addr;
	u8 *iv = walk->iv;

	do {
		fn(crypto_cipher_tfm(tfm), dst, src);
		crypto_xor(dst, iv, bsize);
		memcpy(iv, src, bsize);
		crypto_xor(iv, dst, bsize);

		src += bsize;
		dst += bsize;
	} while ((nbytes -= bsize) >= bsize);

	memcpy(walk->iv, iv, bsize);

	return nbytes;
}

static int crypto_pcbc_decrypt_inplace(struct blkcipher_desc *desc,
				       struct blkcipher_walk *walk,
				       struct crypto_cipher *tfm)
{
	void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
		crypto_cipher_alg(tfm)->cia_decrypt;
	int bsize = crypto_cipher_blocksize(tfm);
	unsigned int nbytes = walk->nbytes;
	u8 *src = walk->src.virt.addr;
	u8 *iv = walk->iv;
	u8 tmpbuf[bsize];

	do {
		memcpy(tmpbuf, src, bsize);
		fn(crypto_cipher_tfm(tfm), src, src);
		crypto_xor(src, iv, bsize);
		memcpy(iv, tmpbuf, bsize);
		crypto_xor(iv, src, bsize);

		src += bsize;
	} while ((nbytes -= bsize) >= bsize);

	memcpy(walk->iv, iv, bsize);

	return nbytes;
}

static int crypto_pcbc_decrypt(struct blkcipher_desc *desc,
			       struct scatterlist *dst, struct scatterlist *src,
			       unsigned int nbytes)
{
	struct blkcipher_walk walk;
	struct crypto_blkcipher *tfm = desc->tfm;
	struct crypto_pcbc_ctx *ctx = crypto_blkcipher_ctx(tfm);
	struct crypto_cipher *child = ctx->child;
	int err;

	blkcipher_walk_init(&walk, dst, src, nbytes);
	err = blkcipher_walk_virt(desc, &walk);

	while ((nbytes = walk.nbytes)) {
		if (walk.src.virt.addr == walk.dst.virt.addr)
			nbytes = crypto_pcbc_decrypt_inplace(desc, &walk,
							     child);
		else
			nbytes = crypto_pcbc_decrypt_segment(desc, &walk,
							     child);
		err = blkcipher_walk_done(desc, &walk, nbytes);
	}

	return err;
}

static int crypto_pcbc_init_tfm(struct crypto_tfm *tfm)
{
	struct crypto_instance *inst = (void *)tfm->__crt_alg;
	struct crypto_spawn *spawn = crypto_instance_ctx(inst);
	struct crypto_pcbc_ctx *ctx = crypto_tfm_ctx(tfm);
	struct crypto_cipher *cipher;

	cipher = crypto_spawn_cipher(spawn);
	if (IS_ERR(cipher))
		return PTR_ERR(cipher);

	ctx->child = cipher;
	return 0;
}

static void crypto_pcbc_exit_tfm(struct crypto_tfm *tfm)
{
	struct crypto_pcbc_ctx *ctx = crypto_tfm_ctx(tfm);
	crypto_free_cipher(ctx->child);
}

static struct crypto_instance *crypto_pcbc_alloc(struct rtattr **tb)
{
	struct crypto_instance *inst;
	struct crypto_alg *alg;
	int err;

	err = crypto_check_attr_type(tb, CRYPTO_ALG_TYPE_BLKCIPHER);
	if (err)
		return ERR_PTR(err);

	alg = crypto_get_attr_alg(tb, CRYPTO_ALG_TYPE_CIPHER,
				  CRYPTO_ALG_TYPE_MASK);
	if (IS_ERR(alg))
		return ERR_PTR(PTR_ERR(alg));

	inst = crypto_alloc_instance("pcbc", alg);
	if (IS_ERR(inst))
		goto out_put_alg;

	inst->alg.cra_flags = CRYPTO_ALG_TYPE_BLKCIPHER;
	inst->alg.cra_priority = alg->cra_priority;
	inst->alg.cra_blocksize = alg->cra_blocksize;
	inst->alg.cra_alignmask = alg->cra_alignmask;
	inst->alg.cra_type = &crypto_blkcipher_type;

	/* We access the data as u32s when xoring. */
	inst->alg.cra_alignmask |= __alignof__(u32) - 1;
an>keys():
				print PF_WJOINT
				if event['handle'] == "kfree_skb":
					print PF_KFREE_SKB % \
						(diff_msec(base_t,
						event['comm_t']),
						event['location'])
				elif event['handle'] == "consume_skb":
					print PF_CONS_SKB % \
						diff_msec(base_t,
							event['comm_t'])
			print PF_JOINT

def trace_begin():
	global show_tx
	global show_rx
	global dev
	global debug

	for i in range(len(sys.argv)):
		if i == 0:
			continue
		arg = sys.argv[i]
		if arg == 'tx':
			show_tx = 1
		elif arg =='rx':
			show_rx = 1
		elif arg.find('dev=',0, 4) >= 0:
			dev = arg[4:]
		elif arg == 'debug':
			debug = 1
	if show_tx == 0  and show_rx == 0:
		show_tx = 1
		show_rx = 1

def trace_end():
	# order all events in time
	all_event_list.sort(lambda a,b :cmp(a[EINFO_IDX_TIME],
					    b[EINFO_IDX_TIME]))
	# process all events
	for i in range(len(all_event_list)):
		event_info = all_event_list[i]
		name = event_info[EINFO_IDX_NAME]
		if name == 'irq__softirq_exit':
			handle_irq_softirq_exit(event_info)
		elif name == 'irq__softirq_entry':
			handle_irq_softirq_entry(event_info)
		elif name == 'irq__softirq_raise':
			handle_irq_softirq_raise(event_info)
		elif name == 'irq__irq_handler_entry':
			handle_irq_handler_entry(event_info)
		elif name == 'irq__irq_handler_exit':
			handle_irq_handler_exit(event_info)
		elif name == 'napi__napi_poll':
			handle_napi_poll(event_info)
		elif name == 'net__netif_receive_skb':
			handle_netif_receive_skb(event_info)
		elif name == 'net__netif_rx':
			handle_netif_rx(event_info)
		elif name == 'skb__skb_copy_datagram_iovec':
			handle_skb_copy_datagram_iovec(event_info)
		elif name == 'net__net_dev_queue':
			handle_net_dev_queue(event_info)
		elif name == 'net__net_dev_xmit':
			handle_net_dev_xmit(event_info)
		elif name == 'skb__kfree_skb':
			handle_kfree_skb(event_info)
		elif name == 'skb__consume_skb':
			handle_consume_skb(event_info)
	# display receive hunks
	if show_rx:
		for i in range(len(receive_hunk_list)):
			print_receive(receive_hunk_list[i])
	# display transmit hunks
	if show_tx:
		print "   dev    len      Qdisc        " \
			"       netdevice             free"
		for i in range(len(tx_free_list)):
			print_transmit(tx_free_list[i])
	if debug:
		print "debug buffer status"
		print "----------------------------"
		print "xmit Qdisc:remain:%d overflow:%d" % \
			(len(tx_queue_list), of_count_tx_queue_list)
		print "xmit netdevice:remain:%d overflow:%d" % \
			(len(tx_xmit_list), of_count_tx_xmit_list)
		print "receive:remain:%d overflow:%d" % \
			(len(rx_skb_list), of_count_rx_skb_list)

# called from perf, when it finds a correspoinding event
def irq__softirq_entry(name, context, cpu, sec, nsec, pid, comm, vec):
	if symbol_str("irq__softirq_entry", "vec", vec) != "NET_RX":
		return
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm, vec)
	all_event_list.append(event_info)

def irq__softirq_exit(name, context, cpu, sec, nsec, pid, comm, vec):
	if symbol_str("irq__softirq_entry", "vec", vec) != "NET_RX":
		return
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm, vec)
	all_event_list.append(event_info)

def irq__softirq_raise(name, context, cpu, sec, nsec, pid, comm, vec):
	if symbol_str("irq__softirq_entry", "vec", vec) != "NET_RX":
		return
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm, vec)
	all_event_list.append(event_info)

def irq__irq_handler_entry(name, context, cpu, sec, nsec, pid, comm,
			irq, irq_name):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm,
			irq, irq_name)
	all_event_list.append(event_info)

def irq__irq_handler_exit(name, context, cpu, sec, nsec, pid, comm, irq, ret):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm, irq, ret)
	all_event_list.append(event_info)

def napi__napi_poll(name, context, cpu, sec, nsec, pid, comm, napi, dev_name):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm,
			napi, dev_name)
	all_event_list.append(event_info)

def net__netif_receive_skb(name, context, cpu, sec, nsec, pid, comm, skbaddr,
			skblen, dev_name):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm,
			skbaddr, skblen, dev_name)
	all_event_list.append(event_info)

def net__netif_rx(name, context, cpu, sec, nsec, pid, comm, skbaddr,
			skblen, dev_name):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm,
			skbaddr, skblen, dev_name)
	all_event_list.append(event_info)

def net__net_dev_queue(name, context, cpu, sec, nsec, pid, comm,
			skbaddr, skblen, dev_name):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm,
			skbaddr, skblen, dev_name)
	all_event_list.append(event_info)

def net__net_dev_xmit(name, context, cpu, sec, nsec, pid, comm,
			skbaddr, skblen, rc, dev_name):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm,
			skbaddr, skblen, rc ,dev_name)
	all_event_list.append(event_info)

def skb__kfree_skb(name, context, cpu, sec, nsec, pid, comm,
			skbaddr, protocol, location):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm,
			skbaddr, protocol, location)
	all_event_list.append(event_info)

def skb__consume_skb(name, context, cpu, sec, nsec, pid, comm, skbaddr):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm,
			skbaddr)
	all_event_list.append(event_info)

def skb__skb_copy_datagram_iovec(name, context, cpu, sec, nsec, pid, comm,
	skbaddr, skblen):
	event_info = (name, context, cpu, nsecs(sec, nsec), pid, comm,
			skbaddr, skblen)
	all_event_list.append(event_info)

def handle_irq_handler_entry(event_info):
	(name, context, cpu, time, pid, comm, irq, irq_name) = event_info
	if cpu not in irq_dic.keys():
		irq_dic[cpu] = []
	irq_record = {'irq':irq, 'name':irq_name, 'cpu':cpu, 'irq_ent_t':time}
	irq_dic[cpu].append(irq_record)

def handle_irq_handler_exit(event_info):
	(name, context, cpu, time, pid, comm, irq, ret) = event_info
	if cpu not in irq_dic.keys():
		return
	irq_record = irq_dic[cpu].pop()
	if irq != irq_record['irq']:
		return
	irq_record.update({'irq_ext_t':time})
	# if an irq doesn't include NET_RX softirq, drop.
	if 'event_list' in irq_record.keys():
		irq_dic[cpu].append(irq_record)

def handle_irq_softirq_raise(event_info):
	(name, context, cpu, time, pid, comm, vec) = event_info
	if cpu not in irq_dic.keys() \
	or len(irq_dic[cpu]) == 0:
		return
	irq_record = irq_dic[cpu].pop()
	if 'event_list' in irq_record.keys():
		irq_event_list = irq_record['event_list']
	else:
		irq_event_list = []
	irq_event_list.append({'time':time, 'event':'sirq_raise'})
	irq_record.update({'event_list':irq_event_list})
	irq_dic[cpu].append(irq_record)

def handle_irq_softirq_entry(event_info):
	(name, context, cpu, time, pid, comm, vec) = event_info
	net_rx_dic[cpu] = {'sirq_ent_t':time, 'event_list':[]}

def handle_irq_softirq_exit(event_info):
	(name, context, cpu, time, pid, comm, vec) = event_info
	irq_list = []
	event_list = 0
	if cpu in irq_dic.keys():
		irq_list = irq_dic[cpu]
		del irq_dic[cpu]
	if cpu in net_rx_dic.keys():
		sirq_ent_t = net_rx_dic[cpu]['sirq_ent_t']
		event_list = net_rx_dic[cpu]['event_list']
		del net_rx_dic[cpu]
	if irq_list == [] or event_list == 0:
		return
	rec_data = {'sirq_ent_t':sirq_ent_t, 'sirq_ext_t':time,
		    'irq_list':irq_list, 'event_list':event_list}
	# merge information realted to a NET_RX softirq
	receive_hunk_list.append(rec_data)

def handle_napi_poll(event_info):
	(name, context, cpu, time, pid, comm, napi, dev_name) = event_info
	if cpu in net_rx_dic.keys():
		event_list = net_rx_dic[cpu]['event_list']
		rec_data = {'event_name':'napi_poll',
				'dev':dev_name, 'event_t':time}
		event_list.append(rec_data)

def handle_netif_rx(event_info):
	(name, context, cpu, time, pid, comm,
		skbaddr, skblen, dev_name) = event_info
	if cpu not in irq_dic.keys() \
	or len(irq_dic[cpu]) == 0:
		return
	irq_record = irq_dic[cpu].pop()
	if 'event_list' in irq_record.keys():
		irq_event_list = irq_record['event_list']
	else:
		irq_event_list = []
	irq_event_list.append({'time':time, 'event':'netif_rx',
		'skbaddr':skbaddr, 'skblen':skblen, 'dev_name':dev_name})
	irq_record.update({'event_list':irq_event_list})
	irq_dic[cpu].append(irq_record)

def handle_netif_receive_skb(event_info):
	global of_count_rx_skb_list

	(name, context, cpu, time, pid, comm,
		skbaddr, skblen, dev_name) = event_info
	if cpu in net_rx_dic.keys():
		rec_data = {'event_name':'netif_receive_skb',
			    'event_t':time, 'skbaddr':skbaddr, 'len':skblen}
		event_list = net_rx_dic[cpu]['event_list']
		event_list.append(rec_data)
		rx_skb_list.insert(0, rec_data)
		if len(rx_skb_list) > buffer_budget:
			rx_skb_list.pop()
			of_count_rx_skb_list += 1

def handle_net_dev_queue(event_info):
	global of_count_tx_queue_list

	(name, context, cpu, time, pid, comm,
		skbaddr, skblen, dev_name) = event_info
	skb = {'dev':dev_name, 'skbaddr':skbaddr, 'len':skblen, 'queue_t':time}
	tx_queue_list.insert(0, skb)
	if len(tx_queue_list) > buffer_budget:
		tx_queue_list.pop()
		of_count_tx_queue_list += 1

def handle_net_dev_xmit(event_info):
	global of_count_tx_xmit_list

	(name, context, cpu, time, pid, comm,
		skbaddr, skblen, rc, dev_name) = event_info
	if rc == 0: # NETDEV_TX_OK
		for i in range(len(tx_queue_list)):
			skb = tx_queue_list[i]
			if skb['skbaddr'] == skbaddr:
				skb['xmit_t'] = time
				tx_xmit_list.insert(0, skb)
				del tx_queue_list[i]
				if len(tx_xmit_list) > buffer_budget:
					tx_xmit_list.pop()
					of_count_tx_xmit_list += 1
				return

def handle_kfree_skb(event_info):
	(name, context, cpu, time, pid, comm,
		skbaddr, protocol, location) = event_info
	for i in range(len(tx_queue_list)):
		skb = tx_queue_list[i]
		if skb['skbaddr'] == skbaddr:
			del tx_queue_list[i]
			return
	for i in range(len(tx_xmit_list)):
		skb = tx_xmit_list[i]
		if skb['skbaddr'] == skbaddr:
			skb['free_t'] = time
			tx_free_list.append(skb)
			del tx_xmit_list[i]
			return
	for i in range(len(rx_skb_list)):
		rec_data = rx_skb_list[i]
		if rec_data['skbaddr'] == skbaddr:
			rec_data.update({'handle':"kfree_skb",
					'comm':comm, 'pid':pid, 'comm_t':time})
			del rx_skb_list[i]
			return

def handle_consume_skb(event_info):
	(name, context, cpu, time, pid, comm, skbaddr) = event_info
	for i in range(len(tx_xmit_list)):
		skb = tx_xmit_list[i]
		if skb['skbaddr'] == skbaddr:
			skb['free_t'] = time
			tx_free_list.append(skb)
			del tx_xmit_list[i]
			return

def handle_skb_copy_datagram_iovec(event_info):
	(name, context, cpu, time, pid, comm, skbaddr, skblen) = event_info
	for i in range(len(rx_skb_list)):
		rec_data = rx_skb_list[i]
		if skbaddr == rec_data['skbaddr']:
			rec_data.update({'handle':"skb_copy_datagram_iovec",
					'comm':comm, 'pid':pid, 'comm_t':time})
			del rx_skb_list[i]
			return